-
Notifications
You must be signed in to change notification settings - Fork 78
/
Copy pathtensornet.py
410 lines (377 loc) · 16 KB
/
tensornet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
import torch
import numpy as np
from typing import Optional, Tuple
from torch import Tensor, nn
from torch_scatter import scatter
from torchmdnet.models.utils import (
CosineCutoff,
OptimizedDistance,
rbf_class_mapping,
act_class_mapping,
)
torch.set_float32_matmul_precision('high')
torch.backends.cuda.matmul.allow_tf32 = True
# Creates a skew-symmetric tensor from a vector
def vector_to_skewtensor(vector):
batch_size = vector.size(0)
zero = torch.zeros(batch_size, device=vector.device, dtype=vector.dtype)
tensor = torch.stack(
(
zero,
-vector[:, 2],
vector[:, 1],
vector[:, 2],
zero,
-vector[:, 0],
-vector[:, 1],
vector[:, 0],
zero,
),
dim=1,
)
tensor = tensor.view(-1, 3, 3)
return tensor.squeeze(0)
# Creates a symmetric traceless tensor from the outer product of a vector with itself
def vector_to_symtensor(vector):
tensor = torch.matmul(vector.unsqueeze(-1), vector.unsqueeze(-2))
I = (tensor.diagonal(offset=0, dim1=-1, dim2=-2)).mean(-1)[
..., None, None
] * torch.eye(3, 3, device=tensor.device, dtype=tensor.dtype)
S = 0.5 * (tensor + tensor.transpose(-2, -1)) - I
return S
# Full tensor decomposition into irreducible components
def decompose_tensor(tensor):
I = (tensor.diagonal(offset=0, dim1=-1, dim2=-2)).mean(-1)[
..., None, None
] * torch.eye(3, 3, device=tensor.device, dtype=tensor.dtype)
A = 0.5 * (tensor - tensor.transpose(-2, -1))
S = 0.5 * (tensor + tensor.transpose(-2, -1)) - I
return I, A, S
# Modifies tensor by multiplying invariant features to irreducible components
def new_radial_tensor(I, A, S, f_I, f_A, f_S):
I = f_I[..., None, None] * I
A = f_A[..., None, None] * A
S = f_S[..., None, None] * S
return I, A, S
# Computes Frobenius norm
def tensor_norm(tensor):
return (tensor**2).sum((-2, -1))
class TensorNet(nn.Module):
r"""TensorNet's architecture, from TensorNet: Cartesian Tensor Representations
for Efficient Learning of Molecular Potentials; G. Simeon and G. de Fabritiis.
Args:
hidden_channels (int, optional): Hidden embedding size.
(default: :obj:`128`)
num_layers (int, optional): The number of interaction layers.
(default: :obj:`2`)
num_rbf (int, optional): The number of radial basis functions :math:`\mu`.
(default: :obj:`32`)
rbf_type (string, optional): The type of radial basis function to use.
(default: :obj:`"expnorm"`)
trainable_rbf (bool, optional): Whether to train RBF parameters with
backpropagation. (default: :obj:`False`)
activation (string, optional): The type of activation function to use.
(default: :obj:`"silu"`)
cutoff_lower (float, optional): Lower cutoff distance for interatomic interactions.
(default: :obj:`0.0`)
cutoff_upper (float, optional): Upper cutoff distance for interatomic interactions.
(default: :obj:`4.5`)
max_z (int, optional): Maximum atomic number. Used for initializing embeddings.
(default: :obj:`128`)
max_num_neighbors (int, optional): Maximum number of neighbors to return for a
given node/atom when constructing the molecular graph during forward passes.
(default: :obj:`64`)
equivariance_invariance_group (string, optional): Group under whose action on input
positions internal tensor features will be equivariant and scalar predictions
will be invariant. O(3) or SO(3).
(default :obj:`"O(3)"`)
static_shapes (bool, optional): Whether to enforce static shapes.
Makes the model CUDA-graph compatible.
(default: :obj:`True`)
"""
def __init__(
self,
hidden_channels=128,
num_layers=2,
num_rbf=32,
rbf_type="expnorm",
trainable_rbf=False,
activation="silu",
cutoff_lower=0,
cutoff_upper=4.5,
max_num_neighbors=64,
max_z=128,
equivariance_invariance_group="O(3)",
static_shapes=True,
dtype=torch.float32,
):
super(TensorNet, self).__init__()
assert rbf_type in rbf_class_mapping, (
f'Unknown RBF type "{rbf_type}". '
f'Choose from {", ".join(rbf_class_mapping.keys())}.'
)
assert activation in act_class_mapping, (
f'Unknown activation function "{activation}". '
f'Choose from {", ".join(act_class_mapping.keys())}.'
)
assert equivariance_invariance_group in ["O(3)", "SO(3)"], (
f'Unknown group "{equivariance_invariance_group}". '
f"Choose O(3) or SO(3)."
)
self.hidden_channels = hidden_channels
self.equivariance_invariance_group = equivariance_invariance_group
self.num_layers = num_layers
self.num_rbf = num_rbf
self.rbf_type = rbf_type
self.activation = activation
self.cutoff_lower = cutoff_lower
self.cutoff_upper = cutoff_upper
act_class = act_class_mapping[activation]
self.distance_expansion = rbf_class_mapping[rbf_type](
cutoff_lower, cutoff_upper, num_rbf, trainable_rbf
)
self.tensor_embedding = TensorEmbedding(
hidden_channels,
num_rbf,
act_class,
cutoff_lower,
cutoff_upper,
trainable_rbf,
max_z,
dtype,
)
self.layers = nn.ModuleList()
if num_layers != 0:
for _ in range(num_layers):
self.layers.append(
Interaction(
num_rbf,
hidden_channels,
act_class,
cutoff_lower,
cutoff_upper,
equivariance_invariance_group,
dtype,
)
)
self.linear = nn.Linear(3 * hidden_channels, hidden_channels, dtype=dtype)
self.out_norm = nn.LayerNorm(3 * hidden_channels, dtype=dtype)
self.act = act_class()
# Resize to fit set to false ensures Distance returns a statically-shaped tensor of size max_num_pairs=pos.size*max_num_neigbors
# negative max_num_pairs argument means "per particle"
# long_edge_index set to False saves memory and spares some kernel launches by keeping neighbor indices as int32.
self.static_shapes = static_shapes
self.distance = OptimizedDistance(
cutoff_lower,
cutoff_upper,
max_num_pairs=-max_num_neighbors,
return_vecs=True,
loop=True,
check_errors=False,
resize_to_fit=not self.static_shapes,
long_edge_index=True,
)
self.reset_parameters()
def reset_parameters(self):
self.tensor_embedding.reset_parameters()
for layer in self.layers:
layer.reset_parameters()
self.linear.reset_parameters()
self.out_norm.reset_parameters()
def forward(
self,
z: Tensor,
pos: Tensor,
batch: Tensor,
q: Optional[Tensor] = None,
s: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor], Tensor, Tensor, Tensor]:
# Obtain graph, with distances and relative position vectors
edge_index, edge_weight, edge_vec = self.distance(pos, batch)
# This assert convinces TorchScript that edge_vec is a Tensor and not an Optional[Tensor]
assert (
edge_vec is not None
), "Distance module did not return directional information"
# Distance module returns -1 for non-existing edges, to avoid having to resize the tensors when we want to ensure static shapes (for CUDA graphs) we make all non-existing edges pertain to the first atom
if self.static_shapes:
mask = (edge_index[0] >= 0).unsqueeze(0).expand_as(edge_index)
# I trick the model into thinking that the masked edges pertain to the first atom
# WARNING: This can hurt performance if max_num_pairs >> actual_num_pairs
edge_index = edge_index * mask
edge_weight = edge_weight * mask[0]
edge_vec = edge_vec * mask[0].unsqueeze(-1).expand_as(edge_vec)
edge_attr = self.distance_expansion(edge_weight)
mask = edge_index[0] == edge_index[1]
# Normalizing edge vectors by their length can result in NaNs, breaking Autograd.
# I avoid dividing by zero by setting the weight of self edges and self loops to 1
edge_vec = edge_vec / edge_weight.masked_fill(mask, 1).unsqueeze(1)
X = self.tensor_embedding(z, edge_index, edge_weight, edge_vec, edge_attr)
for layer in self.layers:
X = layer(X, edge_index, edge_weight, edge_attr)
I, A, S = decompose_tensor(X)
x = torch.cat((tensor_norm(I), tensor_norm(A), tensor_norm(S)), dim=-1)
x = self.out_norm(x)
x = self.act(self.linear((x)))
return x, None, z, pos, batch
class TensorEmbedding(nn.Module):
def __init__(
self,
hidden_channels,
num_rbf,
activation,
cutoff_lower,
cutoff_upper,
trainable_rbf=False,
max_z=128,
dtype=torch.float32,
):
super(TensorEmbedding, self).__init__()
self.hidden_channels = hidden_channels
self.distance_proj1 = nn.Linear(num_rbf, hidden_channels, dtype=dtype)
self.distance_proj2 = nn.Linear(num_rbf, hidden_channels, dtype=dtype)
self.distance_proj3 = nn.Linear(num_rbf, hidden_channels, dtype=dtype)
self.cutoff = CosineCutoff(cutoff_lower, cutoff_upper)
self.max_z = max_z
self.emb = nn.Embedding(max_z, hidden_channels, dtype=dtype)
self.emb2 = nn.Linear(2 * hidden_channels, hidden_channels, dtype=dtype)
self.act = activation()
self.linears_tensor = nn.ModuleList()
for _ in range(3):
self.linears_tensor.append(
nn.Linear(hidden_channels, hidden_channels, bias=False)
)
self.linears_scalar = nn.ModuleList()
self.linears_scalar.append(
nn.Linear(hidden_channels, 2 * hidden_channels, bias=True, dtype=dtype)
)
self.linears_scalar.append(
nn.Linear(2 * hidden_channels, 3 * hidden_channels, bias=True, dtype=dtype)
)
self.init_norm = nn.LayerNorm(hidden_channels, dtype=dtype)
self.reset_parameters()
def reset_parameters(self):
self.distance_proj1.reset_parameters()
self.distance_proj2.reset_parameters()
self.distance_proj3.reset_parameters()
self.emb.reset_parameters()
self.emb2.reset_parameters()
for linear in self.linears_tensor:
linear.reset_parameters()
for linear in self.linears_scalar:
linear.reset_parameters()
self.init_norm.reset_parameters()
def forward(
self,
z: Tensor,
edge_index: Tensor,
edge_weight: Tensor,
edge_vec_norm: Tensor,
edge_attr: Tensor,
) -> Tensor:
C = self.cutoff(edge_weight)
W1 = self.distance_proj1(edge_attr) * C.view(-1, 1)
W2 = self.distance_proj2(edge_attr) * C.view(-1, 1)
W3 = self.distance_proj3(edge_attr) * C.view(-1, 1)
Iij, Aij, Sij = new_radial_tensor(
torch.eye(3, 3, device=edge_vec_norm.device, dtype=edge_vec_norm.dtype)[
None, None, :, :
],
vector_to_skewtensor(edge_vec_norm)[..., None, :, :],
vector_to_symtensor(edge_vec_norm)[..., None, :, :],
W1,
W2,
W3,
)
Z = self.emb(z)
Zij = self.emb2(
Z.index_select(0, edge_index.t().reshape(-1)).view(-1, self.hidden_channels * 2)
)[..., None, None]
I = scatter(Zij*Iij, edge_index[0], dim=0, dim_size=z.shape[0])
A = scatter(Zij*Aij, edge_index[0], dim=0, dim_size=z.shape[0])
S = scatter(Zij*Sij, edge_index[0], dim=0, dim_size=z.shape[0])
norm = self.init_norm(tensor_norm(I + A + S))
I = self.linears_tensor[0](I.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
A = self.linears_tensor[1](A.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
S = self.linears_tensor[2](S.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
for linear_scalar in self.linears_scalar:
norm = self.act(linear_scalar(norm))
norm = norm.reshape(norm.shape[0], self.hidden_channels, 3)
I, A, S = new_radial_tensor(I, A, S, norm[..., 0], norm[..., 1], norm[..., 2])
X = I + A + S
return X
def tensor_message_passing(edge_index: Tensor, factor: Tensor, tensor: Tensor, natoms: int) -> Tensor:
msg = factor * tensor.index_select(0, edge_index[1])
tensor_m = scatter(msg, edge_index[0], dim=0, dim_size=natoms)
return tensor_m
class Interaction(nn.Module):
def __init__(
self,
num_rbf,
hidden_channels,
activation,
cutoff_lower,
cutoff_upper,
equivariance_invariance_group,
dtype=torch.float32,
):
super(Interaction, self).__init__()
self.num_rbf = num_rbf
self.hidden_channels = hidden_channels
self.cutoff = CosineCutoff(cutoff_lower, cutoff_upper)
self.linears_scalar = nn.ModuleList()
self.linears_scalar.append(
nn.Linear(num_rbf, hidden_channels, bias=True, dtype=dtype)
)
self.linears_scalar.append(
nn.Linear(hidden_channels, 2 * hidden_channels, bias=True, dtype=dtype)
)
self.linears_scalar.append(
nn.Linear(2 * hidden_channels, 3 * hidden_channels, bias=True, dtype=dtype)
)
self.linears_tensor = nn.ModuleList()
for _ in range(6):
self.linears_tensor.append(
nn.Linear(hidden_channels, hidden_channels, bias=False)
)
self.act = activation()
self.equivariance_invariance_group = equivariance_invariance_group
self.reset_parameters()
def reset_parameters(self):
for linear in self.linears_scalar:
linear.reset_parameters()
for linear in self.linears_tensor:
linear.reset_parameters()
def forward(
self, X: Tensor, edge_index: Tensor, edge_weight: Tensor, edge_attr: Tensor
) -> Tensor:
C = self.cutoff(edge_weight)
for linear_scalar in self.linears_scalar:
edge_attr = self.act(linear_scalar(edge_attr))
edge_attr = (edge_attr * C.view(-1, 1)).reshape(
edge_attr.shape[0], self.hidden_channels, 3
)
X = X / (tensor_norm(X) + 1)[..., None, None]
I, A, S = decompose_tensor(X)
I = self.linears_tensor[0](I.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
A = self.linears_tensor[1](A.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
S = self.linears_tensor[2](S.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
Y = I + A + S
Im = tensor_message_passing(edge_index, edge_attr[..., 0, None, None], I, X.shape[0])
Am = tensor_message_passing(edge_index, edge_attr[..., 1, None, None], A, X.shape[0])
Sm = tensor_message_passing(edge_index, edge_attr[..., 2, None, None], S, X.shape[0])
msg = Im + Am + Sm
if self.equivariance_invariance_group == "O(3)":
A = torch.matmul(msg, Y)
B = torch.matmul(Y, msg)
I, A, S = decompose_tensor(A + B)
if self.equivariance_invariance_group == "SO(3)":
B = torch.matmul(Y, msg)
I, A, S = decompose_tensor(2 * B)
normp1 = (tensor_norm(I + A + S) + 1)[..., None, None]
I, A, S = I / normp1, A / normp1, S / normp1
I = self.linears_tensor[3](I.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
A = self.linears_tensor[4](A.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
S = self.linears_tensor[5](S.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
dX = I + A + S
X = X + dX + torch.matrix_power(dX, 2)
return X