Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions contact_map/contact_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def _atom_slice(traj, indices):
unitcell_lengths=unitcell_lengths,
unitcell_angles=unitcell_angles)


def _residue_for_atom(topology, atom_list):
return set([topology.atom(a).residue for a in atom_list])

Expand Down Expand Up @@ -197,7 +198,7 @@ def _set_atom_idx_to_residue_idx(self):
def s_idx_to_idx(self, idx):
"""function to convert a sliced atom index back to real index"""
if self._use_atom_slice:
return(self._all_atoms[idx])
return self._all_atoms[idx]
else:
return idx

Expand Down Expand Up @@ -740,8 +741,7 @@ def __init__(self, trajectory, query=None, haystack=None, cutoff=0.45,
query, haystack, cutoff,
n_neighbors_ignored)
contacts = self._build_contact_map(trajectory)
(atom_contacts, self._residue_contacts) = contacts
self._atom_contacts = self.convert_atom_contacts(atom_contacts)
(self._atom_contacts, self._residue_contacts) = contacts

def __hash__(self):
return hash((super(ContactFrequency, self).__hash__(),
Expand Down Expand Up @@ -777,7 +777,6 @@ def _build_contact_map(self, trajectory):
residue_query_atom_idxs = self.residue_query_atom_idxs

used_trajectory = self.slice_trajectory(trajectory)

for frame_num in self.frames:
frame_contacts = self.contact_map(used_trajectory, frame_num,
residue_query_atom_idxs,
Expand All @@ -787,7 +786,7 @@ def _build_contact_map(self, trajectory):
# self._atom_contacts_count += frame_atom_contacts
atom_contacts_count.update(frame_atom_contacts)
residue_contacts_count += frame_residue_contacts

atom_contacts_count = self.convert_atom_contacts(atom_contacts_count)
return (atom_contacts_count, residue_contacts_count)

@property
Expand Down
29 changes: 28 additions & 1 deletion contact_map/tests/test_dask_runner.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@

# pylint: disable=wildcard-import, missing-docstring, protected-access
# pylint: disable=attribute-defined-outside-init, invalid-name, no-self-use
# pylint: disable=wrong-import-order, unused-wildcard-import

from .utils import *
from contact_map.dask_runner import *
from contact_map import ContactFrequency

def dask_setup_test_cluster(distributed, n_workers=4, n_attempts=3):
"""Set up a test cluster using dask.distributed. Try up to n_attempts
Expand Down Expand Up @@ -40,3 +40,30 @@ def test_dask_integration(self):
n_neighbors_ignored=0)
client.close()
assert dask_freq.n_frames == 5

def test_dask_atom_slice(self):
# This is an integration test to check that dask works with atom_slice
dask = pytest.importorskip('dask') # pylint: disable=W0612
distributed = pytest.importorskip('dask.distributed')
# Explicitly set only 4 workers on Travis instead of 31
# Fix copied from https://github.com/spencerahill/aospy/pull/220/files
cluster = dask_setup_test_cluster(distributed, n_workers=4)
client = distributed.Client(cluster)
filename = find_testfile("trajectory.pdb")

dask_freq0 = DaskContactFrequency(client, filename, query=[3, 4],
haystack=[6, 7], cutoff=0.075,
n_neighbors_ignored=0)
client.close()
assert dask_freq0.n_frames == 5
client = distributed.Client(cluster)
# Set the slicing of contact frequency (used in the frqeuency task)
# to False
ContactFrequency._class_use_atom_slice = False
dask_freq1 = DaskContactFrequency(client, filename, query=[3, 4],
haystack=[6, 7], cutoff=0.075,
n_neighbors_ignored=0)
client.close()
assert dask_freq0._use_atom_slice is True
assert dask_freq1._use_atom_slice is False
assert dask_freq0 == dask_freq1