From f13ccfa642bbb07d2590ea2a29beea9423602b5a Mon Sep 17 00:00:00 2001 From: Kin Long Kelvin Lee Date: Wed, 13 Mar 2024 09:28:16 -0700 Subject: [PATCH 01/22] deps: loosening e3nn requirement to match MACE --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index c0dd1be5..183e71fe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,8 @@ dependencies = [ "matgl==1.0.0", "einops==0.7.0", "mendeleev==0.14.0", - "e3nn==0.5.1" + "e3nn", + "mace-torch==0.3.4" ] description = "PyTorch Lightning and Deep Graph Library enabled materials science deep learning pipeline" dynamic = ["version", "readme"] From e0f0c6bb04fabb624d89a5822b49805a1d422cbb Mon Sep 17 00:00:00 2001 From: Kin Long Kelvin Lee Date: Wed, 13 Mar 2024 17:38:24 -0700 Subject: [PATCH 02/22] feat: initial ground work for wrapper Signed-off-by: Kin Long Kelvin Lee --- matsciml/models/pyg/mace/original/__init__.py | 0 matsciml/models/pyg/mace/original/model.py | 43 +++++++++++++++++++ 2 files changed, 43 insertions(+) create mode 100644 matsciml/models/pyg/mace/original/__init__.py create mode 100644 matsciml/models/pyg/mace/original/model.py diff --git a/matsciml/models/pyg/mace/original/__init__.py b/matsciml/models/pyg/mace/original/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/matsciml/models/pyg/mace/original/model.py b/matsciml/models/pyg/mace/original/model.py new file mode 100644 index 00000000..5764e2d6 --- /dev/null +++ b/matsciml/models/pyg/mace/original/model.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +from inspect import signature +from typing import Any + +from e3nn.o3 import Irreps +from mace.modules import MACE + +from matsciml.models.base import AbstractPyGModel +from matsciml.common.registry import registry + + +__mace_signature = signature(MACE) +__mace_parameters = __mace_signature.parameters + + +@registry.register_model("MACEWrapper") +class MACEWrapper(AbstractPyGModel): + def __init__( + self, + atom_embedding_dim: int, + num_atom_embedding: int = 100, + embedding_kwargs: Any = None, + encoder_only: bool = True, + **mace_kwargs, + ) -> None: + super().__init__(atom_embedding_dim, num_atom_embedding, {}, encoder_only) + for key in mace_kwargs: + assert ( + key in __mace_parameters + ), f"{key} was passed as a MACE kwarg but does not match expected arguments." + # remove the embedding table, as MACE uses e3nn layers + del self.atom_embedding + if "num_elements" in mace_kwargs: + raise KeyError( + "Please use `num_atom_embedding` instead of passing `num_elements`." + ) + if "hidden_irreps" in mace_kwargs: + raise KeyError( + "Please use `atom_embedding_dim` instead of passing `hidden_irreps`." + ) + atom_embedding_dim = Irreps(f"{atom_embedding_dim}x0e") + self.encoder = ... From 82242593fa7f6f73c23745d2badd3761b856b0f8 Mon Sep 17 00:00:00 2001 From: Kin Long Kelvin Lee Date: Thu, 14 Mar 2024 13:49:55 -0700 Subject: [PATCH 03/22] feat: adding logger warning for MACE when embedding kwargs are passed --- matsciml/models/pyg/mace/original/model.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/matsciml/models/pyg/mace/original/model.py b/matsciml/models/pyg/mace/original/model.py index 5764e2d6..b5006aa9 100644 --- a/matsciml/models/pyg/mace/original/model.py +++ b/matsciml/models/pyg/mace/original/model.py @@ -1,6 +1,7 @@ from __future__ import annotations from inspect import signature +from logging import getLogger from typing import Any from e3nn.o3 import Irreps @@ -14,6 +15,9 @@ __mace_parameters = __mace_signature.parameters +logger = getLogger(__file__) + + @registry.register_model("MACEWrapper") class MACEWrapper(AbstractPyGModel): def __init__( @@ -24,6 +28,8 @@ def __init__( encoder_only: bool = True, **mace_kwargs, ) -> None: + if embedding_kwargs is not None: + logger.warning("`embedding_kwargs` is not used for MACE models.") super().__init__(atom_embedding_dim, num_atom_embedding, {}, encoder_only) for key in mace_kwargs: assert ( From 60fe6185eb6a6411928b8d352afec30a8e60a618 Mon Sep 17 00:00:00 2001 From: Kin Long Kelvin Lee Date: Fri, 15 Mar 2024 12:54:01 -0700 Subject: [PATCH 04/22] refactor: renaming module to something more descriptive Signed-off-by: Kin Long Kelvin Lee --- matsciml/models/pyg/mace/{original => wrapper}/__init__.py | 0 matsciml/models/pyg/mace/{original => wrapper}/model.py | 5 ++--- 2 files changed, 2 insertions(+), 3 deletions(-) rename matsciml/models/pyg/mace/{original => wrapper}/__init__.py (100%) rename matsciml/models/pyg/mace/{original => wrapper}/model.py (92%) diff --git a/matsciml/models/pyg/mace/original/__init__.py b/matsciml/models/pyg/mace/wrapper/__init__.py similarity index 100% rename from matsciml/models/pyg/mace/original/__init__.py rename to matsciml/models/pyg/mace/wrapper/__init__.py diff --git a/matsciml/models/pyg/mace/original/model.py b/matsciml/models/pyg/mace/wrapper/model.py similarity index 92% rename from matsciml/models/pyg/mace/original/model.py rename to matsciml/models/pyg/mace/wrapper/model.py index b5006aa9..c82a989f 100644 --- a/matsciml/models/pyg/mace/original/model.py +++ b/matsciml/models/pyg/mace/wrapper/model.py @@ -1,6 +1,5 @@ from __future__ import annotations -from inspect import signature from logging import getLogger from typing import Any @@ -9,10 +8,10 @@ from matsciml.models.base import AbstractPyGModel from matsciml.common.registry import registry +from matsciml.common.inspection import get_model_required_args -__mace_signature = signature(MACE) -__mace_parameters = __mace_signature.parameters +__mace_required_args = get_model_required_args(MACE) logger = getLogger(__file__) From 50496f5e506645708a68be25357de9ff9b2ec606 Mon Sep 17 00:00:00 2001 From: Kin Long Kelvin Lee Date: Fri, 15 Mar 2024 13:10:46 -0700 Subject: [PATCH 05/22] feat: adding utility functions for inspecting functions and models Signed-off-by: Kin Long Kelvin Lee --- matsciml/common/inspection.py | 70 +++++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) create mode 100644 matsciml/common/inspection.py diff --git a/matsciml/common/inspection.py b/matsciml/common/inspection.py new file mode 100644 index 00000000..81fd92d5 --- /dev/null +++ b/matsciml/common/inspection.py @@ -0,0 +1,70 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: MIT License +from __future__ import annotations + +from inspect import signature +from typing import Callable, Type + +from torch import nn + +""" +Simple utility functions for inspecting functions and objects. + +The idea is that this should provide reusable functions that +are useful for mapping kwargs onto classes that belong outside +of this library, where we might not know what is required or not. +""" + + +def get_args_without_defaults(func: Callable, exclude_self: bool = True) -> list[str]: + """ + Inspect a function for required positional input arguments. + + The function works by looping through arguments and checking + if defaults are available. The option ``exclude_self`` is also + available to specify whether or not to remove ``self`` entries + from the list, since this may correspond to class methods. + + Parameters + ---------- + func : Callable + A callable function with some input arguments. + exclude_self + If True, ignores ``self`` and ``cls`` from the returned + list. + + Returns + ------- + list[str] + List of argument names that are required by ``func``. + """ + parameters = signature(func).parameters + matches = [] + for arg_name, parameter in parameters.items(): + if getattr(parameter, "default", None): + matches.append(arg_name) + # remove self from list if requested + if exclude_self: + matches = list(filter(lambda x: x not in ["self", "cls"], matches)) + return matches + + +def get_model_required_args(model: Type[nn.Module]) -> list[str]: + """ + Inspect a child of PyTorch ``nn.Module`` for required arguments. + + The idea behind this is to identify which parameters are needed to + instantiate a model, which is useful for determining how args/kwargs + should be unpacked into a model from an abstract interface. + + Parameters + ---------- + model : Type[nn.Module] + A model class to inspect + + Returns + ------- + list[str] + List of argument names that are required to instantiate ``model`` + """ + return get_args_without_defaults(model.__init__, exclude_self=True) From 1137e37a730c821b3471bca0a521d155ff77205c Mon Sep 17 00:00:00 2001 From: Kin Long Kelvin Lee Date: Fri, 15 Mar 2024 13:23:30 -0700 Subject: [PATCH 06/22] feat: added functions to inspect model init and forward --- matsciml/common/inspection.py | 53 +++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/matsciml/common/inspection.py b/matsciml/common/inspection.py index 81fd92d5..f762fcb7 100644 --- a/matsciml/common/inspection.py +++ b/matsciml/common/inspection.py @@ -49,6 +49,24 @@ def get_args_without_defaults(func: Callable, exclude_self: bool = True) -> list return matches +def get_all_args(func: Callable) -> list[str]: + """ + Get all the arguments of a function, include with defaults. + + Parameters + ---------- + func : Callable + Function to inspect arguments for. + + Returns + ------- + list[str] + List of argument names, positional and keyword. + """ + parameters = signature(func).parameters + return list(parameters.keys()) + + def get_model_required_args(model: Type[nn.Module]) -> list[str]: """ Inspect a child of PyTorch ``nn.Module`` for required arguments. @@ -68,3 +86,38 @@ def get_model_required_args(model: Type[nn.Module]) -> list[str]: List of argument names that are required to instantiate ``model`` """ return get_args_without_defaults(model.__init__, exclude_self=True) + + +def get_model_all_args(model: Type[nn.Module]) -> list[str]: + """ + Inspect a model for all of its initialization arguments, including + optional ones. + + Parameters + ---------- + model : Type[nn.Module] + Model class to inspect for arguments. + + Returns + ------- + list[str] + List of arguments used by the model instantiation. + """ + return get_all_args(model.__init__) + + +def get_model_forward_args(model: Type[nn.Module]) -> list[str]: + """ + Inspect a model's ``forward`` method for argument names. + + Parameters + ---------- + model : Type[nn.Module] + Model to inspect. + + Returns + ------- + list[str] + List of argument names in a model's forward method. + """ + return get_all_args(model.forward) From f1b353a60b4f713502e37e7c693ff8e372321cba Mon Sep 17 00:00:00 2001 From: Kin Long Kelvin Lee Date: Fri, 15 Mar 2024 13:24:32 -0700 Subject: [PATCH 07/22] refactor: using function to make sure mace kwargs are present --- matsciml/models/pyg/mace/wrapper/model.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/matsciml/models/pyg/mace/wrapper/model.py b/matsciml/models/pyg/mace/wrapper/model.py index c82a989f..58d87a94 100644 --- a/matsciml/models/pyg/mace/wrapper/model.py +++ b/matsciml/models/pyg/mace/wrapper/model.py @@ -8,10 +8,11 @@ from matsciml.models.base import AbstractPyGModel from matsciml.common.registry import registry -from matsciml.common.inspection import get_model_required_args +from matsciml.common.inspection import get_model_required_args, get_model_all_args __mace_required_args = get_model_required_args(MACE) +__mace_all_args = get_model_all_args(MACE) logger = getLogger(__file__) @@ -32,7 +33,7 @@ def __init__( super().__init__(atom_embedding_dim, num_atom_embedding, {}, encoder_only) for key in mace_kwargs: assert ( - key in __mace_parameters + key in __mace_all_args ), f"{key} was passed as a MACE kwarg but does not match expected arguments." # remove the embedding table, as MACE uses e3nn layers del self.atom_embedding From dfc4d172143211f696830d93fce18dcd8586aaa6 Mon Sep 17 00:00:00 2001 From: Kin Long Kelvin Lee Date: Fri, 15 Mar 2024 13:33:50 -0700 Subject: [PATCH 08/22] feat: unpacking some arguments into MACE constructor --- matsciml/models/pyg/mace/wrapper/model.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/matsciml/models/pyg/mace/wrapper/model.py b/matsciml/models/pyg/mace/wrapper/model.py index 58d87a94..23331365 100644 --- a/matsciml/models/pyg/mace/wrapper/model.py +++ b/matsciml/models/pyg/mace/wrapper/model.py @@ -46,4 +46,7 @@ def __init__( "Please use `atom_embedding_dim` instead of passing `hidden_irreps`." ) atom_embedding_dim = Irreps(f"{atom_embedding_dim}x0e") - self.encoder = ... + # pack stuff into the mace kwargs + mace_kwargs["num_elements"] = num_atom_embedding + mace_kwargs["hidden_irreps"] = atom_embedding_dim + self.encoder = MACE(**mace_kwargs) From 2eb7c502bd79d7a7551a1b389394e874f0484fbd Mon Sep 17 00:00:00 2001 From: Kin Long Kelvin Lee Date: Fri, 15 Mar 2024 13:36:14 -0700 Subject: [PATCH 09/22] refactor: adding check to make sure all args are passed into MACE as needed Signed-off-by: Kin Long Kelvin Lee --- matsciml/models/pyg/mace/wrapper/model.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/matsciml/models/pyg/mace/wrapper/model.py b/matsciml/models/pyg/mace/wrapper/model.py index 23331365..9c187afc 100644 --- a/matsciml/models/pyg/mace/wrapper/model.py +++ b/matsciml/models/pyg/mace/wrapper/model.py @@ -49,4 +49,8 @@ def __init__( # pack stuff into the mace kwargs mace_kwargs["num_elements"] = num_atom_embedding mace_kwargs["hidden_irreps"] = atom_embedding_dim + # check to make sure all that's required is + for key in __mace_required_args: + if key not in mace_kwargs: + raise KeyError(f"{key} is required by MACE, but was not found in kwargs.") self.encoder = MACE(**mace_kwargs) From f0a1782a645785aed53a92fabb393bf9641e6d18 Mon Sep 17 00:00:00 2001 From: Kin Long Kelvin Lee Date: Fri, 15 Mar 2024 13:50:31 -0700 Subject: [PATCH 10/22] feat: implemented one-hot vector conversion --- matsciml/models/pyg/mace/wrapper/model.py | 31 ++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/matsciml/models/pyg/mace/wrapper/model.py b/matsciml/models/pyg/mace/wrapper/model.py index 9c187afc..1f890969 100644 --- a/matsciml/models/pyg/mace/wrapper/model.py +++ b/matsciml/models/pyg/mace/wrapper/model.py @@ -2,7 +2,9 @@ from logging import getLogger from typing import Any +from functools import cache +import torch from e3nn.o3 import Irreps from mace.modules import MACE @@ -52,5 +54,32 @@ def __init__( # check to make sure all that's required is for key in __mace_required_args: if key not in mace_kwargs: - raise KeyError(f"{key} is required by MACE, but was not found in kwargs.") + raise KeyError( + f"{key} is required by MACE, but was not found in kwargs." + ) self.encoder = MACE(**mace_kwargs) + self.save_hyperparameters() + + @property + @cache + def _atom_eye(self) -> torch.Tensor: + return torch.eye( + self.hparams.num_atom_embedding, device=self.device, dtype=self.dtype + ) + + def atomic_numbers_to_one_hot(self, atomic_numbers: torch.Tensor) -> torch.Tensor: + """ + Convert discrete atomic numbers into one-hot vectors based + on some maximum number of elements possible. + + Parameters + ---------- + atomic_numbers : torch.Tensor + 1D tensor of integers corresponding to atomic numbers. + + Returns + ------- + torch.Tensor + 2D tensor of one-hot vectors for each node. + """ + return self._atom_eye[atomic_numbers.long()] From 0dfd5d4400fb965f137104368d9ae42ce45132ef Mon Sep 17 00:00:00 2001 From: Kin Long Kelvin Lee Date: Fri, 15 Mar 2024 14:16:58 -0700 Subject: [PATCH 11/22] feat: ostensibly complete read batch method --- matsciml/models/pyg/mace/wrapper/model.py | 26 +++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/matsciml/models/pyg/mace/wrapper/model.py b/matsciml/models/pyg/mace/wrapper/model.py index 1f890969..20487548 100644 --- a/matsciml/models/pyg/mace/wrapper/model.py +++ b/matsciml/models/pyg/mace/wrapper/model.py @@ -9,6 +9,7 @@ from mace.modules import MACE from matsciml.models.base import AbstractPyGModel +from matsciml.common.types import BatchDict, DataDict from matsciml.common.registry import registry from matsciml.common.inspection import get_model_required_args, get_model_all_args @@ -83,3 +84,28 @@ def atomic_numbers_to_one_hot(self, atomic_numbers: torch.Tensor) -> torch.Tenso 2D tensor of one-hot vectors for each node. """ return self._atom_eye[atomic_numbers.long()] + + def read_batch(self, batch: BatchDict) -> DataDict: + data = super().read_batch(batch) + # expect a PyG graph already + graph = batch["graph"] + atomic_numbers = graph.atomic_numbers + one_hot_atoms = self.atomic_numbers_to_one_hot(atomic_numbers) + # check to make sure we have unit cell shifts + for key in ["cell", "offsets"]: + if key not in batch: + raise KeyError( + f"Expected periodic property {key} to be in batch." + " Please include ``PeriodicPropertiesTransform``." + ) + data.update( + { + "positions": graph.pos, + "edge_index": graph.edge_index, + "node_attrs": one_hot_atoms, + "ptr": graph.ptr, # refers to pointers/node segments + "cell": batch["cell"], + "shifts": batch["offsets"], + } + ) + return data From e06bf883489a567ede8a7ce61dbac20411c965de Mon Sep 17 00:00:00 2001 From: Kin Long Kelvin Lee Date: Fri, 15 Mar 2024 14:36:52 -0700 Subject: [PATCH 12/22] feat: ready to test forward pass --- matsciml/models/pyg/mace/wrapper/model.py | 58 +++++++++++++++++++++-- 1 file changed, 55 insertions(+), 3 deletions(-) diff --git a/matsciml/models/pyg/mace/wrapper/model.py b/matsciml/models/pyg/mace/wrapper/model.py index 20487548..3ec43fef 100644 --- a/matsciml/models/pyg/mace/wrapper/model.py +++ b/matsciml/models/pyg/mace/wrapper/model.py @@ -9,7 +9,7 @@ from mace.modules import MACE from matsciml.models.base import AbstractPyGModel -from matsciml.common.types import BatchDict, DataDict +from matsciml.common.types import BatchDict, DataDict, AbstractGraph, Embeddings from matsciml.common.registry import registry from matsciml.common.inspection import get_model_required_args, get_model_all_args @@ -98,14 +98,66 @@ def read_batch(self, batch: BatchDict) -> DataDict: f"Expected periodic property {key} to be in batch." " Please include ``PeriodicPropertiesTransform``." ) + # the name of these keys matches up with our `_forward`, and + # later get remapped to MACE ones data.update( { - "positions": graph.pos, + "pos": graph.pos, "edge_index": graph.edge_index, - "node_attrs": one_hot_atoms, + "node_feats": one_hot_atoms, "ptr": graph.ptr, # refers to pointers/node segments "cell": batch["cell"], "shifts": batch["offsets"], } ) return data + + def _forward( + self, + graph: AbstractGraph, + node_feats: torch.Tensor, + pos: torch.Tensor, + **kwargs, + ) -> Embeddings: + """ + Takes arguments in the standardized format, and passes them into MACE + with some redundant mapping. + + Parameters + ---------- + graph : AbstractGraph + Graph structure containing node and graph properties + + node_feats : torch.Tensor + Tensor containing one-hot node features, shape ``[num_nodes, num_elements]`` + + pos : torch.Tensor + 2D tensor containing node positions, shape ``[num_nodes, 3]`` + + Returns + ------- + Embeddings + MatSciML ``Embeddings`` structure + """ + # repack data into MACE format + mace_data = { + "positions": pos, + "node_attrs": node_feats, + "ptr": kwargs["ptr"], + "cell": kwargs["cell"], + "shifts": kwargs["shifts"], + "batch": graph.batch, + "edge_index": graph.edge_index, + } + outputs = self.encoder( + mace_data, + training=self.training, + compute_force=False, + compute_virials=False, + compute_stress=False, + compute_displacement=False, + ) + # TODO check that these are the correct things to unpack + node_embeddings = outputs["node_feats"] + graph_embeddings = outputs["contributions"] + return Embeddings(graph_embeddings, node_embeddings) From a7a9d7bed84c29e1c0ef90934313a6bb93f42a48 Mon Sep 17 00:00:00 2001 From: Kin Long Kelvin Lee Date: Fri, 15 Mar 2024 14:39:34 -0700 Subject: [PATCH 13/22] feat: adding MACEWrapper to __all__ --- matsciml/models/pyg/mace/wrapper/model.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/matsciml/models/pyg/mace/wrapper/model.py b/matsciml/models/pyg/mace/wrapper/model.py index 3ec43fef..5d14ce9d 100644 --- a/matsciml/models/pyg/mace/wrapper/model.py +++ b/matsciml/models/pyg/mace/wrapper/model.py @@ -20,6 +20,8 @@ logger = getLogger(__file__) +__all__ = ["MACEWrapper"] + @registry.register_model("MACEWrapper") class MACEWrapper(AbstractPyGModel): From 332301c76a3b86a2100c96997ab03f42f801143b Mon Sep 17 00:00:00 2001 From: Kin Long Kelvin Lee Date: Fri, 15 Mar 2024 14:40:16 -0700 Subject: [PATCH 14/22] feat: exposing MACEWrapper to mace submodule --- matsciml/models/pyg/mace/__init__.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/matsciml/models/pyg/mace/__init__.py b/matsciml/models/pyg/mace/__init__.py index 22130b7e..cb16b780 100644 --- a/matsciml/models/pyg/mace/__init__.py +++ b/matsciml/models/pyg/mace/__init__.py @@ -1,8 +1,6 @@ from __future__ import annotations from matsciml.models.pyg.mace.modules.models import MACE, ScaleShiftMACE +from matsciml.models.pyg.mace.wrapper.model import MACEWrapper -__all__ = [ - "MACE", - "ScaleShiftMACE", -] +__all__ = ["MACE", "ScaleShiftMACE", "MACEWrapper"] From 150b7c40d59014c5512513e9108be5b48b8efab5 Mon Sep 17 00:00:00 2001 From: Kin Long Kelvin Lee Date: Fri, 15 Mar 2024 14:49:27 -0700 Subject: [PATCH 15/22] feat: adding defaults to MACE --- matsciml/models/pyg/mace/wrapper/model.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/matsciml/models/pyg/mace/wrapper/model.py b/matsciml/models/pyg/mace/wrapper/model.py index 5d14ce9d..617450ae 100644 --- a/matsciml/models/pyg/mace/wrapper/model.py +++ b/matsciml/models/pyg/mace/wrapper/model.py @@ -5,6 +5,7 @@ from functools import cache import torch +import numpy as np from e3nn.o3 import Irreps from mace.modules import MACE @@ -54,6 +55,10 @@ def __init__( # pack stuff into the mace kwargs mace_kwargs["num_elements"] = num_atom_embedding mace_kwargs["hidden_irreps"] = atom_embedding_dim + mace_kwargs["atomic_numbers"] = list(range(1, num_atom_embedding)) + if "atomic_energies" not in mace_kwargs: + logger.warning("No ``atomic_energies`` provided, defaulting to ones.") + mace_kwargs["atomic_energies"] = np.ones(num_atom_embedding) # check to make sure all that's required is for key in __mace_required_args: if key not in mace_kwargs: From 4c710aae32a40a0879d4152e1e070e25b924974f Mon Sep 17 00:00:00 2001 From: Kin Long Kelvin Lee Date: Fri, 15 Mar 2024 14:55:07 -0700 Subject: [PATCH 16/22] refactor: moving argspec checks into init method --- matsciml/models/pyg/mace/wrapper/model.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/matsciml/models/pyg/mace/wrapper/model.py b/matsciml/models/pyg/mace/wrapper/model.py index 617450ae..c3a7971d 100644 --- a/matsciml/models/pyg/mace/wrapper/model.py +++ b/matsciml/models/pyg/mace/wrapper/model.py @@ -15,10 +15,6 @@ from matsciml.common.inspection import get_model_required_args, get_model_all_args -__mace_required_args = get_model_required_args(MACE) -__mace_all_args = get_model_all_args(MACE) - - logger = getLogger(__file__) __all__ = ["MACEWrapper"] @@ -37,6 +33,9 @@ def __init__( if embedding_kwargs is not None: logger.warning("`embedding_kwargs` is not used for MACE models.") super().__init__(atom_embedding_dim, num_atom_embedding, {}, encoder_only) + # dynamically check to check which arguments are needed by MACE + __mace_required_args = get_model_required_args(MACE) + __mace_all_args = get_model_all_args(MACE) for key in mace_kwargs: assert ( key in __mace_all_args From 74216b0e6a7de63ec53fed5e4344f1018cfcca89 Mon Sep 17 00:00:00 2001 From: Kin Long Kelvin Lee Date: Fri, 15 Mar 2024 14:56:34 -0700 Subject: [PATCH 17/22] refactor: not using super read_batch because we don't have an embedding table --- matsciml/models/pyg/mace/wrapper/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/matsciml/models/pyg/mace/wrapper/model.py b/matsciml/models/pyg/mace/wrapper/model.py index c3a7971d..f97dc0d2 100644 --- a/matsciml/models/pyg/mace/wrapper/model.py +++ b/matsciml/models/pyg/mace/wrapper/model.py @@ -92,7 +92,7 @@ def atomic_numbers_to_one_hot(self, atomic_numbers: torch.Tensor) -> torch.Tenso return self._atom_eye[atomic_numbers.long()] def read_batch(self, batch: BatchDict) -> DataDict: - data = super().read_batch(batch) + data = {} # expect a PyG graph already graph = batch["graph"] atomic_numbers = graph.atomic_numbers From 5ea8abc0631eb5e93575fc9c88404fcc6d52dc75 Mon Sep 17 00:00:00 2001 From: Kin Long Kelvin Lee Date: Fri, 15 Mar 2024 14:56:46 -0700 Subject: [PATCH 18/22] fix: added missing graph key from read batch --- matsciml/models/pyg/mace/wrapper/model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/matsciml/models/pyg/mace/wrapper/model.py b/matsciml/models/pyg/mace/wrapper/model.py index f97dc0d2..f8c5cca8 100644 --- a/matsciml/models/pyg/mace/wrapper/model.py +++ b/matsciml/models/pyg/mace/wrapper/model.py @@ -108,6 +108,7 @@ def read_batch(self, batch: BatchDict) -> DataDict: # later get remapped to MACE ones data.update( { + "graph": graph, "pos": graph.pos, "edge_index": graph.edge_index, "node_feats": one_hot_atoms, From 4cfdf12d838a2c7df0ea89c6c2df10db707642ea Mon Sep 17 00:00:00 2001 From: Kin Long Kelvin Lee Date: Fri, 15 Mar 2024 15:06:58 -0700 Subject: [PATCH 19/22] refactor: use own readout method Not immediately obvious if a readout is provided from MACE, so we do it ourselves on the node embeddings --- matsciml/models/pyg/mace/wrapper/model.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/matsciml/models/pyg/mace/wrapper/model.py b/matsciml/models/pyg/mace/wrapper/model.py index f8c5cca8..16f90464 100644 --- a/matsciml/models/pyg/mace/wrapper/model.py +++ b/matsciml/models/pyg/mace/wrapper/model.py @@ -1,13 +1,14 @@ from __future__ import annotations from logging import getLogger -from typing import Any +from typing import Any, Callable from functools import cache import torch import numpy as np from e3nn.o3 import Irreps from mace.modules import MACE +from torch_geometric.nn import pool from matsciml.models.base import AbstractPyGModel from matsciml.common.types import BatchDict, DataDict, AbstractGraph, Embeddings @@ -28,6 +29,7 @@ def __init__( num_atom_embedding: int = 100, embedding_kwargs: Any = None, encoder_only: bool = True, + readout_method: str | Callable = "add", **mace_kwargs, ) -> None: if embedding_kwargs is not None: @@ -65,6 +67,17 @@ def __init__( f"{key} is required by MACE, but was not found in kwargs." ) self.encoder = MACE(**mace_kwargs) + # if a string is passed, grab the PyG builtins + if isinstance(readout_method, str): + readout_type = getattr(pool, f"global_{readout_method}_pool", None) + if not readout_type: + possible_methods = list(filter(lambda x: "global" in x, dir(pool))) + raise NotImplementedError( + f"{readout_method} is not a valid function in PyG pooling." + f" Supported methods are: {possible_methods}" + ) + readout_method = readout_type + self.readout = readout_method self.save_hyperparameters() @property @@ -164,7 +177,6 @@ def _forward( compute_stress=False, compute_displacement=False, ) - # TODO check that these are the correct things to unpack node_embeddings = outputs["node_feats"] - graph_embeddings = outputs["contributions"] + graph_embeddings = self.readout(node_embeddings, graph.batch) return Embeddings(graph_embeddings, node_embeddings) From 2030f2ef8247b199038e730270a58101d914cce8 Mon Sep 17 00:00:00 2001 From: Kin Long Kelvin Lee Date: Fri, 15 Mar 2024 15:07:22 -0700 Subject: [PATCH 20/22] script: added example mace wrapper script --- examples/model_demos/mace_wrapper.py | 58 ++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) create mode 100644 examples/model_demos/mace_wrapper.py diff --git a/examples/model_demos/mace_wrapper.py b/examples/model_demos/mace_wrapper.py new file mode 100644 index 00000000..6861d3de --- /dev/null +++ b/examples/model_demos/mace_wrapper.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +import pytorch_lightning as pl +from torch import nn +from e3nn.o3 import Irreps +from mace.modules.blocks import RealAgnosticInteractionBlock + +from matsciml.datasets.transforms import ( + PointCloudToGraphTransform, + PeriodicPropertiesTransform, +) +from matsciml.lightning.data_utils import MatSciMLDataModule +from matsciml.models.base import ScalarRegressionTask +from matsciml.models.pyg.mace import MACEWrapper + + +""" +This example script runs through a fast development run of the IS2RE devset +in combination with a PyG implementation of EGNN. +""" + +# construct IS2RE relaxed energy regression with PyG implementation of E(n)-GNN +task = ScalarRegressionTask( + encoder_class=MACEWrapper, + encoder_kwargs={ + "r_max": 6.0, + "num_bessel": 3, + "num_polynomial_cutoff": 3, + "max_ell": 2, + "interaction_cls": RealAgnosticInteractionBlock, + "interaction_cls_first": RealAgnosticInteractionBlock, + "num_interactions": 2, + "atom_embedding_dim": 64, + "MLP_irreps": Irreps("256x0e"), + "avg_num_neighbors": 10.0, + "correlation": 1, + "radial_type": "bessel", + "gate": nn.Identity(), + }, + task_keys=["energy_relaxed"], +) +# matsciml devset for OCP are serialized with DGL - this transform goes between the two frameworks +dm = MatSciMLDataModule.from_devset( + "IS2REDataset", + dset_kwargs={ + "transforms": [ + PeriodicPropertiesTransform(6.0, adaptive_cutoff=True), + PointCloudToGraphTransform( + "pyg", + node_keys=["pos", "atomic_numbers"], + ), + ], + }, +) + +# run a quick training loop +trainer = pl.Trainer(fast_dev_run=10) +trainer.fit(task, datamodule=dm) From dbfa113103436d5a59ddee9a1416d0c21c5ce166 Mon Sep 17 00:00:00 2001 From: Kin Long Kelvin Lee Date: Fri, 15 Mar 2024 15:22:59 -0700 Subject: [PATCH 21/22] refactor: removing redundant steps from read_batch --- matsciml/models/pyg/mace/wrapper/model.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/matsciml/models/pyg/mace/wrapper/model.py b/matsciml/models/pyg/mace/wrapper/model.py index 16f90464..7678be57 100644 --- a/matsciml/models/pyg/mace/wrapper/model.py +++ b/matsciml/models/pyg/mace/wrapper/model.py @@ -117,15 +117,14 @@ def read_batch(self, batch: BatchDict) -> DataDict: f"Expected periodic property {key} to be in batch." " Please include ``PeriodicPropertiesTransform``." ) + assert hasattr(graph, "ptr"), "Graph is missing the `ptr` attribute!" # the name of these keys matches up with our `_forward`, and # later get remapped to MACE ones data.update( { "graph": graph, "pos": graph.pos, - "edge_index": graph.edge_index, "node_feats": one_hot_atoms, - "ptr": graph.ptr, # refers to pointers/node segments "cell": batch["cell"], "shifts": batch["offsets"], } @@ -163,7 +162,7 @@ def _forward( mace_data = { "positions": pos, "node_attrs": node_feats, - "ptr": kwargs["ptr"], + "ptr": graph.ptr, "cell": kwargs["cell"], "shifts": kwargs["shifts"], "batch": graph.batch, From d86cf1b1637570a7f7d36f8723fd151b53a85149 Mon Sep 17 00:00:00 2001 From: Kin Long Kelvin Lee Date: Fri, 15 Mar 2024 15:30:29 -0700 Subject: [PATCH 22/22] test: adding test suite for mace wrapper across datasets Signed-off-by: Kin Long Kelvin Lee --- .../mace/wrapper/tests/test_mace_wrapper.py | 105 ++++++++++++++++++ 1 file changed, 105 insertions(+) create mode 100644 matsciml/models/pyg/mace/wrapper/tests/test_mace_wrapper.py diff --git a/matsciml/models/pyg/mace/wrapper/tests/test_mace_wrapper.py b/matsciml/models/pyg/mace/wrapper/tests/test_mace_wrapper.py new file mode 100644 index 00000000..87def6ba --- /dev/null +++ b/matsciml/models/pyg/mace/wrapper/tests/test_mace_wrapper.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +import pytest +import torch +from torch import nn +from e3nn.o3 import Irreps +from mace.modules.blocks import RealAgnosticInteractionBlock + +# this import is not used, but ensures that the registry is updated +from matsciml import datasets # noqa: F401 +from matsciml.common.registry import registry +from matsciml.datasets.transforms import ( + PeriodicPropertiesTransform, + PointCloudToGraphTransform, +) +from matsciml.lightning import MatSciMLDataModule +from matsciml.models.pyg.mace import MACEWrapper + + +@pytest.fixture +def mace_architecture() -> MACEWrapper: + """ + Fixture for a nominal mace architecture. + + Some lightweight (but realistic) hyperparameters are + used to test data flowing through the model. + + Returns + ------- + mace + Concrete mace object + """ + model_config = { + "r_max": 6.0, + "num_bessel": 3, + "num_polynomial_cutoff": 3, + "max_ell": 2, + "interaction_cls": RealAgnosticInteractionBlock, + "interaction_cls_first": RealAgnosticInteractionBlock, + "num_interactions": 2, + "atom_embedding_dim": 64, + "MLP_irreps": Irreps("256x0e"), + "avg_num_neighbors": 10.0, + "correlation": 1, + "radial_type": "bessel", + "gate": nn.Identity(), + } + model = MACEWrapper(**model_config) + return model + + +# here we filter out datasets from the registry that don't make sense +ignore_dset = ["Multi", "M3G", "PyG", "Cdvae"] +filtered_list = list( + filter( + lambda x: all([target_str not in x for target_str in ignore_dset]), + registry.__entries__["datasets"].keys(), + ), +) + + +@pytest.mark.parametrize( + "dset_class_name", + filtered_list, +) +def test_model_forward_nograd(dset_class_name: str, mace_architecture: MACEWrapper): + # these are necessary for the model to work as intended + """ + This test checks model ``forward`` compatibility with datasets. + + The test is parameterized to run on all datasets in the registry + that have *not* been filtered out; this list should be sparse, + as the idea is to maximize coverage and we can just ignore failing + combinations if they do not make sense and we can at least be + aware of them. + + Parameters + ---------- + dset_class_name : str + Name of the dataset class to retrieve + mace_architecture : EGNN + Concrete mace object with some parameters + """ + transforms = [ + PeriodicPropertiesTransform(cutoff_radius=6.0, adaptive_cutoff=True), + PointCloudToGraphTransform("pyg"), + ] + dm = MatSciMLDataModule.from_devset( + dset_class_name, + batch_size=4, + dset_kwargs={"transforms": transforms}, + ) + # dummy initialization + dm.setup("fit") + loader = dm.train_dataloader() + batch = next(iter(loader)) + # run the model without gradient tracking + with torch.no_grad(): + embeddings = mace_architecture(batch) + # returns embeddings, and runs numerical checks + for z in [embeddings.system_embedding, embeddings.point_embedding]: + assert torch.isreal(z).all() + assert ~torch.isnan(z).all() # check there are no NaNs + assert torch.isfinite(z).all() + assert torch.all(torch.abs(z) <= 1000) # ensure reasonable values