diff --git a/include/flashinfer/trtllm/fmha/fmhaKernels.cuh b/include/flashinfer/trtllm/fmha/fmhaKernels.cuh index ca71bb0b88..e08820a694 100644 --- a/include/flashinfer/trtllm/fmha/fmhaKernels.cuh +++ b/include/flashinfer/trtllm/fmha/fmhaKernels.cuh @@ -611,8 +611,9 @@ class TllmGenFmhaKernel { float globalModelingKernelTime = FLT_MAX; // Loop over each candidate tile size. for (int tileSizeQ : candidateTileSizesQ) { - // Only consider candidates <= default tileSizeQ. - if (tileSizeQ > defaultTileSizeQ) { + // Only consider candidates <= default tileSizeQ and ensure each CTA processes full + // numHeadsQPerKv. + if (tileSizeQ > defaultTileSizeQ || tileSizeQ < params.mNumHeadsQPerKv) { continue; }