ot.emd2()
does not work as expected with empty weights if the JAX backend is used
#534
Labels
ot.emd2()
does not work as expected with empty weights if the JAX backend is used
#534
Describe the bug
Per documentation of
ot.emd2()
, uniform weights will be used if empty lists are passed as the arguments. However, doing so with the JAX backend will cause broadcasting issue.To Reproduce
Simulate some data first:
With
numpy
backend, the following works without an issue:However, errors occur once we switch to
jnp
:Partial error message:
Possible solution:
This problem can be avoided if we generate the uniform weight by ourselves:
Environment (please complete the following information):
pip
,conda
):pip
Output of the following code snippet:
The text was updated successfully, but these errors were encountered: