From 561b1aeefa31f9767482bb97ad7506781afe1b03 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Tue, 9 Apr 2024 14:19:38 +0200 Subject: [PATCH 1/3] change min_max_iterations default --- src/moscot/base/problems/compound_problem.py | 5 +++++ src/moscot/problems/cross_modality/_translation.py | 4 ++-- src/moscot/problems/generic/_generic.py | 12 ++++++------ src/moscot/problems/space/_alignment.py | 4 ++-- src/moscot/problems/space/_mapping.py | 4 ++-- .../problems/spatiotemporal/_spatio_temporal.py | 4 ++-- src/moscot/problems/time/_lineage.py | 8 ++++---- 7 files changed, 23 insertions(+), 18 deletions(-) diff --git a/src/moscot/base/problems/compound_problem.py b/src/moscot/base/problems/compound_problem.py index 0f6f7462..eef27314 100644 --- a/src/moscot/base/problems/compound_problem.py +++ b/src/moscot/base/problems/compound_problem.py @@ -313,6 +313,11 @@ def solve( problems = self._problem_manager.get_problems(stage=stage) logger.info(f"Solving `{len(problems)}` problems") + # expose min/max iterations to the user but remove them if they are None + if "min_iterations" in kwargs and kwargs["min_iterations"] is None: + kwargs.pop("min_iterations") + if "max_iterations" in kwargs and kwargs["max_iterations"] is None: + kwargs.pop("max_iterations") for problem in problems.values(): logger.info(f"Solving problem {problem}.") _ = problem.solve(**kwargs) diff --git a/src/moscot/problems/cross_modality/_translation.py b/src/moscot/problems/cross_modality/_translation.py index ff91f015..693058c7 100644 --- a/src/moscot/problems/cross_modality/_translation.py +++ b/src/moscot/problems/cross_modality/_translation.py @@ -192,8 +192,8 @@ def solve( # type: ignore[override] initializer: QuadInitializer_t = None, initializer_kwargs: Mapping[str, Any] = types.MappingProxyType({}), jit: bool = True, - min_iterations: int = 5, - max_iterations: int = 50, + min_iterations: Optional[int] = None, + max_iterations: Optional[int] = None, threshold: float = 1e-3, linear_solver_kwargs: Mapping[str, Any] = types.MappingProxyType({}), device: Optional[Literal["cpu", "gpu", "tpu"]] = None, diff --git a/src/moscot/problems/generic/_generic.py b/src/moscot/problems/generic/_generic.py index 09f1e99b..d1b19f08 100644 --- a/src/moscot/problems/generic/_generic.py +++ b/src/moscot/problems/generic/_generic.py @@ -147,8 +147,8 @@ def solve( threshold: float = 1e-3, lse_mode: bool = True, inner_iterations: int = 10, - min_iterations: int = 0, - max_iterations: int = 2000, + min_iterations: Optional[int] = None, + max_iterations: Optional[int] = None, device: Optional[Literal["cpu", "gpu", "tpu"]] = None, **kwargs: Any, ) -> "SinkhornProblem[K,B]": @@ -373,8 +373,8 @@ def solve( initializer: QuadInitializer_t = None, initializer_kwargs: Mapping[str, Any] = types.MappingProxyType({}), jit: bool = True, - min_iterations: int = 5, - max_iterations: int = 50, + min_iterations: Optional[int] = None, + max_iterations: Optional[int] = None, threshold: float = 1e-3, linear_solver_kwargs: Mapping[str, Any] = types.MappingProxyType({}), device: Optional[Literal["cpu", "gpu", "tpu"]] = None, @@ -604,8 +604,8 @@ def solve( initializer: QuadInitializer_t = None, initializer_kwargs: Mapping[str, Any] = types.MappingProxyType({}), jit: bool = True, - min_iterations: int = 5, - max_iterations: int = 50, + min_iterations: Optional[int] = None, + max_iterations: Optional[int] = None, threshold: float = 1e-3, linear_solver_kwargs: Mapping[str, Any] = types.MappingProxyType({}), device: Optional[Literal["cpu", "gpu", "tpu"]] = None, diff --git a/src/moscot/problems/space/_alignment.py b/src/moscot/problems/space/_alignment.py index 9d10f36e..f5a398d2 100644 --- a/src/moscot/problems/space/_alignment.py +++ b/src/moscot/problems/space/_alignment.py @@ -160,8 +160,8 @@ def solve( initializer: QuadInitializer_t = None, initializer_kwargs: Mapping[str, Any] = types.MappingProxyType({}), jit: bool = True, - min_iterations: int = 5, - max_iterations: int = 50, + min_iterations: Optional[int] = None, + max_iterations: Optional[int] = None, threshold: float = 1e-3, linear_solver_kwargs: Mapping[str, Any] = types.MappingProxyType({}), device: Optional[Literal["cpu", "gpu", "tpu"]] = None, diff --git a/src/moscot/problems/space/_mapping.py b/src/moscot/problems/space/_mapping.py index f98e20a5..f3243344 100644 --- a/src/moscot/problems/space/_mapping.py +++ b/src/moscot/problems/space/_mapping.py @@ -207,8 +207,8 @@ def solve( initializer: QuadInitializer_t = None, initializer_kwargs: Mapping[str, Any] = types.MappingProxyType({}), jit: bool = True, - min_iterations: int = 5, - max_iterations: int = 50, + min_iterations: Optional[int] = None, + max_iterations: Optional[int] = None, threshold: float = 1e-3, linear_solver_kwargs: Mapping[str, Any] = types.MappingProxyType({}), device: Optional[Literal["cpu", "gpu", "tpu"]] = None, diff --git a/src/moscot/problems/spatiotemporal/_spatio_temporal.py b/src/moscot/problems/spatiotemporal/_spatio_temporal.py index 7a80d97c..6a2e0424 100644 --- a/src/moscot/problems/spatiotemporal/_spatio_temporal.py +++ b/src/moscot/problems/spatiotemporal/_spatio_temporal.py @@ -174,8 +174,8 @@ def solve( initializer: QuadInitializer_t = None, initializer_kwargs: Mapping[str, Any] = types.MappingProxyType({}), jit: bool = True, - min_iterations: int = 5, - max_iterations: int = 50, + min_iterations: Optional[int] = None, + max_iterations: Optional[int] = None, threshold: float = 1e-3, linear_solver_kwargs: Mapping[str, Any] = types.MappingProxyType({}), device: Optional[Literal["cpu", "gpu", "tpu"]] = None, diff --git a/src/moscot/problems/time/_lineage.py b/src/moscot/problems/time/_lineage.py index 8c878e8d..fea290f4 100644 --- a/src/moscot/problems/time/_lineage.py +++ b/src/moscot/problems/time/_lineage.py @@ -160,8 +160,8 @@ def solve( threshold: float = 1e-3, lse_mode: bool = True, inner_iterations: int = 10, - min_iterations: int = 0, - max_iterations: int = 2000, + min_iterations: Optional[int] = None, + max_iterations: Optional[int] = None, device: Optional[Literal["cpu", "gpu", "tpu"]] = None, **kwargs: Any, ) -> "TemporalProblem": @@ -406,8 +406,8 @@ def solve( initializer: QuadInitializer_t = None, initializer_kwargs: Mapping[str, Any] = types.MappingProxyType({}), jit: bool = True, - min_iterations: int = 5, - max_iterations: int = 50, + min_iterations: Optional[int] = None, + max_iterations: Optional[int] = None, threshold: float = 1e-3, linear_solver_kwargs: Mapping[str, Any] = types.MappingProxyType({}), device: Optional[Literal["cpu", "gpu", "tpu"]] = None, From 057cdf3f3a8534efd3146583125b4e09e81c329b Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Mon, 15 Apr 2024 22:22:57 +0200 Subject: [PATCH 2/3] resolve comment --- src/moscot/base/problems/compound_problem.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/moscot/base/problems/compound_problem.py b/src/moscot/base/problems/compound_problem.py index eef27314..c48a9c6e 100644 --- a/src/moscot/base/problems/compound_problem.py +++ b/src/moscot/base/problems/compound_problem.py @@ -314,9 +314,9 @@ def solve( logger.info(f"Solving `{len(problems)}` problems") # expose min/max iterations to the user but remove them if they are None - if "min_iterations" in kwargs and kwargs["min_iterations"] is None: + if kwargs.get("min_iterations") is None: kwargs.pop("min_iterations") - if "max_iterations" in kwargs and kwargs["max_iterations"] is None: + if kwargs.get("max_iterations") is None: kwargs.pop("max_iterations") for problem in problems.values(): logger.info(f"Solving problem {problem}.") From aa6ad5a1ea73e51c9d24adfe533aeb88cd4464ac Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Thu, 18 Apr 2024 10:29:00 +0200 Subject: [PATCH 3/3] revert to old commit --- src/moscot/base/problems/compound_problem.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/moscot/base/problems/compound_problem.py b/src/moscot/base/problems/compound_problem.py index c48a9c6e..eef27314 100644 --- a/src/moscot/base/problems/compound_problem.py +++ b/src/moscot/base/problems/compound_problem.py @@ -314,9 +314,9 @@ def solve( logger.info(f"Solving `{len(problems)}` problems") # expose min/max iterations to the user but remove them if they are None - if kwargs.get("min_iterations") is None: + if "min_iterations" in kwargs and kwargs["min_iterations"] is None: kwargs.pop("min_iterations") - if kwargs.get("max_iterations") is None: + if "max_iterations" in kwargs and kwargs["max_iterations"] is None: kwargs.pop("max_iterations") for problem in problems.values(): logger.info(f"Solving problem {problem}.")