Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions include/flashinfer/trtllm/fmha/fmhaKernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ class TllmGenFmhaKernel {
} else {
// Compute numTokensPerCtaQ where each CTA must process complete numGroupedHeadsQ.
// Note that each CTA must process complete numHeadsQPerKv.
int numTokensPerCtaQ = kernelMeta.mStepQ / params.mNumHeadsQPerKv;
int numTokensPerCtaQ = std::max(1, kernelMeta.mStepQ / params.mNumHeadsQPerKv);
// Group both headsQ and tokensQ into one CTA.
numCtasPerSeqQ = flashinfer::ceil_div(params.mMaxSeqLenQ, numTokensPerCtaQ);
}
Expand Down Expand Up @@ -747,7 +747,7 @@ class TllmGenFmhaKernel {
int& tileSizeQ = selectKernelParams.mTileSizeQ;

// Mixed precision kernels don't work with groupsTokensHeadsQ = true for now.
if (mDtypeQ != mDtypeKv || mDtypeOut == DATA_TYPE_E2M1) {
if (mDtypeQ != mDtypeKv) {
tileSizeQ = params.mNumHeadsQPerKv <= 8 ? 8 : 16;
kernelType = FmhaKernelType::SwapsMmaAbForGeneration;
return;
Expand Down
52 changes: 23 additions & 29 deletions tests/attention/test_trtllm_gen_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,7 @@ def _test_trtllm_batch_prefill(
)
plan_params["q_data_type"] = q.dtype
plan_params["kv_data_type"] = kv_cache.dtype
plan_params["o_data_type"] = DTYPE_MAP[o_dtype]
wrapper_trtllm_gen.plan(**plan_params)
output_wrapper = wrapper_trtllm_gen.run(
q_input,
Expand Down Expand Up @@ -915,15 +916,6 @@ def _test_trtllm_batch_decode(
if backend == "xqa" and not uses_shared_paged_kv_idx:
pytest.skip("xqa backend does not support non-shared page indices")

if o_dtype == "nvfp4" and (
q_len_per_req is not None
and q_len_per_req > 1
or max_q_len is not None
and max_q_len > 1
):
# todo(Yingyi): add support for nvfp4 with speculative decoding
pytest.skip("nvfp4 is not supported for q_len_per_req > 1 or max_q_len > 1 yet")

if backend == "trtllm-gen" and o_dtype == "fp8" and q_dtype != "fp8":
pytest.skip("trtllm-gen backend only supports fp8 output for fp8 query")

Expand Down Expand Up @@ -1169,6 +1161,7 @@ def _test_trtllm_batch_decode(
)
plan_params["q_data_type"] = q.dtype
plan_params["kv_data_type"] = kv_cache.dtype
plan_params["o_data_type"] = DTYPE_MAP[o_dtype]
wrapper_trtllm_gen.plan(**plan_params)
output_wrapper = wrapper_trtllm_gen.run(
q_input,
Expand Down Expand Up @@ -1258,7 +1251,7 @@ def _test_trtllm_batch_decode(
@pytest.mark.parametrize("enable_pdl", [True, False, None])
@pytest.mark.parametrize("enable_sink", [True, False])
@pytest.mark.parametrize("max_in_kv_len", [110])
@pytest.mark.parametrize("head_dim", [128])
@pytest.mark.parametrize("head_dim", [128, 256])
@pytest.mark.parametrize("non_contiguous_query", [False, True])
@pytest.mark.parametrize("skips_softmax", [False, True])
@pytest.mark.parametrize("uses_shared_paged_kv_idx", [True, False])
Expand Down Expand Up @@ -1721,25 +1714,27 @@ def make_query_non_contiguous(
@pytest.mark.parametrize("backend", ["trtllm-gen"])
@pytest.mark.parametrize("kv_layout", ["HND", "NHD"])
@pytest.mark.parametrize(
"batch_size,max_q_len,page_size,num_kv_heads,head_grp_size",
"batch_size,max_q_len,page_size,num_kv_heads,head_grp_size,head_dim",
[
(4, 1, 16, 2, 1),
(4, 1, 32, 2, 5),
(4, 2, 64, 2, 5),
(4, 3, 32, 2, 5),
(4, 3, 64, 2, 1),
(4, 4, 64, 4, 1),
(4, 5, 64, 4, 8),
(128, 1, 64, 2, 5),
(128, 2, 32, 4, 1),
(128, 3, 16, 4, 8),
(128, 4, 16, 2, 5),
(128, 5, 16, 2, 5),
(256, 1, 64, 4, 8),
(256, 2, 16, 2, 8),
(256, 3, 64, 4, 5),
(256, 4, 32, 2, 8),
(256, 5, 32, 2, 1),
(4, 1, 16, 2, 1, 128),
(4, 1, 32, 2, 5, 128),
(4, 2, 64, 2, 5, 128),
(4, 3, 32, 2, 5, 128),
(4, 3, 64, 2, 1, 128),
(4, 4, 64, 4, 1, 128),
(4, 5, 64, 4, 8, 128),
# Iterate over head_dim 128, 256 for these configs to simplify
*[(bs, 4, 64, 4, 16, hd) for bs in [4, 8, 16, 32] for hd in [128, 256]],
(128, 1, 64, 2, 5, 128),
(128, 2, 32, 4, 1, 128),
(128, 3, 16, 4, 8, 128),
(128, 4, 16, 2, 5, 128),
(128, 5, 16, 2, 5, 128),
(256, 1, 64, 4, 8, 256),
(256, 2, 16, 2, 8, 256),
(256, 3, 64, 4, 5, 256),
(256, 4, 32, 2, 8, 256),
(256, 16, 32, 2, 8, 256),
],
)
@pytest.mark.parametrize("window_left", [-1, 127])
Expand All @@ -1761,7 +1756,6 @@ def make_query_non_contiguous(
@pytest.mark.parametrize("enable_pdl", [True, False, None])
@pytest.mark.parametrize("enable_sink", [True, False])
@pytest.mark.parametrize("max_in_kv_len", [110])
@pytest.mark.parametrize("head_dim", [128])
@pytest.mark.parametrize("skips_softmax", [False, True])
@pytest.mark.parametrize("uses_shared_paged_kv_idx", [False, True])
def test_trtllm_batch_decode_spec(
Expand Down
Loading