-
Notifications
You must be signed in to change notification settings - Fork 11
Added support for scalars in tt-xla #95
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@AleksKnezevic do you have the time to take a look, interested in your opinion. Thanks! |
src/common/api_impl.cc
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI You are gonna hit a conflict after rebase since submit
API has been updated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the heads up! Rebased successfully in order to check tests now.
763ad91
to
65ece7a
Compare
65ece7a
to
91045ee
Compare
91045ee
to
5a95bf7
Compare
Due to the fact that scalars are not supported in our TTIR dialect and instead promoted to tensors during StableHLO->TTIR conversion, the PJRT runtime did not now that return values of jax functions were scalars and not 1x1 tensors. This led to the behaviour of the python calls differing between CPU and Silicon calls, with CPU calls returning scalars as expected, and Silicon jax calls returning 1x1 tensors.
This necessitated a workaround reshaping in the test infra, and additionally did not match with that the end user of the jax framework expected. This PR fixes that, by adding tracking of scalars on the PJRT runtime side.
Fixes issue #81