From 0ef91775a72b8ae477330415ddf57c2b7bfd6514 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Selman=20=C3=96zleyen?= <32667648+selmanozleyen@users.noreply.github.com> Date: Tue, 16 Apr 2024 08:18:15 +0200 Subject: [PATCH] Error/Warning for incompatible terms with solvers (#681) * init * add warning * change implementation * Revert "init" This reverts commit 7457fbfdac15d16e25ef2827b384d8714e4f3183. * check on OTProblem level * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * change defaults according to linear vs quadratic * fix tests and use public getters * revert notebook metadata * add tests where alpha is used --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- src/moscot/base/problems/problem.py | 15 +++++++++++++++ src/moscot/problems/generic/_generic.py | 6 ++++-- .../cross_modality/test_translation_problem.py | 12 ++++++++++-- tests/problems/space/test_mapping_problem.py | 1 + 4 files changed, 30 insertions(+), 4 deletions(-) diff --git a/src/moscot/base/problems/problem.py b/src/moscot/base/problems/problem.py index c602e1a22..e2b47cabb 100644 --- a/src/moscot/base/problems/problem.py +++ b/src/moscot/base/problems/problem.py @@ -403,6 +403,21 @@ def solve( """ solver_class = backends.get_solver(self.problem_kind, backend=backend, return_class=True) init_kwargs, call_kwargs = solver_class._partition_kwargs(**kwargs) + # if linear problem, then alpha is 0.0 by default + # if quadratic problem, then alpha is 1.0 by default + alpha = call_kwargs.get("alpha", 0.0 if self.problem_kind == "linear" else 1.0) + if alpha < 0.0 or alpha > 1.0: + raise ValueError("Expected `alpha` to be in the range `[0, 1]`, found `{alpha}`.") + if self.problem_kind == "linear" and (alpha != 0.0 or not (self.x is None or self.y is None)): + raise ValueError("Unable to solve a linear problem with `alpha != 0` or `x` and `y` supplied.") + if self.problem_kind == "quadratic": + if self.x is None or self.y is None: + raise ValueError("Unable to solve a quadratic problem without `x` and `y` supplied.") + if alpha != 1.0 and self.xy is None: # means FGW case + raise ValueError( + "`alpha` must be 1.0 for quadratic problems without `xy` supplied. See `FGWProblem` class." + ) + self._solver = solver_class(**init_kwargs) self._solution = self._solver( # type: ignore[misc] diff --git a/src/moscot/problems/generic/_generic.py b/src/moscot/problems/generic/_generic.py index 09f1e99ba..6f21cc964 100644 --- a/src/moscot/problems/generic/_generic.py +++ b/src/moscot/problems/generic/_generic.py @@ -593,7 +593,7 @@ def prepare( def solve( self, - alpha: float = 1.0, + alpha: float = 0.5, epsilon: float = 1e-3, tau_a: float = 1.0, tau_b: float = 1.0, @@ -622,7 +622,7 @@ def solve( Parameters ---------- alpha - Parameter in :math:`(0, 1]` that interpolates between the :term:`quadratic term` and + Parameter in :math:`(0, 1)` that interpolates between the :term:`quadratic term` and the :term:`linear term`. :math:`\alpha = 1` corresponds to the pure :term:`Gromov-Wasserstein` problem while :math:`\alpha \to 0` corresponds to the pure :term:`linear problem`. epsilon @@ -672,6 +672,8 @@ def solve( - :attr:`solutions` - the :term:`OT` solutions for each subproblem. - :attr:`stage` - set to ``'solved'``. """ + if alpha == 1.0: + raise ValueError("The `FGWProblem` is equivalent to the `GWProblem` when `alpha=1.0`.") return CompoundProblem.solve( self, # type: ignore[return-value, arg-type] alpha=alpha, diff --git a/tests/problems/cross_modality/test_translation_problem.py b/tests/problems/cross_modality/test_translation_problem.py index f2c3d0314..a90845d75 100644 --- a/tests/problems/cross_modality/test_translation_problem.py +++ b/tests/problems/cross_modality/test_translation_problem.py @@ -83,7 +83,14 @@ def test_prepare_external_star_policy( @pytest.mark.parametrize( ("epsilon", "alpha", "rank", "initializer"), - [(1e-2, 0.9, -1, None), (2, 0.5, -1, "random"), (2, 0.5, -1, "rank2"), (2, 0.1, -1, None)], + [ + (1e-2, 0.9, -1, None), + (2, 0.5, -1, "random"), + (2, 1.0, -1, "rank2"), + (2, 0.1, -1, None), + (2, 1.0, -1, None), + (1.3, 1.0, -1, "random"), + ], ) @pytest.mark.parametrize("src_attr", ["emb_src", {"attr": "obsm", "key": "emb_src"}]) @pytest.mark.parametrize("tgt_attr", ["emb_tgt", {"attr": "obsm", "key": "emb_tgt"}]) @@ -104,7 +111,8 @@ def test_solve_balanced( kwargs["initializer"] = initializer tp = TranslationProblem(adata_src, adata_tgt) - tp = tp.prepare(batch_key="batch", src_attr=src_attr, tgt_attr=tgt_attr) + joint_attr = None if alpha == 1.0 else {"attr": "obsm", "key": "X_pca"} + tp = tp.prepare(batch_key="batch", src_attr=src_attr, tgt_attr=tgt_attr, joint_attr=joint_attr) tp = tp.solve(epsilon=epsilon, alpha=alpha, rank=rank, **kwargs) for key, subsol in tp.solutions.items(): diff --git a/tests/problems/space/test_mapping_problem.py b/tests/problems/space/test_mapping_problem.py index c814f80e5..82a3d090e 100644 --- a/tests/problems/space/test_mapping_problem.py +++ b/tests/problems/space/test_mapping_problem.py @@ -119,6 +119,7 @@ def test_solve_balanced( return # TODO(@MUCDK) fix after refactoring mp = MappingProblem(adataref, adatasp) mp = mp.prepare(batch_key="batch", sc_attr=sc_attr, var_names=var_names) + alpha = alpha if mp.filtered_vars is not None else 1.0 mp = mp.solve(epsilon=epsilon, alpha=alpha, rank=rank, **kwargs) for prob_key in mp: