Graph Neural Networks
Introduction
Graphs are a type of data structure that consist of nodes and edges, and they are commonly found in natural sciences, social sciences, and the web. Unlike images, audio, and language, the underlying data in graphs is non-euclidean, which means it cannot be fully represented in euclidean spaces. This poses a challenge for applying standard deep learning techniques like Convolutional Neural Networks (CNNs) and Recurrent Neural Networks (RNNs) to model such data. In order to address this challenge, researchers have developed a class of deep learning methods known as Graph Neural Networks (GNNs). Bronstein et al.[1] provide an overview of applying deep learning techniques to non-euclidean data like graphs and manifolds. GNNs operate on graphs by propagating information through nodes and edges, enabling them to learn complex relationships and dependencies between nodes. They can be used for a variety of tasks, including node classification, link prediction, and graph classification. In this article, we cover graphs, their structure, GNNs and the math behind GNNs.
Structure of Graphs [2]
First, we need to understand the graph data structure. A graph consists of relations (edges) between entities (nodes). The following features are also commonly present,
- : Vertex (or node) attributes like node identity, number of neighbours
- : Edge (or link) attributes like edge weight
- : Global graph attributes like number of nodes, longest path
There are two commonly used data-structures to represent graphs,
1. Adjacency Matrix
One common way to represent connections in a matrix is using the adjacency matrix. If a graph contains N nodes, the adjacency matrix is an N x N matrix, with if node i and node j are connected else . Therefore, it only consists of 1s and 0s.
2. Adjacency Lists
An adjacency list is an efficient way to store sparse matrices (number of edges << ). It is basically a list storing all the edges in the graph. So for instance, the Adjacency list for the left most graph in Figure 1, would be,
[[1, 4], [2, 4], [3, 4]]
. You can see how instead of having to store 16 numbers for the 16 possible edges, we only store the edges.
Modelling Tasks on Graphs
Let's look at three tasks for which GNNs are commonly used,
- Node-Level Prediction: These are concerned with predicting the role/properties of each node within a graph. For instance, figure 2 is a commonly used example.
- Graph-Level Prediction: These are concerned with predicting the property of the entire graph. For instance, predicting molecular properties like potential energy.
- Edge-Level Prediction: These are concerned with predicting the presence/property of edges in a graph. For instance, predicting relations between subjects and objects in a knowledge graph.
Motivation for GNNs
Consider the task of predicting a molecular property given atom features. Suppose a molecule has N atoms and each atom/node has 2 features. In that case, the input feature matrix will be an N x 2 matrix. However, there is an issue when using conventional machine learning techniques. If I change the order of the nodes, a neural network will produce a distinct result. Standard neural networks do not have permutation-invariance, which means they can't maintain the same output when the order of nodes is altered. Since the graph's structure has no inherent order, changing the node's sequence should not impact a neural network's output.
To tackle this issue, Graph Neural Networks are designed keeping the permutation-invariance property in mind.
Graph Neural Networks
We will describe the working of these networks before going into mathematical details. A GNN generally consists of the following layers,
- Message Passing Layers: Update node by aggregating messages from neighbours
- Pooling Layer: Pool node embeddings to get a single graph level embedding
The embeddings of nodes, edges, and sometimes even entire graphs are learned during the training process, depending on the task at hand. Initially, these embeddings can be established using the properties of nodes, edges, and graphs.
Message Passing Layers
Typically, message passing layers update embeddings for each node using following three steps,
- Gather the node and edge embeddings of all the neighbours where is the set of neighbours of node
- Aggregate them using a specified function (sum, max, mean, attention, etc).
- Update the aggregated message using a neural network
Different variants of GNNs have different variations in how the message passing is done but the broad set of steps remain the same.
Pooling Layers
In these layers, the embeddings from the previous message passing layer are utilized and then pooled/aggregated to generate a single representation for the complete graph. These representations can be employed for downstream tasks, such as graph-level predictions. Commonly used pooling schemes are average, mean, sum, max.
How do GNNs guarantee permutation-invariance?
An important aspect of GNNs is permutation-invariance. If we look at the two components of GNNs carefully,
- Message Passing Layers possess permutation-equivariance, as they update the embeddings of nodes based on their neighbors, independent of the node order. Therefore, the network can recognize the same graph structure regardless of the node ordering, which is crucial for graph-based tasks where the node order should not affect the model's predictions.
- Pooling layers have permutation-invariance, which means that the output remains the same regardless of the order of the input. This is because pooling operations, such as max or average pooling, are commutative and associative, and do not rely on the ordering of the input elements. Therefore, the pooling layer output remains unchanged even if the order of nodes is permuted in the input graph.
Therefore, combining the two you can see how GNNs handle permutations to node order.
Popular Variants of Graph Neural Networks
There are various popular graph neural networks in literature. We look at two important variants,
In the equations shown below, represents node embeddings of node at the message passing layer and represents neighbours of node . Initially, they are set to node features and then are iteratively learned.
Graph Convolutional Networks (GCNs)
Intuition: In each GCN layer, we average the embeddings of the neighbours, transform it with linear weights and a non-linearity. By doing so we update the node embedding with the information of its neighbours in each GCN layer.
Message passing layers of GCNs perform the following,
- For a node with embedding , gather the embeddings of neighbouring nodes (). Sum them and normalize it by the node degree (number of neighbours)
- Transform the normalized embedding value () with a weight matrix
- Transform using
- Transform the sum of step 2 and 3 with (like an activation function) to get the new embedding for node
where , have potential learnable parameters.
For getting graph level embeddings, a permutation-invariant readout layer is used to gather node level embeddings from the last message passing layer to get a graph level embedding ,
Graph Attention Networks (GATs)
Intuition: In each GAT layer, instead of equally weighting all neighbour messages like GCNs, we use the attention weights derived using the neighbour and node embeddings. By doing so, we can attend to important messages from neighbours.
GATs use attention weights to aggregate messages from neighbours. The message passing equation for the th layer performs the following steps,
- For a node with embedding , gather the embeddings of neighbouring nodes.
- Perform a weighted sum using attention weights calculated using the usual attention mechanism ().
- Transform using self-attention weight
- Add the embeddings from step 2 and 3 and transform the dimension with a weight matrix
- Apply activation on the output of step 4
where , are learnable parameters. is using a usual attention mechanism[5] shared across all nodes. We can use a readout layer as described in the previous section to get graph level embeddings.
Problems with Graph Neural Networks
GNNs are useful in building expressive representations and have delivered state of the art performance on various tasks. However, there are certain theoretical challenges associated with graph neural networks because of message passing layers. We have a look at two of those challenges,
Oversmoothing and Overquashing
The phenomenon of oversmoothing happens when repeat aggregations in message passing neural networks cause node features converge. Whereas oversquashing occurs when the receptive field of nodes grows exponentially with each message passing layer. Since node features are fixed-size, aggregating information from a large number of nodes might result in "squashing" of information.
Message Passing Networks are at most 1-WL powerful[6][7]
Let's look at an example, the figure on the right clearly has two graphs with different structures, but if we apply standard message passing networks like GCN or GAT, it will result in the same graph embedding for the two graphs.
The k-Weisfeiler-Lehman (k-WL) test is a graph isomorphism test that works by iteratively coloring the nodes of a graph based on the multiset of colors of its neighbors and comparing the resulting color sequences. The final histograms of colours are compared and if they are not the same then they are not isomorphic. Standard message passing networks are bounded by 1-WL due to which they are often not able to detect non-isomorphic graphs like shown in the example.
Python Implementations
There have been various packages developed for using GNNs with graph-structured datasets. Some of the commonly used packages are as follows,
Follow-up Work
Given the various challenges associated with GNNs, this field has been an active area of research. There has been significant practical and theoretical work done to reduce oversquashing and oversmoothing of GNNs and increase the expressivity of these networks. Some interesting directions are as follows,
- Utilizing the transformer architecture for graphs to allow modelling long-distance node relations[8]
- Improving message passing using multiple aggregators[9]
References
- ↑ Bronstein MM, Bruna J, LeCun Y, Szlam A, Vandergheynst P. Geometric deep learning: going beyond euclidean data. IEEE Signal Processing Magazine. 2017 Jul 11;34(4):18-42.
- ↑ https://distill.pub/2021/gnn-intro/
- ↑ Kipf TN, Welling M. Semi-supervised classification with graph convolutional networks. arXiv preprint arXiv:1609.02907. 2016 Sep 9.
- ↑ Veličković P, Cucurull G, Casanova A, Romero A, Lio P, Bengio Y. Graph attention networks. arXiv preprint arXiv:1710.10903. 2017 Oct 30.
- ↑ Vaswani A, Shazeer N, Parmar N, Uszkoreit J, Jones L, Gomez AN, Kaiser Ł, Polosukhin I. Attention is all you need. Advances in neural information processing systems. 2017;30.
- ↑ Provably expressive graph neural networkshttps://blog.twitter.com/engineering/en_us/topics/insights/2021/provably-expressive-graph-neural-networks
- ↑ Xu K, Hu W, Leskovec J, Jegelka S. How powerful are graph neural networks?. arXiv preprint arXiv:1810.00826. 2018 Oct 1.
- ↑ Kreuzer D, Beaini D, Hamilton W, Létourneau V, Tossou P. Rethinking graph transformers with spectral attention. Advances in Neural Information Processing Systems. 2021 Dec 6;34:21618-29.
- ↑ Corso G, Cavalleri L, Beaini D, Liò P, Veličković P. Principal neighbourhood aggregation for graph nets. Advances in Neural Information Processing Systems. 2020;33:13260-71.