Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ASE Calculator Updates #299

Merged
merged 8 commits into from
Oct 1, 2024
127 changes: 80 additions & 47 deletions matsciml/interfaces/ase/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
)
from matsciml.datasets.transforms.base import AbstractDataTransform
from matsciml.interfaces.ase import multitask as mt
from matsciml.datasets.utils import concatenate_keys

__all__ = ["MatSciMLCalculator"]

Expand Down Expand Up @@ -83,17 +84,21 @@ class MatSciMLCalculator(Calculator):

def __init__(
self,
task_module: ScalarRegressionTask
| GradFreeForceRegressionTask
| ForceRegressionTask
| MultiTaskLitModule,
task_module: (
ScalarRegressionTask
| GradFreeForceRegressionTask
| ForceRegressionTask
| MultiTaskLitModule
),
transforms: list[AbstractDataTransform | Callable] | None = None,
restart=None,
label=None,
atoms: Atoms | None = None,
directory=".",
conversion_factor: float | dict[str, float] = 1.0,
multitask_strategy: str | Callable | mt.AbstractStrategy = "AverageTasks",
output_map: dict[str, str] | None = None,
matsciml_model: bool = True,
**kwargs,
):
"""
Expand Down Expand Up @@ -144,33 +149,39 @@ def __init__(
to ``ase``. If a single ``float`` is passed, we assume that
the conversion is applied to the energy output. Each factor
is multiplied with the result.
output_map : dict[str, str] | None, default None
specifies how model outputs should be mapped to Calculator expected
results. for example {"ase_expected": "model_output"} -> {"forces": "force"}
matsciml_model : bool, default True
flag indicating whether model was trained with matsciml or not.
"""
super().__init__(
restart, label=label, atoms=atoms, directory=directory, **kwargs
)
assert isinstance(
task_module,
(
ForceRegressionTask,
ScalarRegressionTask,
GradFreeForceRegressionTask,
MultiTaskLitModule,
),
), f"Expected task to be one that is capable of energy/force prediction. Got {task_module.__type__}."
if isinstance(task_module, MultiTaskLitModule):
assert any(
[
isinstance(
subtask,
(
ForceRegressionTask,
ScalarRegressionTask,
GradFreeForceRegressionTask,
),
)
for subtask in task_module.task_list
]
), "Expected at least one subtask to be energy/force predictor."
if matsciml_model:
assert isinstance(
task_module,
(
ForceRegressionTask,
ScalarRegressionTask,
GradFreeForceRegressionTask,
MultiTaskLitModule,
),
), f"Expected task to be one that is capable of energy/force prediction. Got {task_module.__type__}."
if isinstance(task_module, MultiTaskLitModule):
assert any(
[
isinstance(
subtask,
(
ForceRegressionTask,
ScalarRegressionTask,
GradFreeForceRegressionTask,
),
)
for subtask in task_module.task_list
]
), "Expected at least one subtask to be energy/force predictor."
self.task_module = task_module
self.transforms = transforms
self.conversion_factor = conversion_factor
Expand All @@ -182,6 +193,18 @@ def __init__(
)
multitask_strategy = cls_name()
self.multitask_strategy = multitask_strategy
self.matsciml_model = matsciml_model
self.output_map = dict(
zip(self.implemented_properties, self.implemented_properties)
)
if output_map is not None:
for k, v in output_map.items():
if k not in self.output_map:
raise KeyError(
f"Specified key {k} is not one of the implemented_properties of this calculator: {self.implemented_properties}"
)
else:
self.output_map[k] = v

@property
def conversion_factor(self) -> dict[str, float]:
Expand Down Expand Up @@ -212,9 +235,8 @@ def _format_atoms(self, atoms: Atoms) -> DataDict:
data_dict["pos"] = pos
data_dict["atomic_numbers"] = atomic_numbers
data_dict["cell"] = cell
# ptr and batch are usually expected by MACE even if it's a single graph
data_dict["ptr"] = torch.tensor([0])
laserkelvin marked this conversation as resolved.
Show resolved Hide resolved
data_dict["batch"] = torch.zeros((pos.size(0)))
data_dict["frac_coords"] = torch.from_numpy(atoms.get_scaled_positions())
data_dict["natoms"] = pos.size(0)
return data_dict

def _format_pipeline(self, atoms: Atoms) -> DataDict:
Expand All @@ -230,10 +252,6 @@ def _format_pipeline(self, atoms: Atoms) -> DataDict:
"""
# initial formatting to get something akin to dataset outputs
data_dict = self._format_atoms(atoms)
# type cast into the type expected by the model
data_dict = recursive_type_cast(
data_dict, self.dtype, ignore_keys=["atomic_numbers"], convert_numpy=True
)
# now run through the same transform pipeline as for datasets
if self.transforms:
for transform in self.transforms:
Expand All @@ -248,24 +266,39 @@ def calculate(
) -> None:
# retrieve atoms even if not passed
Calculator.calculate(self, atoms)
# get into format ready for matsciml model
data_dict = self._format_pipeline(atoms)
# run the data structure through the model
output = self.task_module.predict(data_dict)
if self.matsciml_model:
# get into format ready for matsciml model
data_dict = self._format_pipeline(atoms)
# concatenate_keys batches data and adds some attributes that may be expected, like ptr.
data_dict = concatenate_keys([data_dict])
# type cast into the type expected by the model
data_dict = recursive_type_cast(
data_dict,
self.dtype,
ignore_keys=["atomic_numbers"],
convert_numpy=True,
)
# run the data structure through the model
output = self.task_module.predict(data_dict)
else:
output = self.task_module.forward(atoms)
if isinstance(self.task_module, MultiTaskLitModule):
# use a more complicated parser for multitasks
results = self.multitask_strategy(output, self.task_module)
self.results = results
else:
# add outputs to self.results as expected by ase
if "energy" in output:
self.results["energy"] = output["energy"].detach().item()
if "force" in output:
self.results["forces"] = output["force"].detach().numpy()
if "stress" in output:
self.results["stress"] = output["stress"].detach().numpy()
if "dipole" in output:
self.results["dipole"] = output["dipole"].detach().numpy()
# add outputs to self.results as expected by ase, as specified by ``properties``
# "ase_properties" are those in ``properties``.
for ase_property in properties:
model_property = self.output_map[ase_property]
model_output = output.get(model_property, None)
if model_output is not None:
laserkelvin marked this conversation as resolved.
Show resolved Hide resolved
self.results[ase_property] = model_output.detach().numpy()
else:
raise KeyError(
f"Expected model to return {model_property} as an output."
)

if len(self.results) == 0:
raise RuntimeError(
f"No expected properties were written. Output dict: {output}"
Expand Down
53 changes: 51 additions & 2 deletions matsciml/interfaces/ase/tests/test_ase_calc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@
ForceRegressionTask,
)
from matsciml.models.pyg import EGNN
from types import MethodType

import matgl
import torch
from matgl.ext.ase import Atoms2Graph


np.random.seed(21516136)

Expand Down Expand Up @@ -48,7 +54,9 @@ def test_egnn_energy_forces(egnn_config: dict, test_pbc: Atoms, pbc_transform: l
task = ForceRegressionTask(
encoder_class=EGNN, encoder_kwargs=egnn_config, output_kwargs={"hidden_dim": 32}
)
calc = MatSciMLCalculator(task, transforms=pbc_transform)
calc = MatSciMLCalculator(
task, transforms=pbc_transform, output_map={"forces": "force"}
)
atoms = test_pbc.copy()
atoms.calc = calc
energy = atoms.get_potential_energy()
Expand All @@ -62,8 +70,49 @@ def test_egnn_dynamics(egnn_config: dict, test_pbc: Atoms, pbc_transform: list):
task = ForceRegressionTask(
encoder_class=EGNN, encoder_kwargs=egnn_config, output_kwargs={"hidden_dim": 32}
)
calc = MatSciMLCalculator(task, transforms=pbc_transform)
calc = MatSciMLCalculator(
task, transforms=pbc_transform, output_map={"forces": "force"}
)
atoms = test_pbc.copy()
atoms.calc = calc
dyn = VelocityVerlet(atoms, timestep=5 * units.fs, logfile="md.log")
dyn.run(3)


def test_matgl():
matgl_model = matgl.load_model("CHGNet-MPtrj-2024.2.13-PES-11M")

def forward(self, atoms):
graph_converter = Atoms2Graph(
element_types=matgl_model.model.element_types,
cutoff=matgl_model.model.cutoff,
)
graph, lattice, state_feats_default = graph_converter.get_graph(atoms)
graph.edata["pbc_offshift"] = torch.matmul(
graph.edata["pbc_offset"], lattice[0]
)
graph.ndata["pos"] = graph.ndata["frac_coords"] @ lattice[0]
state_feats = torch.tensor(state_feats_default)
total_energies, forces, stresses, *others = self.matgl_forward(
graph, lattice, state_feats
)
output = {}
output["energy"] = total_energies
output["forces"] = forces
output["stress"] = stresses
return output

matgl_model.matgl_forward = matgl_model.forward
matgl_model.forward = MethodType(forward, matgl_model)

calc = MatSciMLCalculator(matgl_model, matsciml_model=False)
pos = np.random.normal(0.0, 1.0, size=(10, 3))
# Using a different atoms object due to pretrained model atom embedding expecting
# a different range of atomic numbers.
atomic_numbers = np.random.randint(1, 94, size=(10,))
atoms = Atoms(numbers=atomic_numbers, positions=pos)
atoms.calc = calc
energy = atoms.get_potential_energy()
assert np.isfinite(energy)
forces = atoms.get_forces()
assert np.isfinite(forces).all()
Loading