Notice: This is research code that will not necessarily be maintained in the future. The code is under development so make sure you are using the most recent version. We welcome bug reports and PRs but make no guarantees about fixes or responses.
gt_pyg
is an implementation of the Graph Transformer Architecture in Pytorch Geometric.
This sketch provides an overview of the Graph Transformer Architecture (Dwivedi, Bresson, 2021). In a nutshell, the model is implementing a dot product self-attention network with the softmask
function (softmask
is a softmax
applied only over the non-zero elements of the (A+I)
matrix).
This sketch is an overview of the gating mechanism used in the GT model (Chen, et al., 2023, Bioinformatics).
Clone and install the software:
git clone https://github.com/pgniewko/gt-pyg.git
pip install .
Create and activate the conda environment:
conda env create -f environment.yml
conda activate gt
The following code snippet demonstrates how to test the installation of gt-pyg and the usage of the GTConv layer.
import torch
from torch_geometric.data import Data
from gt_pyg.nn.gt_conv import GTConv
num_nodes = 10
num_node_features = 3
num_edges = 20
num_edge_features = 2
# Generate random node features
x = torch.randn(num_nodes, num_node_features)
# Generate random edge indices
edge_index = torch.randint(high=num_nodes, size=(2, num_edges))
# Generate random edge attributes (optional)
edge_attr = torch.randn(num_edges, num_edge_features)
gt = GTConv(node_in_dim=num_node_features,
edge_in_dim=num_edge_features,
hidden_dim=15,
num_heads=3)
gt(x=x, edge_index=edge_index, edge_attr=edge_attr)
The complete example, which demonstrates the usage of the GTConv layer in a model and training a regression model for the ADME task in the Therapeutics Data Commons, can be found in this notebook.
Note: The Gaussian Network Model based positional encodings featurization here
The code works with custom datasets. Let's assume we have a file called solubility.csv
with two columns: SMILES
and logS
. We can prepare a training DataLoader
object with the following code.
fn = 'solubility.csv'
x_label='SMILES'
y_label='logS'
dataset = get_data_from_csv(fn, x_label=x_label, y_label=y_label)
tr_dataset = get_tensor_data(dataset[x_label], dataset[y_label].to_list(), pe_dim=6)
train_loader = DataLoader(tr_dataset, batch_size=256)
The main idea originates from the Weight Watcher lecture, which suggests using per-layer alpha exponents to control the training process. The premise is that if the per-layer alphas fall within the range of [2, 6]
, the network is well-trained, indicating that the deep neural network (DNN) has successfully captured long-range, scale-invariant correlations between the neural network parameters, input data, and the output labels.
The idea developed here is rather simple: whenever a layer's alpha
is within the desired range, we decrease the learning rate to effectively "capture" the weights in this spectral range. Conversely, when the alpha
falls outside of this range, we increase the learning rate so that the layer's weights are modified more rapidly.
Note: Recently this idea has been extended in the form of the TempBalace algorithm.
We have empirically demonstrated that this simple procedure indeed leads to capturing the spectral exponents of the layers within the desired range. It is hypothesized that networks regularized in this way exhibit better generalization capabilities.
The repository also contains a notebook and an example of how to run the (in)famous MoleculeACE benchmark. You can find it here. Enjoy!
In this project, I trained the model on the data from the Discovery of a structural class of antibiotics with explainable deep learning paper. The primary aim was to replicate the calculations presented in the paper, and the initial results are quite promising right from the start.
- The code aims to faithfully replicate the original GTConv layer as closely as possible.
- There is no need for clipping in the softmax function since the softmax procedure in PyG employs the Log-Sum-Exp trick, effectively mitigating any potential risk of overflow.
- Additional details on implementing message passing layers in
pytorch-geometric
can be found on the pyg website. - In the original paper, only the
sum
is used for message aggregation. Drawing inspiration from the PNA model, the user can utilize a set of aggregators. - The current implementation adds a gating mechanism (after Chen et al. 2023, Bioinformatics) and sets biases in the attention mechanism to
False
(after Jumper et al, 2021, Science). To reproduce the original GT paper, setqkv_bias=True
, andgate=False
.
-
Some implementation techniques are borrowed from the TransformerConv module in the PyTorch-Geometric codebase.
-
To convert SMILES into a tensor code, one option is to utilize the from_smiles method. However, the current featurization approach lacks flexibility; It necessitates the creation of multiple embeddings, which are then summed instead of employing a single Linear layer.
-
To maintain simplicity, we forgo creating a separate DataSet object since we are working with small datasets. Instead, we pass a list of Data objects to the DataLoader, as explained in the documentation.
-
The compound cleaning procedure drew inspiration from Pat Walter's blog-post.
-
The test loss is calculated using TorchMetrics for convenience.
-
Note: The order in which batch/layer normalization is executed is a topic of debate in the Deep Learning literature. The current implementation in this work utilizes the Post-Normalization Transformer layer, as opposed to Pre-Normalization. Although it is claimed that pre-normalization works better, this assertion has not been verified in this particular study. For additional references, please refer to the Graphorer paper or a more comprehensive study conducted by Xiong et al..
- A Generalization of Transformer Networks to Graphs
- A gated graph transformer for protein complex structure quality assessment and its performance in CASP15
- What Uncertainties Do We Need in Bayesian Deep Learning for Computer Vision?
- Therapeutics Data Commons
- WeightWatcher
- Gaussian Network Model
Copyright (C) 2023-, Pawel Gniewek
Email: [email protected]
License: MIT
All rights reserved.