From ef8a2785dd5932e409d33efaa4c3c19f83f532eb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 25 Mar 2024 20:30:39 +0000 Subject: [PATCH 1/3] [pre-commit.ci] pre-commit autoupdate MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/asottile/pyupgrade: v3.15.1 → v3.15.2](https://github.com/asottile/pyupgrade/compare/v3.15.1...v3.15.2) - [github.com/astral-sh/ruff-pre-commit: v0.3.3 → v0.3.4](https://github.com/astral-sh/ruff-pre-commit/compare/v0.3.3...v0.3.4) --- .pre-commit-config.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0da702a1e..6635a2ff0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -42,7 +42,7 @@ repos: - id: check-yaml - id: check-toml - repo: https://github.com/asottile/pyupgrade - rev: v3.15.1 + rev: v3.15.2 hooks: - id: pyupgrade args: [--py3-plus, --py38-plus, --keep-runtime-typing] @@ -63,7 +63,7 @@ repos: - id: doc8 - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.3.3 + rev: v0.3.4 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] From 25c6a0e5c30e7d1748c49303e2ba8309534f3d46 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 2 Apr 2024 09:18:03 +0200 Subject: [PATCH 2/3] [pre-commit.ci] pre-commit autoupdate (#682) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/astral-sh/ruff-pre-commit: v0.3.4 → v0.3.5](https://github.com/astral-sh/ruff-pre-commit/compare/v0.3.4...v0.3.5) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6635a2ff0..f0235f776 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -63,7 +63,7 @@ repos: - id: doc8 - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.3.4 + rev: v0.3.5 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] From 3151da7b324672fab6757b8a84bcd1d0bcb9f027 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Selman=20=C3=96zleyen?= <32667648+selmanozleyen@users.noreply.github.com> Date: Tue, 2 Apr 2024 10:37:05 +0300 Subject: [PATCH 3/3] Don't assert for index/column names when comparing series (e.g. in `set_x`, `set_y`, and `set_x`) (#669) * don't assert for index/column names * create and use util func * use the utils everywhere --- src/moscot/base/problems/_utils.py | 11 ++++++++++ src/moscot/base/problems/problem.py | 33 +++++++++++++---------------- 2 files changed, 26 insertions(+), 18 deletions(-) diff --git a/src/moscot/base/problems/_utils.py b/src/moscot/base/problems/_utils.py index 7219b58e3..40bf0a99a 100644 --- a/src/moscot/base/problems/_utils.py +++ b/src/moscot/base/problems/_utils.py @@ -158,6 +158,17 @@ def _validate_args_cell_transition( raise TypeError(f"Expected argument to be either `str` or `dict`, found `{type(arg)}`.") +def _assert_series_match(a: pd.Series, b: pd.Series) -> None: + """Assert that two series are equal ignoring the names.""" + pd.testing.assert_series_equal(a, b, check_names=False) + + +def _assert_columns_and_index_match(a: pd.Series, b: pd.DataFrame) -> None: + """Assert that a series and a dataframe's index and columns are matching.""" + _assert_series_match(a, b.index.to_series()) + _assert_series_match(a, b.columns.to_series()) + + def _get_cell_indices( adata: AnnData, key: Optional[str] = None, diff --git a/src/moscot/base/problems/problem.py b/src/moscot/base/problems/problem.py index 65f3e9692..c602e1a22 100644 --- a/src/moscot/base/problems/problem.py +++ b/src/moscot/base/problems/problem.py @@ -31,6 +31,8 @@ from moscot.base.output import BaseSolverOutput, MatrixSolverOutput from moscot.base.problems._utils import ( TimeScalesHeatKernel, + _assert_columns_and_index_match, + _assert_series_match, require_solution, wrap_prepare, wrap_solve, @@ -530,8 +532,8 @@ def set_solution( raise ValueError(f"`{self}` already contains a solution, use `overwrite=True` to overwrite it.") if isinstance(solution, pd.DataFrame): - pd.testing.assert_series_equal(self.adata_src.obs_names.to_series(), solution.index.to_series()) - pd.testing.assert_series_equal(self.adata_tgt.obs_names.to_series(), solution.columns.to_series()) + _assert_series_match(self.adata_src.obs_names.to_series(), solution.index.to_series()) + _assert_series_match(self.adata_tgt.obs_names.to_series(), solution.columns.to_series()) solution = solution.to_numpy() if not isinstance(solution, BaseSolverOutput): solution = MatrixSolverOutput(solution, **kwargs) @@ -729,13 +731,12 @@ def set_graph_xy( """ expected_series = pd.concat([self.adata_src.obs_names.to_series(), self.adata_tgt.obs_names.to_series()]) if isinstance(data, pd.DataFrame): - pd.testing.assert_series_equal(expected_series, data.index.to_series()) - pd.testing.assert_series_equal(expected_series, data.columns.to_series()) + _assert_columns_and_index_match(expected_series, data) data_src = data.to_numpy() elif isinstance(data, tuple): data_src, index_src, index_tgt = data - pd.testing.assert_series_equal(expected_series, index_src) - pd.testing.assert_series_equal(expected_series, index_tgt) + _assert_series_match(expected_series, index_src) + _assert_series_match(expected_series, index_tgt) else: raise ValueError( "Expected data to be a `pd.DataFrame` or a tuple of (`sp.csr_matrix`, `pd.Series`, `pd.Series`), " @@ -780,12 +781,11 @@ def set_graph_x( """ expected_series = self.adata_src.obs_names.to_series() if isinstance(data, pd.DataFrame): - pd.testing.assert_series_equal(expected_series, data.index.to_series()) - pd.testing.assert_series_equal(expected_series, data.columns.to_series()) + _assert_columns_and_index_match(expected_series, data) data_src = data.to_numpy() elif isinstance(data, tuple): data_src, index_src = data - pd.testing.assert_series_equal(expected_series, index_src) + _assert_series_match(expected_series, index_src) else: raise ValueError( "Expected data to be a `pd.DataFrame` or a tuple of (`sp.csr_matrix`, `pd.Series`), " @@ -830,12 +830,11 @@ def set_graph_y( """ expected_series = self.adata_tgt.obs_names.to_series() if isinstance(data, pd.DataFrame): - pd.testing.assert_series_equal(expected_series, data.index.to_series()) - pd.testing.assert_series_equal(expected_series, data.columns.to_series()) + _assert_columns_and_index_match(expected_series, data) data_src = data.to_numpy() elif isinstance(data, tuple): data_src, index_src = data - pd.testing.assert_series_equal(expected_series, index_src) + _assert_series_match(expected_series, index_src) else: raise ValueError( "Expected data to be a `pd.DataFrame` or a tuple of (`sp.csr_matrix`, `pd.Series`), " @@ -869,8 +868,8 @@ def set_xy( - :attr:`xy` - the :term:`linear term`. - :attr:`stage` - set to ``'prepared'``. """ - pd.testing.assert_series_equal(self.adata_src.obs_names.to_series(), data.index.to_series()) - pd.testing.assert_series_equal(self.adata_tgt.obs_names.to_series(), data.columns.to_series()) + _assert_series_match(self.adata_src.obs_names.to_series(), data.index.to_series()) + _assert_series_match(self.adata_tgt.obs_names.to_series(), data.columns.to_series()) self._xy = TaggedArray(data_src=data.to_numpy(), data_tgt=None, tag=Tag(tag), cost="cost") self._stage = "prepared" @@ -893,8 +892,7 @@ def set_x(self, data: pd.DataFrame, tag: Literal["cost_matrix", "kernel"]) -> No - :attr:`x` - the source :term:`quadratic term`. - :attr:`stage` - set to ``'prepared'``. """ - pd.testing.assert_series_equal(self.adata_src.obs_names.to_series(), data.index.to_series()) - pd.testing.assert_series_equal(self.adata_src.obs_names.to_series(), data.columns.to_series()) + _assert_columns_and_index_match(self.adata_src.obs_names.to_series(), data) if self.problem_kind == "linear": logger.info(f"Changing the problem type from {self.problem_kind!r} to 'quadratic (fused)'.") @@ -920,8 +918,7 @@ def set_y(self, data: pd.DataFrame, tag: Literal["cost_matrix", "kernel"]) -> No - :attr:`y` - the target :term:`quadratic term`. - :attr:`stage` - set to ``'prepared'``. """ - pd.testing.assert_series_equal(self.adata_tgt.obs_names.to_series(), data.index.to_series()) - pd.testing.assert_series_equal(self.adata_tgt.obs_names.to_series(), data.columns.to_series()) + _assert_columns_and_index_match(self.adata_tgt.obs_names.to_series(), data) if self.problem_kind == "linear": logger.info(f"Changing the problem type from {self.problem_kind!r} to 'quadratic (fused)'.")