diff --git a/docs/conf.py b/docs/conf.py index 8c5dc72be..2e1257ff6 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -125,6 +125,8 @@ "https://doi.org/10.1137/17M1140431", "https://doi.org/10.1137/141000439", "https://doi.org/10.1002/mana.19901470121", + "https://doi.org/10.1145/2516971.2516977", + "https://doi.org/10.1145/2766963", ] # List of patterns, relative to source directory, that match files and diff --git a/src/ott/solvers/linear/lineax_implicit.py b/src/ott/solvers/linear/lineax_implicit.py index beb03f929..75259a1c0 100644 --- a/src/ott/solvers/linear/lineax_implicit.py +++ b/src/ott/solvers/linear/lineax_implicit.py @@ -50,6 +50,8 @@ def solve_lineax( lin_t: Optional[Callable] = None, symmetric: bool = False, nonsym_solver: Optional[lx.AbstractLinearSolver] = None, + ridge_identity: float = 0.0, + ridge_kernel: float = 0.0, **kwargs: Any ) -> jnp.ndarray: """Wrapper around lineax solvers. @@ -61,16 +63,29 @@ def solve_lineax( symmetric: whether `lin` is symmetric. nonsym_solver: solver used when handling non-symmetric cases. Note that :class:`~lineax.CG` is used by default in the symmetric case. + ridge_kernel: promotes zero-sum solutions. Only use if `tau_a = tau_b = 1.0` + ridge_identity: handles rank deficient transport matrices (this happens + typically when rows/cols in cost/kernel matrices are collinear, or, + equivalently when two points from either measure are close). kwargs: arguments passed to :class:`~lineax.AbstractLinearSolver` linear solver. """ input_structure = jax.eval_shape(lambda: b) kwargs.setdefault("rtol", 1e-6) kwargs.setdefault("atol", 1e-6) + + if ridge_kernel > 0.0 or ridge_identity > 0.0: + lin_reg = lambda x: lin(x) + ridge_kernel * jnp.sum(x) + ridge_identity * x + lin_t_reg = lambda x: lin_t(x) + ridge_kernel * jnp.sum( + x + ) + ridge_identity * x + else: + lin_reg, lin_t_reg = lin, lin_t + if symmetric: solver = lx.CG(**kwargs) fn_operator = lx.FunctionLinearOperator( - lin, input_structure, tags=lx.positive_semidefinite_tag + lin_reg, input_structure, tags=lx.positive_semidefinite_tag ) return lx.linear_solve(fn_operator, b, solver).value # In the non-symmetric case, use NormalCG by default, but consider @@ -78,6 +93,6 @@ def solve_lineax( solver_type = lx.NormalCG if nonsym_solver is None else nonsym_solver solver = solver_type(**kwargs) fn_operator = CustomTransposeLinearOperator( - lin, lin_t, input_structure, input_structure + lin_reg, lin_t_reg, input_structure, input_structure ) return lx.linear_solve(fn_operator, b, solver).value diff --git a/tests/solvers/linear/sinkhorn_diff_test.py b/tests/solvers/linear/sinkhorn_diff_test.py index 4eff77eea..3d0371954 100644 --- a/tests/solvers/linear/sinkhorn_diff_test.py +++ b/tests/solvers/linear/sinkhorn_diff_test.py @@ -730,43 +730,45 @@ def loss_from_potential( class TestSinkhornHessian: @pytest.mark.fast.with_args( - "lse_mode,tau_a,tau_b,arg", ( - (True, 1.0, 1.0, 0), - (False, 1.0, 1.0, 0), - (True, 1.0, 1.0, 1), - (True, 1.0, .91, 0), - (True, .93, .91, 1), - (False, .93, .91, 1), + "lse_mode,tau_a,tau_b,arg,lineax_ridge", ( + (True, 1.0, 1.0, 0, 0.0), + (False, 1.0, 1.0, 0, 1e-8), + (True, 1.0, 1.0, 1, 0.0), + (True, 1.0, 0.91, 0, 1e-7), + (True, 0.93, 0.91, 1, 0.0), + (False, 0.93, 0.91, 1, 1e-5), ), only_fast=-1 ) def test_hessian_sinkhorn( self, rng: jax.random.PRNGKeyArray, lse_mode: bool, tau_a: float, - tau_b: float, arg: int + tau_b: float, arg: int, lineax_ridge: float ): """Test hessian w.r.t. weights and locations.""" + try: + from ott.solvers.linear import lineax_implicit # noqa: F401 + test_back = True + ridge = lineax_ridge + except ImportError: + test_back = False + ridge = 1e-5 + n, m = (12, 15) dim = 3 rngs = jax.random.split(rng, 6) x = jax.random.uniform(rngs[0], (n, dim)) y = jax.random.uniform(rngs[1], (m, dim)) - a = jax.random.uniform(rngs[2], (n,)) + .1 - b = jax.random.uniform(rngs[3], (m,)) + .1 + a = jax.random.uniform(rngs[2], (n,)) + 0.1 + b = jax.random.uniform(rngs[3], (m,)) + 0.1 a = a / jnp.sum(a) b = b / jnp.sum(b) epsilon = 0.1 - ## Add a ridge when using JAX solvers. - try: - from ott.solvers.linear import lineax_implicit # noqa: F401 - solver_kwargs = {} - test_back = True - except ImportError: - solver_kwargs = { - "ridge_identity": 1e-5, - "ridge_kernel": 1e-5 if tau_a == tau_b == 1.0 else 0.0 - } - test_back = False + # Add a ridge when using JAX solvers, smaller ridge for lineax solvers + solver_kwargs = { + "ridge_identity": ridge, + "ridge_kernel": ridge if tau_a == tau_b == 1.0 else 0.0 + } imp_dif = implicit_lib.ImplicitDiff(solver_kwargs=solver_kwargs)