Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

POC: optimize the graph network with CUDA Graph #60

Closed
wants to merge 12 commits into from
124 changes: 96 additions & 28 deletions benchmarks/graph_network.ipynb

Large diffs are not rendered by default.

29 changes: 29 additions & 0 deletions tests/test_neigbors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import pytest
from pytest import mark
from sklearn import neighbors
import torch as pt

from torchmdnet.models.utils import DistanceBruteForce, Distance

@mark.parametrize('num_atoms', [5, 7, 11, 13, 17])
@mark.parametrize('device', ['cpu', 'cuda'])
def test_neighbors(num_atoms, device):

if not pt.cuda.is_available() and device == 'cuda':
pytest.skip('No GPU')

device = pt.device(device)

# Generate random inputs
pos = (10 * pt.rand(num_atoms, 3, dtype=pt.float32, device=device) - 5)

simple = Distance(0.0, 100.0)
brute_force = DistanceBruteForce()

_, simple_distances, _ = simple(pos, None)
_, brute_force_distances, _ = brute_force(pos, None)

simple_distances = simple_distances.sort().values
brute_force_distances = brute_force_distances.sort().values

assert pt.allclose(simple_distances, brute_force_distances)
10 changes: 8 additions & 2 deletions torchmdnet/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torchmdnet.models import output_modules
from torchmdnet.models.wrappers import AtomFilter
from torchmdnet import priors
import warnings


def create_model(args, prior_model=None, mean=None, std=None):
Expand All @@ -25,6 +26,9 @@ def create_model(args, prior_model=None, mean=None, std=None):
max_num_neighbors=args["max_num_neighbors"],
)

if "neighbors" in args:
shared_args["neighbors"] = args["neighbors"]

# representation network
if args["model"] == "graph-network":
from torchmdnet.models.torchmd_gn import TorchMD_GN
Expand Down Expand Up @@ -100,7 +104,8 @@ def load_model(filepath, args=None, device="cpu", **kwargs):
args = ckpt["hyper_parameters"]

for key, value in kwargs.items():
assert key in args, "Unknown hyperparameter '{key}'."
if not key in args:
warnings.warn(f'Unknown hyperparameter: {key}={value}')
args[key] = value

model = create_model(args)
Expand Down Expand Up @@ -173,7 +178,8 @@ def forward(self, z, pos, batch: Optional[torch.Tensor] = None):
x = self.prior_model(x, z, pos, batch)

# aggregate atoms
out = scatter(x, batch, dim=0, reduce=self.reduce_op)
out = x.sum(0, keepdim=True) if self.reduce_op == "simple_add" else \
scatter(x, batch, dim=0, reduce=self.reduce_op)

# shift by data mean
if self.mean is not None:
Expand Down
13 changes: 10 additions & 3 deletions torchmdnet/models/torchmd_gn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
NeighborEmbedding,
CosineCutoff,
Distance,
DistanceBruteForce,
rbf_class_mapping,
act_class_mapping,
)
Expand Down Expand Up @@ -70,6 +71,7 @@ def __init__(
max_z=100,
max_num_neighbors=32,
aggr="add",
neighbors="simple"
):
super(TorchMD_GN, self).__init__()

Expand Down Expand Up @@ -99,14 +101,19 @@ def __init__(
self.cutoff_upper = cutoff_upper
self.max_z = max_z
self.aggr = aggr
self.neighbors = neighbors

act_class = act_class_mapping[activation]

self.embedding = nn.Embedding(self.max_z, hidden_channels)

self.distance = Distance(
cutoff_lower, cutoff_upper, max_num_neighbors=max_num_neighbors
)
if self.neighbors == "simple":
self.distance = Distance(cutoff_lower, cutoff_upper, max_num_neighbors=max_num_neighbors)
elif self.neighbors == "brute_force":
self.distance = DistanceBruteForce()
else:
raise ValueError('neighbours')

self.distance_expansion = rbf_class_mapping[rbf_type](
cutoff_lower, cutoff_upper, num_rbf, trainable_rbf
)
Expand Down
30 changes: 25 additions & 5 deletions torchmdnet/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ def reset_parameters(self):

def forward(self, z, x, edge_index, edge_weight, edge_attr):
# remove self loops
mask = edge_index[0] != edge_index[1]
if not mask.all():
edge_index = edge_index[:, mask]
edge_weight = edge_weight[mask]
edge_attr = edge_attr[mask]
# mask = edge_index[0] != edge_index[1]
# if not mask.all():
# edge_index = edge_index[:, mask]
# edge_weight = edge_weight[mask]
# edge_attr = edge_attr[mask]

C = self.cutoff(edge_weight)
W = self.distance_proj(edge_attr) * C.view(-1, 1)
Expand Down Expand Up @@ -238,6 +238,26 @@ def forward(self, pos, batch):
return edge_index, edge_weight, None


class DistanceBruteForce(nn.Module):
def __init__(self):
super().__init__()

def forward(self, pos, batch):

num_nodes = len(pos)
indices = torch.arange(0, num_nodes * (num_nodes - 1), device=pos.device)

row = torch.div(indices, num_nodes - 1, rounding_mode='floor')
column = torch.div(indices, num_nodes, rounding_mode='floor')
column = torch.remainder(indices + column + 1, num_nodes)

edge_index = torch.vstack((row, column))
edge_vec = torch.index_select(pos, 0, row) - torch.index_select(pos, 0, column)
edge_weight = torch.norm(edge_vec, dim=-1)

return edge_index, edge_weight, None


class GatedEquivariantBlock(nn.Module):
"""Gated Equivariant Block as defined in Schütt et al. (2021):
Equivariant message passing for the prediction of tensorial properties and molecular spectra
Expand Down