-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[TIR][Arith] Implemented padded inverses in IndexMap #11235
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This allows it to be used for any expression containing an `IterMapExpr`, not just expressions whose top-level node is an `IterMapExpr`.
The existing DetectIterMap tries to rewrite index expression as a linear combination of split/fused iterators, where the new iterators cover the exact same indices as the original expression. DetectPaddedIterMap relaxes this condition, allowing the new iterators to cover a superset of indices that the initial index expression covered. It uses the minimum amount of padding necessary to represent these transformations, and also a predicate that identifies any padding that has been added. This is a utility function to be used for layout transformations of buffers, in cases where the pre-transformation shape of the buffer does not evenly fit into the post-transformation shape.
Allow non-surjective transformations, with DetectIterMap used to determine the minimum padding to insert. Returns the inverse function, along with a predicate that identifies padding indices. The predicate is in terms of the transformed variables.
- `IndexMap::Inverse` exposed as `IndexMap.inverse` - `IndexMap::MapShape` exposed as `IndexMap.map_shape` - `IndexMap::NonSurjectiveInverse` exposed as `IndexMap.non_surjective_inverse`
In preparation for adding additional tests for the IndexMap class, which will require this functionality.
Initially disabled as dynamic shapes resulted in padded lengths whose divisiblity couldn't be proven. Re-enabled along with a simplification rule to resolve it.
|
Also, name suggestions for "non-surjective inverse" would be appreciated. It's something of a mouthful, and describes what it is doing without describing why it would be desired. |
|
Thanks for the PR, I like this change. There is a failing test case |
Thank you, and I'm liking it as well. This sort of utility should be useful regardless on the buffer padding semantics used.
Yeah, so far I've narrowed it down to the handling of cases like |
Hi~ Does it mean pass index like |
Prior to this PR, that is correct. The check on the main branch is here, which requires the offset in a numerator to be zero. After this PR, an index like |
* [Debug] Error logging in DetectIterMap * [Affine] Allowed PrimExpr argument to NormalizeIterMapToExpr This allows it to be used for any expression containing an `IterMapExpr`, not just expressions whose top-level node is an `IterMapExpr`. * [Affine] Implemented DetectPaddedIterMap The existing DetectIterMap tries to rewrite index expression as a linear combination of split/fused iterators, where the new iterators cover the exact same indices as the original expression. DetectPaddedIterMap relaxes this condition, allowing the new iterators to cover a superset of indices that the initial index expression covered. It uses the minimum amount of padding necessary to represent these transformations, and also a predicate that identifies any padding that has been added. This is a utility function to be used for layout transformations of buffers, in cases where the pre-transformation shape of the buffer does not evenly fit into the post-transformation shape. * [IndexMap] Implemented IndexMap::NonSurjectiveInverse Allow non-surjective transformations, with DetectIterMap used to determine the minimum padding to insert. Returns the inverse function, along with a predicate that identifies padding indices. The predicate is in terms of the transformed variables. * [IndexMap] Exposed methods to python - `IndexMap::Inverse` exposed as `IndexMap.inverse` - `IndexMap::MapShape` exposed as `IndexMap.map_shape` - `IndexMap::NonSurjectiveInverse` exposed as `IndexMap.non_surjective_inverse` * [IndexMap] Extracted _assert_equal_index_map into class method In preparation for adding additional tests for the IndexMap class, which will require this functionality. * [IndexMap] Added unit tests for new behavior * Re-enabled divisibility check in CheckMapping Initially disabled as dynamic shapes resulted in padded lengths whose divisiblity couldn't be proven. Re-enabled along with a simplification rule to resolve it. * Fixed breakage in compute_at primitive * Corrected typos/examples in docstring
* [Debug] Error logging in DetectIterMap * [Affine] Allowed PrimExpr argument to NormalizeIterMapToExpr This allows it to be used for any expression containing an `IterMapExpr`, not just expressions whose top-level node is an `IterMapExpr`. * [Affine] Implemented DetectPaddedIterMap The existing DetectIterMap tries to rewrite index expression as a linear combination of split/fused iterators, where the new iterators cover the exact same indices as the original expression. DetectPaddedIterMap relaxes this condition, allowing the new iterators to cover a superset of indices that the initial index expression covered. It uses the minimum amount of padding necessary to represent these transformations, and also a predicate that identifies any padding that has been added. This is a utility function to be used for layout transformations of buffers, in cases where the pre-transformation shape of the buffer does not evenly fit into the post-transformation shape. * [IndexMap] Implemented IndexMap::NonSurjectiveInverse Allow non-surjective transformations, with DetectIterMap used to determine the minimum padding to insert. Returns the inverse function, along with a predicate that identifies padding indices. The predicate is in terms of the transformed variables. * [IndexMap] Exposed methods to python - `IndexMap::Inverse` exposed as `IndexMap.inverse` - `IndexMap::MapShape` exposed as `IndexMap.map_shape` - `IndexMap::NonSurjectiveInverse` exposed as `IndexMap.non_surjective_inverse` * [IndexMap] Extracted _assert_equal_index_map into class method In preparation for adding additional tests for the IndexMap class, which will require this functionality. * [IndexMap] Added unit tests for new behavior * Re-enabled divisibility check in CheckMapping Initially disabled as dynamic shapes resulted in padded lengths whose divisiblity couldn't be proven. Re-enabled along with a simplification rule to resolve it. * Fixed breakage in compute_at primitive * Corrected typos/examples in docstring
* [Debug] Error logging in DetectIterMap * [Affine] Allowed PrimExpr argument to NormalizeIterMapToExpr This allows it to be used for any expression containing an `IterMapExpr`, not just expressions whose top-level node is an `IterMapExpr`. * [Affine] Implemented DetectPaddedIterMap The existing DetectIterMap tries to rewrite index expression as a linear combination of split/fused iterators, where the new iterators cover the exact same indices as the original expression. DetectPaddedIterMap relaxes this condition, allowing the new iterators to cover a superset of indices that the initial index expression covered. It uses the minimum amount of padding necessary to represent these transformations, and also a predicate that identifies any padding that has been added. This is a utility function to be used for layout transformations of buffers, in cases where the pre-transformation shape of the buffer does not evenly fit into the post-transformation shape. * [IndexMap] Implemented IndexMap::NonSurjectiveInverse Allow non-surjective transformations, with DetectIterMap used to determine the minimum padding to insert. Returns the inverse function, along with a predicate that identifies padding indices. The predicate is in terms of the transformed variables. * [IndexMap] Exposed methods to python - `IndexMap::Inverse` exposed as `IndexMap.inverse` - `IndexMap::MapShape` exposed as `IndexMap.map_shape` - `IndexMap::NonSurjectiveInverse` exposed as `IndexMap.non_surjective_inverse` * [IndexMap] Extracted _assert_equal_index_map into class method In preparation for adding additional tests for the IndexMap class, which will require this functionality. * [IndexMap] Added unit tests for new behavior * Re-enabled divisibility check in CheckMapping Initially disabled as dynamic shapes resulted in padded lengths whose divisiblity couldn't be proven. Re-enabled along with a simplification rule to resolve it. * Fixed breakage in compute_at primitive * Corrected typos/examples in docstring
Follow-up from apache#11235, all error messages should be based on expressions that are not IterMapExpr.
Follow-up from #11235, all error messages should be based on expressions that are not IterMapExpr.
The goal of this PR is to allow
IndexMapto express transformations that introduce padding in the transformed shape. For an arbitrary input shape, the new methodIndexMap.non_surjective_inversedetermines the inverse transformation, along with a predicate specifying which coordinates in the transformed index space do not contain an inverse in the original index space. The previous behavior ofIndexMap.inverse, requiring transformations to be bijective over the range given, is maintained.This functionality will be used in the future to allow buffer transformations (see #9727 and #10538) to introduce padding to the buffer.