[Perf][Feature] Add SM103-specific schedulers for NVFP4 CUTLASS kernels#2303
Conversation
Signed-off-by: LopezCastroRoberto <robertol.c510@gmail.com>
Signed-off-by: LopezCastroRoberto <robertol.c510@gmail.com>
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. 📝 WalkthroughWalkthroughThis PR adds FP4 GEMM support for SM103 GPUs by introducing SM103-specific kernel implementations, tile configurations, type adapters, and a JIT module generator. Changes span CUDA kernels, C++ templates, configuration headers, and Python integration for runtime dispatch. Changes
Sequence DiagramssequenceDiagram
participant PythonAPI as Python API<br/>(mm_fp4)
participant Dispatcher as Compute Capability<br/>Dispatcher
participant SM103Gen as SM103 Module<br/>Generator
participant JITCompiler as JIT Compiler
participant CUDARuntime as CUDA Runtime
PythonAPI->>Dispatcher: Extract (major, minor) capability
alt sm_minor == 3
Dispatcher->>SM103Gen: get_gemm_sm103_module_cutlass_fp4()
SM103Gen->>SM103Gen: Generate fp4_gemm_cutlass_sm103.cu<br/>with tile configs<br/>(128x128x768, 128x192x768,<br/>128x256x768)
SM103Gen->>JITCompiler: Render templates &<br/>compile sources
JITCompiler->>JITCompiler: Build SM103 kernels<br/>with ENABLE_FP4,<br/>ENABLE_BF16 flags
JITCompiler-->>SM103Gen: Compiled module
SM103Gen-->>Dispatcher: cutlass_fp4_gemm_runner()
else sm_minor != 3
Dispatcher->>Dispatcher: Fallback to SM100 path
end
Dispatcher->>PythonAPI: Return kernel runner
PythonAPI->>CUDARuntime: Dispatch FP4 GEMM<br/>via tactic selection
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @LopezCastroRoberto, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the performance of NVFP4 General Matrix Multiplication (GEMM) operations on NVIDIA's SM103 architecture (B300 GPUs). By introducing specialized CUTLASS kernel schedulers tailored for SM103, the changes aim to unlock greater efficiency and speedup, especially for larger batch sizes in deep learning workloads. The integration ensures that the system intelligently selects the most performant kernel for the given hardware, without replacing existing configurations for other architectures. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces SM103-specific schedulers for NVFP4 CUTLASS kernels to enhance performance, particularly for larger batch sizes. The changes are well-structured, adding new kernel configurations and the necessary C++ and Python logic to dispatch to them based on the GPU architecture. The overall approach is sound. My review has identified a high-severity issue in the build scripts that could cause file conflicts, along with a couple of medium-severity issues related to misleading documentation and error messages. Addressing these points will improve the correctness and maintainability of the code.
There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
include/flashinfer/gemm/cutlass_gemm_configs.h (2)
284-301: Missingget_cluster_shape_name()cases for several cluster shapes.The
ClusterShapeenum includesClusterShape_1x4x1,ClusterShape_4x2x1,ClusterShape_2x4x1, andClusterShape_4x4x1, butget_cluster_shape_name()does not handle these cases, returning "Unknown shape" for them. This may cause confusion during debugging or logging.Consider adding the missing cases for completeness:
Proposed fix
static auto get_cluster_shape_name(ClusterShape Shape_MNK) { if (Shape_MNK == ClusterShape::ClusterShape_1x1x1) { return "1x1x1"; } else if (Shape_MNK == ClusterShape::ClusterShape_2x1x1) { return "2x1x1"; } else if (Shape_MNK == ClusterShape::ClusterShape_1x2x1) { return "1x2x1"; } else if (Shape_MNK == ClusterShape::ClusterShape_2x2x1) { return "2x2x1"; + } else if (Shape_MNK == ClusterShape::ClusterShape_1x4x1) { + return "1x4x1"; + } else if (Shape_MNK == ClusterShape::ClusterShape_4x2x1) { + return "4x2x1"; + } else if (Shape_MNK == ClusterShape::ClusterShape_2x4x1) { + return "2x4x1"; + } else if (Shape_MNK == ClusterShape::ClusterShape_4x4x1) { + return "4x4x1"; } else if (Shape_MNK == ClusterShape::ClusterShape_1x8x1) { return "1x8x1"; } else if (Shape_MNK == ClusterShape::ClusterShape_8x1x1) { return "8x1x1"; } else if (Shape_MNK == ClusterShape::ClusterShape_4x1x1) { return "4x1x1"; } return "Unknown shape"; }
303-321: Missingget_cluster_shape()cases cause undefined behavior.The template function
get_cluster_shape()does not handleClusterShape_1x4x1,ClusterShape_4x2x1,ClusterShape_2x4x1, andClusterShape_4x4x1. For unmatched cases, the function has no return statement, resulting in undefined behavior.Proposed fix
template <ClusterShape Shape_MNK> constexpr auto get_cluster_shape() { using namespace cute; if constexpr (Shape_MNK == ClusterShape::ClusterShape_1x1x1) { return cute::Shape<_1, _1, _1>{}; } else if constexpr (Shape_MNK == ClusterShape::ClusterShape_2x1x1) { return cute::Shape<_2, _1, _1>{}; } else if constexpr (Shape_MNK == ClusterShape::ClusterShape_1x2x1) { return cute::Shape<_1, _2, _1>{}; } else if constexpr (Shape_MNK == ClusterShape::ClusterShape_2x2x1) { return cute::Shape<_2, _2, _1>{}; + } else if constexpr (Shape_MNK == ClusterShape::ClusterShape_1x4x1) { + return cute::Shape<_1, _4, _1>{}; + } else if constexpr (Shape_MNK == ClusterShape::ClusterShape_4x2x1) { + return cute::Shape<_4, _2, _1>{}; + } else if constexpr (Shape_MNK == ClusterShape::ClusterShape_2x4x1) { + return cute::Shape<_2, _4, _1>{}; + } else if constexpr (Shape_MNK == ClusterShape::ClusterShape_4x4x1) { + return cute::Shape<_4, _4, _1>{}; } else if constexpr (Shape_MNK == ClusterShape::ClusterShape_1x8x1) { return cute::Shape<_1, _8, _1>{}; } else if constexpr (Shape_MNK == ClusterShape::ClusterShape_8x1x1) { return cute::Shape<_8, _1, _1>{}; } else if constexpr (Shape_MNK == ClusterShape::ClusterShape_4x1x1) { return cute::Shape<_4, _1, _1>{}; + } else { + static_assert(sizeof(Shape_MNK) == 0, "Unsupported ClusterShape"); } }
🤖 Fix all issues with AI agents
In @flashinfer/gemm/gemm_base.py:
- Around line 525-531: The docstring for get_gemm_sm103_module_cutlass_fp4() is
incorrect (it references SM100/103/110); update it to accurately describe this
function as returning the SM103 FP4 GEMM module (e.g., "Get the SM103 FP4 GEMM
module.") so it matches the function name and behavior in
gen_gemm_sm103_module_cutlass_fp4() and the _create_cutlass_fp4_gemm_module
call.
In @flashinfer/jit/gemm/cutlass/cutlass_library.py:
- Line 627: Remove the personal annotation "#RLC:" from the
KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm103 mapping in the cutlass mapping
table (the entry that maps to the long cutlass::gemm class name) and add a
corresponding suffix entry to the KernelScheduleSuffixes dictionary for
KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm103 with the value
"_o_vs16_2sm_sm103" so the suffix map includes this schedule type.
In @include/flashinfer/gemm/fp4_gemm_template_sm103.h:
- Around line 270-281: Error messages reference the wrong architecture string;
update the messages constructed after the gemm.initialize (initStatus) and
gemm.run (runStatus) checks to say "sm103" instead of "sm100". Locate the blocks
using gemm.initialize(args, workspace, stream) and gemm.run(args, workspace,
stream, nullptr, /*enablePDL=*/true) and change the human-readable text in the
std::string errMsg concatenations that include "Failed to initialize/run cutlass
FP4 gemm on sm100" to "Failed to initialize/run cutlass FP4 gemm on sm103" while
keeping the rest of the error handling (cutlassGetStatusString, throwing
std::runtime_error) unchanged.
🧹 Nitpick comments (6)
csrc/fp4_gemm_cutlass.jinja (1)
29-29: LGTM! New cluster configuration correctly instantiated.The new (4,1,1) cluster configuration with _2SM scheduler is correctly instantiated and complements the existing configurations. This aligns with the PR objective to improve SM103 NVFP4 performance.
♻️ Optional: Consider reordering for better organization
For improved readability, you might place the (4,1,1) configuration before (4,2,1) to maintain a consistent ordering pattern (cluster_m=4, then cluster_n in ascending order: 1, 2, 4).
INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 2, 4, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 4, 1, 1, _2SM) INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 4, 2, 1, _2SM) INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 4, 4, 1, _2SM) -INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 4, 1, 1, _2SM)This is purely cosmetic and doesn't affect functionality.
include/flashinfer/gemm/fp4_gemm_cutlass_template_sm103.h (3)
17-18: Include guard name may conflict with other headers.The include guard
FLASHINFER_FP4_GEMM_CUTLASS_TEMPLATE_H_is generic and doesn't include "SM103". If there's anotherfp4_gemm_cutlass_template.h(e.g., for SM100), this could cause include guard collisions.Suggested fix
-#ifndef FLASHINFER_FP4_GEMM_CUTLASS_TEMPLATE_H_ -#define FLASHINFER_FP4_GEMM_CUTLASS_TEMPLATE_H_ +#ifndef FLASHINFER_FP4_GEMM_CUTLASS_TEMPLATE_SM103_H_ +#define FLASHINFER_FP4_GEMM_CUTLASS_TEMPLATE_SM103_H_And at the end of the file:
-#endif // FLASHINFER_FP4_GEMM_CUTLASS_TEMPLATE_H_ +#endif // FLASHINFER_FP4_GEMM_CUTLASS_TEMPLATE_SM103_H_
357-364: Weak hash function with high collision probability.The hash function XORs all four values directly without bit shifting, which leads to poor distribution. For example,
(1,2,3,4)and(2,1,4,3)would produce the same hash.Proposed fix using a better hash combination
struct MNKHash { size_t operator()(const MNK& mnk) const { auto h1 = std::hash<int>{}(std::get<0>(mnk)); auto h2 = std::hash<int>{}(std::get<1>(mnk)); auto h3 = std::hash<int>{}(std::get<2>(mnk)); auto h4 = std::hash<int>{}(std::get<3>(mnk)); - return h1 ^ h2 ^ h3 ^ h4; + // Combine hashes with bit shifting to reduce collisions + size_t seed = h1; + seed ^= h2 + 0x9e3779b9 + (seed << 6) + (seed >> 2); + seed ^= h3 + 0x9e3779b9 + (seed << 6) + (seed >> 2); + seed ^= h4 + 0x9e3779b9 + (seed << 6) + (seed >> 2); + return seed; } };
287-329: Review thegetConfigs()tactic ordering.The
best_tactics_indexlist{22, 20, 29, 4, 18}references specific indices incandidateConfigs. This assumes the configuration list order is stable. Any changes totilesSm100orclusterShapesvectors will invalidate these indices, leading to incorrect tactic prioritization.Consider using a more robust approach, such as storing the actual configuration tuples rather than indices.
flashinfer/jit/gemm/cutlass/generate_kernels.py (1)
22-22: Unused import.The
loggeris imported but does not appear to be used anywhere in this file.Proposed fix
-from ...core import loggercsrc/fp4_gemm_cutlass_sm103.cu (1)
103-103: Consider removing or documenting the unused variable.
mat2_k_scaleis set to 1 and used in dimension checks, but its purpose isn't clear. If it's a placeholder for future scaling functionality, a comment explaining this would help. If it's truly unused, consider removing it.
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (11)
csrc/fp4_gemm_cutlass.jinjacsrc/fp4_gemm_cutlass_sm103.cucsrc/fp4_gemm_cutlass_sm103.jinjaflashinfer/gemm/gemm_base.pyflashinfer/jit/gemm/__init__.pyflashinfer/jit/gemm/core.pyflashinfer/jit/gemm/cutlass/cutlass_library.pyflashinfer/jit/gemm/cutlass/generate_kernels.pyinclude/flashinfer/gemm/cutlass_gemm_configs.hinclude/flashinfer/gemm/fp4_gemm_cutlass_template_sm103.hinclude/flashinfer/gemm/fp4_gemm_template_sm103.h
🧰 Additional context used
📓 Path-based instructions (4)
flashinfer/**/*.py
📄 CodeRabbit inference engine (CLAUDE.md)
flashinfer/**/*.py: Use@functools.cachedecorator on Python API functions to implement module-level caching and avoid recompilation
Use@flashinfer_apidecorator for debugging API calls, enable viaFLASHINFER_LOGLEVELenvironment variable (0=off, 1=basic, 3=detailed, 5=with stats)
Files:
flashinfer/jit/gemm/cutlass/generate_kernels.pyflashinfer/jit/gemm/core.pyflashinfer/jit/gemm/cutlass/cutlass_library.pyflashinfer/gemm/gemm_base.pyflashinfer/jit/gemm/__init__.py
flashinfer/jit/**/*.py
📄 CodeRabbit inference engine (CLAUDE.md)
flashinfer/jit/**/*.py: JIT module generators inflashinfer/jit/must follow the pattern: compute URI → create directory → (optional) render Jinja template → copy sources → return JitSpec
Usegen_jit_spec()function to return a properly configured JitSpec from module generators with appropriatesourcesandextra_cuda_cflags
Specifysupported_major_versionsin JitSpec to restrict kernel compilation to supported GPU architectures (e.g., SM versions 9, 10, 11, 12 for Hopper/newer)
Files:
flashinfer/jit/gemm/cutlass/generate_kernels.pyflashinfer/jit/gemm/core.pyflashinfer/jit/gemm/cutlass/cutlass_library.pyflashinfer/jit/gemm/__init__.py
csrc/**/*.jinja
📄 CodeRabbit inference engine (CLAUDE.md)
csrc/**/*.jinja: Use dispatch macros (e.g.,DISPATCH_DTYPE,DISPATCH_BLOCK_SIZE) in.jinjatemplate files to handle combinatorial parameter spaces in CUDA kernels
UseDISPATCH_DTYPE,DISPATCH_BLOCK_SIZE, and similar macros to reduce code duplication when handling multiple dtype and template parameter combinations
Files:
csrc/fp4_gemm_cutlass.jinjacsrc/fp4_gemm_cutlass_sm103.jinja
csrc/**/*.cu
📄 CodeRabbit inference engine (CLAUDE.md)
Framework bindings and PyTorch tensor handling should be implemented in
csrc/via TVM-FFI, not ininclude/headers
Files:
csrc/fp4_gemm_cutlass_sm103.cu
🧠 Learnings (12)
📓 Common learnings
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/jit/**/*.py : Specify `supported_major_versions` in JitSpec to restrict kernel compilation to supported GPU architectures (e.g., SM versions 9, 10, 11, 12 for Hopper/newer)
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
csrc/fp4_gemm_cutlass.jinjacsrc/fp4_gemm_cutlass_sm103.cucsrc/fp4_gemm_cutlass_sm103.jinjainclude/flashinfer/gemm/fp4_gemm_template_sm103.hinclude/flashinfer/gemm/fp4_gemm_cutlass_template_sm103.h
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Use `FLASHINFER_CUDA_ARCH_LIST` environment variable to specify target GPU architectures (e.g., '8.0 9.0a') and `FLASHINFER_NVCC_THREADS` to control parallel compilation threads
Applied to files:
csrc/fp4_gemm_cutlass.jinjacsrc/fp4_gemm_cutlass_sm103.jinja
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/jit/**/*.py : Use `gen_jit_spec()` function to return a properly configured JitSpec from module generators with appropriate `sources` and `extra_cuda_cflags`
Applied to files:
flashinfer/jit/gemm/core.pyflashinfer/gemm/gemm_base.pyflashinfer/jit/gemm/__init__.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/jit/**/*.py : JIT module generators in `flashinfer/jit/` must follow the pattern: compute URI → create directory → (optional) render Jinja template → copy sources → return JitSpec
Applied to files:
flashinfer/jit/gemm/core.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/jit/**/*.py : Specify `supported_major_versions` in JitSpec to restrict kernel compilation to supported GPU architectures (e.g., SM versions 9, 10, 11, 12 for Hopper/newer)
Applied to files:
flashinfer/jit/gemm/core.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/aot.py : Register new operations in `flashinfer/aot.py` by calling the `gen_*_module()` function for AOT (Ahead-Of-Time) pre-compilation support
Applied to files:
flashinfer/jit/gemm/core.pyflashinfer/jit/gemm/__init__.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Kernel code in `include/flashinfer/` is automatically picked up by JIT compilation on changes - no pip reinstall needed
Applied to files:
flashinfer/jit/gemm/core.pycsrc/fp4_gemm_cutlass_sm103.jinjaflashinfer/jit/gemm/__init__.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to csrc/**/*_jit_binding.cu : Create TVM-FFI bindings in files matching the pattern `csrc/*_jit_binding.cu` using the `TVM_FFI_DLL_EXPORT_TYPED_FUNC(name, func)` macro to expose C++ functions
Applied to files:
csrc/fp4_gemm_cutlass_sm103.cu
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to csrc/**/*.cu : Framework bindings and PyTorch tensor handling should be implemented in `csrc/` via TVM-FFI, not in `include/` headers
Applied to files:
csrc/fp4_gemm_cutlass_sm103.cu
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Consult the PTX ISA documentation (https://docs.nvidia.com/cuda/parallel-thread-execution/) for low-level instruction details and new GPU architecture features when writing inline PTX assembly
Applied to files:
csrc/fp4_gemm_cutlass_sm103.jinjainclude/flashinfer/gemm/fp4_gemm_template_sm103.h
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/__init__.py : Export new operations in `flashinfer/__init__.py` to make them available as public API
Applied to files:
flashinfer/jit/gemm/__init__.py
🧬 Code graph analysis (4)
flashinfer/jit/gemm/core.py (2)
flashinfer/jit/core.py (2)
JitSpec(216-397)gen_jit_spec(400-466)flashinfer/compilation_context.py (1)
get_nvcc_flags_list(50-68)
flashinfer/jit/gemm/__init__.py (1)
flashinfer/jit/gemm/core.py (1)
gen_gemm_sm103_module_cutlass_fp4(97-165)
include/flashinfer/gemm/fp4_gemm_template_sm103.h (2)
include/flashinfer/gemm/fp4_gemm_cutlass_template_sm103.h (1)
gemm(42-381)include/flashinfer/gemm/cutlass_gemm_configs.h (1)
ClusterShape(270-412)
include/flashinfer/gemm/fp4_gemm_cutlass_template_sm103.h (2)
include/flashinfer/gemm/fp4_gemm_template_sm103.h (4)
gemm(38-288)void(151-283)_1SM_sm103(55-60)_2SM_sm103(63-68)include/flashinfer/gemm/fp4_gemm_cutlass.h (1)
FP4GemmType(59-88)
🪛 Ruff (0.14.10)
flashinfer/jit/gemm/core.py
157-161: Consider [*nvcc_flags, "-DENABLE_BF16", "-DENABLE_FP4"] instead of concatenation
Replace with [*nvcc_flags, "-DENABLE_BF16", "-DENABLE_FP4"]
(RUF005)
🔇 Additional comments (17)
flashinfer/jit/gemm/__init__.py (1)
20-20: LGTM!The new
gen_gemm_sm103_module_cutlass_fp4symbol is correctly imported and exported, following the established pattern for other SM-specific module generators.Also applies to: 36-36
include/flashinfer/gemm/cutlass_gemm_configs.h (1)
136-140: LGTM!The SM103-specific tile configurations (128x128x768, 128x192x768, 128x256x768) are correctly added to the
CutlassTileConfigSM100enum,TileShapeenum, and the correspondingget_tile_shape()andget_tile_shape_name()functions.Also applies to: 196-200, 228-233, 260-265
include/flashinfer/gemm/fp4_gemm_cutlass_template_sm103.h (1)
45-112: Missing cluster shape cases in dispatch functions.Both
dispatchNVFP4xNVFP4GemmClusterShapeSm100anddispatchNVFP4xNVFP4GemmClusterShapeSm103handle most cluster shapes but missClusterShape::ClusterShape_1x8x1andClusterShape::ClusterShape_8x1x1. These are present in theClusterShapeenum and used ingetConfigs(). If these shapes are selected during autotuning, the dispatch will throw a runtime error.Please verify whether
ClusterShape_1x8x1andClusterShape_8x1x1should be supported for SM103 FP4 GEMM, or if they should be excluded from the config list at lines 300-306.Also applies to: 114-181
csrc/fp4_gemm_cutlass_sm103.jinja (1)
1-32: LGTM!The Jinja template correctly instantiates SM103 FP4 Ultra GEMM kernel launchers for the supported cluster shape configurations, with appropriate SM type suffixes (
_1SM_sm103,_2SM_sm103).flashinfer/jit/gemm/core.py (3)
97-99: Shared generation directory may cause confusion.The
gen_directoryis set to"gen_gemm_sm100_cutlass_fp4", same asgen_gemm_sm100_module_cutlass_fp4(). While this may be intentional (the SM103 module includes SM100 configurations), it could lead to file collisions or confusion during incremental builds. Consider using a distinct directory like"gen_gemm_sm103_cutlass_fp4".
127-149: SM103 module includes SM100 kernel configurations.The SM103 module generator also renders kernels using
fp4_gemm_cutlass.jinjawith SM100 tile configurations. This creates a superset module containing both SM100 and SM103 kernels.Please confirm this is the intended design - the SM103 module should support both SM100 base configurations and SM103-specific optimized configurations for autotuning to select the best one.
151-165: LGTM - follows established JIT module pattern.The function correctly:
- Specifies
supported_major_versions=[10, 11, 12]per coding guidelines- Uses
gen_jit_spec()to return a properly configured JitSpec- Includes appropriate CUDA flags for BF16 and FP4 support
Based on learnings, the
supported_major_versionsspecification aligns with JIT module conventions.csrc/fp4_gemm_cutlass_sm103.cu (4)
1-43: LGTM - File structure and template instantiations are correct.The file correctly implements TVM-FFI bindings for SM103 FP4 GEMM as per the coding guidelines. Template instantiations for both
__nv_bfloat16andhalftypes are properly declared.
49-58: LGTM - Config retrieval with proper bounds checking.The static config caching and bounds validation are correctly implemented.
176-193: LGTM - Public API functions and TVM FFI exports are correct.The
fp4_gemmwrapper andfp4_gemm_tactic_numfunctions are cleanly implemented. The TVM FFI exports follow the correct pattern.
78-84: Verify ifffi::Tensorreference counting prevents premature deallocation of temporary workspace.The async GEMM kernel receives a pointer to
new_workspace, which goes out of scope before the kernel completes. This is safe only if TVM'sffi::Tensoruses reference counting or environment-managed memory that extends the tensor's lifetime beyond the local scope. Verify against TVM's FFI documentation or implementation to confirm the memory lifetime guarantees, or add explicit stream synchronization as a safeguard.flashinfer/gemm/gemm_base.py (2)
542-554: LGTM - SM103 routing logic is correct.The routing correctly identifies SM103 (major=10, minor=3) and routes to the specialized module. Other SM10x/SM11x variants correctly fall back to the SM100 path.
2288-2295: LGTM - Compute capability extraction updated correctly.The change properly extracts both major and minor versions and passes them to enable SM103-specific module selection.
include/flashinfer/gemm/fp4_gemm_template_sm103.h (4)
1-46: LGTM - Header structure and type definitions are correct.The header guard, includes, namespace structure, and SafeBF16_sm103 definition follow established patterns. The conditional BF16 handling is appropriate.
47-68: LGTM - SM type adapters correctly specialized for SM103.The 1SM and 2SM configurations properly define their respective scales, thread shapes, and SM103-specific schedule types.
148-163: LGTM - Architecture guard correctly enforces SM103 execution.The
Sm103Onlywrapper provides a safety mechanism to prevent execution on incompatible architectures. The use ofis_match_v<103>and__trap()follows established patterns for architecture enforcement.
254-257: LGTM - Workspace size query pattern is correct.The null-pointer check pattern for querying workspace size without running the kernel is a standard CUTLASS convention.
Signed-off-by: LopezCastroRoberto <robertol.c510@gmail.com>
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In @flashinfer/gemm/gemm_base.py:
- Line 519: Update the docstring for the function that currently reads "Get the
SM100/103/110 FP4 GEMM module." to reflect that SM103 is now separate; change
the text to "Get the SM100/110 FP4 GEMM module." so it matches the dedicated
get_gemm_sm103_module_cutlass_fp4() handler and avoids confusion when locating
get_gemm_sm100_110_module_cutlass_fp4().
🧹 Nitpick comments (1)
flashinfer/jit/gemm/core.py (1)
157-161: Consider using spread operator for list concatenation.The static analysis tool suggests using the spread operator for cleaner list concatenation.
♻️ Suggested refactor
return gen_jit_spec( "fp4_gemm_cutlass_sm103", source_paths, - extra_cuda_cflags=nvcc_flags - + [ - "-DENABLE_BF16", - "-DENABLE_FP4", - ], + extra_cuda_cflags=[ + *nvcc_flags, + "-DENABLE_BF16", + "-DENABLE_FP4", + ], extra_cflags=[ "-DFAST_BUILD", ], )
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
flashinfer/gemm/gemm_base.pyflashinfer/jit/gemm/core.pyinclude/flashinfer/gemm/fp4_gemm_template_sm103.h
🚧 Files skipped from review as they are similar to previous changes (1)
- include/flashinfer/gemm/fp4_gemm_template_sm103.h
🧰 Additional context used
📓 Path-based instructions (2)
flashinfer/**/*.py
📄 CodeRabbit inference engine (CLAUDE.md)
flashinfer/**/*.py: Use@functools.cachedecorator on Python API functions to implement module-level caching and avoid recompilation
Use@flashinfer_apidecorator for debugging API calls, enable viaFLASHINFER_LOGLEVELenvironment variable (0=off, 1=basic, 3=detailed, 5=with stats)
Files:
flashinfer/gemm/gemm_base.pyflashinfer/jit/gemm/core.py
flashinfer/jit/**/*.py
📄 CodeRabbit inference engine (CLAUDE.md)
flashinfer/jit/**/*.py: JIT module generators inflashinfer/jit/must follow the pattern: compute URI → create directory → (optional) render Jinja template → copy sources → return JitSpec
Usegen_jit_spec()function to return a properly configured JitSpec from module generators with appropriatesourcesandextra_cuda_cflags
Specifysupported_major_versionsin JitSpec to restrict kernel compilation to supported GPU architectures (e.g., SM versions 9, 10, 11, 12 for Hopper/newer)
Files:
flashinfer/jit/gemm/core.py
🧠 Learnings (6)
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/jit/**/*.py : Use `gen_jit_spec()` function to return a properly configured JitSpec from module generators with appropriate `sources` and `extra_cuda_cflags`
Applied to files:
flashinfer/gemm/gemm_base.pyflashinfer/jit/gemm/core.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/jit/**/*.py : JIT module generators in `flashinfer/jit/` must follow the pattern: compute URI → create directory → (optional) render Jinja template → copy sources → return JitSpec
Applied to files:
flashinfer/jit/gemm/core.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/jit/**/*.py : Specify `supported_major_versions` in JitSpec to restrict kernel compilation to supported GPU architectures (e.g., SM versions 9, 10, 11, 12 for Hopper/newer)
Applied to files:
flashinfer/jit/gemm/core.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/aot.py : Register new operations in `flashinfer/aot.py` by calling the `gen_*_module()` function for AOT (Ahead-Of-Time) pre-compilation support
Applied to files:
flashinfer/jit/gemm/core.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Kernel code in `include/flashinfer/` is automatically picked up by JIT compilation on changes - no pip reinstall needed
Applied to files:
flashinfer/jit/gemm/core.py
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
flashinfer/jit/gemm/core.py
🪛 Ruff (0.14.10)
flashinfer/jit/gemm/core.py
157-161: Consider [*nvcc_flags, "-DENABLE_BF16", "-DENABLE_FP4"] instead of concatenation
Replace with [*nvcc_flags, "-DENABLE_BF16", "-DENABLE_FP4"]
(RUF005)
🔇 Additional comments (4)
flashinfer/jit/gemm/core.py (1)
97-166: LGTM! SM103 module generator correctly implements dual-configuration strategy.The implementation properly generates SM103-specific optimizations alongside fallback configurations by using both
fp4_gemm_cutlass_sm103.jinja(with larger K-dimension tiles: 768) andfp4_gemm_cutlass.jinja(standard tiles). This approach aligns with the PR objectives of providing SM103-specific schedulers while maintaining compatibility.The separate directory
gen_gemm_sm103_cutlass_fp4correctly addresses the previous review concern about file collisions.flashinfer/gemm/gemm_base.py (3)
525-531: LGTM! SM103 module accessor correctly implemented.The function properly builds and loads the SM103-specific FP4 GEMM module with correct docstring and caching. Implementation follows the established pattern from SM100 and SM120 variants.
542-554: LGTM! SM103 routing logic correctly implemented.The updated function properly routes to the SM103-specific module when
sm_minor == 3(compute capability 10.3), while maintaining backward compatibility for SM100/110. The conditional logic clearly separates the three variants (SM10x with/without SM103, SM12x).
2288-2295: LGTM! Compute capability extraction correctly updated.The code now properly extracts both major and minor compute capability values and passes them to the module selector, enabling correct routing to SM103-specific kernels when
minor == 3.
IwakuraRein
left a comment
There was a problem hiding this comment.
LTGM. Thanks for the contributions!
|
/bot run |
|
Thanks! I'll also review but today might be hard |
|
[FAILED] Pipeline #41923518: 14/20 passed |
aleozlx
left a comment
There was a problem hiding this comment.
LGTM as well. but wanna give some time for other comments to be resolved
|
/bot run |
|
[CANCELING] Pipeline #43135280: canceled |
bkryu
left a comment
There was a problem hiding this comment.
LGTM. Unit tests are also coming back as passing.
…ls (flashinfer-ai#2303) ## Summary This PR adds new template specializations for SM103 NVFP4 CUTLASS GEMM kernels using architecture-specific tile shapes, cluster shapes, and schedulers. ## Motivation SM103 specifications show a higher NVFP4-over-BF16 speedup ratio than B200 (6× vs. 4×), but current kernels remain far from this limit. This PR introduces SM103-optimized templates to improve the achieved performance on this architecture. The performance gains are more pronounced at larger batch sizes, while the previous SM100 configurations remain preferable in other cases. For this reason, SM103-specific configurations were added alongside the existing ones rather than replacing them, and the optimal configuration is automatically selected as part of the autotuning process. ## Performance results examples Llama-3.1-70B, N=8192 K=28672, BF16 vs NVFP4 GEMMs TFLOP/s: | Batch Size | Torch BF16 | NVFP4 Before | NVFP4 After | |-----------:|-----------:|-------------:|------------:| | 8 | 50.418336 | 110.598008 | 124.005817 | | 16 | 99.350151 | 219.649654 | 260.502226 | | 32 | 193.884850 | 445.840601 | 519.291059 | | 64 | 385.790757 | 978.451544 | 1011.614080 | | 128 | 692.915989 | 2072.797941 | 2076.017433 | | 256 | 1211.413202| 3817.738538 | 3868.924511 | | 512 | 1464.015616| 5141.532768 | 5503.664311 | | 1024 | 1600.983748| 5659.831320 | 6341.013002 | | 2048 | 1625.639619| 5991.840134 | 6630.757403 | | 4096 | 1602.978834| 6160.806595 | 6898.878407 | | 8192 | 1691.174722| 5939.220913 | 6653.915111 | | 16384 | 1688.224044| 5926.519222 | 6595.387600 | | 24576 | 1706.774619| 5905.301100 | 6617.486211 | | 32768 | 1678.225402| 5913.806010 | 6592.762922 | --- Llama-3.1-70B, N=8192 K=8192, BF16 vs NVFP4 GEMMs TFLOP/s: | Batch Size | Torch BF16 | NVFP4 Before | NVFP4 After | |-----------:|-----------:|--------------------------:|--------------------------:| | 8 | 47.780647 | 124.774241 | 124.760324 | | 16 | 95.671633 | 249.502165 | 249.131125 | | 32 | 189.224266 | 497.991489 | 497.277802 | | 64 | 373.320912 | 993.731451 | 989.446041 | | 128 | 707.096994 | 1959.258553 | 1970.430179 | | 256 | 1126.908748| 4037.558967 | 4159.515720 | | 512 | 1407.884777| 5045.981883 | 4958.698763 | | 1024 | 1491.747576| 5654.694949 | 5614.133004 | | 2048 | 1546.322959| 5898.291400 | 6204.813491 | | 4096 | 1610.656216| 6312.498418 | 6605.534723 | | 8192 | 1623.748353| 6392.424296 | 6803.660138 | | 16384 | 1627.947338| 6438.789701 | 6947.466217 | | 24576 | 1614.582791| 6469.307368 | 6991.331576 | | 32768 | 1617.601164| 6515.312895 | 7010.746651 | <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added support for NVIDIA SM103 GPU architecture in FP4 operations with specialized kernel configurations and optimized launcher implementations, extending hardware compatibility and enabling efficient computation on additional GPU variants. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: LopezCastroRoberto <robertol.c510@gmail.com> Co-authored-by: yzh119 <zihaoy@nvidia.com>
Summary
This PR adds new template specializations for SM103 NVFP4 CUTLASS GEMM kernels using architecture-specific tile shapes, cluster shapes, and schedulers.
Motivation
SM103 specifications show a higher NVFP4-over-BF16 speedup ratio than B200 (6× vs. 4×), but current kernels remain far from this limit.
This PR introduces SM103-optimized templates to improve the achieved performance on this architecture.
The performance gains are more pronounced at larger batch sizes, while the previous SM100 configurations remain preferable in other cases.
For this reason, SM103-specific configurations were added alongside the existing ones rather than replacing them, and the optimal configuration is automatically selected as part of the autotuning process.
Performance results examples
Llama-3.1-70B, N=8192 K=28672, BF16 vs NVFP4 GEMMs TFLOP/s:
Llama-3.1-70B, N=8192 K=8192, BF16 vs NVFP4 GEMMs TFLOP/s:
Summary by CodeRabbit
✏️ Tip: You can customize this high-level summary in your review settings.