Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error/Warning for incompatible terms with solvers #681

Merged
merged 11 commits into from
Apr 16, 2024
15 changes: 15 additions & 0 deletions src/moscot/base/problems/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,21 @@ def solve(
"""
solver_class = backends.get_solver(self.problem_kind, backend=backend, return_class=True)
init_kwargs, call_kwargs = solver_class._partition_kwargs(**kwargs)
# if linear problem, then alpha is 0.0 by default
# if quadratic problem, then alpha is 1.0 by default
alpha = call_kwargs.get("alpha", 0.0 if self.problem_kind == "linear" else 1.0)
if alpha < 0.0 or alpha > 1.0:
raise ValueError("Expected `alpha` to be in the range `[0, 1]`, found `{alpha}`.")
if self.problem_kind == "linear" and (alpha != 0.0 or not (self.x is None or self.y is None)):
raise ValueError("Unable to solve a linear problem with `alpha != 0` or `x` and `y` supplied.")
if self.problem_kind == "quadratic":
if self.x is None or self.y is None:
raise ValueError("Unable to solve a quadratic problem without `x` and `y` supplied.")
if alpha != 1.0 and self.xy is None: # means FGW case
raise ValueError(
"`alpha` must be 1.0 for quadratic problems without `xy` supplied. See `FGWProblem` class."
)

self._solver = solver_class(**init_kwargs)

self._solution = self._solver( # type: ignore[misc]
Expand Down
6 changes: 4 additions & 2 deletions src/moscot/problems/generic/_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,7 @@ def prepare(

def solve(
self,
alpha: float = 1.0,
alpha: float = 0.5,
epsilon: float = 1e-3,
tau_a: float = 1.0,
tau_b: float = 1.0,
Expand Down Expand Up @@ -622,7 +622,7 @@ def solve(
Parameters
----------
alpha
Parameter in :math:`(0, 1]` that interpolates between the :term:`quadratic term` and
Parameter in :math:`(0, 1)` that interpolates between the :term:`quadratic term` and
the :term:`linear term`. :math:`\alpha = 1` corresponds to the pure :term:`Gromov-Wasserstein` problem while
:math:`\alpha \to 0` corresponds to the pure :term:`linear problem`.
epsilon
Expand Down Expand Up @@ -672,6 +672,8 @@ def solve(
- :attr:`solutions` - the :term:`OT` solutions for each subproblem.
- :attr:`stage` - set to ``'solved'``.
"""
if alpha == 1.0:
raise ValueError("The `FGWProblem` is equivalent to the `GWProblem` when `alpha=1.0`.")
return CompoundProblem.solve(
self, # type: ignore[return-value, arg-type]
alpha=alpha,
Expand Down
12 changes: 10 additions & 2 deletions tests/problems/cross_modality/test_translation_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,14 @@ def test_prepare_external_star_policy(

@pytest.mark.parametrize(
("epsilon", "alpha", "rank", "initializer"),
[(1e-2, 0.9, -1, None), (2, 0.5, -1, "random"), (2, 0.5, -1, "rank2"), (2, 0.1, -1, None)],
[
(1e-2, 0.9, -1, None),
(2, 0.5, -1, "random"),
(2, 1.0, -1, "rank2"),
(2, 0.1, -1, None),
(2, 1.0, -1, None),
(1.3, 1.0, -1, "random"),
],
)
@pytest.mark.parametrize("src_attr", ["emb_src", {"attr": "obsm", "key": "emb_src"}])
@pytest.mark.parametrize("tgt_attr", ["emb_tgt", {"attr": "obsm", "key": "emb_tgt"}])
Expand All @@ -104,7 +111,8 @@ def test_solve_balanced(
kwargs["initializer"] = initializer

tp = TranslationProblem(adata_src, adata_tgt)
tp = tp.prepare(batch_key="batch", src_attr=src_attr, tgt_attr=tgt_attr)
joint_attr = None if alpha == 1.0 else {"attr": "obsm", "key": "X_pca"}
tp = tp.prepare(batch_key="batch", src_attr=src_attr, tgt_attr=tgt_attr, joint_attr=joint_attr)
tp = tp.solve(epsilon=epsilon, alpha=alpha, rank=rank, **kwargs)

for key, subsol in tp.solutions.items():
Expand Down
1 change: 1 addition & 0 deletions tests/problems/space/test_mapping_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def test_solve_balanced(
return # TODO(@MUCDK) fix after refactoring
mp = MappingProblem(adataref, adatasp)
mp = mp.prepare(batch_key="batch", sc_attr=sc_attr, var_names=var_names)
alpha = alpha if mp.filtered_vars is not None else 1.0
mp = mp.solve(epsilon=epsilon, alpha=alpha, rank=rank, **kwargs)

for prob_key in mp:
Expand Down
Loading