Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion csrc/fmha_v2_run.cu
Original file line number Diff line number Diff line change
Expand Up @@ -250,9 +250,15 @@ static inline void determine_launch_params(
launch_params.device_l2_cache_size = props.l2CacheSize;

// threshold for adopting flash attention or warp_specialized kernels.
// For Q_PAGED_KV layouts, flash attention is required regardless of
// seqlen — the paged KV dispatch path only supports flash attention
// kernels, and `s` is the padded max_kv_len (not per-request seq_lens,
// which the kernel predicates on separately). Falling back to the
// non-flash path produces silently wrong output when max_kv_len < 16.
bool const is_paged_kv = input_layout == Attention_input_layout::Q_PAGED_KV;
launch_params.flash_attention =
(data_type == DATA_TYPE_FP16 || data_type == DATA_TYPE_BF16 || data_type == DATA_TYPE_E4M3) &&
(s >= 16 && d >= 16) && !force_non_flash_attention;
(is_paged_kv || (s >= 16 && d >= 16)) && !force_non_flash_attention;

// enable warp_speialized kernels when s >= 512 on hopper
// note that warp_speialized kernels need flash attention + tma
Expand Down
Loading