This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

ReGrAt: Regularization in Graphs using Attention to handle class imbalance

Neeraja Kirtane1, Jeshuren Chelladurai1, Balaraman Ravindran1, Ashish Tendulkar2
Abstract

Node classification is an important task to solve in graph-based learning. Even though a lot of work has been done in this field, imbalance is neglected. Real-world data is not perfect, and is imbalanced in representations most of the times. Apart from text and images, data can be represented using graphs, and thus addressing the imbalance in graphs has become of paramount importance. In the context of node classification, one class has less examples than others. Changing data composition is a popular way to address the imbalance in node classification. This is done by resampling the data to balance the dataset. However, that can sometimes lead to loss of information or add noise in the dataset. Therefore, in this work, we implicitly solve the problem by changing the model loss. Specifically, we study how attention networks can help tackle imbalance. Moreover, we observe that using a regularizer to assign larger weights to minority nodes helps to mitigate this imbalance. We achieve State of the Art results than the existing methods on several standard citation benchmark datasets.

Introduction

Many instances of real-world data can be represented in the form of graphs. One important task in graph learning is semi-supervised node classification (Yang, Cohen, and Salakhudinov 2016), where only a small amount of nodes are labeled. Kipf and Welling (2016) solves this problem using a basic GCN architecture. However, the architectures for node classification do not address the imbalance in the data and only perform well in a balanced setting. We say that data is imbalanced in the context of node classification when the number of nodes of one class is lesser than others. We call that class the minority class and the rest majority classes.

One real-life example of imbalance in graphs is the problem of fraud detection, where the members committing fraud are in the minority compared to the number of users not involved in any fraud. Here, the nodes in the graph would be the actual users or bots, and the edges would represent the interaction between these nodes. The number of bots will be very less compared to the actual users. The problem becomes harder when the setup is semi-supervised as the size of the labeled training set is very low. The trained model works well on the majority class and poorly on the minority class.

There are three broad categories to tackle imbalance in machine learning: data-level, model-level, and hybrid approaches. The data-level methods resample the data to make the dataset more balanced. However, that cannot be done every time as it leads to data loss, and sometimes, the data can lose its physical significance. For example, if we resample a protein structure by adding or deleting nodes, it would lose its structural identity. Therefore, is a need to implicitly handle the imbalance by making changes to the model than the data. The model-level approaches modify the loss function in such a way as to focus more on the minority nodes. Hybrid approaches combine both methods and create an ensemble of the models in the end.

In this work, we propose a model-based approach where we use Graph Attention networks with a regularizer to handle the imbalance. The regularizer is used by making the attention weights focus more on the minority nodes than the majority nodes. This ensures that the weights for minority nodes are higher than before, and hence they would be classified correctly. To the best of our knowledge, this is the first work that takes care of imbalance in graph networks while implicitly keeping the data intact. Our major contributions to this paper are as follows:

  • Implicitly solve the problem of class imbalance by changing the model loss.

  • Use attention mechanism to handle the class imbalance.

  • Our results outperform the existing results that tackle imbalance handling for citation network datasets.

Related work

Class imbalance

Handling class imbalance is a long-studied problem in the context of Machine Learning, and much work has been done previously on it. As mentioned earlier, the methods are broadly classified into three types: data-level, model-level, and hybrid approaches.

Data level methods usually resample the data in a way to make the distribution more balanced than before. SMOTE (Chawla et al. 2002) is a widely used state-of-the-art method that handles data imbalance by adding synthetic data to make data balanced. In SMOTE, artificial samples are produced by interpolating between existing minority samples and their nearest minority neighbors. SMOTE has been tried in the context of graphs to tackle imbalance where new nodes are created from minority classes. (Zhao, Zhang, and Wang 2021). An edge generator is used to find out the links between new nodes and the existing ones. An embedding space is constructed to encode the similarity among the nodes and generate new samples. In addition, an edge generator is trained simultaneously to model the relational information and provide it for those new samples.

Model-level methods used till now use a cost-sensitive approach which increases the priority of minority classes. The loss function is reweighted to achieve this. Hybrid approaches train multiple models using the methods mentioned above and then create an Ensemble of the models to get the final results (Shi et al. 2021).

Graph Neural Networks

Graph Neural Networks (GNNs) have received much attention and developing rapidly because of the need to work on non-linear structured data. Graph Convolutional Networks (GCNs) (Kipf and Welling 2016) perform similar operations like CNNs. The node in the model learns features by inspecting the neighboring nodes. The major difference between CNNs and GNNs is that CNNs are specially built to operate on structured data, whereas GNNs are the generalized version of CNNs where the numbers of node connections vary and the nodes are unordered. CNNs are used in Euclidean data and GNNs are used in non-Euclidean data. Current GNNs follow a message-passing framework, which is composed of pattern extraction and interaction modeling within each layer. Graph Attention Networks (Veličković et al. 2017) overcome the shortcoming of GCNs by having dynamic weights instead of having static ones. This is done by using masked self-attention layers.

Proposed Method

Methods Cora CiteSeer
Accuracy AUC-ROC Score F1 Score Accuracy AUC-ROC Score F1 Score
GCN with CE 0.804 0.959 0.783 0.673 0.897 0.609
GCN with FL 0.802 0.971 0.785 0.648 0.875 0.608
GraphSMOTE 0.774 0.953 0.768 0.31 0.632 0.283
Reg with CE 0.827 0.975 0.806 0.683 0.895 0.640
Reg with FL 0.823 0.976 0.805 0.668 0.893 0.615
Table 1: Results for Cora and Citeseer Datasets

Given a graph GG with VV nodes and EE edges, each node is labeled with one of the NN class labels. Further each node has a set of feature embeddings ff. Let mm be the number of features and XX is a feature embedding matrix with shape (V,m)(V,m). The connectivity and structure of nodes is represented by the adjacency matrix AdjAdj. Graph Attention networks have been used widely to do node classification (Veličković et al. 2017). They do the node classification based on XX and attention mechanism AA. The attention mechanism works as follows:

hv(k)=f(k)(W(k).[uN(v)avu(k1)hu(k1)+avvkhv(k1)])h^{(k)}_{v}=f^{(k)}(W^{(k)}.[\sum_{u\in N(v)}a^{(k-1)}_{vu}h^{(k-1)}_{u}+a^{k}_{vv}h^{(k-1)}_{v}]) (1)

Here hv(k)h^{(k)}_{v} is embeddings of node vv at step kk, where hv(0)h^{(0)}_{v} represents input feature embeddings of the node xvXx_{v}\in X. This is calculated for all nodes VV in the graph. The attention weights aka^{k} are generated by an attention mechanism AkA^{k} as follows normalized such that the sum over all neighbours of each node vv is 1:

avu(k)=A(k)(hv(k),hu(k))wN(v)(A(k)(hv(k),hw(k)))(v,u)Ea_{vu}^{(k)}=\frac{A^{(k)}(h_{v}^{(k)},h_{u}^{(k)})}{\sum_{w\in N(v)}(A^{(k)}(h_{v}^{(k)},h_{w}^{(k)}))}(v,u)\in E (2)

Multiple attention heads are used instead of one head as it helps to train the model better. The output of the hidden layers from multiple attention heads is concatenated to get the input of the next hidden layer. We obtain a probability distribution of the NN classes based on the output of the last hidden layer. This model is trained by using the cross entropy loss function (CElossCEloss).

yv^=F(hv(K))\hat{y_{v}}=F(h^{(K)}_{v}) (3)

yv^\hat{y_{v}} is the prediction of node vv. KK is the last hidden layer.

In its present form, the model does not work well for an imbalanced dataset. Further investigation reveals that the model focuses more on the majority nodes. We propose a regularizer to make the model focus on minority nodes. We modify the loss function as follows:

losstrain=loss+λ(regularization)loss_{train}=loss+\lambda*(regularization) (4)

losstrainloss_{train} is the loss function used for training the model. lossloss is a base loss function that is to be used with the regularizer. regularizationregularization is a regularizer used to handle the imbalance. λ\lambda is a hyperparameter used whose value can be between 0 to 1.

Loss functions used

We use two types of base loss functions which are used specifically when the data is imbalanced. Weighted cross entropy loss and Focal loss (Lin et al. 2017). Weighted cross-entropy loss assigns every class different weights based on the number of samples in the class. The weights are assigned as follows:

Wc=1N(c)W_{c}=\frac{1}{\sqrt{N(c)}} (5)

N(c)N(c) is the number of samples in class cc. This ensures that classes with less samples that is, the minority classes get more weightage than the majority class while calculating the loss.

Lin et al. (2017) proposed an algorithm that was helping solve the problem of extreme class imbalance in object detection problems by suggesting Focal loss. We use focal loss in the context of graphs. Focal loss reshapes the cross-entropy loss in such a way that it reduces the impact of easily classified examples and stresses more on the hard classified examples. This is done by multiplying the cross-entropy loss by a modulating factor, (1pt)γ(1-pt)^{\gamma}. The hyperparameter γ\gamma acts as a scaling factor. It adjusts the rate at which easy examples are down-weighted. For easily classified examples, where pt tends towards 1, it causes the modulating factor to approach 0 and helps in reducing the sample’s effect on the loss. Focal loss is calculated according to the following equation.

CE(pt)=log(pt)CE(p_{t})=log(p_{t}) (6)
FL(pt)=(1pt)γlog(pt)FL(p_{t})={-(1-p_{t})}^{\gamma}log(p_{t}) (7)

As the value of γ\gamma increases the model focuses more on the misclassified examples. Experiments have shown that using 0.6 as the γ\gamma value gives optimal results.

Regularization

We propose the following regularizer KLDivattKLDiv_{att}. Out of all the attention heads used, we make one of the attention heads focus solely on the minority nodes. We find out AdjminorityAdjAdj_{minority}\subset Adj by taking rows of only the minority nodes from the adjacency matrix. Similarly, we find out aminorityaa_{minority}\subset a by taking rows of only minority nodes from the attention weights. We calculate the attention regularizer term by taking KLDivergence of AdjminorityAdj_{minority} and aminoritya_{minority}. KLDivergence is used to measure the difference between two distributions. The AdjminorityAdj_{minority} matrix consists of zeros and ones. Its value is one when the node has a minority node as its neighbor and is zero otherwise. aminoritya_{minority} consists of the weights of minority nodes which are between zero and one. We want to maximise the weights and bring them numerically close to one of those nodes whose neighbours are minority nodes. This is similar to the distribution in AdjminorityAdj_{minority}. Therefore, we calculate KLDivergence as follows:

KLDivatt=Adjminoritylog(Adjminorityaminority)KLDiv_{att}=Adj_{minority}*log(\frac{Adj_{minority}}{a_{minority}}) (8)

We suggest a new loss function by adding KLDivattKLDiv_{att} to the cross entropy loss function already used.

losstrain=loss+λ(KLDivatt)loss_{train}=loss+\lambda*(KLDiv_{att}) (9)

The model is then trained on this custom loss function. λ\lambda is a hyperparameter which can be tuned between 0 to 1. Both the loss functions that are used aid in handling the imbalance along with the regularizer that we are using.

Experiments

Datasets

The datasets used in these experiments are Cora and CiteSeer. These are widely used benchmark datasets (Wu et al. 2020). Both are citation network datasets. The number of edges, nodes, node features and the number of classes are shown in Table 2.

Cora Citeseer
Number of Nodes 2708 3327
Number of Edges 5429 4732
Number of Features 1433 3703
Number of classes 7 6
Table 2: Description the dataset used

The distribution of classes is shown in Table 3. We observe that the distribution of classes is uneven and imbalanced. We assume L1, L3, L5 and L6 as minority classes for Cora and the rest as majority classes. Similarly, we assume L3 and L4 as our minority classes in Citeseer.

Labels L0 L1 L2 L3 L4 L5 L6
Cora 29 9* 16 13* 15 11* 7*
Citeseer 18 20 21 8* 15* 18 -
Table 3: Distribution of classes (%). (*=minority class)

Our experimental settings are as follows: A semi-supervised setup is used for node classification. Each class has 20 examples. The validation set has 500 nodes, and the test set has 1000 nodes. Learning rate is set to 5e-3 and weight decay is set to 5e-4 for all experiments. The test set has the same minority nodes as the training set as used by standard graph methods. We use a total of two attention heads while using Graph Attention networks. We vary the number of epochs varies between 100-1000. We found that training the model for 300 epochs had the best results while using the regularizer. We report the macro F1 score and the individual F1 scores of the minority nodes.

Table 1 shows the results of node classification for Cora and CiteSeer datasets. The methods that we use are: GCNs with weighted CE, GCNs with FL, existing GraphSMOTE method, Regularizer with weighted CE and Regularizer with FL. The metrics that we report are accuracy, AUC-ROC score and the macro F1 score. The macro-F1 score is the most suitable and important metric for node classification.

Methods Cora minority classes
L1 L3 L5 L6
GCN with CE 0.65 0.78 0.77 0.71
GCN with FL 0.85 0.88 0.74 0.73
GraphSMOTE 0.67 0.88 0.78 0.64
Reg with CE 0.85 0.92 0.74 0.73
Reg with FL 0.85 0.93 0.74 0.73
Table 4: Individual F1 scores for minority classes of Cora Dataset
Methods
CiteSeer
minority classes
L3 L4
GCN with CE 0.24 0.66
GCN with FL 0.29 0.73
GraphSMOTE 0.17 0.14
Reg with CE 0.32 0.71
Reg with FL 0.26 0.71
Table 5: Individual F1 scores for minority classes of CiteSeer Dataset

Comparison with GCNs: We run the experiments on Graph Convolutional Networks as our baseline using both weighted CE and FL as our loss. We see that all three performance metrics: Accuracy, AUC-ROC Score, and F1 Score, increase with the addition of Regularizer as shown in table 1. The individual F1 scores of minority nodes also increase, as shown in Table 4 for Cora and in Table 5 for CiteSeer dataset.

Comparison with GraphSMOTE: We compare our results with the existing state-of-the-art method to handle imbalance in Graph networks. We see a considerable rise in the F1 scores for both Cora and CiteSeer datasets, as shown in Table 1.

The method Reg+CEReg+CE works the best as it has the highest F1 Score. There is a rise in the individual minority F1 scores also apart from L5, as shown in Table 4 for the Cora dataset. The rise in the individual F1 Scores in minority classes, as shown in Tables 4, 5 indicates that the algorithm is focusing more on the minority nodes than before.

Conclusion and Future work

In this work, we try to solve the problem of node imbalance by making intrinsic changes to the graph attention model. We do this by using Graph Attention networks, making the weights focus more on the minority nodes than the majority nodes. We use a custom loss function to train the model with these constraints. Concretely, we do not change the data composition in any way because the change in data can lead to information loss or have unnecessary noise in the model. We observe that altering the attention weights by the addition of a regularizer significantly improves the results. We report the improvements in Cora and CiteSeer datasets compared to the existing methods. Individual F1 scores of minority classes empirically show that our method helps mitigate the imbalance by focusing more on minority classes.

We plan to implement similar techniques on larger datasets. We aim to check if this Regularization-based method or other intrinsic methods work for other tasks like edge-detection and graph detection. We also intend to explore how the sparsity of graphs is related to handling the imbalance.

References

  • Chawla et al. (2002) Chawla, N. V.; Bowyer, K. W.; Hall, L. O.; and Kegelmeyer, W. P. 2002. SMOTE: synthetic minority over-sampling technique. Journal of artificial intelligence research, 16: 321–357.
  • Kipf and Welling (2016) Kipf, T. N.; and Welling, M. 2016. Semi-supervised classification with graph convolutional networks. arXiv preprint arXiv:1609.02907.
  • Lin et al. (2017) Lin, T.-Y.; Goyal, P.; Girshick, R.; He, K.; and Dollár, P. 2017. Focal loss for dense object detection. In Proceedings of the IEEE international conference on computer vision, 2980–2988.
  • Shi et al. (2021) Shi, S.; Qiao, K.; Yang, S.; Wang, L.; Chen, J.; and Yan, B. 2021. Boosting-GNN: Boosting Algorithm for Graph Networks on Imbalanced Node Classification. Frontiers in Neurorobotics, 154.
  • Veličković et al. (2017) Veličković, P.; Cucurull, G.; Casanova, A.; Romero, A.; Lio, P.; and Bengio, Y. 2017. Graph attention networks. arXiv preprint arXiv:1710.10903.
  • Wu et al. (2020) Wu, Z.; Pan, S.; Chen, F.; Long, G.; Zhang, C.; and Philip, S. Y. 2020. A comprehensive survey on graph neural networks. IEEE transactions on neural networks and learning systems, 32(1): 4–24.
  • Yang, Cohen, and Salakhudinov (2016) Yang, Z.; Cohen, W.; and Salakhudinov, R. 2016. Revisiting semi-supervised learning with graph embeddings. In International conference on machine learning, 40–48. PMLR.
  • Zhao, Zhang, and Wang (2021) Zhao, T.; Zhang, X.; and Wang, S. 2021. GraphSMOTE: Imbalanced Node Classification on Graphs with Graph Neural Networks. arXiv preprint arXiv:2103.08826.