Skip to content

Commit

Permalink
Fix low-rank convergence criterion (#547)
Browse files Browse the repository at this point in the history
* Fix low-rank convergence criterion

* Remove gamma scaling factor

* Fix `test_progress_fn` test

* Increase tolerance in `test_lr_unbalanced_ti`

* Remove generalized k-means test
  • Loading branch information
michalk8 committed Jul 3, 2024
1 parent d5d8a47 commit 7cfd393
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 63 deletions.
28 changes: 16 additions & 12 deletions src/ott/solvers/linear/sinkhorn_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ def compute_error( # noqa: D102
err_r = mu.gen_js(self.r, previous_state.r, c=1.0)
err_g = mu.gen_js(self.g, previous_state.g, c=1.0)

return ((1.0 / self.gamma) ** 2) * (err_q + err_r + err_g)
# don't scale by (1 / gamma ** 2); https://github.com/ott-jax/ott/pull/547
return err_q + err_r + err_g

def reg_ot_cost( # noqa: D102
self,
Expand Down Expand Up @@ -175,6 +176,7 @@ class LRSinkhornOutput(NamedTuple):
ot_prob: linear_problem.LinearProblem
epsilon: float
inner_iterations: int
converged: bool
# TODO(michalk8): Optional is an artifact of the current impl., refactor
reg_ot_cost: Optional[float] = None

Expand Down Expand Up @@ -221,12 +223,6 @@ def b(self) -> jnp.ndarray: # noqa: D102
def n_iters(self) -> int: # noqa: D102
return jnp.sum(self.errors != -1) * self.inner_iterations

@property
def converged(self) -> bool: # noqa: D102
return jnp.logical_and(
jnp.any(self.costs == -1), jnp.all(jnp.isfinite(self.costs))
)

@property
def matrix(self) -> jnp.ndarray:
"""Transport matrix if it can be instantiated."""
Expand Down Expand Up @@ -687,7 +683,10 @@ def one_iteration(
lambda: state.reg_ot_cost(ot_prob, epsilon=self.epsilon),
lambda: jnp.inf
)
error = state.compute_error(previous_state)
error = jax.lax.cond(
iteration >= self.min_iterations,
lambda: state.compute_error(previous_state), lambda: jnp.inf
)
crossed_threshold = jnp.logical_or(
state.crossed_threshold,
jnp.logical_and(
Expand Down Expand Up @@ -761,6 +760,8 @@ def output_from_state(
Returns:
A LRSinkhornOutput.
"""
it = jnp.sum(state.errors != -1.0) * self.inner_iterations
converged = self._converged(state, it)
return LRSinkhornOutput(
q=state.q,
r=state.r,
Expand All @@ -770,6 +771,7 @@ def output_from_state(
errors=state.errors,
epsilon=self.epsilon,
inner_iterations=self.inner_iterations,
converged=converged,
)

def _converged(self, state: LRSinkhornState, iteration: int) -> bool:
Expand Down Expand Up @@ -800,11 +802,13 @@ def conv_not_crossed(prev_err: float, curr_err: float) -> bool:
)

def _diverged(self, state: LRSinkhornState, iteration: int) -> bool:
it = iteration // self.inner_iterations
return jnp.logical_and(
jnp.logical_not(jnp.isfinite(state.errors[it - 1])),
jnp.logical_not(jnp.isfinite(state.costs[it - 1]))
it = iteration // self.inner_iterations - 1
is_not_finite = jnp.logical_and(
jnp.logical_not(jnp.isfinite(state.errors[it])),
jnp.logical_not(jnp.isfinite(state.costs[it]))
)
# `jnp.inf` is used if `it < self.min_iterations`
return jnp.logical_and(it >= self.min_iterations, is_not_finite)


def run(
Expand Down
26 changes: 15 additions & 11 deletions src/ott/solvers/quadratic/gromov_wasserstein_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def ent(x: jnp.ndarray) -> float:
errors=None,
epsilon=None,
inner_iterations=None,
converged=False,
)

cost = out.primal_cost - epsilon * (ent(q) + ent(r) + ent(g))
Expand All @@ -148,6 +149,7 @@ class LRGWOutput(NamedTuple):
ot_prob: quadratic_problem.QuadraticProblem
epsilon: float
inner_iterations: int
converged: bool
reg_gw_cost: Optional[float] = None

def set(self, **kwargs: Any) -> "LRGWOutput":
Expand Down Expand Up @@ -194,12 +196,6 @@ def b(self) -> jnp.ndarray: # noqa: D102
def n_iters(self) -> int: # noqa: D102
return jnp.sum(self.errors != -1) * self.inner_iterations

@property
def converged(self) -> bool: # noqa: D102
return jnp.logical_and(
jnp.any(self.costs == -1), jnp.all(jnp.isfinite(self.costs))
)

@property
def matrix(self) -> jnp.ndarray:
"""Transport matrix if it can be instantiated."""
Expand Down Expand Up @@ -718,7 +714,10 @@ def one_iteration(
lambda: state.reg_gw_cost(ot_prob, epsilon=self.epsilon),
lambda: jnp.inf
)
error = state.compute_error(previous_state)
error = jax.lax.cond(
iteration >= self.min_iterations,
lambda: state.compute_error(previous_state), lambda: jnp.inf
)
crossed_threshold = jnp.logical_or(
state.crossed_threshold,
jnp.logical_and(
Expand Down Expand Up @@ -794,6 +793,8 @@ def output_from_state(
Returns:
A LRGWOutput.
"""
it = jnp.sum(state.errors != -1.0) * self.inner_iterations
converged = self._converged(state, it)
return LRGWOutput(
q=state.q,
r=state.r,
Expand All @@ -803,6 +804,7 @@ def output_from_state(
errors=state.errors,
epsilon=self.epsilon,
inner_iterations=self.inner_iterations,
converged=converged,
)

def _converged(self, state: LRGWState, iteration: int) -> bool:
Expand Down Expand Up @@ -833,11 +835,13 @@ def conv_not_crossed(prev_err: float, curr_err: float) -> bool:
)

def _diverged(self, state: LRGWState, iteration: int) -> bool:
it = iteration // self.inner_iterations
return jnp.logical_and(
jnp.logical_not(jnp.isfinite(state.errors[it - 1])),
jnp.logical_not(jnp.isfinite(state.costs[it - 1]))
it = iteration // self.inner_iterations - 1
is_not_finite = jnp.logical_and(
jnp.logical_not(jnp.isfinite(state.errors[it])),
jnp.logical_not(jnp.isfinite(state.costs[it]))
)
# `jnp.inf` is used if `it < self.min_iterations`
return jnp.logical_and(it >= self.min_iterations, is_not_finite)


def run(
Expand Down
31 changes: 1 addition & 30 deletions tests/initializers/linear/sinkhorn_lr_init_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import jax.numpy as jnp
import numpy as np

from ott.geometry import geometry, pointcloud
from ott.geometry import pointcloud
from ott.initializers.linear import initializers_lr
from ott.problems.linear import linear_problem
from ott.solvers.linear import sinkhorn_lr
Expand Down Expand Up @@ -82,35 +82,6 @@ def test_generalized_k_means_has_correct_rank(
assert jnp.linalg.matrix_rank(q) == rank
assert jnp.linalg.matrix_rank(r) == rank

def test_generalized_k_means_matches_k_means(self, rng: jax.Array):
n, d, rank = 27, 7, 5
eps = 1e-1
rng1, rng2 = jax.random.split(rng, 2)
x = jax.random.normal(rng1, (n, d))
y = jax.random.normal(rng1, (n, d))

pc = pointcloud.PointCloud(x, y, epsilon=eps)
geom = geometry.Geometry(cost_matrix=pc.cost_matrix, epsilon=eps)
pc_problem = linear_problem.LinearProblem(pc)
geom_problem = linear_problem.LinearProblem(geom)

solver = sinkhorn_lr.LRSinkhorn(
rank=rank, initializer="k-means", max_iterations=5000
)
pc_out = solver(pc_problem)

solver = sinkhorn_lr.LRSinkhorn(
rank=rank, initializer="generalized-k-means", max_iterations=5000
)
geom_out = solver(geom_problem)

with pytest.raises(AssertionError):
np.testing.assert_allclose(pc_out.costs, geom_out.costs)

np.testing.assert_allclose(
pc_out.reg_ot_cost, geom_out.reg_ot_cost, atol=0.5, rtol=0.02
)

@pytest.mark.parametrize("epsilon", [0.0, 1e-1])
def test_better_initialization_helps(self, rng: jax.Array, epsilon: float):
n, d, rank = 81, 13, 3
Expand Down
14 changes: 4 additions & 10 deletions tests/solvers/linear/sinkhorn_lr_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,15 +200,9 @@ def progress_fn(
lin_prob
)

# check that the function is called on the 10th iteration (iter #9), the
# 20th iteration (iter #19).
assert traced_values["iters"] == [9, 19]

# check that error decreases
np.testing.assert_array_equal(np.diff(traced_values["error"]) < 0, True)

# check that max iterations is provided each time: [30, 30]
assert traced_values["total"] == [num_iterations] * 2
assert traced_values["iters"] == [9, 19, 29, 39]
assert traced_values["total"] == [num_iterations
] * len(traced_values["total"])

@pytest.mark.fast.with_args(eps=[0.0, 1e-1])
def test_lse_matches_kernel_mode(self, eps: float):
Expand Down Expand Up @@ -318,7 +312,7 @@ def test_lr_unbalanced_ti(

assert out.converged
assert out_ti.converged
np.testing.assert_allclose(out.errors, out_ti.errors, rtol=1e-4, atol=1e-4)
np.testing.assert_allclose(out.errors, out_ti.errors, rtol=5e-4, atol=5e-4)
np.testing.assert_allclose(
out.reg_ot_cost, out_ti.reg_ot_cost, rtol=1e-2, atol=1e-2
)
Expand Down

0 comments on commit 7cfd393

Please sign in to comment.