Skip to content

Commit 32a5700

Browse files
miskolbluque
andauthored
add optional field to calculator to output only requested (#922)
fix lint undo change to packages undo change to packages Co-authored-by: Luis Barroso-Luque <[email protected]> Former-commit-id: 7a32302ac1427cf87dcbc2d69001d66345a2fcdb
1 parent 639f5bb commit 32a5700

File tree

4 files changed

+42
-1
lines changed

4 files changed

+42
-1
lines changed

src/fairchem/core/common/relaxation/ase_utils.py

+15
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ def __init__(
124124
trainer: str | None = None,
125125
cpu: bool = True,
126126
seed: int | None = None,
127+
only_output: list[str] | None = None,
127128
) -> None:
128129
"""
129130
OCP-ASE Calculator
@@ -209,6 +210,20 @@ def __init__(
209210
self.config["checkpoint"] = str(checkpoint_path)
210211
del config["dataset"]["src"]
211212

213+
# some models that are published have configs that include tasks
214+
# which are not output by the model
215+
if only_output is not None:
216+
assert isinstance(
217+
only_output, list
218+
), "only output must be a list of targets to output"
219+
for key in only_output:
220+
assert (
221+
key in config["outputs"]
222+
), f"{key} listed in only_outputs is not present in current model outputs {config['outputs'].keys()}"
223+
remove_outputs = set(config["outputs"].keys()) - set(only_output)
224+
for key in remove_outputs:
225+
config["outputs"].pop(key)
226+
212227
self.trainer = registry.get_trainer_class(config["trainer"])(
213228
task=config.get("task", {}),
214229
model=config["model"],

src/fairchem/core/trainers/ocp_trainer.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,8 @@ def _forward(self, batch):
301301
)
302302
else:
303303
raise AttributeError(
304-
f"Output target: '{target_key}', not found in model outputs: {list(out.keys())}"
304+
f"Output target: '{target_key}', not found in model outputs: {list(out.keys())}\n"
305+
+ "If this is being called from OCPCalculator consider using only_output=[..]"
305306
)
306307

307308
### not all models are consistent with the output shape
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
# serializer version: 1
2+
# name: test_energy_with_is2re_model
3+
1.09
4+
# ---
25
# name: test_relaxation_final_energy
36
0.92
47
# ---

tests/core/common/test_ase_calculator.py

+22
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def atoms() -> Atoms:
3838
"PaiNN-S2EF-OC20-All",
3939
"GemNet-OC-Large-S2EF-OC20-All+MD",
4040
"SCN-S2EF-OC20-All+MD",
41+
"PaiNN-IS2RE-OC20-All",
4142
# Equiformer v2 # already tested in test_relaxation_final_energy
4243
# "EquiformerV2-153M-S2EF-OC20-All+MD"
4344
# eSCNm # already tested in test_random_seed_final_energy
@@ -54,6 +55,27 @@ def test_calculator_setup(checkpoint_path):
5455
_ = OCPCalculator(checkpoint_path=checkpoint_path, cpu=True)
5556

5657

58+
def test_energy_with_is2re_model(atoms, tmp_path, snapshot):
59+
random.seed(1)
60+
torch.manual_seed(1)
61+
62+
with pytest.raises(AttributeError): # noqa
63+
calc = OCPCalculator(
64+
checkpoint_path=model_name_to_local_file("PaiNN-IS2RE-OC20-All", tmp_path),
65+
cpu=True,
66+
)
67+
atoms.set_calculator(calc)
68+
atoms.get_potential_energy()
69+
70+
calc = OCPCalculator(
71+
checkpoint_path=model_name_to_local_file("PaiNN-IS2RE-OC20-All", tmp_path),
72+
cpu=True,
73+
only_output=["energy"],
74+
)
75+
atoms.set_calculator(calc)
76+
assert snapshot == round(atoms.get_potential_energy(), 2)
77+
78+
5779
# test relaxation with EqV2
5880
def test_relaxation_final_energy(atoms, tmp_path, snapshot) -> None:
5981
random.seed(1)

0 commit comments

Comments
 (0)