perf: Update trtllm-gen batched GEMM kernels - faster, more NVFP4 tile dims, MXFP8 with relu2 act#2667
Conversation
… NVFP4 tile dims, MXFP8 with relu2 non-gated activation Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com>
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review infoConfiguration used: defaults Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
📝 WalkthroughWalkthroughUpdates the BatchedGemmInterface::run signature and call site by inserting a new parameter (passed as nullptr at the call site), and updates artifact paths and checksum values for TRTLLM batched GEMM artifacts. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request integrates significant updates to the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
|
/bot run |
There was a problem hiding this comment.
Code Review
This pull request updates the trtllm-gen batched GEMM kernels, introducing performance improvements and new features like additional tile dimensions for NVFP4 and MXFP8 with Relu^2 activation. The changes primarily involve updating kernel hashes and modifying the batched GEMM interface to support new dynamic scheduling features. While most of the changes correctly implement the new functionalities, I've identified a critical issue in BatchedGemmInterface.h concerning memory management and thread safety. A static local variable used for a pinned host buffer results in a memory leak and is not thread-safe. My review includes a comment with a suggested fix to address this by making the caller responsible for the buffer's lifecycle.
| static uint32_t* pinnedCounterValue_ = nullptr; | ||
| if (pinnedHostBuffer) { | ||
| // Use the pre-allocated pinned host memory if provided. | ||
| pinnedCounterValue_ = reinterpret_cast<uint32_t*>(pinnedHostBuffer); | ||
| } | ||
| // Use pinned host memory for the source value to support CUDA graph, or it silently breaks | ||
| // during CUDA graph replay and corrupts the data. | ||
| if (pinnedHostBuffer == nullptr && pinnedCounterValue_ == nullptr) { | ||
| // Allocate pinned host memory if not provided. | ||
| auto err = cudaMallocHost((void**)&pinnedCounterValue_, sizeof(uint32_t)); | ||
| if (err != cudaSuccess) { | ||
| return 1; | ||
| } | ||
| } | ||
| *pinnedCounterValue_ = numFixedCtaBatch * numFixedCtaTile; |
There was a problem hiding this comment.
The use of a static local variable pinnedCounterValue_ here introduces two critical issues:
- Memory Leak: The memory allocated with
cudaMallocHostis never freed, leading to a memory leak. Thestaticpointer will hold the allocated memory address for the lifetime of the application. - Thread Safety:
staticlocal variables are shared across all calls to the function, even from different threads and on different object instances. This can lead to race conditions if multiple threads call thisrunmethod concurrently.
It's recommended to make the caller responsible for managing the lifecycle of this pinned buffer. This utility class should not be allocating memory it doesn't own.
A possible fix is to require pinnedHostBuffer to be non-null when dynamic scheduling is enabled and remove the allocation logic from this function. This will require the caller (TrtllmGenBatchedGemmRunner) to manage this buffer.
| static uint32_t* pinnedCounterValue_ = nullptr; | |
| if (pinnedHostBuffer) { | |
| // Use the pre-allocated pinned host memory if provided. | |
| pinnedCounterValue_ = reinterpret_cast<uint32_t*>(pinnedHostBuffer); | |
| } | |
| // Use pinned host memory for the source value to support CUDA graph, or it silently breaks | |
| // during CUDA graph replay and corrupts the data. | |
| if (pinnedHostBuffer == nullptr && pinnedCounterValue_ == nullptr) { | |
| // Allocate pinned host memory if not provided. | |
| auto err = cudaMallocHost((void**)&pinnedCounterValue_, sizeof(uint32_t)); | |
| if (err != cudaSuccess) { | |
| return 1; | |
| } | |
| } | |
| *pinnedCounterValue_ = numFixedCtaBatch * numFixedCtaTile; | |
| if (!pinnedHostBuffer) { | |
| // A pinned host buffer must be provided for dynamic scheduling with PersistentSm90. | |
| return 1; | |
| } | |
| uint32_t* pinnedCounterValue_ = reinterpret_cast<uint32_t*>(pinnedHostBuffer); | |
| *pinnedCounterValue_ = numFixedCtaBatch * numFixedCtaTile; |
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@csrc/trtllm_batched_gemm_runner.cu`:
- Around line 262-264: The call to bmm.run(...) is passing nullptr for the
pinnedHostBuffer parameter which works because the implementation
auto-allocates, but per BatchedGemmInterface's comment this breaks CUDA graph
replay when CUDA graphs are used with dynamic/persistent scheduling
(PersistentSm90); update the call site around bmm.run to: detect when CUDA graph
mode or dynamic scheduling/persistent SM (e.g., using enable_pdl or
multiProcessorCount flag or a dedicated isPersistentSm90 check) and in that case
require or allocate a real pinned host buffer and pass it instead of nullptr,
otherwise keep nullptr; alternatively add an explicit runtime check/ASSERT
logging that prevents using nullptr in graph+persistent modes to avoid silent
corruption.
In
`@include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h`:
- Around line 743-748: After rounding numCtasXy to clusterSizeXy granularity in
the fixed-grid calculation, guard against it becoming zero: ensure numCtasXy is
at least clusterSizeXy (or set to 1*clusterSizeXy) before multiplying by
multiProcessorOccupancy so the computed CTA grid is never zero; adjust the code
around the variables numCtasXy, clusterSizeXy, multiProcessorCount,
options.mNumSlicesForSplitK, and multiProcessorOccupancy to apply this min-check
prior to returning the tuple.
- Around line 569-586: Remove the static shared pinnedCounterValue_ and require
the caller to supply per-call pinnedHostBuffer: in run() fail early (return
error) if pinnedHostBuffer is null, then declare a stack-local pointer uint32_t*
pinnedCounterValue = reinterpret_cast<uint32_t*>(pinnedHostBuffer) (no
cudaMallocHost or static storage), write the counter into *pinnedCounterValue
and use that pointer as the source for cudaMemcpyAsync to
batchedGemmData.mInputBuffers.mPtrDynamicTileCounter; this guarantees each
invocation uses caller-provided isolated pinned memory and avoids races on
pinnedCounterValue_ across in-flight kernels.
ℹ️ Review info
Configuration used: defaults
Review profile: CHILL
Plan: Pro
⛔ Files ignored due to path filters (7)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CommonUtils.his excluded by!**/gen/**include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CudaArchDecl.his excluded by!**/gen/**include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/CudaKernelLauncher.his excluded by!**/gen/**include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.his excluded by!**/gen/**include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/MmaDecl.his excluded by!**/gen/**include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/SfLayoutDecl.his excluded by!**/gen/**include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/SparsityDecl.his excluded by!**/gen/**
📒 Files selected for processing (12)
csrc/trtllm_batched_gemm_runner.cuflashinfer/artifacts.pyinclude/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmEnums.hinclude/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.hinclude/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmOptions.hinclude/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/Enums.hinclude/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.hinclude/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmOptions.hinclude/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.hinclude/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParamsDecl.hinclude/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelTraits.hinclude/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/TmaDescriptor.h
csrc/trtllm_batched_gemm_runner.cu
Outdated
| auto const err = | ||
| bmm.run(config, workspace, gemmData, static_cast<void*>(stream), multiProcessorCount, | ||
| enable_pdl, nullptr, globalTrtllmGenBatchedGemmModuleCache); |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# Find the BatchedGemmInterface header file
find . -name "BatchedGemmInterface.h" -o -name "*BatchedGemmInterface*" | head -20Repository: flashinfer-ai/flashinfer
Length of output: 152
🏁 Script executed:
# Search for the run() method signature in the include files
rg -n "void\s+run\s*\(" include/ csrc/ -A 5 -B 2 --type cpp --type hRepository: flashinfer-ai/flashinfer
Length of output: 13052
🏁 Script executed:
# Look for the specific parameter that was added (before globalTrtllmGenBatchedGemmModuleCache)
rg -n "bmm\.run\s*\(" csrc/ -B 2 -A 2 --type cppRepository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
cat -n include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.hRepository: flashinfer-ai/flashinfer
Length of output: 49914
🏁 Script executed:
# Also check the implementation file (if it exists) to understand the parameter usage
find . -name "*BatchedGemmInterface*" -type fRepository: flashinfer-ai/flashinfer
Length of output: 152
🏁 Script executed:
# Find and read the trtllm_batched_gemm_runner.cu file around lines 262-264
fd trtllm_batched_gemm_runner.cu | xargs -I {} sh -c 'wc -l "{}"; head -280 "{}" | tail -50'Repository: flashinfer-ai/flashinfer
Length of output: 2547
🏁 Script executed:
# Also search for where pinnedHostBuffer or pinned host memory is used/allocated in the codebase
rg -n "pinnedHostBuffer|pinnedCounterValue|cudaMallocHost" csrc/ include/ -B 2 -A 2Repository: flashinfer-ai/flashinfer
Length of output: 4291
🏁 Script executed:
# Search for CUDA graph usage patterns
rg -n "cuGraphCreate|cudaGraphCreate|cuGraphLaunch|cudaGraphLaunch" csrc/ include/ -B 2 -A 2Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Check enable_pdl and scheduler usage to understand if CUDA graphs are relevant
rg -n "enable_pdl|PersistentSm90|cudaGraph" csrc/trtllm_batched_gemm_runner.cu -B 2 -A 2Repository: flashinfer-ai/flashinfer
Length of output: 2752
Clarify the CUDA graph safety requirement for the nullptr argument.
The nullptr value passed to pinnedHostBuffer is handled by the implementation via automatic allocation, so it won't cause a runtime failure. However, the implementation's own code comment (lines 574–575 in BatchedGemmInterface.h) explicitly warns: "Use pinned host memory for the source value to support CUDA graph, or it silently breaks during CUDA graph replay and corrupts the data." If CUDA graphs are used with dynamic scheduling (PersistentSm90), a valid pinned host buffer must be provided instead of nullptr.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@csrc/trtllm_batched_gemm_runner.cu` around lines 262 - 264, The call to
bmm.run(...) is passing nullptr for the pinnedHostBuffer parameter which works
because the implementation auto-allocates, but per BatchedGemmInterface's
comment this breaks CUDA graph replay when CUDA graphs are used with
dynamic/persistent scheduling (PersistentSm90); update the call site around
bmm.run to: detect when CUDA graph mode or dynamic scheduling/persistent SM
(e.g., using enable_pdl or multiProcessorCount flag or a dedicated
isPersistentSm90 check) and in that case require or allocate a real pinned host
buffer and pass it instead of nullptr, otherwise keep nullptr; alternatively add
an explicit runtime check/ASSERT logging that prevents using nullptr in
graph+persistent modes to avoid silent corruption.
| static uint32_t* pinnedCounterValue_ = nullptr; | ||
| if (pinnedHostBuffer) { | ||
| // Use the pre-allocated pinned host memory if provided. | ||
| pinnedCounterValue_ = reinterpret_cast<uint32_t*>(pinnedHostBuffer); | ||
| } | ||
| // Use pinned host memory for the source value to support CUDA graph, or it silently breaks | ||
| // during CUDA graph replay and corrupts the data. | ||
| if (pinnedHostBuffer == nullptr && pinnedCounterValue_ == nullptr) { | ||
| // Allocate pinned host memory if not provided. | ||
| auto err = cudaMallocHost((void**)&pinnedCounterValue_, sizeof(uint32_t)); | ||
| if (err != cudaSuccess) { | ||
| return 1; | ||
| } | ||
| } | ||
| *pinnedCounterValue_ = numFixedCtaBatch * numFixedCtaTile; | ||
| auto err = cudaMemcpyAsync((void*)batchedGemmData.mInputBuffers.mPtrDynamicTileCounter, | ||
| pinnedCounterValue_, sizeof(uint32_t), cudaMemcpyHostToDevice, | ||
| reinterpret_cast<cudaStream_t>(cudaStream)); |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, get the file size and examine the critical section
wc -l include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.hRepository: flashinfer-ai/flashinfer
Length of output: 154
🏁 Script executed:
# Examine the code around lines 569-586
sed -n '560,595p' include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.hRepository: flashinfer-ai/flashinfer
Length of output: 1910
🏁 Script executed:
# Look for the full run() function signature and context
sed -n '540,600p' include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.hRepository: flashinfer-ai/flashinfer
Length of output: 3412
🏁 Script executed:
# Search for pinnedCounterValue_ usage across the entire file
rg -n "pinnedCounterValue_|pinnedHostBuffer|mPtrDynamicTileCounter" include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h | head -30Repository: flashinfer-ai/flashinfer
Length of output: 1031
🏁 Script executed:
# Look for stream synchronization patterns and cudaMemcpyAsync usage
rg -n -B2 -A2 "cudaMemcpyAsync|cudaStreamSynchronize|cudaDeviceSynchronize|cudaStreamWaitEvent" include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h | head -40Repository: flashinfer-ai/flashinfer
Length of output: 428
Remove the static shared host buffer; require caller to provide isolated pinned host memory per call.
The static pinnedCounterValue_ at line 569 is reused across invocations of run(). When pinnedHostBuffer is not provided or provided inconsistently, the same buffer receives the new counter value at line 583 before the previous async copy at line 584 completes, corrupting the dynamic tile counter for in-flight kernels.
The fix is to eliminate the static storage, make pinnedHostBuffer mandatory (return error if null), and use a stack-local variable. This ensures each call gets an isolated source buffer, preventing data races.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In
`@include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h`
around lines 569 - 586, Remove the static shared pinnedCounterValue_ and require
the caller to supply per-call pinnedHostBuffer: in run() fail early (return
error) if pinnedHostBuffer is null, then declare a stack-local pointer uint32_t*
pinnedCounterValue = reinterpret_cast<uint32_t*>(pinnedHostBuffer) (no
cudaMallocHost or static storage), write the counter into *pinnedCounterValue
and use that pointer as the source for cudaMemcpyAsync to
batchedGemmData.mInputBuffers.mPtrDynamicTileCounter; this guarantees each
invocation uses caller-provided isolated pinned memory and avoids races on
pinnedCounterValue_ across in-flight kernels.
| auto numCtasXy = multiProcessorCount / options.mNumSlicesForSplitK; | ||
| // Round down to the nearest multiple of the cluster size. | ||
| numCtasXy = (numCtasXy / clusterSizeXy) * clusterSizeXy; | ||
| // Account for available occupancy for the kernel. Defaults to 1 CTA per multiprocessor. | ||
| numCtasXy *= multiProcessorOccupancy; | ||
| return std::make_tuple(numCtasXy, 1, options.mNumSlicesForSplitK); |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
cat -n include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h | sed -n '730,760p'Repository: flashinfer-ai/flashinfer
Length of output: 2072
🏁 Script executed:
rg -n "getFixedGridDim" include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h -A 30Repository: flashinfer-ai/flashinfer
Length of output: 5536
🏁 Script executed:
# Check for Torch headers in this file
rg -n "torch" include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.hRepository: flashinfer-ai/flashinfer
Length of output: 148
🏁 Script executed:
# Understand the data types and parameter constraints
rg -n "multiProcessorCount|clusterSizeXy|multiProcessorOccupancy|mNumSlicesForSplitK" include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h -B 3 -A 1Repository: flashinfer-ai/flashinfer
Length of output: 3203
Guard against zero CTA dimensions in fixed-grid calculation.
After floor-rounding to cluster granularity (line 745), numCtasXy can become 0 when multiProcessorCount / options.mNumSlicesForSplitK < clusterSizeXy. This produces invalid CUDA launch grids. Add a check to ensure at least one cluster is available:
// The maximum number of CTAs a GPU can run across the XY dimension.
auto numCtasXy = multiProcessorCount / options.mNumSlicesForSplitK;
// Round down to the nearest multiple of the cluster size.
numCtasXy = (numCtasXy / clusterSizeXy) * clusterSizeXy;
+ // Keep at least one full cluster to avoid invalid zero-sized launches.
+ if (numCtasXy == 0) {
+ numCtasXy = clusterSizeXy;
+ }
// Account for available occupancy for the kernel. Defaults to 1 CTA per multiprocessor.
numCtasXy *= multiProcessorOccupancy;📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| auto numCtasXy = multiProcessorCount / options.mNumSlicesForSplitK; | |
| // Round down to the nearest multiple of the cluster size. | |
| numCtasXy = (numCtasXy / clusterSizeXy) * clusterSizeXy; | |
| // Account for available occupancy for the kernel. Defaults to 1 CTA per multiprocessor. | |
| numCtasXy *= multiProcessorOccupancy; | |
| return std::make_tuple(numCtasXy, 1, options.mNumSlicesForSplitK); | |
| auto numCtasXy = multiProcessorCount / options.mNumSlicesForSplitK; | |
| // Round down to the nearest multiple of the cluster size. | |
| numCtasXy = (numCtasXy / clusterSizeXy) * clusterSizeXy; | |
| // Keep at least one full cluster to avoid invalid zero-sized launches. | |
| if (numCtasXy == 0) { | |
| numCtasXy = clusterSizeXy; | |
| } | |
| // Account for available occupancy for the kernel. Defaults to 1 CTA per multiprocessor. | |
| numCtasXy *= multiProcessorOccupancy; | |
| return std::make_tuple(numCtasXy, 1, options.mNumSlicesForSplitK); |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In
`@include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/BatchedGemmInterface.h`
around lines 743 - 748, After rounding numCtasXy to clusterSizeXy granularity in
the fixed-grid calculation, guard against it becoming zero: ensure numCtasXy is
at least clusterSizeXy (or set to 1*clusterSizeXy) before multiplying by
multiProcessorOccupancy so the computed CTA grid is never zero; adjust the code
around the variables numCtasXy, clusterSizeXy, multiProcessorCount,
options.mNumSlicesForSplitK, and multiProcessorOccupancy to apply this min-check
prior to returning the tuple.
aleozlx
left a comment
There was a problem hiding this comment.
lgtm
double checked hash
|
/bot run |
There was a problem hiding this comment.
♻️ Duplicate comments (1)
csrc/trtllm_batched_gemm_runner.cu (1)
261-263:⚠️ Potential issue | 🟠 Major
nullptrfor the newrunargument is still risky in PDL/graph-like execution paths.Line 262-Line 263 still pass
nullptrfor the newly insertedBatchedGemmInterface::runparameter whileenable_pdlis enabled. Please gate this mode or plumb a real pinned host buffer to avoid silent replay/data issues in persistent scheduling paths.#!/bin/bash # Verify API intent + all callsites for the new BatchedGemmInterface::run parameter. set -euo pipefail # 1) Locate BatchedGemmInterface::run declaration and nearby comments/docs. fd "BatchedGemmInterface.h" | while read -r f; do echo "=== $f ===" rg -n "void\\s+run\\s*\\(" "$f" -n -A20 -B8 done # 2) Find all bmm.run(...) callsites and inspect the argument passed in that slot. rg -nP --type=cpp -C3 "\\bbmm\\.run\\s*\\(" csrc include # 3) Check whether this runner path is used with PDL / persistent / graph-related code paths. rg -n -C3 "enable_pdl|Persistent|cudaGraph|graph replay|runInitBeforeWorldSync" csrc include🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@csrc/trtllm_batched_gemm_runner.cu` around lines 261 - 263, The callsite is passing nullptr for the new BatchedGemmInterface::run parameter (bmm.run(..., enable_pdl, nullptr, globalTrtllmGenBatchedGemmModuleCache)), which is unsafe when enable_pdl is true; change the caller to either (a) gate this path so that when enable_pdl is true we do not pass nullptr (early-return or disable PDL for this runner), or (b) allocate and pass a real pinned host buffer/structure and pass that instead of nullptr; update the bmm.run invocation site to check enable_pdl and supply the pinned host buffer (or fail with a clear error) so persistent/graph replay paths never receive a nullptr.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@csrc/trtllm_batched_gemm_runner.cu`:
- Around line 261-263: The callsite is passing nullptr for the new
BatchedGemmInterface::run parameter (bmm.run(..., enable_pdl, nullptr,
globalTrtllmGenBatchedGemmModuleCache)), which is unsafe when enable_pdl is
true; change the caller to either (a) gate this path so that when enable_pdl is
true we do not pass nullptr (early-return or disable PDL for this runner), or
(b) allocate and pass a real pinned host buffer/structure and pass that instead
of nullptr; update the bmm.run invocation site to check enable_pdl and supply
the pinned host buffer (or fail with a clear error) so persistent/graph replay
paths never receive a nullptr.
…v/flashinfer into update-batched-gemm-faster-gemms
|
[FAILED] Pipeline #45171727: 1/20 passed |
|
the errors seem infra issue. trying again |
|
/bot run |
|
[CANCELING] Pipeline #45173693: canceled |
|
well a bunch jobs got cancelled for some reason |
|
@flashinfer-bot run |
|
per "remove-label" error ^^ |
Co-authored-by: Zihao Ye <zihaoye.cs@gmail.com>
…e dims, MXFP8 with relu2 act (flashinfer-ai#2667) ## 📌 Description Updates the trtllm-gen batched GEMM kernel hashes and `include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export` directory for: * Faster batched GEMM kernels * Additional NVFP4 tile dims kernels * Additional MXFP8 with Relu^2 non-gated activation kernels ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Chores** * Updated an internal batched matrix‑multiplication interface (call signature changed). * Refreshed native artifact references and checksums for GPU-accelerated components. * No end-user behavioral changes expected; these are internal compatibility and maintenance updates. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com> Co-authored-by: jimmzhou <jimmzhou@nvidia.com> Co-authored-by: Jimmy Zhou <79552142+jimmyzho@users.noreply.github.com> Co-authored-by: Alex Yang <aleozlx@gmail.com> Co-authored-by: Zihao Ye <zihaoye.cs@gmail.com> Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
📌 Description
Updates the trtllm-gen batched GEMM kernel hashes and
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_exportdirectory for:🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit