-
Notifications
You must be signed in to change notification settings - Fork 20
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #156 from laserkelvin/upstream-mace
Upstream MACE implementation
- Loading branch information
Showing
7 changed files
with
471 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
from __future__ import annotations | ||
|
||
import pytorch_lightning as pl | ||
from torch import nn | ||
from e3nn.o3 import Irreps | ||
from mace.modules.blocks import RealAgnosticInteractionBlock | ||
|
||
from matsciml.datasets.transforms import ( | ||
PointCloudToGraphTransform, | ||
PeriodicPropertiesTransform, | ||
) | ||
from matsciml.lightning.data_utils import MatSciMLDataModule | ||
from matsciml.models.base import ScalarRegressionTask | ||
from matsciml.models.pyg.mace import MACEWrapper | ||
|
||
|
||
""" | ||
This example script runs through a fast development run of the IS2RE devset | ||
in combination with a PyG implementation of EGNN. | ||
""" | ||
|
||
# construct IS2RE relaxed energy regression with PyG implementation of E(n)-GNN | ||
task = ScalarRegressionTask( | ||
encoder_class=MACEWrapper, | ||
encoder_kwargs={ | ||
"r_max": 6.0, | ||
"num_bessel": 3, | ||
"num_polynomial_cutoff": 3, | ||
"max_ell": 2, | ||
"interaction_cls": RealAgnosticInteractionBlock, | ||
"interaction_cls_first": RealAgnosticInteractionBlock, | ||
"num_interactions": 2, | ||
"atom_embedding_dim": 64, | ||
"MLP_irreps": Irreps("256x0e"), | ||
"avg_num_neighbors": 10.0, | ||
"correlation": 1, | ||
"radial_type": "bessel", | ||
"gate": nn.Identity(), | ||
}, | ||
task_keys=["energy_relaxed"], | ||
) | ||
# matsciml devset for OCP are serialized with DGL - this transform goes between the two frameworks | ||
dm = MatSciMLDataModule.from_devset( | ||
"IS2REDataset", | ||
dset_kwargs={ | ||
"transforms": [ | ||
PeriodicPropertiesTransform(6.0, adaptive_cutoff=True), | ||
PointCloudToGraphTransform( | ||
"pyg", | ||
node_keys=["pos", "atomic_numbers"], | ||
), | ||
], | ||
}, | ||
) | ||
|
||
# run a quick training loop | ||
trainer = pl.Trainer(fast_dev_run=10) | ||
trainer.fit(task, datamodule=dm) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: MIT License | ||
from __future__ import annotations | ||
|
||
from inspect import signature | ||
from typing import Callable, Type | ||
|
||
from torch import nn | ||
|
||
""" | ||
Simple utility functions for inspecting functions and objects. | ||
The idea is that this should provide reusable functions that | ||
are useful for mapping kwargs onto classes that belong outside | ||
of this library, where we might not know what is required or not. | ||
""" | ||
|
||
|
||
def get_args_without_defaults(func: Callable, exclude_self: bool = True) -> list[str]: | ||
""" | ||
Inspect a function for required positional input arguments. | ||
The function works by looping through arguments and checking | ||
if defaults are available. The option ``exclude_self`` is also | ||
available to specify whether or not to remove ``self`` entries | ||
from the list, since this may correspond to class methods. | ||
Parameters | ||
---------- | ||
func : Callable | ||
A callable function with some input arguments. | ||
exclude_self | ||
If True, ignores ``self`` and ``cls`` from the returned | ||
list. | ||
Returns | ||
------- | ||
list[str] | ||
List of argument names that are required by ``func``. | ||
""" | ||
parameters = signature(func).parameters | ||
matches = [] | ||
for arg_name, parameter in parameters.items(): | ||
if getattr(parameter, "default", None): | ||
matches.append(arg_name) | ||
# remove self from list if requested | ||
if exclude_self: | ||
matches = list(filter(lambda x: x not in ["self", "cls"], matches)) | ||
return matches | ||
|
||
|
||
def get_all_args(func: Callable) -> list[str]: | ||
""" | ||
Get all the arguments of a function, include with defaults. | ||
Parameters | ||
---------- | ||
func : Callable | ||
Function to inspect arguments for. | ||
Returns | ||
------- | ||
list[str] | ||
List of argument names, positional and keyword. | ||
""" | ||
parameters = signature(func).parameters | ||
return list(parameters.keys()) | ||
|
||
|
||
def get_model_required_args(model: Type[nn.Module]) -> list[str]: | ||
""" | ||
Inspect a child of PyTorch ``nn.Module`` for required arguments. | ||
The idea behind this is to identify which parameters are needed to | ||
instantiate a model, which is useful for determining how args/kwargs | ||
should be unpacked into a model from an abstract interface. | ||
Parameters | ||
---------- | ||
model : Type[nn.Module] | ||
A model class to inspect | ||
Returns | ||
------- | ||
list[str] | ||
List of argument names that are required to instantiate ``model`` | ||
""" | ||
return get_args_without_defaults(model.__init__, exclude_self=True) | ||
|
||
|
||
def get_model_all_args(model: Type[nn.Module]) -> list[str]: | ||
""" | ||
Inspect a model for all of its initialization arguments, including | ||
optional ones. | ||
Parameters | ||
---------- | ||
model : Type[nn.Module] | ||
Model class to inspect for arguments. | ||
Returns | ||
------- | ||
list[str] | ||
List of arguments used by the model instantiation. | ||
""" | ||
return get_all_args(model.__init__) | ||
|
||
|
||
def get_model_forward_args(model: Type[nn.Module]) -> list[str]: | ||
""" | ||
Inspect a model's ``forward`` method for argument names. | ||
Parameters | ||
---------- | ||
model : Type[nn.Module] | ||
Model to inspect. | ||
Returns | ||
------- | ||
list[str] | ||
List of argument names in a model's forward method. | ||
""" | ||
return get_all_args(model.forward) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,6 @@ | ||
from __future__ import annotations | ||
|
||
from matsciml.models.pyg.mace.modules.models import MACE, ScaleShiftMACE | ||
from matsciml.models.pyg.mace.wrapper.model import MACEWrapper | ||
|
||
__all__ = [ | ||
"MACE", | ||
"ScaleShiftMACE", | ||
] | ||
__all__ = ["MACE", "ScaleShiftMACE", "MACEWrapper"] |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,181 @@ | ||
from __future__ import annotations | ||
|
||
from logging import getLogger | ||
from typing import Any, Callable | ||
from functools import cache | ||
|
||
import torch | ||
import numpy as np | ||
from e3nn.o3 import Irreps | ||
from mace.modules import MACE | ||
from torch_geometric.nn import pool | ||
|
||
from matsciml.models.base import AbstractPyGModel | ||
from matsciml.common.types import BatchDict, DataDict, AbstractGraph, Embeddings | ||
from matsciml.common.registry import registry | ||
from matsciml.common.inspection import get_model_required_args, get_model_all_args | ||
|
||
|
||
logger = getLogger(__file__) | ||
|
||
__all__ = ["MACEWrapper"] | ||
|
||
|
||
@registry.register_model("MACEWrapper") | ||
class MACEWrapper(AbstractPyGModel): | ||
def __init__( | ||
self, | ||
atom_embedding_dim: int, | ||
num_atom_embedding: int = 100, | ||
embedding_kwargs: Any = None, | ||
encoder_only: bool = True, | ||
readout_method: str | Callable = "add", | ||
**mace_kwargs, | ||
) -> None: | ||
if embedding_kwargs is not None: | ||
logger.warning("`embedding_kwargs` is not used for MACE models.") | ||
super().__init__(atom_embedding_dim, num_atom_embedding, {}, encoder_only) | ||
# dynamically check to check which arguments are needed by MACE | ||
__mace_required_args = get_model_required_args(MACE) | ||
__mace_all_args = get_model_all_args(MACE) | ||
for key in mace_kwargs: | ||
assert ( | ||
key in __mace_all_args | ||
), f"{key} was passed as a MACE kwarg but does not match expected arguments." | ||
# remove the embedding table, as MACE uses e3nn layers | ||
del self.atom_embedding | ||
if "num_elements" in mace_kwargs: | ||
raise KeyError( | ||
"Please use `num_atom_embedding` instead of passing `num_elements`." | ||
) | ||
if "hidden_irreps" in mace_kwargs: | ||
raise KeyError( | ||
"Please use `atom_embedding_dim` instead of passing `hidden_irreps`." | ||
) | ||
atom_embedding_dim = Irreps(f"{atom_embedding_dim}x0e") | ||
# pack stuff into the mace kwargs | ||
mace_kwargs["num_elements"] = num_atom_embedding | ||
mace_kwargs["hidden_irreps"] = atom_embedding_dim | ||
mace_kwargs["atomic_numbers"] = list(range(1, num_atom_embedding)) | ||
if "atomic_energies" not in mace_kwargs: | ||
logger.warning("No ``atomic_energies`` provided, defaulting to ones.") | ||
mace_kwargs["atomic_energies"] = np.ones(num_atom_embedding) | ||
# check to make sure all that's required is | ||
for key in __mace_required_args: | ||
if key not in mace_kwargs: | ||
raise KeyError( | ||
f"{key} is required by MACE, but was not found in kwargs." | ||
) | ||
self.encoder = MACE(**mace_kwargs) | ||
# if a string is passed, grab the PyG builtins | ||
if isinstance(readout_method, str): | ||
readout_type = getattr(pool, f"global_{readout_method}_pool", None) | ||
if not readout_type: | ||
possible_methods = list(filter(lambda x: "global" in x, dir(pool))) | ||
raise NotImplementedError( | ||
f"{readout_method} is not a valid function in PyG pooling." | ||
f" Supported methods are: {possible_methods}" | ||
) | ||
readout_method = readout_type | ||
self.readout = readout_method | ||
self.save_hyperparameters() | ||
|
||
@property | ||
@cache | ||
def _atom_eye(self) -> torch.Tensor: | ||
return torch.eye( | ||
self.hparams.num_atom_embedding, device=self.device, dtype=self.dtype | ||
) | ||
|
||
def atomic_numbers_to_one_hot(self, atomic_numbers: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Convert discrete atomic numbers into one-hot vectors based | ||
on some maximum number of elements possible. | ||
Parameters | ||
---------- | ||
atomic_numbers : torch.Tensor | ||
1D tensor of integers corresponding to atomic numbers. | ||
Returns | ||
------- | ||
torch.Tensor | ||
2D tensor of one-hot vectors for each node. | ||
""" | ||
return self._atom_eye[atomic_numbers.long()] | ||
|
||
def read_batch(self, batch: BatchDict) -> DataDict: | ||
data = {} | ||
# expect a PyG graph already | ||
graph = batch["graph"] | ||
atomic_numbers = graph.atomic_numbers | ||
one_hot_atoms = self.atomic_numbers_to_one_hot(atomic_numbers) | ||
# check to make sure we have unit cell shifts | ||
for key in ["cell", "offsets"]: | ||
if key not in batch: | ||
raise KeyError( | ||
f"Expected periodic property {key} to be in batch." | ||
" Please include ``PeriodicPropertiesTransform``." | ||
) | ||
assert hasattr(graph, "ptr"), "Graph is missing the `ptr` attribute!" | ||
# the name of these keys matches up with our `_forward`, and | ||
# later get remapped to MACE ones | ||
data.update( | ||
{ | ||
"graph": graph, | ||
"pos": graph.pos, | ||
"node_feats": one_hot_atoms, | ||
"cell": batch["cell"], | ||
"shifts": batch["offsets"], | ||
} | ||
) | ||
return data | ||
|
||
def _forward( | ||
self, | ||
graph: AbstractGraph, | ||
node_feats: torch.Tensor, | ||
pos: torch.Tensor, | ||
**kwargs, | ||
) -> Embeddings: | ||
""" | ||
Takes arguments in the standardized format, and passes them into MACE | ||
with some redundant mapping. | ||
Parameters | ||
---------- | ||
graph : AbstractGraph | ||
Graph structure containing node and graph properties | ||
node_feats : torch.Tensor | ||
Tensor containing one-hot node features, shape ``[num_nodes, num_elements]`` | ||
pos : torch.Tensor | ||
2D tensor containing node positions, shape ``[num_nodes, 3]`` | ||
Returns | ||
------- | ||
Embeddings | ||
MatSciML ``Embeddings`` structure | ||
""" | ||
# repack data into MACE format | ||
mace_data = { | ||
"positions": pos, | ||
"node_attrs": node_feats, | ||
"ptr": graph.ptr, | ||
"cell": kwargs["cell"], | ||
"shifts": kwargs["shifts"], | ||
"batch": graph.batch, | ||
"edge_index": graph.edge_index, | ||
} | ||
outputs = self.encoder( | ||
mace_data, | ||
training=self.training, | ||
compute_force=False, | ||
compute_virials=False, | ||
compute_stress=False, | ||
compute_displacement=False, | ||
) | ||
node_embeddings = outputs["node_feats"] | ||
graph_embeddings = self.readout(node_embeddings, graph.batch) | ||
return Embeddings(graph_embeddings, node_embeddings) |
Oops, something went wrong.