Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Slow simulation on new JAX versions #571

Open
michaeldeistler opened this issue Jan 28, 2025 · 0 comments
Open

Slow simulation on new JAX versions #571

michaeldeistler opened this issue Jan 28, 2025 · 0 comments
Labels
bug Something isn't working

Comments

@michaeldeistler
Copy link
Contributor

michaeldeistler commented Jan 28, 2025

Only relevant for CPU: I realized that, with version jax==0.4.32 or newer, simulation time is 10x slower and gradient time is 5x slower as compared to older versions of JAX 🤯

From the CHANGELOG of JAX:

This release of jaxlib switched to a new version of the CPU backend, which should compile faster
and leverage parallelism better. If you experience any problems due to this change, you can
temporarily enable the old CPU backend by setting the environment variable 
XLA_FLAGS=--xla_cpu_use_thunk_runtime=false. If you need to  do this, please file a JAX
bug with instructions to reproduce.

Indeed, os.environ['XLA_FLAGS'] = '--xla_cpu_use_thunk_runtime=false' at the beginning of the notebook (before importing JAX) fixes it also for newer versions of JAX.

See also this issue that I created on the JAX repo.

As an intermediate fix, #570 proposes to pin the JAX version to an old version.

@michaeldeistler michaeldeistler added the bug Something isn't working label Jan 28, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant