Disabling TF32 in JAX #22701
-
Dear All, I would like to ask, how to prevent JAX from running TF32 computations on A100 GPUs. I know that I can manually set Is there a way, how to globally disable TF32 in JAX? So far, I am using Best, |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Looking at jax.config.update('jax_default_matmul_precision', 'highest') |
Beta Was this translation helpful? Give feedback.
Looking at
jax/_src/config.py
, the following config should do the job: