ML Katas

Graph Convolutional Network for Node Classification

hard (<30 mins) pytorch gnn graph gcn
yesterday by E

Description

Implement a simple Graph Convolutional Network (GCN) to perform node classification on a graph dataset like Cora. [1] A GCN layer aggregates information from a node's neighbors to update its feature representation.

Equation

The simplified GCN layer propagation rule is: H(l+1)=σ(D~12A~D~12H(l)W(l)) Where A~=A+IN is the adjacency matrix with self-loops, and D~ is the degree matrix of A~. Luckily, libraries like PyTorch Geometric abstract this for us.

Guidance

Your task is to build a nn.Module that uses pre-built GCNConv layers from the torch_geometric library. You will need to stack two such layers, with a ReLU activation and Dropout in between, to build a complete model for node classification.

Starter Code

# This example uses the torch_geometric library.
# You may need to install it: pip install torch_geometric
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class SimpleGCN(nn.Module):
    def __init__(self, num_node_features, num_classes):
        super(SimpleGCN, self).__init__()
        # 1. Initialize two GCNConv layers.
        # The first maps input features to a hidden dimension.
        # The second maps the hidden dimension to the number of classes.
        pass

    def forward(self, data):
        # The 'data' object from torch_geometric contains x and edge_index
        x, edge_index = data.x, data.edge_index

        # 1. Pass data through the first GCNConv layer.
        # 2. Apply a ReLU activation.
        # 3. Apply dropout.
        # 4. Pass the result through the second GCNConv layer.
        # 5. Return the log_softmax of the final output.
        pass

Verification

Use torch_geometric to load the Cora dataset. Train the model on the provided training nodes and evaluate the accuracy on the test nodes. A simple two-layer GCN should achieve a decent classification accuracy (>75-80% on Cora).

References

[1] Kipf, T. N., & Welling, M. (2016). Semi-Supervised Classification with Graph Convolutional Networks.

[2] PyTorch Geometric Library Documentation.