diff --git a/environment.yml b/environment.yml index 344c3ccaa..c166af527 100644 --- a/environment.yml +++ b/environment.yml @@ -27,7 +27,7 @@ dependencies: - click - typing-extensions - openmm >=8.0.0,!=8.1.0,<8.2.0 - - openmmtools <0.24.1 + - openmmtools >=0.24.1 - openmmforcefields - perses>=0.10.3 - pooch diff --git a/news/multistate-variablewriting.rst b/news/multistate-variablewriting.rst new file mode 100644 index 000000000..deab65c58 --- /dev/null +++ b/news/multistate-variablewriting.rst @@ -0,0 +1,23 @@ +**Added:** + +* Add support for variable position/velocity trajectory writing + +**Changed:** + +* + +**Deprecated:** + +* + +**Removed:** + +* + +**Fixed:** + +* + +**Security:** + +* diff --git a/openfe/protocols/openmm_afe/base.py b/openfe/protocols/openmm_afe/base.py index 633ec884a..88caf4242 100644 --- a/openfe/protocols/openmm_afe/base.py +++ b/openfe/protocols/openmm_afe/base.py @@ -695,11 +695,33 @@ def _get_reporter( time_per_iteration=simulation_settings.time_per_iteration, ) + if output_settings.positions_write_frequency is not None: + pos_interval = settings_validation.divmod_time_and_check( + numerator=output_settings.positions_write_frequency, + denominator=simulation_settings.time_per_iteration, + numerator_name="output settings' position_write_frequency", + denominator_name="simulation settings' time_per_iteration" + ) + else: + pos_interval = 0 + + if output_settings.velocities_write_frequency is not None: + vel_interval = settings_validation.divmod_time_and_check( + numerator=output_settings.velocities_write_frequency, + denominator=simulation_settings.time_per_iteration, + numerator_name="output settings' velocity_write_frequency", + denominator_name="simulation settings' time_per_iteration" + ) + else: + vel_interval = 0 + reporter = multistate.MultiStateReporter( storage=nc, analysis_particle_indices=selection_indices, checkpoint_interval=chk_intervals, checkpoint_storage=chk, + position_interval=pos_interval, + velocity_interval=vel_interval, ) # Write out the structure's PDB whilst we're here diff --git a/openfe/protocols/openmm_rfe/equil_rfe_methods.py b/openfe/protocols/openmm_rfe/equil_rfe_methods.py index 680b0b229..d2f95f2c3 100644 --- a/openfe/protocols/openmm_rfe/equil_rfe_methods.py +++ b/openfe/protocols/openmm_rfe/equil_rfe_methods.py @@ -920,11 +920,34 @@ def run(self, *, dry=False, verbose=True, nc = shared_basepath / output_settings.output_filename chk = output_settings.checkpoint_storage_filename + + if output_settings.positions_write_frequency is not None: + pos_interval = settings_validation.divmod_time_and_check( + numerator=output_settings.positions_write_frequency, + denominator=sampler_settings.time_per_iteration, + numerator_name="output settings' position_write_frequency", + denominator_name="sampler settings' time_per_iteration" + ) + else: + pos_interval = 0 + + if output_settings.velocities_write_frequency is not None: + vel_interval = settings_validation.divmod_time_and_check( + numerator=output_settings.velocities_write_frequency, + denominator=sampler_settings.time_per_iteration, + numerator_name="output settings' velocity_write_frequency", + denominator_name="sampler settings' time_per_iteration" + ) + else: + vel_interval = 0 + reporter = multistate.MultiStateReporter( storage=nc, analysis_particle_indices=selection_indices, checkpoint_interval=chk_intervals, checkpoint_storage=chk, + position_interval=pos_interval, + velocity_interval=vel_interval, ) # b. Write out a PDB containing the subsampled hybrid state diff --git a/openfe/protocols/openmm_utils/omm_settings.py b/openfe/protocols/openmm_utils/omm_settings.py index f194769e7..59e719f34 100644 --- a/openfe/protocols/openmm_utils/omm_settings.py +++ b/openfe/protocols/openmm_utils/omm_settings.py @@ -451,6 +451,35 @@ class Config: to visualise and further manipulate the system. Default 'hybrid_system.pdb'. """ + positions_write_frequency: Optional[FloatQuantity['picosecond']] = 100 * unit.picosecond + """ + Frequency at which positions are written to the simulation trajectory + storage file (defined by ``output_filename``). + + If ``None``, no positions will be written to the trajectory. + + Unless set to ``None``, must be divisible by + ``MultiStateSimulationSettings.time_per_iteration``. + """ + velocities_write_frequency: Optional[FloatQuantity['picosecond']] = None + """ + Frequency at which velocities are written to the simulation + trajectory storage file (defined by ``output_filename``). + + If ``None`` (default), no velocities will be written to the trajectory. + + Unless set to ``None``, must be divisible by + ``MultiStateSimulationSettings.time_per_iteration``. + """ + + + @validator('positions_write_frequency', 'velocities_write_frequency') + def must_be_positive(cls, v): + if v is not None and v < 0: + errmsg = ("Position_write_frequency and velocities_write_frequency" + f" must be positive (or None), got {v}.") + raise ValueError(errmsg) + return v class SimulationSettings(SettingsBaseModel): diff --git a/openfe/tests/protocols/test_openmm_afe_solvation_protocol.py b/openfe/tests/protocols/test_openmm_afe_solvation_protocol.py index 2e69c6464..30ce93aba 100644 --- a/openfe/tests/protocols/test_openmm_afe_solvation_protocol.py +++ b/openfe/tests/protocols/test_openmm_afe_solvation_protocol.py @@ -905,3 +905,57 @@ def test_filenotfound_replica_states(self, protocolresult): with pytest.raises(ValueError, match=errmsg): protocolresult.get_replica_states() + + +@pytest.mark.parametrize('positions_write_frequency,velocities_write_frequency', + [[100 * offunit.picosecond, None], + [None, None], + [None, 100 * offunit.picosecond]]) +def test_dry_run_vacuum_write_frequency(benzene_modifications, + positions_write_frequency, + velocities_write_frequency, + tmpdir): + s = openmm_afe.AbsoluteSolvationProtocol.default_settings() + s.protocol_repeats = 1 + s.solvent_output_settings.output_indices = "resname UNK" + s.solvent_output_settings.positions_write_frequency = positions_write_frequency + s.solvent_output_settings.velocities_write_frequency = velocities_write_frequency + s.vacuum_output_settings.positions_write_frequency = positions_write_frequency + s.vacuum_output_settings.velocities_write_frequency = velocities_write_frequency + + protocol = openmm_afe.AbsoluteSolvationProtocol( + settings=s, + ) + + stateA = ChemicalSystem({ + 'benzene': benzene_modifications['benzene'], + 'solvent': SolventComponent() + }) + + stateB = ChemicalSystem({ + 'solvent': SolventComponent(), + }) + + # Create DAG from protocol, get the vacuum and solvent units + # and eventually dry run the first solvent unit + dag = protocol.create( + stateA=stateA, + stateB=stateB, + mapping=None, + ) + prot_units = list(dag.protocol_units) + + assert len(prot_units) == 2 + + with tmpdir.as_cwd(): + for u in prot_units: + sampler = u.run(dry=True)['debug']['sampler'] + reporter = sampler._reporter + if positions_write_frequency: + assert reporter.position_interval == positions_write_frequency.m + else: + assert reporter.position_interval == 0 + if velocities_write_frequency: + assert reporter.velocity_interval == velocities_write_frequency.m + else: + assert reporter.velocity_interval == 0 diff --git a/openfe/tests/protocols/test_openmm_equil_rfe_protocols.py b/openfe/tests/protocols/test_openmm_equil_rfe_protocols.py index 9b3183585..513962355 100644 --- a/openfe/tests/protocols/test_openmm_equil_rfe_protocols.py +++ b/openfe/tests/protocols/test_openmm_equil_rfe_protocols.py @@ -2115,3 +2115,81 @@ def test_structural_analysis_error(tmpdir): assert 'structural_analysis_error' in ret assert 'structural_analysis' not in ret + + +@pytest.mark.parametrize('positions_write_frequency,velocities_write_frequency', + [[100 * unit.picosecond, None], + [None, None], + [None, 100 * unit.picosecond]]) +def test_dry_run_vacuum_write_frequency(benzene_vacuum_system, + toluene_vacuum_system, + benzene_to_toluene_mapping, + positions_write_frequency, + velocities_write_frequency, + tmpdir, + ): + + vac_settings = openmm_rfe.RelativeHybridTopologyProtocol.default_settings() + vac_settings.forcefield_settings.nonbonded_method = 'nocutoff' + vac_settings.output_settings.positions_write_frequency = positions_write_frequency + vac_settings.output_settings.velocities_write_frequency = velocities_write_frequency + vac_settings.protocol_repeats = 1 + + protocol = openmm_rfe.RelativeHybridTopologyProtocol( + settings=vac_settings, + ) + + # create DAG from protocol and take first (and only) work unit from within + dag = protocol.create( + stateA=benzene_vacuum_system, + stateB=toluene_vacuum_system, + mapping=benzene_to_toluene_mapping, + ) + dag_unit = list(dag.protocol_units)[0] + + with tmpdir.as_cwd(): + sampler = dag_unit.run(dry=True)['debug']['sampler'] + reporter = sampler._reporter + if positions_write_frequency: + assert reporter.position_interval == positions_write_frequency.m + else: + assert reporter.position_interval == 0 + if velocities_write_frequency: + assert reporter.velocity_interval == velocities_write_frequency.m + else: + assert reporter.velocity_interval == 0 + + +@pytest.mark.parametrize('positions_write_frequency,velocities_write_frequency', + [[100.1 * unit.picosecond, 100 * unit.picosecond], + [100 * unit.picosecond, 100.1 * unit.picosecond]]) +def test_pos_write_frequency_not_divisible(benzene_vacuum_system, + toluene_vacuum_system, + benzene_to_toluene_mapping, + positions_write_frequency, + velocities_write_frequency, + tmpdir, + ): + + vac_settings = openmm_rfe.RelativeHybridTopologyProtocol.default_settings() + vac_settings.forcefield_settings.nonbonded_method = 'nocutoff' + vac_settings.output_settings.positions_write_frequency = positions_write_frequency + vac_settings.output_settings.velocities_write_frequency = velocities_write_frequency + vac_settings.protocol_repeats = 1 + + protocol = openmm_rfe.RelativeHybridTopologyProtocol( + settings=vac_settings, + ) + + # create DAG from protocol and take first (and only) work unit from within + dag = protocol.create( + stateA=benzene_vacuum_system, + stateB=toluene_vacuum_system, + mapping=benzene_to_toluene_mapping, + ) + dag_unit = list(dag.protocol_units)[0] + + with tmpdir.as_cwd(): + errmsg = "The output settings' " + with pytest.raises(ValueError, match=errmsg): + dag_unit.run(dry=True)['debug']['sampler'] diff --git a/openfe/tests/protocols/test_openmm_rfe_slow.py b/openfe/tests/protocols/test_openmm_rfe_slow.py index 7e3d4b55f..f0fe5a0a7 100644 --- a/openfe/tests/protocols/test_openmm_rfe_slow.py +++ b/openfe/tests/protocols/test_openmm_rfe_slow.py @@ -53,6 +53,7 @@ def test_openmm_run_engine(benzene_vacuum_system, platform, s.protocol_repeats = 1 s.engine_settings.compute_platform = platform s.output_settings.checkpoint_interval = 20 * unit.femtosecond + s.output_settings.positions_write_frequency = 20 * unit.femtosecond p = openmm_rfe.RelativeHybridTopologyProtocol(s)