Skip to content

Commit 87b0839

Browse files
authored
[Doc] Add Graph Transformer Tutorial Documentation (#6889)
1 parent 4323986 commit 87b0839

File tree

4 files changed

+192
-0
lines changed

4 files changed

+192
-0
lines changed

docs/source/graphtransformer/data.rst

+90
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
Prepare Data
2+
============
3+
4+
In this section, we will prepare the data for the Graphormer model introduced before. We can use any dataset containing :class:`~dgl.DGLGraph` objects and standard PyTorch dataloader to feed the data to the model. The key is to define a collate function to group features of multiple graphs into batches. We show an example of the collate function as follows:
5+
6+
7+
.. code:: python
8+
def collate(graphs):
9+
# compute shortest path features, can be done in advance
10+
for g in graphs:
11+
spd, path = dgl.shortest_dist(g, root=None, return_paths=True)
12+
g.ndata["spd"] = spd
13+
g.ndata["path"] = path
14+
15+
num_graphs = len(graphs)
16+
num_nodes = [g.num_nodes() for g in graphs]
17+
max_num_nodes = max(num_nodes)
18+
19+
attn_mask = th.zeros(num_graphs, max_num_nodes, max_num_nodes)
20+
node_feat = []
21+
in_degree, out_degree = [], []
22+
path_data = []
23+
# Since shortest_dist returns -1 for unreachable node pairs and padded
24+
# nodes are unreachable to others, distance relevant to padded nodes
25+
# use -1 padding as well.
26+
dist = -th.ones(
27+
(num_graphs, max_num_nodes, max_num_nodes), dtype=th.long
28+
)
29+
30+
for i in range(num_graphs):
31+
# A binary mask where invalid positions are indicated by True.
32+
# Avoid the case where all positions are invalid.
33+
attn_mask[i, :, num_nodes[i] + 1 :] = 1
34+
35+
# +1 to distinguish padded non-existing nodes from real nodes
36+
node_feat.append(graphs[i].ndata["feat"] + 1)
37+
38+
# 0 for padding
39+
in_degree.append(
40+
th.clamp(graphs[i].in_degrees() + 1, min=0, max=512)
41+
)
42+
out_degree.append(
43+
th.clamp(graphs[i].out_degrees() + 1, min=0, max=512)
44+
)
45+
46+
# Path padding to make all paths to the same length "max_len".
47+
path = graphs[i].ndata["path"]
48+
path_len = path.size(dim=2)
49+
# shape of shortest_path: [n, n, max_len]
50+
max_len = 5
51+
if path_len >= max_len:
52+
shortest_path = path[:, :, :max_len]
53+
else:
54+
p1d = (0, max_len - path_len)
55+
# Use the same -1 padding as shortest_dist for
56+
# invalid edge IDs.
57+
shortest_path = th.nn.functional.pad(path, p1d, "constant", -1)
58+
pad_num_nodes = max_num_nodes - num_nodes[i]
59+
p3d = (0, 0, 0, pad_num_nodes, 0, pad_num_nodes)
60+
shortest_path = th.nn.functional.pad(shortest_path, p3d, "constant", -1)
61+
# +1 to distinguish padded non-existing edges from real edges
62+
edata = graphs[i].edata["feat"] + 1
63+
64+
# shortest_dist pads non-existing edges (at the end of shortest
65+
# paths) with edge IDs -1, and th.zeros(1, edata.shape[1]) stands
66+
# for all padded edge features.
67+
edata = th.cat(
68+
(edata, th.zeros(1, edata.shape[1]).to(edata.device)), dim=0
69+
)
70+
path_data.append(edata[shortest_path])
71+
72+
dist[i, : num_nodes[i], : num_nodes[i]] = graphs[i].ndata["spd"]
73+
74+
# node feat padding
75+
node_feat = th.nn.utils.rnn.pad_sequence(node_feat, batch_first=True)
76+
77+
# degree padding
78+
in_degree = th.nn.utils.rnn.pad_sequence(in_degree, batch_first=True)
79+
out_degree = th.nn.utils.rnn.pad_sequence(out_degree, batch_first=True)
80+
81+
return (
82+
node_feat,
83+
in_degree,
84+
out_degree,
85+
attn_mask,
86+
th.stack(path_data),
87+
dist,
88+
)
89+
90+
In this example, we also omit details like the addition of a virtual node. For more details, please refer to the `Graphormer example <https://github.com/dmlc/dgl/tree/master/examples/core/Graphormer>`_.
+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
🆕 Tutorial: GraphTransformer
2+
==========
3+
4+
This tutorial introduces the **graphtransformer** module, which is a set of
5+
utility modules for building and training graph transformer models.
6+
7+
.. toctree::
8+
:maxdepth: 2
9+
:titlesonly:
10+
11+
model
12+
data
+89
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
Build Model
2+
===========
3+
4+
**GraphTransformer** is a graph neural network that uses multi-head self-attention (sparse or dense) to encode the graph structure and node features. It is a generalization of the `Transformer <https://arxiv.org/abs/1706.03762>`_ architecture to arbitrary graphs.
5+
6+
In this tutorial, we will show how to build a graph transformer model with DGL using the `Graphormer <https://arxiv.org/abs/2106.05234>`_ model as an example.
7+
8+
Graphormer is a Transformer model designed for graph-structured data, which encodes the structural information of a graph into the standard Transformer. Specifically, Graphormer utilizes degree encoding to measure the importance of nodes, spatial and path Encoding to measure the relation between node pairs. The degree encoding and the node features serve as input to Graphormer, while the spatial and path encoding act as bias terms in the self-attention module.
9+
10+
Degree Encoding
11+
-------------------
12+
The degree encoder is a learnable embedding layer that encodes the degree of each node into a vector. It takes as input the batched input and output degrees of graph nodes, and outputs the degree embeddings of the nodes.
13+
14+
.. code:: python
15+
degree_encoder = dgl.nn.DegreeEncoder(
16+
max_degree=8, # the maximum degree to cut off
17+
embedding_dim=512 # the dimension of the degree embedding
18+
)
19+
20+
Path Encoding
21+
-------------
22+
The path encoder encodes the edge features on the shortest path between two nodes to get attention bias for the self-attention module. It takes as input the batched edge features in shape and outputs the attention bias based on path encoding.
23+
24+
.. code:: python
25+
path_encoder = PathEncoder(
26+
max_len=5, # the maximum length of the shortest path
27+
feat_dim=512, # the dimension of the edge feature
28+
num_heads=8, # the number of attention heads
29+
)
30+
31+
Spatial Encoding
32+
----------------
33+
The spatial encoder encodes the shortest distance between two nodes to get attention bias for the self-attention module. It takes as input the shortest distance between two nodes and outputs the attention bias based on spatial encoding.
34+
35+
.. code:: python
36+
spatial_encoder = SpatialEncoder(
37+
max_dist=5, # the maximum distance between two nodes
38+
num_heads=8, # the number of attention heads
39+
)
40+
41+
42+
Graphormer Layer
43+
----------------
44+
The Graphormer layer is like a Transformer encoder layer with the Multi-head Attention part replaced with :class:`~dgl.nn.BiasedMHA`. It takes in not only the input node features, but also the attention bias computed computed above, and outputs the updated node features.
45+
46+
We can stack multiple Graphormer layers as a list just like implementing a Transformer encoder in PyTorch.
47+
48+
.. code:: python
49+
layers = th.nn.ModuleList([
50+
GraphormerLayer(
51+
feat_size=512, # the dimension of the input node features
52+
hidden_size=1024, # the dimension of the hidden layer
53+
num_heads=8, # the number of attention heads
54+
dropout=0.1, # the dropout rate
55+
activation=th.nn.ReLU(), # the activation function
56+
norm_first=False, # whether to put the normalization before attention and feedforward
57+
)
58+
for _ in range(6)
59+
])
60+
61+
Model Forward
62+
-------------
63+
Grouping the modules above defines the primary components of the Graphormer model. We then can define the forward process as follows:
64+
65+
.. code:: python
66+
node_feat, in_degree, out_degree, attn_mask, path_data, dist = \
67+
next(iter(dataloader)) # we will use the first batch as an example
68+
num_graphs, max_num_nodes, _ = node_feat.shape
69+
deg_emb = degree_encoder(th.stack((in_degree, out_degree)))
70+
71+
# node feature + degree encoding as input
72+
node_feat = node_feat + deg_emb
73+
74+
# spatial encoding and path encoding serve as attention bias
75+
path_encoding = path_encoder(dist, path_data)
76+
spatial_encoding = spatial_encoder(dist)
77+
attn_bias[:, 1:, 1:, :] = path_encoding + spatial_encoding
78+
79+
# graphormer layers
80+
for layer in layers:
81+
x = layer(
82+
x,
83+
attn_mask=attn_mask,
84+
attn_bias=attn_bias,
85+
)
86+
87+
For simplicity, we omit some details in the forward process. For the complete implementation, please refer to the `Graphormer example <https://github.com/dmlc/dgl/tree/master/examples/core/Graphormer`_.
88+
89+
You can also explore other `utility modules <https://docs.dgl.ai/api/python/nn-pytorch.html#utility-modules-for-graph-transformer>`_ to customize your own graph transformer model. In the next section, we will show how to prepare the data for training.

docs/source/index.rst

+1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ Welcome to Deep Graph Library Tutorials and Documentation
2626
guide/index
2727
guide_cn/index
2828
guide_ko/index
29+
graphtransformer/index
2930
notebooks/sparse/index
3031
tutorials/cpu/index
3132
tutorials/multi/index

0 commit comments

Comments
 (0)