[inductor][TMA] Use revert_sort_idx rather than sort_idx#172009
[inductor][TMA] Use revert_sort_idx rather than sort_idx#172009kundaMwiza wants to merge 1 commit intopytorch:mainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/172009
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit fd8c60d with merge base a7ac482 ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| def revert(self, attr): | ||
| if not self.is_identity: | ||
| return [attr[i] for i in self.sort_idx] | ||
| return [attr[i] for i in self.revert_sort_idx] |
There was a problem hiding this comment.
There is a test case for this but surprisingly it passes because of the block size that is chosen i.e. for certain block sizes regardless of how revert is done the result will be correct. Tested this locally
There was a problem hiding this comment.
e.g. incorrect:
tl.reshape(
tl.trans(
tl.broadcast_to(
tl.reshape(
tl.broadcast_to(tmp7, [YBLOCK, XBLOCK]),
[XBLOCK, ((4) * ((4) <= (YBLOCK)) + (YBLOCK) * ((YBLOCK) < (4))), 1, (3 + YBLOCK) // 4]
),
[XBLOCK, ((4) * ((4) <= (YBLOCK)) + (YBLOCK) * ((YBLOCK) < (4))), ((1) * ((1) <= ((3 + YBLOCK) // 4)) + ((3 + YBLOCK) // 4) * (((3 + YBLOCK) // 4) < (1))),
(3 + YBLOCK) // 4]), [2, 0, 3, 1]), [((4) * ((4) <= (YBLOCK)) + (YBLOCK) * ((YBLOCK) < (4))), (3 + YBLOCK) // 4, XBLOCK]).to(tl.float32), boundary_check=[1, 2])
correct:
tl.reshape(
tl.trans(
tl.broadcast_to(
tl.reshape(
tl.broadcast_to(tmp7, [YBLOCK, XBLOCK]),
[(3 + YBLOCK) // 4, 1, ((4) * ((4) <= (YBLOCK)) + (YBLOCK) * ((YBLOCK) < (4))), XBLOCK]
),
[(3 + YBLOCK) // 4, ((1) * ((1) <= ((3 + YBLOCK) // 4)) + ((3 + YBLOCK) // 4) * (((3 + YBLOCK) // 4) < (1))
), ((4) * ((4) <= (YBLOCK)) + (YBLOCK) * ((YBLOCK) < (4))), XBLOCK]), [2, 0, 3, 1]), [((4) * ((4) <= (YBLOCK)) + (YBLOCK) * ((YBLOCK) < (4))), (3 + YBLOCK) // 4, XBLOCK]).to(tl.float32), boundary_check=[1, 2])
For YBLOCK = 1, 2, or 4; XBLOCK = any power of two, the above two operations are the same, but if YBLOCK exceeds 4 then there is a difference
|
@pytorchbot label "topic: not user facing" |
|
@njriasan I believe the failing test is unrelated. See Dao-AILab/flash-attention#2153 (assuming this is where the test is actually from). The failing test: https://github.com/pytorch/pytorch/actions/runs/20827118782/job/59909424611 |
7c4846f to
fd8c60d
Compare
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Reverting the sort should use
revert_sort_idxinsteadcc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @jataylo