diff --git a/environment.yml b/environment.yml index cd3a70c..2a15530 100644 --- a/environment.yml +++ b/environment.yml @@ -8,9 +8,11 @@ dependencies: - openff-units - pip - tqdm + - pyyaml # for testing - coverage - pooch - pytest - pytest-cov - pytest-xdist + - pytest-rerunfailures diff --git a/pyproject.toml b/pyproject.toml index 0bff04d..0efdae1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,8 +17,9 @@ dependencies = [ 'numpy', 'netCDF4', 'openff-units', + 'pyyaml', ] -description="Analysis of free energy calculations." +description="Trajectory analysis of free energy calculations." readme="README.md" requires-python = ">=3.10" classifiers = [ @@ -29,11 +30,11 @@ classifiers = [ [tool.setuptools] zip-safe = false +include-package-data = true license-files = ["LICENSE"] [tool.setuptools.packages.find] where = ["src"] -include = ["openfe_analysis"] [project.optional-dependencies] test = [ diff --git a/src/openfe_analysis/__init__.py b/src/openfe_analysis/__init__.py index a6cfa64..ed37028 100644 --- a/src/openfe_analysis/__init__.py +++ b/src/openfe_analysis/__init__.py @@ -1,6 +1,5 @@ from ._version import __version__ -from . import handle_trajectories from .reader import FEReader from .transformations import ( NoJump, diff --git a/src/openfe_analysis/reader.py b/src/openfe_analysis/reader.py index 5eebf14..b27eec0 100644 --- a/src/openfe_analysis/reader.py +++ b/src/openfe_analysis/reader.py @@ -1,42 +1,63 @@ from MDAnalysis.coordinates.base import ReaderBase, Timestep import netCDF4 as nc from openff.units import unit +import numpy as np +import yaml from typing import Optional -from . import handle_trajectories +from openfe_analysis.utils import multistate, serialization +from openfe_analysis.utils.multistate import _determine_position_indices -def _determine_dt(ds) -> float: - # first grab integrator timestep - mcmc_move_data = ds.groups['mcmc_moves']['move0'][0].split('\n') - in_timestep = False - for line in mcmc_move_data: - if line.startswith('timestep'): - in_timestep = True - if in_timestep and line.strip().startswith('value'): - timestep = float(line.split()[-1]) / 1000. # convert to ps - break - else: - raise ValueError("Didn't find timestep") - # next get the save interval - option_data = ds.variables['options'][0].split('\n') - for line in option_data: - if line.startswith('online_analysis_interval'): - nsteps = float(line.split()[-1]) - break - else: - raise ValueError("Didn't find online_analysis_interval") +def _determine_iteration_dt(dataset) -> float: + """ + Find out the timestep between each frame in the trajectory. + + Parameters + ---------- + dataset : nc.Dataset + Dataset holding the MultiStateReporter generated NetCDF file. + + Returns + ------- + float + The timestep in units of picoseconds. + + Raises + ------ + KeyError + If either `timestep` or `n_steps` cannot be found in the + zeroth MCMC move. + + Notes + ----- + This assumes an MCMC move which serializes in a manner similar + to `openmmtools.mcmc.LangevinDynamicsMove`, i.e. it must have + both a `timestep` and `n_steps` defined. + """ + # Deserialize the MCMC move information for the 0th entry. + mcmc_move_data = yaml.load( + dataset.groups['mcmc_moves']['move0'][0], + Loader=serialization.UnitedYamlLoader, + ) + + try: + dt = mcmc_move_data['n_steps'] * mcmc_move_data['timestep'] + except KeyError: + msg = "Either `n_steps` or `timestep` are missing from the MCMC move" + raise KeyError(msg) - return timestep * nsteps + return dt.to('picosecond').m class FEReader(ReaderBase): - """A MDAnalysis Reader for nc files created by openfe RFE Protocol + """A MDAnalysis Reader for NetCDF files created by + `openmmtools.multistate.MultiStateReporter` - Looks along a multistate .nc file along one of two axes: - - constant state/lambda (varying replica) - - constant replica (varying lambda) + Looks along a multistate NetCDF file along one of two axes: + - constant state/lambda (varying replica) + - constant replica (varying lambda) """ _state_id: Optional[int] _replica_id: Optional[int] @@ -44,23 +65,42 @@ class FEReader(ReaderBase): _dataset: nc.Dataset _dataset_owner: bool - format = 'openfe RFE' + format = 'MultiStateReporter' - def __init__(self, filename, convert_units=True, **kwargs): + units = { + 'time': 'ps', + 'length': 'nanometer' + } + + def __init__( + self, filename, convert_units=True, + state_id=None, replica_id=None, **kwargs + ): """ Parameters ---------- filename : pathlike or nc.Dataset path to the .nc file convert_units : bool - convert positions to A + convert positions to Angstrom + state_id : Optional[int] + The Hamiltonian state index to extract. Must be defined if + ``replica_id`` is not defined. May be negative (see notes below). + replica_id : Optional[int] + The replica index to extract. Must be defined if ``state_id`` + is not defined. May be negative (see notes below). + + Notes + ----- + A negative index may be passed to either ``state_id`` or + ``replica_id``. This will be interpreted as indexing in reverse + starting from the last state/replica. For example, passing a + value of -2 for ``replica_id`` will select the before last replica. """ - self._state_id = kwargs.pop('state_id', None) - self._replica_id = kwargs.pop('replica_id', None) - if not ((self._state_id is None) ^ (self._replica_id is None)): + if not ((state_id is None) ^ (replica_id is None)): raise ValueError("Specify one and only one of state or replica, " - f"got state id={self._state_id} " - f"replica_id={self._replica_id}") + f"got state id={state_id} " + f"replica_id={replica_id}") super().__init__(filename, convert_units, **kwargs) @@ -70,9 +110,23 @@ def __init__(self, filename, convert_units=True, **kwargs): else: self._dataset = nc.Dataset(filename) self._dataset_owner = True + + # Handle the negative ID case + if state_id is not None and state_id < 0: + state_id = range(self._dataset.dimensions['state'].size)[state_id] + + if replica_id is not None and replica_id < 0: + replica_id = range(self._dataset.dimensions['replica'].size)[replica_id] + + self._state_id = state_id + self._replica_id = replica_id + self._n_atoms = self._dataset.dimensions['atom'].size self.ts = Timestep(self._n_atoms) - self._dt = _determine_dt(self._dataset) + self._frames = _determine_position_indices(self._dataset) + # The MDAnalysis trajectory "dt" is the iteration dt + # multiplied by the number of iterations between frames. + self._dt = _determine_iteration_dt(self._dataset) * np.diff(self._frames)[0] self._read_frame(0) @staticmethod @@ -86,13 +140,12 @@ def n_atoms(self) -> int: @property def n_frames(self) -> int: - return self._dataset.dimensions['iteration'].size + return len(self._frames) @staticmethod def parse_n_atoms(filename, **kwargs) -> int: with nc.Dataset(filename) as ds: n_atoms = ds.dimensions['atom'].size - return n_atoms def _read_next_timestep(self, ts=None) -> Timestep: @@ -104,24 +157,39 @@ def _read_frame(self, frame: int) -> Timestep: self._frame_index = frame if self._state_id is not None: - rep = handle_trajectories._state_to_replica( + rep = multistate._state_to_replica( self._dataset, self._state_id, - self._frame_index + self._frames[self._frame_index] ) else: rep = self._replica_id - pos = handle_trajectories._replica_positions_at_frame( + pos = multistate._replica_positions_at_frame( self._dataset, rep, - self._frame_index) - dim = handle_trajectories._get_unitcell( + self._frames[self._frame_index] + ) + dim = multistate._get_unitcell( self._dataset, rep, - self._frame_index) + self._frames[self._frame_index] + ) + + if pos is None: + errmsg = ( + "NetCDF dataset frame without positions was accessed " + "this likely indicates that the reader failed to work out " + "the write frequency and there is a deeper issue with how " + "this file was written." + ) + raise RuntimeError(errmsg) - self.ts.positions = (pos.to(unit.angstrom)).m + # Convert to base MDAnalysis distance units (Angstrom) if requested + if self.convert_units: + self.ts.positions = (pos.to(unit.angstrom)).m + else: + self.ts.positions = pos.m self.ts.dimensions = dim self.ts.frame = self._frame_index self.ts.time = self._frame_index * self._dt diff --git a/src/openfe_analysis/tests/conftest.py b/src/openfe_analysis/tests/conftest.py new file mode 100644 index 0000000..ad6c909 --- /dev/null +++ b/src/openfe_analysis/tests/conftest.py @@ -0,0 +1,58 @@ +from importlib import resources +import pooch +import pytest + + +RFE_OUTPUT = pooch.create( + path=pooch.os_cache("openfe_analysis"), + base_url="doi:10.6084/m9.figshare.24101655", + registry={ + "checkpoint.nc": "5af398cb14340fddf7492114998b244424b6c3f4514b2e07e4bd411484c08464", + "db.json": "b671f9eb4daf9853f3e1645f9fd7c18150fd2a9bf17c18f23c5cf0c9fd5ca5b3", + "hybrid_system.pdb": "07203679cb14b840b36e4320484df2360f45e323faadb02d6eacac244fddd517", + "simulation.nc": "92361a0864d4359a75399470135f56642b72c605069a4c33dbc4be6f91f28b31", + "simulation_real_time_analysis.yaml": "65706002f371fafba96037f29b054fd7e050e442915205df88567f48f5e5e1cf", + } +) + + +RFE_OUTPUT_skipped_frames = pooch.create( + path=pooch.os_cache("openfe_analysis_skipped"), + base_url="doi:10.6084/m9.figshare.28263203", + registry={ + "hybrid_system.pdb": "77c7914b78724e568f38d5a308d36923f5837c03a1d094e26320b20aeec65fee", + "simulation.nc": "6749e2c895f16b7e4eba196261c34756a0a062741d36cc74925676b91a36d0cd", + } +) + + +@pytest.fixture(scope='session') +def simulation_nc(): + return RFE_OUTPUT.fetch("simulation.nc") + + +@pytest.fixture(scope='session') +def simulation_skipped_nc(): + return RFE_OUTPUT_skipped_frames.fetch("simulation.nc") + + +@pytest.fixture(scope='session') +def hybrid_system_pdb(): + return RFE_OUTPUT.fetch("hybrid_system.pdb") + + +@pytest.fixture(scope='session') +def hybrid_system_skipped_pdb(): + return RFE_OUTPUT_skipped_frames.fetch("hybrid_system.pdb") + + +@pytest.fixture(scope='session') +def mcmc_serialized(): + return ( + '_serialized__class_name: LangevinDynamicsMove\n' + '_serialized__module_name: openmmtools.mcmc\n' + 'collision_rate: !Quantity\n unit: /picosecond\n value: 1\n' + 'constraint_tolerance: 1.0e-06\nn_restart_attempts: 20\n' + 'n_steps: 625\nreassign_velocities: false\n' + 'timestep: !Quantity\n unit: femtosecond\n value: 4\n' + ) diff --git a/src/openfe_analysis/tests/data/__init__.py b/src/openfe_analysis/tests/data/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/openfe_analysis/tests/test_reader.py b/src/openfe_analysis/tests/test_reader.py new file mode 100644 index 0000000..98e20c5 --- /dev/null +++ b/src/openfe_analysis/tests/test_reader.py @@ -0,0 +1,183 @@ +import MDAnalysis as mda +from openfe_analysis.reader import FEReader, _determine_iteration_dt +import netCDF4 as nc +from numpy.testing import assert_allclose +import numpy as np +import pytest + + +def test_determine_dt(tmpdir, mcmc_serialized): + with tmpdir.as_cwd(): + # create a fake dataset with a fake mcmc move group + ds = nc.Dataset('foo', 'w', format='NETCDF3_64BIT_OFFSET') + ds.groups['mcmc_moves'] = { + 'move0': [mcmc_serialized] + } + + assert _determine_iteration_dt(ds) == 2.5 + + +def test_determine_dt_keyerror(tmpdir, mcmc_serialized): + with tmpdir.as_cwd(): + # create a fake dataset with fake mcmc move without timestep + ds = nc.Dataset('foo', 'w', format='NETCDF3_64BIT_OFFSET') + ds.groups['mcmc_moves'] = { + 'move0': [mcmc_serialized[:-51]] + } + + with pytest.raises(KeyError, match="`n_steps` or `timestep` are"): + _ = _determine_iteration_dt(ds) + + +def test_universe_creation(simulation_nc, hybrid_system_pdb): + with pytest.warns(UserWarning, match='This is an older NetCDF file that'): + u = mda.Universe(hybrid_system_pdb, simulation_nc, + format=FEReader, state_id=0) + + # Check that a Universe exists + assert u + + # Test the basics + assert len(u.atoms) == 4782 + assert len(u.trajectory) == 501 + assert u.trajectory.dt == pytest.approx(1.0) + assert u.trajectory.time == pytest.approx(0.0) + assert u.trajectory.totaltime == pytest.approx(500) + + # Check the dimensions & positions of the first frame + # Note: we multipy the positions by 10 since it's stored in nm + assert_allclose( + u.atoms[:3].positions, + np.array([[6.51474, -1.7640617, 8.406607], + [6.641961, -1.8410535, 8.433087], + [6.71369, -1.8112476, 8.533738]]) * 10, + ) + assert_allclose( + u.dimensions, + [82.06851, 82.06851, 82.06851, 90., 90., 90.] + ) + + # Now check the second frame + u.trajectory[1] + assert u.trajectory.time == pytest.approx(1.0) + assert_allclose( + u.atoms[4:7].positions, + np.array([[6.78754, -1.2783755, 8.433636], + [6.62524, -1.333609, 8.399696], + [6.744502, -1.5663723, 8.332421]]) * 10, + ) + assert_allclose( + u.dimensions, + [82.191055, 82.191055, 82.191055, 90., 90., 90.] + ) + + # Now check the last frame + u.trajectory[-1] + assert u.trajectory.time == pytest.approx(500.0) + assert_allclose( + u.atoms[-3:].positions, + np.array( + [[2.9948092, 7.7675443, 0.19704354], + [0.95652354, 2.99566, 1.3466661], + [4.0027137, 4.695961, 3.6892936]] + ) * 10, + ) + assert_allclose( + u.dimensions, + [82.12723, 82.12723, 82.12723, 90., 90., 90.] + ) + + # Finally we rewind to the second frame to make sure that's possible + u.trajectory[1] + assert u.trajectory.time == pytest.approx(1.0) + assert_allclose( + u.atoms[4:7].positions, + np.array([[6.78754, -1.2783755, 8.433636], + [6.62524, -1.333609, 8.399696], + [6.744502, -1.5663723, 8.332421]]) * 10, + ) + assert_allclose( + u.dimensions, + [82.191055, 82.191055, 82.191055, 90., 90., 90.] + ) + + +def test_universe_from_nc_file(simulation_nc, hybrid_system_pdb): + ds = nc.Dataset(simulation_nc) + + with pytest.warns(UserWarning, match='This is an older NetCDF file that'): + u = mda.Universe(hybrid_system_pdb, ds, + format='MultiStateReporter', state_id=0) + + assert u + assert len(u.atoms) == 4782 + assert len(u.trajectory) == 501 + assert u.trajectory.dt == pytest.approx(1.0) + + +def test_universe_creation_noconversion(simulation_nc, hybrid_system_pdb): + with pytest.warns(UserWarning, match='This is an older NetCDF file that'): + u = mda.Universe(hybrid_system_pdb, simulation_nc, + format=FEReader, state_id=0, convert_units=False) + + assert_allclose( + u.atoms[:3].positions, + np.array([[6.51474, -1.7640617, 8.406607], + [6.641961, -1.8410535, 8.433087], + [6.71369, -1.8112476, 8.533738]]), + ) + + +def test_fereader_negative_state(simulation_nc, hybrid_system_pdb): + with pytest.warns(UserWarning, match='This is an older NetCDF file that'): + u = mda.Universe( + hybrid_system_pdb, simulation_nc, format=FEReader, + state_id=-1 + ) + + assert u.trajectory._state_id == 10 + assert u.trajectory._replica_id is None + + +def test_fereader_negative_replica(simulation_nc, hybrid_system_pdb): + with pytest.warns(UserWarning, match='This is an older NetCDF file that'): + u = mda.Universe( + hybrid_system_pdb, simulation_nc, format=FEReader, + replica_id=-2 + ) + + assert u.trajectory._state_id is None + assert u.trajectory._replica_id == 9 + + +@pytest.mark.parametrize('rep_id, state_id', [ + [None, None], + [1, 1] +]) +@pytest.mark.flaky(reruns=3) +def test_fereader_replica_state_id_error( + simulation_nc, hybrid_system_pdb, rep_id, state_id +): + with pytest.raises(ValueError, match="Specify one and only one"): + _ = mda.Universe( + hybrid_system_pdb, simulation_nc, format=FEReader, + state_id=state_id, replica_id=rep_id + ) + + +def test_simulation_skipped_nc( + simulation_skipped_nc, hybrid_system_skipped_pdb +): + u = mda.Universe( + hybrid_system_skipped_pdb, simulation_skipped_nc, + format=FEReader, replica_id=0, + ) + + assert len(u.trajectory) == 6 + assert u.trajectory.n_frames == 6 + times = [0, 100, 200, 300, 400, 500] + for inx, ts in enumerate(u.trajectory): + assert ts.time == times[inx] + assert np.all(u.atoms.positions > 0) + with pytest.raises(mda.exceptions.NoDataError, match='This Timestep has no velocities'): + u.atoms.velocities diff --git a/src/tests/test_transformations.py b/src/openfe_analysis/tests/test_transformations.py similarity index 93% rename from src/tests/test_transformations.py rename to src/openfe_analysis/tests/test_transformations.py index 9f63a9b..327b8de 100644 --- a/src/tests/test_transformations.py +++ b/src/openfe_analysis/tests/test_transformations.py @@ -15,11 +15,12 @@ def universe(hybrid_system_pdb, simulation_nc): return mda.Universe( hybrid_system_pdb, simulation_nc, - format='openfe rfe', + format='MultiStateReporter', state_id=0, ) +@pytest.mark.flaky(reruns=3) def test_minimiser(universe): prot = universe.select_atoms('protein and name CA') lig = universe.select_atoms('resname UNK') @@ -34,6 +35,7 @@ def test_minimiser(universe): assert d == pytest.approx(11.10, abs=0.01) +@pytest.mark.flaky(reruns=3) def test_nojump(universe): # find frame where protein would teleport across boundary and check it prot = universe.select_atoms('protein and name CA') @@ -49,6 +51,7 @@ def test_nojump(universe): assert prot.center_of_mass() == pytest.approx(ref, abs=0.01) +@pytest.mark.flaky(reruns=3) def test_aligner(universe): # checks that rmsd is identical with/without center&super prot = universe.select_atoms('protein and name CA') diff --git a/src/openfe_analysis/tests/utils/__init__.py b/src/openfe_analysis/tests/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/openfe_analysis/tests/utils/test_multistate.py b/src/openfe_analysis/tests/utils/test_multistate.py new file mode 100644 index 0000000..dfba318 --- /dev/null +++ b/src/openfe_analysis/tests/utils/test_multistate.py @@ -0,0 +1,98 @@ +import netCDF4 as nc +import numpy as np +from openff.units import unit +from numpy.testing import assert_allclose +import pytest +from openfe_analysis import __version__ +from openfe_analysis.utils.multistate import ( + _state_to_replica, + _replica_positions_at_frame, + _create_new_dataset, + _get_unitcell, +) + + +@pytest.fixture(scope='module') +def dataset(simulation_nc): + return nc.Dataset(simulation_nc) + + +@pytest.mark.flaky(reruns=3) +@pytest.mark.parametrize('state, frame, replica', [ + [0, 0, 0], + [0, 1, 3], + [0, -1, 7], + [3, 100, 6] +]) +def test_state_to_replica(dataset, state, frame, replica): + assert _state_to_replica(dataset, state, frame) == replica + + +@pytest.mark.flaky(reruns=3) +def test_replica_positions_at_frame(dataset): + pos = _replica_positions_at_frame(dataset, 1, -1) + assert_allclose( + pos[-3] * unit('nanometer'), + np.array([0.6037003, 7.2835016, 5.804355]) * unit('nanometer') + ) + + +def test_create_new_dataset(tmpdir): + with tmpdir.as_cwd(): + ds = _create_new_dataset('foo.nc', 100, title='bar') + + # Test metadata + assert ds.Conventions == 'AMBER' + assert ds.ConventionVersion == '1.0' + assert ds.application == 'openfe_analysis' + assert ds.program == f"openfe_analysis {__version__}" + assert ds.programVersion == f"{__version__}" + assert ds.title == 'bar' + + # Test dimensions + assert ds.dimensions['frame'].size == 0 + assert ds.dimensions['spatial'].size == 3 + assert ds.dimensions['atom'].size == 100 + assert ds.dimensions['cell_spatial'].size == 3 + assert ds.dimensions['cell_angular'].size == 3 + assert ds.dimensions['label'].size == 5 + + # Test variables + assert ds.variables['coordinates'].units == 'angstrom' + assert ds.variables['coordinates'].get_dims()[0].name == 'frame' + assert ds.variables['coordinates'].get_dims()[1].name == 'atom' + assert ds.variables['coordinates'].get_dims()[2].name == 'spatial' + assert ds.variables['coordinates'].dtype.name == 'float32' + + assert ds.variables['cell_lengths'].units == 'angstrom' + assert ds.variables['cell_lengths'].get_dims()[0].name == 'frame' + assert ds.variables['cell_lengths'].get_dims()[1].name == 'cell_spatial' + assert ds.variables['cell_lengths'].dtype.name == 'float64' + + assert ds.variables['cell_angles'].units == 'degree' + assert ds.variables['cell_angles'].get_dims()[0].name == 'frame' + assert ds.variables['cell_angles'].get_dims()[1].name == 'cell_angular' + assert ds.variables['cell_angles'].dtype.name == 'float64' + + +def test_get_unitcell(dataset): + dims = _get_unitcell(dataset, 7, -1) + assert_allclose( + dims, + [82.12723, 82.12723, 82.12723, 90., 90., 90.] + ) + + dims = _get_unitcell(dataset, 3, 1) + assert_allclose( + dims, + [82.191055, 82.191055, 82.191055, 90., 90., 90.] + ) + + +def test_simulation_skipped_nc_no_positions_box_vectors_frame1( + simulation_skipped_nc, +): + dataset = nc.Dataset(simulation_skipped_nc) + + assert _get_unitcell(dataset, 1, 1) is None + assert dataset.variables['positions'][1][0].mask.all() diff --git a/src/openfe_analysis/tests/utils/test_serialization.py b/src/openfe_analysis/tests/utils/test_serialization.py new file mode 100644 index 0000000..bf4e95f --- /dev/null +++ b/src/openfe_analysis/tests/utils/test_serialization.py @@ -0,0 +1,37 @@ +import pytest +import yaml +from openfe_analysis.utils.serialization import ( + omm_quantity_string_to_offunit, + UnitedYamlLoader, +) +from openff.units import unit + + +@pytest.mark.parametrize('expression, expected', [ + ['/ picosecond', 1 / unit('picosecond')], + ['5 kilocalorie / mole', 5 * unit('kilocalorie_per_mole')], + ['4 femtosecond', 4 * unit('femtosecond')], +]) +def test_quantity_string_to_offunit(expression, expected): + retval = omm_quantity_string_to_offunit(expression,) + + assert retval == expected + + +def test_unitedyamlloader(mcmc_serialized): + data = yaml.load(mcmc_serialized, Loader=UnitedYamlLoader) + + expected = { + '_serialized__class_name': 'LangevinDynamicsMove', + '_serialized__module_name': 'openmmtools.mcmc', + 'collision_rate': 1 / unit.picosecond, + 'constraint_tolerance': 1.0e-06, + 'n_restart_attempts': 20, + 'n_steps': 625, + 'reassign_velocities': False, + 'timestep': 4 * unit.femtosecond, + } + + assert data.keys() == expected.keys() + for key in data.keys(): + assert data[key] == expected[key] diff --git a/src/openfe_analysis/utils/__init__.py b/src/openfe_analysis/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/openfe_analysis/handle_trajectories.py b/src/openfe_analysis/utils/multistate.py similarity index 67% rename from src/openfe_analysis/handle_trajectories.py rename to src/openfe_analysis/utils/multistate.py index 9ea3923..27d91f1 100644 --- a/src/openfe_analysis/handle_trajectories.py +++ b/src/openfe_analysis/utils/multistate.py @@ -1,15 +1,61 @@ import netCDF4 as nc import numpy as np +from numpy.typing import NDArray from pathlib import Path +import warnings from openff.units import unit -from typing import Optional +from typing import Optional, Tuple -from . import __version__ +from openfe_analysis import __version__ + + +def _determine_position_indices(dataset: nc.Dataset) -> NDArray: + """ + Determine which iteration indices hold positions. + + Parameters + ---------- + dataset : nc.Dataset + Dataset holding the MultiStateReporter generated NetCDF file. + + Returns + ------- + indices : NDArray[int] + An ordered array of iteration indices which hold positions. + + Note + ---- + This assumes that the indices are equally spaced by a given + value. + """ + if hasattr(dataset, 'PositionInterval'): + indices = [ + i for i in + range(0, dataset.dimensions['iteration'].size, dataset.PositionInterval) + ] + else: + wmsg = ('This is an older NetCDF file that does not yet contain ' + 'information about the write frequency of positions and ' + 'velocities. We will assume that positions and velocities ' + 'were written out at every iteration. ') + warnings.warn(wmsg) + indices = [i for i in range(0, dataset.dimensions['iteration'].size)] + + indices = np.array(indices) + + if not all(np.diff(indices) == np.diff(indices)[0]): + errmsg = ( + "Positions are not written at a consistent frame rate, " + "this is not currently supported" + ) + raise ValueError(errmsg) + + return indices def _state_to_replica(dataset: nc.Dataset, state_num: int, frame_num: int) -> int: - """Convert a state index to replica index at a given frame + """Convert a state index to replica index at a given Dataset frame Parameters ---------- @@ -31,11 +77,14 @@ def _state_to_replica(dataset: nc.Dataset, state_num: int, return np.where(state_distribution == state_num)[0][0] -def _replica_positions_at_frame(dataset: nc.Dataset, - replica_index: int, - frame_num: int) -> unit.Quantity: +def _replica_positions_at_frame( + dataset: nc.Dataset, + replica_index: int, + frame_num: int +) -> Optional[unit.Quantity]: """ - Helper method to extract atom positions of a state at a given frame. + Helper method to extract atom positions of a state at a given + Dataset frame. Parameters ---------- @@ -48,16 +97,24 @@ def _replica_positions_at_frame(dataset: nc.Dataset, Returns ------- - unit.Quantity - n_atoms * 3 position Quantity array + Optional[unit.Quantity] + A n_atoms * 3 position Quantity array. Returns ``None`` + if all the values are masked (i.e. no positions were stored + for that frame). """ + # If all the positions are masked (i.e. not present) + if dataset.variables['positions'][frame_num][replica_index].mask.all(): + return None + pos = dataset.variables['positions'][frame_num][replica_index].data pos_units = dataset.variables['positions'].units return pos * unit(pos_units) -def _create_new_dataset(filename: Path, n_atoms: int, - title: str) -> nc.Dataset: +def _create_new_dataset( + filename: Path, n_atoms: int, + title: str +) -> nc.Dataset: """ Helper method to create a new NetCDF dataset which follows the AMBER convention (see: https://ambermd.org/netcdf/nctraj.xhtml) @@ -77,14 +134,14 @@ def _create_new_dataset(filename: Path, n_atoms: int, AMBER Conventions compliant NetCDF dataset to store information contained in MultiState reporter generated NetCDF file. """ - ncfile = nc.Dataset(filename, 'w', format='NETCDF3_64BIT') + ncfile = nc.Dataset(filename, 'w', format='NETCDF3_64BIT_OFFSET') ncfile.Conventions = 'AMBER' ncfile.ConventionVersion = "1.0" ncfile.application = "openfe_analysis" ncfile.program = f"openfe_analysis {__version__}" ncfile.programVersion = f"{__version__}" ncfile.title = title - + # Set the dimensions ncfile.createDimension('frame', None) ncfile.createDimension('spatial', 3) @@ -92,7 +149,7 @@ def _create_new_dataset(filename: Path, n_atoms: int, ncfile.createDimension('cell_spatial', 3) ncfile.createDimension('cell_angular', 3) ncfile.createDimension('label', 5) - + # Set the variables # positions pos = ncfile.createVariable('coordinates', 'f4', ('frame', 'atom', 'spatial')) @@ -111,18 +168,20 @@ def _create_new_dataset(filename: Path, n_atoms: int, ) cell_lengths.units = 'angstrom' cell_angles = ncfile.createVariable( - 'cell_angles', 'f8', ('frame', 'cell_spatial') + 'cell_angles', 'f8', ('frame', 'cell_angular') ) cell_angles.units = 'degree' - + return ncfile -def _get_unitcell(dataset: nc.Dataset, replica_index: int, frame_num: int): +def _get_unitcell( + dataset: nc.Dataset, replica_index: int, frame_num: int +) -> Optional[Tuple[unit.Quantity]]: """ Helper method to extract a unit cell from the stored box vectors in a MultiState reporter generated NetCDF file - at a given state and frame. + at a given state and Dataset frame. Parameters ---------- @@ -135,9 +194,15 @@ def _get_unitcell(dataset: nc.Dataset, replica_index: int, frame_num: int): Returns ------- - Tuple[lx, ly, lz, alpha, beta, gamma] + Optional[Tuple[lx, ly, lz, alpha, beta, gamma]] Unit cell lengths and angles in angstroms and degrees. + If box_vectors are masked (i.e. they were not stored at this frame), + will return ``None``. """ + # Case: no box_vectors were stored at this frame + if dataset.variables['box_vectors'][frame_num][replica_index].mask.all(): + return None + vecs = dataset.variables['box_vectors'][frame_num][replica_index].data vecs_units = dataset.variables['box_vectors'].units x, y, z = (vecs * unit(vecs_units)).to('angstrom').m @@ -183,15 +248,16 @@ def trajectory_from_multistate(input_file: Path, output_file: Path, multistate = nc.Dataset(input_file, 'r') n_atoms = len(multistate.variables['positions'][0][0]) n_replicas = len(multistate.variables['positions'][0]) - n_frames = len(multistate.variables['positions']) - + frame_list = _determine_position_indices(multistate) + n_frames = len(frame_list) + # Sanity check if state_number is not None and (state_number + 1 > n_replicas): # Note this works for now, but when we have more states # than replicas (e.g. SAMS) this won't really work errmsg = "State does not exist" raise ValueError(errmsg) - + # Create output AMBER NetCDF convention file traj = _create_new_dataset( output_file, n_atoms, @@ -202,15 +268,18 @@ def trajectory_from_multistate(input_file: Path, output_file: Path, if replica_number is not None: replica_id = replica_number - # Loopy de loop + # Loopy de loop over n_frames so that the new Dataset + # is just 0 -> n_frames for frame in range(n_frames): if state_number is not None: - replica_id = _state_to_replica(multistate, state_number, frame) + replica_id = _state_to_replica( + multistate, state_number, frame_list[frame] + ) traj.variables['coordinates'][frame] = _replica_positions_at_frame( - multistate, replica_id, frame + multistate, replica_id, frame_list[frame] ).to('angstrom').m - unitcell = _get_unitcell(multistate, replica_id, frame) + unitcell = _get_unitcell(multistate, replica_id, frame_list[frame]) traj.variables['cell_lengths'][frame] = unitcell[:3] traj.variables['cell_angles'][frame] = unitcell[3:] diff --git a/src/openfe_analysis/utils/serialization.py b/src/openfe_analysis/utils/serialization.py new file mode 100644 index 0000000..b4c7d1e --- /dev/null +++ b/src/openfe_analysis/utils/serialization.py @@ -0,0 +1,49 @@ +import yaml +from openff.units import unit +import numpy as np + + +def omm_quantity_string_to_offunit(expression): + """ + Convert an OpenMM Quantity string to an OpenFF Unit. + + Parameters + ---------- + expression : str + The string expression to convert to an OpenFF Unit. + + Returns + ------- + openff.units.Quantity + An OpenFF unit Quantity. + + Notes + ----- + Inspired by `openmmtools.utils.utils.quantity_from_string`. + """ + # Special case where a quantity can be `/ unit` to represent `1 / unit` + if expression[0] == '/': + expression = f"({expression[1:]})**(-1)" + + return unit(expression) + + +class UnitedYamlLoader(yaml.CLoader): + """ + A YamlLoader that can read !Quantity tags and return + them as OpenFF Units. + + Notes + ----- + Modified from `openmmtools.storage.iodrivers._DictYamlLoader`. + """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.add_constructor(u'!Quantity', self.quantity_constructor) + + @staticmethod + def quantity_constructor(loader, node): + loaded_mapping = loader.construct_mapping(node) + data_unit = omm_quantity_string_to_offunit(loaded_mapping['unit']) + data_value = loaded_mapping['value'] + return data_value * data_unit diff --git a/src/tests/conftest.py b/src/tests/conftest.py deleted file mode 100644 index c5b8fed..0000000 --- a/src/tests/conftest.py +++ /dev/null @@ -1,25 +0,0 @@ -import pooch -import pytest - - -RFE_OUTPUT = pooch.create( - path=pooch.os_cache("openfe_analysis"), - base_url="doi:10.6084/m9.figshare.24101655", - registry={ - "checkpoint.nc": "5af398cb14340fddf7492114998b244424b6c3f4514b2e07e4bd411484c08464", - "db.json": "b671f9eb4daf9853f3e1645f9fd7c18150fd2a9bf17c18f23c5cf0c9fd5ca5b3", - "hybrid_system.pdb": "07203679cb14b840b36e4320484df2360f45e323faadb02d6eacac244fddd517", - "simulation.nc": "92361a0864d4359a75399470135f56642b72c605069a4c33dbc4be6f91f28b31", - "simulation_real_time_analysis.yaml": "65706002f371fafba96037f29b054fd7e050e442915205df88567f48f5e5e1cf", - } -) - - -@pytest.fixture -def simulation_nc(): - return RFE_OUTPUT.fetch("simulation.nc") - - -@pytest.fixture -def hybrid_system_pdb(): - return RFE_OUTPUT.fetch("hybrid_system.pdb") diff --git a/src/tests/test_reader.py b/src/tests/test_reader.py deleted file mode 100644 index fe3f997..0000000 --- a/src/tests/test_reader.py +++ /dev/null @@ -1,26 +0,0 @@ -import MDAnalysis as mda -from openfe_analysis import FEReader -import netCDF4 as nc -import pytest - - -def test_universe_creation(simulation_nc, hybrid_system_pdb): - u = mda.Universe(hybrid_system_pdb, simulation_nc, - format='openfe rfe', state_id=0) - - assert u - assert len(u.atoms) == 4782 - assert len(u.trajectory) == 501 - assert u.trajectory.dt == pytest.approx(1.0) - - -def test_universe_from_nc_file(simulation_nc, hybrid_system_pdb): - ds = nc.Dataset(simulation_nc) - - u = mda.Universe(hybrid_system_pdb, ds, - format='openfe rfe', state_id=0) - - assert u - assert len(u.atoms) == 4782 - -