|
| 1 | +""" |
| 2 | +Copyright (c) Meta, Inc. and its affiliates. |
| 3 | +
|
| 4 | +This source code is licensed under the MIT license found in the |
| 5 | +LICENSE file in the root directory of this source tree. |
| 6 | +""" |
| 7 | + |
1 | 8 | from __future__ import annotations
|
2 | 9 |
|
3 | 10 | import contextlib
|
4 | 11 | import logging
|
5 |
| -import math |
| 12 | +import typing |
6 | 13 | from functools import partial
|
7 | 14 |
|
8 | 15 | import torch
|
9 | 16 | import torch.nn as nn
|
| 17 | +from typing_extensions import deprecated |
10 | 18 |
|
11 | 19 | from fairchem.core.common import gp_utils
|
12 | 20 | from fairchem.core.common.registry import registry
|
13 | 21 | from fairchem.core.common.utils import conditional_grad
|
14 | 22 | from fairchem.core.models.base import (
|
15 | 23 | GraphModelMixin,
|
16 |
| - HeadInterface, |
17 | 24 | )
|
| 25 | +from fairchem.core.models.equiformer_v2.heads import EqV2ScalarHead, EqV2VectorHead |
18 | 26 | from fairchem.core.models.scn.smearing import GaussianSmearing
|
19 | 27 |
|
20 |
| -with contextlib.suppress(ImportError): |
21 |
| - pass |
22 |
| - |
23 |
| - |
24 |
| -import typing |
25 |
| - |
26 | 28 | from .edge_rot_mat import init_edge_rot_mat
|
27 | 29 | from .gaussian_rbf import GaussianRadialBasisLayer
|
28 | 30 | from .input_block import EdgeDegreeEmbedding
|
|
34 | 36 | get_normalization_layer,
|
35 | 37 | )
|
36 | 38 | from .module_list import ModuleListInfo
|
37 |
| -from .radial_function import RadialFunction |
38 | 39 | from .so3 import (
|
39 | 40 | CoefficientMappingModule,
|
40 | 41 | SO3_Embedding,
|
|
43 | 44 | SO3_Rotation,
|
44 | 45 | )
|
45 | 46 | from .transformer_block import (
|
46 |
| - FeedForwardNetwork, |
47 |
| - SO2EquivariantGraphAttention, |
48 | 47 | TransBlockV2,
|
49 | 48 | )
|
| 49 | +from .weight_initialization import eqv2_init_weights |
| 50 | + |
| 51 | +with contextlib.suppress(ImportError): |
| 52 | + pass |
50 | 53 |
|
51 | 54 | if typing.TYPE_CHECKING:
|
52 | 55 | from torch_geometric.data.batch import Batch
|
53 | 56 |
|
54 |
| - from fairchem.core.models.base import GraphData |
55 |
| - |
56 | 57 | # Statistics of IS2RE 100K
|
57 | 58 | _AVG_NUM_NODES = 77.81317
|
58 | 59 | _AVG_DEGREE = 23.395238876342773 # IS2RE: 100k, max_radius = 5, max_neighbors = 100
|
59 | 60 |
|
60 | 61 |
|
61 |
| -def eqv2_init_weights(m, weight_init): |
62 |
| - if isinstance(m, (torch.nn.Linear, SO3_LinearV2)): |
63 |
| - if m.bias is not None: |
64 |
| - torch.nn.init.constant_(m.bias, 0) |
65 |
| - if weight_init == "normal": |
66 |
| - std = 1 / math.sqrt(m.in_features) |
67 |
| - torch.nn.init.normal_(m.weight, 0, std) |
68 |
| - elif isinstance(m, torch.nn.LayerNorm): |
69 |
| - torch.nn.init.constant_(m.bias, 0) |
70 |
| - torch.nn.init.constant_(m.weight, 1.0) |
71 |
| - elif isinstance(m, RadialFunction): |
72 |
| - m.apply(eqv2_uniform_init_linear_weights) |
| 62 | +@deprecated( |
| 63 | + "equiformer_v2_force_head (EquiformerV2ForceHead) class is deprecated in favor of equiformerV2_rank1_head (EqV2Rank1Head)" |
| 64 | +) |
| 65 | +@registry.register_model("equiformer_v2_force_head") |
| 66 | +class EquiformerV2ForceHead(EqV2VectorHead): |
| 67 | + def __init__(self, backbone): |
| 68 | + logging.warning( |
| 69 | + "equiformerV2_force_head (EquiformerV2ForceHead) class is deprecated in favor of equiformerV2_rank1_head (EqV2Rank1Head)" |
| 70 | + ) |
| 71 | + super().__init__(backbone) |
73 | 72 |
|
74 | 73 |
|
75 |
| -def eqv2_uniform_init_linear_weights(m): |
76 |
| - if isinstance(m, torch.nn.Linear): |
77 |
| - if m.bias is not None: |
78 |
| - torch.nn.init.constant_(m.bias, 0) |
79 |
| - std = 1 / math.sqrt(m.in_features) |
80 |
| - torch.nn.init.uniform_(m.weight, -std, std) |
| 74 | +@deprecated( |
| 75 | + "equiformer_v2_energy_head (EquiformerV2EnergyHead) class is deprecated in favor of equiformerV2_scalar_head (EqV2ScalarHead)" |
| 76 | +) |
| 77 | +@registry.register_model("equiformer_v2_energy_head") |
| 78 | +class EquiformerV2EnergyHead(EqV2ScalarHead): |
| 79 | + def __init__(self, backbone, reduce: str = "sum"): |
| 80 | + logging.warning( |
| 81 | + "equiformerV2_energy_head (EquiformerV2EnergyHead) class is deprecated in favor of equiformerV2_scalar_head (EqV2ScalarHead)" |
| 82 | + ) |
| 83 | + super().__init__(backbone, reduce=reduce) |
81 | 84 |
|
82 | 85 |
|
83 | 86 | @registry.register_model("equiformer_v2_backbone")
|
@@ -606,102 +609,3 @@ def no_weight_decay(self) -> set:
|
606 | 609 | no_wd_list.append(global_parameter_name)
|
607 | 610 |
|
608 | 611 | return set(no_wd_list)
|
609 |
| - |
610 |
| - |
611 |
| -@registry.register_model("equiformer_v2_energy_head") |
612 |
| -class EquiformerV2EnergyHead(nn.Module, HeadInterface): |
613 |
| - def __init__(self, backbone, reduce: str = "sum"): |
614 |
| - super().__init__() |
615 |
| - self.reduce = reduce |
616 |
| - self.avg_num_nodes = backbone.avg_num_nodes |
617 |
| - self.energy_block = FeedForwardNetwork( |
618 |
| - backbone.sphere_channels, |
619 |
| - backbone.ffn_hidden_channels, |
620 |
| - 1, |
621 |
| - backbone.lmax_list, |
622 |
| - backbone.mmax_list, |
623 |
| - backbone.SO3_grid, |
624 |
| - backbone.ffn_activation, |
625 |
| - backbone.use_gate_act, |
626 |
| - backbone.use_grid_mlp, |
627 |
| - backbone.use_sep_s2_act, |
628 |
| - ) |
629 |
| - self.apply(partial(eqv2_init_weights, weight_init=backbone.weight_init)) |
630 |
| - |
631 |
| - def forward(self, data: Batch, emb: dict[str, torch.Tensor | GraphData]): |
632 |
| - node_energy = self.energy_block(emb["node_embedding"]) |
633 |
| - node_energy = node_energy.embedding.narrow(1, 0, 1) |
634 |
| - if gp_utils.initialized(): |
635 |
| - node_energy = gp_utils.gather_from_model_parallel_region(node_energy, dim=0) |
636 |
| - energy = torch.zeros( |
637 |
| - len(data.natoms), |
638 |
| - device=node_energy.device, |
639 |
| - dtype=node_energy.dtype, |
640 |
| - ) |
641 |
| - |
642 |
| - energy.index_add_(0, data.batch, node_energy.view(-1)) |
643 |
| - if self.reduce == "sum": |
644 |
| - return {"energy": energy / self.avg_num_nodes} |
645 |
| - elif self.reduce == "mean": |
646 |
| - return {"energy": energy / data.natoms} |
647 |
| - else: |
648 |
| - raise ValueError( |
649 |
| - f"reduce can only be sum or mean, user provided: {self.reduce}" |
650 |
| - ) |
651 |
| - |
652 |
| - |
653 |
| -@registry.register_model("equiformer_v2_force_head") |
654 |
| -class EquiformerV2ForceHead(nn.Module, HeadInterface): |
655 |
| - def __init__(self, backbone): |
656 |
| - super().__init__() |
657 |
| - |
658 |
| - self.activation_checkpoint = backbone.activation_checkpoint |
659 |
| - self.force_block = SO2EquivariantGraphAttention( |
660 |
| - backbone.sphere_channels, |
661 |
| - backbone.attn_hidden_channels, |
662 |
| - backbone.num_heads, |
663 |
| - backbone.attn_alpha_channels, |
664 |
| - backbone.attn_value_channels, |
665 |
| - 1, |
666 |
| - backbone.lmax_list, |
667 |
| - backbone.mmax_list, |
668 |
| - backbone.SO3_rotation, |
669 |
| - backbone.mappingReduced, |
670 |
| - backbone.SO3_grid, |
671 |
| - backbone.max_num_elements, |
672 |
| - backbone.edge_channels_list, |
673 |
| - backbone.block_use_atom_edge_embedding, |
674 |
| - backbone.use_m_share_rad, |
675 |
| - backbone.attn_activation, |
676 |
| - backbone.use_s2_act_attn, |
677 |
| - backbone.use_attn_renorm, |
678 |
| - backbone.use_gate_act, |
679 |
| - backbone.use_sep_s2_act, |
680 |
| - alpha_drop=0.0, |
681 |
| - ) |
682 |
| - self.apply(partial(eqv2_init_weights, weight_init=backbone.weight_init)) |
683 |
| - |
684 |
| - def forward(self, data: Batch, emb: dict[str, torch.Tensor]): |
685 |
| - if self.activation_checkpoint: |
686 |
| - forces = torch.utils.checkpoint.checkpoint( |
687 |
| - self.force_block, |
688 |
| - emb["node_embedding"], |
689 |
| - emb["graph"].atomic_numbers_full, |
690 |
| - emb["graph"].edge_distance, |
691 |
| - emb["graph"].edge_index, |
692 |
| - emb["graph"].node_offset, |
693 |
| - use_reentrant=not self.training, |
694 |
| - ) |
695 |
| - else: |
696 |
| - forces = self.force_block( |
697 |
| - emb["node_embedding"], |
698 |
| - emb["graph"].atomic_numbers_full, |
699 |
| - emb["graph"].edge_distance, |
700 |
| - emb["graph"].edge_index, |
701 |
| - node_offset=emb["graph"].node_offset, |
702 |
| - ) |
703 |
| - forces = forces.embedding.narrow(1, 1, 3) |
704 |
| - forces = forces.view(-1, 3).contiguous() |
705 |
| - if gp_utils.initialized(): |
706 |
| - forces = gp_utils.gather_from_model_parallel_region(forces, dim=0) |
707 |
| - return {"forces": forces} |
0 commit comments