Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ dependencies:
- click
- typing-extensions
- openmm >=8.0.0,!=8.1.0,<8.2.0
- openmmtools <0.24.1
- openmmtools >=0.24.1
- openmmforcefields
- perses>=0.10.3
- pooch
Expand Down
23 changes: 23 additions & 0 deletions news/multistate-variablewriting.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
**Added:**

* Add support for variable position/velocity trajectory writing

**Changed:**

* <news item>

**Deprecated:**

* <news item>

**Removed:**

* <news item>

**Fixed:**

* <news item>

**Security:**

* <news item>
22 changes: 22 additions & 0 deletions openfe/protocols/openmm_afe/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,11 +695,33 @@ def _get_reporter(
time_per_iteration=simulation_settings.time_per_iteration,
)

if output_settings.positions_write_frequency is not None:
pos_interval = settings_validation.divmod_time_and_check(
numerator=output_settings.positions_write_frequency,
denominator=simulation_settings.time_per_iteration,
numerator_name="output settings' position_write_frequency",
denominator_name="simulation settings' time_per_iteration"
)
else:
pos_interval = 0

if output_settings.velocities_write_frequency is not None:
vel_interval = settings_validation.divmod_time_and_check(
numerator=output_settings.velocities_write_frequency,
denominator=simulation_settings.time_per_iteration,
numerator_name="output settings' velocity_write_frequency",
denominator_name="simulation settings' time_per_iteration"
)
else:
vel_interval = 0

reporter = multistate.MultiStateReporter(
storage=nc,
analysis_particle_indices=selection_indices,
checkpoint_interval=chk_intervals,
checkpoint_storage=chk,
position_interval=pos_interval,
velocity_interval=vel_interval,
)

# Write out the structure's PDB whilst we're here
Expand Down
23 changes: 23 additions & 0 deletions openfe/protocols/openmm_rfe/equil_rfe_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -920,11 +920,34 @@ def run(self, *, dry=False, verbose=True,

nc = shared_basepath / output_settings.output_filename
chk = output_settings.checkpoint_storage_filename

if output_settings.positions_write_frequency is not None:
pos_interval = settings_validation.divmod_time_and_check(
numerator=output_settings.positions_write_frequency,
denominator=sampler_settings.time_per_iteration,
numerator_name="output settings' position_write_frequency",
denominator_name="sampler settings' time_per_iteration"
)
else:
pos_interval = 0

if output_settings.velocities_write_frequency is not None:
vel_interval = settings_validation.divmod_time_and_check(
numerator=output_settings.velocities_write_frequency,
denominator=sampler_settings.time_per_iteration,
numerator_name="output settings' velocity_write_frequency",
denominator_name="sampler settings' time_per_iteration"
)
else:
vel_interval = 0

reporter = multistate.MultiStateReporter(
storage=nc,
analysis_particle_indices=selection_indices,
checkpoint_interval=chk_intervals,
checkpoint_storage=chk,
position_interval=pos_interval,
velocity_interval=vel_interval,
)

# b. Write out a PDB containing the subsampled hybrid state
Expand Down
29 changes: 29 additions & 0 deletions openfe/protocols/openmm_utils/omm_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,35 @@
to visualise and further manipulate the system.
Default 'hybrid_system.pdb'.
"""
positions_write_frequency: Optional[FloatQuantity['picosecond']] = 100 * unit.picosecond
"""
Frequency at which positions are written to the simulation trajectory
storage file (defined by ``output_filename``).

If ``None``, no positions will be written to the trajectory.

Unless set to ``None``, must be divisible by
``MultiStateSimulationSettings.time_per_iteration``.
"""
velocities_write_frequency: Optional[FloatQuantity['picosecond']] = None
"""
Frequency at which velocities are written to the simulation
trajectory storage file (defined by ``output_filename``).

If ``None`` (default), no velocities will be written to the trajectory.

Unless set to ``None``, must be divisible by
``MultiStateSimulationSettings.time_per_iteration``.
"""


@validator('positions_write_frequency', 'velocities_write_frequency')
def must_be_positive(cls, v):
if v is not None and v < 0:
errmsg = ("Position_write_frequency and velocities_write_frequency"

Check warning on line 479 in openfe/protocols/openmm_utils/omm_settings.py

View check run for this annotation

Codecov / codecov/patch

openfe/protocols/openmm_utils/omm_settings.py#L479

Added line #L479 was not covered by tests
f" must be positive (or None), got {v}.")
raise ValueError(errmsg)

Check warning on line 481 in openfe/protocols/openmm_utils/omm_settings.py

View check run for this annotation

Codecov / codecov/patch

openfe/protocols/openmm_utils/omm_settings.py#L481

Added line #L481 was not covered by tests
return v


class SimulationSettings(SettingsBaseModel):
Expand Down
54 changes: 54 additions & 0 deletions openfe/tests/protocols/test_openmm_afe_solvation_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,3 +905,57 @@ def test_filenotfound_replica_states(self, protocolresult):

with pytest.raises(ValueError, match=errmsg):
protocolresult.get_replica_states()


@pytest.mark.parametrize('positions_write_frequency,velocities_write_frequency',
[[100 * offunit.picosecond, None],
[None, None],
[None, 100 * offunit.picosecond]])
def test_dry_run_vacuum_write_frequency(benzene_modifications,
positions_write_frequency,
velocities_write_frequency,
tmpdir):
s = openmm_afe.AbsoluteSolvationProtocol.default_settings()
s.protocol_repeats = 1
s.solvent_output_settings.output_indices = "resname UNK"
s.solvent_output_settings.positions_write_frequency = positions_write_frequency
s.solvent_output_settings.velocities_write_frequency = velocities_write_frequency
s.vacuum_output_settings.positions_write_frequency = positions_write_frequency
s.vacuum_output_settings.velocities_write_frequency = velocities_write_frequency

protocol = openmm_afe.AbsoluteSolvationProtocol(
settings=s,
)

stateA = ChemicalSystem({
'benzene': benzene_modifications['benzene'],
'solvent': SolventComponent()
})

stateB = ChemicalSystem({
'solvent': SolventComponent(),
})

# Create DAG from protocol, get the vacuum and solvent units
# and eventually dry run the first solvent unit
dag = protocol.create(
stateA=stateA,
stateB=stateB,
mapping=None,
)
prot_units = list(dag.protocol_units)

assert len(prot_units) == 2

with tmpdir.as_cwd():
for u in prot_units:
sampler = u.run(dry=True)['debug']['sampler']
reporter = sampler._reporter
if positions_write_frequency:
assert reporter.position_interval == positions_write_frequency.m
else:
assert reporter.position_interval == 0
if velocities_write_frequency:
assert reporter.velocity_interval == velocities_write_frequency.m
else:
assert reporter.velocity_interval == 0
78 changes: 78 additions & 0 deletions openfe/tests/protocols/test_openmm_equil_rfe_protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -2115,3 +2115,81 @@ def test_structural_analysis_error(tmpdir):

assert 'structural_analysis_error' in ret
assert 'structural_analysis' not in ret


@pytest.mark.parametrize('positions_write_frequency,velocities_write_frequency',
[[100 * unit.picosecond, None],
[None, None],
[None, 100 * unit.picosecond]])
def test_dry_run_vacuum_write_frequency(benzene_vacuum_system,
toluene_vacuum_system,
benzene_to_toluene_mapping,
positions_write_frequency,
velocities_write_frequency,
tmpdir,
):

vac_settings = openmm_rfe.RelativeHybridTopologyProtocol.default_settings()
vac_settings.forcefield_settings.nonbonded_method = 'nocutoff'
vac_settings.output_settings.positions_write_frequency = positions_write_frequency
vac_settings.output_settings.velocities_write_frequency = velocities_write_frequency
vac_settings.protocol_repeats = 1

protocol = openmm_rfe.RelativeHybridTopologyProtocol(
settings=vac_settings,
)

# create DAG from protocol and take first (and only) work unit from within
dag = protocol.create(
stateA=benzene_vacuum_system,
stateB=toluene_vacuum_system,
mapping=benzene_to_toluene_mapping,
)
dag_unit = list(dag.protocol_units)[0]

with tmpdir.as_cwd():
sampler = dag_unit.run(dry=True)['debug']['sampler']
reporter = sampler._reporter
if positions_write_frequency:
assert reporter.position_interval == positions_write_frequency.m
else:
assert reporter.position_interval == 0
if velocities_write_frequency:
assert reporter.velocity_interval == velocities_write_frequency.m
else:
assert reporter.velocity_interval == 0


@pytest.mark.parametrize('positions_write_frequency,velocities_write_frequency',
[[100.1 * unit.picosecond, 100 * unit.picosecond],
[100 * unit.picosecond, 100.1 * unit.picosecond]])
def test_pos_write_frequency_not_divisible(benzene_vacuum_system,
toluene_vacuum_system,
benzene_to_toluene_mapping,
positions_write_frequency,
velocities_write_frequency,
tmpdir,
):

vac_settings = openmm_rfe.RelativeHybridTopologyProtocol.default_settings()
vac_settings.forcefield_settings.nonbonded_method = 'nocutoff'
vac_settings.output_settings.positions_write_frequency = positions_write_frequency
vac_settings.output_settings.velocities_write_frequency = velocities_write_frequency
vac_settings.protocol_repeats = 1

protocol = openmm_rfe.RelativeHybridTopologyProtocol(
settings=vac_settings,
)

# create DAG from protocol and take first (and only) work unit from within
dag = protocol.create(
stateA=benzene_vacuum_system,
stateB=toluene_vacuum_system,
mapping=benzene_to_toluene_mapping,
)
dag_unit = list(dag.protocol_units)[0]

with tmpdir.as_cwd():
errmsg = "The output settings' "
with pytest.raises(ValueError, match=errmsg):
dag_unit.run(dry=True)['debug']['sampler']
1 change: 1 addition & 0 deletions openfe/tests/protocols/test_openmm_rfe_slow.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
s.protocol_repeats = 1
s.engine_settings.compute_platform = platform
s.output_settings.checkpoint_interval = 20 * unit.femtosecond
s.output_settings.positions_write_frequency = 20 * unit.femtosecond

Check warning on line 56 in openfe/tests/protocols/test_openmm_rfe_slow.py

View check run for this annotation

Codecov / codecov/patch

openfe/tests/protocols/test_openmm_rfe_slow.py#L56

Added line #L56 was not covered by tests

p = openmm_rfe.RelativeHybridTopologyProtocol(s)

Expand Down