Skip to content

Conversation

@vinx13
Copy link
Member

@vinx13 vinx13 commented Nov 9, 2022

  • Added data type pass unification pass to by default promote data types of all indices and shapes to int64 when creating prim func.
  • Added some fixes for lowering passes to make it compatible with int64 data type.

@tvm-bot
Copy link
Collaborator

tvm-bot commented Nov 9, 2022

Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.

Generated by tvm-bot

@vinx13 vinx13 force-pushed the dtype-legalize-2-up branch 4 times, most recently from 24d0375 to a8d56f3 Compare November 11, 2022 01:52
@vinx13 vinx13 force-pushed the dtype-legalize-2-up branch from 3879996 to 9c35fe5 Compare November 14, 2022 21:21
@vinx13 vinx13 marked this pull request as ready for review November 15, 2022 00:43
@vinx13 vinx13 force-pushed the dtype-legalize-2-up branch from a09058b to 5ac8b38 Compare November 15, 2022 18:22
@junrushao
Copy link
Member

Happy to review and let's fix the CI :-)

@junrushao junrushao self-assigned this Nov 15, 2022
@vinx13 vinx13 force-pushed the dtype-legalize-2-up branch from 973fe51 to 7dd2917 Compare November 16, 2022 19:32
Copy link
Member

@junrushao junrushao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@junrushao junrushao merged commit ad5c811 into apache:main Nov 17, 2022
xinetzone pushed a commit to daobook/tvm that referenced this pull request Nov 25, 2022
* Added data type pass unification pass to by default promote data types of all indices and shapes to int64 when creating prim func.
* Added some fixes for lowering passes to make it compatible with int64 data type.
masahi added a commit to masahi/tvm that referenced this pull request Nov 25, 2022
masahi added a commit to masahi/tvm that referenced this pull request Nov 25, 2022
masahi added a commit to masahi/tvm that referenced this pull request Nov 29, 2022
masahi added a commit to masahi/tvm that referenced this pull request Nov 29, 2022
masahi added a commit to masahi/tvm that referenced this pull request Dec 1, 2022
masahi added a commit to masahi/tvm that referenced this pull request Dec 1, 2022
masahi added a commit to masahi/tvm that referenced this pull request Dec 2, 2022
masahi added a commit to masahi/tvm that referenced this pull request Dec 2, 2022
@ibsidorenko
Copy link
Contributor

Hello, @vinx13 !

I have one question about this PR... Is it possible to give more information or motivation about why do we need to convert indexes into "int64" data type?

A few words about why I am asking:
I am working on MetaScheduler for Hexagon target. And found that this PR dramatically reduce performance for some operations.

Example: Average Pooling 2D
For this operator we use indexes in its compute function and pool2d divisor.
Before IndexDataTypeNormalizer:
pool_avg[ax0, ax1, ax2, ax3, ax4] = (pool_sum[ax0, ax1, ax2, ax3, ax4] / max((((min(1, (34 - ax2)) + 2) - max((1 - ax2), 0))*((min(1, (34 - ax3)) + 2) - max((1 - ax3), 0))), 1))
After IndexDataTypeNormalizer:
pool_avg[ax0, ax1, ax2, ax3, ax4] = cast(int32, (cast(int64, pool_sum[ax0, ax1, ax2, ax3, ax4]) / max((((min(1i64, (34i64 - ax2)) + 2i64) - max((1i64 - ax2), 0i64))*((min(1i64, (34i64 - ax3)) + 2i64) - max((1i64 - ax3), 0i64))), 1i64)))

As you can see we get extra cast("int64"). Unfortunately, Hexagon does not support vectorization of "int64" data types and performance became very very poor.

P.S. Just for experiment I have reverted conversion of indexes into int64 and get performance gain +40% (!!!).

So, I would like to fix it somehow but I would like to know motivation for this PR.

Thank you in advance!

Just FYI cc @masahi

@vinx13
Copy link
Member Author

vinx13 commented Mar 7, 2023

Sometimes the model contains mixed indices type (e.g both int32 and int64). It causes dtype mismatch error during scheduling. It is expected that this pass doesn't hurt performance since there is another pass NarrowDataType that should convert it back to int32 (or other smaller types) if possible.

In your case, the result of conversion seems more complicated than expected, it contains some unnecessary cast. I'd expect something like


pool_avg[ax0, ax1, ax2, ax3, ax4] = (pool_sum[ax0, ax1, ax2, ax3, ax4] / max((((min(1, (34 - ax2)) + 2) - max((1 - ax2), 0))*((min(1, (34 - ax3)) + 2) - max((1 - ax3), 0))), 1))

that looks the same as before, except that variables and constant ints have int64 dtypes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants