Skip to content

[TRTLLM-Gen Fmha] add optimized trtllm-gen decode kernels for high throughput + speculative decoding#2265

Merged
yzh119 merged 6 commits intoflashinfer-ai:mainfrom
PerkzZheng:user/perkzz/trtllm-gen-groups-tokens-heads
Jan 7, 2026
Merged

[TRTLLM-Gen Fmha] add optimized trtllm-gen decode kernels for high throughput + speculative decoding#2265
yzh119 merged 6 commits intoflashinfer-ai:mainfrom
PerkzZheng:user/perkzz/trtllm-gen-groups-tokens-heads

Conversation

@PerkzZheng
Copy link
Contributor

@PerkzZheng PerkzZheng commented Dec 24, 2025

📌 Description

This MR adds the optimized decode attention kernels for high throughput (large batch size) + speculative decoding (seqlen_q > 1).

See below for speedups (collected by benchmarks/flashinfer_benchmark.py). The seqlenKv is 16K for all cases.

test case median_time_ms median_time_ms (opt) speedup
Qwen3-480B-fp8_e4m3-batchSize8-seqLenQ2 0.057 0.046 1.24
Qwen3-480B-fp8_e4m3-batchSize16-seqLenQ2 0.11 0.083 1.33
Qwen3-480B-fp8_e4m3-batchSize32-seqLenQ2 0.213 0.168 1.27
Qwen3-480B-fp8_e4m3-batchSize40-seqLenQ2 0.266 0.241 1.10
Qwen3-480B-fp8_e4m3-batchSize64-seqLenQ2 0.432 0.336 1.29
Qwen3-480B-fp8_e4m3-batchSize8-seqLenQ4 0.109 0.048 2.27
Qwen3-480B-fp8_e4m3-batchSize16-seqLenQ4 0.212 0.083 2.55
Qwen3-480B-fp8_e4m3-batchSize32-seqLenQ4 0.371 0.168 2.21
Qwen3-480B-fp8_e4m3-batchSize40-seqLenQ4 0.472 0.245 1.93
Qwen3-480B-fp8_e4m3-batchSize64-seqLenQ4 0.736 0.348 2.11
Qwen3-480B-fp8_e4m3-batchSize8-seqLenQ8 0.212 0.061 3.48
Qwen3-480B-fp8_e4m3-batchSize16-seqLenQ8 0.37 0.106 3.49
Qwen3-480B-fp8_e4m3-batchSize32-seqLenQ8 0.732 0.239 3.06
Qwen3-480B-fp8_e4m3-batchSize40-seqLenQ8 0.937 0.321 2.92
Qwen3-480B-fp8_e4m3-batchSize64-seqLenQ8 1.456 0.484 3.01
GPT-OSS-fp8_e4m3-batchSize8-seqLenQ2 0.051 0.03 1.70
GPT-OSS-fp8_e4m3-batchSize16-seqLenQ2 0.098 0.054 1.81
GPT-OSS-fp8_e4m3-batchSize32-seqLenQ2 0.188 0.104 1.81
GPT-OSS-fp8_e4m3-batchSize40-seqLenQ2 0.234 0.15 1.56
GPT-OSS-fp8_e4m3-batchSize64-seqLenQ2 0.332 0.199 1.67
GPT-OSS-fp8_e4m3-batchSize8-seqLenQ4 0.099 0.038 2.61
GPT-OSS-fp8_e4m3-batchSize16-seqLenQ4 0.188 0.07 2.69
GPT-OSS-fp8_e4m3-batchSize32-seqLenQ4 0.332 0.136 2.44
GPT-OSS-fp8_e4m3-batchSize40-seqLenQ4 0.418 0.2 2.09
GPT-OSS-fp8_e4m3-batchSize64-seqLenQ4 0.647 0.265 2.44
GPT-OSS-fp8_e4m3-batchSize8-seqLenQ8 0.188 0.039 4.82
GPT-OSS-fp8_e4m3-batchSize16-seqLenQ8 0.332 0.065 5.11
GPT-OSS-fp8_e4m3-batchSize32-seqLenQ8 0.647 0.126 5.13
GPT-OSS-fp8_e4m3-batchSize40-seqLenQ8 0.83 0.185 4.49
GPT-OSS-fp8_e4m3-batchSize64-seqLenQ8 1.29 0.245 5.27

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Bug Fixes

    • Generation attention now enforces causal masking during token generation.
  • Performance / Refactor

    • Improved kernel selection and on-demand loading for better performance and GPU compatibility.
    • Added finer-grained tuning parameters for tile/grouping, tokens-per-CTA and inflation to enable more optimal kernel choices.
  • Chores

    • Updated FMHA artifact paths and checksums.
  • Tests

    • Expanded parameterized tests to cover larger batch decoding scenarios.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Dec 24, 2025

📝 Walkthrough

Walkthrough

Refactors trtllm FMHA kernel selection/launch to on-demand cubin loading with a new CtaLaunchParams API, expands kernel hashing/selection surface and TMA token-grouping logic, switches the trtllm generation path maskType to Causal (naming retained), updates artifact checksum/path, and adds head_dim=256 batch=32 tests.

Changes

Cohort / File(s) Summary
FMHA Kernel Selection & Launching Refactor
include/flashinfer/trtllm/fmha/fmhaKernels.cuh
Added isFamilySpecificSMPair; introduced public CtaLaunchParams; extended hash/key construction; changed signatures for hashID/hashFromRunnerParams/loadKernel; added selectKernel helpers; implemented on-demand module/cubin loading and caching; updated CTA/cluster launch wiring.
Kernel Parameter Infrastructure
include/flashinfer/trtllm/fmha/kernelParams.h
Added FastModDivInt32; added mInflateMax, mNumTokensPerCtaQ, and mNumHeadsQPerKvDivisor to KernelParams; makeTmaShapeStrideQ gains groupsTokensHeadsQ and now returns numTokensPerCtaQ; callers and setKernelParams updated.
Kernel Selection Parameters
include/flashinfer/trtllm/fmha/fmhaRunnerParams.h
TllmGenSelectKernelParams adds mNumTokensPerPage and mTileSizeQ; constructor initializes mNumTokensPerPage and defaults mTileSizeQ=128.
Kernel Launcher Masking
csrc/trtllm_fmha_kernel_launcher.cu
In trtllm_paged_attention_launcher generation branch, mMaskType changed from Dense to Causal; clarifying comment added explaining dense naming retained for performance and per-CTA equivalence.
Artifacts
flashinfer/artifacts.py
Updated ArtifactPath.TRTLLM_GEN_FMHA and CheckSumHash.TRTLLM_GEN_FMHA constants (new path and checksum).
Tests
tests/attention/test_trtllm_gen_attention.py
Added three parameterized test cases for head_dim=256 with batch_size=32 across page_size / num_kv_heads / head_grp_size combinations.
sequenceDiagram
  autonumber
  participant Host as Host (Runner)
  participant Selector as SelectKernelParams
  participant Loader as KernelLoader (cache)
  participant CUDA as CUDA Driver
  participant GPU as GPU Kernel

  Host->>Selector: prepare RunnerParams & SelectKernelParams
  Note right of Selector `#bfe9e0`: decide tileSizeQ, scheduler,\nmaskType, kernelType, numTokensPerPage
  Selector->>Loader: request kernel by hash (hashFromRunnerParams)
  alt cached
    Loader-->>Selector: return CUfunction + KernelMeta (cached)
  else not cached
    Loader->>CUDA: load module / cubin
    CUDA-->>Loader: CUfunction + KernelMeta
    Loader-->>Selector: cache & return CUfunction + KernelMeta
  end
  Host->>CUDA: configure grid/cluster using CtaLaunchParams
  CUDA->>GPU: launch kernel with smem per KernelMeta
  GPU-->>Host: execution complete
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related issues

Possibly related PRs

Suggested reviewers

  • cyx-6
  • aleozlx
  • joker-eph
  • wenscarl
  • nvmbreughe
  • djmmoss
  • bkryu

Poem

🐰 I hopped through kernels, hashes, and queues,
CtaLaunchParams snug in my paws and shoes.
Causal masks whispered while modules load fast,
Tiles and tokens arranged to hold fast.
Hop—FMHA refined, now let's run to the grass! 🥕

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly describes the main objective: adding optimized trtllm-gen decode kernels for high throughput and speculative decoding, which directly aligns with the changeset.
Description check ✅ Passed The description includes all required template sections: detailed explanation of changes, benchmark results demonstrating improvements, completed pre-commit checklist, and test verification.
Docstring Coverage ✅ Passed Docstring coverage is 80.00% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @PerkzZheng, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the performance of decode attention kernels, particularly for high-throughput scenarios involving large batch sizes and speculative decoding. By introducing optimized kernels and a refined dynamic kernel selection strategy, the changes aim to improve efficiency and reduce latency in generation-phase operations. The update also includes adjustments to mask handling, expanded kernel parameterization, and new test cases to ensure robustness and validate the reported speedups.

Highlights

  • Optimized Attention Kernels: Introduced new decode attention kernels specifically designed for high-throughput (large batch size) and speculative decoding (seqlen_q > 1), leading to significant speedups as demonstrated in the PR description benchmarks.
  • Dynamic Kernel Selection Logic: Implemented a more sophisticated kernel selection mechanism for generation-phase attention, including heuristics for MLA and GQA kernels, and an experimental kernel-timing model for grouping tokens and heads into a single CTA to find the optimal tileSizeQ.
  • Mask Type Refinement: Changed the default attention mask type for generation kernels from Dense to Causal in the launcher, with clarification that for single-token queries, a dense mask behaves like a causal mask for performance.
  • Kernel Parameterization Enhancements: Extended kernel parameters to include tileSizeQ and numTokensPerCtaQ, and introduced a FastModDivInt32 utility for efficient division operations within kernel parameters.
  • Artifact Updates: Updated compiled kernel artifacts and their checksums to reflect the new optimizations and kernel changes.
  • Test Coverage: Added new test cases to validate the performance and correctness of the batch decode attention with various configurations, specifically for head_dim_256.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@PerkzZheng PerkzZheng changed the title [TRTLLM-Gen Fmha] update trtllm-gen to support groups tokens and headsQ [TRTLLM-Gen Fmha] add optimized trtllm-gen decode kernels for high throughput + speculative decoding Dec 24, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces significant optimizations for speculative decoding in TRT-LLM's FMHA kernels by adding support for grouping tokens and heads. The changes are extensive, involving a major refactoring of the kernel selection logic, hash calculation, and kernel parameter structures. The introduction of a more modular, heuristic-based kernel selection mechanism is a notable improvement. My review focuses on a potential bug in the new performance heuristic and a minor improvement in a math utility function.

Comment on lines +559 to +655
void selectTileSizeQForGqaGeneration(RunnerParams const& params,
SelectKernelParams& selectKernelParams) const {
// Define the per-tile mainloop cost model for different tileSizeQ choices.
std::unordered_map<int, float> kernelMainloopCost = {
{128, 2.2}, // Cost factor when tileSizeQ = 128
{64, 1.68}, // Cost factor when tileSizeQ = 64
{32, 1.48}, // Cost factor when tileSizeQ = 32
{16, 1.2}, // Cost factor when tileSizeQ = 16
{8, 1.0} // Cost factor when tileSizeQ = 8
};

// Define the per-tile reduction cost model for different tileSizeQ choices.
std::unordered_map<int, float> kernelReductionCost = {
{128, 1.32}, // Reduction cost factor when tileSizeQ = 128
{64, 1.2}, // Reduction cost factor when tileSizeQ = 64
{32, 1.08}, // Reduction cost factor when tileSizeQ = 32
{16, 1.03}, // Reduction cost factor when tileSizeQ = 16
{8, 1.0} // Reduction cost factor when tileSizeQ = 8
};

// The reduction cost emulated as a sequence length factor.
float const kernelReductionSeqLenFactor = 128.0f;

// The parameters for launching the kernel.
CtaLaunchParams ctaLaunchParams;
// The copy of the selectKernelParams, which makes sure it won't modify the original
// selectKernelParams when computing the number of CTAs.
SelectKernelParams selectKernelParamsCopy = selectKernelParams;
// Load the kernel.
auto [func, kernelMeta] = loadKernel(params, selectKernelParamsCopy);
// Compute numCtasX, numCtasY and numCtasZ.
computeCtaAndClusterConfig(ctaLaunchParams, params, kernelMeta, selectKernelParamsCopy);

// If there are no free SMs or tileSizeQ is already the smallest one, skip the heuristic
// selection.
if (ctaLaunchParams.mNumCtasX * ctaLaunchParams.mNumCtasY * ctaLaunchParams.mNumCtasZ * 2 >
params.mMultiProcessorCount ||
selectKernelParamsCopy.mTileSizeQ <= 8) {
// No need to select the kernel further.
return;
}

// Candidate tile sizes for tileSizeQ to explore.
int const candidateTileSizesQ[] = {128, 64, 32, 16, 8};

// The default tileSizeQ.
int defaultTileSizeQ = selectKernelParamsCopy.mTileSizeQ;
// The selected tileSizeQ.
int selectedTileSizeQ = selectKernelParamsCopy.mTileSizeQ;

// The minimum modeling kernel time.
float globalModelingKernelTime = FLT_MAX;
// Loop over each candidate tile size.
for (int tileSizeQ : candidateTileSizesQ) {
// Only consider candidates <= default tileSizeQ.
if (tileSizeQ > defaultTileSizeQ) {
continue;
}

// Compute the number of CTAs.
computeCtaAndClusterConfig(ctaLaunchParams, params, kernelMeta, selectKernelParamsCopy);

// Compute the seqLenPerCtaKv.
int32_t seqLenPerCtaKv =
flashinfer::ceil_div(flashinfer::ceil_div(params.mMaxSeqLenKv, kernelMeta.mStepKv),
ctaLaunchParams.mMaxNumCtasKv) *
kernelMeta.mStepKv;

// Compute the modeling kernel time = mainloop cost + reduction cost.
float modelingKernelTime = kernelMainloopCost[tileSizeQ] * seqLenPerCtaKv +
kernelReductionCost[tileSizeQ] * kernelReductionSeqLenFactor *
ctaLaunchParams.mMaxNumCtasKv;

// Compute the total number of CTAs.
int32_t numCtas =
ctaLaunchParams.mNumCtasX * ctaLaunchParams.mNumCtasY * ctaLaunchParams.mNumCtasZ;
// Compute the number of waves.
int32_t numWaves = flashinfer::ceil_div(numCtas, params.mMultiProcessorCount);
// Compute the total modeling kernel time.
modelingKernelTime *= numWaves;

// If this candidate has a lower time than the global minimum, update the global minimum.
if (modelingKernelTime < globalModelingKernelTime) {
globalModelingKernelTime = modelingKernelTime;
selectedTileSizeQ = tileSizeQ;
}
}

// Update the tileSizeQ.
selectKernelParams.mTileSizeQ = selectedTileSizeQ;
// Update the kernel type.
if (selectKernelParams.mTileSizeQ >= 64) {
selectKernelParams.mKernelType = FmhaKernelType::KeepsMmaAbForGeneration;
} else {
selectKernelParams.mKernelType = FmhaKernelType::SwapsMmaAbForGeneration;
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The heuristic for selecting tileSizeQ in selectTileSizeQForGqaGeneration appears to have a logical flaw. The kernelMeta and ctaLaunchParams are computed once before the loop that iterates through candidate tileSizeQ values. However, inside the loop, computeCtaAndClusterConfig is called again with the same stale kernelMeta, which means ctaLaunchParams is not updated based on the candidate tileSizeQ. This leads to an incorrect cost model evaluation, as the launch configuration (number of CTAs, waves, etc.) doesn't change with the tile size being evaluated.

To correctly evaluate each candidate tileSizeQ, the corresponding kernelMeta should be retrieved, and the launch parameters should be recomputed within the loop. This can be done efficiently by looking up the kernel metadata from mKernelMetaMap without fully reloading the kernel function in each iteration. I've provided a suggested refactoring of the function to address this.

  void selectTileSizeQForGqaGeneration(RunnerParams const& params,
                                       SelectKernelParams& selectKernelParams) const {
    // Define the per-tile mainloop cost model for different tileSizeQ choices.
    std::unordered_map<int, float> kernelMainloopCost = {
        {128, 2.2},  // Cost factor when tileSizeQ = 128
        {64, 1.68},  // Cost factor when tileSizeQ = 64
        {32, 1.48},  // Cost factor when tileSizeQ = 32
        {16, 1.2},   // Cost factor when tileSizeQ = 16
        {8, 1.0}     // Cost factor when tileSizeQ = 8
    };

    // Define the per-tile reduction cost model for different tileSizeQ choices.
    std::unordered_map<int, float> kernelReductionCost = {
        {128, 1.32},  // Reduction cost factor when tileSizeQ = 128
        {64, 1.2},    // Reduction cost factor when tileSizeQ = 64
        {32, 1.08},   // Reduction cost factor when tileSizeQ = 32
        {16, 1.03},   // Reduction cost factor when tileSizeQ = 16
        {8, 1.0}      // Reduction cost factor when tileSizeQ = 8
    };

    // The reduction cost emulated as a sequence length factor.
    float const kernelReductionSeqLenFactor = 128.0f;

    // The parameters for launching the kernel.
    CtaLaunchParams ctaLaunchParams;
    // The copy of the selectKernelParams, which makes sure it won't modify the original
    // selectKernelParams when computing the number of CTAs.
    SelectKernelParams selectKernelParamsCopy = selectKernelParams;

    // Candidate tile sizes for tileSizeQ to explore.
    int const candidateTileSizesQ[] = {128, 64, 32, 16, 8};

    // The default tileSizeQ.
    int defaultTileSizeQ = selectKernelParamsCopy.mTileSizeQ;
    // The selected tileSizeQ.
    int selectedTileSizeQ = selectKernelParamsCopy.mTileSizeQ;

    // The minimum modeling kernel time.
    float globalModelingKernelTime = FLT_MAX;
    // Loop over each candidate tile size.
    for (int tileSizeQ : candidateTileSizesQ) {
      // Only consider candidates <= default tileSizeQ.
      if (tileSizeQ > defaultTileSizeQ) {
        continue;
      }

      selectKernelParamsCopy.mTileSizeQ = tileSizeQ;
      if (tileSizeQ >= 64) {
        selectKernelParamsCopy.mKernelType = FmhaKernelType::KeepsMmaAbForGeneration;
      } else {
        selectKernelParamsCopy.mKernelType = FmhaKernelType::SwapsMmaAbForGeneration;
      }

      // Find kernel meta info without loading the kernel function
      auto [hashId, info] = hashFromRunnerParams(params, selectKernelParamsCopy);
      auto const findMetaIter = mKernelMetaMap.find(hashId);
      if (findMetaIter == mKernelMetaMap.end()) {
        continue;  // No kernel available for this tile size
      }
      auto const& kernelMeta = mKernelMeta[findMetaIter->second];

      // Compute the number of CTAs.
      computeCtaAndClusterConfig(ctaLaunchParams, params, kernelMeta, selectKernelParamsCopy);

      // If there are no free SMs, this tile size is not a good candidate.
      if (ctaLaunchParams.mNumCtasX * ctaLaunchParams.mNumCtasY * ctaLaunchParams.mNumCtasZ * 2 >
              params.mMultiProcessorCount &&
          tileSizeQ > 8) {  // allow smallest tile size to be selected even if it oversubscribes
        continue;
      }

      // Compute the seqLenPerCtaKv.
      int32_t seqLenPerCtaKv =
          flashinfer::ceil_div(flashinfer::ceil_div(params.mMaxSeqLenKv, kernelMeta.mStepKv),
                               ctaLaunchParams.mMaxNumCtasKv) *
          kernelMeta.mStepKv;

      // Compute the modeling kernel time = mainloop cost + reduction cost.
      float modelingKernelTime = kernelMainloopCost.at(tileSizeQ) * seqLenPerCtaKv +
                                 kernelReductionCost.at(tileSizeQ) * kernelReductionSeqLenFactor *
                                     ctaLaunchParams.mMaxNumCtasKv;

      // Compute the total number of CTAs.
      int32_t numCtas =
          ctaLaunchParams.mNumCtasX * ctaLaunchParams.mNumCtasY * ctaLaunchParams.mNumCtasZ;
      // Compute the number of waves.
      int32_t numWaves = flashinfer::ceil_div(numCtas, params.mMultiProcessorCount);
      // Compute the total modeling kernel time.
      modelingKernelTime *= numWaves;

      // If this candidate has a lower time than the global minimum, update the global minimum.
      if (modelingKernelTime < globalModelingKernelTime) {
        globalModelingKernelTime = modelingKernelTime;
        selectedTileSizeQ = tileSizeQ;
      }
    }

    // Update the tileSizeQ.
    selectKernelParams.mTileSizeQ = selectedTileSizeQ;
    // Update the kernel type.
    if (selectKernelParams.mTileSizeQ >= 64) {
      selectKernelParams.mKernelType = FmhaKernelType::KeepsMmaAbForGeneration;
    } else {
      selectKernelParams.mKernelType = FmhaKernelType::SwapsMmaAbForGeneration;
    }
  }

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It makes sense and I have fixed it.

}

int32_t ceilLog2(int32_t value) const {
return static_cast<int32_t>(std::ceil(std::log2(value)));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The ceilLog2 implementation uses floating-point math (std::log2), which can be less efficient and potentially have precision issues for integer inputs. A bitwise implementation would be more robust and performant, especially if this function is ever used in a performance-critical path. Consider replacing it with a portable integer-based implementation.

    if (value <= 0) return 0; // Match behavior of std::log2 for non-positive values
    int32_t result = 0;
    int32_t temp = value - 1;
    while (temp > 0) {
        temp >>= 1;
        result++;
    }
    return result;

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (2)
include/flashinfer/trtllm/fmha/fmhaKernels.cuh (2)

497-498: Remove extraneous semicolon.

There's a double semicolon on line 498 after the ceil_div call.

🔎 Proposed fix
-    int const maxNumCtasPerSeqKv = flashinfer::ceil_div(params.mMaxSeqLenKv, 256);
-    ;
+    int const maxNumCtasPerSeqKv = flashinfer::ceil_div(params.mMaxSeqLenKv, 256);

558-655: Cost model uses undocumented magic numbers.

The kernelMainloopCost and kernelReductionCost maps use empirical constants (2.2, 1.68, 1.48, etc.) without documentation on how they were derived. Consider adding a brief comment explaining these are benchmarked/profiled values or referencing the source.

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 0ccf4e3 and c5b673a.

📒 Files selected for processing (6)
  • csrc/trtllm_fmha_kernel_launcher.cu
  • flashinfer/artifacts.py
  • include/flashinfer/trtllm/fmha/fmhaKernels.cuh
  • include/flashinfer/trtllm/fmha/fmhaRunnerParams.h
  • include/flashinfer/trtllm/fmha/kernelParams.h
  • tests/attention/test_trtllm_gen_attention.py
🧰 Additional context used
🧬 Code graph analysis (1)
include/flashinfer/trtllm/fmha/kernelParams.h (1)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CommonUtils.h (1)
  • ceilDiv (42-44)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (15)
flashinfer/artifacts.py (1)

90-90: LGTM! Artifact path and checksum updated consistently.

The TRTLLM_GEN_FMHA artifact path and its corresponding checksum are updated together, which is correct. These changes align with the PR objectives to integrate optimized decode attention kernels, and the existing verify_cubin logic (line 224) will validate the checksum during artifact download.

Also applies to: 110-110

tests/attention/test_trtllm_gen_attention.py (1)

1282-1284: AI summary inconsistency: function name and parameter values are incorrect.

The AI summary states these changes are for test_trtllm_batch_decode_head_dim_256, but they are actually added to test_trtllm_batch_decode_long_sequence_length (line 1300). Additionally, the parameter interpretation is incorrect—the actual values are (batch_size=32, q_len_per_req=[4,8,16], page_size=16, num_kv_heads=2, head_grp_size=8), not what the summary describes.

The test additions themselves look good. They appropriately expand coverage for long-sequence scenarios with moderate batch sizes (32) and varying speculative decoding lengths (q_len_per_req of 4, 8, 16), which aligns well with the PR objectives to optimize for high-throughput workloads and speculative decoding.

include/flashinfer/trtllm/fmha/kernelParams.h (3)

282-301: LGTM – logic for grouping tokens and heads per CTA is clear.

The conditional handling for groupsTokensHeadsQ correctly computes numTokensPerCtaQ when grouping is enabled, with appropriate padding comments. The tuple return value now includes numTokensPerCtaQ for downstream use.


594-596: Helpful debugging guidance added.

The comment explaining that TMA descriptor errors may be caused by previous kernels and suggesting CUDA_LAUNCH_BLOCKING or cuda-gdb is useful for debugging.


801-803: Verify FastModDivInt32 handles mNumHeadsQPerKv = 1 safely.

Given the potential edge case with ceilLog2(1) mentioned earlier, ensure that when options.mNumHeadsQPerKv == 1, the FastModDivInt32 constructor doesn't produce invalid values.

csrc/trtllm_fmha_kernel_launcher.cu (1)

162-166: LGTM – mask type change with clear rationale.

The switch from Dense to Causal for generation is well-documented. The comment clarifies that kernel naming retains "dense" for performance reasons while the behavior is equivalent when each CTA processes a single tokenQ.

include/flashinfer/trtllm/fmha/fmhaRunnerParams.h (2)

351-361: New kernel selection parameters added.

The mNumTokensPerPage and mTileSizeQ fields extend the kernel selection API to support the new grouping and tile sizing logic. These integrate well with the updated hashID and kernel selection flow in fmhaKernels.cuh.


376-382: Constructor properly initializes new fields.

mNumTokensPerPage is propagated from input params, and mTileSizeQ/mTileSizeKv are initialized to 128, which is consistent with the largest supported tile size. The initialization order matches the declaration order.

include/flashinfer/trtllm/fmha/fmhaKernels.cuh (7)

54-62: LGTM – SM family/specific pair detection.

The isFamilySpecificSMPair helper correctly identifies when SM values are family/specific variants (e.g., kSM_100f with kSM_100 or kSM_103), enabling proper hash conflict resolution.


76-92: LGTM – CtaLaunchParams encapsulates launch configuration.

The new struct cleanly groups related launch parameters, improving code organization and reducing parameter passing overhead.


117-133: LGTM – Hash conflict resolution prefers specific SM.

The logic correctly handles hash conflicts between family and specific SM versions, preferring the specific version (e.g., kSM_100 or kSM_103 over kSM_100f).


399-402: LGTM – Factor of 2 for reduction overhead.

The comment clearly explains the factor of 2 is applied to balance the reduction overhead against mainloop benefits.


768-829: LGTM – On-demand kernel loading with caching.

The refactored loadKernel method properly caches modules and functions, handles shared memory configuration for large kernels (≥48KB), and includes helpful error messages.


334-345: LGTM – Generation kernel CTA handling for spec-decoding.

The logic correctly distinguishes between groupsTokensHeadsQ enabled/disabled scenarios, computing numCtasPerSeqQ appropriately for speculative decoding workloads.


583-590: Line 588 is not within the heuristic selection loop—it's the initial kernel load before the loop begins. The subsequent heuristic loop (lines 613+) only iterates over candidate tile sizes using the already-loaded kernelMeta; it does not call loadKernel again. If loadKernel fails due to a missing kernel, that's a legitimate error condition (kernel hash not found) that should be reported, not silently caught during heuristic probing.

@PerkzZheng PerkzZheng marked this pull request as draft December 24, 2025 15:12
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (2)
include/flashinfer/trtllm/fmha/kernelParams.h (2)

53-55: Consider replacing floating-point ceilLog2 with a bitwise implementation.

The current implementation uses std::log2, which is less efficient and may have precision issues for integer inputs. A bitwise approach would be more robust and performant.

🔎 Proposed bitwise implementation
  int32_t ceilLog2(int32_t value) const {
-   return static_cast<int32_t>(std::ceil(std::log2(value)));
+   if (value <= 1) return 0;
+   int32_t result = 0;
+   int32_t temp = value - 1;
+   while (temp > 0) {
+     temp >>= 1;
+     result++;
+   }
+   return result;
  }

Based on past review comments.


46-50: Critical: Edge case divisor == 1 leads to negative mShift causing undefined behavior.

When divisor = 1, ceilLog2(1) returns 0, so mShift = 0 - 1 = -1. Negative shift values cause undefined behavior in the multiplier calculation (uint64_t(1) << (32 + mShift) at line 49) and may break downstream usages of mShift.

🔎 Proposed fix
  FastModDivInt32(int32_t divisor) : mDivisor(divisor) {
+   if (divisor == 1) {
+     mShift = 0;
+     mMultiplier = 1;
+     return;
+   }
    mShift = ceilLog2(mDivisor) - 1;
    mMultiplier = static_cast<uint32_t>(
        flashinfer::ceil_div(uint64_t(1) << (32 + mShift), static_cast<uint64_t>(mDivisor)));
  }

Based on past review comments.

🧹 Nitpick comments (2)
include/flashinfer/trtllm/fmha/fmhaKernels.cuh (2)

559-665: The loop now correctly reloads kernel metadata for each candidate tileSizeQ.

The function addresses the prior review concern by:

  • Updating selectKernelParamsCopy.mTileSizeQ at line 618 for each candidate
  • Reloading the kernel with loadKernel at line 626 to get the correct kernelMeta
  • Recomputing CTA configuration at line 629 with the updated metadata

However, loadKernel (line 787) throws an exception if a kernel for a particular tileSizeQ doesn't exist. Consider wrapping the loadKernel call in a try-catch to gracefully skip unavailable candidates rather than aborting the entire selection.

🔎 Optional: graceful handling of missing kernels
    for (int tileSizeQ : candidateTileSizesQ) {
      // Only consider candidates <= default tileSizeQ.
      if (tileSizeQ > defaultTileSizeQ) {
        continue;
      }

      selectKernelParamsCopy.mTileSizeQ = tileSizeQ;
      if (tileSizeQ >= 64) {
        selectKernelParamsCopy.mKernelType = FmhaKernelType::KeepsMmaAbForGeneration;
      } else {
        selectKernelParamsCopy.mKernelType = FmhaKernelType::SwapsMmaAbForGeneration;
      }

-     // Load the kernel.
-     std::tie(func, kernelMeta) = loadKernel(params, selectKernelParamsCopy);
+     // Load the kernel. Skip if not available.
+     try {
+       std::tie(func, kernelMeta) = loadKernel(params, selectKernelParamsCopy);
+     } catch (...) {
+       continue;  // Skip candidates without available kernels
+     }

      // Compute the number of CTAs.
      computeCtaAndClusterConfig(ctaLaunchParams, params, kernelMeta, selectKernelParamsCopy);

496-498: Optional: Remove stray semicolon.

Line 498 has an extra semicolon (;;) after the maxNumCtasPerSeqKv calculation, which is harmless but reduces code cleanliness.

🔎 Proposed fix
    int const maxNumCtasPerSeqKv = flashinfer::ceil_div(params.mMaxSeqLenKv, 256);
-   ;
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between c5b673a and e7f9ba0.

📒 Files selected for processing (2)
  • include/flashinfer/trtllm/fmha/fmhaKernels.cuh
  • include/flashinfer/trtllm/fmha/kernelParams.h
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (16)
include/flashinfer/trtllm/fmha/kernelParams.h (6)

148-150: LGTM.

The new mInflateMax field is clearly documented and adds support for adjusting max value inflation during iterations.


173-174: Verify default initialization {1} given FastModDivInt32 edge case.

The default initialization FastModDivInt32 mNumHeadsQPerKvDivisor{1} triggers the divisor == 1 edge case flagged earlier (leading to mShift = -1). Ensure this default is intentional and that the critical fix for FastModDivInt32 constructor is applied to prevent undefined behavior.


205-207: LGTM.

The signature expansion to include groupsTokensHeadsQ aligns with the new per-CTA tokenization logic described in the PR summary.


296-296: LGTM.

The return tuple now includes numTokensPerCtaQ, consistent with the expanded per-CTA tokenization logic.


639-641: LGTM.

The call site correctly passes the new groupsTokensHeadsQ parameter and captures the expanded return tuple including numTokensPerCtaQ.


796-798: LGTM.

The new fields mNumHeadsQPerKvDivisor and mNumTokensPerCtaQ are correctly initialized from options and the computed numTokensPerCtaQ value.

include/flashinfer/trtllm/fmha/fmhaKernels.cuh (10)

54-62: LGTM.

The isFamilySpecificSMPair function correctly identifies when two SM values represent family/specific architecture pairs (e.g., SM_100f with SM_100 or SM_103), supporting graceful handling of Blackwell GPU variants.


76-93: LGTM.

The new CtaLaunchParams struct is a clean refactor that groups related kernel launch parameters, replacing multiple tuple returns with a more maintainable structure.


117-134: LGTM.

The hash conflict detection correctly allows only family/specific SM pairs to share a hash, and sensibly prefers specific SM versions (e.g., SM_100 over SM_100f) when both are present.


140-183: LGTM.

The expanded hashID signature and bit layout correctly incorporate the new kernel selection parameters (tileSizeQ, reuseSmemKForV, uses2CtaMma, sparseMla) with appropriate assertions and bit-packing.


511-556: LGTM.

The selectMlaGenerationKernel heuristic appropriately selects between low-latency (SwapsMmaAbForGeneration) and high-throughput (KeepsMmaAbForGeneration) kernels based on numHeadsQPerKv and GPU utilization, with clear logic for enabling 2-CTA MMA mode.


667-707: LGTM.

The selectGqGenerationKernel function uses clear threshold-based heuristics for selecting tileSizeQ and kernelType, and appropriately delegates to the cost-model-based selectTileSizeQForGqaGeneration when maxSeqLenQ > 1 for speculative decoding.


399-402: LGTM.

The factor-of-2 adjustment in maxNumCtasPerSeqKv is well-justified by the comment: it prevents splitting KV sequences so finely that reduction overhead exceeds mainloop speedup benefits.


477-484: LGTM.

The refactor to update CtaLaunchParams in place (lines 478–483) is clean and consistent with the new struct-based parameter passing introduced in this PR.


777-839: LGTM.

The loadKernel refactor implements clean on-demand kernel loading with two-level caching (modules and functions), appropriate shared memory configuration for large allocations (≥48KB), and clear error messages when kernels are not found.


204-314: LGTM.

The run() method correctly integrates the new CtaLaunchParams struct and on-demand loadKernel approach, replacing the previous tuple-based parameter passing with a cleaner, more maintainable flow.

Comment on lines 280 to 294
if (groupsHeadsQ) {
if (isSpecDecodingGenerationKernel(options.mKernelType)) {
FLASHINFER_CHECK((tileSizeQ % numGroupedHeads == 0), "internal error");
tileShapes = std::vector<uint32_t>{static_cast<uint32_t>(numEltsInClampedHeadDimQ),
static_cast<uint32_t>(numGroupedHeads), 1,
static_cast<uint32_t>(tileSizeQ / numGroupedHeads)};
if (groupsTokensHeadsQ) {
// Currently, it requires each CTA to process complete headsQ (i.e. numGroupedHeads) at a
// time, so it allows paddings in the end. Removing paddings needs re-organizing the Q
// tensor to [numTokensQ, numGroupedHeads, numHeads, headDimQ] and we might want to revisit
// this in the future.
numTokensPerCtaQ = static_cast<int32_t>(numTokensPerCtaQ / numGroupedHeads);
} else {
tileShapes = std::vector<uint32_t>{static_cast<uint32_t>(numEltsInClampedHeadDimQ),
static_cast<uint32_t>(tileSizeQ), 1, 1};
numGroupedHeads = tileSizeQ;
numTokensPerCtaQ = 1;
}
tileShapes = std::vector<uint32_t>{static_cast<uint32_t>(numEltsInClampedHeadDimQ),
static_cast<uint32_t>(numGroupedHeads), 1,
static_cast<uint32_t>(numTokensPerCtaQ)};
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

# Search for makeTmaShapeStrideQ definition and context
rg -nP 'makeTmaShapeStrideQ' --type=cpp -B5 -A20

Repository: flashinfer-ai/flashinfer

Length of output: 7590


🏁 Script executed:

# Search for tileSizeQ and numGroupedHeads initialization/assignment
rg -nP '\btileSizeQ\s*=' --type=cpp -B2 -A2 | head -50

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Look for the complete function containing this code
fd -e h | xargs rg -l 'groupsHeadsQ\|groupsTokensHeadsQ' | head -5

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check for any assertions or validation around division
rg -nP 'numTokensPerCtaQ|numGroupedHeads' include/flashinfer/trtllm/fmha/kernelParams.h -B2 -A2

Repository: flashinfer-ai/flashinfer

Length of output: 2957


🏁 Script executed:

# Search for tileSizeQ/mTileSizeQ definition and constraints
rg -nP 'mTileSizeQ|tileSizeQ' --type=cpp -B3 -A3 | head -80

Repository: flashinfer-ai/flashinfer

Length of output: 5598


🏁 Script executed:

# Search for kernel metadata definition and mTileSizeQ initialization
fd -e h -e cpp | xargs rg -l 'mTileSizeQ' | head -5

Repository: flashinfer-ai/flashinfer

Length of output: 163


🏁 Script executed:

# Check if there are any constraints on tileSizeQ values in test files or kernel selection logic
rg -nP 'mTileSizeQ.*=' --type=cpp | head -30

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Search for kernel metadata setup and how mTileSizeQ gets assigned/constrained
rg -nP 'mGroupsTokensHeadsQ|groupsTokensHeadsQ' --type=cpp -B5 -A5 | head -80

Repository: flashinfer-ai/flashinfer

Length of output: 3587


🏁 Script executed:

# Look for any validation or constraints on tileSizeQ relative to mNumHeadsQPerKv
rg -nP 'numHeadsQPerKv|mNumHeadsQPerKv' --type=cpp | grep -i 'tile\|size' | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 593


🏁 Script executed:

# Check if there are tests or examples that show tileSizeQ and mNumHeadsQPerKv combinations
fd -e cu -e cpp | xargs rg -l 'mTileSizeQ.*=' | head -3

Repository: flashinfer-ai/flashinfer

Length of output: 89


🏁 Script executed:

# Search for where kernel metadata is created/selected and if there are constraints between tileSizeQ and mNumHeadsQPerKv
fd -e cpp -e cu | xargs rg -l 'mGroupsTokensHeadsQ.*=' | head -5

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check kernel selection logic and metadata initialization
rg -nP 'KernelMeta|kernelMeta.*{' --type=cpp -A10 | head -100

Repository: flashinfer-ai/flashinfer

Length of output: 4231


🏁 Script executed:

# Look for any divmod, remainder, or modulo checks related to tileSizeQ or numGroupedHeads
rg -nP '%|divmod' include/flashinfer/trtllm/fmha/kernelParams.h -B3 -A3

Repository: flashinfer-ai/flashinfer

Length of output: 1465


Add a runtime assertion or comment to document the divisibility constraint at line 286.

At line 286, numTokensPerCtaQ is computed as tileSizeQ / numGroupedHeads and cast to int32_t, which silently truncates if the division is inexact. The comment acknowledges that padding is "currently required" and suggests revisiting this design. While the current implementation tolerates padding, there should be either:

  • An explicit FLASHINFER_CHECK(tileSizeQ % numGroupedHeads == 0, ...) to catch misconfigured kernel metadata, or
  • A clearer comment documenting that tileSizeQ must be divisible by numGroupedHeads when groupsTokensHeadsQ is true

This aligns with existing validation patterns in the codebase (e.g., line 225 for head counts, line 806 for sparse MLA top-k).

@PerkzZheng PerkzZheng marked this pull request as ready for review December 24, 2025 15:15
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
include/flashinfer/trtllm/fmha/kernelParams.h (1)

593-604: Out-of-bounds access when dim < 5 in error logging.

The error logging code accesses fixed indices [0-4] for shapes, tileShapes, and tileStrides, and [0-3] for stridesInBytes. However, dim can be 2–5 (per the check at line 555), so accessing index 4 when dim is 2, 3, or 4 is undefined behavior.

🔎 Proposed fix using a loop or conditional printing
-      std::cerr << "Shape: " << shapes[0] << " " << shapes[1] << " " << shapes[2] << " "
-                << shapes[3] << " " << shapes[4] << std::endl;
-      std::cerr << "Stride: " << stridesInBytes[0] << " " << stridesInBytes[1] << " "
-                << stridesInBytes[2] << " " << stridesInBytes[3] << std::endl;
-      std::cerr << "tileShapes: " << tileShapes[0] << " " << tileShapes[1] << " " << tileShapes[2]
-                << " " << tileShapes[3] << " " << tileShapes[4] << std::endl;
-      std::cerr << "tileStrides: " << tileStrides[0] << " " << tileStrides[1] << " "
-                << tileStrides[2] << " " << tileStrides[3] << " " << tileStrides[4] << std::endl;
+      std::cerr << "Shape:";
+      for (int32_t ii = 0; ii < dim; ++ii) std::cerr << " " << shapes[ii];
+      std::cerr << std::endl;
+      std::cerr << "Stride:";
+      for (int32_t ii = 0; ii < dim - 1; ++ii) std::cerr << " " << stridesInBytes[ii];
+      std::cerr << std::endl;
+      std::cerr << "tileShapes:";
+      for (int32_t ii = 0; ii < dim; ++ii) std::cerr << " " << tileShapes[ii];
+      std::cerr << std::endl;
+      std::cerr << "tileStrides:";
+      for (int32_t ii = 0; ii < dim; ++ii) std::cerr << " " << tileStrides[ii];
+      std::cerr << std::endl;
♻️ Duplicate comments (1)
include/flashinfer/trtllm/fmha/kernelParams.h (1)

280-294: Add a runtime assertion for the divisibility constraint.

The division at line 286 can silently truncate if tileSizeQ is not evenly divisible by numGroupedHeads. While the comment documents that padding is "currently required," adding an explicit FLASHINFER_CHECK would catch misconfigured kernel metadata early and align with validation patterns elsewhere in this file (e.g., line 225, line 806).

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between e7f9ba0 and e2734cd.

📒 Files selected for processing (1)
  • include/flashinfer/trtllm/fmha/kernelParams.h
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • include/flashinfer/trtllm/fmha/kernelParams.h
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (7)
include/flashinfer/trtllm/fmha/kernelParams.h (7)

44-62: LGTM - FastModDivInt32 implementation correctly addresses prior feedback.

The edge case for divisor == 1 is now handled via std::max(ceilLog2(mDivisor) - 1, 0), and the code reuses flashinfer::ceil_div instead of a local duplicate. The mAdd field is correctly left at 0 per the CCCL fast modulo division algorithm for signed types.


148-150: LGTM - New inflation parameter is well-documented.

The mInflateMax field has a clear comment explaining its purpose. The field will be zero-initialized via memset in setKernelParams.


173-180: LGTM - New kernel parameters for token/head grouping.

The mNumHeadsQPerKvDivisor is correctly typed as FastModDivInt32{1} for fast modulo operations, and mNumTokensPerCtaQ supports the new grouping kernel feature.


205-207: LGTM - Clean API extension.

The new groupsTokensHeadsQ parameter cleanly extends the function signature to support the new token/head grouping mode.


589-591: Helpful debugging guidance added.

The updated comment clarifying that errors may originate from previous kernels and suggesting CUDA_LAUNCH_BLOCKING or cuda-gdb is a useful addition for debugging TMA initialization failures.


639-641: LGTM - Call site correctly updated.

The structured binding now captures the new numTokensPerCtaQ return value, and kernelMeta.mGroupsTokensHeadsQ is passed as the new parameter.


796-798: LGTM - Kernel parameters correctly assigned.

The mNumHeadsQPerKvDivisor is initialized from options.mNumHeadsQPerKv, and mNumTokensPerCtaQ is propagated from the makeTmaShapeStrideQ return value.

@yzh119
Copy link
Collaborator

yzh119 commented Dec 25, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !218 has been created, and the CI pipeline #40778397 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #40778397: 1/20 passed

@yzh119 yzh119 requested a review from nvmbreughe as a code owner January 7, 2026 08:31
@yzh119
Copy link
Collaborator

yzh119 commented Jan 7, 2026

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !218 has been updated with latest changes, and the CI pipeline #41270824 is currently running. I'll report back once the pipeline job completes.

Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be ready to merge as long as CI passed.

@yzh119 yzh119 mentioned this pull request Jan 7, 2026
5 tasks
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In @include/flashinfer/trtllm/fmha/fmhaKernels.cuh:
- Around line 783-788: The code dereferences findMetaIter (used to set
metaIndex) before checking its validity, which can cause UB if the key is
absent; to fix, first check that findMetaIter != mKernelMetaMap.end() using the
existing FLASHINFER_CHECK (or an explicit if and error) and only after that
assign metaIndex = findMetaIter->second; update the block around findMetaIter,
metaIndex, mKernelMetaMap and FLASHINFER_CHECK (and preserve the error message
using info) so the iterator is validated prior to dereference.
🧹 Nitpick comments (3)
include/flashinfer/trtllm/fmha/fmhaKernels.cuh (3)

498-499: Remove redundant semicolon.

Line 499 has a double semicolon which is a minor style issue.

Proposed fix
-    int const maxNumCtasPerSeqKv = flashinfer::ceil_div(params.mMaxSeqLenKv, 256);
-    ;
+    int const maxNumCtasPerSeqKv = flashinfer::ceil_div(params.mMaxSeqLenKv, 256);

559-600: Consider making cost model maps static const for efficiency.

The kernelMainloopCost and kernelReductionCost maps are constant and reconstructed on each function call. Making them static const would avoid repeated construction.

Proposed fix
   void selectTileSizeQForGqaGeneration(RunnerParams const& params,
                                        SelectKernelParams& selectKernelParams) const {
     // Define the per-tile mainloop cost model for different tileSizeQ choices.
-    std::unordered_map<int, float> kernelMainloopCost = {
+    static const std::unordered_map<int, float> kernelMainloopCost = {
         {128, 2.2},  // Cost factor when tileSizeQ = 128
         {64, 1.68},  // Cost factor when tileSizeQ = 64
         {32, 1.48},  // Cost factor when tileSizeQ = 32
         {16, 1.2},   // Cost factor when tileSizeQ = 16
         {8, 1.0}     // Cost factor when tileSizeQ = 8
     };

     // Define the per-tile reduction cost model for different tileSizeQ choices.
-    std::unordered_map<int, float> kernelReductionCost = {
+    static const std::unordered_map<int, float> kernelReductionCost = {
         {128, 1.32},  // Reduction cost factor when tileSizeQ = 128
         {64, 1.2},    // Reduction cost factor when tileSizeQ = 64
         {32, 1.08},   // Reduction cost factor when tileSizeQ = 32
         {16, 1.03},   // Reduction cost factor when tileSizeQ = 16
         {8, 1.0}      // Reduction cost factor when tileSizeQ = 8
     };

799-804: Remove unused capitalizeFirst lambda.

This lambda is defined but never used within the function - appears to be leftover from refactoring.

Proposed fix
       // Check if the module is already loaded.
       auto findModuleIter = mModules.find(kernelMeta.mFuncName);
-      auto capitalizeFirst = [](std::string str) {
-        if (!str.empty()) {
-          str[0] = std::toupper(str[0]);
-        }
-        return str;
-      };
       if (findModuleIter == mModules.end()) {
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 4507598 and 78448ee.

📒 Files selected for processing (1)
  • include/flashinfer/trtllm/fmha/fmhaKernels.cuh
🧰 Additional context used
📓 Path-based instructions (1)
include/**/*.cuh

📄 CodeRabbit inference engine (CLAUDE.md)

include/**/*.cuh: Torch headers MUST NOT be included in files within the include/ directory - keep framework-agnostic CUDA kernels that accept raw pointers
Kernel code in include/flashinfer/ is automatically picked up by JIT compilation on changes - no pip reinstall needed
For performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers
Consult the PTX ISA documentation (https://docs.nvidia.com/cuda/parallel-thread-execution/) for low-level instruction details and new GPU architecture features when writing inline PTX assembly

Files:

  • include/flashinfer/trtllm/fmha/fmhaKernels.cuh
🧠 Learnings (3)
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Consult the PTX ISA documentation (https://docs.nvidia.com/cuda/parallel-thread-execution/) for low-level instruction details and new GPU architecture features when writing inline PTX assembly

Applied to files:

  • include/flashinfer/trtllm/fmha/fmhaKernels.cuh
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Torch headers MUST NOT be included in files within the `include/` directory - keep framework-agnostic CUDA kernels that accept raw pointers

Applied to files:

  • include/flashinfer/trtllm/fmha/fmhaKernels.cuh
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : For performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers

Applied to files:

  • include/flashinfer/trtllm/fmha/fmhaKernels.cuh
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (11)
include/flashinfer/trtllm/fmha/fmhaKernels.cuh (11)

55-63: LGTM!

The isFamilySpecificSMPair helper correctly implements symmetric detection of family/specific SM version pairs, enabling proper hash conflict resolution.


77-93: LGTM!

The CtaLaunchParams struct cleanly encapsulates kernel launch dimensions with appropriate documentation for each member.


118-134: LGTM!

Hash conflict handling correctly prioritizes specific SM versions (sm100, sm103) over the family version (sm100f), with appropriate validation that conflicts only occur between compatible SM variants.


141-184: LGTM!

The expanded hashID function correctly encodes the new kernel selection parameters into the 64-bit hash with proper bit allocations and validation checks.


203-315: LGTM!

The refactored run() method cleanly integrates the new kernel selection, on-demand loading, and CTA configuration workflow with proper deadlock detection and cluster configuration handling.


335-347: Approve the token/heads grouping logic for speculative decoding.

The logic correctly handles the groupsTokensHeadsQ flag to support speculative decoding scenarios where multiple tokens and heads are processed within a single CTA. The calculation ensures complete numHeadsQPerKv groups are processed together.


512-557: LGTM!

The MLA generation kernel selection heuristic is well-documented with clear conditions for choosing between low-latency and high-throughput kernel variants. As per coding guidelines, the performance-critical algorithmic choices are properly explained in comments.


613-666: LGTM!

The tile size selection loop correctly updates parameters for each candidate, computes CTA configuration, and selects the optimal tile size based on the cost model. The previously reported issue about not updating tileSizeQ per iteration has been fixed.


668-708: LGTM!

The GQA generation kernel selection correctly handles mixed precision cases separately and uses threshold-based tile size selection for standard cases, with cost-model-based selection for speculative decoding scenarios.


710-740: LGTM!

The selectKernel method provides a clean entry point that dispatches to appropriate kernel selection logic based on kernel type and handles mask type and page token configuration appropriately.


742-776: LGTM!

The hashFromRunnerParams method correctly constructs both the debug info string and hash ID with matching parameters for kernel lookup and debugging.

Comment on lines +783 to +788
auto const findMetaIter = mKernelMetaMap.find(hashId);
// The meta index.
auto const metaIndex = findMetaIter->second;

// Add debug info when kernels are not found.
FLASHINFER_CHECK(findMetaIter != mKernelMetaMap.end(), "Trtllm-gen kernels not found: " + info);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Iterator dereferenced before validity check - undefined behavior if kernel not found.

Line 785 accesses findMetaIter->second before line 788 checks whether findMetaIter != mKernelMetaMap.end(). If the hash is not found in the map, dereferencing the end iterator causes undefined behavior.

Proposed fix
   std::pair<CUfunction, KernelMeta> loadKernel(RunnerParams const& params,
                                                SelectKernelParams& selectKernelParams) const {
     // Hash the runner params.
     auto [hashId, info] = hashFromRunnerParams(params, selectKernelParams);
     auto const findMetaIter = mKernelMetaMap.find(hashId);
-    // The meta index.
-    auto const metaIndex = findMetaIter->second;

     // Add debug info when kernels are not found.
     FLASHINFER_CHECK(findMetaIter != mKernelMetaMap.end(), "Trtllm-gen kernels not found: " + info);

+    // The meta index.
+    auto const metaIndex = findMetaIter->second;
+
     // Load the function if not found.
     if (mFunctions.find(hashId) == mFunctions.end()) {
🤖 Prompt for AI Agents
In @include/flashinfer/trtllm/fmha/fmhaKernels.cuh around lines 783 - 788, The
code dereferences findMetaIter (used to set metaIndex) before checking its
validity, which can cause UB if the key is absent; to fix, first check that
findMetaIter != mKernelMetaMap.end() using the existing FLASHINFER_CHECK (or an
explicit if and error) and only after that assign metaIndex =
findMetaIter->second; update the block around findMetaIter, metaIndex,
mKernelMetaMap and FLASHINFER_CHECK (and preserve the error message using info)
so the iterator is validated prior to dereference.

@bkryu
Copy link
Collaborator

bkryu commented Jan 7, 2026

/bot stop

@bkryu
Copy link
Collaborator

bkryu commented Jan 7, 2026

/bot run

@flashinfer-bot
Copy link
Collaborator

The GitLab CI pipeline #41270824 has been cancelled.

@flashinfer-bot
Copy link
Collaborator

GitLab MR !218 has been updated with latest changes, and the CI pipeline #41294820 is currently running. I'll report back once the pipeline job completes.

@bkryu
Copy link
Collaborator

bkryu commented Jan 7, 2026

@yzh119, unit test failures on Blackwell cards are unrelated failures 👍

@yzh119
Copy link
Collaborator

yzh119 commented Jan 7, 2026

The only failed public CI ("script returned exit code 143") is not relevant: https://ci.tlcpack.ai/blue/organizations/jenkins/flashinfer-ci/detail/PR-2265/9/pipeline, bypass and merge to unblock 0.6.0 release

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants