You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Only relevant for CPU: I realized that, with version jax==0.4.32 or newer, simulation time is 10x slower and gradient time is 5x slower as compared to older versions of JAX 🤯
From the CHANGELOG of JAX:
This release of jaxlib switched to a new version of the CPU backend, which should compile faster
and leverage parallelism better. If you experience any problems due to this change, you can
temporarily enable the old CPU backend by setting the environment variable
XLA_FLAGS=--xla_cpu_use_thunk_runtime=false. If you need to do this, please file a JAX
bug with instructions to reproduce.
Indeed, os.environ['XLA_FLAGS'] = '--xla_cpu_use_thunk_runtime=false' at the beginning of the notebook (before importing JAX) fixes it also for newer versions of JAX.
Only relevant for CPU: I realized that, with version jax==0.4.32 or newer, simulation time is 10x slower and gradient time is 5x slower as compared to older versions of JAX 🤯
From the CHANGELOG of JAX:
Indeed,
os.environ['XLA_FLAGS'] = '--xla_cpu_use_thunk_runtime=false'
at the beginning of the notebook (before importing JAX) fixes it also for newer versions of JAX.See also this issue that I created on the JAX repo.
As an intermediate fix, #570 proposes to pin the JAX version to an old version.
The text was updated successfully, but these errors were encountered: