This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

Efficient and effective training of language and graph neural network models

David S. Hippocampus
Department of Computer Science
Cranberry-Lemon University
Pittsburgh, PA 15213
[email protected]
Use footnote for providing further information about author (webpage, alternative address)—not for acknowledging funding agencies.
   Vassilis N. Ioannidis, Xiang Song, Da Zheng, George Karypis
Amazon Web Services AI, USA \ANDHouyu Zhang, Jun Ma, Yi Xu, Belinda Zeng, Trishul Chilimbi
Amazon Search AI, USA
Abstract

Can we combine heterogenous graph structure with text to learn high-quality semantic and behavioural representations? Graph neural networks (GNN)s encode numerical node attributes and graph structure to achieve impressive performance in a variety of supervised learning tasks. Current GNN approaches are challenged by textual features, which typically need to be encoded to a numerical vector before provided to the GNN that may incur some information loss. In this paper, we put forth an efficient and effective framework termed language model GNN (LM-GNN) to jointly train large-scale language models and graph neural networks. The effectiveness in our framework is achieved by applying stage-wise fine-tuning of the BERT model first with heterogenous graph information and then with a GNN model. Several system and design optimizations are proposed to enable scalable and efficient training. LM-GNN accommodates node and edge classification as well as link prediction tasks. We evaluate the LM-GNN framework in different datasets performance and showcase the effectiveness of the proposed approach. LM-GNN provides competitive results in an Amazon query-purchase-product application.

1 Introduction

GNNs rely on a layered processing architecture comprising trainable graph convolutional operations to linearly combine features per graph neighborhood, followed by pointwise nonlinear functions applied to the linearly transformed features [6]. GNNs have shown remarkable success in a variety of graph machine learning tasks both in supervised and unsupervised learning settings [13]. Typically, the graphs used for profiling GNN models have node features as numerical attributes. These numerical attributes may be the output of network that encodes a much richer original information that is in the form of text or picture. One could apply such a general pre-trained network to extract the represenations and use them as feature vectors in a GNN. However, as we detail in this work such an approach is not optimal. This raises the question of how can we train better GNN models with rich text features. This work presents a stage-wise fine-tuning framework termed LM-GNN for encoding text data with transformers and GNN models.

We implement a sequence of fine-tuning steps to gradually infuse the transformer model with graph structure information. Our study reveals the necessity of pre-fine-tuning the transformer for graph-aware tasks. The graph-aware BERT is subsequently fine-tuned together with a GNN model for performing any downstream tasks, which allows the model to access multi-hop information. Our stage-wise fine-tuning, besides achieving good performance, significantly reduces training time compared to end-to-end training without our stage-wise approach because the mode converges faster. Further, LM-GNN is a distributed framework which can scale to hundreds of millions nodes. Besides improving the model performance we implement a series of system and design optimizations to speed up the overall training speed. Our LM-GNN framework improves compared to baseline performance in four public datasets and one query-purchase-product application.

2 Related work

Node classification is typically formulated as a semi-supervised learning (SSL) task over graphs, where the labels for a subset of nodes are available for training [4]. GNNs achieve state-of-the-art performance in SSL by utilizing regular graph convolution [17] or graph attention [23], while these models have later been extended in the heterogeneous graph setting [21, 10, 25]. Similarly, another prominent graph machine learning task is link prediction with numerous applications in recommendation systems [24] and drug discovery [31, 15]. Knowledge-graph (KG) embedding models for link prediction rely on mapping the nodes and edges of the KG to a vector space by maximizing a score function for existing KG edges [24, 28, 29]. Relational graph convolutional network (RGCN) models [21] have been successful in link prediction and contrary to KG embedding models can further utilize node features. The aforementioned graph machine learning tasks may also be addressed using unsupervised learning approaches. Graph representation learning approaches map nodes in an embedding space where the graph topological information and structure is preserved [13]. Such unsupervised representations can be consumed by any downstream model for task-specific prediction. Unsupervised methods employ a GNN encoder that generates node embedding and is supervised by a task dependent decoder. Typically decoders perform matrix factorization [22, 3, 5, 7, 19], random walks [11, 20], or a combination of learning tasks [16]. Language models (LM)s are powerful in modeling text data [9]. Harnessing the power of LMs with graph data is under-explored. This work details a framework for training large-scale LMs jointly with GNNs. Recent work [8] also identifies that pre-training BERT models in graph data can be beneficial and exploits a neighborhood prediction objective to enrich the BERT model with graph information. However, this work [8] did not explore to fine-tune the BERT and GNN model together. Another prominent work in [18, 32] trains GNN models for improving the search results in sponsored search. The work there can be seen as a special case of this framework, although [18, 32] did not explore the stage-wise fine-tuning that we introduce in this work.

3 Definitions and Problem formulation

A heterogeneous graph with TT node types and RR relation types is defined as 𝒢:={{𝒱t}t=1T,{r}r=1R}\mathcal{G}:=\{\{\mathcal{V}_{t}\}_{t=1}^{T},\{\mathcal{E}_{r}\}_{r=1}^{R}\}. The node types represent the different entities and the relation types represent how these entities are semantically associated to each other. For example, in the query-product network of Fig. 1(a), the node types correspond to queries and products, whereas the relation types may correspond to whether a product was clicked based on a query and whether a product was purchased based on a query relations. The number of nodes of type tt is denoted by NtN_{t} and its associated node set by 𝒱t:={nt}n=1Nt\mathcal{V}_{t}:=\{n_{t}\}_{n=1}^{N_{t}}. The total number of nodes in 𝒢\mathcal{G} is N:=t=1TNtN:=\sum_{t=1}^{T}{N_{t}} and the total number of edges is E:=r=1R|r|E:=\sum_{r=1}^{R}|\mathcal{E}_{r}| where |||| denotes the number of elements in the set. The rrth relation type, r:={(nt,r,nt)𝒱t×𝒱t}\mathcal{E}_{r}:=\{(n_{t},r,n^{\prime}_{t^{\prime}})\in\mathcal{V}_{t}\times\mathcal{V}_{t^{\prime}}\}, holds all interactions of a certain type among 𝒱t\mathcal{V}_{t} and 𝒱t\mathcal{V}_{t^{\prime}} and may represent that a product is was-clicked-based on a query.

Each node ntn_{t} is also associated with a short text. In the query-product network for example, each product is accompanied by a title and each query by the query text. Notice that different node types could have texts that are drawn from different distribution, e.g., the text for titles can be distinctly different from that for queries. Oftentimes, such text features are mapped to a F×1F\times 1 embedding vector 𝐱nt\mathbf{x}_{n_{t}} by transformers in a task independent fashion. In this paper, we will explore different methods to implement this mapping, that directly relate to the downstream application of interest.

4 LM-GNN Models : Adapt and fine-tune

In this paper, our high-level goal is to investigate how to fuse transformer and GNN models to learn informative representations from graph and text data. Our LM-GNN framework achieves this by stage-wise fine-tuning that gradually fuses the transformer with graph information.

4.1 Semantic encoder

We employ the BERT model [9] as the transformer in the LM-GNN framework to encode the nodes textual semantics. Given a node’s text BERT encodes the textual information to a F×1F\times 1 embedding vector 𝐱nt\mathbf{x}_{n_{t}}. This embedding vector corresponds to the [CLS] token embedding of the BERT model and the mapping from the text to the embedding is defined as 𝐱nt:=gBERT(nt;𝐖BERT)\mathbf{x}_{n_{t}}:=g_{\text{{BERT}}}(n_{t};\mathbf{W}_{\text{{BERT}}}). The mapping is controlled by the learnable parameters 𝐖BERT\mathbf{W}_{\text{{BERT}}}. These parameters can be instantiated by any language model pre-training approach, e.g., masked language modeling (MLM). Pre-training BERT models on large unlabeled text data has shown significant benefit in different LM applications. However, transferring this benefit for graph ML applications is not straightforward. We employ the techniques in Sec. 4.3 to pre-train BERT models with graph data. For different node-types t{1,,T}t\in\{1,\ldots,T\} in the graph we may consider different semantic encoders; e.g. queries and products.

4.2 Graph Encoder

Although the LM-GNN framework can utilize any GNN model as an encoder [27], in this paper LM-GNN uses a modified RGCN encoder [21]. RGCNs extend the graph convolution operation [17] to heterogeneous graphs. The llth self-RGCN layer computes the nnth node representation 𝐡n(l+1)\mathbf{h}^{(l+1)}_{n} as follows

𝐡n(l+1):=σ(𝐖self(l)𝐡n(l)+r=1Rn𝒩nr𝐖r(l)𝐡n(l)),\displaystyle\mathbf{h}^{(l+1)}_{n}:=\sigma\left(\mathbf{W}_{\text{self}}^{(l)}{\mathbf{h}}_{n}^{(l)}+\sum_{r=1}^{R}\sum_{n^{\prime}\in\mathcal{N}_{n}^{r}}\mathbf{W}^{(l)}_{r}{\mathbf{h}}_{n^{\prime}}^{(l)}\right), (1)

where 𝒩nr\mathcal{N}_{n}^{r} is the neighborhood of node nn under relation rr, σ\sigma the rectified linear unit non linear function, 𝐖r(l)\mathbf{W}^{(l)}_{r} is a learnable matrix associated with the rrth relation, and 𝐖self(l)\mathbf{W}_{\text{self}}^{(l)} is a projection matrix for the nodes embedding in layer ll. Our adaptation augments the messages at each layer with a projected transformation of the node embedding for that layer acting as a self loop with an appropriate weight matrix. This adaptation over the traditional RGCN the node’s own representation based on the semantic encoder is essential for the downstream applications and needs to be treated separately by the model. The matrix 𝐇\mathbf{H} in this paper represents the embedding extracted in the final layer. The node features are the input of the first layer in the model i.e., 𝐡n(0)=𝐱n\mathbf{h}^{(0)}_{n}=\mathbf{x}_{n}, where 𝐱n\mathbf{x}_{n} is the node feature for node nn.

4.3 Supervision approaches

Structure prediction task. Link prediction is a specific type of structure prediction, which will be our focus in the following section. Consider the heterogeneous graph 𝒢\mathcal{G} in Sec. 3. Given the sets of links {r}r=1R\{\mathcal{E}_{r}\}_{r=1}^{R}, and the node features the goal of link prediction is to predict whether a new set of node pairs are linked or not.

Typically, structure prediction models utilize a contrastive loss function that requires the model to distinguish among positive and negative examples [29]. In this context, positive examples are the set of existing links in the graph. The negative examples, which are links that the model should classify as nonexistent, are typically sampled from the missing links in the graph. For each positive triplet q=(nt,r,nt)q=(n_{t},{r},{n^{\prime}}_{t^{\prime}}) a number of negative links is generated by corrupting the head and tail entities at random (nt,r,νt)(n_{t},{r},{\nu^{\prime}}_{t^{\prime}}) and (νt,r,nt)(\nu_{t},{r},{n^{\prime}}_{t^{\prime}}).

The minimization function for link prediction can be defined as follows

(nt,r,nt)𝔻+𝔻log(1+exp(y×c(nt,r,nt)),\displaystyle\sum_{(n_{t},{r},{n^{\prime}}_{t^{\prime}})\in\mathbb{D}^{+}\cup\mathbb{D}^{-}}\log(1+\exp(-y\times c(n_{t},r,{n^{\prime}}_{t^{\prime}})), (2)

where cc is a scoring function that return as scalar given the head, and tail nodes and the relation such as the DistMult model [28], 𝔻+\mathbb{D}^{+} and 𝔻\mathbb{D}^{-} are the positive and negative sets of triplets and y=1y=1 if the triplet corresponds to a positive example and 1-1 otherwise.

Node prediction task. Each node n{n} has a label yn{0,,P1}y_{n}\in\{0,\ldots,P-1\}, which in the query-product network may represent the type of a product. In semi-supervised learning, we know labels only for a subset of nodes {yn}n\{y_{{n}}\}_{{n}\in\mathcal{M}}, with 𝒱\mathcal{M}\subset\mathcal{V}. This partial availability may be attributed to privacy concerns (medical data); energy considerations (sensor networks); or unrated items (recommender systems). The N×P{N}\times P matrix 𝐘\mathbf{Y} is the one-hot representation of the true node labels; that is, if yn=py_{n}=p then Ynp=1Y_{{n}p}=1 and Ynp=0,ppY_{{n}p^{\prime}}=0,\forall p^{\prime}\neq p. The minimization objective in this task is the cross-entropy loss.

Edge prediction task. Each link ll of a certain type r{r} may also associated with a label of interest ψl{0,,Π1}\psi_{l}\in\{0,\ldots,\Pi-1\}. For example, consider the query to product graph in Fig. 1(a), where the task becomes to predict whether a pair of connected (query, product) is an exact search match or not. Hence, here given an existing link we predict the class label, which is different from the structure prediction task in Sec. 4.3 where the objective is to predict the existence of a link.

The final predicted label for link ll is the output of the following edge classification decoder

ψl^=𝐖ec(𝐡nt𝐡nt)\displaystyle\hat{\psi_{l}}=\mathbf{W}_{\text{ec}}(\mathbf{h}_{n_{t}}\mathbin{\|}\mathbf{h}_{{n^{\prime}}_{t^{\prime}}}) (3)

where 𝐖ec\mathbf{W}_{\text{ec}} is a projection matrix of appropriate dimension and 𝐖ec\mathbf{W}_{\text{ec}} denotes concatenation. Hence, the predicted label for ll is a function of the entity embeddings for the nodes at the endpoints of the link nt,nt{n_{t}},{n^{\prime}}_{t^{\prime}}.

4.4 LM-GNN: Training at scale using graph and text

A straightforward approach would directly use the LM encoder as a semantic encoder that feeds representations to the GNN encoder, and train such an architecture in an end-to-end fashion. However, training large scale language models and graph neural networks involves challenges relating to efficiency and effectiveness.

Effectiveness challenges stem from the fact that the pre-trained language model is well optimized in language tasks but has not trained before in graph tasks, which surfaces three main issues. (1) Using such pre-trained transfomers may not be the most appropriate initialization and may trap the GNN to a sub-optimal local minimum. (2) Further, the well optimized transformer for the text tasks, may be more resistant in parameter updates. (3) Another hurdle stems from the random initialization of the GNN weights relative to the well-attuned transformer model, which may challenge the optimization of such an end-to-end framework.

Efficiency challenges relate to the large number of neighbors required by message passing in GNNs. In mini-batch training of a kk-layer GNN the kk-hop ego-network of every target node is created and the target node embedding is computed as a function of all the node in the expanded ego-network (also known as source nodes). The number of source nodes in an ego-network may be very large even for shallow GNNs. Alleviating this issue, recent GNN approaches apply random sampling [12] to reduce the number of neighbors. However, even with a shallow GNN (2 layers) and modest sampling (20 neighbors per layer), there are up to 400 source nodes. This remains a serious challenge in our unique setting where the transformer model needs to make 400 forward passes to calculate the embedding of a single target node. As a consequence the size of the required GPU is quite large even for small mini-batch sizes, which is a unique challenge in our framework.

PurchasedRunning shoesHiking shoesNike ZoomX VaporflyAddidas swiftProductQueryTimberland boots
(a) Query to products graph.
Running shoesHiking shoesNike ZoomX VaporflyAddidas swiftTimberland boots
(b) Transformer embeddings.
Running shoesHiking shoesNike ZoomX VaporflyAddidas swiftTimberland boots
(c) Graph-aware transformer embeddings.
Running shoesHiking shoesNike ZoomX VaporflyAddidas swiftTimberland boots
(d) LM-GNN embedding.
Figure 1: (a) The underlying graph among products and queries where an edge signifies that a query leads to the purchase of a product. (b-d) The 2-D projected embeddings as generated by different pipelines. (b) The transformer maps entities solely on text and fails to capture semantic similarity, besides language based e.g., shoes are close to boots. (c) The graph-aware transformer maps connected entities close, however fails at capturing higher order relations and embeds the two running shoes in different regions. (d)The proposed LM-GNN captures the connectivity, higher order structure, as well as text semantics and provides a refined representations useful for retrieval tasks.

4.4.1 Addressing effectiveness

Consider the search graph among products and queries depicted in Fig. 1(a). Such a graph is typically encountered in catalog systems for query-product datasets. One could attempt to directly use the embedding generated by a transformer as an input to a GNN model for further fine-tuning. However, the transfomer embedding of such a model will only take into account the text information and may introduce noise at message passing. Indeed, Fig. 1(b) shows that embeddings that are connected in the graph may be located in different regions of the embedding space. The poor performance of such a scheme is also detailed in the experiments; see Section 5. Our contribution in this context is to pre-fine-tune the transformer with graph information, which will endow the text embeddings with relational semantics and boost the performance when used as a semantic encoder.

Graph-aware pre-fine-tuning. We consider the structure prediction decoder that directly uses the scoring function cc instantiated in Section 4.3. The graph-aware transformer model directly uses the structure prediction decoder as a supervision to predict whether an edge exists among two nodes or not. Specifically, the transformer generates the CLS token embeddings for the text associated with the nodes and the vectors are contributing to the loss (2). The resulting graph-aware transformer embeddings respect both the semantics introduced by the language as well as the relations imposed by the graph; see also Fig. 1(c). The graph-aware pre-fine-tuning also results LM that is more suitable for the end-to-end training with the GNN, which is also supported in Section 5. The top part of Fig. 2 showcases the graph-aware pre-fine-tuning pipeline.

Our proposed framework LM-GNN employs the graph-aware transformer as a semantic encoder that first embeds the text and then is fed to the GNN encoder. However, since the GNN model is typically initialized at random this may challenge the end-to-end fine-tuning method and get trapped in not desirable local minima. Hence, we warm-start the GNN weights by keeping fixed the transformer weights for a few iterations and optimize only the GNN encoder. This way we can have a good initialization for the GNN model weights before we attempt the joint training. Finally, we fine-tune end-to-end the semantic encoders and GNN models for the downstream task. Such a scheme will provide a good initial point for the GNN model. The resulting embeddings abide by the text semantics, graph relations and the multi-hop graph structure; see Fig. 1(d). Our overall stage-wise fine-tuning pipeline is depicted in the bottom part of Fig. 2.

Running shoesNike ZoomX VaporflyGraph-awaretransformer Structure prediction decoder Running shoesNike ZoomX Vaporfly?
Running shoesNike ZoomX VaporflyGraph-awaretransformer Structure prediction decoder Running shoesNike ZoomX Vaporfly?GNN encoderLM-GNN
Figure 2: (Top) The graph-aware transformer framework relies on the input text to predict whether two entities are connected in the heterogenous graph. (Bottom) The LM-GNN framework employs the graph-aware transformer as a semantic encoder that is further fine-tuned using the GNN encoder, for predicting links in the heterogeneous graph. Different than the graph-aware transformer framework the LM-GNN can access nodes in multi-hop neighborhood.

4.4.2 Addressing efficiency

The high computation overhead and memory consumption required by the LM-GNN framework limits the wide applicability of the approach. Addressing these issues, we adopt several optimizations to efficiently train LM-GNN.

Back-propagate on samples. Instead of back-propagating gradients to the transformer models on all nodes, we sub-sample a fixed-size number of nodes (train nodes) in a mini-batch where we back-propagate gradients to the BERT model. For the rest of the nodes (inference nodes), we just run BERT forward computation to generate BERT embeddings. To further reduce memory consumption and allow training in limited GPU machines, we split the inference nodes into multiple sub-batches of fixed size.

Cache BERT embeddings. To further reduce transformer computations, we cache text embeddings of some nodes in a mini-batch. During the training, whenever we compute new text embeddings, we save them in the cache. Whenever we need node BERT embeddings, we fetch them from the cache. Some cached text embeddings may be out-of-date in a large graph, which may lower the overall model accuracy.

Joint negative sampling. Link prediction task requires positive and negative samples to be trained as detailed in Sec. 4.3. Hence, we construct kk negative edges for each positive edge in a mini-batch. By default, we sample kk negative edges for each positive edge independently, which requires us to sample k×nk\times n new nodes with nn is the number of positive links. The optimization for joint negative sample samples nn nodes and use them to construct kk negative edges jointly. Specifically, we reuse these nodes and randomly pair them with nodes in our positive set to generate negative pairs. As a result, we can significantly reduce the number of nodes in a mini-batch and accelerate training. The default method generates 2×n+k×n2\times n+k\times n end-point nodes and their neighbor nodes, while the joint negative sampling generates 3×n3\times n end-point nodes.

Distributed GNN training. Finally, to allow scallability to billion node graphs we exploit and extend the distributed GNN training framework [30] to accomodate our end-to-end fine-tuning setting. To increase the training efficiency, we apply hierarchical graph partitioning in DGL’s distributed training. When using this method, the target nodes/edges are sampled from the same partition. Therefore, when we sample their neighbors, it’s more likely that different target nodes may sample the same neighbor nodes and thus, reduce the number of nodes in a mini-batch.

5 Experimental setting

In the experiments we want to evaluate our techniques for improving the transformer representation with graph information to allow better joint fine-tuning with any GNN model. Hence, although LM-GNN can include any GNN model as an encoder here we will only evaluate the GraphSAGE [13] for homogenous graphs and the RGCN [21] for heterogenous graphs.

5.1 Datasets.

Public. Our unique setting requires graph datasets where the nodes are associated with text. We employ the arxiv, and products datesets from the OGB benchmark [14] with NN=169,343 and E=E=1,166,243 and NN=2,449,029 EE=61,859,140 respectively with the standard split ratios from [14]. We further augment the data with the original text features for each node; the data are collected in [8]. In the arxiv dataset the original title and abstract is used as text feature for the node. On the other hand the product dataset represents Amazon products and the product title was crawled from the web and used as the text feature. The benchmark in [14] provides text embeddings of the original text as features for the nodes. The task here is to predict the labels on the nodes in a standard semi-supervised setting. The labels are the type the paper and the category of product for arxiv and product respectively. For these datasets we also formulate a link prediction problem with splits 80% training, 10% validation and 10% testing and the tasks are predicting paper citation and product co-purchase links. Further we also construct the Yelp dataset augmented with the text using sources provided from [2]. The following node types are included with corresponding number of nodes business N1N_{1}=160,585, category N2N_{2}=1330, city N3N_{3}=836, review N4N_{4}=8,635,403, user N5N_{5}=2,189,457. The following edges are considered (user, friendship, user) E1E_{1}=17,971,548, (business, in, city) E2E_{2}=160585, (business, in category, category) E3E_{3}=708968, (review, on, business) E4E_{4}=8,635,403, (user, write, review) E5E_{5}=8,635,403. In this dataset only the review nodes are associated with text. The task here is to predict the stars for a business and is formulated as a node prediction task

Private. Additionally, we consider the dataset provided by the recent Amazon KDD22 challenge [1]. The graph structure is indeed similar to the one depicted in Fig. 1(a). There are N1=646,640N_{1}=646,640 product and N2=33,804N_{2}=33,804 query nodes in the graph and E=781,744E=781,744 edges that represent a match among the query and the product. Each edge in this dataset is associated with a label which corresponds to whether a match between the query and the product is an exact, substitute, complement or irrelevant. This problem is known as ESCI and is solved as an edge classification task. We create a custom split for this task by splitting the set of edges to 60% for training 10% for validation and 30% for testing. Finally, we also consider the private query-purchase-product dataset that is used to predict which product will be bought based on a query. Specifically, there are N1=130,191,253N_{1}=130,191,253 products and N2=5,878,377N_{2}=5,878,377 queries. Also there are the following behavioral data represented as edges among queries and products: Based on a query a product was added in the basked E1=E_{1}=15,163,234, a product was clicked E2=E_{2}=24,896,086, a product was returned as candidate E3=E_{3}=140,008,811, was purchased E4=E_{4}=13,138,944. The task here is given a query return the most probable product to be purchased. A natural formulation for this task is under the link prediction setting.

5.2 Baseline setting

Next we explain the different parameters that define the various approaches considered in this work.

Encoders. We consider the following possible semantic encoders in this work. BERT is the pre-trained BERT model from [26]. graph-aware BERT is the pre-trained BERT model [26] that we further fine-tune it for graph structure prediction as in equation (2). BERT-PR is a BERT model that is pre-trained using the MLM objective in the proprietary data of the company.
Task encoders. We consider 2 candidate encoders for the experiments presented here. MLP is a single-layer MLP that projects the text embeding to an appropriate dimension for node classification or for link prediction and allows us to circumvent to directly compare with the language model. This is used as a baseline approach to directly use the semantic encoder model in the downstream tasks. GraphSAGE is the model presented in [12] and is used as our baseline GNN model.
Fine-tune. This parameter defines whether we will back-propagate the loss to the semantic encoder during learning or not. The training is orders of magnitude faster when the loss is not back-propagated.
Warm-start. This parameter defines whether we will warm-start the GNN model by keeping the semantic encoder parameters fixed for some iterations before end-to-end fine-tuning.
Model configuration. We optimize the parameters such that the validation set performance is optimized. We select the number of GNN layers from 1,2,31,2,3, GNN hidden dimension from 128,256,512128,256,512 and learning rate from 103,104,10510^{-3},10^{-4},10^{-5}.

6 Experiments

6.1 Node classification

Table 1 collects the results for the public datasets for different training configurations and encoder models in node classification. Notice that the first two rows apply the node prediction loss directly on the node representation of the text embeddings after it is appropriately mapped by a single layer MLP. Fine-tuning in this context means that the gradient updates the parameters of the semantic encoder otherwise it is not besides the MLP parameters. The target node for the Yelp dataset does not have any text hence, we can not evaluate the first two settings for that. For this experiment the warm-start did not give significant improvement and hence was not included.

By comparing lines 2 and 3 we observe that fine-tuning the BERT directly for the downstream tasks and disregarding the graph structure achieves on-par performance as the one of keeping the BERT model fixed and using these representations as input to the GNN model. This suggests that the initial BERT embeddings are indeed not the most appropriate semantic embeddings. By comparing lines 3 and 4 we see a performance benefit of fine-tuning the BERT model through the GNN, since the multi-hop information is captured by the GNN. By comparing lines 3 and 5 we observe the clear advantage of the graph-aware BERT. The graph-aware pre-fine-tuning fuses the transformer with graph information and is the most suitable semantic encoder. Finally, line 6 coincides with the proposed LM-GNN framework. We observe that this configuration achieves the best overall performance and includes the proposed stage-wise fine-tuning approach. Hence, fine-tuning the BERT model for link prediction provides good performance in the node classification tasks. This result is very important since it allows to train a BERT model on link prediction and transfer the knowledge on different downstream tasks which may speed up the overall training.

Table 1: Node classification results for the public datasets. Results measured in accuracy.
Semantic encoder Graph encoder Fine-tune arxiv products Yelp
1   BERT MLP No 62.91 61.83 -
2   BERT MLP Yes 72.98 77.64 -
3   BERT GNN No 71.39 79.10 65.81
4   BERT GNN Yes 73.42 81.24 73.06
5   graph-aware BERT GNN No 73.79 80.53 66.88
6   graph-aware BERT GNN Yes 74.97 82.35 76.47

6.2 Link prediction

Table 2 collects the link prediction performance of the various baselines measured using the MRR score. The first row applies the link prediction supervision in (2) directly on the node representation of the text encoding after mapped by a single layer MLP to an embedding and the whole architecture is fine-tuned end-to-end. This model corresponds to the graph-aware pre-fine-tuned model for link prediction and is the same as the used as a semantic encoder in lines 5 and 6.

By comparing lines 1 versus 2, 3, and 4 we observe that the original BERT model is indeed not appropriate as a semantic encoder used with the GNN. On the other hand, fine-tuning the BERT model for link prediction in row 1 achieves a very good MRR performance. Lines 5, 6, 7 relative to 2, 3, 4 showcase the advantage of using the graph-aware pre-fine-tuning as an essential step in our LM-GNN framework, where the former leads to a large performance boost. Furthermore, by comparing 5 with 6 and 7 we observe the necessity of warm-starting in certain cases of the GNN encoder to avoid non-desirable local minima. Since the GNN model is initialized at random and the graph-aware BERT is well-trained optimizing this model without warm-start is challenging. By comparing lines 3 and 4 we see a performance benefit of fine-tuning the BERT model through the GNN, since the multi-hop information is captured by the GNN.

Convergence improvement. The warm-starting option behinds performance gains in Table 2 it provides significant training speed up. Specifically, for the ogbn-products dataset it takes 168 hours for the option without warm-start (row 6) to reach the maximum performance reported, whereas for the option with warm-start (row 7) it takes only 13 hours to reach the same MRR. Thus warm-start provides a 13x speed up in training speed.

Table 2: Link prediction results for the public datasets. The performance is measured in MRR scores. Note that the converged model for row 1 is the graph-aware BERT used in row 5, 6, 7.
Semantic encoder Graph encoder Warm-start Fine-tune arxiv products
7   graph-aware BERT MLP No Yes 59.32 82.29
8   BERT GNN No No 12.43 74.50
9   BERT GNN No Yes 10.11 72.13
10   BERT GNN Yes Yes 15.23 77.42
11   graph-aware BERT GNN No No 58.13 84.34
12   graph-aware BERT GNN No Yes 55.32 78.31
13   graph-aware BERT GNN Yes Yes 63.21 87.23

6.3 Public ESCI: edge classification

The Table 3 contains the edge classification results for the various baselines in ESCI. Note that here we also report the performance for each individual class since we are interested in predicting well also for the rare classes in our application.

By comparing rows 2 and 4 that both fine-tune the BERT-PR model we observe a strong boost of 320 bps in performance when the GNN is used. This indicates that the GNN can indeed help boosting the performance probably for the rare classes (S-C-I) by a large extend. By comparing rows 3 and 4 we observe that it is very important to fine-tune the BERT-PR model during GNN training. By comparing rows 3 and 6 we observe that the graph-aware pre-fine-tuning is giving a significant boost when the BERT embeddings are fixed. This benefit diminishes when the BERT model is fine-tuned. Finally, the performance in lines 5-8 is quite similar, but interesting the performance in the rare classes is maximized in rows 5 and 6. We plan to dive deeper into these results and analyze the performance for different sample sizes besides the current split.

Table 3: Edge classification results for the public ESCI dataset. The performance is measured in F1-score, all classes reported. The converged model in row 2 is the graph-aware BERT used in rows 6,7,8.
Semantic encoder Graph encoder Warm-start Fine-tune EvSvCvI E S C I
14   BERT-PR MLP No No 32.62 57.42 38.12 14.21 20.56
15   graph-aware BERT MLP No Yes 37.36 59.25 36.13 21.23 32.34
16   BERT-PR GNN No No 35.12 53.34 38.21 25.23 25.42
17   BERT-PR GNN No Yes 40.56 61.90 42.21 26.30 30.21
18   BERT-PR GNN Yes Yes 40.23 60.21 45.21 30.02 30.11
19   graph-aware BERT GNN No No 37.43 55.62 43.22 32.60 35.24
20   graph-aware BERT GNN No Yes 39.51 60.81 42.45 21.56 28.80
21   graph-aware BERT GNN Yes Yes 40.13 60.90 43.82 18.20 31.52

6.4 Query-purchase-product dataset

Table 4 collects the results for predicting if the search query leads to the purchase of a product and is treated as a link prediction task. The evaluation metric is the Macro recall at 100, which is the percentage of true products that exist in the top 100 retrieved products by each method. The results for fine-tuning the BERT model via the GNN model are ongoing. By comparing rows 1 and 3 that do not fine-tune the BERT-PR model we observe a very large performance boost by the GNN model that considers the graph structure. An even larger performance boost is observed when fine-tuning the BERT model via graph information in row 2. The best performance is observed by using the graph-aware BERT model as a fixed semantic-encoder for the GNN model. Future steps will focus on end-to-end fine-tuning.

Table 4: Macro recall at 100 for the query-purchase-product dataset. The converged model for row 2 is the graph-aware BERT in row 4.
Semantic encoder Graph encoder Fine-tune Macro@100
22   BERT-PR MLP No 34.12
23   graph-aware BERT MLP Yes 79.06
24   BERT-PR GNN No 77.26
25   graph-aware BERT GNN No 86.53

7 Conclusion

In this paper we develop a framework termed LM-GNN that achieves high-quality representations for graph data with rich textual features. Our framework employs stage-wise fine-tuning steps that allow for the BERT model to gradually adapt to the graph domain data. We prove with experiments in four public datasets and one query-purchase-product dataset the power of the LM-GNN framework.

References

  • [1] Amazon KDD 2022 challenge. [Online]. Available: https://www.aicrowd.com/challenges/esci-challenge-for-improving-product-search.
  • [2] Yelp dataset. [Online]. Available: https://www.yelp.com/dataset.
  • [3] Amr Ahmed, Nino Shervashidze, Shravan Narayanamurthy, Vanja Josifovski, and Alexander J Smola. Distributed large-scale natural graph factorization. In Proceedings of the 22nd international conference on World Wide Web, pages 37–48, 2013.
  • [4] M. Belkin, I. Matveeva, and P. Niyogi. Regularization and semi-supervised learning on large graphs. In Proc. Annual Conf. Learning Theory, volume 3120, pages 624–638, Banff, Canada, Jul. 2004. Springer.
  • [5] Mikhail Belkin and Partha Niyogi. Laplacian eigenmaps and spectral techniques for embedding and clustering. In Advances in neural information processing systems, pages 585–591, 2002.
  • [6] Michael M Bronstein, Joan Bruna, Yann LeCun, Arthur Szlam, and Pierre Vandergheynst. Geometric deep learning: going beyond euclidean data. IEEE Sig. Process. Mag., 34(4):18–42, 2017.
  • [7] Shaosheng Cao, Wei Lu, and Qiongkai Xu. Grarep: Learning graph representations with global structural information. In Proceedings of the 24th ACM international on conference on information and knowledge management, pages 891–900, 2015.
  • [8] Eli Chien, Wei-Cheng Chang, Cho-Jui Hsieh, Hsiang-Fu Yu, Jiong Zhang, Olgica Milenkovic, and Inderjit S Dhillon. Node feature extraction by self-supervised multi-scale neighborhood prediction. arXiv preprint arXiv:2111.00064, 2021.
  • [9] Jacob Devlin, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. Bert: Pre-training of deep bidirectional transformers for language understanding. arXiv preprint arXiv:1810.04805, 2018.
  • [10] Xinyu Fu, Jiani Zhang, Ziqiao Meng, and Irwin King. Magnn: Metapath aggregated graph neural network for heterogeneous graph embedding. In Proceedings of The Web Conference 2020, pages 2331–2341, 2020.
  • [11] Aditya Grover and Jure Leskovec. node2vec: Scalable feature learning for networks. In Proceedings of the 22nd ACM SIGKDD international conference on Knowledge discovery and data mining, pages 855–864, 2016.
  • [12] William L. Hamilton, Rex Ying, and Jure Leskovec. Inductive representation learning on large graphs. In Proceedings of the 31st International Conference on Neural Information Processing Systems, NIPS’17, page 1025–1035, 2017.
  • [13] William L Hamilton, Rex Ying, and Jure Leskovec. Representation learning on graphs: Methods and applications. arXiv preprint arXiv:1709.05584, 2017.
  • [14] Weihua Hu, Matthias Fey, Marinka Zitnik, Yuxiao Dong, Hongyu Ren, Bowen Liu, Michele Catasta, and Jure Leskovec. Open graph benchmark: Datasets for machine learning on graphs. In H. Larochelle, M. Ranzato, R. Hadsell, M. F. Balcan, and H. Lin, editors, Advances in Neural Information Processing Systems, volume 33, pages 22118–22133. Curran Associates, Inc., 2020.
  • [15] Vassilis N Ioannidis, Da Zheng, and George Karypis. Few-shot link prediction via graph neural networks for covid-19 drug-repurposing. In ICML 2020; Graph Representation Learning and Beyond workshop, 2020.
  • [16] Vassilis N Ioannidis, Da Zheng, and George Karypis. Panrep: Graph neural networks for extracting universal node embeddings in heterogeneous graphs. In KDD 2020; Workshop on Deep Learning on Graphs: Methods and Applications, 2020.
  • [17] Thomas N Kipf and Max Welling. Semi-supervised classification with graph convolutional networks. In Proc. Int. Conf. on Learn. Represantions, Toulon, France, Apr. 2017.
  • [18] Chaozhuo Li, Bochen Pang, Yuming Liu, Hao Sun, Zheng Liu, Xing Xie, Tianqi Yang, Yanling Cui, Liangjie Zhang, and Qi Zhang. Adsgnn: Behavior-graph augmented relevance modeling in sponsored search. In Proceedings of the 44th International ACM SIGIR Conference on Research and Development in Information Retrieval, pages 223–232, 2021.
  • [19] Mingdong Ou, Peng Cui, Jian Pei, Ziwei Zhang, and Wenwu Zhu. Asymmetric transitivity preserving graph embedding. In Proceedings of the 22nd ACM SIGKDD international conference on Knowledge discovery and data mining, pages 1105–1114, 2016.
  • [20] Bryan Perozzi, Rami Al-Rfou, and Steven Skiena. Deepwalk: Online learning of social representations. In Proceedings of the 20th ACM SIGKDD international conference on Knowledge discovery and data mining, pages 701–710, 2014.
  • [21] Michael Schlichtkrull, Thomas N Kipf, Peter Bloem, Rianne Van Den Berg, Ivan Titov, and Max Welling. Modeling relational data with graph convolutional networks. In European Semantic Web Conference, pages 593–607. Springer, 2018.
  • [22] Jian Tang, Meng Qu, Mingzhe Wang, Ming Zhang, Jun Yan, and Qiaozhu Mei. Line: Large-scale information network embedding. In Proceedings of the 24th international conference on world wide web, pages 1067–1077, 2015.
  • [23] Petar Veličković, Guillem Cucurull, Arantxa Casanova, Adriana Romero, Pietro Lio, and Yoshua Bengio. Graph attention networks. In Proc. Int. Conf. on Learn. Represantions, 2018.
  • [24] Quan Wang, Zhendong Mao, Bin Wang, and Li Guo. Knowledge graph embedding: A survey of approaches and applications. IEEE Transactions on Knowledge and Data Engineering, 29(12):2724–2743, 2017.
  • [25] Xiao Wang, Houye Ji, Chuan Shi, Bai Wang, Yanfang Ye, Peng Cui, and Philip S Yu. Heterogeneous graph attention network. In The World Wide Web Conference, pages 2022–2032, 2019.
  • [26] Thomas Wolf, Lysandre Debut, Victor Sanh, Julien Chaumond, Clement Delangue, Anthony Moi, Pierric Cistac, Tim Rault, Rémi Louf, Morgan Funtowicz, et al. Huggingface’s transformers: State-of-the-art natural language processing. arXiv preprint arXiv:1910.03771, 2019.
  • [27] Zonghan Wu, Shirui Pan, Fengwen Chen, Guodong Long, Chengqi Zhang, and S Yu Philip. A comprehensive survey on graph neural networks. IEEE Transactions on Neural Networks and Learning Systems, 2020.
  • [28] Bishan Yang, Wen-tau Yih, Xiaodong He, Jianfeng Gao, and Li Deng. Embedding entities and relations for learning and inference in knowledge bases. arXiv preprint arXiv:1412.6575, 2014.
  • [29] Da Zheng, Xiang Song, Chao Ma, Zeyuan Tan, Zihao Ye, Jin Dong, Hao Xiong, Zheng Zhang, and George Karypis. Dgl-ke: Training knowledge graph embeddings at scale. arXiv preprint arXiv:2004.08532, 2020.
  • [30] Da Zheng, Xiang Song, Chengru Yang, Dominique LaSalle, Qidong Su, Minjie Wang, Chao Ma, and George Karypis. Distributed hybrid cpu and gpu training for graph neural networks on billion-scale graphs. arXiv preprint arXiv:2112.15345, 2021.
  • [31] Yadi Zhou, Yuan Hou, Jiayu Shen, Yin Huang, William Martin, and Feixiong Cheng. Network-based drug repurposing for novel coronavirus 2019-ncov/sars-cov-2. Cell discovery, 6(1):1–18, 2020.
  • [32] Jason Zhu, Yanling Cui, Yuming Liu, Hao Sun, Xue Li, Markus Pelger, Tianqi Yang, Liangjie Zhang, Ruofei Zhang, and Huasha Zhao. Textgnn: Improving text encoder via graph neural network in sponsored search. In Proceedings of the Web Conference 2021, pages 2848–2857, 2021.