Skip to content

Optimize tl.where by converting select to branch when lowering to llvm#820

Closed
manman-ren wants to merge 3 commits into
mren/if-tensor-valuefrom
mren/add-vote
Closed

Optimize tl.where by converting select to branch when lowering to llvm#820
manman-ren wants to merge 3 commits into
mren/if-tensor-valuefrom
mren/add-vote

Conversation

@manman-ren
Copy link
Copy Markdown
Contributor

@manman-ren manman-ren commented Jan 28, 2026

One use case will be rescaling optimization of FA. When any thread in a warp needs rescaling of correction, correction_rescale will be invoked. We currently have 128 rows, 4 warps, each thread is responsible for one row.
Triton currently doesn't support ifOp on a tensor condition, which is needed for FA4 where should_rescale is a tensor value where it is uniform within a warp. The PR attempts to handle it when lowering to llvm, where we have a per-thread view.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jan 28, 2026
@manman-ren manman-ren marked this pull request as draft January 28, 2026 21:26
alpha = tl.math.exp2(alpha_)
rescale_mask = alpha_ >= -8.0
alpha = tl.where(rescale_mask, 1.0, alpha)
m_ij = tl.where(rescale_mask, m_i, m_ij)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems to hit a TLX issue with this change: Pipeline failed while executing [`TritonTLXFixup
python third_party/tlx/tutorials/blackwell-fa-ws-pipelined-persistent_test.py

CC @htyu

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The actual error is
error: 'ttng.vote_ballot_sync' op operand #1 must be 1-bit signless integer, but got 'tensor<128x1xi1>'
ballot_result = tlx.vote_ballot_sync(0xFFFFFFFF, pred)
Probably need to make vote_ballot_sync a tensor operation.

# All elements contain the same warp-level ballot value
# Non-zero means at least one thread has alpha_1 < 1.0
ballot_result = tlx.vote_ballot_sync(0xFFFFFFFF, pred)
should_rescale = ballot_result != 0
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alpha_1 is 128x1, we have 4 warps, 128 threads, each thread owns one row. We are voting to get rid of thread divergence within a warp, which means 32 threads will have the same value. The 4 warps can have different value. One way is to unroll this four times, one for each warp, then should_rescale will be 32x1, which is uniform. We can also perform reduction across 128x1 to have one value per 4 warps.
CC @njriasan @htyu @kvbp2k

Copy link
Copy Markdown
Contributor

@htyu htyu Feb 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So if one row needs rescale, the full 32 rows within that warp also need rescale as a result based on the current implementation?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, to avoid warp divergence, all 32 threads in the warp will make the same decision.

@manman-ren manman-ren changed the title Add tlx.vote_ballot_sync op that lowers to NVVM::VoteSyncOp Handle ifOp on a tensor condition by converting select to branch when lowering to llvm Feb 2, 2026
@manman-ren manman-ren changed the title Handle ifOp on a tensor condition by converting select to branch when lowering to llvm Optimize tl.where by converting select to branch when lowering to llvm Feb 3, 2026
@manman-ren manman-ren changed the base branch from main to mren/if-tensor-value February 3, 2026 17:27
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@manman-ren manman-ren force-pushed the mren/if-tensor-value branch from 4d1660c to 7d29173 Compare February 5, 2026 00:54
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@manman-ren manman-ren force-pushed the mren/if-tensor-value branch from 7d29173 to a26607a Compare February 6, 2026 23:28
@manman-ren manman-ren deleted the branch mren/if-tensor-value February 13, 2026 18:41
@manman-ren manman-ren closed this Feb 13, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants