GammaGL is a multi-backend graph learning library based on TensorLayerX, which supports TensorFlow, PyTorch, PaddlePaddle, MindSpore as the backends.
We give a development tutorial in Chinese on wiki.
GammaGL supports multiple deep learning backends, such as TensorFlow, PyTorch, Paddle and MindSpore. Different from DGL, the GammaGL's examples are implemented with the same code on different backend. It allows users to run the same code on different hardwares like Nvidia-GPU and Huawei-Ascend. Besides, users could use a particular framework API based on preferences for different frameworks.
Following PyTorch Geometric(PyG), GammaGL utilizes a tensor-centric API. If you are familiar with PyG, it will be friendly and maybe a TensorFlow Geometric, Paddle Geometric, or MindSpore Geometric to you.
2024-07-29 release v0.5
We release the latest version v0.5
- 70 GNN models
- More fused operators
- Support GPU sample
- Support GraphStore and FeatureStore
2024-01-24 release v0.4
We release the latest version v0.4.
- 60 GNN models
- More fused operators and users can truly use these operators
- Support the latest version of PyTorch and MindSpore
- Support for graph database like neo4j
2023-07-12 release v0.3
We release the latest version v0.3.
- 50 GNN models
- Efficient message passing operators and fused operator
- Rebuild sampling architecture.
2023-04-01 paper accepted
Our paper GammaGL: A Multi-Backend Library for Graph Neural Networks is accpeted at SIGIR 2023 resource paper track.
2023-02-21 中国电子学会科技进步一等奖
算法库支撑了北邮牵头,蚂蚁、中移动、海致科技等参与的“大规模复杂异质图数据智能分析技术与规模化应用”项目。该项目获得了2022年电子学会科技进步一等奖。
2023-01-17 release v0.2
We release the latest version v0.2.
- 40 GNN models
- 20 datasets
- Efficient message passing operators and fused operator
- GPU sampling and heterogeneous graphs samplers.
2022-06-20 release v0.1
We release the latest version v0.1.
- Framework-agnostic design
- PyG-like
- Graph data structures, message passing module and sampling module
- 20+ GNN models
Currently, GammaGL requires Python Version >= 3.9 and is only supported on Linux operating systems.
-
Python environment (Optional): We recommend using Conda package manager
$ conda create -n ggl python=3.9 $ source activate ggl
-
Install Backend
# For tensorflow $ pip install tensorflow-gpu # GPU version $ pip install tensorflow # CPU version # For torch, version 2.1+cuda 11.8 # https://pytorch.org/get-started/locally/ $ pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 # For paddle, any latest stable version # https://www.paddlepaddle.org.cn/ $ python -m pip install paddlepaddle-gpu # For mindspore, GammaGL supports version 2.2.0, GPU-CUDA 11.6 # https://www.mindspore.cn/install $ pip install https://ms-release.obs.cn-north-4.myhuaweicloud.com/2.2.0/MindSpore/unified/x86_64/mindspore-2.2.0-cp39-cp39-linux_x86_64.whl --trusted-host ms-release.obs.cn-north-4.myhuaweicloud.com -i https://pypi.tuna.tsinghua.edu.cn/simple
For other backend with specific version, please check whether TLX supports.
Install TensorLayerX
pip install git+https://github.com/dddg617/tensorlayerx.git@nightly
Note:
- PyTorch is necessary when installing TensorLayerX.
- This TensorLayerX is supported by BUPT GAMMA Lab Team.
-
Download GammaGL
You may download the nightly version through the following commands:
$ git clone --recursive https://github.com/BUPT-GAMMA/GammaGL.git $ pip install pybind11 pyparsing $ python setup.py install
大陆用户如果遇到网络问题,推荐从启智社区安装
Try to git clone from OpenI
git clone --recursive https://git.openi.org.cn/GAMMALab/GammaGL.git
Note:
"--recursive" is necessary, if you forgot, you can run command below in GammaGL root dir:
git submodule update --init
You may also download the stable version refer to our document.
In this quick tour, we highlight the ease of creating and training a GNN model with only a few lines of code.
In the first glimpse of GammaGL, we implement the training of a GNN for classifying papers in a citation graph.
For this, we load the Cora dataset, and create a simple 2-layer GCN model using the pre-defined GCNConv
:
import tensorlayerx as tlx
from gammagl.layers.conv import GCNConv
from gammagl.datasets import Planetoid
dataset = Planetoid(root='.', name='Cora')
class GCN(tlx.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()
self.conv1 = GCNConv(in_channels, hidden_channels)
self.conv2 = GCNConv(hidden_channels, out_channels)
self.relu = tlx.ReLU()
def forward(self, x, edge_index):
# x: Node feature matrix of shape [num_nodes, in_channels]
# edge_index: Graph connectivity matrix of shape [2, num_edges]
x = self.conv1(x, edge_index)
x = self.relu(x)
x = self.conv2(x, edge_index)
return x
model = GCN(dataset.num_features, 16, dataset.num_classes)
We can now optimize the model in a training loop, similar to the standard TensorLayerX training procedure.
import tensorlayerx as tlx
data = dataset[0]
loss_fn = tlx.losses.softmax_cross_entropy_with_logits
optimizer = tlx.optimizers.Adam(learning_rate=1e-3)
net_with_loss = tlx.model.WithLoss(model, loss_fn)
train_one_step = tlx.model.TrainOneStep(net_with_loss, optimizer, train_weights)
for epoch in range(200):
loss = train_one_step(data.x, data.y)
We can now optimize the model in a training loop, similar to the standard PyTorch training procedure.
import torch.nn.functional as F
data = dataset[0]
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
for epoch in range(200):
pred = model(data.x, data.edge_index)
loss = F.cross_entropy(pred[data.train_mask], data.y[data.train_mask])
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
We can now optimize the model in a training loop, similar to the standard TensorFlow training procedure.
import tensorflow as tf
optimizer = tf.keras.optimizers.Adam()
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
for epoch in range(200):
with tf.GradientTape() as tape:
predictions = model(images, training=True)
loss = loss_fn(labels, predictions)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
We can now optimize the model in a training loop, similar to the standard PaddlePaddle training procedure.
import paddle
data = dataset[0]
optim = paddle.optimizer.Adam(parameters=model.parameters())
loss_fn = paddle.nn.CrossEntropyLoss()
model.train()
for epoch in range(200):
predicts = model(data.x, data.edge_index)
loss = loss_fn(predicts, y_data)
# Backpropagation
loss.backward()
optim.step()
optim.clear_grad()
We can now optimize the model in a training loop, similar to the standard MindSpore training procedure.
# 1. Generate training dataset
train_dataset = create_dataset(num_data=160, batch_size=16)
# 2.Build a model and define the loss function
net = LinearNet()
loss = nn.MSELoss()
# 3.Connect the network with loss function, and define the optimizer
net_with_loss = nn.WithLossCell(net, loss)
opt = nn.Momentum(net.trainable_params(), learning_rate=0.005, momentum=0.9)
# 4.Define the training network
train_net = nn.TrainOneStepCell(net_with_loss, opt)
# 5.Set the model as training mode
train_net.set_train()
# 6.Training procedure
for epoch in range(200):
for d in train_dataset.create_dict_iterator():
result = train_net(d['data'], d['label'])
print(f"Epoch: [{epoch} / {epochs}], "
f"step: [{step} / {steps}], "
f"loss: {result}")
step = step + 1
More information about evaluating final model performance can be found in the corresponding example.
In addition to the easy application of existing GNNs, GammaGL makes it simple to implement custom Graph Neural Networks (see here for the accompanying tutorial). For example, this is all it takes to implement the edge convolutional layer from Wang et al.:
import tensorlayerx as tlx
from tensorlayerx.nn import Sequential as Seq, Linear, ReLU
from gammagl.layers import MessagePassing
class EdgeConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super().__init__()
self.mlp = Seq(Linear(2 * in_channels, out_channels),
ReLU(),
Linear(out_channels, out_channels))
def forward(self, x, edge_index):
# x has shape [N, in_channels]
# edge_index has shape [2, E]
return self.propagate(x=x, edge_index,aggr_type='max')
def message(self, x_i, x_j):
# x_i has shape [E, in_channels]
# x_j has shape [E, in_channels]
tmp = tlx.concat([x_i, x_j - x_i], axis=1) # tmp has shape [E, 2 * in_channels]
return self.mlp(tmp)
Take GCN as an example:
# cd ./examples/gcn
# set parameters if necessary
python gcn_trainer.py --dataset cora --lr 0.01
If you want to use specific backend
or GPU
, just set environment variable like:
CUDA_VISIBLE_DEVICES="1" TL_BACKEND="paddle" python gcn_trainer.py
Note
The DEFAULT backend is
torch
and GPU is0
.The backend TensorFlow will take up all GPU left memory by default.
The CANDIDATE backends are
tensorflow
,paddle
,torch
andmindspore
.Set
CUDA_VISIBLE_DEVICES=" "
if you want to run it in CPU.
Now, GammaGL supports about 70 models, we welcome everyone to use or contribute models.
TensorFlow | PyTorch | Paddle | MindSpore | |
---|---|---|---|---|
GCN [ICLR 2017] | ✔️ | ✔️ | ✔️ | ✔️ |
GAT [ICLR 2018] | ✔️ | ✔️ | ✔️ | ✔️ |
GraphSAGE [NeurIPS 2017] | ✔️ | ✔️ | ✔️ | ✔️ |
ChebNet [NeurIPS 2016] | ✔️ | ✔️ | ✔️ | ✔️ |
GCNII [ICLR 2017] | ✔️ | ✔️ | ✔️ | ✔️ |
You may see the other models here.
Contrastive Learning | TensorFlow | PyTorch | Paddle | MindSpore |
---|---|---|---|---|
DGI [ICLR 2019] | ✔️ | ✔️ | ✔️ | ✔️ |
GRACE [ICML 2020 Workshop] | ✔️ | ✔️ | ✔️ | ✔️ |
GRADE [NeurIPS 2022] | ✔️ | ✔️ | ✔️ | ✔️ |
MVGRL [ICML 2020] | ✔️ | ✔️ | ✔️ | ✔️ |
InfoGraph [ICLR 2020] | ✔️ | ✔️ | ✔️ | ✔️ |
MERIT [IJCAI 2021] | ✔️ | ✔️ | ✔️ | |
GNN-POT [NeurIPS 2023] | ✔️ | |||
MAGCL [AAAI 2023] | ✔️ | ✔️ | ✔️ | ✔️ |
Sp2GCL [NeurIPS 2023] | ✔️ |
Heterogeneous Graph Learning | TensorFlow | PyTorch | Paddle | MindSpore |
---|---|---|---|---|
RGCN [ESWC 2018] | ✔️ | ✔️ | ✔️ | ✔️ |
HAN [WWW 2019] | ✔️ | ✔️ | ✔️ | ✔️ |
HGT [WWW 2020] | ✔️ | ✔️ | ✔️ | ✔️ |
SimpleHGN [KDD 2021] | ✔️ | ✔️ | ||
CompGCN [ICLR 2020] | ✔️ | ✔️ | ✔️ | |
HPN [TKDE 2021] | ✔️ | ✔️ | ✔️ | ✔️ |
ieHGCN [TKDE 2021] | ✔️ | ✔️ | ✔️ | ✔️ |
MetaPath2Vec [KDD 2017] | ✔️ | ✔️ | ✔️ | ✔️ |
HERec [TKDE 2018] | ✔️ | ✔️ | ✔️ | ✔️ |
HeCo [KDD 2021] | ✔️ | ✔️ | ✔️ | |
DHN [TKDE 2023] | ✔️ | |||
HEAT [T-ITS 2023] | ✔️ |
Note
The models can be run in mindspore backend. Howerver, the results of experiments are not satisfying due to training component issue, which will be fixed in future.
GammaGL Team[GAMMA LAB] and Peng Cheng Laboratory.
See more in CONTRIBUTING.
Contribution is always welcomed. Please feel free to open an issue or email to [email protected].
If you use GammaGL in a scientific publication, we would appreciate citations to the following paper:
@inproceedings{10.1145/3539618.3591891,
author = {Liu, Yaoqi and Yang, Cheng and Zhao, Tianyu and Han, Hui and Zhang, Siyuan and Wu, Jing and Zhou, Guangyu and Huang, Hai and Wang, Hui and Shi, Chuan},
title = {GammaGL: A Multi-Backend Library for Graph Neural Networks},
year = {2023},
isbn = {9781450394086},
publisher = {Association for Computing Machinery},
address = {New York, NY, USA},
url = {https://doi.org/10.1145/3539618.3591891},
doi = {10.1145/3539618.3591891},
booktitle = {Proceedings of the 46th International ACM SIGIR Conference on Research and Development in Information Retrieval},
pages = {2861–2870},
numpages = {10},
keywords = {graph neural networks, frameworks, deep learning},
location = {, Taipei, Taiwan, },
series = {SIGIR '23}
}