-
Notifications
You must be signed in to change notification settings - Fork 433
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bug-fix][XLA:CPU][oneDNN] Fix BINARY_ADD fusion to Dot #13301
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the fix! Could you please explain what is causing the issue, and how this fix addresses it? Is it because of rank mismatch + wrong auto broadcasting or something?
|
||
ENTRY main { | ||
constant.2 = f32[] constant(1e-06) | ||
broadcast.3 = f32[1000000] broadcast(constant.2), dimensions={} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think the size needs to be this big to reproduce the failure. Would 10 work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The issue is not reproducible with a smaller size.
subtract.14 = f32[1000000,3] subtract(broadcast.8, broadcast.13) | ||
constant.4 = f32[] constant(0) | ||
broadcast.5 = f32[3,3] broadcast(constant.4), dimensions={} | ||
dot.15 = f32[1000000,3] dot(subtract.14, broadcast.5), lhs_contracting_dims={1}, rhs_contracting_dims={0} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we reduce the ops to just necessary ops that reproduce the failure? I don't think all the dots are needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The bug seems to be seen with this particular case.
I'd like to try to get this fix in in the next day or so so I can incorporate it in the next JAX release, please. |
oneDNN expects Matmul followed by Bias-Add followed by Binary-Add. But, here Matmul is followed by Binary-Add and then by Bias-Add which oneDNN does not support. The fix here is extending the dimensions of the Bias-Add to a Binary-Add which is supported. As seen below |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you very much for the clarifications!
Imported from GitHub PR openxla/xla#13301 This PR fixes a bug reported for JAX (openxla/xla#13054) Copybara import of the project: -- 47d5bde8eab607d0fe9b60c6fd82d95365c8169f by mdfaijul <[email protected]>: Make addend rank same to dot. Merging this change closes #13301 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#13301 from Intel-tensorflow:amin/bug-fix-jax 47d5bde8eab607d0fe9b60c6fd82d95365c8169f PiperOrigin-RevId: 640081553
Imported from GitHub PR openxla/xla#13301 This PR fixes a bug reported for JAX (openxla/xla#13054) Copybara import of the project: -- 47d5bde8eab607d0fe9b60c6fd82d95365c8169f by mdfaijul <[email protected]>: Make addend rank same to dot. Merging this change closes #13301 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#13301 from Intel-tensorflow:amin/bug-fix-jax 47d5bde8eab607d0fe9b60c6fd82d95365c8169f PiperOrigin-RevId: 640081553
FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#13301 from Intel-tensorflow:amin/bug-fix-jax 47d5bde8eab607d0fe9b60c6fd82d95365c8169f PiperOrigin-RevId: 638276915
Imported from GitHub PR openxla/xla#13301 This PR fixes a bug reported for JAX (openxla/xla#13054) Copybara import of the project: -- 47d5bde8eab607d0fe9b60c6fd82d95365c8169f by mdfaijul <[email protected]>: Make addend rank same to dot. Merging this change closes #13301 PiperOrigin-RevId: 640094871
This PR fixes a bug reported for JAX (#13054)