In the previous parts of this series, we looked at Graph Convolutional Networks (GCNs) and Graph Attention Networks (GATs). Both architectures work fine, but they also have some limitations! A big one is that for large graphs, calculating the node representations with GCNs and GATs will become v-e-r-y slow. Another limitation is that if the graph structure changes, GCNs and GATs will not be able to generalize. So if nodes are added to the graph, a GCN or GAT cannot make predictions for it. Luckily, these issues can be solved!
In this post, I will explain Graphsage and how it solves common problems of GCNs and GATs. We will train GraphSAGE and use it for graph predictions to compare performance with GCNs and GATs.
New to GNNs? You can start with post 1 about GCNs (also containing the initial setup for running the code samples), and post 2 about GATs.
Two Key Problems with GCNs and GATs
I shortly touched upon it in the introduction, but let’s dive a bit deeper. What are the problems with the previous GNN models?
Problem 1. They don’t generalize
GCNs and GATs struggle with generalizing to unseen graphs. The graph structure needs to be the same as the training data. This is known as transductive learning, where the model trains and makes predictions on the same fixed graph. It is actually overfitting to specific graph topologies. In reality, graphs will change: Nodes and edges can be added or removed, and this happens often in real world scenarios. We want our GNNs to be capable of learning patterns that generalize to unseen nodes, or to entirely new graphs (this is called inductive learning).
Problem 2. They have scalability issues
Training GCNs and GATs on large-scale graphs is computationally expensive. GCNs require repeated neighbor aggregation, which grows exponentially with graph size, while GATs involve (multihead) attention mechanisms that scale poorly with increasing nodes.
In big production recommendation systems that have large graphs with millions of users and products, GCNs and GATs are impractical and slow.
Let’s take a look at GraphSAGE to fix these issues.
GraphSAGE (SAmple and aggreGatE)
GraphSAGE makes training much faster and scalable. It does this by sampling only a subset of neighbors. For super large graphs it’s computationally impossible to process all neighbors of a node (except if you have limitless time, which we all don’t…), like with traditional GCNs. Another important step of GraphSAGE is combining the features of the sampled neighbors with an aggregation function.
We will walk through all the steps of GraphSAGE below.
1. Sampling Neighbors
With tabular data, sampling is easy. It’s something you do in every common machine learning project when creating train, test, and validation sets. With graphs, you cannot select random nodes. This can result in disconnected graphs, nodes without neighbors, etcetera:

What you can do with graphs, is selecting a random fixed-size subset of neighbors. For example in a social network, you can sample 3 friends for each user (instead of all friends):

2. Aggregate Information
After the neighbor selection from the previous part, GraphSAGE combines their features into one single representation. There are multiple ways to do this (multiple aggregation functions). The most common types and the ones explained in the paper are mean aggregation, LSTM, and pooling.
With mean aggregation, the average is computed over all sampled neighbors’ features (very simple and often effective). In a formula:
LSTM aggregation uses an LSTM (type of neural network) to process neighbor features sequentially. It can capture more complex relationships, and is more powerful than mean aggregation.
The third type, pool aggregation, applies a non-linear function to extract key features (think about max-pooling in a neural network, where you also take the maximum value of some values).
3. Update Node Representation
After sampling and aggregation, the node combines its previous features with the aggregated neighbor features. Nodes will learn from their neighbors but also keep their own identity, just like we saw before with GCNs and GATs. Information can flow across the graph effectively.
This is the formula for this step:
The aggregation of step 2 is done over all neighbors, and then the feature representation of the node is concatenated. This vector is multiplied by the weight matrix, and passed through non-linearity (for example ReLU). As a final step, normalization can be applied.
4. Repeat for Multiple Layers
The first three steps can be repeated multiple times, when this happens, information can flow from distant neighbors. In the image below you see a node with three neighbors selected in the first layer (direct neighbors), and two neighbors selected in the second layer (neighbors of neighbors).

To summarize, the key strengths of GraphSAGE are its scalability (sampling makes it efficient for massive graphs); flexibility, you can use it for Inductive learning (works well when used for predicting on unseen nodes and graphs); aggregation helps with generalization because it smooths out noisy features; and the multi-layers allow the model to learn from far-away nodes.
Cool! And the best thing, GraphSAGE is implemented in PyG, so we can use it easily in PyTorch.
Predicting with GraphSAGE
In the previous posts, we implemented an MLP, GCN, and GAT on the Cora dataset (CC BY-SA). To refresh your mind a bit, Cora is a dataset with scientific publications where you have to predict the subject of each paper, with seven classes in total. This dataset is relatively small, so it might be not the best set for testing GraphSAGE. We will do this anyway, just to be able to compare. Let’s see how well GraphSAGE performs.
Interesting parts of the code I like to highlight related to GraphSAGE:
- The
NeighborLoader
that performs selecting the neighbors for each layer:
from torch_geometric.loader import NeighborLoader
# 10 neighbors sampled in the first layer, 10 in the second layer
num_neighbors = [10, 10]
# sample data from the train set
train_loader = NeighborLoader(
data,
num_neighbors=num_neighbors,
batch_size=batch_size,
input_nodes=data.train_mask,
)
- The aggregation type is implemented in the
SAGEConv
layer. The default ismean
, you can change this tomax
orlstm
:
from torch_geometric.nn import SAGEConv
SAGEConv(in_c, out_c, aggr='mean')
- Another important difference is that GraphSAGE is trained in mini batches, and GCN and GAT on the full dataset. This touches the essence of GraphSAGE, because the neighbor sampling of GraphSAGE makes it possible to train in mini batches, we don’t need the full graph anymore. GCNs and GATs do need the complete graph for correct feature propagation and calculation of attention scores, so that’s why we train GCNs and GATs on the full graph.
- The rest of the code is similar as before, except that we have one class where all different models are instantiated based on the
model_type
(GCN, GAT, or SAGE). This makes it easy to compare or make small changes.
This is the complete script, we train 100 epochs and repeat the experiment 10 times to calculate average accuracy and standard deviation for each model:
import torch
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv, GCNConv, GATConv
from torch_geometric.datasets import Planetoid
from torch_geometric.loader import NeighborLoader
# dataset_name can be 'Cora', 'CiteSeer', 'PubMed'
dataset_name = 'Cora'
hidden_dim = 64
num_layers = 2
num_neighbors = [10, 10]
batch_size = 128
num_epochs = 100
model_types = ['GCN', 'GAT', 'SAGE']
dataset = Planetoid(root='data', name=dataset_name)
data = dataset[0]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data = data.to(device)
class GNN(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, num_layers, model_type='SAGE', gat_heads=8):
super().__init__()
self.convs = torch.nn.ModuleList()
self.model_type = model_type
self.gat_heads = gat_heads
def get_conv(in_c, out_c, is_final=False):
if model_type == 'GCN':
return GCNConv(in_c, out_c)
elif model_type == 'GAT':
heads = 1 if is_final else gat_heads
concat = False if is_final else True
return GATConv(in_c, out_c, heads=heads, concat=concat)
else:
return SAGEConv(in_c, out_c, aggr='mean')
if model_type == 'GAT':
self.convs.append(get_conv(in_channels, hidden_channels))
in_dim = hidden_channels * gat_heads
for _ in range(num_layers - 2):
self.convs.append(get_conv(in_dim, hidden_channels))
in_dim = hidden_channels * gat_heads
self.convs.append(get_conv(in_dim, out_channels, is_final=True))
else:
self.convs.append(get_conv(in_channels, hidden_channels))
for _ in range(num_layers - 2):
self.convs.append(get_conv(hidden_channels, hidden_channels))
self.convs.append(get_conv(hidden_channels, out_channels))
def forward(self, x, edge_index):
for conv in self.convs[:-1]:
x = F.relu(conv(x, edge_index))
x = self.convs[-1](x, edge_index)
return x
@torch.no_grad()
def test(model):
model.eval()
out = model(data.x, data.edge_index)
pred = out.argmax(dim=1)
accs = []
for mask in [data.train_mask, data.val_mask, data.test_mask]:
accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum()))
return accs
results = {}
for model_type in model_types:
print(f'Training {model_type}')
results[model_type] = []
for i in range(10):
model = GNN(dataset.num_features, hidden_dim, dataset.num_classes, num_layers, model_type, gat_heads=8).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
if model_type == 'SAGE':
train_loader = NeighborLoader(
data,
num_neighbors=num_neighbors,
batch_size=batch_size,
input_nodes=data.train_mask,
)
def train():
model.train()
total_loss = 0
for batch in train_loader:
batch = batch.to(device)
optimizer.zero_grad()
out = model(batch.x, batch.edge_index)
loss = F.cross_entropy(out, batch.y[:out.size(0)])
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss / len(train_loader)
else:
def train():
model.train()
optimizer.zero_grad()
out = model(data.x, data.edge_index)
loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
return loss.item()
best_val_acc = 0
best_test_acc = 0
for epoch in range(1, num_epochs + 1):
loss = train()
train_acc, val_acc, test_acc = test(model)
if val_acc > best_val_acc:
best_val_acc = val_acc
best_test_acc = test_acc
if epoch % 10 == 0:
print(f'Epoch {epoch:02d} | Loss: {loss:.4f} | Train: {train_acc:.4f} | Val: {val_acc:.4f} | Test: {test_acc:.4f}')
results[model_type].append([best_val_acc, best_test_acc])
for model_name, model_results in results.items():
model_results = torch.tensor(model_results)
print(f'{model_name} Val Accuracy: {model_results[:, 0].mean():.3f} ± {model_results[:, 0].std():.3f}')
print(f'{model_name} Test Accuracy: {model_results[:, 1].mean():.3f} ± {model_results[:, 1].std():.3f}')
And here are the results:
GCN Val Accuracy: 0.791 ± 0.007
GCN Test Accuracy: 0.806 ± 0.006
GAT Val Accuracy: 0.790 ± 0.007
GAT Test Accuracy: 0.800 ± 0.004
SAGE Val Accuracy: 0.899 ± 0.005
SAGE Test Accuracy: 0.907 ± 0.004
Impressive improvement! Even on this small dataset, GraphSAGE outperforms GAT and GCN easily! I repeated this test for CiteSeer and PubMed datasets, and always GraphSAGE came out best.
What I like to note here is that GCN is still very useful, it’s one of the most effective baselines (if the graph structure allows it). Also, I didn’t do much hyperparameter tuning, but just went with some standard values (like 8 heads for the GAT multi-head attention). In larger, more complex and noisier graphs, the advantages of GraphSAGE become more clear than in this example. We didn’t do any performance testing, because for these small graphs GraphSAGE isn’t faster than GCN.
Conclusion
GraphSAGE brings us very nice improvements and benefits compared to GATs and GCNs. Inductive learning is possible, GraphSAGE can handle changing graph structures quite well. And we didn’t test it in this post, but neighbor sampling makes it possible to create feature representations for larger graphs with good performance.
Related
Optimizing Connections: Mathematical Optimization within Graphs
Graph Neural Networks Part 1. Graph Convolutional Networks Explained
Graph Neural Networks Part 2. Graph Attention Networks vs. GCNs
The post Graph Neural Networks Part 3: How GraphSAGE Handles Changing Graph Structure appeared first on Towards Data Science.
And how you can use it for large graphs
The post Graph Neural Networks Part 3: How GraphSAGE Handles Changing Graph Structure appeared first on Towards Data Science. Artificial Intelligence, Editors Pick, Graphsage, Inductive learning, Large graphs, Node representation Towards Data ScienceRead More


0 Comments