Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make the default matmul precision float32 even on TPUs #7010

Open
shoyer opened this issue Jun 17, 2021 · 15 comments
Open

Make the default matmul precision float32 even on TPUs #7010

shoyer opened this issue Jun 17, 2021 · 15 comments
Assignees
Labels
enhancement New feature or request

Comments

@shoyer
Copy link
Member

shoyer commented Jun 17, 2021

Follow-up from #2161:

The default low precision is a bit of a footgun, at least when doing anything that isn't implementing a neural net layer. In my opinion, it would be much safer to use "highest" precision by default (which isn't that much slower) on float32 data. Neural net libraries, of course, can default to lower precision, so this really only effects users who directly use NumPy APIs or the @ infix operator.

@shoyer shoyer added the enhancement New feature or request label Jun 17, 2021
@j-towns
Copy link
Contributor

j-towns commented Jun 21, 2021

I would argue this is also a footgun for some neural net use cases. Both of the two times that I’ve carefully re-implemented a model in JAX, I’ve found that performance is worse than expected with the default precision and it took me some time to realise that precision was the cause.

As a (temporary) improvement on the current situation, we could at least add some information on this issue to point 5 of the ‘Current Gotchas’ in the readme.

@shoyer
Copy link
Member Author

shoyer commented Jun 22, 2021

#6143 added a config flag, so this should in principle be easier to change now.

@j-towns
Copy link
Contributor

j-towns commented Jul 16, 2021

As @juliuskunze just pointed out to me, this is the only difference in semantics (that we can think of) between the GPU, CPU and TPU backends. 'Backend transparency' is a valuable property in mine and Julius's opinion, and given how close JAX is to achieving it (if this really is the only case where semantics differ significantly), it's surely worth changing the TPU default for all matmul-like ops (including conv) to closely approximate GPU and CPU behaviour. I understand this will harm 'default' speed, but I think we can mitigate that by making clear the availability of the bfloat16 option.

@jonbarron
Copy link
Contributor

My team shot itself in the foot last week for the ~fourth time due to matmult defaulting to bfloat16. This issue continues to be my biggest/only grievance with JAX.

@mattjj
Copy link
Member

mattjj commented Sep 8, 2021

Just got another +1 for this issue. Another foot lost.

copybara-service bot pushed a commit that referenced this issue Sep 8, 2021
On CPU and GPU, this change has no effect.

On TPU, this PR changes the default matmul algorithm from a fast, low-quality algorithm to a slower, high-precision algorithm that uses multiple passes. Many users have reported the low-quality-by-default behavior to be a footgun, especially when performing non-neural network computations.

The old behavior can be restored either by passing an explicit Precision option to operators such as `dot`, or by changing the default precision, e.g.,
jax.config.update('jax_default_matmul_precision', 'fastest')

#7010

PiperOrigin-RevId: 395549544
@patrickvonplaten
Copy link

+1 for this issue from the Transformers team for one of the most popular architectures for speech recognition (Wav2Vec2) - see: huggingface/transformers#15754

@inoryy
Copy link
Contributor

inoryy commented May 12, 2022

+1 to setting default to f32.

Having had the pleasure of shooting both my feet with either performance gotchas and numerics gotchas, I'd wholeheartedly prefer debugging the performance ones as they're much easier to spot.

@ppwwyyxx
Copy link
Contributor

ppwwyyxx commented Jun 29, 2022

Similar discussion in pytorch: pytorch/pytorch#67384. Pytorch once enabled tfloat32 by default for a few ops, and then had to revert the decision due to similar complains. Enabling bfloat16 by default is presumably even worse.

@jakeh-gc
Copy link
Contributor

jakeh-gc commented Nov 15, 2022

If anyone here is collecting feet 🦶🔫, I lost one to something related to this too. #12008 (comment)

@nouiz
Copy link
Collaborator

nouiz commented Nov 15, 2022

What about having a kind of warning/info printed once per process when the low precision is used by default? Printing too much stuff like TF by default isn't great. But I think this one is worth it.
We should allow to remove that warning too.

@DavidNorman
Copy link

It does seem a poor choice of user experience that the system does the wrong thing by default, and forces a developer to have to debug its unexpected behaviour. Rather than do the right thing by default and allow the developer an opportunity to feel good about optimising the performance by selecting lower precision maths. An action which they will then find easy to undo should the system perform badly due to the limited precision.

@ayaka14732
Copy link
Member

It does seem a poor choice of user experience that the system does the wrong thing by default, and forces a developer to have to debug its unexpected behaviour. Rather than do the right thing by default and allow the developer an opportunity to feel good about optimising the performance by selecting lower precision maths. An action which they will then find easy to undo should the system perform badly due to the limited precision.

+1 for this. My benchmark shows that models in low precision do not always get the same performance as high precision, so the correct way (i.e. high precision) should be the default.

@nouiz
Copy link
Collaborator

nouiz commented Nov 23, 2022

My personal point of view is that this is more complicated than this.
Perf, when we talk about >2x speed up is a feature. Smallish speed up like 1.Nx isn't as much a feature.
Here, DL and non-DL community need different feature.

When a software target or is used dominantly by one community, you just pick the favorite default.
But JAX has a fair amount of non-DL users that need different features then DL. So the choice is more complicated. If we do the switch other users will complain about the big slowdown.

Increasing the knowledge of this issue would be a first good step that I guess people would agree on.
Like adding this in the FAQ, the Sharp Bit and GPU doc. If people know about this, they can adjust more easily.
Are you interested in contributing this?

@jaschau
Copy link

jaschau commented Jan 15, 2023

Solving neural ODEs with diffrax is also affected by the unexpected default choice of TensorFloat32, see here patrick-kidger/diffrax#213.

@kvablack
Copy link

+1 another foot lost (kvablack/ddpo-pytorch#3 (comment)), this time in the form of forcing me to update 1/4 of the results for a paper that I already submitted and released. I do think using high-precision by default and allowing users to opt-in for better performance is much easier to debug.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests