Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
michalk8 committed Sep 7, 2023
1 parent e06f1b0 commit f04757f
Showing 1 changed file with 74 additions and 0 deletions.
74 changes: 74 additions & 0 deletions tests/solvers/quadratic/gw_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,3 +407,77 @@ def test_relative_epsilon(
else:
assert 0.215 < out.reg_gw_cost < 0.22
assert 0.19 < out.primal_cost < 0.20

@pytest.mark.parametrize(("tau_a", "tau_b", "eps", "ti"),
[(0.99, 0.95, 0.0, True), (0.9, 0.8, 1e-3, False),
(1.0, 0.999, 0.0, True), (0.5, 1.0, 1e-2, False)])
def test_gwlr_unbalanced(
self, tau_a: float, tau_b: float, eps: float, ti: bool
):
geom_x = pointcloud.PointCloud(self.x)
geom_y = pointcloud.PointCloud(self.y)
a = self.a.at[:2].set(0.0)
b = self.b.at[15:20].set(0.0)
prob = quadratic_problem.QuadraticProblem(
geom_x,
geom_y,
a=a,
b=b,
tau_a=tau_a,
tau_b=tau_b,
)
solver = jax.jit(
gromov_wasserstein_lr.LRGromovWasserstein(
rank=4, epsilon=eps, kwargs_dys={"translation_invariant": ti}
)
)

res = solver(prob)

np.testing.assert_array_equal(jnp.isfinite(res.errors), True)
np.testing.assert_array_equal(jnp.isfinite(res.costs), True)

@pytest.mark.parametrize(("rank", "eps"), [(5, 0.0), (10, 1e-3), (15, 1e-2)])
def test_gwlr_unbalanced_matches_balanced(
self, rank: int, eps: float, enable_x64: bool
):
del enable_x64

geom_x = pointcloud.PointCloud(self.x)
geom_y = pointcloud.PointCloud(self.y)
prob = quadratic_problem.QuadraticProblem(
geom_x,
geom_y,
a=self.a,
b=self.b,
tau_a=1.0,
tau_b=1.0,
)
prob_unbal = quadratic_problem.QuadraticProblem(
geom_x,
geom_y,
a=self.a,
b=self.b,
tau_a=0.9999,
tau_b=0.9999,
)
solver = jax.jit(
gromov_wasserstein_lr.LRGromovWasserstein(
rank=rank,
epsilon=eps,
initializer="random",
min_iterations=50,
max_iterations=50
)
)

res = solver(prob)
res_unbal = solver(prob_unbal)

np.testing.assert_allclose(res.transport_mass, 1.0, rtol=1e-4, atol=1e-4)
np.testing.assert_allclose(
res.transport_mass, res_unbal.transport_mass, rtol=1e-4, atol=1e-4
)
np.testing.assert_allclose(
res.primal_cost, res_unbal.primal_cost, rtol=1e-3, atol=1e-3
)

0 comments on commit f04757f

Please sign in to comment.