diff --git a/include/flashinfer/trtllm/fmha/fmhaKernels.cuh b/include/flashinfer/trtllm/fmha/fmhaKernels.cuh index 1e64fb329f..15d950a4ff 100644 --- a/include/flashinfer/trtllm/fmha/fmhaKernels.cuh +++ b/include/flashinfer/trtllm/fmha/fmhaKernels.cuh @@ -356,7 +356,7 @@ class TllmGenFmhaKernel { } else { // Compute numTokensPerCtaQ where each CTA must process complete numGroupedHeadsQ. // Note that each CTA must process complete numHeadsQPerKv. - int numTokensPerCtaQ = kernelMeta.mStepQ / params.mNumHeadsQPerKv; + int numTokensPerCtaQ = std::max(1, kernelMeta.mStepQ / params.mNumHeadsQPerKv); // Group both headsQ and tokensQ into one CTA. numCtasPerSeqQ = flashinfer::ceil_div(params.mMaxSeqLenQ, numTokensPerCtaQ); } @@ -747,7 +747,7 @@ class TllmGenFmhaKernel { int& tileSizeQ = selectKernelParams.mTileSizeQ; // Mixed precision kernels don't work with groupsTokensHeadsQ = true for now. - if (mDtypeQ != mDtypeKv || mDtypeOut == DATA_TYPE_E2M1) { + if (mDtypeQ != mDtypeKv) { tileSizeQ = params.mNumHeadsQPerKv <= 8 ? 8 : 16; kernelType = FmhaKernelType::SwapsMmaAbForGeneration; return; diff --git a/tests/attention/test_trtllm_gen_attention.py b/tests/attention/test_trtllm_gen_attention.py index efe5981dd3..fb1807c655 100755 --- a/tests/attention/test_trtllm_gen_attention.py +++ b/tests/attention/test_trtllm_gen_attention.py @@ -701,6 +701,7 @@ def _test_trtllm_batch_prefill( ) plan_params["q_data_type"] = q.dtype plan_params["kv_data_type"] = kv_cache.dtype + plan_params["o_data_type"] = DTYPE_MAP[o_dtype] wrapper_trtllm_gen.plan(**plan_params) output_wrapper = wrapper_trtllm_gen.run( q_input, @@ -915,15 +916,6 @@ def _test_trtllm_batch_decode( if backend == "xqa" and not uses_shared_paged_kv_idx: pytest.skip("xqa backend does not support non-shared page indices") - if o_dtype == "nvfp4" and ( - q_len_per_req is not None - and q_len_per_req > 1 - or max_q_len is not None - and max_q_len > 1 - ): - # todo(Yingyi): add support for nvfp4 with speculative decoding - pytest.skip("nvfp4 is not supported for q_len_per_req > 1 or max_q_len > 1 yet") - if backend == "trtllm-gen" and o_dtype == "fp8" and q_dtype != "fp8": pytest.skip("trtllm-gen backend only supports fp8 output for fp8 query") @@ -1169,6 +1161,7 @@ def _test_trtllm_batch_decode( ) plan_params["q_data_type"] = q.dtype plan_params["kv_data_type"] = kv_cache.dtype + plan_params["o_data_type"] = DTYPE_MAP[o_dtype] wrapper_trtllm_gen.plan(**plan_params) output_wrapper = wrapper_trtllm_gen.run( q_input, @@ -1258,7 +1251,7 @@ def _test_trtllm_batch_decode( @pytest.mark.parametrize("enable_pdl", [True, False, None]) @pytest.mark.parametrize("enable_sink", [True, False]) @pytest.mark.parametrize("max_in_kv_len", [110]) -@pytest.mark.parametrize("head_dim", [128]) +@pytest.mark.parametrize("head_dim", [128, 256]) @pytest.mark.parametrize("non_contiguous_query", [False, True]) @pytest.mark.parametrize("skips_softmax", [False, True]) @pytest.mark.parametrize("uses_shared_paged_kv_idx", [True, False]) @@ -1721,25 +1714,27 @@ def make_query_non_contiguous( @pytest.mark.parametrize("backend", ["trtllm-gen"]) @pytest.mark.parametrize("kv_layout", ["HND", "NHD"]) @pytest.mark.parametrize( - "batch_size,max_q_len,page_size,num_kv_heads,head_grp_size", + "batch_size,max_q_len,page_size,num_kv_heads,head_grp_size,head_dim", [ - (4, 1, 16, 2, 1), - (4, 1, 32, 2, 5), - (4, 2, 64, 2, 5), - (4, 3, 32, 2, 5), - (4, 3, 64, 2, 1), - (4, 4, 64, 4, 1), - (4, 5, 64, 4, 8), - (128, 1, 64, 2, 5), - (128, 2, 32, 4, 1), - (128, 3, 16, 4, 8), - (128, 4, 16, 2, 5), - (128, 5, 16, 2, 5), - (256, 1, 64, 4, 8), - (256, 2, 16, 2, 8), - (256, 3, 64, 4, 5), - (256, 4, 32, 2, 8), - (256, 5, 32, 2, 1), + (4, 1, 16, 2, 1, 128), + (4, 1, 32, 2, 5, 128), + (4, 2, 64, 2, 5, 128), + (4, 3, 32, 2, 5, 128), + (4, 3, 64, 2, 1, 128), + (4, 4, 64, 4, 1, 128), + (4, 5, 64, 4, 8, 128), + # Iterate over head_dim 128, 256 for these configs to simplify + *[(bs, 4, 64, 4, 16, hd) for bs in [4, 8, 16, 32] for hd in [128, 256]], + (128, 1, 64, 2, 5, 128), + (128, 2, 32, 4, 1, 128), + (128, 3, 16, 4, 8, 128), + (128, 4, 16, 2, 5, 128), + (128, 5, 16, 2, 5, 128), + (256, 1, 64, 4, 8, 256), + (256, 2, 16, 2, 8, 256), + (256, 3, 64, 4, 5, 256), + (256, 4, 32, 2, 8, 256), + (256, 16, 32, 2, 8, 256), ], ) @pytest.mark.parametrize("window_left", [-1, 127]) @@ -1761,7 +1756,6 @@ def make_query_non_contiguous( @pytest.mark.parametrize("enable_pdl", [True, False, None]) @pytest.mark.parametrize("enable_sink", [True, False]) @pytest.mark.parametrize("max_in_kv_len", [110]) -@pytest.mark.parametrize("head_dim", [128]) @pytest.mark.parametrize("skips_softmax", [False, True]) @pytest.mark.parametrize("uses_shared_paged_kv_idx", [False, True]) def test_trtllm_batch_decode_spec(