-
Notifications
You must be signed in to change notification settings - Fork 830
feat: Add TRTLLM fmha_v2 library for SM90 attention with Skip-Softmax #2446
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
f032fdf
e7b0da6
17f8eef
00bec06
ef742fd
b6a3eaf
b0227ef
50de0e2
8168009
eaf1739
6dde167
3b2cc1e
b1f3fb7
b40f2fa
39c911d
da8852e
65f2d56
9a433ee
2ea7730
0e5974f
847af74
bb1739d
a2bf401
d0bd900
03415b0
7db51d1
063afc8
ad4653b
e8ebf7c
d56e2d3
6d86e4a
3e13201
c05a4b9
56cd4d0
d049a41
b5a13e1
ae4aaed
c2b39fa
d4b014a
53f5789
bb1a87a
36db6d8
0e8fd66
c668515
e8365c7
0d83dfc
a86f591
ff1c5a7
f2936d3
feccb1a
c778991
4cde977
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -179,7 +179,7 @@ struct Compute { | |||||||||||||||||||||||||
| USE_CUSTOM_MASK ? (head_info.mask_sum_s + q_step_idx * STEP_Q + local_q_tile_offset) \ | ||||||||||||||||||||||||||
| : (q_step_idx * STEP_Q + head_info.q_tile_offset), \ | ||||||||||||||||||||||||||
| kv_step_idx * STEP_KV, sage_scale_row, cbr, cbr_v, mutex_accessor, \ | ||||||||||||||||||||||||||
| kv_step_idx == kv_idx_end - 1); | ||||||||||||||||||||||||||
| &shared->skip_softmax_votes[kv_step_idx & 1][warpgroup_id], kv_step_idx == kv_idx_end - 1); | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| //////////////////////////////////////////////////////////////////////////////////////////////// | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
@@ -277,6 +277,12 @@ struct Compute { | |||||||||||||||||||||||||
| int const actual_kv_seqlen = | ||||||||||||||||||||||||||
| SEPARATE_Q_KV_BUFFER ? head_info.actual_kv_seqlen : actual_q_seqlen; | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| // Update threshold of Skip-Softmax | ||||||||||||||||||||||||||
| if constexpr (Kernel_traits::ENABLE_SKIP_SOFTMAX) { | ||||||||||||||||||||||||||
| softmax.skip_softmax_threshold = | ||||||||||||||||||||||||||
| params.skip_softmax_threshold_scale_factor / actual_kv_seqlen; | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| // Calculate the alibi head_scaling_factor. | ||||||||||||||||||||||||||
| float alibi_head_scale = APPLY_ALIBI ? get_alibi_head_scaling_factor<AlibiParams>( | ||||||||||||||||||||||||||
| head_info.bidh, params.alibi_params) | ||||||||||||||||||||||||||
|
|
@@ -411,6 +417,12 @@ struct Compute { | |||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
| #ifdef SKIP_SOFTMAX_STAT | ||||||||||||||||||||||||||
| if (tidx == 0) { | ||||||||||||||||||||||||||
| atomicAdd(params.skip_softmax_total_blocks, softmax.total_blocks); | ||||||||||||||||||||||||||
| atomicAdd(params.skip_softmax_skipped_blocks, softmax.skipped_blocks); | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
| #endif | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| //////////////////////////////////////////////////////////////////////////////////////////////// | ||||||||||||||||||||||||||
|
|
@@ -421,7 +433,14 @@ struct Compute { | |||||||||||||||||||||||||
| float (&p_max)[Mma_tile_p::CORES_M], float (&p_sum)[Mma_tile_p::CORES_M], int const tidx, | ||||||||||||||||||||||||||
| int const actual_kv_seqlen, float const alibi_head_scale, int const row_offset, | ||||||||||||||||||||||||||
| int const col_offset, int const sage_scale_row, Circular_buffer_q_reader& cbr, | ||||||||||||||||||||||||||
| Circular_buffer_kv_reader& cbr_v, OrderedMutexAccessor& mutex, bool complete = false) { | ||||||||||||||||||||||||||
| Circular_buffer_kv_reader& cbr_v, OrderedMutexAccessor& mutex, uint32_t* skip_softmax_vote, | ||||||||||||||||||||||||||
| bool complete = false) { | ||||||||||||||||||||||||||
| // Skip-softmax vote initialization | ||||||||||||||||||||||||||
| if (tidx == 0) { | ||||||||||||||||||||||||||
| // Note that we need a named_barrier_wait in compute_single_tile to make sure init is before | ||||||||||||||||||||||||||
| // voting. | ||||||||||||||||||||||||||
| *skip_softmax_vote = 1; | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
|
Comment on lines
+438
to
+443
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Initialize skip-softmax votes per warpgroup, not only thread 0. π§ Proposed fix- if (tidx == 0) {
+ if ((tidx % 128) == 0) {
// Note that we need a named_barrier_wait in compute_single_tile to make sure init is before
// voting.
*skip_softmax_vote = 1;
}π Committable suggestion
Suggested change
π€ Prompt for AI Agents |
||||||||||||||||||||||||||
| // load the scales of K/V from global memory | ||||||||||||||||||||||||||
| #define LOAD_SCALES_KV(dst, which, blocks_per_step, block_size) \ | ||||||||||||||||||||||||||
| if constexpr (block_size > 0) { \ | ||||||||||||||||||||||||||
|
|
@@ -453,6 +472,10 @@ struct Compute { | |||||||||||||||||||||||||
| // Ctile_p is only used once by each n step. | ||||||||||||||||||||||||||
| ctile_p.clear(); | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| // If skip_softmax is enabled, make sure there is no racing between the initialization and | ||||||||||||||||||||||||||
| // writing of skip_softmax_vote. | ||||||||||||||||||||||||||
| named_barrier_wait(Kernel_traits::SKIP_SOFTMAX_BARRIER_ID + threadIdx.x / 128, 128); | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| // BMM1 (Q x K'). | ||||||||||||||||||||||||||
| warpgroup_arrive(); | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
@@ -513,8 +536,27 @@ struct Compute { | |||||||||||||||||||||||||
| softmax.apply_alibi_and_mask<APPLY_MASK>(ctile_p, params.alibi_params, alibi_head_scale, | ||||||||||||||||||||||||||
| actual_kv_seqlen, row_offset, col_offset); | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| // Softmax Exp, max/sum, and update scales. | ||||||||||||||||||||||||||
| softmax.compute_and_update_scale<IS_FIRST_COL>(p_max, p_sum); | ||||||||||||||||||||||||||
| // Softmax Exp, max/sum, and update scales. If returns false we skip the rest. | ||||||||||||||||||||||||||
| if (!softmax.compute_and_update_scale<IS_FIRST_COL>(p_max, p_sum, skip_softmax_vote)) { | ||||||||||||||||||||||||||
| if constexpr (ENABLE_MUTEX && Kernel_traits::ELEMENT_BYTES == 1) { | ||||||||||||||||||||||||||
| // Notify another warpgroup to execute QGMMA. | ||||||||||||||||||||||||||
| mutex.named_bar_arrive(); | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
| // Need to wait V, otherwise compute-sanitizer synccheck will fail. | ||||||||||||||||||||||||||
| int ready2 = cbr_v.peek(); | ||||||||||||||||||||||||||
| if (!ready2) { | ||||||||||||||||||||||||||
| cbr_v.wait(); | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| #pragma unroll | ||||||||||||||||||||||||||
| // Advance V descriptor by the same amount as the BMM2 loop would, | ||||||||||||||||||||||||||
| // so that the descriptor stays in sync for subsequent KV steps. | ||||||||||||||||||||||||||
| for (int kbi = 0; kbi < BMM2_MMAS_K_GROUPS - 1; kbi++) { | ||||||||||||||||||||||||||
| ctile_o.increment_gmma_desc_group(); | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| return; | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| // experiments show that here is the best place to load scales of V | ||||||||||||||||||||||||||
| float scales_v[SAGE_BLOCKS_PER_STEP_V]; | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guard skip-softmax threshold against zero-length KV.
If
actual_kv_seqlenis 0, the division yields inf/NaN and can taint skip decisions.π‘οΈ Proposed fix
if constexpr (Kernel_traits::ENABLE_SKIP_SOFTMAX) { - softmax.skip_softmax_threshold = - params.skip_softmax_threshold_scale_factor / actual_kv_seqlen; + int denom = actual_kv_seqlen > 0 ? actual_kv_seqlen : 1; + softmax.skip_softmax_threshold = + params.skip_softmax_threshold_scale_factor / denom; }π€ Prompt for AI Agents