Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ot.emd2() does not work as expected with empty weights if the JAX backend is used #534

Closed
Francis-Hsu opened this issue Oct 15, 2023 · 4 comments

Comments

@Francis-Hsu
Copy link

Francis-Hsu commented Oct 15, 2023

Describe the bug

Per documentation of ot.emd2(), uniform weights will be used if empty lists are passed as the arguments. However, doing so with the JAX backend will cause broadcasting issue.

To Reproduce

Simulate some data first:

import jax
from jax import numpy as jnp


key = jax.random.PRNGKey(1)
x = jax.random.normal(key, (100, 2))
y = jax.random.normal(key, (100, 2))

With numpy backend, the following works without an issue:

from opt_einsum import contract

M = contract('mi,ni->mn', x, y, backend='numpy') ** 2.
emt = np.empty((0))
Wass_dis = ot.emd2(emt, emt, M=M)
Wass_dis

However, errors occur once we switch to jnp:

M = contract('mi,ni->mn', x, y, backend='jax') ** 2.
emt = jnp.empty((0))
Wass_dis = ot.emd2(emt, emt, M=M)
Wass_dis

Partial error message:

File [c:\ProgramData\anaconda3\Lib\site-packages\ot\lp\__init__.py:567](file:///C:/ProgramData/anaconda3/Lib/site-packages/ot/lp/__init__.py:567), in emd2.<locals>.f(b)
    559     warnings.warn(
    560         "Input histogram consists of integer. The transport plan will be "
    561         "casted accordingly, possibly resulting in a loss of precision. "
   (...)
    564         stacklevel=2
    565     )
    566 G = nx.from_numpy(G, type_as=type_as)
--> 567 cost = nx.set_gradients(nx.from_numpy(cost, type_as=type_as),
    568                         (a0, b0, M0), (nx.from_numpy(u - np.mean(u), type_as=type_as),
    569                                        nx.from_numpy(v - np.mean(v), type_as=type_as), G))
    571 check_result(result_code)
    572 return cost

File [c:\ProgramData\anaconda3\Lib\site-packages\ot\backend.py:1392](file:///C:/ProgramData/anaconda3/Lib/site-packages/ot/backend.py:1392), in JaxBackend.set_gradients(self, val, inputs, grads)
   1389 ravelled_inputs, _ = ravel_pytree(inputs)
   1390 ravelled_grads, _ = ravel_pytree(grads)
-> 1392 aux = jnp.sum(ravelled_inputs * ravelled_grads) / 2
   1393 aux = aux - jax.lax.stop_gradient(aux)
   1395 val, = jax.tree_map(lambda z: z + aux, (val,))

File [c:\ProgramData\anaconda3\Lib\site-packages\jax\_src\numpy\array_methods.py:256](file:///C:/ProgramData/anaconda3/Lib/site-packages/jax/_src/numpy/array_methods.py:256), in _defer_to_unrecognized_arg.<locals>.deferring_binary_op(self, other)
    254 args = (other, self) if swap else (self, other)
    255 if isinstance(other, _accepted_binop_types):
--> 256   return binary_op(*args)
    257 # Note: don't use isinstance here, because we don't want to raise for
    258 # subclasses, e.g. NamedTuple objects that may override operators.
    259 if type(other) in _rejected_binop_types:

    [... skipping hidden 12 frame]

File [c:\ProgramData\anaconda3\Lib\site-packages\jax\_src\numpy\ufuncs.py:97](file:///C:/ProgramData/anaconda3/Lib/site-packages/jax/_src/numpy/ufuncs.py:97), in _maybe_bool_binop.<locals>.fn(x1, x2)
     95 def fn(x1, x2, /):
     96   x1, x2 = promote_args(numpy_fn.__name__, x1, x2)
---> 97   return lax_fn(x1, x2) if x1.dtype != np.bool_ else bool_lax_fn(x1, x2)

    [... skipping hidden 7 frame]

File [c:\ProgramData\anaconda3\Lib\site-packages\jax\_src\lax\lax.py:1591](file:///C:/ProgramData/anaconda3/Lib/site-packages/jax/_src/lax/lax.py:1591), in broadcasting_shape_rule(name, *avals)
   1589       result_shape.append(non_1s[0])
   1590     else:
-> 1591       raise TypeError(f'{name} got incompatible shapes for broadcasting: '
   1592                       f'{", ".join(map(str, map(tuple, shapes)))}.')
   1594 return tuple(result_shape)

TypeError: mul got incompatible shapes for broadcasting: (10000,), (10200,).

Possible solution:

This problem can be avoided if we generate the uniform weight by ourselves:

M = contract('mi,ni->mn', x, y, backend='jax') ** 2.
emt0 = jnp.ones((M.shape[0],)) / M.shape[0]
emt1 = jnp.ones((M.shape[1],)) / M.shape[1]
Wass_dis = ot.emd2(emt0, emt1, M=M)
Wass_dis # correct result

Environment (please complete the following information):

  • OS (e.g. MacOS, Windows, Linux): Windows
  • Python version: 3.11.4
  • How was POT installed (source, pip, conda): pip

Output of the following code snippet:

import platform; print(platform.platform())
import sys; print("Python", sys.version)
import numpy; print("NumPy", numpy.__version__)
import scipy; print("SciPy", scipy.__version__)
import ot; print("POT", ot.__version__)
Windows-10-10.0.22621-SP0
Python 3.11.4 | packaged by Anaconda, Inc. | (main, Jul  5 2023, 13:38:37) [MSC v.1916 64 bit (AMD64)]
NumPy 1.24.3
SciPy 1.10.1
POT 0.9.1
@rflamary
Copy link
Collaborator

Hello @Francis-Hsu and thanks for the feedback. Could you do a quick check and see if there is a bug when you provide actual empty python list (a=[]) instead of empty jax arrays?

Unless I'm mistaken the documentation states "empty list" and the function should handle this well for any backend.

@rflamary
Copy link
Collaborator

also note that for the new API wheights are now optional and there is no need for emty lists:

Wass_dis = ot.solve(M).value

@Francis-Hsu
Copy link
Author

Hello @Francis-Hsu and thanks for the feedback. Could you do a quick check and see if there is a bug when you provide actual empty python list (a=[]) instead of empty jax arrays?

Unless I'm mistaken the documentation states "empty list" and the function should handle this well for any backend.

Hi @rflamary. Thank you for the feedback. If I use ot.emd2([], [], M=M) I will get the type checking error:

ValueError: All array should be from the same type/backend. Current types are : [<class 'jaxlib.xla_extension.ArrayImpl'>, <class 'numpy.ndarray'>, <class 'numpy.ndarray'>]

But indeed the ot.solve(M) interface is much more convenient. I didn't know about it until now :P

@rflamary
Copy link
Collaborator

rflamary commented Mar 4, 2024

This one should be fixed in #606

@rflamary rflamary closed this as completed Mar 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants