diff --git a/src/moscot/base/problems/problem.py b/src/moscot/base/problems/problem.py index c602e1a22..e2b47cabb 100644 --- a/src/moscot/base/problems/problem.py +++ b/src/moscot/base/problems/problem.py @@ -403,6 +403,21 @@ def solve( """ solver_class = backends.get_solver(self.problem_kind, backend=backend, return_class=True) init_kwargs, call_kwargs = solver_class._partition_kwargs(**kwargs) + # if linear problem, then alpha is 0.0 by default + # if quadratic problem, then alpha is 1.0 by default + alpha = call_kwargs.get("alpha", 0.0 if self.problem_kind == "linear" else 1.0) + if alpha < 0.0 or alpha > 1.0: + raise ValueError("Expected `alpha` to be in the range `[0, 1]`, found `{alpha}`.") + if self.problem_kind == "linear" and (alpha != 0.0 or not (self.x is None or self.y is None)): + raise ValueError("Unable to solve a linear problem with `alpha != 0` or `x` and `y` supplied.") + if self.problem_kind == "quadratic": + if self.x is None or self.y is None: + raise ValueError("Unable to solve a quadratic problem without `x` and `y` supplied.") + if alpha != 1.0 and self.xy is None: # means FGW case + raise ValueError( + "`alpha` must be 1.0 for quadratic problems without `xy` supplied. See `FGWProblem` class." + ) + self._solver = solver_class(**init_kwargs) self._solution = self._solver( # type: ignore[misc] diff --git a/src/moscot/problems/generic/_generic.py b/src/moscot/problems/generic/_generic.py index 09f1e99ba..6f21cc964 100644 --- a/src/moscot/problems/generic/_generic.py +++ b/src/moscot/problems/generic/_generic.py @@ -593,7 +593,7 @@ def prepare( def solve( self, - alpha: float = 1.0, + alpha: float = 0.5, epsilon: float = 1e-3, tau_a: float = 1.0, tau_b: float = 1.0, @@ -622,7 +622,7 @@ def solve( Parameters ---------- alpha - Parameter in :math:`(0, 1]` that interpolates between the :term:`quadratic term` and + Parameter in :math:`(0, 1)` that interpolates between the :term:`quadratic term` and the :term:`linear term`. :math:`\alpha = 1` corresponds to the pure :term:`Gromov-Wasserstein` problem while :math:`\alpha \to 0` corresponds to the pure :term:`linear problem`. epsilon @@ -672,6 +672,8 @@ def solve( - :attr:`solutions` - the :term:`OT` solutions for each subproblem. - :attr:`stage` - set to ``'solved'``. """ + if alpha == 1.0: + raise ValueError("The `FGWProblem` is equivalent to the `GWProblem` when `alpha=1.0`.") return CompoundProblem.solve( self, # type: ignore[return-value, arg-type] alpha=alpha, diff --git a/tests/problems/cross_modality/test_translation_problem.py b/tests/problems/cross_modality/test_translation_problem.py index f2c3d0314..a90845d75 100644 --- a/tests/problems/cross_modality/test_translation_problem.py +++ b/tests/problems/cross_modality/test_translation_problem.py @@ -83,7 +83,14 @@ def test_prepare_external_star_policy( @pytest.mark.parametrize( ("epsilon", "alpha", "rank", "initializer"), - [(1e-2, 0.9, -1, None), (2, 0.5, -1, "random"), (2, 0.5, -1, "rank2"), (2, 0.1, -1, None)], + [ + (1e-2, 0.9, -1, None), + (2, 0.5, -1, "random"), + (2, 1.0, -1, "rank2"), + (2, 0.1, -1, None), + (2, 1.0, -1, None), + (1.3, 1.0, -1, "random"), + ], ) @pytest.mark.parametrize("src_attr", ["emb_src", {"attr": "obsm", "key": "emb_src"}]) @pytest.mark.parametrize("tgt_attr", ["emb_tgt", {"attr": "obsm", "key": "emb_tgt"}]) @@ -104,7 +111,8 @@ def test_solve_balanced( kwargs["initializer"] = initializer tp = TranslationProblem(adata_src, adata_tgt) - tp = tp.prepare(batch_key="batch", src_attr=src_attr, tgt_attr=tgt_attr) + joint_attr = None if alpha == 1.0 else {"attr": "obsm", "key": "X_pca"} + tp = tp.prepare(batch_key="batch", src_attr=src_attr, tgt_attr=tgt_attr, joint_attr=joint_attr) tp = tp.solve(epsilon=epsilon, alpha=alpha, rank=rank, **kwargs) for key, subsol in tp.solutions.items(): diff --git a/tests/problems/space/test_mapping_problem.py b/tests/problems/space/test_mapping_problem.py index 6a8020a4f..65e5641ee 100644 --- a/tests/problems/space/test_mapping_problem.py +++ b/tests/problems/space/test_mapping_problem.py @@ -120,6 +120,7 @@ def test_solve_balanced( return # TODO(@MUCDK) fix after refactoring mp = MappingProblem(adataref, adatasp) mp = mp.prepare(batch_key="batch", sc_attr=sc_attr, var_names=var_names) + alpha = alpha if mp.filtered_vars is not None else 1.0 mp = mp.solve(epsilon=epsilon, alpha=alpha, rank=rank, **kwargs) for prob_key in mp: