Skip to content

Add tlx.vote_ballot_sync op that lowers to NVVM::VoteSyncOp#828

Closed
manman-ren wants to merge 1 commit into
mainfrom
mren/if-tensor-value
Closed

Add tlx.vote_ballot_sync op that lowers to NVVM::VoteSyncOp#828
manman-ren wants to merge 1 commit into
mainfrom
mren/if-tensor-value

Conversation

@manman-ren
Copy link
Copy Markdown
Contributor

@manman-ren manman-ren commented Feb 2, 2026

This will be used by rescaling optimization in FA4. Kernel change will be in a separate PR.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Feb 2, 2026
@manman-ren manman-ren marked this pull request as draft February 2, 2026 19:44
@manman-ren manman-ren marked this pull request as ready for review February 3, 2026 16:11
@manman-ren manman-ren requested review from htyu and njriasan February 3, 2026 16:11
Copy link
Copy Markdown
Contributor

@njriasan njriasan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A couple minor nits but overall this looks great. Thanks!

for mode in ["fwd", "bwd"]:
for causal in [False, True]:
for BWD_BLOCK_M1 in [64, 128]:
for mode in ["fwd"]: # , "bwd"]:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this code fully broken? If not can we update this code to remove the comments?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for debugging. Will revert.

tlx.local_store(subslice, acc)
# Perform warp-level ballot vote to check if any thread needs rescaling
# 0xFFFFFFFF means all 32 threads in the warp participate
if RESCALE_OPT:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor nit: Might be easier to understand this code if we assert USE_WHERE = False if RESCALE_OPT = False. That way we can even split the logic into a helper function and make it very simple to understand. What do you think?

tlx.local_store(subslice, acc)
else:
# option 2: use a single scalar IfOp
if RESCALE_OPT:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd be curious to know if we have performance numbers for this yet? I can report the GB300 numbers today if we update the TritonBench kernel.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious too. I'm wondering if we should defer the kernel changes to separate PR until we get some perf numbers and numerics reuslts.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I should have included some perf number. I am still working on #820 for the USE_WHERE path. Currently, enabling RESCALE_OPT + not USE_WHERE has some perf win, pending numerical results. This path performs reduction for all rows in the block. I can get rid of the kernel changes or hard-code to RESCALE_OPT being off.


# Run the kernel with 1 warp
vote_ballot_kernel[(1, )](output, BLOCK_SIZE=32, num_warps=1)
torch.cuda.synchronize()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why needs torch.cuda.synchronize()?


@tl.builtin
def vote_ballot_sync(
mask: tl.constexpr,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is mask required to be constant from PTX pov?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it doesn't need to be.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps change int to int ?

Comment thread third_party/tlx/language/tlx/warp_ops.py
Returns:
If pred is scalar: A 32-bit integer where bit N is set if thread N's
predicate was true and thread N is in the mask.
If pred is tensor: A tensor of i32 with the same shape, where each
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same shape with result value?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

returns a tensor with the same shape as pred

Comment thread third_party/tlx/language/tlx/warp_ops.py
C1 = 0.695146143436431884765625
C2 = 0.227564394474029541015625
C3 = 0.077119089663028717041015625

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These changes seem unrelate to the PR?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, these are for exp simulation from your original PR. Will revert

Copy link
Copy Markdown
Contributor

@htyu htyu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please also update the README? Thanks.

@manman-ren manman-ren force-pushed the mren/if-tensor-value branch from 4d1660c to 7d29173 Compare February 5, 2026 00:54
Copy link
Copy Markdown
Contributor

@htyu htyu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!


@tl.builtin
def vote_ballot_sync(
mask: tl.constexpr,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps change int to int ?

@manman-ren manman-ren force-pushed the mren/if-tensor-value branch from 7d29173 to a26607a Compare February 6, 2026 23:28
@meta-codesync
Copy link
Copy Markdown
Contributor

meta-codesync Bot commented Feb 9, 2026

@manman-ren has imported this pull request. If you are a Meta employee, you can view this in D92753224.

Summary:
This will be used by rescaling optimization in FA4. Kernel change will be in a separate PR.


Reviewed By: htyu

Differential Revision: D92753224

Pulled By: manman-ren
@meta-codesync
Copy link
Copy Markdown
Contributor

meta-codesync Bot commented Feb 11, 2026

@manman-ren has exported this pull request. If you are a Meta employee, you can view the originating Diff in D92753224.

@meta-codesync
Copy link
Copy Markdown
Contributor

meta-codesync Bot commented Feb 12, 2026

@manman-ren merged this pull request in 517ce79.

@manman-ren manman-ren deleted the mren/if-tensor-value branch February 13, 2026 18:41
htyu pushed a commit that referenced this pull request Mar 3, 2026
Summary:
This will be used by rescaling optimization in FA4. Kernel change will be in a separate PR.

Pull Request resolved: #828

Reviewed By: htyu

Differential Revision: D92753224

Pulled By: manman-ren

fbshipit-source-id: f3df62dcb5193f0c1022fca41bd9aa2828084de8
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot. fb-exported Merged meta-exported

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants