Skip to content

[inductor][TMA] Use revert_sort_idx rather than sort_idx#172009

Closed
kundaMwiza wants to merge 1 commit intopytorch:mainfrom
graphcore:mwizak/fix-stride-sort-revert-v2
Closed

[inductor][TMA] Use revert_sort_idx rather than sort_idx#172009
kundaMwiza wants to merge 1 commit intopytorch:mainfrom
graphcore:mwizak/fix-stride-sort-revert-v2

Conversation

@kundaMwiza
Copy link
Copy Markdown
Collaborator

@kundaMwiza kundaMwiza commented Jan 8, 2026

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Jan 8, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/172009

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit fd8c60d with merge base a7ac482 (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@kundaMwiza kundaMwiza marked this pull request as draft January 8, 2026 19:20
def revert(self, attr):
if not self.is_identity:
return [attr[i] for i in self.sort_idx]
return [attr[i] for i in self.revert_sort_idx]
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a test case for this but surprisingly it passes because of the block size that is chosen i.e. for certain block sizes regardless of how revert is done the result will be correct. Tested this locally

Copy link
Copy Markdown
Collaborator Author

@kundaMwiza kundaMwiza Jan 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

e.g. incorrect:

tl.reshape(
    tl.trans(
        tl.broadcast_to(
            tl.reshape(
                tl.broadcast_to(tmp7, [YBLOCK, XBLOCK]),
                [XBLOCK, ((4) * ((4) <= (YBLOCK)) + (YBLOCK) * ((YBLOCK) < (4))), 1, (3 + YBLOCK) // 4]
            ), 
        [XBLOCK, ((4) * ((4) <= (YBLOCK)) + (YBLOCK) * ((YBLOCK) < (4))), ((1) * ((1) <= ((3 + YBLOCK) // 4)) + ((3 + YBLOCK) // 4) * (((3 + YBLOCK) // 4) < (1))), 
    (3 + YBLOCK) // 4]), [2, 0, 3, 1]), [((4) * ((4) <= (YBLOCK)) + (YBLOCK) * ((YBLOCK) < (4))), (3 + YBLOCK) // 4, XBLOCK]).to(tl.float32), boundary_check=[1, 2]) 

correct:

 tl.reshape(
     tl.trans(
          tl.broadcast_to(
               tl.reshape(
                   tl.broadcast_to(tmp7, [YBLOCK, XBLOCK]), 
                   [(3 + YBLOCK) // 4, 1, ((4) * ((4) <= (YBLOCK)) + (YBLOCK) * ((YBLOCK) < (4))), XBLOCK]
               ), 
          [(3 + YBLOCK) // 4, ((1) * ((1) <= ((3 + YBLOCK) // 4)) + ((3 + YBLOCK) // 4) * (((3 + YBLOCK) // 4) < (1))
     ), ((4) * ((4) <= (YBLOCK)) + (YBLOCK) * ((YBLOCK) < (4))), XBLOCK]), [2, 0, 3, 1]), [((4) * ((4) <= (YBLOCK)) + (YBLOCK) * ((YBLOCK) < (4))), (3 + YBLOCK) // 4, XBLOCK]).to(tl.float32), boundary_check=[1, 2])

For YBLOCK = 1, 2, or 4; XBLOCK = any power of two, the above two operations are the same, but if YBLOCK exceeds 4 then there is a difference

@kundaMwiza
Copy link
Copy Markdown
Collaborator Author

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Jan 8, 2026
@kundaMwiza kundaMwiza marked this pull request as ready for review January 8, 2026 20:32
@bdhirsh bdhirsh requested review from jansel and njriasan January 12, 2026 15:49
@bdhirsh bdhirsh added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jan 12, 2026
@kundaMwiza
Copy link
Copy Markdown
Collaborator Author

kundaMwiza commented Jan 12, 2026

@njriasan I believe the failing test is unrelated. See Dao-AILab/flash-attention#2153 (assuming this is where the test is actually from). The failing test: https://github.com/pytorch/pytorch/actions/runs/20827118782/job/59909424611

@kundaMwiza kundaMwiza force-pushed the mwizak/fix-stride-sort-revert-v2 branch from 7c4846f to fd8c60d Compare January 15, 2026 09:57
@kundaMwiza
Copy link
Copy Markdown
Collaborator Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jan 15, 2026
@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-no-td Do not run TD on this PR ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: inductor open source topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants