Skip to content

Commit f4e0543

Browse files
authored
Rename hydra heads (#903)
* refactor eqV2 heads * refactor init_weights * add output_name attribute to eqV2 heads * fix deprecated registry names * remove debug breakpoint * add default name for rank2 head Former-commit-id: 4ea0231431c29fed1d61b888ed8ddbf957e5bf13
1 parent 851da37 commit f4e0543

File tree

8 files changed

+231
-138
lines changed

8 files changed

+231
-138
lines changed

src/fairchem/core/models/equiformer_v2/equiformer_v2.py

+34-130
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,30 @@
1+
"""
2+
Copyright (c) Meta, Inc. and its affiliates.
3+
4+
This source code is licensed under the MIT license found in the
5+
LICENSE file in the root directory of this source tree.
6+
"""
7+
18
from __future__ import annotations
29

310
import contextlib
411
import logging
5-
import math
12+
import typing
613
from functools import partial
714

815
import torch
916
import torch.nn as nn
17+
from typing_extensions import deprecated
1018

1119
from fairchem.core.common import gp_utils
1220
from fairchem.core.common.registry import registry
1321
from fairchem.core.common.utils import conditional_grad
1422
from fairchem.core.models.base import (
1523
GraphModelMixin,
16-
HeadInterface,
1724
)
25+
from fairchem.core.models.equiformer_v2.heads import EqV2ScalarHead, EqV2VectorHead
1826
from fairchem.core.models.scn.smearing import GaussianSmearing
1927

20-
with contextlib.suppress(ImportError):
21-
pass
22-
23-
24-
import typing
25-
2628
from .edge_rot_mat import init_edge_rot_mat
2729
from .gaussian_rbf import GaussianRadialBasisLayer
2830
from .input_block import EdgeDegreeEmbedding
@@ -34,7 +36,6 @@
3436
get_normalization_layer,
3537
)
3638
from .module_list import ModuleListInfo
37-
from .radial_function import RadialFunction
3839
from .so3 import (
3940
CoefficientMappingModule,
4041
SO3_Embedding,
@@ -43,41 +44,43 @@
4344
SO3_Rotation,
4445
)
4546
from .transformer_block import (
46-
FeedForwardNetwork,
47-
SO2EquivariantGraphAttention,
4847
TransBlockV2,
4948
)
49+
from .weight_initialization import eqv2_init_weights
50+
51+
with contextlib.suppress(ImportError):
52+
pass
5053

5154
if typing.TYPE_CHECKING:
5255
from torch_geometric.data.batch import Batch
5356

54-
from fairchem.core.models.base import GraphData
55-
5657
# Statistics of IS2RE 100K
5758
_AVG_NUM_NODES = 77.81317
5859
_AVG_DEGREE = 23.395238876342773 # IS2RE: 100k, max_radius = 5, max_neighbors = 100
5960

6061

61-
def eqv2_init_weights(m, weight_init):
62-
if isinstance(m, (torch.nn.Linear, SO3_LinearV2)):
63-
if m.bias is not None:
64-
torch.nn.init.constant_(m.bias, 0)
65-
if weight_init == "normal":
66-
std = 1 / math.sqrt(m.in_features)
67-
torch.nn.init.normal_(m.weight, 0, std)
68-
elif isinstance(m, torch.nn.LayerNorm):
69-
torch.nn.init.constant_(m.bias, 0)
70-
torch.nn.init.constant_(m.weight, 1.0)
71-
elif isinstance(m, RadialFunction):
72-
m.apply(eqv2_uniform_init_linear_weights)
62+
@deprecated(
63+
"equiformer_v2_force_head (EquiformerV2ForceHead) class is deprecated in favor of equiformerV2_rank1_head (EqV2Rank1Head)"
64+
)
65+
@registry.register_model("equiformer_v2_force_head")
66+
class EquiformerV2ForceHead(EqV2VectorHead):
67+
def __init__(self, backbone):
68+
logging.warning(
69+
"equiformerV2_force_head (EquiformerV2ForceHead) class is deprecated in favor of equiformerV2_rank1_head (EqV2Rank1Head)"
70+
)
71+
super().__init__(backbone)
7372

7473

75-
def eqv2_uniform_init_linear_weights(m):
76-
if isinstance(m, torch.nn.Linear):
77-
if m.bias is not None:
78-
torch.nn.init.constant_(m.bias, 0)
79-
std = 1 / math.sqrt(m.in_features)
80-
torch.nn.init.uniform_(m.weight, -std, std)
74+
@deprecated(
75+
"equiformer_v2_energy_head (EquiformerV2EnergyHead) class is deprecated in favor of equiformerV2_scalar_head (EqV2ScalarHead)"
76+
)
77+
@registry.register_model("equiformer_v2_energy_head")
78+
class EquiformerV2EnergyHead(EqV2ScalarHead):
79+
def __init__(self, backbone, reduce: str = "sum"):
80+
logging.warning(
81+
"equiformerV2_energy_head (EquiformerV2EnergyHead) class is deprecated in favor of equiformerV2_scalar_head (EqV2ScalarHead)"
82+
)
83+
super().__init__(backbone, reduce=reduce)
8184

8285

8386
@registry.register_model("equiformer_v2_backbone")
@@ -606,102 +609,3 @@ def no_weight_decay(self) -> set:
606609
no_wd_list.append(global_parameter_name)
607610

608611
return set(no_wd_list)
609-
610-
611-
@registry.register_model("equiformer_v2_energy_head")
612-
class EquiformerV2EnergyHead(nn.Module, HeadInterface):
613-
def __init__(self, backbone, reduce: str = "sum"):
614-
super().__init__()
615-
self.reduce = reduce
616-
self.avg_num_nodes = backbone.avg_num_nodes
617-
self.energy_block = FeedForwardNetwork(
618-
backbone.sphere_channels,
619-
backbone.ffn_hidden_channels,
620-
1,
621-
backbone.lmax_list,
622-
backbone.mmax_list,
623-
backbone.SO3_grid,
624-
backbone.ffn_activation,
625-
backbone.use_gate_act,
626-
backbone.use_grid_mlp,
627-
backbone.use_sep_s2_act,
628-
)
629-
self.apply(partial(eqv2_init_weights, weight_init=backbone.weight_init))
630-
631-
def forward(self, data: Batch, emb: dict[str, torch.Tensor | GraphData]):
632-
node_energy = self.energy_block(emb["node_embedding"])
633-
node_energy = node_energy.embedding.narrow(1, 0, 1)
634-
if gp_utils.initialized():
635-
node_energy = gp_utils.gather_from_model_parallel_region(node_energy, dim=0)
636-
energy = torch.zeros(
637-
len(data.natoms),
638-
device=node_energy.device,
639-
dtype=node_energy.dtype,
640-
)
641-
642-
energy.index_add_(0, data.batch, node_energy.view(-1))
643-
if self.reduce == "sum":
644-
return {"energy": energy / self.avg_num_nodes}
645-
elif self.reduce == "mean":
646-
return {"energy": energy / data.natoms}
647-
else:
648-
raise ValueError(
649-
f"reduce can only be sum or mean, user provided: {self.reduce}"
650-
)
651-
652-
653-
@registry.register_model("equiformer_v2_force_head")
654-
class EquiformerV2ForceHead(nn.Module, HeadInterface):
655-
def __init__(self, backbone):
656-
super().__init__()
657-
658-
self.activation_checkpoint = backbone.activation_checkpoint
659-
self.force_block = SO2EquivariantGraphAttention(
660-
backbone.sphere_channels,
661-
backbone.attn_hidden_channels,
662-
backbone.num_heads,
663-
backbone.attn_alpha_channels,
664-
backbone.attn_value_channels,
665-
1,
666-
backbone.lmax_list,
667-
backbone.mmax_list,
668-
backbone.SO3_rotation,
669-
backbone.mappingReduced,
670-
backbone.SO3_grid,
671-
backbone.max_num_elements,
672-
backbone.edge_channels_list,
673-
backbone.block_use_atom_edge_embedding,
674-
backbone.use_m_share_rad,
675-
backbone.attn_activation,
676-
backbone.use_s2_act_attn,
677-
backbone.use_attn_renorm,
678-
backbone.use_gate_act,
679-
backbone.use_sep_s2_act,
680-
alpha_drop=0.0,
681-
)
682-
self.apply(partial(eqv2_init_weights, weight_init=backbone.weight_init))
683-
684-
def forward(self, data: Batch, emb: dict[str, torch.Tensor]):
685-
if self.activation_checkpoint:
686-
forces = torch.utils.checkpoint.checkpoint(
687-
self.force_block,
688-
emb["node_embedding"],
689-
emb["graph"].atomic_numbers_full,
690-
emb["graph"].edge_distance,
691-
emb["graph"].edge_index,
692-
emb["graph"].node_offset,
693-
use_reentrant=not self.training,
694-
)
695-
else:
696-
forces = self.force_block(
697-
emb["node_embedding"],
698-
emb["graph"].atomic_numbers_full,
699-
emb["graph"].edge_distance,
700-
emb["graph"].edge_index,
701-
node_offset=emb["graph"].node_offset,
702-
)
703-
forces = forces.embedding.narrow(1, 1, 3)
704-
forces = forces.view(-1, 3).contiguous()
705-
if gp_utils.initialized():
706-
forces = gp_utils.gather_from_model_parallel_region(forces, dim=0)
707-
return {"forces": forces}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from __future__ import annotations
2+
3+
from .rank2 import Rank2SymmetricTensorHead
4+
from .scalar import EqV2ScalarHead
5+
from .vector import EqV2VectorHead
6+
7+
__all__ = ["EqV2ScalarHead", "EqV2VectorHead", "Rank2SymmetricTensorHead"]

src/fairchem/core/models/equiformer_v2/prediction_heads/rank2.py src/fairchem/core/models/equiformer_v2/heads/rank2.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616

1717
from fairchem.core.common.registry import registry
1818
from fairchem.core.models.base import BackboneInterface, HeadInterface
19-
from fairchem.core.models.equiformer_v2.equiformer_v2 import eqv2_init_weights
2019
from fairchem.core.models.equiformer_v2.layer_norm import get_normalization_layer
20+
from fairchem.core.models.equiformer_v2.weight_initialization import eqv2_init_weights
2121

2222

2323
class Rank2Block(nn.Module):
@@ -238,7 +238,7 @@ class Rank2SymmetricTensorHead(nn.Module, HeadInterface):
238238
def __init__(
239239
self,
240240
backbone: BackboneInterface,
241-
output_name: str,
241+
output_name: str = "stress",
242242
decompose: bool = False,
243243
edge_level_mlp: bool = False,
244244
num_mlp_layers: int = 2,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
"""
2+
Copyright (c) Meta, Inc. and its affiliates.
3+
4+
This source code is licensed under the MIT license found in the
5+
LICENSE file in the root directory of this source tree.
6+
"""
7+
8+
from __future__ import annotations
9+
10+
from functools import partial
11+
from typing import TYPE_CHECKING
12+
13+
import torch
14+
from torch import nn
15+
16+
from fairchem.core.common import gp_utils
17+
from fairchem.core.common.registry import registry
18+
from fairchem.core.models.base import GraphData, HeadInterface
19+
from fairchem.core.models.equiformer_v2.transformer_block import FeedForwardNetwork
20+
from fairchem.core.models.equiformer_v2.weight_initialization import eqv2_init_weights
21+
22+
if TYPE_CHECKING:
23+
from torch_geometric.data import Batch
24+
25+
26+
@registry.register_model("equiformerV2_scalar_head")
27+
class EqV2ScalarHead(nn.Module, HeadInterface):
28+
def __init__(self, backbone, output_name: str = "energy", reduce: str = "sum"):
29+
super().__init__()
30+
self.output_name = output_name
31+
self.reduce = reduce
32+
self.avg_num_nodes = backbone.avg_num_nodes
33+
self.energy_block = FeedForwardNetwork(
34+
backbone.sphere_channels,
35+
backbone.ffn_hidden_channels,
36+
1,
37+
backbone.lmax_list,
38+
backbone.mmax_list,
39+
backbone.SO3_grid,
40+
backbone.ffn_activation,
41+
backbone.use_gate_act,
42+
backbone.use_grid_mlp,
43+
backbone.use_sep_s2_act,
44+
)
45+
self.apply(partial(eqv2_init_weights, weight_init=backbone.weight_init))
46+
47+
def forward(self, data: Batch, emb: dict[str, torch.Tensor | GraphData]):
48+
node_output = self.energy_block(emb["node_embedding"])
49+
node_output = node_output.embedding.narrow(1, 0, 1)
50+
if gp_utils.initialized():
51+
node_output = gp_utils.gather_from_model_parallel_region(node_output, dim=0)
52+
output = torch.zeros(
53+
len(data.natoms),
54+
device=node_output.device,
55+
dtype=node_output.dtype,
56+
)
57+
58+
output.index_add_(0, data.batch, node_output.view(-1))
59+
if self.reduce == "sum":
60+
return {self.output_name: output / self.avg_num_nodes}
61+
elif self.reduce == "mean":
62+
return {self.output_name: output / data.natoms}
63+
else:
64+
raise ValueError(
65+
f"reduce can only be sum or mean, user provided: {self.reduce}"
66+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
"""
2+
Copyright (c) Meta, Inc. and its affiliates.
3+
4+
This source code is licensed under the MIT license found in the
5+
LICENSE file in the root directory of this source tree.
6+
"""
7+
8+
from __future__ import annotations
9+
10+
from functools import partial
11+
from typing import TYPE_CHECKING
12+
13+
import torch
14+
from torch import nn
15+
16+
from fairchem.core.common import gp_utils
17+
from fairchem.core.common.registry import registry
18+
from fairchem.core.models.base import HeadInterface
19+
from fairchem.core.models.equiformer_v2.transformer_block import (
20+
SO2EquivariantGraphAttention,
21+
)
22+
from fairchem.core.models.equiformer_v2.weight_initialization import eqv2_init_weights
23+
24+
if TYPE_CHECKING:
25+
from torch_geometric.data import Batch
26+
27+
from fairchem.core.models.base import BackboneInterface
28+
29+
30+
@registry.register_model("equiformerV2_vector_head")
31+
class EqV2VectorHead(nn.Module, HeadInterface):
32+
def __init__(self, backbone: BackboneInterface, output_name: str = "forces"):
33+
super().__init__()
34+
self.output_name = output_name
35+
self.activation_checkpoint = backbone.activation_checkpoint
36+
self.force_block = SO2EquivariantGraphAttention(
37+
backbone.sphere_channels,
38+
backbone.attn_hidden_channels,
39+
backbone.num_heads,
40+
backbone.attn_alpha_channels,
41+
backbone.attn_value_channels,
42+
1,
43+
backbone.lmax_list,
44+
backbone.mmax_list,
45+
backbone.SO3_rotation,
46+
backbone.mappingReduced,
47+
backbone.SO3_grid,
48+
backbone.max_num_elements,
49+
backbone.edge_channels_list,
50+
backbone.block_use_atom_edge_embedding,
51+
backbone.use_m_share_rad,
52+
backbone.attn_activation,
53+
backbone.use_s2_act_attn,
54+
backbone.use_attn_renorm,
55+
backbone.use_gate_act,
56+
backbone.use_sep_s2_act,
57+
alpha_drop=0.0,
58+
)
59+
self.apply(partial(eqv2_init_weights, weight_init=backbone.weight_init))
60+
61+
def forward(self, data: Batch, emb: dict[str, torch.Tensor]):
62+
if self.activation_checkpoint:
63+
output = torch.utils.checkpoint.checkpoint(
64+
self.force_block,
65+
emb["node_embedding"],
66+
emb["graph"].atomic_numbers_full,
67+
emb["graph"].edge_distance,
68+
emb["graph"].edge_index,
69+
emb["graph"].node_offset,
70+
use_reentrant=not self.training,
71+
)
72+
else:
73+
output = self.force_block(
74+
emb["node_embedding"],
75+
emb["graph"].atomic_numbers_full,
76+
emb["graph"].edge_distance,
77+
emb["graph"].edge_index,
78+
node_offset=emb["graph"].node_offset,
79+
)
80+
output = output.embedding.narrow(1, 1, 3)
81+
output = output.view(-1, 3).contiguous()
82+
if gp_utils.initialized():
83+
output = gp_utils.gather_from_model_parallel_region(output, dim=0)
84+
return {self.output_name: output}

0 commit comments

Comments
 (0)