Skip to content

Commit

Permalink
Feature/lineax ridge (#424)
Browse files Browse the repository at this point in the history
* Update pre-commits

* Dedent `tox` in `pyproject.toml`

* Temporarily indent again

* Run linters

* Update for `tox>=4.9.0`

* Add `ridge` options to `lineax`

* Modify test for `lineax` ridge

* Fix linkchecker
  • Loading branch information
michalk8 authored Sep 5, 2023
1 parent 937e049 commit 21d3627
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 23 deletions.
2 changes: 2 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@
"https://doi.org/10.1137/17M1140431",
"https://doi.org/10.1137/141000439",
"https://doi.org/10.1002/mana.19901470121",
"https://doi.org/10.1145/2516971.2516977",
"https://doi.org/10.1145/2766963",
]

# List of patterns, relative to source directory, that match files and
Expand Down
19 changes: 17 additions & 2 deletions src/ott/solvers/linear/lineax_implicit.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def solve_lineax(
lin_t: Optional[Callable] = None,
symmetric: bool = False,
nonsym_solver: Optional[lx.AbstractLinearSolver] = None,
ridge_identity: float = 0.0,
ridge_kernel: float = 0.0,
**kwargs: Any
) -> jnp.ndarray:
"""Wrapper around lineax solvers.
Expand All @@ -61,23 +63,36 @@ def solve_lineax(
symmetric: whether `lin` is symmetric.
nonsym_solver: solver used when handling non-symmetric cases. Note that
:class:`~lineax.CG` is used by default in the symmetric case.
ridge_kernel: promotes zero-sum solutions. Only use if `tau_a = tau_b = 1.0`
ridge_identity: handles rank deficient transport matrices (this happens
typically when rows/cols in cost/kernel matrices are collinear, or,
equivalently when two points from either measure are close).
kwargs: arguments passed to :class:`~lineax.AbstractLinearSolver` linear
solver.
"""
input_structure = jax.eval_shape(lambda: b)
kwargs.setdefault("rtol", 1e-6)
kwargs.setdefault("atol", 1e-6)

if ridge_kernel > 0.0 or ridge_identity > 0.0:
lin_reg = lambda x: lin(x) + ridge_kernel * jnp.sum(x) + ridge_identity * x
lin_t_reg = lambda x: lin_t(x) + ridge_kernel * jnp.sum(
x
) + ridge_identity * x
else:
lin_reg, lin_t_reg = lin, lin_t

if symmetric:
solver = lx.CG(**kwargs)
fn_operator = lx.FunctionLinearOperator(
lin, input_structure, tags=lx.positive_semidefinite_tag
lin_reg, input_structure, tags=lx.positive_semidefinite_tag
)
return lx.linear_solve(fn_operator, b, solver).value
# In the non-symmetric case, use NormalCG by default, but consider
# user defined choice of alternative lx solver.
solver_type = lx.NormalCG if nonsym_solver is None else nonsym_solver
solver = solver_type(**kwargs)
fn_operator = CustomTransposeLinearOperator(
lin, lin_t, input_structure, input_structure
lin_reg, lin_t_reg, input_structure, input_structure
)
return lx.linear_solve(fn_operator, b, solver).value
44 changes: 23 additions & 21 deletions tests/solvers/linear/sinkhorn_diff_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,43 +730,45 @@ def loss_from_potential(
class TestSinkhornHessian:

@pytest.mark.fast.with_args(
"lse_mode,tau_a,tau_b,arg", (
(True, 1.0, 1.0, 0),
(False, 1.0, 1.0, 0),
(True, 1.0, 1.0, 1),
(True, 1.0, .91, 0),
(True, .93, .91, 1),
(False, .93, .91, 1),
"lse_mode,tau_a,tau_b,arg,lineax_ridge", (
(True, 1.0, 1.0, 0, 0.0),
(False, 1.0, 1.0, 0, 1e-8),
(True, 1.0, 1.0, 1, 0.0),
(True, 1.0, 0.91, 0, 1e-7),
(True, 0.93, 0.91, 1, 0.0),
(False, 0.93, 0.91, 1, 1e-5),
),
only_fast=-1
)
def test_hessian_sinkhorn(
self, rng: jax.random.PRNGKeyArray, lse_mode: bool, tau_a: float,
tau_b: float, arg: int
tau_b: float, arg: int, lineax_ridge: float
):
"""Test hessian w.r.t. weights and locations."""
try:
from ott.solvers.linear import lineax_implicit # noqa: F401
test_back = True
ridge = lineax_ridge
except ImportError:
test_back = False
ridge = 1e-5

n, m = (12, 15)
dim = 3
rngs = jax.random.split(rng, 6)
x = jax.random.uniform(rngs[0], (n, dim))
y = jax.random.uniform(rngs[1], (m, dim))
a = jax.random.uniform(rngs[2], (n,)) + .1
b = jax.random.uniform(rngs[3], (m,)) + .1
a = jax.random.uniform(rngs[2], (n,)) + 0.1
b = jax.random.uniform(rngs[3], (m,)) + 0.1
a = a / jnp.sum(a)
b = b / jnp.sum(b)
epsilon = 0.1

## Add a ridge when using JAX solvers.
try:
from ott.solvers.linear import lineax_implicit # noqa: F401
solver_kwargs = {}
test_back = True
except ImportError:
solver_kwargs = {
"ridge_identity": 1e-5,
"ridge_kernel": 1e-5 if tau_a == tau_b == 1.0 else 0.0
}
test_back = False
# Add a ridge when using JAX solvers, smaller ridge for lineax solvers
solver_kwargs = {
"ridge_identity": ridge,
"ridge_kernel": ridge if tau_a == tau_b == 1.0 else 0.0
}

imp_dif = implicit_lib.ImplicitDiff(solver_kwargs=solver_kwargs)

Expand Down

0 comments on commit 21d3627

Please sign in to comment.