Skip to content

Commit ec91ac3

Browse files
lbluquezulissimeta
andauthored
OptimizableBatch and stress relaxations (#718)
* remove r_edges, radius, max_neigh and add deprecation warning * edit typing and dont use dicts as default * use super() and remove overkill deprecation warning * set implemented_properties from config * make determine step a method * allow calculator to operate on batches * only update if old config is used * reshape properties * no test classes in ase calculator * yaml load fix * use mappingproxy * expressive import * remove duplicated code * optimizable batch class for ase compatible batch relaxations * fix optimizable batch * optimizable goodies * apply force constraints * use optimizable batch instead and remove torchcalc * update ml relaxations to use optimizable batch correctly * force_consistent check for ASE compat * force_consistent check for ASE compat * check force_consistent * init docs in lbfgs * unitcellfilter for batch relaxations * ruff * UnitCellOptimizable as child class instead of filter * allow running unit cell relaxations * ruff * no grad in run_relaxations * make batched_dot and determine_step methods * imports * rename to optimizableunitcellbatch * allow passing energy and forces explicitly to batch to atoms * check convergence in optimizable and allow passing general results to atoms_from_batch * relaxation test * unit tests * move update mask to optimizable * use energy instead of y * all setting/getting positions and convergence in optimizable * more (unfinished) tests * backwards compatible test * minor fixes * code cleanup * add/fix tests * fix lbfgs * assert using norm * add eps to masked batches if using ASE optimizers * match iterations from previous implementation * use float64 for forces * float32 * use energy_relaxed instead of y_relaxed * energy_relaxed and more explicit error msg * default to batch_size 1 if not set in config * keep float64 training * rename y_relaxed -> energy_relaxed * rm expcell batch * convenience commit from no_experimental_resolve * use numatoms tensor for cell factor * remove positions tests (wrapping atoms gives different results) * allow wrapping positions in batch to atoms * fix test * wrap_positions in batch_to_atoms * take a2g properties from model * test lbfgs traj writes * remove comments * use model generate graph * fix cell_factor * fix using model in ddp * fix r_edges in OCPcalculator * write initial and final structure if save_full is false * check unique atoms saved in trajectory * tighter tol * update ASE release comment * remove cumulative mask option * remove left over cumulative_mask * fix batching when sids as str * do not try to fetch energy and forces if no explicit results * accept Path objects * clean up setting defaults * expose ml_relax in relaxation * force set r_pbc True * make relax_opt optional * no ema on inference only * define ema none to avoid issues * lower force threshold to make sure test does not converge * clean up exception msg * allow strings in batch * remove device argument from lbfgs * minor cleanup * fix optimizable import * do not pass device in ml_relax * simplify enforce max neighbors * fix tests (still not testing stress) * pin sphinx autoapi * typo in version --------- Co-authored-by: zulissimeta <[email protected]> Co-authored-by: Zack Ulissi <[email protected]> Former-commit-id: 69f13a70241995cdbfab66b8ce1d1459aa10c229
1 parent a8c448b commit ec91ac3

File tree

17 files changed

+1068
-241
lines changed

17 files changed

+1068
-241
lines changed

packages/fairchem-core/pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ dependencies = [
2828

2929
[project.optional-dependencies] # add optional dependencies to be installed as pip install fairchem.core[dev]
3030
dev = ["pre-commit", "pytest", "pytest-cov", "coverage", "syrupy", "ruff==0.5.1"]
31-
docs = ["jupyter-book", "jupytext", "sphinx","sphinx-autoapi", "umap-learn", "vdict"]
31+
docs = ["jupyter-book", "jupytext", "sphinx","sphinx-autoapi==3.3.3", "umap-learn", "vdict"]
3232
adsorbml = ["dscribe","x3dase","scikit-image"]
3333

3434
[project.scripts]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
"""
2+
Copyright (c) Meta, Inc. and its affiliates.
3+
4+
This source code is licensed under the MIT license found in the
5+
LICENSE file in the root directory of this source tree.
6+
"""
7+
8+
from __future__ import annotations
9+
10+
from .ml_relaxation import ml_relax
11+
from .optimizable import OptimizableBatch, OptimizableUnitCellBatch
12+
13+
__all__ = ["ml_relax", "OptimizableBatch", "OptimizableUnitCellBatch"]

src/fairchem/core/common/relaxation/ase_utils.py

+77-31
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@
1414

1515
import copy
1616
import logging
17-
from typing import ClassVar
17+
from types import MappingProxyType
18+
from typing import TYPE_CHECKING
1819

1920
import torch
2021
from ase import Atoms
2122
from ase.calculators.calculator import Calculator
22-
from ase.calculators.singlepoint import SinglePointCalculator as sp
23+
from ase.calculators.singlepoint import SinglePointCalculator
2324
from ase.constraints import FixAtoms
25+
from ase.geometry import wrap_positions
2426

2527
from fairchem.core.common.registry import registry
2628
from fairchem.core.common.utils import (
@@ -33,51 +35,93 @@
3335
from fairchem.core.models.model_registry import model_name_to_local_file
3436
from fairchem.core.preprocessing import AtomsToGraphs
3537

38+
if TYPE_CHECKING:
39+
from pathlib import Path
3640

37-
def batch_to_atoms(batch):
41+
from torch_geometric.data import Batch
42+
43+
44+
# system level model predictions have different shapes than expected by ASE
45+
ASE_PROP_RESHAPE = MappingProxyType(
46+
{"stress": (-1, 3, 3), "dielectric_tensor": (-1, 3, 3)}
47+
)
48+
49+
50+
def batch_to_atoms(
51+
batch: Batch,
52+
results: dict[str, torch.Tensor] | None = None,
53+
wrap_pos: bool = True,
54+
eps: float = 1e-7,
55+
) -> list[Atoms]:
56+
"""Convert a data batch to ase Atoms
57+
58+
Args:
59+
batch: data batch
60+
results: dictionary with predicted result tensors that will be added to a SinglePointCalculator. If no results
61+
are given no calculator will be added to the atoms objects.
62+
wrap_pos: wrap positions back into the cell.
63+
eps: Small number to prevent slightly negative coordinates from being wrapped.
64+
65+
Returns:
66+
list of Atoms
67+
"""
3868
n_systems = batch.natoms.shape[0]
3969
natoms = batch.natoms.tolist()
4070
numbers = torch.split(batch.atomic_numbers, natoms)
4171
fixed = torch.split(batch.fixed.to(torch.bool), natoms)
42-
forces = torch.split(batch.force, natoms)
72+
if results is not None:
73+
results = {
74+
key: val.view(ASE_PROP_RESHAPE.get(key, -1)).tolist()
75+
if len(val) == len(batch)
76+
else [v.cpu().detach().numpy() for v in torch.split(val, natoms)]
77+
for key, val in results.items()
78+
}
79+
4380
positions = torch.split(batch.pos, natoms)
4481
tags = torch.split(batch.tags, natoms)
4582
cells = batch.cell
46-
energies = batch.energy.view(-1).tolist()
4783

4884
atoms_objects = []
4985
for idx in range(n_systems):
86+
pos = positions[idx].cpu().detach().numpy()
87+
cell = cells[idx].cpu().detach().numpy()
88+
89+
# TODO take pbc from data
90+
if wrap_pos:
91+
pos = wrap_positions(pos, cell, pbc=[True, True, True], eps=eps)
92+
5093
atoms = Atoms(
5194
numbers=numbers[idx].tolist(),
52-
positions=positions[idx].cpu().detach().numpy(),
95+
cell=cell,
96+
positions=pos,
5397
tags=tags[idx].tolist(),
54-
cell=cells[idx].cpu().detach().numpy(),
5598
constraint=FixAtoms(mask=fixed[idx].tolist()),
5699
pbc=[True, True, True],
57100
)
58-
calc = sp(
59-
atoms=atoms,
60-
energy=energies[idx],
61-
forces=forces[idx].cpu().detach().numpy(),
62-
)
63-
atoms.set_calculator(calc)
101+
102+
if results is not None:
103+
calc = SinglePointCalculator(
104+
atoms=atoms, **{key: val[idx] for key, val in results.items()}
105+
)
106+
atoms.set_calculator(calc)
107+
64108
atoms_objects.append(atoms)
65109

66110
return atoms_objects
67111

68112

69113
class OCPCalculator(Calculator):
70-
implemented_properties: ClassVar[list[str]] = ["energy", "forces"]
114+
"""ASE based calculator using an OCP model"""
115+
116+
_reshaped_props = ASE_PROP_RESHAPE
71117

72118
def __init__(
73119
self,
74120
config_yml: str | None = None,
75-
checkpoint_path: str | None = None,
121+
checkpoint_path: str | Path | None = None,
76122
model_name: str | None = None,
77123
local_cache: str | None = None,
78124
trainer: str | None = None,
79-
cutoff: int = 6,
80-
max_neighbors: int = 50,
81125
cpu: bool = True,
82126
seed: int | None = None,
83127
) -> None:
@@ -96,16 +140,12 @@ def __init__(
96140
Directory to save pretrained model checkpoints.
97141
trainer (str):
98142
OCP trainer to be used. "forces" for S2EF, "energy" for IS2RE.
99-
cutoff (int):
100-
Cutoff radius to be used for data preprocessing.
101-
max_neighbors (int):
102-
Maximum amount of neighbors to store for a given atom.
103143
cpu (bool):
104144
Whether to load and run the model on CPU. Set `False` for GPU.
105145
"""
106146
setup_imports()
107147
setup_logging()
108-
Calculator.__init__(self)
148+
super().__init__()
109149

110150
if model_name is not None:
111151
if checkpoint_path is not None:
@@ -165,9 +205,8 @@ def __init__(
165205
### backwards compatability with OCP v<2.0
166206
config = update_config(config)
167207

168-
# Save config so obj can be transported over network (pkl)
169208
self.config = copy.deepcopy(config)
170-
self.config["checkpoint"] = checkpoint_path
209+
self.config["checkpoint"] = str(checkpoint_path)
171210
del config["dataset"]["src"]
172211

173212
self.trainer = registry.get_trainer_class(config["trainer"])(
@@ -199,14 +238,13 @@ def __init__(
199238
self.trainer.set_seed(seed)
200239

201240
self.a2g = AtomsToGraphs(
202-
max_neigh=max_neighbors,
203-
radius=cutoff,
204241
r_energy=False,
205242
r_forces=False,
206243
r_distances=False,
207-
r_edges=False,
208244
r_pbc=True,
245+
r_edges=not self.trainer.model.otf_graph, # otf graph should not be a property of the model...
209246
)
247+
self.implemented_properties = list(self.config["outputs"].keys())
210248

211249
def load_checkpoint(
212250
self, checkpoint_path: str, checkpoint: dict | None = None
@@ -217,6 +255,8 @@ def load_checkpoint(
217255
Args:
218256
checkpoint_path: string
219257
Path to trained model
258+
checkpoint: dict
259+
A pretrained checkpoint dict
220260
"""
221261
try:
222262
self.trainer.load_checkpoint(
@@ -225,14 +265,20 @@ def load_checkpoint(
225265
except NotImplementedError:
226266
logging.warning("Unable to load checkpoint!")
227267

228-
def calculate(self, atoms: Atoms, properties, system_changes) -> None:
229-
Calculator.calculate(self, atoms, properties, system_changes)
230-
data_object = self.a2g.convert(atoms)
231-
batch = data_list_collater([data_object], otf_graph=True)
268+
def calculate(self, atoms: Atoms | Batch, properties, system_changes) -> None:
269+
"""Calculate implemented properties for a single Atoms object or a Batch of them."""
270+
super().calculate(atoms, properties, system_changes)
271+
if isinstance(atoms, Atoms):
272+
data_object = self.a2g.convert(atoms)
273+
batch = data_list_collater([data_object], otf_graph=True)
274+
else:
275+
batch = atoms
232276

233277
predictions = self.trainer.predict(batch, per_image=False, disable_tqdm=True)
234278

235279
for key in predictions:
236280
_pred = predictions[key]
237281
_pred = _pred.item() if _pred.numel() == 1 else _pred.cpu().numpy()
282+
if key in OCPCalculator._reshaped_props:
283+
_pred = _pred.reshape(OCPCalculator._reshaped_props.get(key)).squeeze()
238284
self.results[key] = _pred

src/fairchem/core/common/relaxation/ml_relaxation.py

+79-38
Original file line numberDiff line numberDiff line change
@@ -10,85 +10,126 @@
1010
import logging
1111
from collections import deque
1212
from pathlib import Path
13+
from typing import TYPE_CHECKING
1314

1415
import torch
1516
from torch_geometric.data import Batch
1617

1718
from fairchem.core.common.typing import assert_is_instance
1819
from fairchem.core.datasets.lmdb_dataset import data_list_collater
1920

20-
from .optimizers.lbfgs_torch import LBFGS, TorchCalc
21+
from .optimizable import OptimizableBatch, OptimizableUnitCellBatch
22+
from .optimizers.lbfgs_torch import LBFGS
23+
24+
if TYPE_CHECKING:
25+
from fairchem.core.trainers import BaseTrainer
2126

2227

2328
def ml_relax(
24-
batch,
25-
model,
29+
batch: Batch,
30+
model: BaseTrainer,
2631
steps: int,
2732
fmax: float,
28-
relax_opt,
29-
save_full_traj,
30-
device: str = "cuda:0",
31-
transform=None,
32-
early_stop_batch: bool = False,
33+
relax_opt: dict[str] | None = None,
34+
relax_cell: bool = False,
35+
relax_volume: bool = False,
36+
save_full_traj: bool = True,
37+
transform: torch.nn.Module | None = None,
38+
mask_converged: bool = True,
3339
):
34-
"""
35-
Runs ML-based relaxations.
40+
"""Runs ML-based relaxations.
41+
3642
Args:
37-
batch: object
38-
model: object
39-
steps: int
40-
Max number of steps in the structure relaxation.
41-
fmax: float
42-
Structure relaxation terminates when the max force
43-
of the system is no bigger than fmax.
44-
relax_opt: str
45-
Optimizer and corresponding parameters to be used for structure relaxations.
46-
save_full_traj: bool
47-
Whether to save out the full ASE trajectory. If False, only save out initial and final frames.
43+
batch: a data batch object.
44+
model: a trainer object with model.
45+
steps: Max number of steps in the structure relaxation.
46+
fmax: Structure relaxation terminates when the max force of the system is no bigger than fmax.
47+
relax_opt: Optimizer parameters to be used for structure relaxations.
48+
relax_cell: if true will use stress predictions to relax crystallographic cell.
49+
The model given must predict stress
50+
relax_volume: if true will relax the cell isotropically. the given model must predict stress.
51+
save_full_traj: Whether to save out the full ASE trajectory. If False, only save out initial and final frames.
52+
mask_converged: whether to mask batches where all atoms are below convergence threshold
53+
cumulative_mask: if true, once system is masked then it remains masked even if new predictions give forces
54+
above threshold, ie. once masked always masked. Note if this is used make sure to check convergence with
55+
the same fmax always
4856
"""
57+
relax_opt = relax_opt or {}
58+
# if not pbc is set, ignore it when comparing batches
59+
if not hasattr(batch, "pbc"):
60+
OptimizableBatch.ignored_changes = {"pbc"}
61+
4962
batches = deque([batch])
5063
relaxed_batches = []
5164
while batches:
5265
batch = batches.popleft()
5366
oom = False
5467
ids = batch.sid
55-
calc = TorchCalc(model, transform)
68+
69+
# clone the batch otherwise you can not run batch.to_data_list
70+
# see https://github.com/pyg-team/pytorch_geometric/issues/8439#issuecomment-1826747915
71+
if relax_cell or relax_volume:
72+
optimizable = OptimizableUnitCellBatch(
73+
batch.clone(),
74+
trainer=model,
75+
transform=transform,
76+
mask_converged=mask_converged,
77+
hydrostatic_strain=relax_volume,
78+
)
79+
else:
80+
optimizable = OptimizableBatch(
81+
batch.clone(),
82+
trainer=model,
83+
transform=transform,
84+
mask_converged=mask_converged,
85+
)
5686

5787
# Run ML-based relaxation
58-
traj_dir = relax_opt.get("traj_dir", None)
88+
traj_dir = relax_opt.get("traj_dir")
89+
relax_opt.update({"traj_dir": Path(traj_dir) if traj_dir is not None else None})
90+
5991
optimizer = LBFGS(
60-
batch,
61-
calc,
62-
maxstep=relax_opt.get("maxstep", 0.2),
63-
memory=relax_opt["memory"],
64-
damping=relax_opt.get("damping", 1.2),
65-
alpha=relax_opt.get("alpha", 80.0),
66-
device=device,
92+
optimizable_batch=optimizable,
6793
save_full_traj=save_full_traj,
68-
traj_dir=Path(traj_dir) if traj_dir is not None else None,
6994
traj_names=ids,
70-
early_stop_batch=early_stop_batch,
95+
**relax_opt,
7196
)
7297

7398
e: RuntimeError | None = None
7499
try:
75-
relaxed_batch = optimizer.run(fmax=fmax, steps=steps)
76-
relaxed_batches.append(relaxed_batch)
100+
optimizer.run(fmax=fmax, steps=steps)
101+
relaxed_batches.append(optimizable.batch)
77102
except RuntimeError as err:
78103
e = err
79104
oom = True
80105
torch.cuda.empty_cache()
81106

82107
if oom:
83-
# move OOM recovery code outside of except clause to allow tensors to be freed.
108+
# move OOM recovery code outside off except clause to allow tensors to be freed.
84109
data_list = batch.to_data_list()
85110
if len(data_list) == 1:
86111
raise assert_is_instance(e, RuntimeError)
87112
logging.info(
88113
f"Failed to relax batch with size: {len(data_list)}, splitting into two..."
89114
)
90115
mid = len(data_list) // 2
91-
batches.appendleft(data_list_collater(data_list[:mid]))
92-
batches.appendleft(data_list_collater(data_list[mid:]))
116+
batches.appendleft(
117+
data_list_collater(data_list[:mid], otf_graph=optimizable.otf_graph)
118+
)
119+
batches.appendleft(
120+
data_list_collater(data_list[mid:], otf_graph=optimizable.otf_graph)
121+
)
122+
123+
# reset for good measure
124+
OptimizableBatch.ignored_changes = {}
125+
126+
relaxed_batch = Batch.from_data_list(relaxed_batches)
127+
128+
# Batch.from_data_list is not intended to be used with a list of batches, so when sid is a list of str
129+
# it will be incorrectly collated as a list of lists for each batch.
130+
# but we can not use to_data_list in the relaxed batches (since they have been changed, see linked comment above).
131+
# So instead just manually fix it for now. Remove this once pyg dependency is removed
132+
if isinstance(relaxed_batch.sid, list):
133+
relaxed_batch.sid = [sid for sid_list in relaxed_batch.sid for sid in sid_list]
93134

94-
return Batch.from_data_list(relaxed_batches)
135+
return relaxed_batch

0 commit comments

Comments
 (0)