Skip to content

Commit

Permalink
Merge pull request #686 from theislab/fix/min_max_iter_default
Browse files Browse the repository at this point in the history
Set defaults of `min_iterations` and `max_iterations` to None
  • Loading branch information
selmanozleyen authored May 7, 2024
2 parents 389e008 + 7926ea3 commit 6cc2215
Show file tree
Hide file tree
Showing 7 changed files with 23 additions and 18 deletions.
5 changes: 5 additions & 0 deletions src/moscot/base/problems/compound_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,11 @@ def solve(
problems = self._problem_manager.get_problems(stage=stage)

logger.info(f"Solving `{len(problems)}` problems")
# expose min/max iterations to the user but remove them if they are None
if "min_iterations" in kwargs and kwargs["min_iterations"] is None:
kwargs.pop("min_iterations")
if "max_iterations" in kwargs and kwargs["max_iterations"] is None:
kwargs.pop("max_iterations")
for problem in problems.values():
logger.info(f"Solving problem {problem}.")
_ = problem.solve(**kwargs)
Expand Down
4 changes: 2 additions & 2 deletions src/moscot/problems/cross_modality/_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,8 @@ def solve( # type: ignore[override]
initializer: QuadInitializer_t = None,
initializer_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
jit: bool = True,
min_iterations: int = 5,
max_iterations: int = 50,
min_iterations: Optional[int] = None,
max_iterations: Optional[int] = None,
threshold: float = 1e-3,
linear_solver_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
device: Optional[Literal["cpu", "gpu", "tpu"]] = None,
Expand Down
12 changes: 6 additions & 6 deletions src/moscot/problems/generic/_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,8 @@ def solve(
threshold: float = 1e-3,
lse_mode: bool = True,
inner_iterations: int = 10,
min_iterations: int = 0,
max_iterations: int = 2000,
min_iterations: Optional[int] = None,
max_iterations: Optional[int] = None,
device: Optional[Literal["cpu", "gpu", "tpu"]] = None,
**kwargs: Any,
) -> "SinkhornProblem[K,B]":
Expand Down Expand Up @@ -373,8 +373,8 @@ def solve(
initializer: QuadInitializer_t = None,
initializer_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
jit: bool = True,
min_iterations: int = 5,
max_iterations: int = 50,
min_iterations: Optional[int] = None,
max_iterations: Optional[int] = None,
threshold: float = 1e-3,
linear_solver_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
device: Optional[Literal["cpu", "gpu", "tpu"]] = None,
Expand Down Expand Up @@ -604,8 +604,8 @@ def solve(
initializer: QuadInitializer_t = None,
initializer_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
jit: bool = True,
min_iterations: int = 5,
max_iterations: int = 50,
min_iterations: Optional[int] = None,
max_iterations: Optional[int] = None,
threshold: float = 1e-3,
linear_solver_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
device: Optional[Literal["cpu", "gpu", "tpu"]] = None,
Expand Down
4 changes: 2 additions & 2 deletions src/moscot/problems/space/_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,8 @@ def solve(
initializer: QuadInitializer_t = None,
initializer_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
jit: bool = True,
min_iterations: int = 5,
max_iterations: int = 50,
min_iterations: Optional[int] = None,
max_iterations: Optional[int] = None,
threshold: float = 1e-3,
linear_solver_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
device: Optional[Literal["cpu", "gpu", "tpu"]] = None,
Expand Down
4 changes: 2 additions & 2 deletions src/moscot/problems/space/_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,8 @@ def solve(
initializer: QuadInitializer_t = None,
initializer_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
jit: bool = True,
min_iterations: int = 5,
max_iterations: int = 50,
min_iterations: Optional[int] = None,
max_iterations: Optional[int] = None,
threshold: float = 1e-3,
linear_solver_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
device: Optional[Literal["cpu", "gpu", "tpu"]] = None,
Expand Down
4 changes: 2 additions & 2 deletions src/moscot/problems/spatiotemporal/_spatio_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,8 @@ def solve(
initializer: QuadInitializer_t = None,
initializer_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
jit: bool = True,
min_iterations: int = 5,
max_iterations: int = 50,
min_iterations: Optional[int] = None,
max_iterations: Optional[int] = None,
threshold: float = 1e-3,
linear_solver_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
device: Optional[Literal["cpu", "gpu", "tpu"]] = None,
Expand Down
8 changes: 4 additions & 4 deletions src/moscot/problems/time/_lineage.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,8 @@ def solve(
threshold: float = 1e-3,
lse_mode: bool = True,
inner_iterations: int = 10,
min_iterations: int = 0,
max_iterations: int = 2000,
min_iterations: Optional[int] = None,
max_iterations: Optional[int] = None,
device: Optional[Literal["cpu", "gpu", "tpu"]] = None,
**kwargs: Any,
) -> "TemporalProblem":
Expand Down Expand Up @@ -406,8 +406,8 @@ def solve(
initializer: QuadInitializer_t = None,
initializer_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
jit: bool = True,
min_iterations: int = 5,
max_iterations: int = 50,
min_iterations: Optional[int] = None,
max_iterations: Optional[int] = None,
threshold: float = 1e-3,
linear_solver_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
device: Optional[Literal["cpu", "gpu", "tpu"]] = None,
Expand Down

0 comments on commit 6cc2215

Please sign in to comment.