Skip to content

Commit

Permalink
Add a couple more tests for lookup tables (#129)
Browse files Browse the repository at this point in the history
* start adding tests

* add more tests

* add some more tests
  • Loading branch information
lilyminium authored Jul 22, 2024
1 parent fe5d09b commit 515b336
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 0 deletions.
2 changes: 2 additions & 0 deletions openff/nagl/tests/data/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,5 @@

EXAMPLE_FEATURIZED_LAZY_DATA = data_directory / "cbe6f394311f594a9df33d7580e8b8478f0aef5b505f16f8b2f6af721a14e30d.arrow"
EXAMPLE_FEATURIZED_LAZY_DATA_SHORT = data_directory / "b6713a9ba87e89cb53d264256664bb9e4e4a5831f0c7660808e9f44a9f832ab5.arrow"

EXAMPLE_MODEL_RC3 = data_directory / "openff-gnn-am1bcc-0.1.0-rc.3.pt"
Binary file not shown.
30 changes: 30 additions & 0 deletions openff/nagl/tests/nn/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from openff.nagl.lookups import AtomPropertiesLookupTable, AtomPropertiesLookupTableEntry
from openff.nagl.tests.data.files import (
EXAMPLE_AM1BCC_MODEL,
EXAMPLE_MODEL_RC3
)
from openff.nagl.features.atoms import (
AtomicElement,
Expand Down Expand Up @@ -394,3 +395,32 @@ def test_outside_lookup_table(self, am1bcc_model):
[-0.738375, 0.246125, 0.246125, 0.246125],
atol=1e-5
)

class TestGNNModelRC3:

@pytest.fixture()
def model(self):
return GNNModel.load(EXAMPLE_MODEL_RC3, eval_mode=True)

def test_contains_lookup_tables(self, model):
assert "am1bcc_charges" in model.lookup_tables
assert len(model.lookup_tables) == 1
assert len(model.lookup_tables["am1bcc_charges"]) == 13944

@pytest.mark.parametrize("lookup, expected_charges", [
(True, [-0.10866 , 0.027165, 0.027165, 0.027165, 0.027165]),
(False, [-0.159474, 0.039869, 0.039869, 0.039869, 0.039869])
])
def test_compute_property(
self, model, openff_methane_uncharged, lookup, expected_charges
):
charges = model.compute_property(
openff_methane_uncharged,
as_numpy=True,
check_lookup_table=lookup,

)
assert charges.shape == (5,)
assert charges.dtype == np.float32

assert_allclose(charges, expected_charges, atol=1e-5)

0 comments on commit 515b336

Please sign in to comment.