diff --git a/contact_map/contact_map.py b/contact_map/contact_map.py index f765e0c..7c145f1 100644 --- a/contact_map/contact_map.py +++ b/contact_map/contact_map.py @@ -70,6 +70,7 @@ def _atom_slice(traj, indices): unitcell_lengths=unitcell_lengths, unitcell_angles=unitcell_angles) + def _residue_for_atom(topology, atom_list): return set([topology.atom(a).residue for a in atom_list]) @@ -197,7 +198,7 @@ def _set_atom_idx_to_residue_idx(self): def s_idx_to_idx(self, idx): """function to convert a sliced atom index back to real index""" if self._use_atom_slice: - return(self._all_atoms[idx]) + return self._all_atoms[idx] else: return idx @@ -740,8 +741,7 @@ def __init__(self, trajectory, query=None, haystack=None, cutoff=0.45, query, haystack, cutoff, n_neighbors_ignored) contacts = self._build_contact_map(trajectory) - (atom_contacts, self._residue_contacts) = contacts - self._atom_contacts = self.convert_atom_contacts(atom_contacts) + (self._atom_contacts, self._residue_contacts) = contacts def __hash__(self): return hash((super(ContactFrequency, self).__hash__(), @@ -777,7 +777,6 @@ def _build_contact_map(self, trajectory): residue_query_atom_idxs = self.residue_query_atom_idxs used_trajectory = self.slice_trajectory(trajectory) - for frame_num in self.frames: frame_contacts = self.contact_map(used_trajectory, frame_num, residue_query_atom_idxs, @@ -787,7 +786,7 @@ def _build_contact_map(self, trajectory): # self._atom_contacts_count += frame_atom_contacts atom_contacts_count.update(frame_atom_contacts) residue_contacts_count += frame_residue_contacts - + atom_contacts_count = self.convert_atom_contacts(atom_contacts_count) return (atom_contacts_count, residue_contacts_count) @property diff --git a/contact_map/tests/test_dask_runner.py b/contact_map/tests/test_dask_runner.py index f630c13..3776812 100644 --- a/contact_map/tests/test_dask_runner.py +++ b/contact_map/tests/test_dask_runner.py @@ -1,10 +1,10 @@ - # pylint: disable=wildcard-import, missing-docstring, protected-access # pylint: disable=attribute-defined-outside-init, invalid-name, no-self-use # pylint: disable=wrong-import-order, unused-wildcard-import from .utils import * from contact_map.dask_runner import * +from contact_map import ContactFrequency def dask_setup_test_cluster(distributed, n_workers=4, n_attempts=3): """Set up a test cluster using dask.distributed. Try up to n_attempts @@ -40,3 +40,30 @@ def test_dask_integration(self): n_neighbors_ignored=0) client.close() assert dask_freq.n_frames == 5 + + def test_dask_atom_slice(self): + # This is an integration test to check that dask works with atom_slice + dask = pytest.importorskip('dask') # pylint: disable=W0612 + distributed = pytest.importorskip('dask.distributed') + # Explicitly set only 4 workers on Travis instead of 31 + # Fix copied from https://github.com/spencerahill/aospy/pull/220/files + cluster = dask_setup_test_cluster(distributed, n_workers=4) + client = distributed.Client(cluster) + filename = find_testfile("trajectory.pdb") + + dask_freq0 = DaskContactFrequency(client, filename, query=[3, 4], + haystack=[6, 7], cutoff=0.075, + n_neighbors_ignored=0) + client.close() + assert dask_freq0.n_frames == 5 + client = distributed.Client(cluster) + # Set the slicing of contact frequency (used in the frqeuency task) + # to False + ContactFrequency._class_use_atom_slice = False + dask_freq1 = DaskContactFrequency(client, filename, query=[3, 4], + haystack=[6, 7], cutoff=0.075, + n_neighbors_ignored=0) + client.close() + assert dask_freq0._use_atom_slice is True + assert dask_freq1._use_atom_slice is False + assert dask_freq0 == dask_freq1