Skip to content

Commit 7707b3e

Browse files
authored
ESCN export II (#848)
* temp fix to train mptraj * add compile option * move jd to init * fix dynamic export * add value testing for export * update test * lint * update comment * update forward code * reraise error * wrap escn * revert packages * format Former-commit-id: 362482e00920bea5af5bf7fcb9c035bd62966aa6
1 parent 60f56bd commit 7707b3e

File tree

3 files changed

+349
-199
lines changed

3 files changed

+349
-199
lines changed

src/fairchem/core/models/escn/escn_exportable.py

+110-43
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,17 @@
99

1010
import contextlib
1111
import logging
12+
import os
13+
import typing
1214

1315
import torch
1416
import torch.nn as nn
1517

18+
if typing.TYPE_CHECKING:
19+
from torch_geometric.data.batch import Batch
20+
1621
from fairchem.core.common.registry import registry
22+
from fairchem.core.models.base import GraphModelMixin
1723
from fairchem.core.models.escn.so3_exportable import (
1824
CoefficientMapping,
1925
SO3_Grid,
@@ -32,15 +38,15 @@
3238

3339

3440
@registry.register_model("escn_export")
35-
class eSCN(nn.Module):
41+
class eSCN(nn.Module, GraphModelMixin):
3642
"""Equivariant Spherical Channel Network
3743
Paper: Reducing SO(3) Convolutions to SO(2) for Efficient Equivariant GNNs
3844
3945
4046
Args:
41-
regress_forces (bool): Compute forces
42-
cutoff (float): Maximum distance between nieghboring atoms in Angstroms
43-
max_num_elements (int): Maximum atomic number
47+
max_neighbors(int): Max neighbors to take per node, when using the graph generation
48+
cutoff (float): Maximum distance between nieghboring atoms in Angstroms
49+
max_num_elements (int): Maximum atomic number
4450
num_layers (int): Number of layers in the GNN
4551
lmax (int): maximum degree of the spherical harmonics (1 to 10)
4652
mmax (int): maximum order of the spherical harmonics (0 to lmax)
@@ -51,13 +57,15 @@ class eSCN(nn.Module):
5157
distance_function ("gaussian", "sigmoid", "linearsigmoid", "silu"): Basis function used for distances
5258
basis_width_scalar (float): Width of distance basis function
5359
distance_resolution (float): Distance between distance basis functions in Angstroms
60+
compile (bool): use torch.compile on the forward
61+
export (bool): use the exportable version of the module
5462
"""
5563

5664
def __init__(
5765
self,
58-
regress_forces: bool = True,
66+
max_neighbors: int = 300,
5967
cutoff: float = 8.0,
60-
max_num_elements: int = 90,
68+
max_num_elements: int = 100,
6169
num_layers: int = 8,
6270
lmax: int = 4,
6371
mmax: int = 2,
@@ -69,6 +77,8 @@ def __init__(
6977
basis_width_scalar: float = 1.0,
7078
distance_resolution: float = 0.02,
7179
resolution: int | None = None,
80+
compile: bool = False,
81+
export: bool = False,
7282
) -> None:
7383
super().__init__()
7484

@@ -78,7 +88,7 @@ def __init__(
7888
logging.error("You need to install the e3nn library to use the SCN model")
7989
raise ImportError
8090

81-
self.regress_forces = regress_forces
91+
self.max_neighbors = max_neighbors
8292
self.cutoff = cutoff
8393
self.max_num_elements = max_num_elements
8494
self.hidden_channels = hidden_channels
@@ -91,6 +101,8 @@ def __init__(
91101
self.mmax = mmax
92102
self.basis_width_scalar = basis_width_scalar
93103
self.distance_function = distance_function
104+
self.compile = compile
105+
self.export = export
94106

95107
# non-linear activation function used throughout the network
96108
self.act = nn.SiLU()
@@ -169,10 +181,9 @@ def __init__(
169181
self.energy_block = EnergyBlock(
170182
self.sphere_channels, self.num_sphere_samples, self.act
171183
)
172-
if self.regress_forces:
173-
self.force_block = ForceBlock(
174-
self.sphere_channels, self.num_sphere_samples, self.act
175-
)
184+
self.force_block = ForceBlock(
185+
self.sphere_channels, self.num_sphere_samples, self.act
186+
)
176187

177188
# Create a roughly evenly distributed point sampling of the sphere for the output blocks
178189
self.sphere_points = nn.Parameter(
@@ -189,29 +200,96 @@ def __init__(
189200
requires_grad=False,
190201
)
191202

192-
def forward(self, data: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
193-
pos: torch.Tensor = data["pos"]
194-
batch_idx: torch.Tensor = data["batch"]
195-
natoms: torch.Tensor = data["natoms"]
196-
atomic_numbers: torch.Tensor = data["atomic_numbers"]
197-
edge_index: torch.Tensor = data["edge_index"]
198-
edge_distance: torch.Tensor = data["distances"]
199-
edge_distance_vec: torch.Tensor = data["edge_distance_vec"]
200-
201-
atomic_numbers = atomic_numbers.long()
202-
# TODO: this requires upgrade to torch2.4 with export non-strict mode to enable
203-
# assert (
204-
# atomic_numbers.max().item() < self.max_num_elements
205-
# ), "Atomic number exceeds that given in model config"
203+
self.sph_feature_size = int((self.lmax + 1) ** 2)
204+
# Pre-load Jd tensors for wigner matrices
205+
# Borrowed from e3nn @ 0.4.0:
206+
# https://github.com/e3nn/e3nn/blob/0.4.0/e3nn/o3/_wigner.py#L10
207+
# _Jd is a list of tensors of shape (2l+1, 2l+1)
208+
# TODO: we should probably just bake this into the file as strings to avoid
209+
# carrying this extra file around
210+
Jd_list = torch.load(os.path.join(os.path.dirname(__file__), "Jd.pt"))
211+
for l in range(self.lmax + 1):
212+
self.register_buffer(f"Jd_{l}", Jd_list[l])
213+
214+
if self.compile:
215+
logging.info("Using the compiled escn forward function...")
216+
self.forward = torch.compile(
217+
options={"triton.cudagraphs": True}, fullgraph=True, dynamic=True
218+
)(self.forward)
219+
220+
# torch.export only works with nn.module with an unaltered forward function,
221+
# furthermore AOT Inductor currently requires a flat list of inputs
222+
# this we need keep the module.forward function as the fully exportable region
223+
# When not using export, ie for training, we swap out the forward with a version
224+
# that wraps it with the graph generator
225+
#
226+
# TODO: this is really ugly and confusing to read, find a better way to deal
227+
# with partially exportable model
228+
if not self.export:
229+
self._forward = self.forward
230+
self.forward = self.forward_trainable
231+
232+
def forward_trainable(self, data: Batch) -> dict[str, torch.Tensor]:
233+
# standard forward call that generates the graph on-the-fly with generate_graph
234+
# this part of the code is not compile/export friendly so we keep it separated and wrap the exportaable forward
235+
graph = self.generate_graph(
236+
data,
237+
max_neighbors=self.max_neighbors,
238+
otf_graph=True,
239+
use_pbc=True,
240+
use_pbc_single=True,
241+
)
242+
energy, forces = self._forward(
243+
data.pos,
244+
data.batch,
245+
data.natoms,
246+
data.atomic_numbers.long(),
247+
graph.edge_index,
248+
graph.edge_distance,
249+
graph.edge_distance_vec,
250+
)
251+
return {"energy": energy, "forces": forces}
252+
253+
# a fully compilable/exportable forward function
254+
# takes a full graph with edges as input
255+
def forward(
256+
self,
257+
pos: torch.Tensor,
258+
batch_idx: torch.Tensor,
259+
natoms: torch.Tensor,
260+
atomic_numbers: torch.Tensor,
261+
edge_index: torch.Tensor,
262+
edge_distance: torch.Tensor,
263+
edge_distance_vec: torch.Tensor,
264+
) -> list[torch.Tensor]:
265+
"""
266+
N: num atoms
267+
N: batch size
268+
E: num edges
269+
270+
pos: [N, 3] atom positions
271+
batch_idx: [N] batch index of each atom
272+
natoms: [B] number of atoms in each batch
273+
atomic_numbers: [N] atomic number per atom
274+
edge_index: [2, E] edges between source and target atoms
275+
edge_distance: [E] cartesian distance for each edge
276+
edge_distance_vec: [E, 3] direction vector of edges (includes pbc)
277+
"""
278+
if not self.export and not self.compile:
279+
assert atomic_numbers.max().item() < self.max_num_elements
206280
num_atoms = len(atomic_numbers)
207281

208282
###############################################################
209283
# Initialize data structures
210284
###############################################################
211285

212286
# Compute 3x3 rotation matrix per edge
213-
edge_rot_mat = self._init_edge_rot_mat(edge_index, edge_distance_vec)
214-
wigner = rotation_to_wigner(edge_rot_mat, 0, self.lmax).detach()
287+
edge_rot_mat = self._init_edge_rot_mat(edge_distance_vec)
288+
Jd_buffers = [
289+
getattr(self, f"Jd_{l}").type(edge_rot_mat.dtype)
290+
for l in range(self.lmax + 1)
291+
]
292+
wigner = rotation_to_wigner(edge_rot_mat, 0, self.lmax, Jd_buffers).detach()
215293

216294
###############################################################
217295
# Initialize node embeddings
@@ -220,7 +298,7 @@ def forward(self, data: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
220298
# Init per node representations using an atomic number based embedding
221299
x_message = torch.zeros(
222300
num_atoms,
223-
int((self.lmax + 1) ** 2),
301+
self.sph_feature_size,
224302
self.sphere_channels,
225303
device=pos.device,
226304
dtype=pos.dtype,
@@ -266,31 +344,20 @@ def forward(self, data: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
266344
# Scale energy to help balance numerical precision w.r.t. forces
267345
energy = energy * 0.001
268346

269-
outputs = {"energy": energy}
270347
###############################################################
271348
# Force estimation
272349
###############################################################
273-
if self.regress_forces:
274-
forces = self.force_block(x_pt, self.sphere_points)
275-
outputs["forces"] = forces
350+
forces = self.force_block(x_pt, self.sphere_points)
276351

277-
return outputs
352+
return energy, forces
278353

279354
# Initialize the edge rotation matrics
280-
def _init_edge_rot_mat(self, edge_index, edge_distance_vec):
355+
def _init_edge_rot_mat(self, edge_distance_vec):
281356
edge_vec_0 = edge_distance_vec
282357
edge_vec_0_distance = torch.sqrt(torch.sum(edge_vec_0**2, dim=1))
283358

284359
# Make sure the atoms are far enough apart
285-
# TODO: this requires upgrade to torch2.4 with export non-strict mode to enable
286-
# if torch.min(edge_vec_0_distance) < 0.0001:
287-
# logging.error(
288-
# f"Error edge_vec_0_distance: {torch.min(edge_vec_0_distance)}"
289-
# )
290-
# (minval, minidx) = torch.min(edge_vec_0_distance, 0)
291-
# logging.error(
292-
# f"Error edge_vec_0_distance: {minidx} {edge_index[0, minidx]} {edge_index[1, minidx]} {data.pos[edge_index[0, minidx]]} {data.pos[edge_index[1, minidx]]}"
293-
# )
360+
# assert torch.min(edge_vec_0_distance) < 0.0001
294361

295362
norm_x = edge_vec_0 / (edge_vec_0_distance.view(-1, 1))
296363

src/fairchem/core/models/escn/so3_exportable.py

+28-27
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
import math
4-
import os
54

65
import torch
76

@@ -11,51 +10,53 @@
1110
except ImportError:
1211
pass
1312

14-
# Borrowed from e3nn @ 0.4.0:
15-
# https://github.com/e3nn/e3nn/blob/0.4.0/e3nn/o3/_wigner.py#L10
16-
# _Jd is a list of tensors of shape (2l+1, 2l+1)
17-
__Jd = torch.load(os.path.join(os.path.dirname(__file__), "Jd.pt"))
18-
19-
20-
@torch.compiler.assume_constant_result
21-
def get_jd() -> torch.Tensor:
22-
return __Jd
23-
2413

2514
# Borrowed from e3nn @ 0.4.0:
2615
# https://github.com/e3nn/e3nn/blob/0.4.0/e3nn/o3/_wigner.py#L37
2716
#
2817
# In 0.5.0, e3nn shifted to torch.matrix_exp which is significantly slower:
2918
# https://github.com/e3nn/e3nn/blob/0.5.0/e3nn/o3/_wigner.py#L92
3019
def wigner_D(
31-
lv: int, alpha: torch.Tensor, beta: torch.Tensor, gamma: torch.Tensor
20+
lv: int,
21+
alpha: torch.Tensor,
22+
beta: torch.Tensor,
23+
gamma: torch.Tensor,
24+
_Jd: list[torch.Tensor],
3225
) -> torch.Tensor:
33-
_Jd = get_jd()
34-
assert (
35-
lv < len(_Jd)
36-
), f"wigner D maximum l implemented is {len(_Jd) - 1}, send us an email to ask for more"
37-
3826
alpha, beta, gamma = torch.broadcast_tensors(alpha, beta, gamma)
39-
J = _Jd[lv].to(dtype=alpha.dtype, device=alpha.device)
27+
J = _Jd[lv]
4028
Xa = _z_rot_mat(alpha, lv)
4129
Xb = _z_rot_mat(beta, lv)
4230
Xc = _z_rot_mat(gamma, lv)
4331
return Xa @ J @ Xb @ J @ Xc
4432

4533

4634
def _z_rot_mat(angle: torch.Tensor, lv: int) -> torch.Tensor:
47-
shape, device, dtype = angle.shape, angle.device, angle.dtype
48-
M = angle.new_zeros((*shape, 2 * lv + 1, 2 * lv + 1))
49-
inds = torch.arange(0, 2 * lv + 1, 1, device=device)
50-
reversed_inds = torch.arange(2 * lv, -1, -1, device=device)
51-
frequencies = torch.arange(lv, -lv - 1, -1, dtype=dtype, device=device)
52-
M[..., inds, reversed_inds] = torch.sin(frequencies * angle[..., None])
53-
M[..., inds, inds] = torch.cos(frequencies * angle[..., None])
35+
M = angle.new_zeros((*angle.shape, 2 * lv + 1, 2 * lv + 1))
36+
37+
# The following code needs to replaced for a for loop because
38+
# torch.export barfs on outer product like operations
39+
# ie: torch.outer(frequences, angle) (same as frequencies * angle[..., None])
40+
# will place a non-sense Guard on the dimensions of angle when attempting to export setting
41+
# angle (edge dimensions) as dynamic. This may be fixed in torch2.4.
42+
43+
# inds = torch.arange(0, 2 * lv + 1, 1, device=device)
44+
# reversed_inds = torch.arange(2 * lv, -1, -1, device=device)
45+
# frequencies = torch.arange(lv, -lv - 1, -1, dtype=dtype, device=device)
46+
# M[..., inds, reversed_inds] = torch.sin(frequencies * angle[..., None])
47+
# M[..., inds, inds] = torch.cos(frequencies * angle[..., None])
48+
49+
inds = list(range(0, 2 * lv + 1, 1))
50+
reversed_inds = list(range(2 * lv, -1, -1))
51+
frequencies = list(range(lv, -lv - 1, -1))
52+
for i in range(len(frequencies)):
53+
M[..., inds[i], reversed_inds[i]] = torch.sin(frequencies[i] * angle)
54+
M[..., inds[i], inds[i]] = torch.cos(frequencies[i] * angle)
5455
return M
5556

5657

5758
def rotation_to_wigner(
58-
edge_rot_mat: torch.Tensor, start_lmax: int, end_lmax: int
59+
edge_rot_mat: torch.Tensor, start_lmax: int, end_lmax: int, Jd: list[torch.Tensor]
5960
) -> torch.Tensor:
6061
x = edge_rot_mat @ edge_rot_mat.new_tensor([0.0, 1.0, 0.0])
6162
alpha, beta = o3.xyz_to_angles(x)
@@ -69,7 +70,7 @@ def rotation_to_wigner(
6970
wigner = torch.zeros(len(alpha), size, size, device=edge_rot_mat.device)
7071
start = 0
7172
for lmax in range(start_lmax, end_lmax + 1):
72-
block = wigner_D(lmax, alpha, beta, gamma)
73+
block = wigner_D(lmax, alpha, beta, gamma, Jd)
7374
end = start + block.size()[1]
7475
wigner[:, start:end, start:end] = block
7576
start = end

0 commit comments

Comments
 (0)