Skip to content

perf: Update trtllm-gen batched GEMM kernels - faster, more NVFP4 tile dims, MXFP8 with relu2 act#2667

Merged
aleozlx merged 6 commits intoflashinfer-ai:mainfrom
amitz-nv:update-batched-gemm-faster-gemms
Mar 3, 2026
Merged

perf: Update trtllm-gen batched GEMM kernels - faster, more NVFP4 tile dims, MXFP8 with relu2 act#2667
aleozlx merged 6 commits intoflashinfer-ai:mainfrom
amitz-nv:update-batched-gemm-faster-gemms

Conversation

@amitz-nv
Copy link
Contributor

@amitz-nv amitz-nv commented Mar 2, 2026

📌 Description

Updates the trtllm-gen batched GEMM kernel hashes and include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export directory for:

  • Faster batched GEMM kernels
  • Additional NVFP4 tile dims kernels
  • Additional MXFP8 with Relu^2 non-gated activation kernels

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Chores
    • Updated an internal batched matrix‑multiplication interface (call signature changed).
    • Refreshed native artifact references and checksums for GPU-accelerated components.
    • No end-user behavioral changes expected; these are internal compatibility and maintenance updates.

… NVFP4 tile dims, MXFP8 with relu2 non-gated activation

Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com>
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 2, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 351a374 and ca601c4.

📒 Files selected for processing (1)
  • csrc/trtllm_batched_gemm_runner.cu

📝 Walkthrough

Walkthrough

Updates the BatchedGemmInterface::run signature and call site by inserting a new parameter (passed as nullptr at the call site), and updates artifact paths and checksum values for TRTLLM batched GEMM artifacts.

Changes

Cohort / File(s) Summary
GEMM Runner API Update
csrc/trtllm_batched_gemm_runner.cu
Call site updated: BatchedGemmInterface::run now receives an additional pinnedHostBuffer argument (passed as nullptr) before the existing module cache parameter, reflecting a signature change.
Artifact Versioning
flashinfer/artifacts.py
Updated artifact paths: ArtifactPath.TRTLLM_GEN_BMM, ArtifactPath.TRTLLM_GEN_GEMM (added trailing slash) and checksum: CheckSumHash.TRTLLM_GEN_BMM changed to new values pointing to updated batched GEMM artifacts.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Suggested reviewers

  • cyx-6
  • IwakuraRein
  • bkryu

Poem

🐇 A tiny arg hops into place—
Before the cache, it finds its space.
Artifacts swapped, the hashes sing,
Batches ready, engines spring,
I nibble code and dance with grace. 🥕

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 64.71% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately and specifically describes the main changes: updating trtllm-gen batched GEMM kernels with performance improvements and new feature support.
Description check ✅ Passed The description covers the main changes and includes completed pre-commit and testing checklists, though related issues section is left blank.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request integrates significant updates to the trtllm-gen batched GEMM kernels, focusing on performance improvements, broader support for NVFP4 tile dimensions, and the addition of MXFP8 kernels with Relu^2 activation. A key architectural change involves the introduction of dynamic tile scheduling mechanisms, which enhance the kernel's adaptability and efficiency on modern GPU architectures. These changes collectively aim to optimize matrix multiplication operations critical for deep learning workloads.

Highlights

  • Performance Enhancements: Updated batched GEMM kernels for improved performance, as indicated by the PR title and artifact hash updates.
  • Expanded NVFP4 Support: Added support for more NVFP4 tile dimensions, enhancing flexibility for specific hardware configurations.
  • New MXFP8 Kernels: Introduced additional MXFP8 kernels with Relu^2 non-gated activation, expanding the range of supported operations.
  • Dynamic Tile Scheduling: Implemented dynamic tile scheduling capabilities, including new StaticPersistent and PersistentSm90 schedulers, and associated logic for managing dynamic tile counters.
  • Artifact Updates: Updated the trtllm-gen batched GEMM artifact hashes and paths, reflecting the new kernel versions.
  • Copyright Year Update: Updated the copyright year to 2026 across numerous header files within the trtllmGen_bmm_export directory.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • csrc/trtllm_batched_gemm_runner.cu
    • Modified the bmm.run function call to pass a nullptr for a newly introduced pinnedHostBuffer parameter.
  • flashinfer/artifacts.py
    • Updated the artifact path and checksum hash for TRTLLM_GEN_BMM.
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmEnums.h
    • Updated the copyright year to 2026.
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h
    • Added a mPtrDynamicTileCounter member to the BatchedGemmData struct.
    • Updated the run method signature to include a pinnedHostBuffer parameter.
    • Implemented logic for initializing the dynamic tile counter when dynamic scheduling is enabled.
    • Introduced a new getFixedGridDim method to calculate grid dimensions for fixed-size grids.
    • Modified getLaunchGrid and getNumCtas to account for fixed grid dimensions and new tile schedulers.
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h
    • Updated the copyright year to 2026.
    • Added a clcFastDrain boolean member to the BatchedGemmOptions struct and its constructor.
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/Enums.h
    • Updated the copyright year to 2026.
    • Expanded the TileScheduler enum with StaticPersistent and PersistentSm90 options.
    • Added helper functions isPersistentScheduler and supportsCleanEarlyExit.
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h
    • Updated the copyright year to 2026.
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h
    • Updated the copyright year to 2026.
    • Added clcFastDrain to the GemmOptions constructor and members.
    • Provided a toString specialization for CtaSwizzleType.
    • Updated dumpOptions to include the mClcFastDrain option.
    • Adjusted checkAndUpdateGemmOptions to use the new isPersistentScheduler helper and refined warning messages for persistent schedulers.
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h
    • Updated the copyright year to 2026.
    • Extended setKernelParams to accept and assign a ptrDynamicTileCounter.
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParamsDecl.h
    • Updated the copyright year to 2026.
    • Added a ptrDynamicTileCounter member to the KernelParams structure for dynamic tile scheduling.
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.h
    • Updated the copyright year to 2026.
    • Initialized several member variables (mMmaKind, mFuseUtccpWithUtcmma, mUseMaxTmemOverlap, mNumEpilogueWarps) with default values.
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.h
    • Updated the copyright year to 2026.
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CommonUtils.h
    • Updated the copyright year to 2026.
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CudaArchDecl.h
    • Updated the copyright year to 2026.
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CudaKernelLauncher.h
    • Updated the copyright year to 2026.
    • Adjusted the cuFuncSetAttribute call by explicitly declaring CUresult.
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h
    • Updated the copyright year to 2026.
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/MmaDecl.h
    • Updated the copyright year to 2026.
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/SfLayoutDecl.h
    • Updated the copyright year to 2026.
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/SparsityDecl.h
    • Updated the copyright year to 2026.
Activity
  • No human activity has been recorded on this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@aleozlx aleozlx added v0.6.5 release blocker label for v0.6.5 run-ci labels Mar 2, 2026
@aleozlx
Copy link
Collaborator

aleozlx commented Mar 2, 2026

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !362 has been created, and the CI pipeline #45166470 is currently running. I'll report back once the pipeline job completes.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the trtllm-gen batched GEMM kernels, introducing performance improvements and new features like additional tile dimensions for NVFP4 and MXFP8 with Relu^2 activation. The changes primarily involve updating kernel hashes and modifying the batched GEMM interface to support new dynamic scheduling features. While most of the changes correctly implement the new functionalities, I've identified a critical issue in BatchedGemmInterface.h concerning memory management and thread safety. A static local variable used for a pinned host buffer results in a memory leak and is not thread-safe. My review includes a comment with a suggested fix to address this by making the caller responsible for the buffer's lifecycle.

Comment on lines +569 to +583
static uint32_t* pinnedCounterValue_ = nullptr;
if (pinnedHostBuffer) {
// Use the pre-allocated pinned host memory if provided.
pinnedCounterValue_ = reinterpret_cast<uint32_t*>(pinnedHostBuffer);
}
// Use pinned host memory for the source value to support CUDA graph, or it silently breaks
// during CUDA graph replay and corrupts the data.
if (pinnedHostBuffer == nullptr && pinnedCounterValue_ == nullptr) {
// Allocate pinned host memory if not provided.
auto err = cudaMallocHost((void**)&pinnedCounterValue_, sizeof(uint32_t));
if (err != cudaSuccess) {
return 1;
}
}
*pinnedCounterValue_ = numFixedCtaBatch * numFixedCtaTile;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The use of a static local variable pinnedCounterValue_ here introduces two critical issues:

  1. Memory Leak: The memory allocated with cudaMallocHost is never freed, leading to a memory leak. The static pointer will hold the allocated memory address for the lifetime of the application.
  2. Thread Safety: static local variables are shared across all calls to the function, even from different threads and on different object instances. This can lead to race conditions if multiple threads call this run method concurrently.

It's recommended to make the caller responsible for managing the lifecycle of this pinned buffer. This utility class should not be allocating memory it doesn't own.

A possible fix is to require pinnedHostBuffer to be non-null when dynamic scheduling is enabled and remove the allocation logic from this function. This will require the caller (TrtllmGenBatchedGemmRunner) to manage this buffer.

Suggested change
static uint32_t* pinnedCounterValue_ = nullptr;
if (pinnedHostBuffer) {
// Use the pre-allocated pinned host memory if provided.
pinnedCounterValue_ = reinterpret_cast<uint32_t*>(pinnedHostBuffer);
}
// Use pinned host memory for the source value to support CUDA graph, or it silently breaks
// during CUDA graph replay and corrupts the data.
if (pinnedHostBuffer == nullptr && pinnedCounterValue_ == nullptr) {
// Allocate pinned host memory if not provided.
auto err = cudaMallocHost((void**)&pinnedCounterValue_, sizeof(uint32_t));
if (err != cudaSuccess) {
return 1;
}
}
*pinnedCounterValue_ = numFixedCtaBatch * numFixedCtaTile;
if (!pinnedHostBuffer) {
// A pinned host buffer must be provided for dynamic scheduling with PersistentSm90.
return 1;
}
uint32_t* pinnedCounterValue_ = reinterpret_cast<uint32_t*>(pinnedHostBuffer);
*pinnedCounterValue_ = numFixedCtaBatch * numFixedCtaTile;

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@csrc/trtllm_batched_gemm_runner.cu`:
- Around line 262-264: The call to bmm.run(...) is passing nullptr for the
pinnedHostBuffer parameter which works because the implementation
auto-allocates, but per BatchedGemmInterface's comment this breaks CUDA graph
replay when CUDA graphs are used with dynamic/persistent scheduling
(PersistentSm90); update the call site around bmm.run to: detect when CUDA graph
mode or dynamic scheduling/persistent SM (e.g., using enable_pdl or
multiProcessorCount flag or a dedicated isPersistentSm90 check) and in that case
require or allocate a real pinned host buffer and pass it instead of nullptr,
otherwise keep nullptr; alternatively add an explicit runtime check/ASSERT
logging that prevents using nullptr in graph+persistent modes to avoid silent
corruption.

In
`@include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h`:
- Around line 743-748: After rounding numCtasXy to clusterSizeXy granularity in
the fixed-grid calculation, guard against it becoming zero: ensure numCtasXy is
at least clusterSizeXy (or set to 1*clusterSizeXy) before multiplying by
multiProcessorOccupancy so the computed CTA grid is never zero; adjust the code
around the variables numCtasXy, clusterSizeXy, multiProcessorCount,
options.mNumSlicesForSplitK, and multiProcessorOccupancy to apply this min-check
prior to returning the tuple.
- Around line 569-586: Remove the static shared pinnedCounterValue_ and require
the caller to supply per-call pinnedHostBuffer: in run() fail early (return
error) if pinnedHostBuffer is null, then declare a stack-local pointer uint32_t*
pinnedCounterValue = reinterpret_cast<uint32_t*>(pinnedHostBuffer) (no
cudaMallocHost or static storage), write the counter into *pinnedCounterValue
and use that pointer as the source for cudaMemcpyAsync to
batchedGemmData.mInputBuffers.mPtrDynamicTileCounter; this guarantees each
invocation uses caller-provided isolated pinned memory and avoids races on
pinnedCounterValue_ across in-flight kernels.

ℹ️ Review info

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between c25bd14 and 8b0cf4b.

⛔ Files ignored due to path filters (7)
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CommonUtils.h is excluded by !**/gen/**
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CudaArchDecl.h is excluded by !**/gen/**
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CudaKernelLauncher.h is excluded by !**/gen/**
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h is excluded by !**/gen/**
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/MmaDecl.h is excluded by !**/gen/**
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/SfLayoutDecl.h is excluded by !**/gen/**
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/SparsityDecl.h is excluded by !**/gen/**
📒 Files selected for processing (12)
  • csrc/trtllm_batched_gemm_runner.cu
  • flashinfer/artifacts.py
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmEnums.h
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.h
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/Enums.h
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.h
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParamsDecl.h
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.h
  • include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.h

Comment on lines +262 to +264
auto const err =
bmm.run(config, workspace, gemmData, static_cast<void*>(stream), multiProcessorCount,
enable_pdl, nullptr, globalTrtllmGenBatchedGemmModuleCache);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

# Find the BatchedGemmInterface header file
find . -name "BatchedGemmInterface.h" -o -name "*BatchedGemmInterface*" | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 152


🏁 Script executed:

# Search for the run() method signature in the include files
rg -n "void\s+run\s*\(" include/ csrc/ -A 5 -B 2 --type cpp --type h

Repository: flashinfer-ai/flashinfer

Length of output: 13052


🏁 Script executed:

# Look for the specific parameter that was added (before globalTrtllmGenBatchedGemmModuleCache)
rg -n "bmm\.run\s*\(" csrc/ -B 2 -A 2 --type cpp

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

cat -n include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h

Repository: flashinfer-ai/flashinfer

Length of output: 49914


🏁 Script executed:

# Also check the implementation file (if it exists) to understand the parameter usage
find . -name "*BatchedGemmInterface*" -type f

Repository: flashinfer-ai/flashinfer

Length of output: 152


🏁 Script executed:

# Find and read the trtllm_batched_gemm_runner.cu file around lines 262-264
fd trtllm_batched_gemm_runner.cu | xargs -I {} sh -c 'wc -l "{}"; head -280 "{}" | tail -50'

Repository: flashinfer-ai/flashinfer

Length of output: 2547


🏁 Script executed:

# Also search for where pinnedHostBuffer or pinned host memory is used/allocated in the codebase
rg -n "pinnedHostBuffer|pinnedCounterValue|cudaMallocHost" csrc/ include/ -B 2 -A 2

Repository: flashinfer-ai/flashinfer

Length of output: 4291


🏁 Script executed:

# Search for CUDA graph usage patterns
rg -n "cuGraphCreate|cudaGraphCreate|cuGraphLaunch|cudaGraphLaunch" csrc/ include/ -B 2 -A 2

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check enable_pdl and scheduler usage to understand if CUDA graphs are relevant
rg -n "enable_pdl|PersistentSm90|cudaGraph" csrc/trtllm_batched_gemm_runner.cu -B 2 -A 2

Repository: flashinfer-ai/flashinfer

Length of output: 2752


Clarify the CUDA graph safety requirement for the nullptr argument.

The nullptr value passed to pinnedHostBuffer is handled by the implementation via automatic allocation, so it won't cause a runtime failure. However, the implementation's own code comment (lines 574–575 in BatchedGemmInterface.h) explicitly warns: "Use pinned host memory for the source value to support CUDA graph, or it silently breaks during CUDA graph replay and corrupts the data." If CUDA graphs are used with dynamic scheduling (PersistentSm90), a valid pinned host buffer must be provided instead of nullptr.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/trtllm_batched_gemm_runner.cu` around lines 262 - 264, The call to
bmm.run(...) is passing nullptr for the pinnedHostBuffer parameter which works
because the implementation auto-allocates, but per BatchedGemmInterface's
comment this breaks CUDA graph replay when CUDA graphs are used with
dynamic/persistent scheduling (PersistentSm90); update the call site around
bmm.run to: detect when CUDA graph mode or dynamic scheduling/persistent SM
(e.g., using enable_pdl or multiProcessorCount flag or a dedicated
isPersistentSm90 check) and in that case require or allocate a real pinned host
buffer and pass it instead of nullptr, otherwise keep nullptr; alternatively add
an explicit runtime check/ASSERT logging that prevents using nullptr in
graph+persistent modes to avoid silent corruption.

Comment on lines +569 to +586
static uint32_t* pinnedCounterValue_ = nullptr;
if (pinnedHostBuffer) {
// Use the pre-allocated pinned host memory if provided.
pinnedCounterValue_ = reinterpret_cast<uint32_t*>(pinnedHostBuffer);
}
// Use pinned host memory for the source value to support CUDA graph, or it silently breaks
// during CUDA graph replay and corrupts the data.
if (pinnedHostBuffer == nullptr && pinnedCounterValue_ == nullptr) {
// Allocate pinned host memory if not provided.
auto err = cudaMallocHost((void**)&pinnedCounterValue_, sizeof(uint32_t));
if (err != cudaSuccess) {
return 1;
}
}
*pinnedCounterValue_ = numFixedCtaBatch * numFixedCtaTile;
auto err = cudaMemcpyAsync((void*)batchedGemmData.mInputBuffers.mPtrDynamicTileCounter,
pinnedCounterValue_, sizeof(uint32_t), cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(cudaStream));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

# First, get the file size and examine the critical section
wc -l include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h

Repository: flashinfer-ai/flashinfer

Length of output: 154


🏁 Script executed:

# Examine the code around lines 569-586
sed -n '560,595p' include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h

Repository: flashinfer-ai/flashinfer

Length of output: 1910


🏁 Script executed:

# Look for the full run() function signature and context
sed -n '540,600p' include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h

Repository: flashinfer-ai/flashinfer

Length of output: 3412


🏁 Script executed:

# Search for pinnedCounterValue_ usage across the entire file
rg -n "pinnedCounterValue_|pinnedHostBuffer|mPtrDynamicTileCounter" include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h | head -30

Repository: flashinfer-ai/flashinfer

Length of output: 1031


🏁 Script executed:

# Look for stream synchronization patterns and cudaMemcpyAsync usage
rg -n -B2 -A2 "cudaMemcpyAsync|cudaStreamSynchronize|cudaDeviceSynchronize|cudaStreamWaitEvent" include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h | head -40

Repository: flashinfer-ai/flashinfer

Length of output: 428


Remove the static shared host buffer; require caller to provide isolated pinned host memory per call.

The static pinnedCounterValue_ at line 569 is reused across invocations of run(). When pinnedHostBuffer is not provided or provided inconsistently, the same buffer receives the new counter value at line 583 before the previous async copy at line 584 completes, corrupting the dynamic tile counter for in-flight kernels.

The fix is to eliminate the static storage, make pinnedHostBuffer mandatory (return error if null), and use a stack-local variable. This ensures each call gets an isolated source buffer, preventing data races.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h`
around lines 569 - 586, Remove the static shared pinnedCounterValue_ and require
the caller to supply per-call pinnedHostBuffer: in run() fail early (return
error) if pinnedHostBuffer is null, then declare a stack-local pointer uint32_t*
pinnedCounterValue = reinterpret_cast<uint32_t*>(pinnedHostBuffer) (no
cudaMallocHost or static storage), write the counter into *pinnedCounterValue
and use that pointer as the source for cudaMemcpyAsync to
batchedGemmData.mInputBuffers.mPtrDynamicTileCounter; this guarantees each
invocation uses caller-provided isolated pinned memory and avoids races on
pinnedCounterValue_ across in-flight kernels.

Comment on lines +743 to +748
auto numCtasXy = multiProcessorCount / options.mNumSlicesForSplitK;
// Round down to the nearest multiple of the cluster size.
numCtasXy = (numCtasXy / clusterSizeXy) * clusterSizeXy;
// Account for available occupancy for the kernel. Defaults to 1 CTA per multiprocessor.
numCtasXy *= multiProcessorOccupancy;
return std::make_tuple(numCtasXy, 1, options.mNumSlicesForSplitK);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

cat -n include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h | sed -n '730,760p'

Repository: flashinfer-ai/flashinfer

Length of output: 2072


🏁 Script executed:

rg -n "getFixedGridDim" include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h -A 30

Repository: flashinfer-ai/flashinfer

Length of output: 5536


🏁 Script executed:

# Check for Torch headers in this file
rg -n "torch" include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h

Repository: flashinfer-ai/flashinfer

Length of output: 148


🏁 Script executed:

# Understand the data types and parameter constraints
rg -n "multiProcessorCount|clusterSizeXy|multiProcessorOccupancy|mNumSlicesForSplitK" include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h -B 3 -A 1

Repository: flashinfer-ai/flashinfer

Length of output: 3203


Guard against zero CTA dimensions in fixed-grid calculation.

After floor-rounding to cluster granularity (line 745), numCtasXy can become 0 when multiProcessorCount / options.mNumSlicesForSplitK < clusterSizeXy. This produces invalid CUDA launch grids. Add a check to ensure at least one cluster is available:

   // The maximum number of CTAs a GPU can run across the XY dimension.
   auto numCtasXy = multiProcessorCount / options.mNumSlicesForSplitK;
   // Round down to the nearest multiple of the cluster size.
   numCtasXy = (numCtasXy / clusterSizeXy) * clusterSizeXy;
+  // Keep at least one full cluster to avoid invalid zero-sized launches.
+  if (numCtasXy == 0) {
+    numCtasXy = clusterSizeXy;
+  }
   // Account for available occupancy for the kernel. Defaults to 1 CTA per multiprocessor.
   numCtasXy *= multiProcessorOccupancy;
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
auto numCtasXy = multiProcessorCount / options.mNumSlicesForSplitK;
// Round down to the nearest multiple of the cluster size.
numCtasXy = (numCtasXy / clusterSizeXy) * clusterSizeXy;
// Account for available occupancy for the kernel. Defaults to 1 CTA per multiprocessor.
numCtasXy *= multiProcessorOccupancy;
return std::make_tuple(numCtasXy, 1, options.mNumSlicesForSplitK);
auto numCtasXy = multiProcessorCount / options.mNumSlicesForSplitK;
// Round down to the nearest multiple of the cluster size.
numCtasXy = (numCtasXy / clusterSizeXy) * clusterSizeXy;
// Keep at least one full cluster to avoid invalid zero-sized launches.
if (numCtasXy == 0) {
numCtasXy = clusterSizeXy;
}
// Account for available occupancy for the kernel. Defaults to 1 CTA per multiprocessor.
numCtasXy *= multiProcessorOccupancy;
return std::make_tuple(numCtasXy, 1, options.mNumSlicesForSplitK);
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h`
around lines 743 - 748, After rounding numCtasXy to clusterSizeXy granularity in
the fixed-grid calculation, guard against it becoming zero: ensure numCtasXy is
at least clusterSizeXy (or set to 1*clusterSizeXy) before multiplying by
multiProcessorOccupancy so the computed CTA grid is never zero; adjust the code
around the variables numCtasXy, clusterSizeXy, multiProcessorCount,
options.mNumSlicesForSplitK, and multiProcessorOccupancy to apply this min-check
prior to returning the tuple.

Copy link
Collaborator

@aleozlx aleozlx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

double checked hash

Copy link
Collaborator

@aleozlx aleozlx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks clean

conflicts resolved against refactor change #2235

approved again

@aleozlx
Copy link
Collaborator

aleozlx commented Mar 2, 2026

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !362 has been updated with latest changes, and the CI pipeline #45171727 is currently running. I'll report back once the pipeline job completes.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (1)
csrc/trtllm_batched_gemm_runner.cu (1)

261-263: ⚠️ Potential issue | 🟠 Major

nullptr for the new run argument is still risky in PDL/graph-like execution paths.

Line 262-Line 263 still pass nullptr for the newly inserted BatchedGemmInterface::run parameter while enable_pdl is enabled. Please gate this mode or plumb a real pinned host buffer to avoid silent replay/data issues in persistent scheduling paths.

#!/bin/bash
# Verify API intent + all callsites for the new BatchedGemmInterface::run parameter.

set -euo pipefail

# 1) Locate BatchedGemmInterface::run declaration and nearby comments/docs.
fd "BatchedGemmInterface.h" | while read -r f; do
  echo "=== $f ==="
  rg -n "void\\s+run\\s*\\(" "$f" -n -A20 -B8
done

# 2) Find all bmm.run(...) callsites and inspect the argument passed in that slot.
rg -nP --type=cpp -C3 "\\bbmm\\.run\\s*\\(" csrc include

# 3) Check whether this runner path is used with PDL / persistent / graph-related code paths.
rg -n -C3 "enable_pdl|Persistent|cudaGraph|graph replay|runInitBeforeWorldSync" csrc include
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/trtllm_batched_gemm_runner.cu` around lines 261 - 263, The callsite is
passing nullptr for the new BatchedGemmInterface::run parameter (bmm.run(...,
enable_pdl, nullptr, globalTrtllmGenBatchedGemmModuleCache)), which is unsafe
when enable_pdl is true; change the caller to either (a) gate this path so that
when enable_pdl is true we do not pass nullptr (early-return or disable PDL for
this runner), or (b) allocate and pass a real pinned host buffer/structure and
pass that instead of nullptr; update the bmm.run invocation site to check
enable_pdl and supply the pinned host buffer (or fail with a clear error) so
persistent/graph replay paths never receive a nullptr.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@csrc/trtllm_batched_gemm_runner.cu`:
- Around line 261-263: The callsite is passing nullptr for the new
BatchedGemmInterface::run parameter (bmm.run(..., enable_pdl, nullptr,
globalTrtllmGenBatchedGemmModuleCache)), which is unsafe when enable_pdl is
true; change the caller to either (a) gate this path so that when enable_pdl is
true we do not pass nullptr (early-return or disable PDL for this runner), or
(b) allocate and pass a real pinned host buffer/structure and pass that instead
of nullptr; update the bmm.run invocation site to check enable_pdl and supply
the pinned host buffer (or fail with a clear error) so persistent/graph replay
paths never receive a nullptr.

ℹ️ Review info

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 8b0cf4b and 0a10131.

📒 Files selected for processing (2)
  • csrc/trtllm_batched_gemm_runner.cu
  • flashinfer/artifacts.py

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #45171727: 1/20 passed

@aleozlx
Copy link
Collaborator

aleozlx commented Mar 2, 2026

the errors seem infra issue. trying again

@aleozlx
Copy link
Collaborator

aleozlx commented Mar 2, 2026

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !362 has been updated with latest changes, and the CI pipeline #45173693 is currently running. I'll report back once the pipeline job completes.

@aleozlx aleozlx changed the title Update trtllm-gen batched GEMM kernels - faster, more NVFP4 tile dims, MXFP8 with relu2 act perf: Update trtllm-gen batched GEMM kernels - faster, more NVFP4 tile dims, MXFP8 with relu2 act Mar 2, 2026
@flashinfer-bot
Copy link
Collaborator

[CANCELING] Pipeline #45173693: canceled

@aleozlx
Copy link
Collaborator

aleozlx commented Mar 2, 2026

well a bunch jobs got cancelled for some reason
results seem clean enough

@aleozlx aleozlx added the ready label Mar 2, 2026
@aleozlx
Copy link
Collaborator

aleozlx commented Mar 3, 2026

@flashinfer-bot run

@aleozlx
Copy link
Collaborator

aleozlx commented Mar 3, 2026

per "remove-label" error ^^

Co-authored-by: Zihao Ye <zihaoye.cs@gmail.com>
@aleozlx aleozlx enabled auto-merge (squash) March 3, 2026 03:00
@aleozlx aleozlx merged commit 4f422f5 into flashinfer-ai:main Mar 3, 2026
30 of 33 checks passed
ameynaik-hub pushed a commit to ameynaik-hub/flashinfer that referenced this pull request Mar 18, 2026
…e dims, MXFP8 with relu2 act (flashinfer-ai#2667)

## 📌 Description

Updates the trtllm-gen batched GEMM kernel hashes and
`include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export` directory
for:
* Faster batched GEMM kernels
* Additional NVFP4 tile dims kernels
* Additional MXFP8 with Relu^2 non-gated activation kernels

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **Chores**
* Updated an internal batched matrix‑multiplication interface (call
signature changed).
* Refreshed native artifact references and checksums for GPU-accelerated
components.
* No end-user behavioral changes expected; these are internal
compatibility and maintenance updates.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com>
Co-authored-by: jimmzhou <jimmzhou@nvidia.com>
Co-authored-by: Jimmy Zhou <79552142+jimmyzho@users.noreply.github.com>
Co-authored-by: Alex Yang <aleozlx@gmail.com>
Co-authored-by: Zihao Ye <zihaoye.cs@gmail.com>
Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready run-ci v0.6.5 release blocker label for v0.6.5

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants