Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jax backend: jax.errors.TracerArrayConversionError #625

Open
act65 opened this issue May 21, 2024 · 3 comments
Open

Jax backend: jax.errors.TracerArrayConversionError #625

act65 opened this issue May 21, 2024 · 3 comments

Comments

@act65
Copy link

act65 commented May 21, 2024

Describe the bug

As far as I understand it, I should be able to use this library paired with jax via your backend switching (depending on the input types)? However, I am getting a jax.errors.TracerArrayConversionError which seems to be arising as POT is converting to numpy (not jax.numpy) in the backend (despite me giving only jax.numpy inputs).

To Reproduce

import jax.numpy as jnp 
from jax import random, grad
import ot as pot

key = random.PRNGKey(0)
B = 10

key, subkey = random.split(key)
x = random.normal(subkey, (B, 1))
key, subkey = random.split(key)
y = random.normal(subkey, (B, 1))


def loss_fn(x, y):
    costs = jnp.linalg.norm(x[:, None] - y[None, :], axis=-1)**2

    pi = pot.emd(
        jnp.ones(B) / B, 
        jnp.ones(B) / B, 
        costs)

    return jnp.sum(pi * costs)

g = grad(loss_fn)(x, y)
print(g)

(note the problem isn't specific to grad. it also applies to; vmap, jit, ...)

Traceback (most recent call last):
  File "/home/telfaralex/Documents/phdv2/code/sinterp/sinterp/tests/test_couplings.py", line 58, in test_grad
    g = grad(loss_fn)(x, y)
        ^^^^^^^^^^^^^^^^^^^
  File "/home/telfaralex/Documents/phdv2/code/sinterp/sinterp/tests/test_couplings.py", line 51, in loss_fn
    pi = ot_fn(
         ^^^^^^
  File "/home/telfaralex/miniconda3/lib/python3.11/site-packages/ot/lp/__init__.py", line 318, in emd
    M, a, b = nx.to_numpy(M, a, b)
              ^^^^^^^^^^^^^^^^^^^^
  File "/home/telfaralex/miniconda3/lib/python3.11/site-packages/ot/backend.py", line 260, in to_numpy
    return [self._to_numpy(array) for array in arrays]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/telfaralex/miniconda3/lib/python3.11/site-packages/ot/backend.py", line 260, in <listcomp>
    return [self._to_numpy(array) for array in arrays]
            ^^^^^^^^^^^^^^^^^^^^^
  File "/home/telfaralex/miniconda3/lib/python3.11/site-packages/ot/backend.py", line 1439, in _to_numpy
    return np.array(a)
           ^^^^^^^^^^^
jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float32[10,10]
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

Environment:

  • OS (e.g. MacOS, Windows, Linux): Linux
  • Python version: 3.11.4
  • How was POT installed (source, pip, conda): pip. v0.9.3
@rflamary
Copy link
Collaborator

you shoudl use ot.emd2 that returns the ot loss (no need to sum) a value hith proper grads. ot.emd function returns an OT plan that is indeed detached from teh uput since the exact OT plan is not differentiable. Could you please try that and tell us if you have the same error ?

@act65
Copy link
Author

act65 commented May 23, 2024

I tried with ot.emd2. Same issue.
Like I said, the issue isn't with grad. It also applies to jit, vmap, ... any of jax's code transformation fns.
In the ot backend, numpy is being used where it should be jax.numpy!?

@rflamary
Copy link
Collaborator

jit anf vmap will NOT work for exact ot solver: they use specific C++ solvers and the backend format does not allow us to handle that properly with jax. Grad should work and should be tested. We will look into that after some pressing deadlines. I you have some ideas please help us while we provide a jax backend we are mainly pytorch users and no jax experts.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants