This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

Structure-Aware E(3)-Invariant Molecular Conformer Aggregation Networks

Duy M. H. Nguyen    Nina Lukashina    Tai Nguyen    An T. Le    TrungTin Nguyen    Nhat Ho    Jan Peters    Daniel Sonntag    Viktor Zaverkin    Mathias Niepert
Abstract

A molecule’s 2D representation consists of its atoms, their attributes, and the molecule’s covalent bonds. A 3D (geometric) representation of a molecule is called a conformer and consists of its atom types and Cartesian coordinates. Every conformer has a potential energy, and the lower this energy, the more likely it occurs in nature. Most existing machine learning methods for molecular property prediction consider either 2D molecular graphs or 3D conformer structure representations in isolation. Inspired by recent work on using ensembles of conformers in conjunction with 2D graph representations, we propose E\mathrm{E}(3)-invariant molecular conformer aggregation networks. The method integrates a molecule’s 2D representation with that of multiple of its conformers. Contrary to prior work, we propose a novel 2D–3D aggregation mechanism based on a differentiable solver for the Fused Gromov-Wasserstein Barycenter problem and the use of an efficient conformer generation method based on distance geometry. We show that the proposed aggregation mechanism is E\mathrm{E}(3) invariant and propose an efficient GPU implementation. Moreover, we demonstrate that the aggregation mechanism helps to significantly outperform state-of-the-art molecule property prediction methods on established datasets. Our implementation is available at this link.

Machine Learning, ICML

1 Introduction

Machine learning is increasingly used for modeling and analyzing properties of atomic systems with important applications in drug discovery and material design (Butler et al., 2018; Vamathevan et al., 2019; Choudhary et al., 2022; Fedik et al., 2022; Batatia et al., 2023). Most existing machine learning approaches to molecular property prediction either incorporate 2D (topological) (Kipf & Welling, 2017; Gilmer et al., 2017b; Xu et al., 2018; Veličković et al., 2018) or 3D (geometric) information of molecular structures (Schütt et al., 2017, 2021; Batzner et al., 2022; Batatia et al., 2022). 2D molecular graphs describe molecular connectivity (covalent bonds) but ignore the spatial arrangement of the atoms in a molecule (molecular conformation). 3D graph representations capture conformational changes but are commonly used to encode an individual conformer. Many molecular properties, such as solubility and binding affinity (Cao et al., 2022), however, inherently depend on a large number of conformations a molecule can occur as in nature, and employing a single geometry per molecule limits the applicability of machine-learning models. Furthermore, it is challenging to determine conformers that predominantly contribute to the molecular properties of interest. Thus, developing expressive representations for molecular systems when modeling their properties is an ongoing challenge.

To overcome this, recent work has introduced molecular representations that incorporate both 2D molecular graphs and 3D conformers (Zhu et al., 2023). These methods aim to encode various molecular structures, such as atom types, bond types, and spatial coordinates, leading to more comprehensive feature embeddings. The latest algorithms, including graph neural networks, attention mechanisms (Axelrod & Gómez-Bombarelli, 2023), and long short-term memory networks (Wang et al., 2024b), have demonstrated improved generalization capabilities in various molecular prediction tasks. Despite their effectiveness, these methods struggle to balance model complexity and performance and face scalability challenges mainly due to the computational cost of generating 3D conformers. These problems are exacerbated when using several conformers per system, underscoring the need for strategies to mitigate these limitations.

Contributions. We propose a new message-passing neural network architecture that integrates both 2D and ensembles of 3D molecular structures. The approach introduces a geometry-aware conformer ensemble aggregation strategy using Fused Gromov-Wasserstein (FGW) barycenters (Titouan et al., 2019), in which interactions between atoms across conformers are captured using both latent atom embeddings and conformer structures. The aggregation mechanism is invariant to actions of the group E(3)E(3)—the Euclidean group in 3 dimensions—such as translations, rotations, and inversion as well as to permutations of the input conformers. To make the proposed method applicable to large-scale problems, we accelerate the solvers for the FGW barycenter problem with entropic-based techniques (Rioux et al., 2023), allowing the model to be trained in parallel on multiple GPUs. We also experimentally explore the impact of the number of conformers and demonstrate that, within our framework, a modest number of conformers generated through efficient distance geometry-based sampling achieves state-of-the-art accuracy. We partially explain this through a theoretical analysis showing that the empirical barycenter converges to the target barycenter at a rate of 𝒪(1/K)\mathcal{O}\left(1/K\right), where KK denotes the number of conformers. Finally, we conduct a systematic evaluation of our proposed approaches, comparing their performance to state-of-the-art algorithms. The results show that our method is competitive and frequently surpasses existing methods across a variety of datasets and tasks.

2 Background

We first provide notations used in the paper. We note the simplex histogram with nn-bins as Δn:={𝝎+n:iωi=1}\Delta_{n}:=\left\{\bm{\omega}\in\mathbb{R}_{+}^{n}:\sum_{i}\omega_{i}=1\right\} and 𝕊n(𝔸)\mathbb{S}_{n}(\mathbb{A}) as the set of symmetric matrices of size nn taking values in 𝔸\mathbb{A}\subset\mathbb{R}. For any x𝛀x\in\bm{\Omega}, δx\delta_{x} denotes the Dirac measure in xx. Let 𝒫(𝛀)\mathcal{P}(\bm{\Omega}) be the set of all probability measures on a space 𝛀\bm{\Omega}. We denote [K]={1,2,,K}[K]=\{1,2,\ldots,K\} for any KK\in\mathbb{N}. We denote the matrix scalar product associated with the Forbenius norm as \langle\cdot\rangle. The tensor-matrix multiplication will be denoted as \otimes, i.e., given any tensor 𝑳:=(Lijkl){\bm{\mathsfit{L}}}:=\left(L_{ijkl}\right) and matrix 𝑩:=(Bkl){\bm{B}}:=(B_{kl}), 𝑳𝑩{\bm{\mathsfit{L}}}\otimes{\bm{B}} is the matrix (klLijklBkl)ij\left(\sum_{kl}L_{ijkl}B_{kl}\right)_{ij}.

A graph GG is a pair (V,E)(V,E) with finite sets of vertices or nodes VV and edges E{{u,v}Vuv}E\subseteq\{\{u,v\}\subseteq V\mid u\neq v\}. We set n|V|n\coloneqq|V| and write that the graph is of order nn. For ease of notation, we denote the edge {u,v}\{u,v\} in EE by (u,v)(u,v) or (v,u)(v,u). The neighborhood of vv in VV is denoted by N(v){uV(v,u)E}N(v)\coloneqq\{u\in V\mid(v,u)\in E\} and the degree of a vertex vv is |N(v)||N(v)|. An attributed graph GG is a triple (V,E,f)(V,E,\ell_{f}) with a graph (V,E)(V,E) and (vertex-)feature (attribute) function f:V1×d\ell_{f}\colon V\to\mathbb{R}^{1\times d}, for some dd\in\mathbb{N}^{\star}. Then f(v)\ell_{f}(v) is an attribute or feature of vv, for vv in VV. When we have multiple attributes, we have a pair 𝑮=(G,𝑯){\bm{G}}=(G,{\bm{H}}), where G=(V,E)G=(V,E) and 𝑯{\bm{H}} in n×d\mathbb{R}^{n\times d} is a node attribute matrix. For a matrix 𝑯{\bm{H}} in n×d\mathbb{R}^{n\times d} and vv in [n][n], we denote by 𝑯v{\bm{H}}_{v} in 1×d\mathbb{R}^{1\times d} the vvth row of 𝑯{\bm{H}} such that 𝑯vf(v){\bm{H}}_{v}\coloneqq\ell_{f}(v). Analogously, we can define attributes for the edges of the graph. Furthermore, we can encode an nn-order graph GG via an adjacency matrix 𝐀(G){0,1}n×n\mathbf{A}(G)\in\{0,1\}^{n\times n}.

2.1 Message-Passing Neural Networks

Message-passing neural networks (MPNN) learn dd-dimensional real-valued vector representations for each vertex in a graph by exchanging and aggregating information from neighboring nodes. Each vertex vv is annotated with a feature 𝐡v(0)\mathbf{h}_{v}^{(0)} in d\mathbb{R}^{d} representing characteristics such as atom positions and numbers in the case of chemical molecules. In addition, each edge (u,v)(u,v) is associated with a feature vector 𝒆(u,v)\bm{e}(u,v). An MPNN architecture consists of a composition of permutation-equivariant parameterized functions.

Following Gilmer et al. (2017a) and Scarselli et al. (2009), in each layer, >0\ell>0, we compute vertex features

𝐡v()\displaystyle\mathbf{h}_{v}^{(\ell)} 𝖴𝖯𝖣()(𝐡v(1),𝖠𝖦𝖦()({{𝐦v,u()uN(v)}}))\displaystyle\coloneqq\mathsf{UPD}^{(\ell)}\Bigl{(}\mathbf{h}_{v}^{(\ell-1)},\mathsf{AGG}^{(\ell)}\bigl{(}\{\!\!\{\mathbf{m}_{v,u}^{(\ell)}\mid u\in N(v)\}\!\!\}\bigr{)}\Bigr{)}
𝐦v,u()\displaystyle\mathbf{m}_{v,u}^{(\ell)} 𝐌()(𝐡v(1),𝐡u(1),𝐞v,u)d,\displaystyle\coloneqq\mathbf{M}^{(\ell)}\Bigl{(}\mathbf{h}_{v}^{(\ell-1)},\mathbf{h}_{u}^{(\ell-1)},\mathbf{e}_{v,u}\Bigr{)}\in\mathbb{R}^{d}, (1)

where 𝖴𝖯𝖣()\mathsf{UPD}^{(\ell)}, 𝐌()\mathbf{M}^{(\ell)}, and 𝖠𝖦𝖦()\mathsf{AGG}^{(\ell)} are differentiable parameterized functions. In the case of graph-level regression problems, one uses

𝐡G𝖱𝖤𝖠𝖣𝖮𝖴𝖳({{𝐡v(L)vV(G)}})d,\mathbf{h}_{G}\coloneqq\mathsf{READOUT}\bigl{(}\{\!\!\{\mathbf{h}_{v}^{{(L)}}\mid v\in V(G)\}\!\!\}\bigr{)}\in\mathbb{R}^{d}, (2)

to compute a single vectorial representation based on learned vertex features after iteration LL where 𝖱𝖤𝖠𝖣𝖮𝖴𝖳\mathsf{READOUT} can be a differentiable parameterized function.

Molecules are 33-dimensional structures that can be represented by geometric graphs, capturing each atom’s 3D position. To obtain more expressive representations, we also consider geometric input attributes and focus on vectorial features \vv𝐯v,\vv𝐯u\vv{\mathbf{v}}_{v},\vv{\mathbf{v}}_{u} of nodes. Since we address the problem of molecular property prediction, where we assume the properties to be invariant to actions of the group E(3)E(3), we focus on E(3)E(3)-invariant MPNNs for geometric graphs.

Refer to caption
Figure 1: Overview of the proposed conformer aggregation network with alanine dipeptide as example input.

2.2 Fused Gromov-Wasserstein Distance

Fused Gromov-Wasserstein. An undirected attributed graph GG of order nn in the optimal transport context is defined as a tuple G:=(𝑯,𝑨,𝝎)G:=({\bm{H}},{\bm{A}},\bm{\omega}), where 𝑯n×d{\bm{H}}\in\mathbb{R}^{n\times d} is a node feature matrix and 𝑨{\bm{A}} is a matrix encoding relationships between nodes, and 𝝎Δn\bm{\omega}\in\Delta_{n} denotes the probability measure of nodes within the graph, which can be modeled as the relative importance weights of graph nodes. Without any prior knowledge, uniform weights can be chosen (𝝎=𝟏n/n)(\bm{\omega}=\mathbf{1}_{n}/n) (Vincent-Cuaz et al., 2022). The matrix 𝑨{\bm{A}} can be the graph adjacency matrix, the shortest-path matrix or other distance metrics based on the graph topologies (Peyré et al., 2016; Titouan et al., 2019, 2020). Given two graphs G1,G2G_{1},G_{2} of order n1,n2n_{1},n_{2}, respectively, Fused Gromov-Wasserstein (FGW) distance can be defined as follows:

FGWp,α(G1,G2)\displaystyle\text{FGW}_{p,\alpha}(G_{1},G_{2})
:=min𝝅Π(𝝎1,𝝎2)(1α)𝑴+α𝑳(𝑨1,𝑨2)𝝅,𝝅.\displaystyle:=\min_{{\bm{\pi}}\in\Pi\left(\bm{\omega}_{1},\bm{\omega}_{2}\right)}\left\langle(1-\alpha){\bm{M}}+\alpha{\bm{\mathsfit{L}}}\left({\bm{A}}_{1},{\bm{A}}_{2}\right)\otimes{\bm{\pi}},{\bm{\pi}}\right\rangle. (3)

Here 𝑴:=(df(𝑯1[i],𝑯2[j])p)n1×n2n1×n2{\bm{M}}:=\left(d_{f}({\bm{H}}_{1}[i],{\bm{H}}_{2}[j])^{p}\right)_{n_{1}\times n_{2}}\in{\mathbb{R}}^{n_{1}\times n_{2}} is the pairwise node distance matrix, 𝑳(𝑨1,𝑨2)=(|𝑨1[i,j]𝑨2[l,m]|p)ijlm{\bm{\mathsfit{L}}}\left({\bm{A}}_{1},{\bm{A}}_{2}\right)=(|{\bm{A}}_{1}[i,j]-{\bm{A}}_{2}[l,m]|^{p})_{ijlm} the 4-tensor representing the alignment cost matrix, and 𝚷(𝝎1,𝝎2):={𝝅+n1×n2|𝝅𝟏n2=𝝎1,𝝅𝟏n1=𝝎2}\bm{\Pi}(\bm{\omega}_{1},\bm{\omega}_{2}):=\{\bm{\pi}\in\mathbb{R}_{+}^{n_{1}\times n_{2}}|\bm{\pi}\mathbf{1}_{n_{2}}=\bm{\omega}_{1},~{}\bm{\pi}\mathbf{1}_{n_{1}}=\bm{\omega}_{2}\} is the set of all valid couplings between node distributions 𝝎1\bm{\omega}_{1} and 𝝎2\bm{\omega}_{2}. Moreover, df(,)d_{f}(\cdot,\cdot) is the distance metric in the feature space, and α[0,1]\alpha\in[0,1] is the weight that trades off between the Gromov-Wasserstein cost on the graph structure and Wasserstein cost on the feature signal. In practice, we usually choose p=2p=2, Euclidean distance for df(,)d_{f}(\cdot,\cdot), and α=0.5\alpha=0.5 to calculate FGW distance.

Entropic Fused Gromov-Wasserstein. The entropic FGW distance adds an entropic term (Cuturi, 2013) as

FGWp,αϵ(G1,G2):=FGWp,α(G1,G2)ϵH(𝝅),\text{FGW}_{p,\alpha}^{\epsilon}(G_{1},G_{2}):=\text{FGW}_{p,\alpha}(G_{1},G_{2})-\epsilon\operatorname{H}({\bm{\pi}}), (4)

where the entropic scalar ϵ\epsilon facilitates the tunable trade-off between solution accuracy and computational performance (w.r.t. lower and higher ϵ\epsilon, respectively). Solving this entropic FGW involves iterations of solving the linear entropic OT problem Equation 37 with (stabilized) Sinkhorn projections (Proposition 2 (Peyré et al., 2016)), described in Appendix C and Algorithm 2.

3 ConAN: Conformer Aggregation Networks via Fused Gromov-Wasserstein Barycenters

In what follows, we refer to the representation of atoms and covalent bonds and their attributes as the 2D structure and the atoms, their 3D coordinates, and atom types as 3D structures. The following subsections describe each part of the framework in detail.

3.1 Conformer Generation

To efficiently generate conformers, we employ distance geometry-based algorithms, which convert distance constraints into Cartesian coordinates. For atomistic systems, constraints typically define lower and upper bounds on interatomic squared distances. In a 2D input graph, covalent bond distances adhere to known ranges, while bond angles are determined by corresponding geminal distances. Adjacent atoms or functional groups adhere to cis/trans limits for rotatable bonds or set values for rigid groups. Other distances have hard sphere lower bounds, usually chosen approximately 10%\% below van der Waals radii (Hawkins, 2017). Chirality constraints are applied to every rigid quadruple of atoms.

A distance geometry algorithm now randomly generates a 33-dimensional conformation satisfying the constraints. To bias the generation towards low-energy conformations, a simple and efficient force field is typically applied. We use efficient implementations from the RDKit package (Landrum, 2016).

3.2 Conformer Aggregation Network

We propose a new MPNN-based neural network that consists of three parts as depicted in Figure 1. First, a 2D MPNN model is used to capture the general molecular features such as covalent bond structure and atom features. Second, a novel FGW barycenter-based implicit E(3)E(3)-invariant aggregation function that integrates the representations of molecular 3D conformations computed by geometric message-passing neural networks. Finally, a permutation and E(3)E(3)-invariant aggregation function will be used to combine the 2D graph and 3D conformer representations of the molecules.

2D Molecular Graph Message-Passing Network. Each molecule is represented by a 2D graph G=(V,E)G=(V,E) with nodes VV representing its atoms and edges EE representing its covalent bonds, annotated with molecular features 𝒉v(0){\bm{h}}_{v}^{(0)} and 𝒆v,u{\bm{e}}_{v,u}, respectively (see Section 6 for details). To propagate features across a molecule and get 2D molecular representations, we use GAT layers, which utilize a self-attention mechanism in message-passing with the following operations:

𝐡v():=\displaystyle\mathbf{h}_{v}^{(\ell)}:= 𝖠𝖦𝖦()({{𝐦v,u()uN(v)}})=uN(v)𝐦v,u()\displaystyle\mathsf{AGG}^{(\ell)}\bigl{(}\{\!\!\{\mathbf{m}_{v,u}^{(\ell)}\mid u\in N(v)\}\!\!\}\bigr{)}=\sum_{u\in N(v)}\mathbf{m}_{v,u}^{(\ell)}
with 𝐦v,u()=αv,u𝐖𝐡u(l1),\displaystyle\mbox{ with }\ \ \ \mathbf{m}_{v,u}^{(\ell)}=\alpha_{v,u}\mathbf{W}\mathbf{h}_{u}^{(l-1)}, (5)

and where αv,u\alpha_{v,u} are the GAT attention coefficients and 𝐖\mathbf{W} a learnable parameter matrix. Following Veličković et al. (2018), the attention mechanism is implemented with a single-layer feedforward neural network. To obtain a per-molecule embedding, we compute 𝐡G𝟸𝙳=vV𝐡v(L)\mathbf{h}^{\mathtt{2D}}_{G}=\sum_{v\in V}\mathbf{h}_{v}^{(L)}, where LL is the number of message-passing layers.

3D Conformer Message-Passing Network. A conformer (atomic structure) of a molecule is defined as S={𝐫i,Zi}i=1NS=\{\mathbf{r}_{i},Z_{i}\}_{i=1}^{N} where NN is the number of atoms, 𝐫i3\mathbf{r}_{i}\in\mathbb{R}^{3} are the Cartesian coordinates of atom ii, and ZiZ_{i}\in\mathbb{N} is the atomic number of atom ii. We use weighted adjacency matrices 𝐀n×n\mathbf{A}\in\mathbb{R}^{n\times n} to represent pairwise atom distances. In some cases we will apply a cutoff radius to these distances. We employ the geometric MPNN SchNet (Schütt et al., 2017), although it is worth noting that alternative E(3)E(3)-invariant neural networks could be seamlessly integrated. The selection of SchNet is motivated not only by its proficient balance between model complexity and efficacy but also by its proven utility in previous works (Axelrod & Gómez-Bombarelli, 2023). SchNet performs E(3)E(3)-invariant message-passing by using radial basis functions to incorporate the distances of the geometric node features \vv𝐯v,\vv𝐯u\vv{\mathbf{v}}_{v},\vv{\mathbf{v}}_{u}. We refer the reader to Section D.1 for more details. We denote the matrix whose columns are the atom-wise features of SchNet from the last message-passing layer LL with 𝐇\mathbf{H}, that is, 𝐇[v]=𝐡v(L)\mathbf{H}[v]=\mathbf{h}^{(L)}_{v}.

To compute the vector representation for a conformer SS, we aggregate the atom-wise embeddings obtained from the last message-passing layer LL of SchNet into a single vector representation as 𝐡S𝟹𝙳=vV(𝐀𝐡v(L)+𝐚)\mathbf{h}_{S}^{\mathtt{3D}}=\sum_{v\in V}\left(\mathbf{A}\mathbf{h}_{v}^{(L)}+{\bf a}\right), where VV is the set of atoms and 𝐀\mathbf{A} and 𝐚{\bf a} learned during training. For a set of KK conformers, the output of our 3D MPNN models is a matrix whose columns are the embeddings 𝐡Sk\mathbf{h}_{S_{k}} for conformer kk, that is, 𝐇𝟹𝙳[k]=𝐡Sk𝟹𝙳\mathbf{H}^{\mathtt{3D}}[k]=\mathbf{h}_{S_{k}}^{\mathtt{3D}}.

FGW Barycenter Aggregation. We now introduce an implicit and differentiable neural aggregation function whose output is determined by solving an FGW barycenter optimization problem. Its input is KK graphs Gk=(𝐇k,𝑨k,𝝎k)G_{k}=(\mathbf{H}_{k},{\bm{A}}_{k},\bm{\omega}_{k}) for each conformer Sk={𝐫k,i,Zk,i}i=1NS_{k}=\{\mathbf{r}_{k,i},Z_{k,i}\}_{i=1}^{N}, with features 𝐇k\mathbf{H}_{k} computed by an E(3)E(3)-invariant MPNN, with weighted adjacency matrix 𝑨k{\bm{A}}_{k} of pairwise atomic distances, and the probability mass of each atom 𝝎k\bm{\omega}_{k}, typically set to 1/N1/N. The output of the barycenter conformer, denoted as G¯=(𝐇¯,𝑨¯,𝝎¯)\overline{G}=(\overline{\mathbf{H}},\overline{{\bm{A}}},\overline{\bm{\omega}}), represents the geometric mean of the input conformers, incorporating both their structural characteristics and features (Figure 1). The barycenter G¯\overline{G} is the conformer graph that minimizes the sum of weighted FGW distances among the conformer graphs (Gk)k[K](G_{k})_{k\in[K]} with feature matrices (𝐇k)k[K](\mathbf{H}_{k})_{k\in[K]}, structure matrices (𝑨k)k[K]({\bm{A}}_{k})_{k\in[K]}, and base histograms (𝝎k)k[K]ΔnK(\bm{\omega}_{k})_{k\in[K]}\in\Delta_{n}^{K}. That is, given any fixed KK\in\mathbb{N} and any 𝝀ΔK\bm{\lambda}\in\Delta_{K}, the FGW barycenter is defined as

G¯:=argminGk=1KλkFGWp,α(G,Gk),\displaystyle\overline{G}:=\operatorname*{arg~{}min}_{G}\sum_{k=1}^{K}\lambda_{k}\text{FGW}_{p,\alpha}(G,G_{k}), (6)

where FGWp,α(G,Gk)\text{FGW}_{p,\alpha}(G,G_{k}) is the fused Gromov-Wasserstein distance defined in Equation 3, and where we set, for each pair of conformer graphs G=(𝐇,𝑨,𝝎)G=(\mathbf{H},{\bm{A}},\bm{\omega}) and Gk=(𝐇k,𝑨k,𝝎k)G_{k}=(\mathbf{H}_{k},{\bm{A}}_{k},\bm{\omega}_{k}), 𝑴:=((𝐇[i]𝐇k[j])2)n×nn×n{\bm{M}}:=\left(\left(\mathbf{H}[i]-\mathbf{H}_{k}[j]\right)^{2}\right)_{n\times n}\in{\mathbb{R}}^{n\times n} as the feature distance matrix, and 𝑳(𝑨,𝑨k)=(𝑨[i,j]𝑨k[l,m]|)ijlm{\bm{\mathsfit{L}}}\left({\bm{A}},{\bm{A}}_{k}\right)=\left({\bm{A}}[i,j]-{\bm{A}}_{k}[l,m]|\right)_{ijlm} as the 4-tensor representing the structural distance when aligning atoms ii to ll and jj to mm (Figure 2). Solving Equation 6, we obtain a unique FGW barycenter graph G¯=(𝐇¯,𝑨¯,𝝎¯)\overline{G}=(\overline{\mathbf{H}},\overline{{\bm{A}}},\overline{\bm{\omega}}) with representation 𝐡¯v=𝐇¯[v]\overline{\mathbf{h}}_{v}=\overline{\mathbf{H}}[v] for each atom vv. We aggregate the atom-wise embeddings obtained from the FGW barycenter G¯\overline{G} into a single vector representation using 𝐡G¯𝙱𝙲=vV(𝐀¯𝐡¯v+𝐚¯)\mathbf{h}^{\mathtt{BC}}_{\overline{G}}=\sum_{v\in V}\left(\overline{\mathbf{A}}~{}\overline{\mathbf{h}}_{v}+\overline{{\bf a}}\right).

Refer to caption
Figure 2: Illustration of the feature-based and structural distances of conformers (here: alanine dipeptide) we use for the computation of the Fused Gromov-Wasserstein barycenter.

Intuitively, barycenter-based aggregation in Eq.(6) can be seen as a more distance (structure) preserving pooling operation rather than standard mean aggregation. For instance, consider two conformers, where one is a 180-degree rotation of the other. Averaging their coordinates collapses the hydrogen atoms into the same position, creating an unphysical structure. On the contrary, employing the FGW Barycenter might prevent such issues.

Invariant Aggregation of 2D and 3D Representations. We integrate the representations of the 2D graph and the 3D conformer graphs using an average aggregation as well as the barycenter-based aggregation. The requirement for this aggregation is that it is invariant to the order of the input conformers; that is, it treats the conformers as a set as well as invariant to actions of the group E(3)E(3).

Let 𝐇𝟸𝙳\mathbf{H}^{\mathtt{2D}} and 𝐇𝙱𝙲\mathbf{H}^{\mathtt{BC}} be the matrices whose columns are, respectively, KK copies of the 2D and barycenter representations from previous sections. Using learnable weight matrices 𝐖𝟸𝙳\mathbf{W}^{\mathtt{2D}}, 𝐖𝟹𝙳\mathbf{W}^{\mathtt{3D}}, and 𝐖𝙱𝙲\mathbf{W}^{\mathtt{BC}}, we obtain the final atom-wise feature matrices as

𝐇𝚌𝚘𝚖𝚋=𝐖𝟸𝙳𝐇𝟸𝙳+𝐖𝟹𝙳𝐇𝟹𝙳+γ𝐖𝙱𝙲𝐇𝙱𝙲,\mathbf{H}^{\mathtt{comb}}=\mathbf{W}^{\mathtt{2D}}\mathbf{H}^{\mathtt{2D}}+\mathbf{W}^{\mathtt{3D}}\mathbf{H}^{\mathtt{3D}}+\gamma\,\mathbf{W}^{\mathtt{BC}}\mathbf{H}^{\mathtt{BC}},\vspace{-0.02in} (7)

where γ\gamma is a hyper-parameter controlling the contribution of the barycenter-based feature. Intuitively, this aggregation function, where we use multiple copies of the 2D graph and barycenter representations, provides a balanced contribution of the three types of representations and is empirically highly beneficial. Finally, to predict a molecular property, we apply a linear regression layer on a mean-aggregation of the per-conformations embedding as:

y^=𝐖G(1Kk=1K𝐇𝚌𝚘𝚖𝚋[k])+𝐛G.\hat{y}=\mathbf{W}^{G}\left(\frac{1}{K}\sum_{k=1}^{K}\mathbf{H}^{\mathtt{comb}}[k]\right)+{\bf b}^{G}. (8)

We can show that the function defined by Section 3.2 to Equation 8 is invariant to actions of the group E(3)E(3) and permutations acting on the sequence of input conformers.

Theorem 3.1.

Let GG be the 2D graph and (S1,,SK)(S_{1},...,S_{K}) with Sk={𝐫k,i,Zk,i}i=1NS_{k}=\{\mathbf{r}_{k,i},Z_{k,i}\}_{i=1}^{N}, 1kK1\leq k\leq K, be a sequence of KK conformers of a molecule. Let y^=f𝛉(G,(S1,,SK))\hat{y}=f_{\bm{\theta}}(G,(S_{1},...,S_{K})) be the function defined by Section 3.2 to Equation 8. For any g1,,gKE(3)g_{1},...,g_{K}\in E(3) we have that f𝛉(G,(g1S1,,gKSK))=f𝛉(G,(S1,,SK))f_{\bm{\theta}}(G,(g_{1}S_{1},...,g_{K}S_{K}))=f_{\bm{\theta}}(G,(S_{1},...,S_{K})). Moreover, for any πSym([K])\pi\in\mathrm{Sym}([K]) we have that f𝛉(G,(Sπ(1),,Sπ(K)))=f𝛉(G,(S1,,SK))f_{\bm{\theta}}(G,(S_{\pi(1)},...,S_{\pi(K)}))=f_{\bm{\theta}}(G,(S_{1},...,S_{K})).

4 Efficient and Convergent Molecular Conformer Aggregation

In this section, we provide some theoretical results to justify our novel FGW barycenter-based implicit E(3)E(3)-invariant aggregation function that integrates the representations of molecular 3D conformations computed by geometric message-passing neural networks in Section 3.2. We established a fast convergence rate of the empirical FGW barycenters to the true barycenters as a function of the number of conformer samples KK.

Undirected Attribute Graph Space. Let us define a structured object to be a triplet (𝛀,𝑨,μ)(\bm{\Omega},{\bm{A}},\mu), 𝛀=𝛀s×𝛀f\bm{\Omega}=\bm{\Omega}_{s}\times\bm{\Omega}_{f}, where (𝛀f,df)(\bm{\Omega}_{f},d_{f}) and (𝛀s,𝑨)(\bm{\Omega}_{s},{\bm{A}}) are feature and structure metric spaces, respectively, and μ\mu is a probability measure over 𝛀\bm{\Omega}. By defining 𝝎\bm{\omega}, the probability measure of the nodes, the graph GG represents a fully supported probability measure over the feature/structure of the product space, μ=kωkδ(𝒙k,𝒂k)\mu=\sum_{k}\omega_{k}\delta_{({\bm{x}}_{k},{\bm{a}}_{k})}, which describes the entire undirected attributed graph. We note 𝕏\mathbb{X} the set of all metric spaces. The space of all structured objects over (𝛀f,df)(\bm{\Omega}_{f},d_{f}) will be written as 𝕊(𝛀)\mathbb{S}(\bm{\Omega}), and is defined by all the triplets (𝛀,𝑨,μ)(\bm{\Omega},{\bm{A}},\mu), where (𝛀f,df)𝕏(\bm{\Omega}_{f},d_{f})\in\mathbb{X} and μ𝒫(𝛀)\mu\in\mathcal{P}(\bm{\Omega}).

True and Empirical Barycenters. Given (𝛀,𝑨,μ)𝕊(𝛀)(\bm{\Omega},{\bm{A}},\mu)\in\mathbb{S}(\bm{\Omega}), the variance functional σ2\sigma^{2} of a distribution P𝒫(𝒫p(𝛀))P\in\mathcal{P}(\mathcal{P}_{p}(\bm{\Omega})) is defined as follows:

σP2=𝒫p(𝛀)FGWp,αp(μ¯0,ν)𝑑P(ν),\displaystyle\sigma_{P}^{2}=\int_{\mathcal{P}_{p}(\bm{\Omega})}\text{FGW}_{p,\alpha}^{p}(\overline{\mu}_{0},\nu)dP(\nu),\vspace{-0.05in} (9)

where μ¯0\overline{\mu}_{0} is a true barycenter defined in equation (10). We will then restrict our attention to the subset 𝒫p(𝒫p(𝛀))={P𝒫(𝒫p(𝛀)):σP2<+}\mathcal{P}_{p}(\mathcal{P}_{p}(\bm{\Omega}))=\left\{P\in\mathcal{P}(\mathcal{P}_{p}(\bm{\Omega})):\sigma_{P}^{2}<+\infty\right\}. Note that 𝒫p(𝛀)\mathcal{P}_{p}(\bm{\Omega}) is a subset of 𝒫(𝛀)\mathcal{P}(\bm{\Omega}) with finite variance and defined the same way as 𝒫p(𝒫p(𝛀))\mathcal{P}_{p}(\mathcal{P}_{p}(\bm{\Omega})) but on (𝛀,𝑨,μ)(\bm{\Omega},{\bm{A}},\mu). For any P𝒫p(𝒫p(𝛀))P\in\mathcal{P}_{p}(\mathcal{P}_{p}(\bm{\Omega})), we define the true barycenter of PP is any μ¯0𝒫p(𝛀)\overline{\mu}_{0}\in\mathcal{P}_{p}(\bm{\Omega}) s.t.

μ¯0argminμ𝒫p(𝛀)𝒫p(𝛀)FGWp,αp(μ,ν)𝑑P(ν).\displaystyle\overline{\mu}_{0}\in\operatorname*{arg~{}min}_{\mu\in\mathcal{P}_{p}(\bm{\Omega})}\int_{\mathcal{P}_{p}(\bm{\Omega})}\text{FGW}_{p,\alpha}^{p}(\mu,\nu)dP(\nu).\vspace{-0.15in} (10)

In our context of predicting molecular properties, the true barycenter μ¯0\overline{\mu}_{0} is unknown. However, we can still draw KK random sample independently of the 3D molecular representation {μk}k[K]={l=1kωlδ(xl,al)}k[K]\left\{\mu_{k}\right\}_{k\in[K]}=\left\{\sum_{l=1}^{k}\omega_{l}\delta_{(x_{l},a_{l})}\right\}_{k\in[K]} from PP. Then, an empirical barycenter is defined as a barycenter of the empirical distribution PK=(1/K)kδμkP_{K}=(1/K)\sum_{k}\delta_{\mu_{k}}, i.e.,

μ¯Kargminμ𝒫p(𝛀)1KkFGWp,αp(μ,μk).\displaystyle\overline{\mu}_{K}\in\operatorname*{arg~{}min}_{\mu\in\mathcal{P}_{p}(\bm{\Omega})}\frac{1}{K}\sum_{k}\text{FGW}_{p,\alpha}^{p}(\mu,\mu_{k}). (11)

4.1 Fast Convergence of Empirical FGW Barycenter

This work establishes a novel fast rate convergence for empirical barycenters in the FGW space via Theorem 4.1, which is proved in Appendix B. To the best of our knowledge, this is new in the literature, where only the result for Wasserstein space exists in Le Gouic et al. (2022).

Theorem 4.1.

Let P𝒫2(𝒫2(𝛀))P\in\mathcal{P}_{2}(\mathcal{P}_{2}(\bm{\Omega})) be a probability measure on the 2-FGW space. Let μ¯0𝒫2(𝛀)\overline{\mu}_{0}\in\mathcal{P}_{2}(\bm{\Omega}) and σP2\sigma^{2}_{P} be barycenter and variance functional of PP satisfying (10) and (9), respectively. Let γ,β>0\gamma,\beta>0 and suppose that every μsupp(P)\mu\in\operatorname*{supp}(P) is the pushforward of μ¯0\overline{\mu}_{0} by the gradient of an γ\gamma-strongly convex and β\beta smooth function ψμ¯0μ\psi_{\overline{\mu}_{0}\rightarrow\mu}, i.e., μ=(ψμ¯0μ)#μ¯0\mu=(\nabla\psi_{\overline{\mu}_{0}\rightarrow\mu})_{\#}\overline{\mu}_{0}. If βγ<1\beta-\gamma<1, then μ¯0\overline{\mu}_{0} is unique and any empirical barycenter μ¯K\overline{\mu}_{K} of PP satisfies

𝔼(FGW2,α2(μ¯0,μ¯K))4σP2(1β+γ)2K.\displaystyle\mathbb{E}\left(\text{FGW}_{2,\alpha}^{2}(\overline{\mu}_{0},\overline{\mu}_{K})\right)\leq\frac{4\sigma^{2}_{P}}{(1-\beta+\gamma)^{2}K}.\vspace{-0.2in} (12)

The upper bound in Equation 12 implies that the empirical barycenter converges to the target distribution at a rate of 𝒪(1/K)\mathcal{O}(1/K), where KK is the number of 3D conformers. This suggests utilizing small values of KK, such as K{5,10}K\in\{5,10\}, would yield a satisfactory approximation for μ¯0\overline{\mu}_{0}. We confirm this empirically in experiments in Section 6.5.

Algorithm 1 Entropic FGW Barycenter
  Input: 𝝎¯\overline{\bm{\omega}}, {Gs:=(𝑯s,𝑨s,𝝎s)}s=1K\{G_{s}:=({\bm{H}}_{s},{\bm{A}}_{s},\bm{\omega}_{s})\}_{s=1}^{K}, ϵ\epsilon.
  Optimizing: G¯,{𝝅sΠ(𝝎¯,𝝎s)}s=1K\overline{G},\{{\bm{\pi}}_{s}\in\Pi(\overline{\bm{\omega}},\bm{\omega}_{s})\}_{s=1}^{K}.
  repeat
     for s=1s=1 to KK do
        Solve argmin𝝅s(k)FGWp,αϵ(G¯(k),Gs)\operatorname*{arg~{}min}_{{\bm{\pi}}_{s}^{(k)}}\text{FGW}_{p,\alpha}^{\epsilon}(\overline{G}^{(k)},G_{s}) with Alg. 2.
     end for
     Update 𝑨¯(k+1)1𝝎¯𝝎¯1Ks=1K𝝅s(k)𝑨s𝝅s(k)\overline{{\bm{A}}}^{(k+1)}\leftarrow\frac{1}{\overline{\bm{\omega}}\overline{\bm{\omega}}^{\top}}\frac{1}{K}\sum_{s=1}^{K}{{\bm{\pi}}_{s}^{(k)}}{\bm{A}}_{s}{{\bm{\pi}}_{s}^{(k)}}^{\top}.
     Update 𝑯¯(k+1)diag(1/𝝎¯)1Ks=1K𝝅s(k)𝑯s\overline{{\bm{H}}}^{(k+1)}\leftarrow\mathrm{diag}(1/\overline{\bm{\omega}})\frac{1}{K}\sum_{s=1}^{K}{{\bm{\pi}}_{s}^{(k)}}{\bm{H}}_{s}
  until kk in outer iterations and not converged

4.2 Empirical Entropic FGW Barycenter

To train on large-scale problems, we propose to solve the entropic relaxation of Equation 6 to take advantage of GPU computing power (Peyré et al., 2019). Given a set of conformer graphs {Gs:=(𝑯s,𝑨s,𝝎s)}s=1K\{G_{s}:=({\bm{H}}_{s},{\bm{A}}_{s},\bm{\omega}_{s})\}_{s=1}^{K}, we want to optimize the entropic barycenter G¯\overline{G}, where we fixed the prior on nodes 𝝎¯\overline{\bm{\omega}}

G¯=argminG1Ks=1KFGWp,αϵ(G¯,Gs).\overline{G}=\operatorname*{arg~{}min}_{G}\frac{1}{K}\sum_{s=1}^{K}\text{FGW}_{p,\alpha}^{\epsilon}\left(\overline{G},G_{s}\right). (13)

with λs=1/K,s[1,K]\lambda_{s}=1/K,\,\forall s\in[1,K]. Titouan et al. (2019) solve Equation 13 using Block Coordinate Descent, which iteratively minimizes the original FGW distance between the current barycenter and the graphs GsG_{s}. In our case, we solve for KK couplings of entropic FGW distances to the graphs at each iteration, then following the update rule for structure matrix (Proposition 4, (Peyré et al., 2016))

𝑨¯(k+1)1𝝎¯𝝎¯1Ks=1K𝝅s(k)𝑨s𝝅s(k),\overline{{\bm{A}}}^{(k+1)}\leftarrow\frac{1}{\overline{\bm{\omega}}~{}\overline{\bm{\omega}}^{\top}}\frac{1}{K}\sum_{s=1}^{K}{{\bm{\pi}}_{s}^{(k)}}{\bm{A}}_{s}{{\bm{\pi}}_{s}^{(k)}}^{\top}, (14)

and for the feature matrix (Titouan et al., 2019; Cuturi & Doucet, 2014)

𝑯¯(k+1)diag(1/𝝎¯)1Ks=1K𝝅s(k)𝑯s,\overline{{\bm{H}}}^{(k+1)}\leftarrow\mathrm{diag}(1/\overline{\bm{\omega}})\frac{1}{K}\sum_{s=1}^{K}{{\bm{\pi}}_{s}^{(k)}}{\bm{H}}_{s}, (15)

leading to Algorithm 1. More details on practical implementations and algorithm complexity are in Appendix C.

5 Related Work

Molecular Representation Learning. The traditional approach for molecular representation referred to as connectivity fingerprints (Morgan, 1965) encodes the presence of different substructures within a compound in the form of a binary vector. Modern molecular representations used in machine learning for molecular properties prediction include 1D strings (Ahmad et al., 2022; Wang et al., 2019), 2D topological graphs (Gilmer et al., 2017a; Yang et al., 2019; Rong et al., 2020; Hu et al., 2020b) and 3D geometric graphs (Fang et al., 2021; Zhou et al., 2023; Liu et al., 2022a). The use of an ensemble of molecular conformations remains a relatively unexplored frontier in research, despite early evidence suggesting its efficacy in property prediction (Axelrod & Gómez-Bombarelli, 2023; Wang et al., 2024b). Another line of work uses conformers only at training time in a self-supervised loss to improve a 2D MPNN (Stärk et al., 2022). Contrary to prior work, we introduce a novel and streamlined barycenter-based conformer aggregation technique, seamlessly integrating learned representations from both 2D and 3D MPNNs. Moreover, we show that cost-effective conformers generated through distance-geometry sampling are sufficiently informative.

Geometric Graph Neural Networks. Graph Neural Networks (GNNs) designed for geometric graphs operate based on the message-passing framework, where the features of each node are dynamically updated through a process that respects permutation equivariance. Examples are models such as SphereNet (Liu et al., 2022b), GMNNs (Zaverkin & Kästner, 2020), DimeNet (Gasteiger et al., 2020b), GemNet-T (Gasteiger et al., 2021), SchNet (Schütt et al., 2017), GVP-GNN, PaiNN, E(n)-GNN (Satorras et al., 2021), MACE (Batatia et al., 2022), ICTP (Zaverkin et al., 2024), SEGNN (Brandstetter et al., 2022), SE(3)-Transformer (Fuchs et al., 2020), and VisNet (Wang et al., 2024a).

Optimal Transport in Graph Learning. By modeling graph features/structures as probability measures, the (Fused) GW distance (Titouan et al., 2020) serves as a versatile metric for comparing structured graphs. Previous applications of GW distance include graph node matching (Xu et al., 2019b), partitioning (Xu et al., 2019a; Chowdhury & Needham, 2021), and its use as a loss function for graph metric learning (Vincent-Cuaz et al., 2021, 2022; Chen et al., 2020; Zeng et al., 2023). More recently, FGW has been leveraged as an objective for encoding graphs (Tang et al., 2023) in tasks such as graph prediction (Brogat-Motte et al., 2022) and classification (Ma et al., 2023). To the best of our knowledge, we are the first to introduce the entropic FGW barycenter problem (Peyré et al., 2016; Titouan et al., 2020) for molecular representation learning. By employing the entropic formulation (Cuturi, 2013; Cuturi & Doucet, 2014), our learning pipeline enjoys a tunable trade-off between barycenter accuracy and computational performance, thus enabling an efficient hyperparameter tuning process. Moreover, we also present empirical barycenter-related theories, demonstrating how this entropic FGW barycenter framework effectively captures meaningful underlying structures of 3D conformers, thereby enhancing overall performance.

6 Experiments

6.1 Implementation Details

We encode each molecule in the SMILES format and employ the RDKit package to generate 3D conformers. We set the size of the latent dimensions of GAT (Veličković et al., 2018) to 128/256128/256. Node features are initialized based on atomic properties such as atomic number, chirality, degree, charge, number of hydrogens, radical electrons, hybridization, aromaticity, and ring membership, while edges are represented as one-hot vectors denoting bond type, stereo configuration, and conjugation status. Each 3D conformer generated by RDKit comprises nn atoms with the corresponding 3D coordinates representing their spatial positions. Subsequently, we establish the graph structure and compute atomic embeddings utilizing the force-field energy-based SchNet model (Schütt et al., 2017), extracting features prior to the 𝖱𝖤𝖠𝖣𝖮𝖴𝖳\mathsf{READOUT} layer. Our SchNet configuration incorporates three interaction blocks with feature maps of size F=128F=128, employing a radial function defined on Gaussians spaced at intervals of 0.1Å0.1\mathring{{A}} with a cutoff distance of 10 Å\mathring{A}. The output of each conformer k[K]k\in[K] forms a graph GkG_{k}, utilized in solving the FGW barycenter G¯\overline{G} as defined in Eq. (6). Subsequently, we aggregate features from 2D, 3D, and barycenter molecule graphs using Eqs. (7-8), followed by MLP layers. Leveraging Sinkhorn iterations in our barycenter solver (Algorithm 1), we speed up the training process across multiple GPUs using PyTorch’s distributed data-parallel technique. Training the entire model employs the Adam optimizer with initial learning rates selected from 1e3,1e3/2,1e4{1e^{-3},1e^{-3}/2,1e^{-4}}, halved using ReduceLROnPlateau after 1010 epochs without validation set improvement. Further experimental details are provided in the Appendix.

To accelerate the training process, especially in large-scale settings (e.g., BDE dataset), we first train the model with 2D and 3D features for some epochs, and then load the saved model and continue to train with full configurations as in Eq.(7) till converge. We set empirically γ\gamma in Eq.(7) is 0.20.2.

Table 1: Number of samples for each split on molecular property prediction, classification tasks, and reaction prediction.
Lipo ESOL FreeSolv BACE CoV-2 3CL Cov-2 BDE
Train 2940 789 449 1059 50 (485) 53 (3294) 8280
Valid. 420 112 64 151 15 (157) 17 (1096) 1184
Test 840 227 129 303 11 (162) 22 (1086) 2366
Total 4200 1128 642 1513 76 (804) 92 (5476) 11830

6.2 Molecular Property Prediction Tasks

Dataset. We use four datasets Lipo, ESOL, FreeSolv, and BACE in MoleculeNet benchmark (Table 1), spanning on various molecular characteristics such as physical chemistry and biophysics. We split data using random scaffold settings as baselines and reported the mean and standard deviation of mean square error (mse) by running on five trial times. More information for datasets is in Section D.2 Appendix.

Baselines. We compare against various benchmarks, including both supervised, pre-training, and multi-modal approaches. (i) The supervised methods are 2D graph neural network models including 2D-GAT (Veličković et al., 2018), D-MPNN (Yang et al., 2019), and AttentiveFP (Xiong et al., 2019); (ii) 2D molecule pretraining methods are PretrainGNN (Hu et al., 2020a), GROVER (Rong et al., 2020), MolCLR (Wang et al., 2022), ChemRL-Gem (Fang et al., 2022), ChemBERTa-2 (Ahmad et al., 2022), and MolFormer (Ross et al., 2022). It’s important to note that these models are pre-trained on a vast amount of data; for example, MolFormer is learned on 1.11.1 billion molecules from PubChem and ZINC datasets. We also compare with the (iii) 2D-3D aggregation ConfNet model (Liu et al., 2021), which is one of the winners of KDD Cup on OGB Large-Scale Challenge (Hu et al., 2021). Finally, we benchmark with 3D conformers-based models such as UniMol (Zhou et al., 2023), SchNet, and ChemProp3D (Axelrod & Gómez-Bombarelli, 2023). Among this, UniMol is pre-trained on 209209 M molecular conformation and requires 11 conformers on each downstream task. We train SchNet with 55 conformers (1010 for FreeSolv) and test with two versions: (a) taking output at the final layer and averaging different conformers (SchNet-scalar), (b) using feature node embeddings before 𝖱𝖤𝖠𝖣𝖮𝖴𝖳\mathsf{READOUT} layers and aggregating conformers by an MLP layer (SchNet-em). In ChemProp3D, we replace the classification header with an MLP layer for regression tasks, training with a 2D molecular graph and 1010 conformers. With the ConfNet, we use 2020 conformers in the training step and provide results for 20 and 4020\textrm{ and }40 conformers for the evaluations step, followed by configurations in (Liu et al., 2021).

Table 2: Models evaluation on regression tasks (MSE \downarrow).
Model Lipo ESOL FreeSolv BACE
2D-GAT 1.387±0.2061.387\pm 0.206 2.288±0.0172.288\pm 0.017 8.564±1.3458.564\pm 1.345 1.844±0.331.844\pm 0.33
D-MPNN 0.534±0.0220.534\pm 0.022 0.923±0.0450.923\pm 0.045 4.213±0.0684.213\pm 0.068 0.723±0.0210.723\pm 0.021
Attentive FP 0.520±0.0010.520\pm 0.001 0.771±0.0260.771\pm 0.026 4.197±0.1934.197\pm 0.193 -
PretrainGNN 0.545±0.0030.545\pm 0.003 1.21±0.0051.21\pm 0.005 6.392±0.0036.392\pm 0.003 -
GROVER_large 0.676±0.0120.676\pm 0.012 0.798±0.0180.798\pm 0.018 5.162±0.0475.162\pm 0.047 -
ChemBERTa-2* 0.639±0.0060.639\pm 0.006 0.795±0.0330.795\pm 0.033 - 1.858±0.0291.858\pm 0.029
ChemRL-GEM 0.486±0.0080.486\pm 0.008 0.706±0.0610.706\pm 0.061 3.924±0.4363.924\pm 0.436 -
MolFormer 0.492±0.0120.492\pm 0.012 0.766±0.0260.766\pm 0.026 5.485±0.0455.485\pm 0.045 1.091±0.0211.091\pm 0.021
UniMol 0.374±0.0120.374\pm 0.012 0.741±0.0140.741\pm 0.014 2.867±0.1862.867\pm 0.186 -
SchNet-scalar 0.704±0.0320.704\pm 0.032 0.672±0.0270.672\pm 0.027 1.608±0.1581.608\pm 0.158 0.723±0.10.723\pm 0.1
SchNet-emb 0.589±0.0220.589\pm 0.022 0.635±0.0570.635\pm 0.057 1.587±0.1361.587\pm 0.136 0.692±0.0280.692\pm 0.028
ChemProp3D 0.602±0.0350.602\pm 0.035 0.681±0.0230.681\pm 0.023 2.014±0.1822.014\pm 0.182 0.815±0.170.815\pm 0.17
ConfNet 1.360±0.0381.360\pm 0.038 2.115±0.4842.115\pm 0.484 - 1.329±0.0421.329\pm 0.042
ConAN 0.556±0.0130.556\pm 0.013 0.571±0.0190.571\pm 0.019 1.496±0.1581.496\pm 0.158 0.635±0.0510.635\pm 0.051
ConAN-FGW 0.422±0.0160.422\pm 0.016 0.529±0.0220.529\pm 0.022 1.068±0.0831.068\pm 0.083 0.549±0.0160.549\pm 0.016

Results. Table 2 presents the experimental findings of ConAN, alongside competitive methods, with the best results highlighted in bold. Baseline outcomes from prior studies (Zhou et al., 2023; Fang et al., 2022; Chang & Ye, 2023) are included, while performance for other models is provided through public codes. ConAN  version denotes the aggregation of 2D and 3D features as per Eq. (7) without employing the barycenter, whereas ConAN-FGW signifies full configurations. We employ a number of conformers {5,5,10,5}\{5,5,10,5\} and {5,5,5,5}\{5,5,5,5\} for ConAN, and ConAN-FGW, respectively, based on validation results for Lipo, ESOL, FreeSolv, and BACE. From the experiments, several observations emerge: (i) ConAN  proves more effective than relying solely on 2D or 3D, as shown by Conan’s performance, achieving second-best rankings on three datasets compared to models using only 2D (ChemRL-GEM) or 3D representations (UniMol). (ii) ConAN-FGW consistently outperforms baselines across all datasets, despite employing significantly fewer 3D conformers than ConAN. This underscores the importance of leveraging the barycenter to capture invariant 3D geometric characteristics.

6.3 3D SARS-CoV Molecular Classification Tasks

Dataset.  We evaluate ConAN on two datasets Cov-2 3CL and Cov-2 (Table 1), focusing on molecular classification tasks. The same splitting for training and testing is followed (Axelrod & Gómez-Bombarelli, 2023). We also apply the CREST (Grimme, 2019) to filter generated conformers by RDKit as (Axelrod & Gómez-Bombarelli, 2023) for fair comparisons. Model performance is reported with the receiver operating characteristic area under the curve (ROC) and precision-recall area under the curve (PRC) over three trial times.

Baselines. We compare with three models, namely, SchNetFeatures, ChemProp3D, CP3D-NDU, each with two different attention mechanisms to ensemble 3D conformers and 2D molecular graph feature embedding as proposed by Axelrod & Gómez-Bombarelli (2023). These baselines generate 200200 conformers for their input algorithms. Additionally, the ConfNet (Liu et al., 2021) is also evaluated using 2020 or 4040 conformers in testing.

Table 3: Performance of various models on the two molecular classification tasks.
Method Num Conformers Dataset ROC \uparrow PRC \uparrow
SchNetFeatures 200 CoV-2 3CL 0.86 0.26
ChemProp3D 200 CoV-2 3CL 0.66 0.20
CP3D-NDU 200 CoV-2 3CL 0.901 0.413
SchNetFeatures average neighbors CoV-2 3CL 0.84 0.29
ChemProp3D average neighbors CoV-2 3CL 0.73 0.31
CP3D-NDU average neighbors CoV-2 3CL 0.916 0.467
ConfNet {20,40}\{20,40\} CoV-2 3CL 0.493 0.128
ConAN 10 CoV-2 3CL 0.881 ±\pm 0.009 0.317 ±\pm 0.052
ConAN-FGW 5 CoV-2 3CL 0.918 ±\pm 0.012 0.423 ±\pm 0.045
SchNetFeatures 200 CoV-2 0.63 0.037
ChemProp3D 200 CoV-2 0.53 0.032
CP3D-NDU 200 CoV-2 0.663 0.06
SchNetFeatures average neighbors CoV-2 0.61 0.027
ChemProp3D average neighbors CoV-2 0.56 0.10
CP3D-NDU average neighbors CoV-2 0.647 0.058
ConfNet {20,40}\{20,40\} CoV-2 0.501 ±\pm 0.001 0.36 ±\pm 0.2
ConAN 10 CoV-2 0.634 ±\pm 0.053 0.031 ±\pm 0.023
ConAN-FGW 10 CoV-2 0.6735 ±\pm 0.032 0.036 ±\pm 0.014

Results. Table 3 presents performance of ConAN  and ConAN-FGW with the number of conformers 10 or 510\text{ or }5. It can be seen that ConAN-FGW delivers the best performance on ROC metric on two datasets and holds the second-best rank with PRC on CoV-2-3CL while requiring only 10 or 510\text{ or }5 conformers compared with 200200 conformers as CP3D-NDU. These results underscore the efficacy of incorporating barycenter components over merely aggregating 2D and 3D conformer embeddings, as observed in ConAN.

6.4 Molecular Conformer Ensemble Benchmark

Dataset. We run ConANon the BDE dataset (Table 1), which is the second-largest setting in (Zhu et al., 2023) aim to predict reaction-level molecule properties.
Baselines. ConAN is compared with state-of-the-art conformer ensemble strategies presented in Zhu et al. (2023), including SchNet (Schütt et al., 2017), DimeNet++ (Gasteiger et al., 2020a), GemNet (Gasteiger et al., 2021), PaiNN (Schütt et al., 2021), ClofNet (Du et al., 2022), and LEFTNet (Du et al., 2024). All these approaches employ 2020 conformers in training and testing. We provide two results of ConANusing only 1010 conformers and based on two architectures, SchNet and LEFTNet.
Results. Table 4 summarizes our achieved scores where the ConAN-FGW using LEFTNet backbone holds the second rank overall while using half the number of conformers. Additionally, it can be seen that ConAN-FGW improves with significant margins over both base models like SchNet (1.97371.60471.9737\rightarrow 1.6047) and LEFTNet (1.52761.48291.5276\rightarrow 1.4829), demonstrating the generalization of the proposed aggregation.

Table 4: Performance of different conformer ensemble strategies on reaction molecules prediction. Results are in Mean Absolute Error (MAE \downarrow). ConAN-FGW1 and ConAN-FGW2 denote for our versions using SchNet and LEFTNet, respectively.
SchNet DimeNet++ GemNet PainNN ClofNet LEFTNet ConAN-FGW1 ConAN-FGW2
Conf. 20 20 20 20 20 20 10 10
MAE \downarrow 1.9737 1.4741 1.6059 1.8744 2.0106 1.5276 1.6047 1.4829

6.5 Ablation Study

Contribution of 3D Conformer Number. One of the building blocks of our model is the use of multiple 3D conformations of a molecule. Each molecule is represented by KK conformations, so the choice of KK affects the model’s behavior. We treat KK as a hyperparameter and conduct experiments to validate the impact on model performance. To this end, we test on the ConAN  version with different KK (K=0K=0 is equivalent to the 2D-GAT baseline) and report performance in Table 7 Appendix. We can observe that using 3D conformers with K1K\geq 1 clearly improves performance compared to using only 2D molecular graphs as 2D-GAT. Furthermore, there is no straightforward dependency between the number of conformations in use and the accuracy of the model. For e.g., the performance tends to increase when using K=10K=10 (Lipo and FreeSolv), but overall, the best trade-off value is K=5K=5.

Refer to caption
Figure 3: Ablation study on the effect of number conformers on the FGW barycenter component on valid sets.

Contribution of FGW Barycenter Aggregation.  We examine the effect of barycenter aggregation when varying the number of conformers KK. Figure 3 summarizes results for those settings where we report average RMSE over four datasets in the MoleculeNet benchmark. We draw the following observations. First, ConAN-FGW shows notable enhancements as the number of conformers increases, with KK values ranging within the set 3,5,10{3,5,10}; however, when as K=20K=20, discernible disparities compared to the results obtained at K=10K=10 diminish. We argue that this phenomenon aligns consistently with theoretical results in Theorem 4.1 suggesting that employing a sufficiently large KK facilitates a precise approximation of the target barycenter.

Secondly, upon examining various datasets, it becomes evident that ConAN-FGW consistently demonstrates enhanced performance with the utilization of larger conformers, a phenomenon not uniformly observed in the case of ConAN. This observation validates the robustness and resilience inherent in ConAN-FGW. We attribute this advantage to the efficacy of its geometry-informed aggregation strategy in ensemble learning with diverse 3D conformers.

Generalization to other Backbone Model. We investigate ConAN  and ConAN-FGW performance using the VisNet backbone (Wang et al., 2024a), an equivariant geometry-enhanced graph neural network for 3D conformer embedding extraction. Results in Table 5 confirm that ConAN-FGW still advances ConAN  performance. Between VisNet and SchNet, there is no universal best choice over datasets.

Table 5: ConAN  evaluation using VisNet and SchNet on regression tasks (MSE \downarrow).
Model Lipo ESOL FreeSolv BACE
ConAN  (VisNet) 0.554±0.4480.554\pm 0.448 1.025±0.1191.025\pm 0.119 0.692±0.0320.692\pm 0.032 0.612±0.1480.612\pm 0.148
ConAN-FGW 0.495±0.0080.495\pm 0.008 0.552±0.0520.552\pm 0.052 0.643±0.0150.643\pm 0.015 0.469±0.0120.469\pm 0.012
ConAN (SchNet) 0.556±0.0130.556\pm 0.013 0.571±0.0190.571\pm 0.019 1.496±0.1581.496\pm 0.158 0.635±0.0510.635\pm 0.051
ConAN-FGW 0.422±0.0160.422\pm 0.016 0.529±0.0220.529\pm 0.022 1.068±0.0831.068\pm 0.083 0.549±0.0160.549\pm 0.016

6.6 3D Conformers distance distribution

We check diversity conformers randomly selected from a set of conformers generated by RDKit. For each pair of 3D conformers, we compute the optimal root mean square distance, which first aligns two molecules before measuring distance. Two settings are conducted: (i) estimating the mean, variance, max, and min distances distribution for conformers sampled by ConAN  over 200200 conformers generated by RDKit. and (ii) estimate distribution for those values in case they are the top closest conformers. Figure 4 below shows our observation with a box plot on the validation set of Fressolv.

Refer to caption
Figure 4: (left) box-plot distribution of mean, variance, maximum, and minimum distances among conformers; (right) distribution of the same values where sample top-kk closest conformers.

We observe that the distribution on the left ranges from 0.1 to 1.5, while in the worst case, the distance is between 0.01 and 0.08. Additionally, there’s a large gap between dmaxd_{\mathrm{max}} and dmind_{\mathrm{min}} on the left, whereas on the right, their means are close. It, therefore, can be seen that ConAN  sampling strategy, given 200 RDKit-generated conformers, remains consistent and diverse.

6.7 FGW Barycenter Algorithm Efficiency

We contrast our entropic solver (Algorithm 1) with FGW-Mixup (Ma et al., 2023) for the KK barycenter problem. FGW-Mixup accelerates FGW problem-solving by relaxing coupling feasibility constraints. However, as the number of conformers KK increases, FGW-Mixup requires more outer iterations due to compounding marginal errors in solving KK FGW distances. In contrast, our approach employs an entropic-relaxation FGW formulation ensuring that marginal constraints are respected, resulting in a less noisy FGW subgradient. Furthermore, we implement our algorithm with distributed computation on multi-GPUs, as highlighted in Fig. 5. This figure illustrates epoch durations during both forward and backward steps of training, showcasing the performance across various conformer setups on FreeSolv and CoV-2 3CL datasets. Utilizing a batch size of 32 conformers, all three algorithms employ early termination upon reaching error tolerance. Notably, our solver exhibits linear scalability with KK, while FGW-Mixup shows exponential growth, presenting challenges for large-scale learning tasks. More details are in Section D.5.

Refer to caption
Figure 5: Comparing runtimes of FGW-Mixup, ConAN-FGW (single and multi-GPU).

7 Conclusion and Future Works

In this study, we present an E(3)E(3)-invariant molecular conformer aggregation network that integrates 2D molecular graphs, 3D conformers, and geometry-attributed structures using Fused Gromov-Wasserstein barycenter formulations. The results indicate the effectiveness of this approach, surpassing several baseline methods across diverse downstream tasks, including molecular property prediction and 3D classification. Moreover, we investigate the convergence properties of the empirical barycenter problem, demonstrating that an adequate number of conformers can yield a reliable approximation of the target structure. To enable training on large datasets, we also introduce entropic barycenter solvers, maximizing GPU utilization. Future research directions include exploring the robustness of using RDKit for multiple low-energy scenarios or more accurate reference methods for atomic structure relaxation, such as density-functional theory. Finally, extending ConAN, to learn from large-scale unlabeled multi-modal molecular datasets holds significant promise for advancing the field.

Acknowledgement

The authors thank the International Max Planck Research School for Intelligent Systems (IMPRS-IS) for supporting Duy M. H. Nguyen and Nina Lukashina. Duy M. H. Nguyen and Daniel Sonntag are also supported by the XAINES project (BMBF, 01IW20005), No-IDLE project (BMBF, 01IW23002), and the Endowed Chair of Applied Artificial Intelligence, Oldenburg University. An T. Le was supported by the German Research Foundation project METRIC4IMITATION (PE 2315/11-1). Nhat Ho acknowledges support from the NSF IFML 2019844 and the NSF AI Institute for Foundations of Machine Learning. Mathias Niepert acknowledges funding by Deutsche Forschungsgemeinschaft (DFG, German Research Foundation) under Germany’s Excellence Strategy - EXC and support by the Stuttgart Center for Simulation Science (SimTech). Furthermore, we acknowledge the support of the European Laboratory for Learning and Intelligent Systems (ELLIS) Unit Stuttgart.

Impact Statement

This paper presents work whose goal is to advance the field of Machine Learning, focusing on molecular representation learning. There are many potential societal consequences of our work, none of which we feel must be specifically highlighted here.

References

  • Ahmad et al. (2022) Ahmad, W., Simon, E., Chithrananda, S., Grand, G., and Ramsundar, B. Chemberta-2: Towards chemical foundation models, 2022.
  • Axelrod & Gómez-Bombarelli (2023) Axelrod, S. and Gómez-Bombarelli, R. Molecular machine learning with conformer ensembles. Mach. Learn.: Sci. Technol., 4(3):035025, September 2023. ISSN 2632-2153. doi: 10.1088/2632-2153/acefa7.
  • Batatia et al. (2022) Batatia, I., Kovacs, D. P., Simm, G., Ortner, C., and Csanyi, G. Mace: Higher order equivariant message passing neural networks for fast and accurate force fields. In Koyejo, S., Mohamed, S., Agarwal, A., Belgrave, D., Cho, K., and Oh, A. (eds.), Advances in Neural Information Processing Systems, volume 35, pp.  11423–11436. Curran Associates, Inc., 2022.
  • Batatia et al. (2023) Batatia, I., Benner, P., Chiang, Y., Elena, A. M., Kovács, D. P., Riebesell, J., Advincula, X. R., Asta, M., Baldwin, W. J., Bernstein, N., Bhowmik, A., Blau, S. M., Cărare, V., Darby, J. P., De, S., Pia, F. D., Deringer, V. L., Elijošius, R., El-Machachi, Z., Fako, E., Ferrari, A. C., Genreith-Schriever, A., George, J., Goodall, R. E. A., Grey, C. P., Han, S., Handley, W., Heenen, H. H., Hermansson, K., Holm, C., Jaafar, J., Hofmann, S., Jakob, K. S., Jung, H., Kapil, V., Kaplan, A. D., Karimitari, N., Kroupa, N., Kullgren, J., Kuner, M. C., Kuryla, D., Liepuoniute, G., Margraf, J. T., Magdău, I.-B., Michaelides, A., Moore, J. H., Naik, A. A., Niblett, S. P., Norwood, S. W., O’Neill, N., Ortner, C., Persson, K. A., Reuter, K., Rosen, A. S., Schaaf, L. L., Schran, C., Sivonxay, E., Stenczel, T. K., Svahn, V., Sutton, C., van der Oord, C., Varga-Umbrich, E., Vegge, T., Vondrák, M., Wang, Y., Witt, W. C., Zills, F., and Csányi, G. A foundation model for atomistic materials chemistry, 2023.
  • Batzner et al. (2022) Batzner, S., Musaelian, A., Sun, L., Geiger, M., Mailoa, J. P., Kornbluth, M., Molinari, N., Smidt, T. E., and Kozinsky, B. E(3)-equivariant graph neural networks for data-efficient and accurate interatomic potentials. Nat. Commun., 13(1):2453, May 2022. ISSN 2041-1723.
  • Brandstetter et al. (2022) Brandstetter, J., Hesselink, R., van der Pol, E., Bekkers, E. J., and Welling, M. Geometric and physical quantities improve e(3) equivariant message passing. In International Conference on Learning Representations, 2022.
  • Brogat-Motte et al. (2022) Brogat-Motte, L., Flamary, R., Brouard, C., Rousu, J., and D’Alché-Buc, F. Learning to Predict Graphs with Fused Gromov-Wasserstein Barycenters. In Chaudhuri, K., Jegelka, S., Song, L., Szepesvari, C., Niu, G., and Sabato, S. (eds.), Proceedings of the 39th International Conference on Machine Learning, volume 162 of Proceedings of Machine Learning Research, pp.  2321–2335. PMLR, July 2022.
  • Butler et al. (2018) Butler, K. T., Davies, D. W., Cartwright, H., Isayev, O., and Walsh, A. Machine learning for molecular and materials science. Nature, 559(7715):547–555, Jul 2018. ISSN 1476-4687.
  • Cao et al. (2022) Cao, L., Coventry, B., Goreshnik, I., Huang, B., Sheffler, W., Park, J. S., Jude, K. M., Marković, I., Kadam, R. U., Verschueren, K. H., et al. Design of protein-binding proteins from the target structure alone. Nature, 605(7910):551–560, 2022.
  • Chang & Ye (2023) Chang, J. and Ye, J. C. Bidirectional generation of structure and properties through a single molecular foundation model. arXiv preprint arXiv:2211.10590, 2023.
  • Chen et al. (2020) Chen, B., Bécigneul, G., Ganea, O.-E., Barzilay, R., and Jaakkola, T. Optimal transport graph neural networks. arXiv preprint arXiv:2006.04804, 2020.
  • Choudhary et al. (2022) Choudhary, K., DeCost, B., Chen, C., Jain, A., Tavazza, F., Cohn, R., Park, C. W., Choudhary, A., Agrawal, A., Billinge, S. J. L., Holm, E., Ong, S. P., and Wolverton, C. Recent advances and applications of deep learning methods in materials science. npj Comput. Mater., 8(1):59, Apr 2022. ISSN 2057-3960.
  • Chowdhury & Needham (2021) Chowdhury, S. and Needham, T. Generalized spectral clustering via gromov-wasserstein learning. In International Conference on Artificial Intelligence and Statistics, pp.  712–720. PMLR, 2021.
  • Cuturi (2013) Cuturi, M. Sinkhorn distances: Lightspeed computation of optimal transport. Advances in neural information processing systems, 26, 2013.
  • Cuturi & Doucet (2014) Cuturi, M. and Doucet, A. Fast computation of wasserstein barycenters. In International conference on machine learning, pp. 685–693. PMLR, 2014.
  • Du et al. (2022) Du, W., Zhang, H., Du, Y., Meng, Q., Chen, W., Zheng, N., Shao, B., and Liu, T.-Y. Se (3) equivariant graph neural networks with complete local frames. In International Conference on Machine Learning, pp. 5583–5608. PMLR, 2022.
  • Du et al. (2024) Du, Y., Wang, L., Feng, D., Wang, G., Ji, S., Gomes, C. P., Ma, Z.-M., et al. A new perspective on building efficient and expressive 3d equivariant graph neural networks. Advances in Neural Information Processing Systems, 36, 2024.
  • Ellinger et al. (2020) Ellinger, B., Bojkova, D., Zaliani, A., Cinatl, J., Claussen, C., Westhaus, S., Reinshagen, J., Kuzikov, M., Wolf, M., Geisslinger, G., Gribbon, P., and Ciesek, S. Identification of inhibitors of sars-cov-2 in-vitro cellular toxicity in human (caco-2) cells using a large scale drug repurposing collection, 2020.
  • Fang et al. (2021) Fang, X., Liu, L., Lei, J., He, D., Zhang, S., Zhou, J., Wang, F., Wu, H., and Wang, H. Chemrl-gem: Geometry enhanced molecular representation learning for property prediction. Nature Machine Intelligence, 2021. doi: 10.48550/ARXIV.2106.06130.
  • Fang et al. (2022) Fang, X., Liu, L., Lei, J., He, D., Zhang, S., Zhou, J., Wang, F., Wu, H., and Wang, H. Geometry-enhanced molecular representation learning for property prediction. Nature Machine Intelligence, 4(2):127–134, 2022.
  • Fedik et al. (2022) Fedik, N., Zubatyuk, R., Kulichenko, M., Lubbers, N., Smith, J. S., Nebgen, B., Messerly, R., Li, Y. W., Boldyrev, A. I., Barros, K., Isayev, O., and Tretiak, S. Extending machine learning beyond interatomic potentials for predicting molecular properties. Nat. Rev. Chem., 6(9):653–672, Sep 2022. ISSN 2397-3358. doi: 10.1038/s41570-022-00416-3.
  • Feydy et al. (2019) Feydy, J., Séjourné, T., Vialard, F.-X., Amari, S.-i., Trouvé, A., and Peyré, G. Interpolating between optimal transport and mmd using sinkhorn divergences. In The 22nd International Conference on aRtIfIcIaL InTeLlIgEnCe and Statistics, pp.  2681–2690. PMLR, 2019.
  • Fuchs et al. (2020) Fuchs, F., Worrall, D., Fischer, V., and Welling, M. Se(3)-transformers: 3d roto-translation equivariant attention networks. In Larochelle, H., Ranzato, M., Hadsell, R., Balcan, M., and Lin, H. (eds.), Advances in Neural Information Processing Systems, volume 33, pp.  1970–1981. Curran Associates, Inc., 2020.
  • Gasteiger et al. (2020a) Gasteiger, J., Groß, J., and Günnemann, S. Directional message passing for molecular graphs. International Conference on Learning Representations, 2020a.
  • Gasteiger et al. (2020b) Gasteiger, J., Groß, J., and Günnemann, S. Directional message passing for molecular graphs. In International Conference on Learning Representations (ICLR), 2020b.
  • Gasteiger et al. (2021) Gasteiger, J., Becker, F., and Günnemann, S. Gemnet: Universal directional graph neural networks for molecules. In Beygelzimer, A., Dauphin, Y., Liang, P., and Vaughan, J. W. (eds.), Advances in Neural Information Processing Systems, 2021.
  • Gilmer et al. (2017a) Gilmer, J., Schoenholz, S. S., Riley, P. F., Vinyals, O., and Dahl, G. E. Neural message passing for quantum chemistry. In International Conference on Machine Learning, pp. 1263–1272, 2017a.
  • Gilmer et al. (2017b) Gilmer, J., Schoenholz, S. S., Riley, P. F., Vinyals, O., and Dahl, G. E. Neural message passing for quantum chemistry. In Precup, D. and Teh, Y. W. (eds.), Proceedings of the 34th ICML, volume 70 of Proceedings of Machine Learning Research, pp. 1263–1272. PMLR, 06–11 Aug 2017b.
  • Grimme (2019) Grimme, S. Exploration of chemical compound, conformer, and reaction space with meta-dynamics simulations based on tight-binding quantum chemical calculations. Journal of chemical theory and computation, 15(5):2847–2862, 2019.
  • Hawkins (2017) Hawkins, P. C. Conformation generation: the state of the art. Journal of chemical information and modeling, 57(8):1747–1756, 2017.
  • Hu et al. (2020a) Hu, W., Liu, B., Gomes, J., Zitnik, M., Liang, P., Pande, V., and Leskovec, J. Strategies for pre-training graph neural networks. In International Conference on Learning Representations, 2020a.
  • Hu et al. (2020b) Hu, W., Liu, B., Gomes, J., Zitnik, M., Liang, P., Pande, V., and Leskovec, J. Strategies for pre-training graph neural networks. In International Conference on Learning Representations, 2020b.
  • Hu et al. (2021) Hu, W., Fey, M., Ren, H., Nakata, M., Dong, Y., and Leskovec, J. Ogb-lsc: A large-scale challenge for machine learning on graphs. arXiv preprint arXiv:2103.09430, 2021.
  • Kipf & Welling (2017) Kipf, T. N. and Welling, M. Semi-supervised classification with graph convolutional networks. In International Conference on Learning Representations, 2017.
  • Landrum (2016) Landrum, G. Rdkit: open-source cheminformatics http://www. rdkit. org. 3(8), 2016.
  • Le et al. (2022) Le, K., Le, D., Nguyen, H., Do, D., Pham, T., and Ho, N. Entropic Gromov-Wasserstein between Gaussian distributions. In ICML, 2022.
  • Le Gouic et al. (2022) Le Gouic, T., Paris, Q., Rigollet, P., and Stromme, A. J. Fast convergence of empirical barycenters in Alexandrov spaces and the Wasserstein space. Journal of the European Mathematical Society, 25(6):2229–2250, May 2022. ISSN 1435-9855.
  • Lin et al. (2020) Lin, T., Ho, N., Chen, X., Cuturi, M., and Jordan, M. I. Fixed-support Wasserstein barycenters: Computational hardness and fast algorithm. In NeurIPS, pp.  5368–5380, 2020.
  • Liu et al. (2021) Liu, M., Fu, C., Zhang, X., Wang, L., Xie, Y., Yuan, H., Luo, Y., Xu, Z., Xu, S., and Ji, S. Fast quantum property prediction via deeper 2d and 3d graph networks. arXiv preprint arXiv:2106.08551, 2021.
  • Liu et al. (2022a) Liu, S., Wang, H., Liu, W., Lasenby, J., Guo, H., and Tang, J. Pre-training molecular graph representation with 3d geometry. In International Conference on Learning Representations, 2022a.
  • Liu et al. (2022b) Liu, Y., Wang, L., Liu, M., Lin, Y., Zhang, X., Oztekin, B., and Ji, S. Spherical message passing for 3d molecular graphs. In International Conference on Learning Representations, 2022b.
  • Ma et al. (2023) Ma, X., Chu, X., Wang, Y., Lin, Y., Zhao, J., Ma, L., and Zhu, W. Fused gromov-wasserstein graph mixup for graph-level classifications. In Thirty-seventh Conference on Neural Information Processing Systems, 2023.
  • Meyer et al. (2018) Meyer, B., Sawatlon, B., Heinen, S., Von Lilienfeld, O. A., and Corminboeuf, C. Machine learning meets volcano plots: computational discovery of cross-coupling catalysts. Chemical science, 9(35):7069–7077, 2018.
  • Morgan (1965) Morgan, H. L. The generation of a unique machine description for chemical structures-a technique developed at chemical abstracts service. Journal of Chemical Documentation, 5(2):107–113, May 1965. ISSN 1541-5732. doi: 10.1021/c160017a018.
  • Neyshabur et al. (2013) Neyshabur, B., Khadem, A., Hashemifar, S., and Arab, S. S. Netal: a new graph-based method for global alignment of protein–protein interaction networks. Bioinformatics, 29(13):1654–1662, 2013.
  • Paszke et al. (2017) Paszke, A., Gross, S., Chintala, S., Chanan, G., Yang, E., DeVito, Z., Lin, Z., Desmaison, A., Antiga, L., and Lerer, A. Automatic differentiation in pytorch. In NIPS 2017 Workshop on Autodiff, 2017.
  • Peyré (2015) Peyré, G. Entropic approximation of wasserstein gradient flows. SIAM Journal on Imaging Sciences, 8(4):2323–2351, 2015.
  • Peyré et al. (2016) Peyré, G., Cuturi, M., and Solomon, J. Gromov-wasserstein averaging of kernel and distance matrices. In International conference on machine learning, pp. 2664–2672. PMLR, 2016.
  • Peyré et al. (2019) Peyré, G., Cuturi, M., et al. Computational optimal transport: With applications to data science. Foundations and Trends® in Machine Learning, 11(5-6):355–607, 2019.
  • Peyré et al. (2016) Peyré, G., Cuturi, M., and Solomon, J. Gromov-Wasserstein Averaging of Kernel and Distance Matrices. In Balcan, M. F. and Weinberger, K. Q. (eds.), Proceedings of The 33rd International Conference on Machine Learning, volume 48 of Proceedings of Machine Learning Research, pp.  2664–2672, New York, New York, USA, June 2016. PMLR.
  • Rioux et al. (2023) Rioux, G., Goldfeld, Z., and Kato, K. Entropic gromov-wasserstein distances: Stability, algorithms, and distributional limits. arXiv preprint arXiv:2306.00182, 2023.
  • Rong et al. (2020) Rong, Y., Bian, Y., Xu, T., Xie, W., Wei, Y., Huang, W., and Huang, J. Self-supervised graph transformer on large-scale molecular data. Advances in Neural Information Processing Systems, 33:12559–12571, 2020.
  • Ross et al. (2022) Ross, J., Belgodere, B., Chenthamarakshan, V., Padhi, I., Mroueh, Y., and Das, P. Large-scale chemical language representations capture molecular structure and properties. Nature Machine Intelligence, 4(12):1256–1264, 2022.
  • Satorras et al. (2021) Satorras, V. G., Hoogeboom, E., and Welling, M. E(n) equivariant graph neural networks. In Meila, M. and Zhang, T. (eds.), Proceedings of the 38th International Conference on Machine Learning, volume 139 of Proceedings of Machine Learning Research, pp.  9323–9332. PMLR, 18–24 Jul 2021.
  • Scarselli et al. (2009) Scarselli, F., Gori, M., Tsoi, A. C., Hagenbuchner, M., and Monfardini, G. The graph neural network model. IEEE Transactions on Neural Networks, 20(1):61–80, 2009.
  • Schmitzer (2019) Schmitzer, B. Stabilized sparse scaling algorithms for entropy regularized transport problems. SIAM Journal on Scientific Computing, 41(3):A1443–A1481, 2019.
  • Schütt et al. (2017) Schütt, K., Kindermans, P., Felix, H. E. S., Chmiela, S., Tkatchenko, A., and Müller, K. Schnet: A continuous-filter convolutional neural network for modeling quantum interactions. In Advances in Neural Information Processing Systems 30: Annual Conference on Neural Information Processing Systems 2017, December 4-9, 2017, Long Beach, CA, USA, pp.  991–1001, 2017.
  • Schütt et al. (2017) Schütt, K., Kindermans, P.-J., Sauceda Felix, H. E., Chmiela, S., Tkatchenko, A., and Müller, K.-R. Schnet: A continuous-filter convolutional neural network for modeling quantum interactions. In Guyon, I., Luxburg, U. V., Bengio, S., Wallach, H., Fergus, R., Vishwanathan, S., and Garnett, R. (eds.), Advances in Neural Information Processing Systems, volume 30. Curran Associates, Inc., 2017.
  • Schütt et al. (2021) Schütt, K. T., Unke, O. T., and Gastegger, M. Equivariant message passing for the prediction of tensorial properties and molecular spectra. ICML, pp.  1–13, 2021.
  • Source (2020) Source, D. L. Main protease structure and xchem fragment screen, 2020.
  • Stärk et al. (2022) Stärk, H., Beaini, D., Corso, G., Tossou, P., Dallago, C., Günnemann, S., and Lió, P. 3D infomax improves GNNs for molecular property prediction. In Chaudhuri, K., Jegelka, S., Song, L., Szepesvari, C., Niu, G., and Sabato, S. (eds.), Proceedings of the 39th International Conference on Machine Learning, volume 162 of Proceedings of Machine Learning Research, pp.  20479–20502. PMLR, 17–23 Jul 2022.
  • Tang et al. (2023) Tang, J., Zhao, K., and Li, J. A fused gromov-wasserstein framework for unsupervised knowledge graph entity alignment. arXiv preprint arXiv:2305.06574, 2023.
  • Titouan et al. (2019) Titouan, V., Courty, N., Tavenard, R., Laetitia, C., and Flamary, R. Optimal Transport for structured data with application on graphs. In Chaudhuri, K. and Salakhutdinov, R. (eds.), Proceedings of the 36th International Conference on Machine Learning, volume 97 of Proceedings of Machine Learning Research, pp.  6275–6284. PMLR, June 2019.
  • Titouan et al. (2020) Titouan, V., Chapel, L., Flamary, R., Tavenard, R., and Courty, N. Fused Gromov-Wasserstein Distance for Structured Objects. Algorithms, 13(9):212, August 2020. ISSN 1999-4893. doi: 10.3390/a13090212.
  • Touret et al. (2020) Touret, F., Gilles, M., Barral, K., and et al. In vitro screening of a fda approved chemical library reveals potential inhibitors of sars-cov-2 replication. Sci Rep, 10:13093, 2020. doi: 10.1038/s41598-020-70143-6.
  • Vamathevan et al. (2019) Vamathevan, J., Clark, D., Czodrowski, P., Dunham, I., Ferran, E., Lee, G., Li, B., Madabhushi, A., Shah, P., Spitzer, M., and Zhao, S. Applications of machine learning in drug discovery and development. Nat. Rev. Drug Discov., 18(6):463–477, Jun 2019. ISSN 1474-1784.
  • Veličković et al. (2018) Veličković, P., Cucurull, G., Casanova, A., Romero, A., Liò, P., and Bengio, Y. Graph attention networks. In International Conference on Learning Representations, 2018.
  • Veličković et al. (2018) Veličković, P., Cucurull, G., Casanova, A., Romero, A., Liò, P., and Bengio, Y. Graph attention networks. In ICLR, 2018.
  • Vincent-Cuaz et al. (2021) Vincent-Cuaz, C., Vayer, T., Flamary, R., Corneli, M., and Courty, N. Online graph dictionary learning. In International conference on machine learning, pp. 10564–10574. PMLR, 2021.
  • Vincent-Cuaz et al. (2022) Vincent-Cuaz, C., Flamary, R., Corneli, M., Vayer, T., and Courty, N. Template based Graph Neural Network with Optimal Transport Distances. In Koyejo, S., Mohamed, S., Agarwal, A., Belgrave, D., Cho, K., and Oh, A. (eds.), Advances in Neural Information Processing Systems, volume 35, pp.  11800–11814. Curran Associates, Inc., 2022.
  • Wang et al. (2019) Wang, S., Guo, Y., Wang, Y., Sun, H., and Huang, J. Smiles-bert: Large scale unsupervised pre-training for molecular property prediction. In Proceedings of the 10th ACM International Conference on Bioinformatics, Computational Biology and Health Informatics, BCB ’19, pp. 429–436, New York, NY, USA, 2019. Association for Computing Machinery. ISBN 9781450366663. doi: 10.1145/3307339.3342186.
  • Wang et al. (2022) Wang, Y., Wang, J., Cao, Z., and Barati Farimani, A. Molecular contrastive learning of representations via graph neural networks. Nature Machine Intelligence, 4(3):279–287, 2022.
  • Wang et al. (2024a) Wang, Y., Wang, T., Li, S., He, X., Li, M., Wang, Z., Zheng, N., Shao, B., and Liu, T.-Y. Enhancing geometric representations for molecules with equivariant vector-scalar interactive message passing. Nature Communications, 15(1):313, 2024a.
  • Wang et al. (2024b) Wang, Z., Jiang, T., Wang, J., and Xuan, Q. Multi-modal representation learning for molecular property prediction: Sequence, graph, geometry. arXiv preprint arXiv:2401.03369, 2024b.
  • Wu et al. (2018) Wu, Z., Ramsundar, B., Feinberg, E. N., Gomes, J., Geniesse, C., Pappu, A. S., Leswing, K., and Pande, V. MoleculeNet: A benchmark for molecular machine learning. Chemical Science, pp.  513–530, 2018.
  • Xiong et al. (2019) Xiong, Z., Wang, D., Liu, X., Zhong, F., Wan, X., Li, X., Li, Z., Luo, X., Chen, K., Jiang, H., et al. Pushing the boundaries of molecular representation for drug discovery with the graph attention mechanism. Journal of medicinal chemistry, 63(16):8749–8760, 2019.
  • Xu et al. (2019a) Xu, H., Luo, D., and Carin, L. Scalable gromov-wasserstein learning for graph partitioning and matching. Advances in neural information processing systems, 32, 2019a.
  • Xu et al. (2019b) Xu, H., Luo, D., Zha, H., and Duke, L. C. Gromov-wasserstein learning for graph matching and node embedding. In International conference on machine learning, pp. 6932–6941. PMLR, 2019b.
  • Xu et al. (2018) Xu, K., Li, C., Tian, Y., Sonobe, T., Kawarabayashi, K.-i., and Jegelka, S. Representation learning on graphs with jumping knowledge networks. In Dy, J. and Krause, A. (eds.), Proceedings of the 35th International Conference on Machine Learning, volume 80 of Proceedings of Machine Learning Research, pp.  5453–5462. PMLR, 10–15 Jul 2018.
  • Yang et al. (2019) Yang, K., Swanson, K., Jin, W., Coley, C., Eiden, P., Gao, H., Guzman-Perez, A., Hopper, T., Kelley, B., Mathea, M., Palmer, A., Settels, V., Jaakkola, T., Jensen, K., and Barzilay, R. Analyzing learned molecular representations for property prediction. Journal of Chemical Information and Modeling, 59(8):3370–3388, July 2019. ISSN 1549-960X. doi: 10.1021/acs.jcim.9b00237.
  • Zaverkin & Kästner (2020) Zaverkin, V. and Kästner, J. Gaussian moments as physically inspired molecular descriptors for accurate and scalable machine learning potentials. Journal of Chemical Theory and Computation, 16(8):5410–5421, 2020. doi: 10.1021/acs.jctc.0c00347.
  • Zaverkin et al. (2024) Zaverkin, V., Alesiani, F., Maruyama, T., Errica, F., Christiansen, H., Takamoto, M., Weber, N., and Niepert, M. Higher-rank irreducible Cartesian tensors for equivariant message passing, 2024.
  • Zeng et al. (2023) Zeng, Z., Zhu, R., Xia, Y., Zeng, H., and Tong, H. Generative graph dictionary learning. In International Conference on Machine Learning, pp. 40749–40769. PMLR, 2023.
  • Zhou et al. (2023) Zhou, G., Gao, Z., Ding, Q., Zheng, H., Xu, H., Wei, Z., Zhang, L., and Ke, G. Uni-mol: A universal 3d molecular representation learning framework. In The Eleventh International Conference on Learning Representations, 2023.
  • Zhu et al. (2023) Zhu, Y., Hwang, J., Adams, K., Liu, Z., Nan, B., Stenfors, B., Du, Y., Chauhan, J., Wiest, O., Isayev, O., Coley, C. W., Sun, Y., and Wang, W. Learning over molecular conformer ensembles: Datasets and benchmarks, 2023.

Structure-Aware E(3)-Invariant Molecular Conformer Aggregation Networks Supplementary Material

In this supplementary material, we first present rigorous proofs for results concerning the E(3) invariant of the proposed aggregation mechanism in Appendix A, while those for the fast convergence of the empirical FGW barycenter are then provided in Appendix B. The entropic FGW algorithm and practical GPU considerations are then given in more detail in Appendix C. Finally, some experiment configuration supplements on SchNet neural architecture, 3D conformers generation and comparison between entropic FGW and FGW-mixup are deffered in Appendix D.

Appendix A Proof of Theorem 3.1

We will proceed as follows. First, we prove that 𝐇𝙱𝙲\mathbf{H}^{\mathtt{BC}} is invariant to permutations of the input conformers and actions of the group E(3)E(3) applied to the input conformers. 𝐇𝙱𝙲\mathbf{H}^{\mathtt{BC}} is invariant to the order of the input conformers by definition of the barycenter which is invariant to the order of the input graphs. Moreover, since by definition, actions of the group E(3)E(3) preserve distances between points in a 33-dimensional space and, by assumption, the upstream 3D MPNN is invariant to actions of E(3)E(3), for any input conformer SS and its corresponding graph G(S)=(𝐇,𝑨,𝝎)G(S)=(\mathbf{H},{\bm{A}},\bm{\omega}) and any action gE(3)g\in E(3) we have that G(gS)=(𝐇,𝑨,𝝎)=G(S)G(gS)=(\mathbf{H},{\bm{A}},\bm{\omega})=G(S). 𝐇\mathbf{H} is invariant to actions of the group E(3)E(3) because the 3D MPNN is invariant to actions of the group. 𝑨{\bm{A}} is invariant due to distances between points being invariant. Hence, the input graphs to the barycenter optimization problem are invariant to actions of the group E(3)E(3) on the conformers and, therefore, the output barycenters are invariant to such group actions.

We know now for Equation 7: 𝐇𝚌𝚘𝚖𝚋=𝐖𝟸𝙳𝐇𝟸𝙳+𝐖𝟹𝙳𝐇𝟹𝙳+𝐖𝙱𝙲𝐇𝙱𝙲\mathbf{H}^{\mathtt{comb}}=\mathbf{W}^{\mathtt{2D}}\mathbf{H}^{\mathtt{2D}}+\mathbf{W}^{\mathtt{3D}}\mathbf{H}^{\mathtt{3D}}+\mathbf{W}^{\mathtt{BC}}\mathbf{H}^{\mathtt{BC}}, that 𝐇𝙱𝙲\mathbf{H}^{\mathtt{BC}} is invariant to both actions of the group E(3)E(3) and permutations of the input conformers. We also know that 𝐇𝟹𝙳\mathbf{H}^{\mathtt{3D}} is equivariant to permutations of the input conformers, that is, every permutation of the input conformers also permutes the column of 𝐇𝟹𝙳\mathbf{H}^{\mathtt{3D}} in the same way. In addition, 𝐇𝟹𝙳\mathbf{H}^{\mathtt{3D}} is invariant to actions of the group E(3)E(3) on the input conformers by the assumption that the 3D MPNN is E(3)E(3)-invariant.

What remains to be shown is that 1Kk=1K𝐇𝚌𝚘𝚖𝚋\displaystyle\frac{1}{K}\sum_{k=1}^{K}\mathbf{H}^{\mathtt{comb}} with 𝐇𝚌𝚘𝚖𝚋=𝐖𝟸𝙳𝐇𝟸𝙳+𝐖𝟹𝙳𝐇𝟹𝙳+𝐖𝙱𝙲𝐇𝙱𝙲\mathbf{H}^{\mathtt{comb}}=\mathbf{W}^{\mathtt{2D}}\mathbf{H}^{\mathtt{2D}}+\mathbf{W}^{\mathtt{3D}}\mathbf{H}^{\mathtt{3D}}+\mathbf{W}^{\mathtt{BC}}\mathbf{H}^{\mathtt{BC}} is invariant to column permutations of the matrix 𝐇𝟹𝙳\mathbf{H}^{\mathtt{3D}}. Since we compute the average of the columns of 𝐇𝚌𝚘𝚖𝚋\mathbf{H}^{\mathtt{comb}} this is indeed the case.

Appendix B Proof of Theorem 4.1

We begin by introducing the notation used in the proof of the paper.

Undirected attribute graph as Distributions: Given the set of vertices and edges of the graph (V,E)(V,E), we define the undirected labeled graphs as tuples of the form G=(V,E,f,s)G=(V,E,\ell_{f},\ell_{s}). Here, f:V𝛀f\ell_{f}:V\rightarrow\bm{\Omega}_{f} is a labeling function that associates each vertex viVv_{i}\in V with an attribute or feature 𝒙i=f(vi){\bm{x}}_{i}=\ell_{f}(v_{i}) in some feature metric space (𝛀f,df)(\bm{\Omega}_{f},d_{f}), and s:V𝛀s\ell_{s}:V\rightarrow\bm{\Omega}_{s} maps a vertex viv_{i} from the graph to its structure representation 𝒂k=s(vi){\bm{a}}_{k}=\ell_{s}(v_{i}) in some structure space (𝛀s,𝑨)(\bm{\Omega}_{s},{\bm{A}}) specific to each graph where A:𝛀s×𝛀s+A:\bm{\Omega}_{s}\times\bm{\Omega}_{s}\rightarrow\mathbb{R}_{+} is a symmetric application aimed at measuring similarity between nodes in the graph. In our context, it is sufficient to consider the feature space as a dd-dimensional Euclidean space 1×d\mathbb{R}^{1\times d} with Euclidean distance (2\ell^{2} norm), i.e.,  (𝛀f,df)=(1×d,2)(\bm{\Omega}_{f},d_{f})=(\mathbb{R}^{1\times d},\ell^{2}). With some abuse, we denote AA and 𝑨{\bm{A}} as both the measure of structural similarity and the matrix encoding this similarity between nodes in the graph, i.e., 𝐀[i,k]:=A(𝐚i,𝐚k){\bm{A}}[i,k]:=A({\bm{a}}_{i},{\bm{a}}_{k}).

The Wasserstein (W) and Gromov-Wasserstein (GW) distances: Given two structure graphs G1=(𝑯1,𝑨1,𝝎1)G_{1}=({\bm{H}}_{1},{\bm{A}}_{1},\bm{\omega}_{1}) and G2=(𝑯2,𝑨2,𝝎2)G_{2}=({\bm{H}}_{2},{\bm{A}}_{2},\bm{\omega}_{2}) of order n1n_{1} and n2n_{2}, respectively, described previously by their probability measure μ1=kω1kδ(𝒙1k,𝒂1k)\displaystyle\mu_{1}=\sum_{k}\omega_{1k}\delta_{({\bm{x}}_{1k},{\bm{a}}_{1k})} and μ2=lω1lδ(𝒙2l,𝒂2l))\displaystyle\mu_{2}=\sum_{l}\omega_{1l}\delta_{({\bm{x}}_{2l},{\bm{a}}_{2l}))}, we denote μ𝑯1=kωkδ𝒙k\displaystyle\mu_{{\bm{H}}_{1}}=\sum_{k}\omega_{k}\delta_{{\bm{x}}_{k}} and μ𝑨1=kωkδ𝒂k\mu_{{\bm{A}}_{1}}=\sum_{k}\omega_{k}\delta_{{\bm{a}}_{k}} (resp. μ𝐇2\mu_{{\bm{H}}_{2}} and μ𝑨2\displaystyle\mu_{{\bm{A}}_{2}}) the marginals of μ1\mu_{1} (resp. μ2\mu_{2}) w.r.t. the feature and structure, respectively. We next consider the following notations:

Jp(𝑨1,𝑨2,𝝅)\displaystyle J_{p}({\bm{A}}_{1},{\bm{A}}_{2},\bm{\pi}) =ijklLijkl(𝑨1,𝑨2)p𝝅ij𝝅kl\displaystyle=\sum_{ijkl}L_{ijkl}({\bm{A}}_{1},{\bm{A}}_{2})^{p}\bm{\pi}_{ij}\bm{\pi}_{kl} (16)
GWp(μ𝑯1,μ𝑯2)p\displaystyle\text{GW}_{p}(\mu_{{\bm{H}}_{1}},\mu_{{\bm{H}}_{2}})^{p} =min𝝅𝚷(𝝎1,𝝎2)Jp(𝑨1,𝑨2,𝝅)\displaystyle=\min_{\bm{\pi}\in\bm{\Pi}(\bm{\omega}_{1},\bm{\omega}_{2})}J_{p}({\bm{A}}_{1},{\bm{A}}_{2},\bm{\pi}) (17)
Hp(𝑴,𝝅)\displaystyle H_{p}({\bm{M}},\bm{\pi}) =kldf(𝒙1k,𝒙2l)p𝝅kl\displaystyle=\sum_{kl}d_{f}({\bm{x}}_{1k},{\bm{x}}_{2l})^{p}\bm{\pi}_{kl} (18)
Wp(μ𝑨1,μ𝑨2)p\displaystyle\text{W}_{p}(\mu_{{\bm{A}}_{1}},\mu_{{\bm{A}}_{2}})^{p} =min𝝅𝚷(𝝎1,𝝎2)Hp(𝑴,𝝅).\displaystyle=\min_{\bm{\pi}\in\bm{\Pi}(\bm{\omega}_{1},\bm{\omega}_{2})}H_{p}({\bm{M}},\bm{\pi}). (19)

Note that 𝔼p,α(𝑴,𝑨1,𝑨2,𝝅)\mathbb{E}_{p,\alpha}\left({\bm{M}},{\bm{A}}_{1},{\bm{A}}_{2},\bm{\pi}\right) can be further expanded as follows:

𝔼p,α(𝑴,𝑨1,𝑨2,𝝅)\displaystyle\mathbb{E}_{p,\alpha}\left({\bm{M}},{\bm{A}}_{1},{\bm{A}}_{2},\bm{\pi}\right) =(1α)𝑴p+α𝑳(𝑨1,𝑨2)p𝝅,𝝅\displaystyle=\langle(1-\alpha){\bm{M}}^{p}+\alpha{\bm{\mathsfit{L}}}({\bm{A}}_{1},{\bm{A}}_{2})^{p}\otimes\bm{\pi},\bm{\pi}\rangle
=ijkl[(1α)df(𝒙1k,𝒙2l)p+α|𝑨1(i,k)𝑨2(j,l)|p]𝝅ij𝝅kl.\displaystyle=\sum_{ijkl}\Big{[}(1-\alpha)d_{f}({\bm{x}}_{1k},{\bm{x}}_{2l})^{p}+\alpha\left|{\bm{A}}_{1}(i,k)-{\bm{A}}_{2}(j,l)\right|^{p}\big{]}\bm{\pi}_{ij}\bm{\pi}_{kl}.

Comparison between FGW and W: Let 𝝅𝚷(𝝎1,𝝎2)\bm{\pi}\in\bm{\Pi}(\bm{\omega}_{1},\bm{\omega}_{2}) be any admissible coupling between 𝝎1\bm{\omega}_{1} and 𝝎2\bm{\omega}_{2}. Assume that μ1\mu_{1} and μ2\mu_{2} belong to the same ground space (𝛀,𝑨,μ)(\bm{\Omega},{\bm{A}},\mu), by the definition of the FGW distance in equation (3), i.e.,

FGWp,α(G1,G2):=min𝝅Π(𝝎1,𝝎2)(1α)𝑴+α𝑳(𝑨1,𝑨2)𝝅,𝝅,\displaystyle\text{FGW}_{p,\alpha}(G_{1},G_{2}):=\min_{{\bm{\pi}}\in\Pi\left(\bm{\omega}_{1},\bm{\omega}_{2}\right)}\left\langle(1-\alpha){\bm{M}}+\alpha{\bm{\mathsfit{L}}}\left({\bm{A}}_{1},{\bm{A}}_{2}\right)\otimes{\bm{\pi}},{\bm{\pi}}\right\rangle,

we get the following important relationship:

FGWp,α(G1,G2)\displaystyle\text{FGW}_{p,\alpha}(G_{1},G_{2}) (1α)𝑴+α𝑳(𝑨1,𝑨2)𝝅,𝝅\displaystyle\leq\left\langle(1-\alpha){\bm{M}}+\alpha{\bm{\mathsfit{L}}}\left({\bm{A}}_{1},{\bm{A}}_{2}\right)\otimes{\bm{\pi}},{\bm{\pi}}\right\rangle
=ijkl[(1α)df(𝒙1k,𝒙2l)p+α|𝑨[i,k]𝑨[j,l]|p]𝝅ij𝝅kl\displaystyle=\sum_{ijkl}\Big{[}(1-\alpha)d_{f}({\bm{x}}_{1k},{\bm{x}}_{2l})^{p}+\alpha\left|{\bm{A}}[i,k]-{\bm{A}}[j,l]\right|^{p}\big{]}\bm{\pi}_{ij}\bm{\pi}_{kl}
ijkl[(1α)df(𝒙1k,𝒙2l)p+α|𝑨[i,j]+𝑨[j,k]𝑨[j,k]+𝑨[k,l]|p]𝝅ij𝝅kl\displaystyle\leq\sum_{ijkl}\Big{[}(1-\alpha)d_{f}({\bm{x}}_{1k},{\bm{x}}_{2l})^{p}+\alpha\left|{\bm{A}}[i,j]+{\bm{A}}[j,k]-{\bm{A}}[j,k]+{\bm{A}}[k,l]\right|^{p}\big{]}\bm{\pi}_{ij}\bm{\pi}_{kl} (20)
=ijkl[(1α)df(𝒙1k,𝒙2l)p+α|𝑨[i,j]+𝑨[k,l]|p]𝝅ij𝝅kl\displaystyle=\sum_{ijkl}\Big{[}(1-\alpha)d_{f}({\bm{x}}_{1k},{\bm{x}}_{2l})^{p}+\alpha\left|{\bm{A}}[i,j]+{\bm{A}}[k,l]\right|^{p}\big{]}\bm{\pi}_{ij}\bm{\pi}_{kl}
ijkl[(1α)df(𝒙1k,𝒙2l)p+(α2p1𝑨[i,j]p+α2p1𝑨[k,l]p)]𝝅ij𝝅kl\displaystyle\leq\sum_{ijkl}\Big{[}(1-\alpha)d_{f}({\bm{x}}_{1k},{\bm{x}}_{2l})^{p}+\left(\alpha 2^{p-1}{\bm{A}}[i,j]^{p}+\alpha 2^{p-1}{\bm{A}}[k,l]^{p}\right)\big{]}\bm{\pi}_{ij}\bm{\pi}_{kl} (21)
ijkl[((1α)df(𝒙1k,𝒙2l)p+α2p1𝑨[k,l]p)\displaystyle\leq\sum_{ijkl}\Big{[}\left((1-\alpha)d_{f}({\bm{x}}_{1k},{\bm{x}}_{2l})^{p}+\alpha 2^{p-1}{\bm{A}}[k,l]^{p}\right)
+((1α)df(𝒙1i,𝒙2j)p+α2p1𝑨[i,j]p)]𝝅ij𝝅kl\displaystyle\hskip 56.9055pt+\left((1-\alpha)d_{f}({\bm{x}}_{1i},{\bm{x}}_{2j})^{p}+\alpha 2^{p-1}{\bm{A}}[i,j]^{p}\right)\big{]}\bm{\pi}_{ij}\bm{\pi}_{kl}
kl[((1α)df(𝒙1k,𝒙2l)p+α2p1𝑨[k,l]p)]𝝅kl\displaystyle\leq\sum_{kl}\Big{[}\left((1-\alpha)d_{f}({\bm{x}}_{1k},{\bm{x}}_{2l})^{p}+\alpha 2^{p-1}{\bm{A}}[k,l]^{p}\right)\Big{]}\bm{\pi}_{kl}
+i,j[((1α)df(𝒙1i,𝒙2j)p+α2p1𝑨[i,j]p)]𝝅ij\displaystyle\hskip 56.9055pt+\sum_{i,j}\Big{[}\left((1-\alpha)d_{f}({\bm{x}}_{1i},{\bm{x}}_{2j})^{p}+\alpha 2^{p-1}{\bm{A}}[i,j]^{p}\right)\big{]}\bm{\pi}_{ij}
kl[((1α)df(𝒙1k,𝒙2l)p+2p1α𝑨[k,l]p)]𝝅kl\displaystyle\leq\sum_{kl}\Big{[}\left((1-\alpha)d_{f}({\bm{x}}_{1k},{\bm{x}}_{2l})^{p}+2^{p-1}\alpha{\bm{A}}[k,l]^{p}\right)\Big{]}\bm{\pi}_{kl}
kl[((1α)df(𝒙1k,𝒙2l)+2p1α𝑨[k,l])]p𝝅kl.\displaystyle\leq\sum_{kl}\Big{[}\left((1-\alpha)d_{f}({\bm{x}}_{1k},{\bm{x}}_{2l})+2^{p-1}\alpha{\bm{A}}[k,l]\right)\Big{]}^{p}\bm{\pi}_{kl}. (22)

Here equation (20) is obtained by using the triangle inequality of the metric 𝑨{\bm{A}}, while equation (21) comes from Lemma B.1. Note that the inequality equation (22) holds for any admissible coupling 𝝅𝚷(𝝎1,𝝎2)\bm{\pi}\in\bm{\Pi}(\bm{\omega}_{1},\bm{\omega}_{2}). This also holds for the optimal coupling, denoted by 𝝅¯\overline{\bm{\pi}}, for the Wasserstein distance Wp(μ1,μ2)\text{W}_{p}(\mu_{1},\mu_{2}) defined by the following metric space (𝛀,d¯)(\bm{\Omega},\overline{d}), where d¯\overline{d} is given by:

d¯((𝒙1,𝒂1),(𝒙2,𝒂2))=(1α)df(𝒙1,𝒙2)+2p1α𝑨(𝒂1,𝒂2).\displaystyle\overline{d}(({\bm{x}}_{1},{\bm{a}}_{1}),({\bm{x}}_{2},{\bm{a}}_{2}))=(1-\alpha)d_{f}({\bm{x}}_{1},{\bm{x}}_{2})+2^{p-1}\alpha{\bm{A}}({\bm{a}}_{1},{\bm{a}}_{2}).

Here, we have to verify that d¯\overline{d} is in fact a distance in 𝛀\bm{\Omega}. Indeed, for the triangle inequality, for any (𝒙1,𝒂1),(𝒙2,𝒂2),(𝒙3,𝒂3)𝛀({\bm{x}}_{1},{\bm{a}}_{1}),({\bm{x}}_{2},{\bm{a}}_{2}),({\bm{x}}_{3},{\bm{a}}_{3})\in\bm{\Omega}, we have

d¯((𝒙1,𝒂1),(𝒙2,𝒂2))\displaystyle\overline{d}(({\bm{x}}_{1},{\bm{a}}_{1}),({\bm{x}}_{2},{\bm{a}}_{2})) =(1α)df(𝒙1,𝒙2)+2p1α𝑨(𝒂1,𝒂2)\displaystyle=(1-\alpha)d_{f}({\bm{x}}_{1},{\bm{x}}_{2})+2^{p-1}\alpha{\bm{A}}({\bm{a}}_{1},{\bm{a}}_{2})
(1α)df(𝒙1,𝒙3)+(1α)df(𝒙3,𝒙2)\displaystyle\leq(1-\alpha)d_{f}({\bm{x}}_{1},{\bm{x}}_{3})+(1-\alpha)d_{f}({\bm{x}}_{3},{\bm{x}}_{2})
+2p1α𝑨(𝒂1,𝒂2)+2p1α𝑨(𝒂1,𝒂3)+2p1α𝑨(𝒂3,𝒂2)\displaystyle\hskip 56.9055pt+2^{p-1}\alpha{\bm{A}}({\bm{a}}_{1},{\bm{a}}_{2})+2^{p-1}\alpha{\bm{A}}({\bm{a}}_{1},{\bm{a}}_{3})+2^{p-1}\alpha{\bm{A}}({\bm{a}}_{3},{\bm{a}}_{2})
=(1α)df(𝒙1,𝒙3)+2p1α𝑨(𝒂1,𝒂3)\displaystyle=(1-\alpha)d_{f}({\bm{x}}_{1},{\bm{x}}_{3})+2^{p-1}\alpha{\bm{A}}({\bm{a}}_{1},{\bm{a}}_{3})
+(1α)df(𝒙3,𝒙2)+2p1α𝑨(𝒂1,𝒂2)+2p1α𝑨(𝒂3,𝒂2)\displaystyle\hskip 56.9055pt+(1-\alpha)d_{f}({\bm{x}}_{3},{\bm{x}}_{2})+2^{p-1}\alpha{\bm{A}}({\bm{a}}_{1},{\bm{a}}_{2})+2^{p-1}\alpha{\bm{A}}({\bm{a}}_{3},{\bm{a}}_{2})
=d¯((𝒙1,𝒂1),(𝒙3,𝒂3))+d¯((𝒙3,𝒂3),(𝒙2,𝒂2)).\displaystyle=\overline{d}(({\bm{x}}_{1},{\bm{a}}_{1}),({\bm{x}}_{3},{\bm{a}}_{3}))+\overline{d}(({\bm{x}}_{3},{\bm{a}}_{3}),({\bm{x}}_{2},{\bm{a}}_{2})).

In this case, the above inequality is derived from the triangle inequalities of dd and CC. The symmetry and equality relation of d¯\overline{d} comes from the same properties of dfd_{f} and 𝑨{\bm{A}}.

By definition of Wasserstein distance in equation (19), this implies that

FGWp,α(G1,G2)Wp(μ𝑨1,μ𝑨2).\displaystyle\text{FGW}_{p,\alpha}(G_{1},G_{2})\leq\text{W}_{p}(\mu_{{\bm{A}}_{1}},\mu_{{\bm{A}}_{2}}). (23)
Lemma B.1.

For any pp\in\mathbb{N}. We have

(a+b)p2p(a+b)p.\displaystyle(a+b)^{p}\leq 2^{p}(a+b)^{p}. (24)
Proof of Lemma B.1.

It is easy to check that the inequality is satisfied for p=1p=1. For any pp\in\mathbb{N} and p>1p>1, it holds that

(x+y)p\displaystyle(x+y)^{p} =((12p1)1px(12p1)1p+(12p1)1py(12p1)1p)p\displaystyle=\left(\left(\frac{1}{2^{p-1}}\right)^{\frac{1}{p}}\frac{x}{\left(\frac{1}{2^{p-1}}\right)^{\frac{1}{p}}}+\left(\frac{1}{2^{p-1}}\right)^{\frac{1}{p}}\frac{y}{\left(\frac{1}{2^{p-1}}\right)^{\frac{1}{p}}}\right)^{p}
=((12p1)1p1x(12p1)+(12p1)1p1y(12p1))p\displaystyle=\left(\left(\frac{1}{2^{p-1}}\right)^{\frac{1}{p-1}}\frac{x}{\left(\frac{1}{2^{p-1}}\right)}+\left(\frac{1}{2^{p-1}}\right)^{\frac{1}{p-1}}\frac{y}{\left(\frac{1}{2^{p-1}}\right)}\right)^{p}
[(12p1)1p1+(12p1)1p1]p1(xp12p1+yp12p1)\displaystyle\leq\left[\left(\frac{1}{2^{p-1}}\right)^{\frac{1}{p-1}}+\left(\frac{1}{2^{p-1}}\right)^{\frac{1}{p-1}}\right]^{p-1}\left(\frac{x^{p}}{\frac{1}{2^{p-1}}}+\frac{y^{p}}{\frac{1}{2^{p-1}}}\right)
=2p1[(12p1)1p1]p12p1(xp+yp)\displaystyle=2^{p-1}\left[\left(\frac{1}{2^{p-1}}\right)^{\frac{1}{p-1}}\right]^{p-1}2^{p-1}(x^{p}+y^{p})
=2p1(xp+yp).\displaystyle=2^{p-1}(x^{p}+y^{p}).

Here the last inequality is a consequence of the Hölder inequality. ∎

Recall that we have

μ¯Kargminμ𝒫p(𝛀)1KkFGWp,αp(μ,μk)𝒫p(𝛀)\displaystyle\overline{\mu}_{K}\in\operatorname*{arg~{}min}_{\mu\in\mathcal{P}_{p}(\bm{\Omega})}\frac{1}{K}\sum_{k}\text{FGW}_{p,\alpha}^{p}(\mu,\mu_{k})\in\mathcal{P}_{p}(\bm{\Omega})
μ¯0argminμ𝒫p(𝛀)𝒫p(𝛀)FGWp,αp(μ,ν)𝑑P(ν)𝒫p(𝛀)𝒫p(𝛀).\displaystyle\overline{\mu}_{0}\in\operatorname*{arg~{}min}_{\mu\in\mathcal{P}_{p}(\bm{\Omega})}\int_{\mathcal{P}_{p}(\bm{\Omega})}\text{FGW}_{p,\alpha}^{p}(\mu,\nu)dP(\nu)\in\mathcal{P}_{p}(\bm{\Omega})\subset\mathcal{P}_{p}(\bm{\Omega}).

Therefore, μ¯K\overline{\mu}_{K} and μ¯0\overline{\mu}_{0} belong to the same ground space (𝛀,𝑨,μ).(\bm{\Omega},{\bm{A}},\mu). By using equation (23), this implies that

FGWp,α(μ¯0,μ¯K)2Wp(μ¯0,μ¯K)p\displaystyle\text{FGW}_{p,\alpha}(\overline{\mu}_{0},\overline{\mu}_{K})\leq 2\text{W}_{p}(\overline{\mu}_{0},\overline{\mu}_{K})^{p} (25)

and hence

𝔼(FGW2,α(μ¯0,μ¯K))𝔼(W22(μ¯0,μ¯K))4σP2(1β+γ)K.\displaystyle\mathbb{E}\left(\text{FGW}_{2,\alpha}(\overline{\mu}_{0},\overline{\mu}_{K})\right)\leq\mathbb{E}\left(\text{W}_{2}^{2}(\overline{\mu}_{0},\overline{\mu}_{K})\right)\leq\frac{4\sigma^{2}_{P}}{(1-\beta+\gamma)K}. (26)

This is equivalent to the following

𝔼(FGW2,α2(G¯0,G¯K))4σP2(1β+γ)2K.\displaystyle\mathbb{E}\left(\text{FGW}_{2,\alpha}^{2}(\overline{G}_{0},\overline{G}_{K})\right)\leq\frac{4\sigma^{2}_{P}}{(1-\beta+\gamma)^{2}K}. (27)

Here, Lemma B.3 leads to the last inquality for the Wassertein distance Wp(μ,ν)\text{W}_{p}(\mu,\nu) on the metric space (𝛀,d¯)(\bm{\Omega},\overline{d}).

We recall the following definitions and results.

Definition B.2 (Strongly convex and smooth functions).

Given a separable Hilbert space HH, with inner product ,\langle\cdot,\cdot\rangle and norm |||\cdot|, we define the subdifferential ψS2\partial\psi\subset S^{2} of a function ψ:S\psi:S\rightarrow\mathbb{R} by ψ={(x,g):yS,ψ(y)ψ(x)+g,yx}\partial\psi=\left\{(x,g):\forall y\in S,\psi(y)\geq\psi(x)+\langle g,y-x\rangle\right\} and denote ψ(x)={gS:(x,g)ψ}\partial\psi(x)=\left\{g\in S:(x,g)\in\partial\psi\right\}. We then refer to ψ\psi as γ\gamma-strongly convex, if for every xSx\in S it holds that

ψ(x), and g,xyψ(x)ψ(y)+α2|xy|2 for all gψ(x) and all yS.\displaystyle\partial\psi(x)\neq\emptyset,\text{ and }\langle g,x-y\rangle\geq\psi(x)-\psi(y)+\frac{\alpha}{2}\left|x-y\right|^{2}\text{ for all }g\in\partial\psi(x)\text{ and all }y\in S. (28)

We also recall that a convex function ψ:S\psi:S\rightarrow\mathbb{R} is called β\beta-smooth if

gx,xyψ(x)ψ(y)+β2|xy|2,gxψ(x),x,yS.\displaystyle\langle g_{x},x-y\rangle\leq\psi(x)-\psi(y)+\frac{\beta}{2}\left|x-y\right|^{2},~{}\forall g_{x}\in\partial\psi(x),~{}\forall x,y\in S. (29)
Lemma B.3 (Corollary 4.4 from (Le Gouic et al., 2022)).

Let P𝒫2(𝒫2(𝛀))P\in\mathcal{P}_{2}(\mathcal{P}_{2}(\bm{\Omega}))be a probability measure on the 2-Wasserstein space W2\text{W}_{2} on the metric space (𝛀,d¯)(\bm{\Omega},\overline{d}) and let μ¯0𝒫2(𝛀)\overline{\mu}_{0}\in\mathcal{P}_{2}(\bm{\Omega}) and σP2\sigma^{2}_{P} be a barycenter and a variance functional of PP, respectively. Let γ,β>0\gamma,\beta>0 and suppose that every μsupp(P)\mu\in\operatorname*{supp}(P) is the pushforward of μ¯0\overline{\mu}_{0} by the gradient of an γ\gamma-strongly convex and β\beta smooth function ψμ¯0μ\psi_{\overline{\mu}_{0}\rightarrow\mu}, defined in Definition B.2, i.e., μ=(ψμ¯0μ)#μ¯0\mu=(\nabla\psi_{\overline{\mu}_{0}\rightarrow\mu})_{\#}\overline{\mu}_{0}. If βγ<1\beta-\gamma<1, then μ¯0\overline{\mu}_{0} is unique and any empirical barycenter μ¯K\overline{\mu}_{K} of PP satisfies

𝔼(W22(μ¯0,μ¯K))4σP2(1β+γ)2K.\displaystyle\mathbb{E}\left(\text{W}_{2}^{2}(\overline{\mu}_{0},\overline{\mu}_{K})\right)\leq\frac{4\sigma^{2}_{P}}{(1-\beta+\gamma)^{2}K}. (30)

We then obtain the following important identity

𝔼p,α(𝑴,𝑨1,𝑨2,𝝅)\displaystyle\mathbb{E}_{p,\alpha}\left({\bm{M}},{\bm{A}}_{1},{\bm{A}}_{2},\bm{\pi}\right) :=ijkl[(1α)df(𝒙1k,𝒙2l)p+α|𝑨1(i,k)𝑨2(j,l)|p]𝝅ij𝝅kl\displaystyle:=\sum_{ijkl}\Big{[}(1-\alpha)d_{f}({\bm{x}}_{1k},{\bm{x}}_{2l})^{p}+\alpha\left|{\bm{A}}_{1}(i,k)-{\bm{A}}_{2}(j,l)\right|^{p}\big{]}\bm{\pi}_{ij}\bm{\pi}_{kl}
=(1α)Hp(𝑴,𝝅)+αJp(𝑨1,𝑨2,𝝅).\displaystyle=(1-\alpha)H_{p}({\bm{M}},\bm{\pi})+\alpha J_{p}({\bm{A}}_{1},{\bm{A}}_{2},\bm{\pi}). (31)

Furthermore, given 𝝅α\bm{\pi}_{\alpha} as the coupling that minimizes 𝔼p,α(𝑴,𝑨1,𝑨2,)\mathbb{E}_{p,\alpha}\left({\bm{M}},{\bm{A}}_{1},{\bm{A}}_{2},\cdot\right), it holds that

FGWp,αp(μ1,μ2)\displaystyle\text{FGW}_{p,\alpha}^{p}(\mu_{1},\mu_{2}) =min𝝅𝚷(𝝎1,𝝎2)𝔼p,α(𝑴,𝑨1,𝑨2,𝝅)\displaystyle=\min_{\bm{\pi}\in\bm{\Pi}(\bm{\omega}_{1},\bm{\omega}_{2})}\mathbb{E}_{p,\alpha}\left({\bm{M}},{\bm{A}}_{1},{\bm{A}}_{2},\bm{\pi}\right)
=𝔼p,α(𝑴,𝑨1,𝑨2,𝝅α)\displaystyle=\mathbb{E}_{p,\alpha}\left({\bm{M}},{\bm{A}}_{1},{\bm{A}}_{2},\bm{\pi}_{\alpha}\right)
=(1α)Hp(𝑴,𝝅α)+αJp(𝑨1,𝑨2,𝝅α)\displaystyle=(1-\alpha)H_{p}({\bm{M}},\bm{\pi}_{\alpha})+\alpha J_{p}({\bm{A}}_{1},{\bm{A}}_{2},\bm{\pi}_{\alpha})
(1α)Wpp(μ𝑨1,μ𝑨2)+αGWpp(μ𝑯1,μ𝑯2).\displaystyle\geq(1-\alpha)\text{W}_{p}^{p}(\mu_{{\bm{A}}_{1}},\mu_{{\bm{A}}_{2}})+\alpha\text{GW}_{p}^{p}(\mu_{{\bm{H}}_{1}},\mu_{{\bm{H}}_{2}}). (32)

This results in the following by-product:

𝔼(GW22(μ¯0,𝑯1,μ¯K,𝑯2))4σP2α(1β+γ)2K,\displaystyle\mathbb{E}\left(\text{GW}_{2}^{2}(\overline{\mu}_{0,{\bm{H}}_{1}},\overline{\mu}_{K,{\bm{H}}_{2}})\right)\leq\frac{4\sigma^{2}_{P}}{\alpha(1-\beta+\gamma)^{2}K},
𝔼(W22(μ¯0,𝑨1,μ¯K,𝑨2))4σP2(1α)(1β+γ)2K.\displaystyle\mathbb{E}\left({\text{W}}^{2}_{2}(\overline{\mu}_{0,{\bm{A}}_{1}},\overline{\mu}_{K,{\bm{A}}_{2}})\right)\leq\frac{4\sigma^{2}_{P}}{(1-\alpha)(1-\beta+\gamma)^{2}K}. (33)

Appendix C Solving Entropic Fused Gromov-Wasserstein

C.1 Optimization Formulation

Entropic-regularization (Cuturi, 2013) has been well-studied in various OT formulations including entropic Wassterstein (Peyré et al., 2019; Peyré, 2015) and entropic Gromov-Wasserstein (Rioux et al., 2023; Le et al., 2022) for fast computations of numerous barycenter problems (Cuturi & Doucet, 2014; Peyré et al., 2016; Xu et al., 2019b; Lin et al., 2020). However, adapting entropic formulation to the FGW barycenter problem for learning molecular representation, to the best of our knowledge, is novel. Our motivation is to implement Sinkhorn projections solving for the FGW barycenter subgradients, which can be straightforwardly vectorized, computed reversed-mode gradients, and batch-distributed in multi-GPU, benefiting the scaling of the learning pipeline with large molecular datasets.

Recall that FGW between two graphs G1,G2G_{1},G_{2} can be described as

FGW(G1,G2)FGW2,α(G1,G2):=min𝝅Π(𝝎1,𝝎2)(1α)𝑴+α𝑳(𝑨1,𝑨2)𝝅,𝝅,\operatorname{FGW}\left(G_{1},G_{2}\right)\equiv\text{FGW}_{2,\alpha}\left(G_{1},G_{2}\right):=\min_{{\bm{\pi}}\in\Pi\left(\bm{\omega}_{1},\bm{\omega}_{2}\right)}\left\langle(1-\alpha){\bm{M}}+\alpha{\bm{\mathsfit{L}}}\left({\bm{A}}_{1},{\bm{A}}_{2}\right)\otimes{\bm{\pi}},{\bm{\pi}}\right\rangle, (34)

where 𝑴:=(df(𝑯1[i],𝑯2[j]))n1×n2n1×n2{\bm{M}}:=\left(d_{f}({\bm{H}}_{1}[i],{\bm{H}}_{2}[j])\right)_{n_{1}\times n_{2}}\in{\mathbb{R}}^{n_{1}\times n_{2}} the pairwise node distance matrix, 𝑳(𝑨1,𝑨2):={L(𝑨1[i,j],𝑨2[k,l])}ijkl{\bm{\mathsfit{L}}}\left({\bm{A}}_{1},{\bm{A}}_{2}\right):=\{L({\bm{A}}_{1}[i,j],{\bm{A}}_{2}[k,l])\}_{ijkl} the 4-tensor of structure distance matrix. Assume the loss having the form L(a,b)=f1(a)+f2(b)h1(a)h2(b)L(a,b)=f_{1}(a)+f_{2}(b)-h_{1}(a)h_{2}(b), then from Proposition 1 (Peyré et al., 2016), we can write the second term in Equation 34 as

𝑳(𝑨1,𝑨2)𝝅:=𝑳2h1(𝑨1)𝝅h2(𝑨2),𝑳:=f1(𝑨1)𝝎1𝟏n2+𝟏n1𝝎2f1(𝑨2),\displaystyle\begin{split}{\bm{\mathsfit{L}}}\left({\bm{A}}_{1},{\bm{A}}_{2}\right)\otimes{\bm{\pi}}&:={\bm{L}}-2h_{1}({\bm{A}}_{1}){\bm{\pi}}h_{2}({\bm{A}}_{2})^{\top},\\ {\bm{L}}&:=f_{1}({\bm{A}}_{1})\bm{\omega}_{1}\bm{1}_{n_{2}}^{\top}+\bm{1}_{n_{1}}\bm{\omega}_{2}^{\top}f_{1}({\bm{A}}_{2})^{\top},\end{split} (35)

where the square loss L=L2L=\operatorname{L_{2}} having the element-wise functions f1(a)=a2,f2(b)=b2,h1(a)=a,h2(b)=2bf_{1}(a)=a^{2},\,f_{2}(b)=b^{2},\,h_{1}(a)=a,\,h_{2}(b)=2b, and the KL loss L=KLL=\operatorname{KL} having f1(a)=alogaa,f2(b)=b,h1(a)=a,h2(b)=logbf_{1}(a)=a\log{a}-a,\,f_{2}(b)=b,\,h_{1}(a)=a,\,h_{2}(b)=\log{b}. By definition, the entropic FGW distance adds an entropic term as

FGWϵ(G1,G2):=min𝝅Π(𝝎1,𝝎2)(1α)𝑴+α𝑳(𝑨1,𝑨2)𝝅,𝝅ϵH(𝝅),\operatorname{FGW}_{\epsilon}\left(G_{1},G_{2}\right):=\min_{{\bm{\pi}}\in\Pi\left(\bm{\omega}_{1},\bm{\omega}_{2}\right)}\left\langle(1-\alpha){\bm{M}}+\alpha{\bm{\mathsfit{L}}}\left({\bm{A}}_{1},{\bm{A}}_{2}\right)\otimes{\bm{\pi}},{\bm{\pi}}\right\rangle-\epsilon\operatorname{H}({\bm{\pi}}), (36)

which is a non-convex optimization problem. Following Proposition 2 (Peyré et al., 2016), the update rule solving Equation 36 is the solution of the entropic OT

𝝅=argmin𝝅Π(𝝎1,𝝎2)(1α)𝑴+𝑳2h1(𝑨1)𝝅h2(𝑨2),𝝅ϵH(𝝅),{\bm{\pi}}=\operatorname*{arg~{}min}_{{\bm{\pi}}\in\Pi\left(\bm{\omega}_{1},\bm{\omega}_{2}\right)}\left\langle(1-\alpha){\bm{M}}+{\bm{L}}-2h_{1}({\bm{A}}_{1}){\bm{\pi}}h_{2}({\bm{A}}_{2})^{\top},{\bm{\pi}}\right\rangle-\epsilon\operatorname{H}({\bm{\pi}}), (37)

where the feature and structure matrices 𝑴,𝑳{\bm{M}},\,{\bm{L}} can be precomputed. Since the cost matrix of Equation 37 depends on 𝝅{\bm{\pi}}, solving Equation 36 involves iterations of solving the linear entropic OT problem Equation 37 with Sinkhorn projections, as shown in Algorithm 2.

Following Proposition 4.1 in  (Peyré et al., 2019), for sufficiently small regularization ϵ\epsilon, the approximate solution from the entropic OT problem

OTϵ(𝝎1,𝝎2)=min𝝅Π(𝝎1,𝝎2)𝑪,𝝅ϵH(𝝅)\operatorname{OT}_{\epsilon}(\bm{\omega}_{1},\bm{\omega}_{2})=\min_{{\bm{\pi}}\in\Pi\left(\bm{\omega}_{1},\bm{\omega}_{2}\right)}\left\langle{\bm{C}},{\bm{\pi}}\right\rangle-\epsilon\operatorname{H}({\bm{\pi}})

approaches the original OT problem. However, small ϵ\epsilon incurs serious numerical instability for a high-dimensional cost matrix, e.g., large graph comparisons. In the context of the barycenter problem, too high ϵ\epsilon has cheap computation time but leads to a “blurry” barycenter solution, while smaller ϵ\epsilon produces better accuracy but suffers both numerical instability and computational demanding (Schmitzer, 2019; Feydy et al., 2019). Thus, we solve the dual entropic OT problem (Peyré et al., 2019)

OTϵ(𝝎1,𝝎2)= def. max𝒇,𝒈𝝎1,𝒇+𝝎2,𝒈ε𝝎1𝝎2,exp(1ε(𝒇𝒈𝑪))1,\operatorname{OT}_{\epsilon}(\bm{\omega}_{1},\bm{\omega}_{2})\stackrel{{\scriptstyle\text{ def. }}}{{=}}\max_{{\bm{f}},{\bm{g}}}\langle\bm{\omega}_{1},{\bm{f}}\rangle+\langle\bm{\omega}_{2},{\bm{g}}\rangle-\varepsilon\left\langle\bm{\omega}_{1}\otimes\bm{\omega}_{2},\exp\left(\frac{1}{\varepsilon}({\bm{f}}\oplus{\bm{g}}-{\bm{C}})\right)-1\right\rangle, (38)

where 𝒇n1,𝒈n2{\bm{f}}\in{\mathbb{R}}^{n_{1}},{\bm{g}}\in{\mathbb{R}}^{n_{2}} are the potential vectors and \oplus is the tensor plus, with stabilized log-sum-exp (LSE) operators (Feydy et al., 2019) for i[1,n1],j[1,n2]\forall i\in[1,n_{1}],\,\forall j\in[1,n_{2}]

𝒇[i]=εLSEk=1n2(log(𝝎2[k])+1ε𝒈[k]1ε𝑪[i,k])𝒈[j]=εLSEk=1n1(log(𝝎1[k])+1ε𝒇[k]1ε𝑪[k,j]) where LSEk=1n(𝒙[k])=logk=1nexp(𝒙[k])\begin{gathered}{\bm{f}}[i]=-\varepsilon\operatorname{LSE}_{k=1}^{n_{2}}\left(\log\left(\bm{\omega}_{2}[k]\right)+\frac{1}{\varepsilon}{\bm{g}}[k]-\frac{1}{\varepsilon}{\bm{C}}[i,k]\right)\\ {\bm{g}}[j]=-\varepsilon\operatorname{LSE}_{k=1}^{n_{1}}\left(\log\left(\bm{\omega}_{1}[k]\right)+\frac{1}{\varepsilon}{\bm{f}}[k]-\frac{1}{\varepsilon}{\bm{C}}[k,j]\right)\\ \text{ where }\quad\operatorname{LSE}_{k=1}^{n}\left({\bm{x}}[k]\right)=\log\sum_{k=1}^{n}\exp\left({\bm{x}}[k]\right)\end{gathered} (39)

for numerical stability with large dimension datasets. In practice, we implement these LSEs using einsum operations.

The optimal coupling of the dual entropic OT can be computed after the potential vectors converged as

𝝅=exp(1ε(𝒇𝒈𝑪))(𝝎1𝝎2).{\bm{\pi}}^{*}=\exp\left(\frac{1}{\varepsilon}({\bm{f}}^{*}\oplus{\bm{g}}^{*}-{\bm{C}})\right)\cdot(\bm{\omega}_{1}\otimes\bm{\omega}_{2}).

We state the Sinkhorn algorithm solving the dual entropic OT in Algorithm 3. With Algorithm 3, the auto-differentiation gradient is robust through small perturbation of the potential solutions 𝒇,𝒈{\bm{f}}^{*},{\bm{g}}^{*}. We observe that ϵ[0.1,0.2]\epsilon\in[0.1,0.2] and a few Sinkhorn LSEs are enough for our setting.

C.2 Empirical Entropic FGW Barycenter

In our experiments, we propose to solve the entropic relaxation of Equation 6 for utilizing GPU-accelerated Sinkhorn iterations (Peyré et al., 2019). Given a set of conformer graphs {Gs:=(𝑯s,𝑨s,𝝎s)}s=1K\{G_{s}:=({\bm{H}}_{s},{\bm{A}}_{s},\bm{\omega}_{s})\}_{s=1}^{K}, we want to optimize the entropic barycenter Equation 13, where we fixed the prior on nodes 𝝎¯\overline{\bm{\omega}}Titouan et al. (2019) solves Equation 13 using Block Coordinate Descent as shown in Algorithm 1, which iteratively minimizes the original FGW distance between the current barycenter and the graphs GsG_{s}. In our case, we solve for KK couplings of entropic FGW distances to the empirical graphs at each iteration (i.e., λs=1/K\lambda_{s}=1/K), then following the update rule for structure matrix (Proposition 4, (Peyré et al., 2016))

𝑨¯(k+1)1𝝎¯𝝎¯s=1Kλs𝝅s(k)𝑨s𝝅s(k),if L:=L2𝑨¯(k+1)exp(1𝝎¯𝝎¯s=1Kλs𝝅s(k)𝑨s𝝅s(k)),if L:=KL,\displaystyle\begin{split}\overline{{\bm{A}}}^{(k+1)}&\leftarrow\frac{1}{\overline{\bm{\omega}}~{}\overline{\bm{\omega}}^{\top}}\sum_{s=1}^{K}\lambda_{s}{{\bm{\pi}}_{s}^{(k)}}{\bm{A}}_{s}{{\bm{\pi}}_{s}^{(k)}}^{\top},\,\textrm{if }L:=L_{2}\\ \overline{{\bm{A}}}^{(k+1)}&\leftarrow\exp\left(\frac{1}{\overline{\bm{\omega}}~{}\overline{\bm{\omega}}^{\top}}\sum_{s=1}^{K}\lambda_{s}{{\bm{\pi}}_{s}^{(k)}}{\bm{A}}_{s}{{\bm{\pi}}_{s}^{(k)}}^{\top}\right),\,\textrm{if }L:=\operatorname{KL},\end{split} (40)

and for the feature matrix (Titouan et al., 2019; Cuturi & Doucet, 2014)

𝑯¯(k+1)diag(1/𝝎¯)s=1Kλs𝝅s(k)𝑯s,\overline{{\bm{H}}}^{(k+1)}\leftarrow\mathrm{diag}(1/\overline{\bm{\omega}})\sum_{s=1}^{K}\lambda_{s}{{\bm{\pi}}_{s}^{(k)}}{\bm{H}}_{s}, (41)

leading to Algorithm 1. Note that Algorithm 1 presents only the structure matrix update rule for the square loss L=L2L=L_{2} for clarity. We can modify the structure matrix update rule according to the loss type LL. In the experiment, we found that the algorithm usually converges after running the number of 1010 outer iterations and 3030 inner iterations.

Algorithm 2 Entropic FGW with Sinkhorn projections
  Input: Graph G1,G2G_{1},G_{2}, weighting α\alpha, entropic scalar ϵ\epsilon.
  Optimizing: 𝝅Π(𝝎1,𝝎2){\bm{\pi}}\in\Pi(\bm{\omega}_{1},\bm{\omega}_{2}).
  Compute 𝑳:=f1(𝑨1)𝝎1𝟏n2+𝟏n1𝝎2f1(𝑨2){\bm{L}}:=f_{1}({\bm{A}}_{1})\bm{\omega}_{1}\bm{1}_{n_{2}}^{\top}+\bm{1}_{n_{1}}\bm{\omega}_{2}^{\top}f_{1}({\bm{A}}_{2})^{\top}.
  Compute 𝑴=(d(𝑯1[i],𝑯2[j]))n1×n2{\bm{M}}=\left(d({\bm{H}}_{1}[i],{\bm{H}}_{2}[j])\right)_{n_{1}\times n_{2}}.
  Initialize 𝝅{\bm{\pi}}.
  repeat
     Compute 𝑪(k)=(1α)𝑴+2α(𝑳h1(𝑨1)𝝅(k)h2(𝑨2)){\bm{C}}^{(k)}=(1-\alpha){\bm{M}}+2\alpha({\bm{L}}-h_{1}({\bm{A}}_{1}){\bm{\pi}}^{(k)}h_{2}({\bm{A}}_{2})^{\top}).
     Solve argmin𝝅s(k)𝑪,𝝅ϵH(𝝅)\operatorname*{arg~{}min}_{{\bm{\pi}}_{s}^{(k)}}\left\langle{\bm{C}},{\bm{\pi}}\right\rangle-\epsilon\operatorname{H}({\bm{\pi}}) with Algorithm 3.
  until kk in inner iterations and not converged
Algorithm 3 Stabilized LSE Sinkhorn algorithm
  Input: Entropic scalar ϵ\epsilon, cost matrix 𝑪{\bm{C}}, marginals 𝝎1,𝝎2\bm{\omega}_{1},\bm{\omega}_{2}.
  Initialize 𝒇,𝒈=𝟎{\bm{f}},{\bm{g}}=\bm{0}.
  while termination criteria not met do
     for i[1,n]\forall i\in[1,n] do
        𝒇[i]=εLSEk=1m(log(𝝎2[k])+1ε𝒈[k]1ε𝑪[i,k]){\bm{f}}[i]=-\varepsilon\operatorname{LSE}_{k=1}^{m}\left(\log\left(\bm{\omega}_{2}[k]\right)+\frac{1}{\varepsilon}{\bm{g}}[k]-\frac{1}{\varepsilon}{\bm{C}}[i,k]\right).
     end for
     for j[1,m]\forall j\in[1,m] do
        𝒈[j]=εLSEk=1n(log(𝝎1[k])+1ε𝒇[k]1ε𝑪[k,j]){\bm{g}}[j]=-\varepsilon\operatorname{LSE}_{k=1}^{n}\left(\log\left(\bm{\omega}_{1}[k]\right)+\frac{1}{\varepsilon}{\bm{f}}[k]-\frac{1}{\varepsilon}{\bm{C}}[k,j]\right).
     end for
  end while
  Return 𝝅=exp(1ε(𝒇𝒈𝑪))(𝝎1𝝎2){\bm{\pi}}^{*}=\exp\left(\frac{1}{\varepsilon}({\bm{f}}^{*}\oplus{\bm{g}}^{*}-{\bm{C}})\right)\cdot(\bm{\omega}_{1}\otimes\bm{\omega}_{2}).

Practical GPU considerations.

Our motivation for adopting entropic formulation for FGW barycenter is to solve the barycenter problem fast with (stabilized LSE) Sinkhorn projections, which can be straightforwardly vectorized in PyTorch, facilitating end-to-end unsupervised training with GPU (Cuturi, 2013; Cuturi & Doucet, 2014; Peyré et al., 2019). This entropic formulation avoids using Conditional Gradients (Titouan et al., 2019) to solve FGW, which uses the classical network flow algorithms111These algorithms are usually available in off-the-shell C++ backend libraries, which are difficult to construct auto-differentiation computation graph over these solvers. at each iteration. Furthermore, by implementing Algorithm 1 in PyTorch (Paszke et al., 2017), we utilize reverse-mode auto differentiation over solver iterations to propagate gradients from the graph parameters to the barycenter solutions. We observe that the inner entropic OT problem usually converges with a few iterations; thus, we typically limit the number of Sinkhorn iterations solving entropic OT problem to reduce memory burden (Peyré et al., 2019).

Scalability and complexity.

As shown in Algorithm 1, we have three loops to optimize for the FGW barycenter. However, the inner entropic OT problem typically converges with a few stabilized LSE Sinkhorn iterations. Thus, we fix a constant number of Sinkhorn iterations and denote maximum outer (Algorithm 1) and inner iterations (Algorithm 2) as M,NM,N. In Algorithm 2, the complexity computing 𝑪{\bm{C}} is 𝒪(n3+n2d){\mathcal{O}}(n^{3}+n^{2}d) with n:=max({ns}s=1K)n:=\max(\{n_{s}\}_{s=1}^{K}). The first term is the complexity of computing structure cost, while the second is the feature cost complexity. Thus, the complexity for Algorithm 1 is 𝒪(MKN(n3+n2d)){\mathcal{O}}(MKN(n^{3}+n^{2}d)) including the feature and structure matrix updates. Note that solving entropic FGW for KK graphs can be done in parallel with GPU. Additionally, this complexity does not depend on the maximum edge numbers in graphs e:=max({Es}s=1K)e:=\max(\{\|E_{s}\|\}_{s=1}^{K}), and thus very competitive compared to previous graph matching method (Neyshabur et al., 2013) for each outer iteration when ene\gg n.

Appendix D Experiment Configuration Supplements

D.1 SchNet Neural Architecture

We represent each of the KK molecular conformers as a set of atoms VV with atom numbers Z=(Z1,,Zn)Z=(Z_{1},...,Z_{n}) and atomic positions R=(𝒓1,..,𝒓n)R=({\bm{r}}_{1},..,{\bm{r}}_{n}). At each layer \ell an atom vv is represented by a learnable representation 𝐡v\mathbf{h}_{v}. We use the geometric message and aggregation functions of SchNet Schütt et al. (2017) but any other E(3)E(3)-invariant neural network can be used instead. Besides providing a good trade-off between model complexity and efficacy, we choose SchNet as it was used in prior related work (Axelrod & Gómez-Bombarelli, 2023).

SchNet relies on the following building blocks. The initial node attributes are learnable embeddings of the atom types, that is, 𝐡v(0)d\mathbf{h}_{v}^{(0)}\in\mathbb{R}^{d} is an embedding of the atom type of node vv with dd dimensions. Two types of combinations of atom-wise linear layers and activation functions

φi()(𝐡)\displaystyle\varphi_{i}^{(\ell)}\left(\mathbf{h}\right) :=𝐖i(l)𝐡+𝐛i(l) and \displaystyle:=\mathbf{W}_{i}^{(l)}\mathbf{h}+{\bf b}_{i}^{(l)}\ \ \ \ \mbox{ and }\ \ \ \ \
ϕi,j()(𝐡)\displaystyle\phi_{i,j}^{(\ell)}\left(\mathbf{h}\right) :=φj(𝚜𝚜𝚙(φi()(𝐡)))\displaystyle:=\varphi_{j}^{\ell}\left(\mathtt{ssp}\left(\varphi_{i}^{(\ell)}\left(\mathbf{h}\right)\right)\right) (42)

where 𝚜𝚜𝚙\mathtt{ssp} is the shifted softplus function (cite), 𝐖i(l)d×d\mathbf{W}_{i}^{(l)}\in\mathbb{R}^{d\times d}, 𝐛i(l)d{\bf b}_{i}^{(l)}\in\mathbb{R}^{d}, with dd the hidden dimension of the atom embeddings. A filter-generating network that serves as a rotationally invariant function 𝙸𝚗𝚟\mathtt{Inv}:

𝐞v,u\displaystyle\mathbf{e}_{v,u} =𝙸𝚗𝚟(\vv𝐯v(1),\vv𝐯u(1))=ϕ1,2()(𝚁𝙱𝙵(𝐫v𝐫u)),\displaystyle=\mathtt{Inv}\left(\vv{\mathbf{v}}_{v}^{(\ell-1)},\vv{\mathbf{v}}_{u}^{{(\ell-1)}}\right)=\phi_{1,2}^{(\ell)}(\mathtt{RBF}(||\mathbf{r}_{v}-\mathbf{r}_{u}||)),

where 𝚁𝙱𝙵\mathtt{RBF} is the radial basis function and ϕ1,2()\phi_{1,2}^{(\ell)} is a sequence of two dense layers with shifted softplus activation.

E(3)E(3)-invariant message-passing is performed by using the following message function

𝐦v,u()=𝐌()(𝐡v(1),𝐡u(1),𝐞v,u)=φ1()(𝐡u(l1))𝒆v,u,\displaystyle\mathbf{m}_{v,u}^{(\ell)}=\mathbf{M}^{(\ell)}\Bigl{(}\mathbf{h}_{v}^{(\ell-1)},\mathbf{h}_{u}^{(\ell-1)},\mathbf{e}_{v,u}\Bigr{)}=\varphi_{1}^{(\ell)}\left(\mathbf{h}^{(l-1)}_{u}\right)\circ{\bm{e}}_{v,u},

where \circ represents the element-wise multiplication. The aggregation function is now defined as

𝐡¯v()\displaystyle\bar{\mathbf{h}}_{v}^{(\ell)} :=𝖠𝖦𝖦()({{𝐦v,u()uN(v)}})=uN(v)𝐦v,u().\displaystyle:=\mathsf{AGG}^{(\ell)}\bigl{(}\{\!\!\{\mathbf{m}_{v,u}^{(\ell)}\mid u\in N(v)\}\!\!\}\bigr{)}=\sum_{u\in N(v)}\mathbf{m}_{v,u}^{(\ell)}.

Finally, the update function is given by

𝐡v()\displaystyle\mathbf{h}_{v}^{(\ell)} =𝖴𝖯𝖣()(𝐡v(1),𝖠𝖦𝖦()({{𝐦v,u()uN(v)}}))\displaystyle=\mathsf{UPD}^{(\ell)}\Bigl{(}\mathbf{h}_{v}^{(\ell-1)},\mathsf{AGG}^{(\ell)}\bigl{(}\{\!\!\{\mathbf{m}_{v,u}^{(\ell)}\mid u\in N(v)\}\!\!\}\bigr{)}\Bigr{)}
=𝐡v(1)+ϕ3,4()(𝐡¯v()).\displaystyle=\mathbf{h}_{v}^{(\ell-1)}+\phi_{3,4}^{(\ell)}\left(\bar{\mathbf{h}}_{v}^{(\ell)}\right). (43)

We denote the matrix whose columns are the atom-wise features from the last message-passing layer LL with 𝐇\mathbf{H}, that is, 𝐇[v]=𝐡v(L)\mathbf{H}[v]=\mathbf{h}^{(L)}_{v}.

D.2 Dataset Overview

Molecular Property Prediction Tasks  We conduct our experiments on MoleculeNet (Wu et al., 2018), a comprehensive benchmark dataset for computational chemistry. It spans a wide array of tasks that range from predicting quantum mechanical properties to determining biological activities and solubilities of compounds. In our study, we focus on the regression tasks on four datasets from MoleculeNet benchmark: Lipo, ESOL, FreeSolv, and BACE.

  • The Lipo dataset is a collection of 4200 lipophilicity values for various chemical compounds. Lipophilicity is a key property that impacts a molecule’s pharmacokinetic behavior, making it crucial for drug development.

  • ESOL contains 1128 experimental solubility values for a range of small, drug-like molecules. Understanding solubility is vital in drug discovery, as poor solubility can lead to issues with bioavailability.

  • FreeSolv offers both calculated and experimentally determined hydration-free energies for a collection of 642 small molecules. These hydration-free energies are critical for assessing a molecule’s stability and solubility in water.

  • The BACE dataset focuses on biochemical assays related to Alzheimer’s Disease. It contains 1513 pIC50 values, indicating the efficiency of various molecules in inhibiting the β\beta-site amyloid precursor protein cleaving enzyme 1 (BACE-1).

3D Molecular Classification Tasks  In addition, we evaluate the classification performance using two closely related datasets associated with SARS-CoV: SARS-CoV-2 3CL (CoV-2 3CL), and SARS-CoV-2 (CoV-2).

  • CoV-2 3CL protease dataset comprises 76 instances corresponding to inhibitory interactions, considering a total of 804 unique species. This dataset specifically addresses the inhibition of the SARS-CoV-2 3CL protease (denoted as ‘CoV-2 CL’) (Source, 2020).

  • CoV-2 dataset, which encompasses 92 instances across a spectrum of 5,476 unique species. This dataset focuses on the broader context of inhibitory interactions against SARS-CoV-2 measured in vitro within human cells (Ellinger et al., 2020; Touret et al., 2020).

Reaction-level molecule properties prediction The BDE dataset (Meyer et al., 2018) contains 5915 organometallic catalysts (ML1L2ML_{1}L_{2}), with metal centers (Pd, Pt, Au, Ag, Cu, Ni) and two flexible organic ligands (L1L_{1} and L2L_{2}) chosen from a 91-ligand library. It includes conformations of each unbound catalyst and those bound to ethylene and bromide after reacting with vinyl bromide (resulting in 11830 individual molecules). The dataset provides electronic binding energies, calculated as the energy difference between the bound-catalyst complex and the unbound catalyst, optimized using DFT. Conformers are initially generated with Open Babel and then geometry-optimized to likely represent the global minimum energy structures at the force field level.

D.3 3D Conformers Generation

RDKit offers two methodologies to generate conformers for molecules:

  • The distance geometry approach employs distance geometry principles for conformer generation, starting with the determination of a molecule’s distance bounds matrix based on connectivity and predefined rules. This matrix is then refined and used to formulate a random distance matrix, which subsequently guides the molecule’s embedding into 3D space. The resulting atomic coordinates undergo further refinement through a specialized “distance geometry force field.”

  • ETKDG method, which refines generated conformers by integrating torsion angle preferences from the Cambridge Structural Database (CSD). This technique can be further enhanced with additional torsion terms, catering especially to small rings and macrocycles, yielding high-quality conformers suitable for direct application in many scenarios.

In our experiments, we applied a standardized approach to configuring all benchmark datasets, encompassing the following steps:

  • Conformer Generation: During the training phase, we use RDKit to generate a fixed set of 200 conformers for every molecular structure specified by its SMILES string. However, in each epoch, each molecular is sampled with a KK conformers (K<<200K<<200). For the validation and testing, we use a fixed seed and generate randomly KK conformers for each sample in the dataset.

  • Parallel Processing: Utilizing a process pool enhances the parallelization of conformer generation, thereby optimizing overall efficiency. We provide in Table 6 the average execution time for generating a single conformer from its SMILES string across diverse datasets.

For a comprehensive 3D structural analysis, we present summary statistics detailing the number of edges and nodes (Table 6). These statistics provide insights into the structural characteristics of molecules within the datasets. Average values offer a perspective on the typical size of molecules in terms of edges and nodes, while minimum and maximum values reflect the varying complexities of molecular structures across datasets. Notably, the Lipo and BACE datasets emerge as the most intricate graphs, contrasting with ESOL and FreeSolv, which exhibit sparser structures. We illustrate in Figure 7 some typical generated conformers for each dataset.

Table 6: Summary statistics for edge and node counts in diverse datasets, reflecting the runtime needed to generate a conformer from a molecular structure.
Dataset Number of Edges Number of Nodes Execution Time (seconds)
Avg Min Max Avg Min Max
Lipo 101.8 24 412 48.4 12 203 4.68×1064.68\times 10^{-6}
ESOL 52 6 252 25.6 4 119 3.58×1063.58\times 10^{-6}
FreeSolv 35.5 4 92 18.1 3 44 3.13×1063.13\times 10^{-6}
BACE 135 36 376 64.7 17 184 4.34×1064.34\times 10^{-6}
CoV-2 3CL 56 16 96 27.4 8 48 3.12×1063.12\times 10^{-6}
CoV-2 95.2 4 220 45.7 3 100 3.96×1063.96\times 10^{-6}

D.4 Ablation Studies of Number of Conformers

Table 7: The impact of number of conformations KK on the accuracy of the ConAN model (without barycenter). Results are in MSE \downarrow computed on the validation set. Bold and underline values denote first and second-rank results.
KK Lipo ESOL FreeSolv BACE
0 1.387±0.2061.387\pm 0.206 2.288±0.0172.288\pm 0.017 8.564±1.3458.564\pm 1.345 1.844±0.331.844\pm 0.33
1 0.619±0.0450.619\pm 0.045 0.645±0.0540.645\pm 0.054 2.306±0.8072.306\pm 0.807 0.705±0.0640.705\pm 0.064
3 0.581±0.0330.581\pm 0.033 0.592±0.0720.592\pm 0.072 2.035±0.2562.035\pm 0.256 0.653±0.0260.653\pm 0.026
5 0.567±0.0190.567\pm 0.019 0.581±0.0510.581\pm 0.051 1.799±0.6621.799\pm 0.662 0.616±0.0510.616\pm 0.051
10 0.564±0.0300.564\pm 0.030 0.583±0.0270.583\pm 0.027 1.568±0.1831.568\pm 0.183 0.832±0.1430.832\pm 0.143
20 0.569±0.0030.569\pm 0.003 0.589±0.0120.589\pm 0.012 1.742±0.1431.742\pm 0.143 0.670±0.0360.670\pm 0.036

Table 7 illustrates that incorporating 3D conformers with K1K\geq 1 significantly enhances performance compared to relying solely on 2D molecular graphs, as used in the 2D-GAT model. However, the relationship between the number of conformations and model accuracy is not linear or straightforward. For instance, while increasing the number of conformations to K=10K=10 improves performance for datasets such as Lipo and ESOL, the best overall performance is usually achieved with K=5K=5. This suggests that an optimal number of conformers maximizes model accuracy, which varies depending on the specific dataset.

D.5 Entropic FGW versus FGW-Mixup detail

We provide more details on the efficiency ablation study in Section 6.7. We adapt the original GitHub repository https://github.com/ArthurLeoM/FGWMixup from Ma et al. (2023) as the baseline. In the context of KK FGW barycenter problem, due to the numerical instability of the exp\exp function, we have to set small stepsize γ\gamma of the Bregman projections (Algorithm 2 in (Ma et al., 2023)) to avoid NAN values output of FGW-Mixup in some datasets, leading to more inner iterations to converge. Indeed, it is particularly difficult to find optimal parameters for FGW-Mixup, balancing between the marginal errors inducing the FGW subgradient noise at the outer iteration and the empirical convergence rate at the inner iteration.

Running Time Analysis.   In Figure 5, we compare the running time of our solver with FGW-Mixup on two datasets, FreeSolv, and CoV-2 3CL, for both forward and backward steps to update gradients for the whole models. We measure average times over epochs during the training steps with increasing values of conformers KK. Note that in FGW-Mixup, the solver is not supported for inference on GPU, while our algorithm is designed for this purpose and can be scaled on large training samples using data distributed parallel in Pytorch. In particular, ConAN-FGW Single-GPU ConAN-FGW on Multi-GPUs indicates the version where one and four Tesla V100-32GB are used for training, respectively.

To delve deeper into the computation of the FGW barycenter, we present the runtime analysis in Figure 6. The configuration mirrors that of Figure 5, with the exception that the runtime is specifically gauged at the barycenter components during the forward step. Notably, the execution time exhibits a consistent pattern comparable to Figure 5, highlighting that ConAN-FGW outperforms FGX-Mixup in both single GPU and multi-GPU setups, achieving significantly faster runtimes as the number of conformers is scaled.

Refer to caption
Figure 6: Runtime comparison of FGW-Mixup, ConAN-FGW (single and multi-GPU) in the FGW barycenter computation.

Error Analysis.  In this part, we investigate the error of ConAN-FGW and FGW-Mixup. To this end, we use the solution of the original FGW problem solved by the Conditional Gradient algorithm (Titouan et al., 2020) as the approximated ground truth for comparing solution errors (Table 8). We fix the same hyperparameters for both solvers as in Figure 5. As expected, the FGW-Mixup solution errors are slightly smaller than our ConAN-FGW ones. This is due to the fact that (i) to prevent numerical instability, we set small stepsize for the mirror descents (i.e., alternating Bregman projections) and (ii) FGW-Mixup asymptotically converges to the original FGW solution up to a bounded gap (Ma et al., 2023). However, this induces more computational time for large FGW problems, as seen Figures 5 and 6. In contrast, ConAN-FGW maintains comparable solution errors to FGW-Mixup while having reasonable computational runtime and being compatible with deploying multi-GPU for large-scale problems.

Refer to caption
Figure 7: Visualizing 3D molecular conformers with corresponding SMILES strings across diverse datasets.
Table 8: Error estimation performance across datasets, demonstrating the influence of conformer variations and different methodologies for Ground Truth (GT) in conjunction with ConAN-FGW and FGW-Mixup. The comparing matrix metrics are Normalized Frobenius norm, Mean Absolute Error (MAE), Mean Absolute Percent Error (MAPE), and Mean Square Error (MSE).
Dataset Conformers GT and ConAN-FGW GT and FGW-Mixup
N-Frobenius MAE MAPE MSE N-Frobenius MAE MAPE MSE
FreeSolv 3 0.1325 0.1727 0.3523 0.0812 0.1190 0.1590 0.3210 0.0671
5 0.1387 0.1823 0.3753 0.0870 0.1258 0.1695 0.3466 0.0731
10 0.1431 0.1874 0.3876 0.0919 0.1323 0.1776 0.3638 0.0792
15 0.1460 0.1924 0.3980 0.0947 0.1358 0.1819 0.3703 0.0832
20 0.1453 0.1920 0.3954 0.0952 0.1336 0.1805 0.3662 0.0816
CoV-2 3CL 3 0.0859 0.1696 0.4207 0.0670 0.0804 0.1626 0.4055 0.0600
5 0.0842 0.1688 0.4114 0.0632 0.0793 0.1616 0.3942 0.0569
10 0.0879 0.1801 0.4452 0.0719 0.0806 0.1697 0.4201 0.0637
15 0.0859 0.1729 0.4251 0.0670 0.0764 0.1571 0.3899 0.0543
20 0.0902 0.1823 0.4558 0.0714 0.0865 0.1779 0.4460 0.0653

D.6 Visualize Conformers Generated by RDKit

We present in Figure 7 typical 3D conformers generated by RDKit with their string inputs denote below each figure.