Add tlx.vote_ballot_sync op that lowers to NVVM::VoteSyncOp#828
Add tlx.vote_ballot_sync op that lowers to NVVM::VoteSyncOp#828manman-ren wants to merge 1 commit into
Conversation
njriasan
left a comment
There was a problem hiding this comment.
A couple minor nits but overall this looks great. Thanks!
| for mode in ["fwd", "bwd"]: | ||
| for causal in [False, True]: | ||
| for BWD_BLOCK_M1 in [64, 128]: | ||
| for mode in ["fwd"]: # , "bwd"]: |
There was a problem hiding this comment.
Is this code fully broken? If not can we update this code to remove the comments?
There was a problem hiding this comment.
This is for debugging. Will revert.
| tlx.local_store(subslice, acc) | ||
| # Perform warp-level ballot vote to check if any thread needs rescaling | ||
| # 0xFFFFFFFF means all 32 threads in the warp participate | ||
| if RESCALE_OPT: |
There was a problem hiding this comment.
Minor nit: Might be easier to understand this code if we assert USE_WHERE = False if RESCALE_OPT = False. That way we can even split the logic into a helper function and make it very simple to understand. What do you think?
| tlx.local_store(subslice, acc) | ||
| else: | ||
| # option 2: use a single scalar IfOp | ||
| if RESCALE_OPT: |
There was a problem hiding this comment.
I'd be curious to know if we have performance numbers for this yet? I can report the GB300 numbers today if we update the TritonBench kernel.
There was a problem hiding this comment.
I'm curious too. I'm wondering if we should defer the kernel changes to separate PR until we get some perf numbers and numerics reuslts.
There was a problem hiding this comment.
Yes I should have included some perf number. I am still working on #820 for the USE_WHERE path. Currently, enabling RESCALE_OPT + not USE_WHERE has some perf win, pending numerical results. This path performs reduction for all rows in the block. I can get rid of the kernel changes or hard-code to RESCALE_OPT being off.
|
|
||
| # Run the kernel with 1 warp | ||
| vote_ballot_kernel[(1, )](output, BLOCK_SIZE=32, num_warps=1) | ||
| torch.cuda.synchronize() |
There was a problem hiding this comment.
why needs torch.cuda.synchronize()?
|
|
||
| @tl.builtin | ||
| def vote_ballot_sync( | ||
| mask: tl.constexpr, |
There was a problem hiding this comment.
Is mask required to be constant from PTX pov?
There was a problem hiding this comment.
No, it doesn't need to be.
| Returns: | ||
| If pred is scalar: A 32-bit integer where bit N is set if thread N's | ||
| predicate was true and thread N is in the mask. | ||
| If pred is tensor: A tensor of i32 with the same shape, where each |
There was a problem hiding this comment.
returns a tensor with the same shape as pred
| C1 = 0.695146143436431884765625 | ||
| C2 = 0.227564394474029541015625 | ||
| C3 = 0.077119089663028717041015625 | ||
|
|
There was a problem hiding this comment.
These changes seem unrelate to the PR?
There was a problem hiding this comment.
Yes, these are for exp simulation from your original PR. Will revert
htyu
left a comment
There was a problem hiding this comment.
Can you please also update the README? Thanks.
4d1660c to
7d29173
Compare
|
|
||
| @tl.builtin | ||
| def vote_ballot_sync( | ||
| mask: tl.constexpr, |
7d29173 to
a26607a
Compare
|
@manman-ren has imported this pull request. If you are a Meta employee, you can view this in D92753224. |
Summary: This will be used by rescaling optimization in FA4. Kernel change will be in a separate PR. Reviewed By: htyu Differential Revision: D92753224 Pulled By: manman-ren
a26607a to
4dfcc3c
Compare
|
@manman-ren has exported this pull request. If you are a Meta employee, you can view the originating Diff in D92753224. |
|
@manman-ren merged this pull request in 517ce79. |
Summary: This will be used by rescaling optimization in FA4. Kernel change will be in a separate PR. Pull Request resolved: #828 Reviewed By: htyu Differential Revision: D92753224 Pulled By: manman-ren fbshipit-source-id: f3df62dcb5193f0c1022fca41bd9aa2828084de8
This will be used by rescaling optimization in FA4. Kernel change will be in a separate PR.