Skip to content

Commit

Permalink
Merge pull request #156 from laserkelvin/upstream-mace
Browse files Browse the repository at this point in the history
Upstream MACE implementation
  • Loading branch information
smiret-intel authored Mar 18, 2024
2 parents 4b74259 + 96a99f8 commit 807532e
Show file tree
Hide file tree
Showing 7 changed files with 471 additions and 5 deletions.
58 changes: 58 additions & 0 deletions examples/model_demos/mace_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from __future__ import annotations

import pytorch_lightning as pl
from torch import nn
from e3nn.o3 import Irreps
from mace.modules.blocks import RealAgnosticInteractionBlock

from matsciml.datasets.transforms import (
PointCloudToGraphTransform,
PeriodicPropertiesTransform,
)
from matsciml.lightning.data_utils import MatSciMLDataModule
from matsciml.models.base import ScalarRegressionTask
from matsciml.models.pyg.mace import MACEWrapper


"""
This example script runs through a fast development run of the IS2RE devset
in combination with a PyG implementation of EGNN.
"""

# construct IS2RE relaxed energy regression with PyG implementation of E(n)-GNN
task = ScalarRegressionTask(
encoder_class=MACEWrapper,
encoder_kwargs={
"r_max": 6.0,
"num_bessel": 3,
"num_polynomial_cutoff": 3,
"max_ell": 2,
"interaction_cls": RealAgnosticInteractionBlock,
"interaction_cls_first": RealAgnosticInteractionBlock,
"num_interactions": 2,
"atom_embedding_dim": 64,
"MLP_irreps": Irreps("256x0e"),
"avg_num_neighbors": 10.0,
"correlation": 1,
"radial_type": "bessel",
"gate": nn.Identity(),
},
task_keys=["energy_relaxed"],
)
# matsciml devset for OCP are serialized with DGL - this transform goes between the two frameworks
dm = MatSciMLDataModule.from_devset(
"IS2REDataset",
dset_kwargs={
"transforms": [
PeriodicPropertiesTransform(6.0, adaptive_cutoff=True),
PointCloudToGraphTransform(
"pyg",
node_keys=["pos", "atomic_numbers"],
),
],
},
)

# run a quick training loop
trainer = pl.Trainer(fast_dev_run=10)
trainer.fit(task, datamodule=dm)
123 changes: 123 additions & 0 deletions matsciml/common/inspection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: MIT License
from __future__ import annotations

from inspect import signature
from typing import Callable, Type

from torch import nn

"""
Simple utility functions for inspecting functions and objects.
The idea is that this should provide reusable functions that
are useful for mapping kwargs onto classes that belong outside
of this library, where we might not know what is required or not.
"""


def get_args_without_defaults(func: Callable, exclude_self: bool = True) -> list[str]:
"""
Inspect a function for required positional input arguments.
The function works by looping through arguments and checking
if defaults are available. The option ``exclude_self`` is also
available to specify whether or not to remove ``self`` entries
from the list, since this may correspond to class methods.
Parameters
----------
func : Callable
A callable function with some input arguments.
exclude_self
If True, ignores ``self`` and ``cls`` from the returned
list.
Returns
-------
list[str]
List of argument names that are required by ``func``.
"""
parameters = signature(func).parameters
matches = []
for arg_name, parameter in parameters.items():
if getattr(parameter, "default", None):
matches.append(arg_name)
# remove self from list if requested
if exclude_self:
matches = list(filter(lambda x: x not in ["self", "cls"], matches))
return matches


def get_all_args(func: Callable) -> list[str]:
"""
Get all the arguments of a function, include with defaults.
Parameters
----------
func : Callable
Function to inspect arguments for.
Returns
-------
list[str]
List of argument names, positional and keyword.
"""
parameters = signature(func).parameters
return list(parameters.keys())


def get_model_required_args(model: Type[nn.Module]) -> list[str]:
"""
Inspect a child of PyTorch ``nn.Module`` for required arguments.
The idea behind this is to identify which parameters are needed to
instantiate a model, which is useful for determining how args/kwargs
should be unpacked into a model from an abstract interface.
Parameters
----------
model : Type[nn.Module]
A model class to inspect
Returns
-------
list[str]
List of argument names that are required to instantiate ``model``
"""
return get_args_without_defaults(model.__init__, exclude_self=True)


def get_model_all_args(model: Type[nn.Module]) -> list[str]:
"""
Inspect a model for all of its initialization arguments, including
optional ones.
Parameters
----------
model : Type[nn.Module]
Model class to inspect for arguments.
Returns
-------
list[str]
List of arguments used by the model instantiation.
"""
return get_all_args(model.__init__)


def get_model_forward_args(model: Type[nn.Module]) -> list[str]:
"""
Inspect a model's ``forward`` method for argument names.
Parameters
----------
model : Type[nn.Module]
Model to inspect.
Returns
-------
list[str]
List of argument names in a model's forward method.
"""
return get_all_args(model.forward)
6 changes: 2 additions & 4 deletions matsciml/models/pyg/mace/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from __future__ import annotations

from matsciml.models.pyg.mace.modules.models import MACE, ScaleShiftMACE
from matsciml.models.pyg.mace.wrapper.model import MACEWrapper

__all__ = [
"MACE",
"ScaleShiftMACE",
]
__all__ = ["MACE", "ScaleShiftMACE", "MACEWrapper"]
Empty file.
181 changes: 181 additions & 0 deletions matsciml/models/pyg/mace/wrapper/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
from __future__ import annotations

from logging import getLogger
from typing import Any, Callable
from functools import cache

import torch
import numpy as np
from e3nn.o3 import Irreps
from mace.modules import MACE
from torch_geometric.nn import pool

from matsciml.models.base import AbstractPyGModel
from matsciml.common.types import BatchDict, DataDict, AbstractGraph, Embeddings
from matsciml.common.registry import registry
from matsciml.common.inspection import get_model_required_args, get_model_all_args


logger = getLogger(__file__)

__all__ = ["MACEWrapper"]


@registry.register_model("MACEWrapper")
class MACEWrapper(AbstractPyGModel):
def __init__(
self,
atom_embedding_dim: int,
num_atom_embedding: int = 100,
embedding_kwargs: Any = None,
encoder_only: bool = True,
readout_method: str | Callable = "add",
**mace_kwargs,
) -> None:
if embedding_kwargs is not None:
logger.warning("`embedding_kwargs` is not used for MACE models.")
super().__init__(atom_embedding_dim, num_atom_embedding, {}, encoder_only)
# dynamically check to check which arguments are needed by MACE
__mace_required_args = get_model_required_args(MACE)
__mace_all_args = get_model_all_args(MACE)
for key in mace_kwargs:
assert (
key in __mace_all_args
), f"{key} was passed as a MACE kwarg but does not match expected arguments."
# remove the embedding table, as MACE uses e3nn layers
del self.atom_embedding
if "num_elements" in mace_kwargs:
raise KeyError(
"Please use `num_atom_embedding` instead of passing `num_elements`."
)
if "hidden_irreps" in mace_kwargs:
raise KeyError(
"Please use `atom_embedding_dim` instead of passing `hidden_irreps`."
)
atom_embedding_dim = Irreps(f"{atom_embedding_dim}x0e")
# pack stuff into the mace kwargs
mace_kwargs["num_elements"] = num_atom_embedding
mace_kwargs["hidden_irreps"] = atom_embedding_dim
mace_kwargs["atomic_numbers"] = list(range(1, num_atom_embedding))
if "atomic_energies" not in mace_kwargs:
logger.warning("No ``atomic_energies`` provided, defaulting to ones.")
mace_kwargs["atomic_energies"] = np.ones(num_atom_embedding)
# check to make sure all that's required is
for key in __mace_required_args:
if key not in mace_kwargs:
raise KeyError(
f"{key} is required by MACE, but was not found in kwargs."
)
self.encoder = MACE(**mace_kwargs)
# if a string is passed, grab the PyG builtins
if isinstance(readout_method, str):
readout_type = getattr(pool, f"global_{readout_method}_pool", None)
if not readout_type:
possible_methods = list(filter(lambda x: "global" in x, dir(pool)))
raise NotImplementedError(
f"{readout_method} is not a valid function in PyG pooling."
f" Supported methods are: {possible_methods}"
)
readout_method = readout_type
self.readout = readout_method
self.save_hyperparameters()

@property
@cache
def _atom_eye(self) -> torch.Tensor:
return torch.eye(
self.hparams.num_atom_embedding, device=self.device, dtype=self.dtype
)

def atomic_numbers_to_one_hot(self, atomic_numbers: torch.Tensor) -> torch.Tensor:
"""
Convert discrete atomic numbers into one-hot vectors based
on some maximum number of elements possible.
Parameters
----------
atomic_numbers : torch.Tensor
1D tensor of integers corresponding to atomic numbers.
Returns
-------
torch.Tensor
2D tensor of one-hot vectors for each node.
"""
return self._atom_eye[atomic_numbers.long()]

def read_batch(self, batch: BatchDict) -> DataDict:
data = {}
# expect a PyG graph already
graph = batch["graph"]
atomic_numbers = graph.atomic_numbers
one_hot_atoms = self.atomic_numbers_to_one_hot(atomic_numbers)
# check to make sure we have unit cell shifts
for key in ["cell", "offsets"]:
if key not in batch:
raise KeyError(
f"Expected periodic property {key} to be in batch."
" Please include ``PeriodicPropertiesTransform``."
)
assert hasattr(graph, "ptr"), "Graph is missing the `ptr` attribute!"
# the name of these keys matches up with our `_forward`, and
# later get remapped to MACE ones
data.update(
{
"graph": graph,
"pos": graph.pos,
"node_feats": one_hot_atoms,
"cell": batch["cell"],
"shifts": batch["offsets"],
}
)
return data

def _forward(
self,
graph: AbstractGraph,
node_feats: torch.Tensor,
pos: torch.Tensor,
**kwargs,
) -> Embeddings:
"""
Takes arguments in the standardized format, and passes them into MACE
with some redundant mapping.
Parameters
----------
graph : AbstractGraph
Graph structure containing node and graph properties
node_feats : torch.Tensor
Tensor containing one-hot node features, shape ``[num_nodes, num_elements]``
pos : torch.Tensor
2D tensor containing node positions, shape ``[num_nodes, 3]``
Returns
-------
Embeddings
MatSciML ``Embeddings`` structure
"""
# repack data into MACE format
mace_data = {
"positions": pos,
"node_attrs": node_feats,
"ptr": graph.ptr,
"cell": kwargs["cell"],
"shifts": kwargs["shifts"],
"batch": graph.batch,
"edge_index": graph.edge_index,
}
outputs = self.encoder(
mace_data,
training=self.training,
compute_force=False,
compute_virials=False,
compute_stress=False,
compute_displacement=False,
)
node_embeddings = outputs["node_feats"]
graph_embeddings = self.readout(node_embeddings, graph.batch)
return Embeddings(graph_embeddings, node_embeddings)
Loading

0 comments on commit 807532e

Please sign in to comment.