diff --git a/src/moscot/base/problems/compound_problem.py b/src/moscot/base/problems/compound_problem.py index 0f6f7462..eef27314 100644 --- a/src/moscot/base/problems/compound_problem.py +++ b/src/moscot/base/problems/compound_problem.py @@ -313,6 +313,11 @@ def solve( problems = self._problem_manager.get_problems(stage=stage) logger.info(f"Solving `{len(problems)}` problems") + # expose min/max iterations to the user but remove them if they are None + if "min_iterations" in kwargs and kwargs["min_iterations"] is None: + kwargs.pop("min_iterations") + if "max_iterations" in kwargs and kwargs["max_iterations"] is None: + kwargs.pop("max_iterations") for problem in problems.values(): logger.info(f"Solving problem {problem}.") _ = problem.solve(**kwargs) diff --git a/src/moscot/problems/cross_modality/_translation.py b/src/moscot/problems/cross_modality/_translation.py index ff91f015..693058c7 100644 --- a/src/moscot/problems/cross_modality/_translation.py +++ b/src/moscot/problems/cross_modality/_translation.py @@ -192,8 +192,8 @@ def solve( # type: ignore[override] initializer: QuadInitializer_t = None, initializer_kwargs: Mapping[str, Any] = types.MappingProxyType({}), jit: bool = True, - min_iterations: int = 5, - max_iterations: int = 50, + min_iterations: Optional[int] = None, + max_iterations: Optional[int] = None, threshold: float = 1e-3, linear_solver_kwargs: Mapping[str, Any] = types.MappingProxyType({}), device: Optional[Literal["cpu", "gpu", "tpu"]] = None, diff --git a/src/moscot/problems/generic/_generic.py b/src/moscot/problems/generic/_generic.py index 6f21cc96..1466aa99 100644 --- a/src/moscot/problems/generic/_generic.py +++ b/src/moscot/problems/generic/_generic.py @@ -147,8 +147,8 @@ def solve( threshold: float = 1e-3, lse_mode: bool = True, inner_iterations: int = 10, - min_iterations: int = 0, - max_iterations: int = 2000, + min_iterations: Optional[int] = None, + max_iterations: Optional[int] = None, device: Optional[Literal["cpu", "gpu", "tpu"]] = None, **kwargs: Any, ) -> "SinkhornProblem[K,B]": @@ -373,8 +373,8 @@ def solve( initializer: QuadInitializer_t = None, initializer_kwargs: Mapping[str, Any] = types.MappingProxyType({}), jit: bool = True, - min_iterations: int = 5, - max_iterations: int = 50, + min_iterations: Optional[int] = None, + max_iterations: Optional[int] = None, threshold: float = 1e-3, linear_solver_kwargs: Mapping[str, Any] = types.MappingProxyType({}), device: Optional[Literal["cpu", "gpu", "tpu"]] = None, @@ -604,8 +604,8 @@ def solve( initializer: QuadInitializer_t = None, initializer_kwargs: Mapping[str, Any] = types.MappingProxyType({}), jit: bool = True, - min_iterations: int = 5, - max_iterations: int = 50, + min_iterations: Optional[int] = None, + max_iterations: Optional[int] = None, threshold: float = 1e-3, linear_solver_kwargs: Mapping[str, Any] = types.MappingProxyType({}), device: Optional[Literal["cpu", "gpu", "tpu"]] = None, diff --git a/src/moscot/problems/space/_alignment.py b/src/moscot/problems/space/_alignment.py index 9d10f36e..f5a398d2 100644 --- a/src/moscot/problems/space/_alignment.py +++ b/src/moscot/problems/space/_alignment.py @@ -160,8 +160,8 @@ def solve( initializer: QuadInitializer_t = None, initializer_kwargs: Mapping[str, Any] = types.MappingProxyType({}), jit: bool = True, - min_iterations: int = 5, - max_iterations: int = 50, + min_iterations: Optional[int] = None, + max_iterations: Optional[int] = None, threshold: float = 1e-3, linear_solver_kwargs: Mapping[str, Any] = types.MappingProxyType({}), device: Optional[Literal["cpu", "gpu", "tpu"]] = None, diff --git a/src/moscot/problems/space/_mapping.py b/src/moscot/problems/space/_mapping.py index f98e20a5..f3243344 100644 --- a/src/moscot/problems/space/_mapping.py +++ b/src/moscot/problems/space/_mapping.py @@ -207,8 +207,8 @@ def solve( initializer: QuadInitializer_t = None, initializer_kwargs: Mapping[str, Any] = types.MappingProxyType({}), jit: bool = True, - min_iterations: int = 5, - max_iterations: int = 50, + min_iterations: Optional[int] = None, + max_iterations: Optional[int] = None, threshold: float = 1e-3, linear_solver_kwargs: Mapping[str, Any] = types.MappingProxyType({}), device: Optional[Literal["cpu", "gpu", "tpu"]] = None, diff --git a/src/moscot/problems/spatiotemporal/_spatio_temporal.py b/src/moscot/problems/spatiotemporal/_spatio_temporal.py index 7a80d97c..6a2e0424 100644 --- a/src/moscot/problems/spatiotemporal/_spatio_temporal.py +++ b/src/moscot/problems/spatiotemporal/_spatio_temporal.py @@ -174,8 +174,8 @@ def solve( initializer: QuadInitializer_t = None, initializer_kwargs: Mapping[str, Any] = types.MappingProxyType({}), jit: bool = True, - min_iterations: int = 5, - max_iterations: int = 50, + min_iterations: Optional[int] = None, + max_iterations: Optional[int] = None, threshold: float = 1e-3, linear_solver_kwargs: Mapping[str, Any] = types.MappingProxyType({}), device: Optional[Literal["cpu", "gpu", "tpu"]] = None, diff --git a/src/moscot/problems/time/_lineage.py b/src/moscot/problems/time/_lineage.py index 8c878e8d..fea290f4 100644 --- a/src/moscot/problems/time/_lineage.py +++ b/src/moscot/problems/time/_lineage.py @@ -160,8 +160,8 @@ def solve( threshold: float = 1e-3, lse_mode: bool = True, inner_iterations: int = 10, - min_iterations: int = 0, - max_iterations: int = 2000, + min_iterations: Optional[int] = None, + max_iterations: Optional[int] = None, device: Optional[Literal["cpu", "gpu", "tpu"]] = None, **kwargs: Any, ) -> "TemporalProblem": @@ -406,8 +406,8 @@ def solve( initializer: QuadInitializer_t = None, initializer_kwargs: Mapping[str, Any] = types.MappingProxyType({}), jit: bool = True, - min_iterations: int = 5, - max_iterations: int = 50, + min_iterations: Optional[int] = None, + max_iterations: Optional[int] = None, threshold: float = 1e-3, linear_solver_kwargs: Mapping[str, Any] = types.MappingProxyType({}), device: Optional[Literal["cpu", "gpu", "tpu"]] = None,