Skip to content

Commit

Permalink
Update pre-commits (#416)
Browse files Browse the repository at this point in the history
* Update pre-commits

* Dedent `tox` in `pyproject.toml`

* Temporarily indent again

* Run linters

* Update for `tox>=4.9.0`
  • Loading branch information
michalk8 authored Aug 22, 2023
1 parent 441576c commit 137fd3a
Show file tree
Hide file tree
Showing 10 changed files with 25 additions and 25 deletions.
12 changes: 6 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,22 @@ default_language_version:
default_stages:
- commit
- push
minimum_pre_commit_version: 2.16.0
minimum_pre_commit_version: 3.0.0
repos:
- repo: https://github.com/google/yapf
rev: v0.32.0
rev: v0.40.0
hooks:
- id: yapf
additional_dependencies: [toml]
- repo: https://github.com/nbQA-dev/nbQA
rev: 1.6.3
rev: 1.7.0
hooks:
- id: nbqa-pyupgrade
args: [--py38-plus]
- id: nbqa-black
- id: nbqa-isort
- repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks
rev: v2.7.0
rev: v2.10.0
hooks:
- id: pretty-format-yaml
args: [--autofix, --indent, '2']
Expand All @@ -38,12 +38,12 @@ repos:
- id: check-case-conflict
- repo: https://github.com/charliermarsh/ruff-pre-commit
# Ruff version.
rev: v0.0.252
rev: v0.0.285
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
- repo: https://github.com/rstcheck/rstcheck
rev: v6.1.1
rev: v6.1.2
hooks:
- id: rstcheck
additional_dependencies: [tomli]
Expand Down
1 change: 0 additions & 1 deletion LICENSE
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
Expand Down
2 changes: 2 additions & 0 deletions docs/spelling/technical.txt
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ logit
macOS
methylation
neuroimaging
normed
numerics
omics
optimality
Expand All @@ -89,6 +90,7 @@ parameterizing
piecewise
pluripotent
positivity
postfix
potentials
precompile
precompute
Expand Down
10 changes: 5 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ ignore_path = ["docs/**/_autosummary", "docs/contributing.rst"]
legacy_tox_ini = """
[tox]
min_version = 4.0
env_list = lint-code,py{3.8,3.9,3.10,3.11}
env_list = lint-code,py{3.8,3.9,3.10,3.11},py3.9-jax-default
skip_missing_interpreters = true
[testenv]
Expand Down Expand Up @@ -299,10 +299,10 @@ select = [
unfixable = ["B", "UP", "C4", "BLE", "T20", "RET"]
target-version = "py38"
[tool.ruff.per-file-ignores]
"tests/*" = ["D", "PT004"] # TODO(michalk8): remove `self.initialize` in `tests/`
"*/__init__.py" = ["F401"]
"docs/*" = ["D"]
"src/ott/types.py" = ["D102"]
"tests/*" = ["D", "PT004"] # TODO(michalk8): remove `self.initialize` in `tests/`
"*/__init__.py" = ["F401"]
"docs/*" = ["D"]
"src/ott/types.py" = ["D102"]
[tool.ruff.pydocstyle]
convention = "google"
[tool.ruff.pyupgrade]
Expand Down
4 changes: 2 additions & 2 deletions src/ott/geometry/costs.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,8 +614,8 @@ def _reg(self, z: jnp.ndarray) -> float: # noqa: D102
# Choose first index satisfying constraint in Prop 2.1
lower_bound = cesaro - top_w >= 0
# Last upper bound is always True.
upper_bound = jnp.concatenate(((top_w[1:] - cesaro[:-1] > 0),
jnp.array((True,))))
upper_bound = jnp.concatenate(((top_w[1:] - cesaro[:-1]
> 0), jnp.array((True,))))
r = jnp.argmax(lower_bound * upper_bound)
s = jnp.sum(jnp.where(jnp.arange(k) < k - r - 1, jnp.flip(top_w) ** 2, 0))

Expand Down
4 changes: 2 additions & 2 deletions src/ott/initializers/linear/initializers_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,8 +613,8 @@ def body_fn(
crossed_threshold = jnp.logical_or(
state.crossed_threshold,
jnp.logical_and(
state.criterions[it - 1] >= consts.threshold,
criterion < consts.threshold
state.criterions[it - 1] >= consts.threshold, criterion
< consts.threshold
)
)

Expand Down
4 changes: 2 additions & 2 deletions src/ott/solvers/linear/acceleration.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,8 @@ def weight(self, state: "sinkhorn.SinkhornState", iteration: int) -> float:

return jax.lax.cond(
jnp.logical_and(
iteration >= self.start,
state.errors[idx - 1, -1] < self.error_threshold
iteration >= self.start, state.errors[idx - 1, -1]
< self.error_threshold
), lambda state: self.lehmann(state), lambda state: self.value, state
)

Expand Down
4 changes: 2 additions & 2 deletions src/ott/solvers/linear/sinkhorn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1077,8 +1077,8 @@ def output_from_state(
# position) is lower than the threshold.

converged = jnp.logical_and(
jnp.logical_not(jnp.any(jnp.isnan(state.errors))),
state.errors[-1] < self.threshold
jnp.logical_not(jnp.any(jnp.isnan(state.errors))), state.errors[-1]
< self.threshold
)[0]

return SinkhornOutput(
Expand Down
5 changes: 3 additions & 2 deletions src/ott/tools/gaussian_mixture/scale_tril.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,9 @@ def _flatten_cov(cov: jnp.ndarray) -> jnp.ndarray:
x0 = _flatten_cov(self.covariance())
x1 = _flatten_cov(other.covariance())
cost_fn = costs.Bures(dimension=dimension)
return (cost_fn.norm(x0) + cost_fn.norm(x1) +
cost_fn.pairwise(x0, x1))[...,]
return (cost_fn.norm(x0) + cost_fn.norm(x1) + cost_fn.pairwise(x0, x1))[
...,
]

def gaussian_map(self, dest_scale: "ScaleTriL") -> jnp.ndarray:
"""Scaling matrix used in transport between 0-mean Gaussians.
Expand Down
4 changes: 1 addition & 3 deletions tests/solvers/linear/sinkhorn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,9 +493,7 @@ def test_sinkhorn_online_memory_jit(self):
assert out.converged
assert out.primal_cost > 0.0

@pytest.mark.fast.with_args(
cost_fn=[None, costs.SqPNorm(1.6)],
)
@pytest.mark.fast.with_args(cost_fn=[None, costs.SqPNorm(1.6)])
def test_primal_cost_grid(self, cost_fn: Optional[costs.CostFn]):
"""Test computation of primal / costs for Grids."""
ns = [6, 7, 11]
Expand Down

0 comments on commit 137fd3a

Please sign in to comment.