Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unbalanced FGW doesn't converge when margins are provided #519

Open
Tracked by #677
selmanozleyen opened this issue Apr 17, 2024 · 3 comments
Open
Tracked by #677

Unbalanced FGW doesn't converge when margins are provided #519

selmanozleyen opened this issue Apr 17, 2024 · 3 comments

Comments

@selmanozleyen
Copy link
Contributor

selmanozleyen commented Apr 17, 2024

Describe the bug
For application use case see tests from moscot https://github.com/theislab/moscot/actions/runs/8709537760/job/23889450330?pr=677

Unbalanced FGW is unstable especially when margins are provided. I played with epsilon and tau's but still doesn't converge. I think this happened after 41906a2

To Reproduce

import numpy as np
import jax.numpy as jnp
from ott.geometry import pointcloud
from ott.solvers.quadratic import solve


# Generating random data for x and y
x = np.random.rand(96, 2)  # 96 points in 2D
y = np.random.rand(96, 2)  # Another 96 points in 2D

# Create PointCloud instances
geom_xx = pointcloud.PointCloud(x)
geom_yy = pointcloud.PointCloud(y)
geom_xy = pointcloud.PointCloud(x, y)

# a and b are vectors of ones with lengths matching the number of points in x and y, respectively
a = jnp.ones(x.shape[0])
b = jnp.ones(y.shape[0])

# Call solve function with the specified parameters
solve(geom_xx=geom_xx, geom_yy=geom_yy, geom_xy=geom_xy, tau_a=0.9, tau_b=0.9,
      fused_penalty=1.0, epsilon=1.0, a=a, b=b)
@selmanozleyen selmanozleyen changed the title Unbalanced FGW is doesn't converge when margins are provided Unbalanced FGW doesn't converge when margins are provided Apr 17, 2024
@michalk8
Copy link
Collaborator

michalk8 commented Apr 24, 2024

Hi @selmanozleyen , this seems to come from numerical imprecisions; more specifically, the NaNs come directly from initialization here, where marginal_1 is an array of all 0s (leads to a transport mass of 0), and later to the rescaling factor to be NaN.
I will take a look whether there's more numerically stable way of computing this, however simply using

a = jnp.ones(x.shape[0]) / x.shape[0]
b = jnp.ones(y.shape[0]) / y.shape[0]

solves to numerical precision issues.

@selmanozleyen
Copy link
Contributor Author

@michalk8, as you said when I normalize it works. But when they don't sum to 1 it still doesn't work in many cases. For example see the cases below. I'd assume unbalanced ot to not expect marginals sum to 1

a = np.ones(x.shape[0])*2
a[0:4] = 1
b = np.ones(y.shape[0])*2
b[0:4] = 1
# or 
a = np.ones(x.shape[0])*2
b = np.ones(y.shape[0])*2

@marcocuturi
Copy link
Contributor

Thanks @selmanozleyen . I think what's happening here is a problem of scales. Although it may seem dividing/multiplying a/b by a constant should have no bearing on the optimization, in the case of entropic GW this is likely not the case because of the interplay with other parameters (notably epsilon but also more generally the scale of the cost matrix, since the unbalanced problem adds a KL term.

Tangentially related: I think the converged flag in GW was bugged, as discussed in #566

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants