9
9
10
10
import contextlib
11
11
import logging
12
+ import os
13
+ import typing
12
14
13
15
import torch
14
16
import torch .nn as nn
15
17
18
+ if typing .TYPE_CHECKING :
19
+ from torch_geometric .data .batch import Batch
20
+
16
21
from fairchem .core .common .registry import registry
22
+ from fairchem .core .models .base import GraphModelMixin
17
23
from fairchem .core .models .escn .so3_exportable import (
18
24
CoefficientMapping ,
19
25
SO3_Grid ,
32
38
33
39
34
40
@registry .register_model ("escn_export" )
35
- class eSCN (nn .Module ):
41
+ class eSCN (nn .Module , GraphModelMixin ):
36
42
"""Equivariant Spherical Channel Network
37
43
Paper: Reducing SO(3) Convolutions to SO(2) for Efficient Equivariant GNNs
38
44
39
45
40
46
Args:
41
- regress_forces (bool ): Compute forces
42
- cutoff (float): Maximum distance between nieghboring atoms in Angstroms
43
- max_num_elements (int): Maximum atomic number
47
+ max_neighbors(int ): Max neighbors to take per node, when using the graph generation
48
+ cutoff (float): Maximum distance between nieghboring atoms in Angstroms
49
+ max_num_elements (int): Maximum atomic number
44
50
num_layers (int): Number of layers in the GNN
45
51
lmax (int): maximum degree of the spherical harmonics (1 to 10)
46
52
mmax (int): maximum order of the spherical harmonics (0 to lmax)
@@ -51,13 +57,15 @@ class eSCN(nn.Module):
51
57
distance_function ("gaussian", "sigmoid", "linearsigmoid", "silu"): Basis function used for distances
52
58
basis_width_scalar (float): Width of distance basis function
53
59
distance_resolution (float): Distance between distance basis functions in Angstroms
60
+ compile (bool): use torch.compile on the forward
61
+ export (bool): use the exportable version of the module
54
62
"""
55
63
56
64
def __init__ (
57
65
self ,
58
- regress_forces : bool = True ,
66
+ max_neighbors : int = 300 ,
59
67
cutoff : float = 8.0 ,
60
- max_num_elements : int = 90 ,
68
+ max_num_elements : int = 100 ,
61
69
num_layers : int = 8 ,
62
70
lmax : int = 4 ,
63
71
mmax : int = 2 ,
@@ -69,6 +77,8 @@ def __init__(
69
77
basis_width_scalar : float = 1.0 ,
70
78
distance_resolution : float = 0.02 ,
71
79
resolution : int | None = None ,
80
+ compile : bool = False ,
81
+ export : bool = False ,
72
82
) -> None :
73
83
super ().__init__ ()
74
84
@@ -78,7 +88,7 @@ def __init__(
78
88
logging .error ("You need to install the e3nn library to use the SCN model" )
79
89
raise ImportError
80
90
81
- self .regress_forces = regress_forces
91
+ self .max_neighbors = max_neighbors
82
92
self .cutoff = cutoff
83
93
self .max_num_elements = max_num_elements
84
94
self .hidden_channels = hidden_channels
@@ -91,6 +101,8 @@ def __init__(
91
101
self .mmax = mmax
92
102
self .basis_width_scalar = basis_width_scalar
93
103
self .distance_function = distance_function
104
+ self .compile = compile
105
+ self .export = export
94
106
95
107
# non-linear activation function used throughout the network
96
108
self .act = nn .SiLU ()
@@ -169,10 +181,9 @@ def __init__(
169
181
self .energy_block = EnergyBlock (
170
182
self .sphere_channels , self .num_sphere_samples , self .act
171
183
)
172
- if self .regress_forces :
173
- self .force_block = ForceBlock (
174
- self .sphere_channels , self .num_sphere_samples , self .act
175
- )
184
+ self .force_block = ForceBlock (
185
+ self .sphere_channels , self .num_sphere_samples , self .act
186
+ )
176
187
177
188
# Create a roughly evenly distributed point sampling of the sphere for the output blocks
178
189
self .sphere_points = nn .Parameter (
@@ -189,29 +200,96 @@ def __init__(
189
200
requires_grad = False ,
190
201
)
191
202
192
- def forward (self , data : dict [str , torch .Tensor ]) -> dict [str , torch .Tensor ]:
193
- pos : torch .Tensor = data ["pos" ]
194
- batch_idx : torch .Tensor = data ["batch" ]
195
- natoms : torch .Tensor = data ["natoms" ]
196
- atomic_numbers : torch .Tensor = data ["atomic_numbers" ]
197
- edge_index : torch .Tensor = data ["edge_index" ]
198
- edge_distance : torch .Tensor = data ["distances" ]
199
- edge_distance_vec : torch .Tensor = data ["edge_distance_vec" ]
200
-
201
- atomic_numbers = atomic_numbers .long ()
202
- # TODO: this requires upgrade to torch2.4 with export non-strict mode to enable
203
- # assert (
204
- # atomic_numbers.max().item() < self.max_num_elements
205
- # ), "Atomic number exceeds that given in model config"
203
+ self .sph_feature_size = int ((self .lmax + 1 ) ** 2 )
204
+ # Pre-load Jd tensors for wigner matrices
205
+ # Borrowed from e3nn @ 0.4.0:
206
+ # https://github.com/e3nn/e3nn/blob/0.4.0/e3nn/o3/_wigner.py#L10
207
+ # _Jd is a list of tensors of shape (2l+1, 2l+1)
208
+ # TODO: we should probably just bake this into the file as strings to avoid
209
+ # carrying this extra file around
210
+ Jd_list = torch .load (os .path .join (os .path .dirname (__file__ ), "Jd.pt" ))
211
+ for l in range (self .lmax + 1 ):
212
+ self .register_buffer (f"Jd_{ l } " , Jd_list [l ])
213
+
214
+ if self .compile :
215
+ logging .info ("Using the compiled escn forward function..." )
216
+ self .forward = torch .compile (
217
+ options = {"triton.cudagraphs" : True }, fullgraph = True , dynamic = True
218
+ )(self .forward )
219
+
220
+ # torch.export only works with nn.module with an unaltered forward function,
221
+ # furthermore AOT Inductor currently requires a flat list of inputs
222
+ # this we need keep the module.forward function as the fully exportable region
223
+ # When not using export, ie for training, we swap out the forward with a version
224
+ # that wraps it with the graph generator
225
+ #
226
+ # TODO: this is really ugly and confusing to read, find a better way to deal
227
+ # with partially exportable model
228
+ if not self .export :
229
+ self ._forward = self .forward
230
+ self .forward = self .forward_trainable
231
+
232
+ def forward_trainable (self , data : Batch ) -> dict [str , torch .Tensor ]:
233
+ # standard forward call that generates the graph on-the-fly with generate_graph
234
+ # this part of the code is not compile/export friendly so we keep it separated and wrap the exportaable forward
235
+ graph = self .generate_graph (
236
+ data ,
237
+ max_neighbors = self .max_neighbors ,
238
+ otf_graph = True ,
239
+ use_pbc = True ,
240
+ use_pbc_single = True ,
241
+ )
242
+ energy , forces = self ._forward (
243
+ data .pos ,
244
+ data .batch ,
245
+ data .natoms ,
246
+ data .atomic_numbers .long (),
247
+ graph .edge_index ,
248
+ graph .edge_distance ,
249
+ graph .edge_distance_vec ,
250
+ )
251
+ return {"energy" : energy , "forces" : forces }
252
+
253
+ # a fully compilable/exportable forward function
254
+ # takes a full graph with edges as input
255
+ def forward (
256
+ self ,
257
+ pos : torch .Tensor ,
258
+ batch_idx : torch .Tensor ,
259
+ natoms : torch .Tensor ,
260
+ atomic_numbers : torch .Tensor ,
261
+ edge_index : torch .Tensor ,
262
+ edge_distance : torch .Tensor ,
263
+ edge_distance_vec : torch .Tensor ,
264
+ ) -> list [torch .Tensor ]:
265
+ """
266
+ N: num atoms
267
+ N: batch size
268
+ E: num edges
269
+
270
+ pos: [N, 3] atom positions
271
+ batch_idx: [N] batch index of each atom
272
+ natoms: [B] number of atoms in each batch
273
+ atomic_numbers: [N] atomic number per atom
274
+ edge_index: [2, E] edges between source and target atoms
275
+ edge_distance: [E] cartesian distance for each edge
276
+ edge_distance_vec: [E, 3] direction vector of edges (includes pbc)
277
+ """
278
+ if not self .export and not self .compile :
279
+ assert atomic_numbers .max ().item () < self .max_num_elements
206
280
num_atoms = len (atomic_numbers )
207
281
208
282
###############################################################
209
283
# Initialize data structures
210
284
###############################################################
211
285
212
286
# Compute 3x3 rotation matrix per edge
213
- edge_rot_mat = self ._init_edge_rot_mat (edge_index , edge_distance_vec )
214
- wigner = rotation_to_wigner (edge_rot_mat , 0 , self .lmax ).detach ()
287
+ edge_rot_mat = self ._init_edge_rot_mat (edge_distance_vec )
288
+ Jd_buffers = [
289
+ getattr (self , f"Jd_{ l } " ).type (edge_rot_mat .dtype )
290
+ for l in range (self .lmax + 1 )
291
+ ]
292
+ wigner = rotation_to_wigner (edge_rot_mat , 0 , self .lmax , Jd_buffers ).detach ()
215
293
216
294
###############################################################
217
295
# Initialize node embeddings
@@ -220,7 +298,7 @@ def forward(self, data: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
220
298
# Init per node representations using an atomic number based embedding
221
299
x_message = torch .zeros (
222
300
num_atoms ,
223
- int (( self .lmax + 1 ) ** 2 ) ,
301
+ self .sph_feature_size ,
224
302
self .sphere_channels ,
225
303
device = pos .device ,
226
304
dtype = pos .dtype ,
@@ -266,31 +344,20 @@ def forward(self, data: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
266
344
# Scale energy to help balance numerical precision w.r.t. forces
267
345
energy = energy * 0.001
268
346
269
- outputs = {"energy" : energy }
270
347
###############################################################
271
348
# Force estimation
272
349
###############################################################
273
- if self .regress_forces :
274
- forces = self .force_block (x_pt , self .sphere_points )
275
- outputs ["forces" ] = forces
350
+ forces = self .force_block (x_pt , self .sphere_points )
276
351
277
- return outputs
352
+ return energy , forces
278
353
279
354
# Initialize the edge rotation matrics
280
- def _init_edge_rot_mat (self , edge_index , edge_distance_vec ):
355
+ def _init_edge_rot_mat (self , edge_distance_vec ):
281
356
edge_vec_0 = edge_distance_vec
282
357
edge_vec_0_distance = torch .sqrt (torch .sum (edge_vec_0 ** 2 , dim = 1 ))
283
358
284
359
# Make sure the atoms are far enough apart
285
- # TODO: this requires upgrade to torch2.4 with export non-strict mode to enable
286
- # if torch.min(edge_vec_0_distance) < 0.0001:
287
- # logging.error(
288
- # f"Error edge_vec_0_distance: {torch.min(edge_vec_0_distance)}"
289
- # )
290
- # (minval, minidx) = torch.min(edge_vec_0_distance, 0)
291
- # logging.error(
292
- # f"Error edge_vec_0_distance: {minidx} {edge_index[0, minidx]} {edge_index[1, minidx]} {data.pos[edge_index[0, minidx]]} {data.pos[edge_index[1, minidx]]}"
293
- # )
360
+ # assert torch.min(edge_vec_0_distance) < 0.0001
294
361
295
362
norm_x = edge_vec_0 / (edge_vec_0_distance .view (- 1 , 1 ))
296
363
0 commit comments