diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1435bab10..396cca399 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,22 +4,22 @@ default_language_version: default_stages: - commit - push -minimum_pre_commit_version: 2.16.0 +minimum_pre_commit_version: 3.0.0 repos: - repo: https://github.com/google/yapf - rev: v0.32.0 + rev: v0.40.0 hooks: - id: yapf additional_dependencies: [toml] - repo: https://github.com/nbQA-dev/nbQA - rev: 1.6.3 + rev: 1.7.0 hooks: - id: nbqa-pyupgrade args: [--py38-plus] - id: nbqa-black - id: nbqa-isort - repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks - rev: v2.7.0 + rev: v2.10.0 hooks: - id: pretty-format-yaml args: [--autofix, --indent, '2'] @@ -38,12 +38,12 @@ repos: - id: check-case-conflict - repo: https://github.com/charliermarsh/ruff-pre-commit # Ruff version. - rev: v0.0.252 + rev: v0.0.285 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] - repo: https://github.com/rstcheck/rstcheck - rev: v6.1.1 + rev: v6.1.2 hooks: - id: rstcheck additional_dependencies: [tomli] diff --git a/LICENSE b/LICENSE index d64569567..261eeb9e9 100644 --- a/LICENSE +++ b/LICENSE @@ -1,4 +1,3 @@ - Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ diff --git a/docs/spelling/technical.txt b/docs/spelling/technical.txt index 008dac7b4..5f7d2581a 100644 --- a/docs/spelling/technical.txt +++ b/docs/spelling/technical.txt @@ -79,6 +79,7 @@ logit macOS methylation neuroimaging +normed numerics omics optimality @@ -89,6 +90,7 @@ parameterizing piecewise pluripotent positivity +postfix potentials precompile precompute diff --git a/pyproject.toml b/pyproject.toml index bb7a86614..a100ee68e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -179,7 +179,7 @@ ignore_path = ["docs/**/_autosummary", "docs/contributing.rst"] legacy_tox_ini = """ [tox] min_version = 4.0 - env_list = lint-code,py{3.8,3.9,3.10,3.11} + env_list = lint-code,py{3.8,3.9,3.10,3.11},py3.9-jax-default skip_missing_interpreters = true [testenv] @@ -299,10 +299,10 @@ select = [ unfixable = ["B", "UP", "C4", "BLE", "T20", "RET"] target-version = "py38" [tool.ruff.per-file-ignores] - "tests/*" = ["D", "PT004"] # TODO(michalk8): remove `self.initialize` in `tests/` - "*/__init__.py" = ["F401"] - "docs/*" = ["D"] - "src/ott/types.py" = ["D102"] +"tests/*" = ["D", "PT004"] # TODO(michalk8): remove `self.initialize` in `tests/` +"*/__init__.py" = ["F401"] +"docs/*" = ["D"] +"src/ott/types.py" = ["D102"] [tool.ruff.pydocstyle] convention = "google" [tool.ruff.pyupgrade] diff --git a/src/ott/geometry/costs.py b/src/ott/geometry/costs.py index 4556269cd..9f1a6c3a0 100644 --- a/src/ott/geometry/costs.py +++ b/src/ott/geometry/costs.py @@ -614,8 +614,8 @@ def _reg(self, z: jnp.ndarray) -> float: # noqa: D102 # Choose first index satisfying constraint in Prop 2.1 lower_bound = cesaro - top_w >= 0 # Last upper bound is always True. - upper_bound = jnp.concatenate(((top_w[1:] - cesaro[:-1] > 0), - jnp.array((True,)))) + upper_bound = jnp.concatenate(((top_w[1:] - cesaro[:-1] + > 0), jnp.array((True,)))) r = jnp.argmax(lower_bound * upper_bound) s = jnp.sum(jnp.where(jnp.arange(k) < k - r - 1, jnp.flip(top_w) ** 2, 0)) diff --git a/src/ott/initializers/linear/initializers_lr.py b/src/ott/initializers/linear/initializers_lr.py index d4c3aea3e..78146b52e 100644 --- a/src/ott/initializers/linear/initializers_lr.py +++ b/src/ott/initializers/linear/initializers_lr.py @@ -613,8 +613,8 @@ def body_fn( crossed_threshold = jnp.logical_or( state.crossed_threshold, jnp.logical_and( - state.criterions[it - 1] >= consts.threshold, - criterion < consts.threshold + state.criterions[it - 1] >= consts.threshold, criterion + < consts.threshold ) ) diff --git a/src/ott/solvers/linear/acceleration.py b/src/ott/solvers/linear/acceleration.py index 641985794..4529e7f78 100644 --- a/src/ott/solvers/linear/acceleration.py +++ b/src/ott/solvers/linear/acceleration.py @@ -144,8 +144,8 @@ def weight(self, state: "sinkhorn.SinkhornState", iteration: int) -> float: return jax.lax.cond( jnp.logical_and( - iteration >= self.start, - state.errors[idx - 1, -1] < self.error_threshold + iteration >= self.start, state.errors[idx - 1, -1] + < self.error_threshold ), lambda state: self.lehmann(state), lambda state: self.value, state ) diff --git a/src/ott/solvers/linear/sinkhorn.py b/src/ott/solvers/linear/sinkhorn.py index 59ef5a5a9..b3825d520 100644 --- a/src/ott/solvers/linear/sinkhorn.py +++ b/src/ott/solvers/linear/sinkhorn.py @@ -1077,8 +1077,8 @@ def output_from_state( # position) is lower than the threshold. converged = jnp.logical_and( - jnp.logical_not(jnp.any(jnp.isnan(state.errors))), - state.errors[-1] < self.threshold + jnp.logical_not(jnp.any(jnp.isnan(state.errors))), state.errors[-1] + < self.threshold )[0] return SinkhornOutput( diff --git a/src/ott/tools/gaussian_mixture/scale_tril.py b/src/ott/tools/gaussian_mixture/scale_tril.py index 03cbf0a7a..2fbe1bb82 100644 --- a/src/ott/tools/gaussian_mixture/scale_tril.py +++ b/src/ott/tools/gaussian_mixture/scale_tril.py @@ -155,8 +155,9 @@ def _flatten_cov(cov: jnp.ndarray) -> jnp.ndarray: x0 = _flatten_cov(self.covariance()) x1 = _flatten_cov(other.covariance()) cost_fn = costs.Bures(dimension=dimension) - return (cost_fn.norm(x0) + cost_fn.norm(x1) + - cost_fn.pairwise(x0, x1))[...,] + return (cost_fn.norm(x0) + cost_fn.norm(x1) + cost_fn.pairwise(x0, x1))[ + ..., + ] def gaussian_map(self, dest_scale: "ScaleTriL") -> jnp.ndarray: """Scaling matrix used in transport between 0-mean Gaussians. diff --git a/tests/solvers/linear/sinkhorn_test.py b/tests/solvers/linear/sinkhorn_test.py index 15445a8f4..c0045bc65 100644 --- a/tests/solvers/linear/sinkhorn_test.py +++ b/tests/solvers/linear/sinkhorn_test.py @@ -493,9 +493,7 @@ def test_sinkhorn_online_memory_jit(self): assert out.converged assert out.primal_cost > 0.0 - @pytest.mark.fast.with_args( - cost_fn=[None, costs.SqPNorm(1.6)], - ) + @pytest.mark.fast.with_args(cost_fn=[None, costs.SqPNorm(1.6)]) def test_primal_cost_grid(self, cost_fn: Optional[costs.CostFn]): """Test computation of primal / costs for Grids.""" ns = [6, 7, 11]