Skip to content

Commit

Permalink
Merge branch 'main' into add/sparse-geodesic
Browse files Browse the repository at this point in the history
  • Loading branch information
selmanozleyen authored Apr 3, 2024
2 parents 5b83621 + 3151da7 commit d85fa68
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 20 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ repos:
- id: check-yaml
- id: check-toml
- repo: https://github.com/asottile/pyupgrade
rev: v3.15.1
rev: v3.15.2
hooks:
- id: pyupgrade
args: [--py3-plus, --py38-plus, --keep-runtime-typing]
Expand All @@ -63,7 +63,7 @@ repos:
- id: doc8
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.3.3
rev: v0.3.5
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
11 changes: 11 additions & 0 deletions src/moscot/base/problems/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,17 @@ def _validate_args_cell_transition(
raise TypeError(f"Expected argument to be either `str` or `dict`, found `{type(arg)}`.")


def _assert_series_match(a: pd.Series, b: pd.Series) -> None:
"""Assert that two series are equal ignoring the names."""
pd.testing.assert_series_equal(a, b, check_names=False)


def _assert_columns_and_index_match(a: pd.Series, b: pd.DataFrame) -> None:
"""Assert that a series and a dataframe's index and columns are matching."""
_assert_series_match(a, b.index.to_series())
_assert_series_match(a, b.columns.to_series())


def _get_cell_indices(
adata: AnnData,
key: Optional[str] = None,
Expand Down
33 changes: 15 additions & 18 deletions src/moscot/base/problems/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
from moscot.base.output import BaseSolverOutput, MatrixSolverOutput
from moscot.base.problems._utils import (
TimeScalesHeatKernel,
_assert_columns_and_index_match,
_assert_series_match,
require_solution,
wrap_prepare,
wrap_solve,
Expand Down Expand Up @@ -530,8 +532,8 @@ def set_solution(
raise ValueError(f"`{self}` already contains a solution, use `overwrite=True` to overwrite it.")

if isinstance(solution, pd.DataFrame):
pd.testing.assert_series_equal(self.adata_src.obs_names.to_series(), solution.index.to_series())
pd.testing.assert_series_equal(self.adata_tgt.obs_names.to_series(), solution.columns.to_series())
_assert_series_match(self.adata_src.obs_names.to_series(), solution.index.to_series())
_assert_series_match(self.adata_tgt.obs_names.to_series(), solution.columns.to_series())
solution = solution.to_numpy()
if not isinstance(solution, BaseSolverOutput):
solution = MatrixSolverOutput(solution, **kwargs)
Expand Down Expand Up @@ -729,13 +731,12 @@ def set_graph_xy(
"""
expected_series = pd.concat([self.adata_src.obs_names.to_series(), self.adata_tgt.obs_names.to_series()])
if isinstance(data, pd.DataFrame):
pd.testing.assert_series_equal(expected_series, data.index.to_series())
pd.testing.assert_series_equal(expected_series, data.columns.to_series())
_assert_columns_and_index_match(expected_series, data)
data_src = data.to_numpy()
elif isinstance(data, tuple):
data_src, index_src, index_tgt = data
pd.testing.assert_series_equal(expected_series, index_src)
pd.testing.assert_series_equal(expected_series, index_tgt)
_assert_series_match(expected_series, index_src)
_assert_series_match(expected_series, index_tgt)
else:
raise ValueError(
"Expected data to be a `pd.DataFrame` or a tuple of (`sp.csr_matrix`, `pd.Series`, `pd.Series`), "
Expand Down Expand Up @@ -780,12 +781,11 @@ def set_graph_x(
"""
expected_series = self.adata_src.obs_names.to_series()
if isinstance(data, pd.DataFrame):
pd.testing.assert_series_equal(expected_series, data.index.to_series())
pd.testing.assert_series_equal(expected_series, data.columns.to_series())
_assert_columns_and_index_match(expected_series, data)
data_src = data.to_numpy()
elif isinstance(data, tuple):
data_src, index_src = data
pd.testing.assert_series_equal(expected_series, index_src)
_assert_series_match(expected_series, index_src)
else:
raise ValueError(
"Expected data to be a `pd.DataFrame` or a tuple of (`sp.csr_matrix`, `pd.Series`), "
Expand Down Expand Up @@ -830,12 +830,11 @@ def set_graph_y(
"""
expected_series = self.adata_tgt.obs_names.to_series()
if isinstance(data, pd.DataFrame):
pd.testing.assert_series_equal(expected_series, data.index.to_series())
pd.testing.assert_series_equal(expected_series, data.columns.to_series())
_assert_columns_and_index_match(expected_series, data)
data_src = data.to_numpy()
elif isinstance(data, tuple):
data_src, index_src = data
pd.testing.assert_series_equal(expected_series, index_src)
_assert_series_match(expected_series, index_src)
else:
raise ValueError(
"Expected data to be a `pd.DataFrame` or a tuple of (`sp.csr_matrix`, `pd.Series`), "
Expand Down Expand Up @@ -869,8 +868,8 @@ def set_xy(
- :attr:`xy` - the :term:`linear term`.
- :attr:`stage` - set to ``'prepared'``.
"""
pd.testing.assert_series_equal(self.adata_src.obs_names.to_series(), data.index.to_series())
pd.testing.assert_series_equal(self.adata_tgt.obs_names.to_series(), data.columns.to_series())
_assert_series_match(self.adata_src.obs_names.to_series(), data.index.to_series())
_assert_series_match(self.adata_tgt.obs_names.to_series(), data.columns.to_series())

self._xy = TaggedArray(data_src=data.to_numpy(), data_tgt=None, tag=Tag(tag), cost="cost")
self._stage = "prepared"
Expand All @@ -893,8 +892,7 @@ def set_x(self, data: pd.DataFrame, tag: Literal["cost_matrix", "kernel"]) -> No
- :attr:`x` - the source :term:`quadratic term`.
- :attr:`stage` - set to ``'prepared'``.
"""
pd.testing.assert_series_equal(self.adata_src.obs_names.to_series(), data.index.to_series())
pd.testing.assert_series_equal(self.adata_src.obs_names.to_series(), data.columns.to_series())
_assert_columns_and_index_match(self.adata_src.obs_names.to_series(), data)

if self.problem_kind == "linear":
logger.info(f"Changing the problem type from {self.problem_kind!r} to 'quadratic (fused)'.")
Expand All @@ -920,8 +918,7 @@ def set_y(self, data: pd.DataFrame, tag: Literal["cost_matrix", "kernel"]) -> No
- :attr:`y` - the target :term:`quadratic term`.
- :attr:`stage` - set to ``'prepared'``.
"""
pd.testing.assert_series_equal(self.adata_tgt.obs_names.to_series(), data.index.to_series())
pd.testing.assert_series_equal(self.adata_tgt.obs_names.to_series(), data.columns.to_series())
_assert_columns_and_index_match(self.adata_tgt.obs_names.to_series(), data)

if self.problem_kind == "linear":
logger.info(f"Changing the problem type from {self.problem_kind!r} to 'quadratic (fused)'.")
Expand Down

0 comments on commit d85fa68

Please sign in to comment.