diff --git a/ci/conda-recipe/meta.yaml b/ci/conda-recipe/meta.yaml index 3f74281..bdc6572 100644 --- a/ci/conda-recipe/meta.yaml +++ b/ci/conda-recipe/meta.yaml @@ -20,7 +20,7 @@ requirements: - mdtraj - numpy - scipy - - pandas<1.0 + - pandas test: requires: diff --git a/contact_map/contact_count.py b/contact_map/contact_count.py index c9beb2e..4e57732 100644 --- a/contact_map/contact_count.py +++ b/contact_map/contact_count.py @@ -13,6 +13,8 @@ else: HAS_MATPLOTLIB = True +# pandas 0.25 not available on py27; can drop this when we drop py27 +_PD_VERSION = tuple(int(x) for x in pd.__version__.split('.')[:2]) def _colorbar(with_colorbar, cmap_f, norm, min_val): if with_colorbar is False: @@ -25,6 +27,34 @@ def _colorbar(with_colorbar, cmap_f, norm, min_val): return cb +# TODO: remove following: this is a monkeypatch for a bug in pandas +# see: https://github.com/pandas-dev/pandas/issues/29814 +from pandas._libs.sparse import BlockIndex, IntIndex, SparseIndex +def _patch_from_spmatrix(cls, data): + length, ncol = data.shape + + if ncol != 1: + raise ValueError("'data' must have a single column, not '{}'".format(ncol)) + + # our sparse index classes require that the positions be strictly + # increasing. So we need to sort loc, and arr accordingly. + arr = data.data + #idx, _ = data.nonzero() + idx = data.indices + loc = np.argsort(idx) + arr = arr.take(loc) + idx.sort() + + zero = np.array(0, dtype=arr.dtype).item() + dtype = pd.SparseDtype(arr.dtype, zero) + index = IntIndex(length, idx) + + return cls._simple_new(arr, index, dtype) + +if _PD_VERSION >= (0, 25): + pd.core.arrays.SparseArray.from_spmatrix = classmethod(_patch_from_spmatrix) +# TODO: this is the end of what to remove when pandas is fixed + class ContactCount(object): """Return object when dealing with contacts (residue or atom). @@ -95,10 +125,22 @@ def df(self): Rows/columns correspond to indices and the values correspond to the count """ - mtx = self.sparse_matrix.tocoo() + mtx = self.sparse_matrix index = list(range(self.n_x)) columns = list(range(self.n_y)) - return pd.SparseDataFrame(mtx, index=index, columns=columns) + + if _PD_VERSION < (0, 25): # py27 only + mtx = mtx.tocoo() + return pd.SparseDataFrame(mtx, index=index, columns=columns) + + df = pd.DataFrame.sparse.from_spmatrix(mtx, index=index, + columns=columns) + # note: I think we can always use float here for dtype; but in + # principle maybe we need to inspect and get the internal type? + # Problem is, pandas technically stores a different dtype for each + # column. + df = df.astype(pd.SparseDtype("float", np.nan)) + return df def _check_number_of_pixels(self, figure): """ diff --git a/contact_map/tests/test_contact_count.py b/contact_map/tests/test_contact_count.py index ef91eb3..77f3a5d 100644 --- a/contact_map/tests/test_contact_count.py +++ b/contact_map/tests/test_contact_count.py @@ -89,12 +89,23 @@ def test_sparse_matrix(self): def test_df(self): atom_df = self.map.atom_contacts.df residue_df = self.map.residue_contacts.df - assert isinstance(atom_df, pd.SparseDataFrame) - assert isinstance(residue_df, pd.SparseDataFrame) - assert_array_equal(atom_df.to_dense().values, + # this block is for old pandas on py27 + pd_version = tuple(int(x) for x in pd.__version__.split('.')[:2]) + if pd_version < (0, 25): + assert isinstance(atom_df, pd.SparseDataFrame) + assert isinstance(residue_df, pd.SparseDataFrame) + assert_array_equal(atom_df.to_dense().values, + zero_to_nan(self.atom_matrix)) + assert_array_equal(residue_df.to_dense().values, + zero_to_nan(self.residue_matrix)) + return + + assert isinstance(atom_df, pd.DataFrame) + assert isinstance(residue_df, pd.DataFrame) + assert_array_equal(atom_df.sparse.to_dense().values, zero_to_nan(self.atom_matrix)) - assert_array_equal(residue_df.to_dense().values, + assert_array_equal(residue_df.sparse.to_dense().values, zero_to_nan(self.residue_matrix)) @pytest.mark.parametrize("obj_type", ['atom', 'res']) diff --git a/requirements.txt b/requirements.txt index 2bc28c8..6cd0d95 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ future numpy scipy -pandas<1.0 +pandas mdtraj diff --git a/setup.cfg b/setup.cfg index 199836f..63f45c5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -26,7 +26,7 @@ install_requires = numpy mdtraj scipy - pandas<1.0 + pandas packages = find: [bdist_wheel]