-
Notifications
You must be signed in to change notification settings - Fork 82
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Unbalanced FGW doesn't converge when margins are provided #519
Comments
Hi @selmanozleyen , this seems to come from numerical imprecisions; more specifically, the NaNs come directly from initialization here, where a = jnp.ones(x.shape[0]) / x.shape[0]
b = jnp.ones(y.shape[0]) / y.shape[0] solves to numerical precision issues. |
@michalk8, as you said when I normalize it works. But when they don't sum to 1 it still doesn't work in many cases. For example see the cases below. I'd assume unbalanced ot to not expect marginals sum to 1 a = np.ones(x.shape[0])*2
a[0:4] = 1
b = np.ones(y.shape[0])*2
b[0:4] = 1
# or
a = np.ones(x.shape[0])*2
b = np.ones(y.shape[0])*2 |
Thanks @selmanozleyen . I think what's happening here is a problem of scales. Although it may seem dividing/multiplying Tangentially related: I think the |
Describe the bug
For application use case see tests from moscot https://github.com/theislab/moscot/actions/runs/8709537760/job/23889450330?pr=677
Unbalanced FGW is unstable especially when margins are provided. I played with epsilon and tau's but still doesn't converge. I think this happened after 41906a2
To Reproduce
The text was updated successfully, but these errors were encountered: