-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[TIR] Unify index data type when creating prim func #13327
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
vinx13
commented
Nov 9, 2022
- Added data type pass unification pass to by default promote data types of all indices and shapes to int64 when creating prim func.
- Added some fixes for lowering passes to make it compatible with int64 data type.
|
Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.
Generated by tvm-bot |
24d0375 to
a8d56f3
Compare
3879996 to
9c35fe5
Compare
a09058b to
5ac8b38
Compare
|
Happy to review and let's fix the CI :-) |
973fe51 to
7dd2917
Compare
junrushao
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
* Added data type pass unification pass to by default promote data types of all indices and shapes to int64 when creating prim func. * Added some fixes for lowering passes to make it compatible with int64 data type.
…pache#13327)"" This reverts commit d2dbcb1.
…pache#13327)"" This reverts commit d2dbcb1.
…pache#13327)"" This reverts commit d2dbcb1.
|
Hello, @vinx13 ! I have one question about this PR... Is it possible to give more information or motivation about why do we need to convert indexes into "int64" data type? A few words about why I am asking: Example: Average Pooling 2D As you can see we get extra cast("int64"). Unfortunately, Hexagon does not support vectorization of "int64" data types and performance became very very poor. P.S. Just for experiment I have reverted conversion of indexes into int64 and get performance gain +40% (!!!). So, I would like to fix it somehow but I would like to know motivation for this PR. Thank you in advance! Just FYI cc @masahi |
|
Sometimes the model contains mixed indices type (e.g both int32 and int64). It causes dtype mismatch error during scheduling. It is expected that this pass doesn't hurt performance since there is another pass NarrowDataType that should convert it back to int32 (or other smaller types) if possible. In your case, the result of conversion seems more complicated than expected, it contains some unnecessary cast. I'd expect something like that looks the same as before, except that variables and constant ints have int64 dtypes |