diff --git a/tests/solvers/quadratic/gw_test.py b/tests/solvers/quadratic/gw_test.py index e8d7c7ca4..fb3825ce1 100644 --- a/tests/solvers/quadratic/gw_test.py +++ b/tests/solvers/quadratic/gw_test.py @@ -407,3 +407,77 @@ def test_relative_epsilon( else: assert 0.215 < out.reg_gw_cost < 0.22 assert 0.19 < out.primal_cost < 0.20 + + @pytest.mark.parametrize(("tau_a", "tau_b", "eps", "ti"), + [(0.99, 0.95, 0.0, True), (0.9, 0.8, 1e-3, False), + (1.0, 0.999, 0.0, True), (0.5, 1.0, 1e-2, False)]) + def test_gwlr_unbalanced( + self, tau_a: float, tau_b: float, eps: float, ti: bool + ): + geom_x = pointcloud.PointCloud(self.x) + geom_y = pointcloud.PointCloud(self.y) + a = self.a.at[:2].set(0.0) + b = self.b.at[15:20].set(0.0) + prob = quadratic_problem.QuadraticProblem( + geom_x, + geom_y, + a=a, + b=b, + tau_a=tau_a, + tau_b=tau_b, + ) + solver = jax.jit( + gromov_wasserstein_lr.LRGromovWasserstein( + rank=4, epsilon=eps, kwargs_dys={"translation_invariant": ti} + ) + ) + + res = solver(prob) + + np.testing.assert_array_equal(jnp.isfinite(res.errors), True) + np.testing.assert_array_equal(jnp.isfinite(res.costs), True) + + @pytest.mark.parametrize(("rank", "eps"), [(5, 0.0), (10, 1e-3), (15, 1e-2)]) + def test_gwlr_unbalanced_matches_balanced( + self, rank: int, eps: float, enable_x64: bool + ): + del enable_x64 + + geom_x = pointcloud.PointCloud(self.x) + geom_y = pointcloud.PointCloud(self.y) + prob = quadratic_problem.QuadraticProblem( + geom_x, + geom_y, + a=self.a, + b=self.b, + tau_a=1.0, + tau_b=1.0, + ) + prob_unbal = quadratic_problem.QuadraticProblem( + geom_x, + geom_y, + a=self.a, + b=self.b, + tau_a=0.9999, + tau_b=0.9999, + ) + solver = jax.jit( + gromov_wasserstein_lr.LRGromovWasserstein( + rank=rank, + epsilon=eps, + initializer="random", + min_iterations=50, + max_iterations=50 + ) + ) + + res = solver(prob) + res_unbal = solver(prob_unbal) + + np.testing.assert_allclose(res.transport_mass, 1.0, rtol=1e-4, atol=1e-4) + np.testing.assert_allclose( + res.transport_mass, res_unbal.transport_mass, rtol=1e-4, atol=1e-4 + ) + np.testing.assert_allclose( + res.primal_cost, res_unbal.primal_cost, rtol=1e-3, atol=1e-3 + )