Optimize tl.where by converting select to branch when lowering to llvm#820
Optimize tl.where by converting select to branch when lowering to llvm#820manman-ren wants to merge 3 commits into
Conversation
1889c13 to
87912e7
Compare
| alpha = tl.math.exp2(alpha_) | ||
| rescale_mask = alpha_ >= -8.0 | ||
| alpha = tl.where(rescale_mask, 1.0, alpha) | ||
| m_ij = tl.where(rescale_mask, m_i, m_ij) |
There was a problem hiding this comment.
seems to hit a TLX issue with this change: Pipeline failed while executing [`TritonTLXFixup
python third_party/tlx/tutorials/blackwell-fa-ws-pipelined-persistent_test.py
CC @htyu
There was a problem hiding this comment.
The actual error is
error: 'ttng.vote_ballot_sync' op operand #1 must be 1-bit signless integer, but got 'tensor<128x1xi1>'
ballot_result = tlx.vote_ballot_sync(0xFFFFFFFF, pred)
Probably need to make vote_ballot_sync a tensor operation.
| # All elements contain the same warp-level ballot value | ||
| # Non-zero means at least one thread has alpha_1 < 1.0 | ||
| ballot_result = tlx.vote_ballot_sync(0xFFFFFFFF, pred) | ||
| should_rescale = ballot_result != 0 |
There was a problem hiding this comment.
alpha_1 is 128x1, we have 4 warps, 128 threads, each thread owns one row. We are voting to get rid of thread divergence within a warp, which means 32 threads will have the same value. The 4 warps can have different value. One way is to unroll this four times, one for each warp, then should_rescale will be 32x1, which is uniform. We can also perform reduction across 128x1 to have one value per 4 warps.
CC @njriasan @htyu @kvbp2k
There was a problem hiding this comment.
So if one row needs rescale, the full 32 rows within that warp also need rescale as a result based on the current implementation?
There was a problem hiding this comment.
Yes, to avoid warp divergence, all 32 threads in the warp will make the same decision.
ddfaa3b to
d81aabf
Compare
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
4d1660c to
7d29173
Compare
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
d81aabf to
2ab1690
Compare
7d29173 to
a26607a
Compare
a26607a to
4dfcc3c
Compare
One use case will be rescaling optimization of FA. When any thread in a warp needs rescaling of correction, correction_rescale will be invoked. We currently have 128 rows, 4 warps, each thread is responsible for one row.
Triton currently doesn't support ifOp on a tensor condition, which is needed for FA4 where should_rescale is a tensor value where it is uniform within a warp. The PR attempts to handle it when lowering to llvm, where we have a per-thread view.