Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
f032fdf
init wip
jimmyzho Jan 7, 2026
e7b0da6
Merge branch 'flashinfer-ai:main' into fmhav2
jimmyzho Jan 8, 2026
17f8eef
move files
jimmyzho Jan 8, 2026
00bec06
Merge branch 'fmhav2' of github.com:jimmyzho/flashinfer into fmhav2
jimmyzho Jan 8, 2026
ef742fd
refactor module gen and compile
jimmyzho Jan 9, 2026
b6a3eaf
Merge branch 'flashinfer-ai:main' into fmhav2
jimmyzho Jan 9, 2026
b0227ef
(wip) compile flow and basic launcher
jimmyzho Jan 23, 2026
50de0e2
fp16 paged
jimmyzho Jan 26, 2026
8168009
fix paged tests
jimmyzho Jan 27, 2026
eaf1739
packed-qkv and separate-qkv
jimmyzho Jan 29, 2026
6dde167
sliding window off by one
jimmyzho Jan 30, 2026
3b2cc1e
attention sinks and refactor tests
jimmyzho Jan 30, 2026
b1f3fb7
kernel level done, next to enable it in python
zhou-yuxin Feb 5, 2026
b40f2fa
Merge branch 'flashinfer-ai:main' into fmhav2
jimmyzho Feb 5, 2026
39c911d
fp8 paged
jimmyzho Feb 5, 2026
da8852e
Merge branch 'fmhav2' of github.com:jimmyzho/flashinfer into fmhav2
jimmyzho Feb 5, 2026
65f2d56
pre-commit format
zhou-yuxin Feb 5, 2026
9a433ee
1. enum skip-softmax kernels in fmha_library.py;
zhou-yuxin Feb 5, 2026
2ea7730
fp8 unit tests, test refactor
jimmyzho Feb 6, 2026
0e5974f
Merge branch 'fmhav2' into add-skip-softmax
jimmyzho Feb 6, 2026
847af74
Merge pull request #1 from zhou-yuxin/add-skip-softmax
jimmyzho Feb 6, 2026
bb1739d
save-softmax, cleanup
jimmyzho Feb 6, 2026
a2bf401
save-softmax, cleanup
jimmyzho Feb 6, 2026
d0bd900
Merge branch 'fmhav2' of github.com:jimmyzho/flashinfer into fmhav2
jimmyzho Feb 6, 2026
03415b0
1. add skip_softmax_threshold_scale_factor;
zhou-yuxin Feb 11, 2026
7db51d1
Add bsz=1 and max_seqlen=16384 and skip_softmax_threshold_scale_facto…
bobboli Feb 13, 2026
063afc8
max seq len and fixes
jimmyzho Feb 13, 2026
ad4653b
Merge branch 'fmhav2' into add-skip-softmax
jimmyzho Feb 13, 2026
e8ebf7c
Merge pull request #2 from zhou-yuxin/add-skip-softmax
jimmyzho Feb 13, 2026
d56e2d3
style
jimmyzho Feb 14, 2026
6d86e4a
Merge branch 'fmhav2' of github.com:jimmyzho/flashinfer into fmhav2
jimmyzho Feb 14, 2026
3e13201
Split skip softmax test out of the base tests to reduce combos.
bobboli Feb 17, 2026
c05a4b9
Merge pull request #3 from bobboli/pr-2446
jimmyzho Feb 17, 2026
56cd4d0
adjust tol
jimmyzho Feb 17, 2026
d049a41
feat: add SM120 support for fmha_v2 flash attention kernels
blake-snc Feb 20, 2026
b5a13e1
cleanup, rm non_blocking, rm overhead
jimmyzho Feb 24, 2026
ae4aaed
Merge pull request #4 from blake-snc/fmhav2-sm120
jimmyzho Feb 24, 2026
c2b39fa
Merge branch 'flashinfer-ai:main' into fmhav2
jimmyzho Feb 25, 2026
d4b014a
cleanup, add chunked prefill chunked attention tests
jimmyzho Mar 4, 2026
53f5789
Merge remote-tracking branch 'origin/main' into fmhav2
jimmyzho Mar 4, 2026
bb1a87a
Only JIT the kernels that belongs to the corresponding CUDA arch.
bobboli Mar 4, 2026
36db6d8
Refactor input_layout.
bobboli Mar 4, 2026
0e8fd66
Fix error.
bobboli Mar 4, 2026
c668515
fix function params, fix sm120
jimmyzho Mar 4, 2026
e8365c7
Merge branch 'main' into fmhav2
jimmyzho Mar 4, 2026
0d83dfc
Merge branch 'fmhav2' into pr-2446
jimmyzho Mar 4, 2026
a86f591
Refactor target SM version checks
jimmyzho Mar 4, 2026
ff1c5a7
kernel inclusion
jimmyzho Mar 4, 2026
f2936d3
delete target
jimmyzho Mar 4, 2026
feccb1a
Merge pull request #5 from bobboli/pr-2446
jimmyzho Mar 4, 2026
c778991
docstring, sm120 unit test fix
jimmyzho Mar 5, 2026
4cde977
skip hanginging config, global workspace alloc testing
jimmyzho Mar 6, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 0 additions & 196 deletions csrc/fmha_v2/convert.cu

This file was deleted.

6 changes: 3 additions & 3 deletions csrc/fmha_v2/fmha/gmem_tile_o_packed.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ struct Hmma_gmem_tile_o {
//
// row_offset += binfo.bidx * VALID_BYTES_PER_ROW;
//
row_offset += binfo.bidx * valid_bytes_per_row;
row_offset += (int64_t)binfo.bidx * valid_bytes_per_row;

// Assemble the final pointer.
o_ptr_ += row_offset + col_in_bytes_;
Expand Down Expand Up @@ -753,7 +753,7 @@ struct Gmem_tile_o_8bit {
// The amount of bytes per row without padding (runtime).
int const valid_bytes_per_row = params.dv * BYTES_PER_ELEMENT;
// Take the batch/head offset into account.
row_offset += block_info.bidx * valid_bytes_per_row;
row_offset += (int64_t)block_info.bidx * valid_bytes_per_row;
// Assemble the final pointer.
o_ptr_ += row_offset + col_in_bytes_;

Expand Down Expand Up @@ -1088,7 +1088,7 @@ struct Gmem_tile_o_16bit {
// The amount of bytes per row without padding (runtime).
int const valid_bytes_per_row = params.dv * BYTES_PER_ELEMENT;
// Take the batch/head offset into account.
row_offset += block_info.bidx * valid_bytes_per_row;
row_offset += (int64_t)block_info.bidx * valid_bytes_per_row;
// Assemble the final pointer.
o_ptr_ += row_offset + col_in_bytes_;

Expand Down
6 changes: 3 additions & 3 deletions csrc/fmha_v2/fmha/gmem_tile_ps.h
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,7 @@ struct Gmem_tile_ps {
int col = warp / Cta_tile::WARPS_M * Mma_tile::N_PER_MMA + lane % 4 * ELEMENTS_PER_STG;

// The offset of the 1st row written by the thread. We store the P matrix interleaved.
int64_t row_offset = (int64_t)row * params_stride_in_bytes_ + bidx * BYTES_PER_ROW;
int64_t row_offset = (int64_t)row * params_stride_in_bytes_ + (int64_t)bidx * BYTES_PER_ROW;
// Finalize the pointer.
ptr_ += row_offset + col * BYTES_PER_ELEMENT;
}
Expand Down Expand Up @@ -654,7 +654,7 @@ struct Gmem_tile_ps<Volta_hmma_fp16_traits, Cta_tile, 16> {

// The offset of the 1st row written by the thread. We store the P matrix interleaved.
int64_t row_offset =
(int64_t)row * params_stride_in_bytes_ + bidx * BYTES_PER_ROW + cta_row_offset;
(int64_t)row * params_stride_in_bytes_ + (int64_t)bidx * BYTES_PER_ROW + cta_row_offset;

// Finalize the pointer.
ptr_ += row_offset + col * BYTES_PER_ELEMENT;
Expand Down Expand Up @@ -760,7 +760,7 @@ struct Gmem_tile_ps_hopper {
int col = warpgroup_idx * Mma_tile::N_PER_MMA + lane % 4 * ELEMENTS_PER_STG;

// The offset of the 1st row written by the thread. We store the P matrix interleaved.
int64_t row_offset = (int64_t)row * params_stride_in_bytes_ + bidx * bytes_per_row;
int64_t row_offset = (int64_t)row * params_stride_in_bytes_ + (int64_t)bidx * bytes_per_row;
// Finalize the pointer.
ptr_ += row_offset + col * BYTES_PER_ELEMENT;
}
Expand Down
6 changes: 3 additions & 3 deletions csrc/fmha_v2/fmha/hopper/gmem_tile_o_packed.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ struct Gmem_tile_o_hopper_16bits {

// The offset of the 1st row written by the thread. We store the P matrix interleaved.
int64_t row_offset =
(int64_t)row_ * params_o_stride_in_bytes_ + block_info.bidx * BYTES_PER_ROW;
(int64_t)row_ * params_o_stride_in_bytes_ + (int64_t)block_info.bidx * BYTES_PER_ROW;
// Finalize the pointer.
o_ptr_ += row_offset + col * BYTES_PER_ELEMENT;
}
Expand Down Expand Up @@ -599,7 +599,7 @@ struct Gmem_tile_o_gmma_32bit_8bit {

// The offset of the 1st row written by the thread. We store the P matrix interleaved.
int64_t row_offset =
(int64_t)row_ * params_o_stride_in_bytes_ + block_info.bidx * BYTES_PER_ROW;
(int64_t)row_ * params_o_stride_in_bytes_ + (int64_t)block_info.bidx * BYTES_PER_ROW;
// Finalize the pointer.
o_ptr_ += row_offset + col_ * BYTES_PER_ELEMENT;
}
Expand Down Expand Up @@ -1065,7 +1065,7 @@ struct Gmem_tile_o_qgmma_fp32_16bits {

// The offset of the 1st row written by the thread. We store the P matrix interleaved.
int64_t row_offset =
(int64_t)row_ * params_o_stride_in_bytes_ + block_info.bidx * BYTES_PER_ROW;
(int64_t)row_ * params_o_stride_in_bytes_ + (int64_t)block_info.bidx * BYTES_PER_ROW;
// Finalize the pointer.
o_ptr_ += row_offset + col * BYTES_PER_ELEMENT;
}
Expand Down
7 changes: 4 additions & 3 deletions csrc/fmha_v2/fmha/kernel_traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,13 +195,14 @@ struct Kernel_traits_ {

// Compute the total BMM2_MMAS_K (might not the same as Mma_tile_o::MMAS_K if the granular tiling
// is used).
static_assert(S % CTA_O_TILE_K == 0, "");
// S=0 for flash attention (variable sequence length): tile counts are determined at runtime.
static_assert(S == 0 || S % CTA_O_TILE_K == 0, "");

enum { TOTAL_BMM2_MMAS_K = Mma_tile_o::MMAS_K * (S / CTA_O_TILE_K) };
enum { TOTAL_BMM2_MMAS_K = S == 0 ? 0 : Mma_tile_o::MMAS_K * (S / CTA_O_TILE_K) };

// Constraints on the K dimension.
static_assert(Mma_tile_p::K_PER_MMA <= static_cast<int>(D));
static_assert(Mma_tile_o::K_PER_MMA <= S);
static_assert(S == 0 || Mma_tile_o::K_PER_MMA <= S);

// The version.
enum { VERSION = VERSION_ };
Expand Down
50 changes: 46 additions & 4 deletions csrc/fmha_v2/fmha/warpspec/compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ struct Compute {
USE_CUSTOM_MASK ? (head_info.mask_sum_s + q_step_idx * STEP_Q + local_q_tile_offset) \
: (q_step_idx * STEP_Q + head_info.q_tile_offset), \
kv_step_idx * STEP_KV, sage_scale_row, cbr, cbr_v, mutex_accessor, \
kv_step_idx == kv_idx_end - 1);
&shared->skip_softmax_votes[kv_step_idx & 1][warpgroup_id], kv_step_idx == kv_idx_end - 1);

////////////////////////////////////////////////////////////////////////////////////////////////

Expand Down Expand Up @@ -277,6 +277,12 @@ struct Compute {
int const actual_kv_seqlen =
SEPARATE_Q_KV_BUFFER ? head_info.actual_kv_seqlen : actual_q_seqlen;

// Update threshold of Skip-Softmax
if constexpr (Kernel_traits::ENABLE_SKIP_SOFTMAX) {
softmax.skip_softmax_threshold =
params.skip_softmax_threshold_scale_factor / actual_kv_seqlen;
}
Comment on lines +280 to +284
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟑 Minor

Guard skip-softmax threshold against zero-length KV.
If actual_kv_seqlen is 0, the division yields inf/NaN and can taint skip decisions.

πŸ›‘οΈ Proposed fix
     if constexpr (Kernel_traits::ENABLE_SKIP_SOFTMAX) {
-      softmax.skip_softmax_threshold =
-          params.skip_softmax_threshold_scale_factor / actual_kv_seqlen;
+      int denom = actual_kv_seqlen > 0 ? actual_kv_seqlen : 1;
+      softmax.skip_softmax_threshold =
+          params.skip_softmax_threshold_scale_factor / denom;
     }
πŸ€– Prompt for AI Agents
In `@csrc/fmha_v2/fmha/warpspec/compute.h` around lines 280 - 284, Guard the
division by zero when computing softmax.skip_softmax_threshold: inside the
Kernel_traits::ENABLE_SKIP_SOFTMAX block check actual_kv_seqlen > 0 before doing
the division (using params.skip_softmax_threshold_scale_factor /
actual_kv_seqlen) and, if actual_kv_seqlen == 0, assign a safe non-NaN value
(e.g. std::numeric_limits<float>::max()) to softmax.skip_softmax_threshold;
include <limits> if needed and keep the change localized around the current
block with the same symbols (Kernel_traits::ENABLE_SKIP_SOFTMAX,
softmax.skip_softmax_threshold, params.skip_softmax_threshold_scale_factor,
actual_kv_seqlen).


// Calculate the alibi head_scaling_factor.
float alibi_head_scale = APPLY_ALIBI ? get_alibi_head_scaling_factor<AlibiParams>(
head_info.bidh, params.alibi_params)
Expand Down Expand Up @@ -411,6 +417,12 @@ struct Compute {
}
}
}
#ifdef SKIP_SOFTMAX_STAT
if (tidx == 0) {
atomicAdd(params.skip_softmax_total_blocks, softmax.total_blocks);
atomicAdd(params.skip_softmax_skipped_blocks, softmax.skipped_blocks);
}
#endif
}

////////////////////////////////////////////////////////////////////////////////////////////////
Expand All @@ -421,7 +433,14 @@ struct Compute {
float (&p_max)[Mma_tile_p::CORES_M], float (&p_sum)[Mma_tile_p::CORES_M], int const tidx,
int const actual_kv_seqlen, float const alibi_head_scale, int const row_offset,
int const col_offset, int const sage_scale_row, Circular_buffer_q_reader& cbr,
Circular_buffer_kv_reader& cbr_v, OrderedMutexAccessor& mutex, bool complete = false) {
Circular_buffer_kv_reader& cbr_v, OrderedMutexAccessor& mutex, uint32_t* skip_softmax_vote,
bool complete = false) {
// Skip-softmax vote initialization
if (tidx == 0) {
// Note that we need a named_barrier_wait in compute_single_tile to make sure init is before
// voting.
*skip_softmax_vote = 1;
}
Comment on lines +438 to +443
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Initialize skip-softmax votes per warpgroup, not only thread 0.
tidx == 0 only resets the vote for warpgroup 0; warpgroup 1 reads stale values and can skip incorrectly.

πŸ”§ Proposed fix
-    if (tidx == 0) {
+    if ((tidx % 128) == 0) {
       // Note that we need a named_barrier_wait in compute_single_tile to make sure init is before
       // voting.
       *skip_softmax_vote = 1;
     }
πŸ“ Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
// Skip-softmax vote initialization
if (tidx == 0) {
// Note that we need a named_barrier_wait in compute_single_tile to make sure init is before
// voting.
*skip_softmax_vote = 1;
}
// Skip-softmax vote initialization
if ((tidx % 128) == 0) {
// Note that we need a named_barrier_wait in compute_single_tile to make sure init is before
// voting.
*skip_softmax_vote = 1;
}
πŸ€– Prompt for AI Agents
In `@csrc/fmha_v2/fmha/warpspec/compute.h` around lines 438 - 443, The code
currently sets *skip_softmax_vote = 1 only when tidx == 0 which initializes the
vote for warpgroup 0 only; change the initialization to run once per warpgroup
by having each warp's leader initialize the vote (e.g., use the lane/warp check
instead of tidx == 0 β€” for example check laneId == 0 or (tidx & (WARP_SIZE-1))
== 0) so every warpgroup writes *skip_softmax_vote = 1 before voting; keep the
existing comment about needing a named_barrier_wait in compute_single_tile to
ensure ordering.

// load the scales of K/V from global memory
#define LOAD_SCALES_KV(dst, which, blocks_per_step, block_size) \
if constexpr (block_size > 0) { \
Expand Down Expand Up @@ -453,6 +472,10 @@ struct Compute {
// Ctile_p is only used once by each n step.
ctile_p.clear();

// If skip_softmax is enabled, make sure there is no racing between the initialization and
// writing of skip_softmax_vote.
named_barrier_wait(Kernel_traits::SKIP_SOFTMAX_BARRIER_ID + threadIdx.x / 128, 128);

// BMM1 (Q x K').
warpgroup_arrive();

Expand Down Expand Up @@ -513,8 +536,27 @@ struct Compute {
softmax.apply_alibi_and_mask<APPLY_MASK>(ctile_p, params.alibi_params, alibi_head_scale,
actual_kv_seqlen, row_offset, col_offset);

// Softmax Exp, max/sum, and update scales.
softmax.compute_and_update_scale<IS_FIRST_COL>(p_max, p_sum);
// Softmax Exp, max/sum, and update scales. If returns false we skip the rest.
if (!softmax.compute_and_update_scale<IS_FIRST_COL>(p_max, p_sum, skip_softmax_vote)) {
if constexpr (ENABLE_MUTEX && Kernel_traits::ELEMENT_BYTES == 1) {
// Notify another warpgroup to execute QGMMA.
mutex.named_bar_arrive();
}
// Need to wait V, otherwise compute-sanitizer synccheck will fail.
int ready2 = cbr_v.peek();
if (!ready2) {
cbr_v.wait();
}

#pragma unroll
// Advance V descriptor by the same amount as the BMM2 loop would,
// so that the descriptor stays in sync for subsequent KV steps.
for (int kbi = 0; kbi < BMM2_MMAS_K_GROUPS - 1; kbi++) {
ctile_o.increment_gmma_desc_group();
}

return;
}

// experiments show that here is the best place to load scales of V
float scales_v[SAGE_BLOCKS_PER_STEP_V];
Expand Down
Loading
Loading