Skip to content

Optimize quantization function in large problem size#2343

Merged
yzh119 merged 5 commits intoflashinfer-ai:mainfrom
Shunkangz:new_quant
Jan 23, 2026
Merged

Optimize quantization function in large problem size#2343
yzh119 merged 5 commits intoflashinfer-ai:mainfrom
Shunkangz:new_quant

Conversation

@Shunkangz
Copy link
Contributor

@Shunkangz Shunkangz commented Jan 13, 2026

📌 Description

Optimize quantization function in large problem size by using TMA and warp specialized programming.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Added a high-throughput TMA-backed quantization path for large tensors to improve scalability and throughput.
  • Refactor

    • Reworked quantization internals to a template-driven, per-thread packing architecture, unifying element packing and scale-factor handling across FP16/FP8→FP4 and MXFP8 paths.
  • Documentation

    • Updated comments and dispatch logic to describe templated and TMA-backed quantization flows.
  • Tests

    • Extended quantization test coverage with a larger tensor shape.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 13, 2026

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

📝 Walkthrough

Walkthrough

Adds a comprehensive quantization utilities header, templatizes device quantization kernels around per-thread element packing and SF-vector sizes, and adds host-side cuTensorMap/TMA dispatch for high-throughput FP4/MXFP8 quantization while updating internal call sites and tests.

Changes

Cohort / File(s) Summary
New utilities header
csrc/nv_internal/tensorrt_llm/kernels/quantization_utils.cuh
Adds DstVec type-traits, TypeConverter/PackedVec templates, fp32→FP4/FP8 conversion helpers, clamp/quantize helpers, SF-offset calculators, silu utilities, TMA kernel traits, indexing/threading helpers, and TMA shared-memory sizing.
Quantization kernel refactor
csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh
Replaces legacy DstVec helpers with templated PackedVec<Type, ELTS_PER_THREAD> (PackedVecT), introduces Barrier alias, propagates ELTS_PER_THREAD and SF-vector template params through cvt_* helpers, updates static_asserts, loads, and SF offset/getters.
Host TMA dispatch & helpers
csrc/nv_internal/cpp/kernels/quantization.cu
Adds make_3d_tma_copy_desc and launchFP4QuantizationTma templates, pins SF_VEC_SIZE for MXFP8 path, and conditionally dispatches TMA-backed quantization for large m (SF_VEC_SIZE == 16) while retaining non-TMA fallback.
Call-site template updates
csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh
Internal quantizePackedFPXValue calls updated to pass CVT_ELTS_PER_THREAD to cvt_* helpers (FP4/MXFP8 paths).
Tests
tests/utils/test_fp4_quantize.py
Expands SHAPES and BATCH_SHAPES to include (2048, 2048) cases for larger input coverage.

Sequence Diagram(s)

sequenceDiagram
    participant Host as Host
    participant Map as CUtensorMap/TMA
    participant GPU as QuantKernel
    Host->>Host: choose path (TMA vs non-TMA) based on SF_VEC_SIZE and m
    Host->>Map: make_3d_tma_copy_desc(...) (build descriptor)
    Host->>GPU: launchFP4QuantizationTma(...) / cuLaunchKernelEx with CUtensorMap
    GPU->>Map: load input via TMA descriptor
    GPU->>GPU: perform packed quantization using PackedVecT and cvt_warp_* helpers
    GPU->>Host: write quantized output and SF outputs to global memory
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested reviewers

  • joker-eph
  • djmmoss
  • yongwww
  • bkryu
  • cyx-6

Poem

🐰 I hop through PackedVecT and tidy arrays,
I count SF vectors in quickspring rays,
TMA maps like carrots neatly laid,
Kernels nibble numbers, no time to fade,
A rabbit cheers: templates bloom today!

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly summarizes the main change: optimizing quantization for large problem sizes using TMA and warp specialization.
Description check ✅ Passed The description explains the optimization approach (TMA and warp specialization), indicates tests are added/updated and passing, and confirms pre-commit checks are complete.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @Shunkangz, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request focuses on optimizing the quantization function for large problem sizes by significantly refactoring the existing CUDA kernel code. Key helper functions, data structures, and conversion routines have been extracted into a new utility header, promoting better modularity. Furthermore, the quantization kernels have been generalized to support more flexible element-per-thread configurations, which is crucial for leveraging warp-specialized programming and potentially Tensor Memory Accelerator (TMA) features to achieve performance gains on large-scale computations.

Highlights

  • Code Refactoring and Modularity: A significant portion of quantization-related helper functions, data structures, and kernels have been moved from quantization.cuh to a new dedicated utility file, quantization_utils.cuh. This improves code organization and reusability.
  • Enhanced Genericity for Warp-Specialized Programming: The PackedVec template and several quantization kernels (cvt_warp_fp16_to_fp4, cvt_warp_fp8_to_fp4, cvt_warp_fp16_to_mxfp8) have been made more generic by introducing an ELTS_PER_THREAD template parameter. This allows for more flexible vector processing sizes, supporting advanced warp-specialized programming techniques for better performance.
  • Improved Scale Factor Offset Calculation: The cvt_quant_to_fp4_get_sf_out_offset function now includes CVT_FP4_SF_VEC_SIZE as a template parameter, further enhancing the flexibility and adaptability of scale factor calculations.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors quantization helper functions into a new quantization_utils.cuh header file, which is a great move for modularity. The changes also generalize several functions and data structures using templates to support variable vector sizes, enabling optimizations like warp-specialized programming and TMA. The code is well-structured and the changes significantly improve flexibility. My review includes a few suggestions to enhance comment clarity and accuracy for better long-term maintainability.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In @csrc/nv_internal/tensorrt_llm/kernels/quantization_utils.cuh:
- Around line 32-35: The primary template DstVec<T, NUM_ELTS> uses an ill-formed
static_assert with a string literal; replace it with a dependent-false pattern
so the assertion only fires when the template is instantiated (e.g., introduce a
template variable like dependent_false_v<T> that is constexpr false and use
static_assert(dependent_false_v<T>, "not implemented.") in the DstVec primary
template) so compilation succeeds until a specialization is required.
🧹 Nitpick comments (4)
csrc/nv_internal/tensorrt_llm/kernels/quantization_utils.cuh (4)

24-24: Avoid using namespace directive in header files.

This using namespace directive will propagate to all translation units that include this header, potentially causing name collisions or unexpected symbol resolution. Consider qualifying names explicitly (e.g., tensorrt_llm::common::cuda_clamp) or moving the directive inside function bodies where needed.


133-137: Consider adding runtime protection or clearer documentation for unsupported architectures.

The fallback return 0 for __CUDA_ARCH__ < 1000 could lead to silent incorrect results if these functions are inadvertently called on older GPU architectures. While the calling code likely has architecture guards, adding a comment noting this dependency or an assert would help future maintainers.

Based on learnings, for performance-critical hot paths, leaving comments explaining special algorithmic choices and potential alternatives is recommended.

Also applies to: 159-163, 199-203


229-232: Add comment explaining the exp == 0 edge case handling.

The special case returning 1 when exp == 0 (instead of computing exp2f(127)) deserves a brief comment explaining the rationale, as this deviates from the mathematical formula and may confuse future readers.

💡 Suggested documentation
 __device__ __forceinline__ float exp2f_rcp(uint8_t exp) {
   constexpr uint32_t FP32_EXPONENT_BIAS = 127;
+  // When exp == 0 (smallest positive scale factor), return 1.0 to avoid
+  // computing 2^127 which would cause overflow in subsequent operations.
   return (exp == 0) ? 1 : exp2f(FP32_EXPONENT_BIAS - static_cast<float>(exp));
 }

469-470: Consider documenting the magic constant 448.0f.

The constant 448.0f represents the maximum representable value in E4M3 format. Adding a brief comment or named constant would improve readability.

💡 Suggested improvement
+  // 448.0f is the maximum representable value in E4M3 FP8 format
   float SFValue = vecMax * reciprocal_approximate_ftz(448.0f);
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 2062dec and 2d42135.

📒 Files selected for processing (2)
  • csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh
  • csrc/nv_internal/tensorrt_llm/kernels/quantization_utils.cuh
🧰 Additional context used
🧠 Learnings (3)
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : For performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers

Applied to files:

  • csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh
  • csrc/nv_internal/tensorrt_llm/kernels/quantization_utils.cuh
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Consult the PTX ISA documentation (https://docs.nvidia.com/cuda/parallel-thread-execution/) for low-level instruction details and new GPU architecture features when writing inline PTX assembly

Applied to files:

  • csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh
  • csrc/nv_internal/tensorrt_llm/kernels/quantization_utils.cuh
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to csrc/**/*.cu : Framework bindings and PyTorch tensor handling should be implemented in `csrc/` via TVM-FFI, not in `include/` headers

Applied to files:

  • csrc/nv_internal/tensorrt_llm/kernels/quantization_utils.cuh
🔇 Additional comments (11)
csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh (4)

25-25: LGTM!

The new include for quantization_utils.cuh properly brings in the extracted helper functions and type traits.


175-204: LGTM!

The PackedVecT alias with the accompanying static_assert provides clean, type-safe access to the packed vector type while ensuring compile-time validation of size consistency.


292-304: LGTM!

The template parameters SF_VEC_SIZE, ELTS_PER_THREAD, and UE8M0_SF are correctly propagated to the quantization utility functions, maintaining consistency with the kernel's configuration.


326-329: LGTM!

The cvt_fp16_to_fp4_expert kernel correctly uses the CVT_FP4_* constants with the templated utility functions, maintaining consistent parameterization throughout the quantization pipeline.

Also applies to: 380-405

csrc/nv_internal/tensorrt_llm/kernels/quantization_utils.cuh (7)

37-66: LGTM!

The DstVec specializations correctly map source types to appropriate destination vector types, with full specializations taking precedence for known types and partial specializations providing fallbacks with size validation.


70-108: LGTM!

The helper functions are cleanly extracted with appropriate template parameters. The USE_SMEM optimization avoids redundant clamping for data already processed through shared memory.


237-271: LGTM!

The TypeConverter and PackedVec templates provide clean type mappings with compile-time size validation. The FP8 specialization correctly handles the different underlying type.


275-359: LGTM!

The cvt_warp_fp16_to_fp4 function is well-designed with:

  • Compile-time return type selection based on element count
  • Proper warp-level reduction for computing the scale factor
  • Clean handling of both UE8M0 and E4M3 scale factor formats
  • Type-safe conversion dispatch for half vs bfloat16 inputs

361-443: LGTM!

The FP8-to-FP4 conversion correctly implements the two-stage quantization (FP8→FP16→FP4) with appropriate scale factor handling. The static_assert enforces the required 16-element constraint.


510-628: LGTM!

The scale factor offset calculations correctly implement the swizzled tile layouts (128x4 and 8x4) and linear layout. The thread gating ensures only one thread per scale factor group performs the write, avoiding race conditions.


676-699: LGTM!

The silu and silu_and_mul functions are correctly implemented. The in-place modification of x_vec in silu_and_mul is efficient for the fused gate+activation pattern common in LLM architectures.

@yzh119
Copy link
Collaborator

yzh119 commented Jan 13, 2026

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !239 has been created, and the CI pipeline #41612397 is currently running. I'll report back once the pipeline job completes.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In @csrc/nv_internal/tensorrt_llm/kernels/quantization_utils.cuh:
- Around line 32-35: The primary template DstVec<T, NUM_ELTS> uses an invalid
static_assert with a string literal; replace it with a dependent-false
static_assert so the assertion only fires for instantiations (e.g., introduce a
template helper like dependent_false_v<T> and use
static_assert(dependent_false_v<T>, "DstVec not implemented for this type/size")
in the primary template), ensuring the specializations remain usable.
🧹 Nitpick comments (5)
csrc/nv_internal/cpp/kernels/quantization.cu (3)

219-223: Redundant fallback branch for unknown types.

The else branch at line 221-222 defaults to CU_TENSOR_MAP_DATA_TYPE_UINT8 for any type not explicitly handled. However, since this is a template function and only instantiated with half, __nv_bfloat16, and __nv_fp8_e4m3, the else branch duplicates line 220's FP8 case. Consider adding a static_assert to catch unintended instantiations.

Suggested improvement
   } else if constexpr (std::is_same_v<T, __nv_fp8_e4m3>) {
     data_type = CU_TENSOR_MAP_DATA_TYPE_UINT8;
   } else {
-    data_type = CU_TENSOR_MAP_DATA_TYPE_UINT8;
+    static_assert(sizeof(T) == 0, "Unsupported type for TMA tensor map");
   }

300-302: Missing error check after cudaLaunchKernelEx.

The return value of cudaLaunchKernelEx is not checked. While other kernel launches in this file also lack explicit error checks (relying on synchronization elsewhere), for consistency with the defensive error handling style shown for cuTensorMapEncodeTiled, consider adding a check here—especially since TMA kernel failures may be harder to diagnose.

Suggested fix
-  cudaLaunchKernelEx(&config, kernel_instance, b, m, n, n, input, SFScale,
-                     reinterpret_cast<uint32_t*>(output), reinterpret_cast<uint32_t*>(SFOuput),
-                     layout, tensor_map);
+  cudaError_t err = cudaLaunchKernelEx(&config, kernel_instance, b, m, n, n, input, SFScale,
+                     reinterpret_cast<uint32_t*>(output), reinterpret_cast<uint32_t*>(SFOuput),
+                     layout, tensor_map);
+  TLLM_CHECK_WITH_INFO(err == cudaSuccess, "Failed to launch TMA quantization kernel");

312-321: Document the threshold rationale for TMA path selection.

The hardcoded threshold m > 1024 determines when to use the TMA-optimized path. Per the learnings, for performance-critical hot paths, leave comments explaining special algorithmic choices. Consider adding a brief comment explaining why 1024 was chosen (e.g., amortizing TMA overhead, occupancy considerations).

Suggested documentation
     // Use TMA kernel for large m (high throughput mode)
     // Use if constexpr for SF_VEC_SIZE to avoid instantiating TMA kernel for unsupported sizes
     if constexpr (SF_VEC_SIZE == 16) {
+      // TMA path provides better throughput for larger problem sizes where the setup cost
+      // is amortized. Threshold of 1024 rows was empirically determined.
       if (m > 1024) {
csrc/nv_internal/tensorrt_llm/kernels/quantization_utils.cuh (1)

771-787: Consider adding a comment explaining the SWIZZLE_128B indexing math.

The swizzle index calculation in load_input_vec is non-trivial. Per learnings, performance-critical hot paths should include comments explaining algorithmic choices. A brief explanation of how col_after_swizzle relates to the 128-byte swizzle pattern would help future maintainers.

Suggested documentation
   template <typename PackedVecT>
   __device__ static PackedVecT load_input_vec(float4 const* base_float4, int threadRowIdxLocal,
                                               int threadColIdxLocal) {
-    // Compute swizzled indices for SWIZZLE_128B
+    // Compute swizzled indices for SWIZZLE_128B:
+    // - Each row is 64 elements = 128 bytes (8 float4s)
+    // - SWIZZLE_128B XORs the column index with the row index within the 128-byte tile
+    // - This provides conflict-free access for warp-wide loads
     int swizzled_col = threadColIdxLocal * 2;  // Each thread reads 2 float4s
csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh (1)

487-489: Potentially redundant threadOutOffset computation.

Line 489 recomputes threadOutOffset after advance_col(), but this value is immediately overwritten at lines 432-434 in the next loop iteration. Consider removing this line to avoid confusion.

Suggested cleanup
           // Update column offset
           tidx.advance_col();
-          threadOutOffset = (rowOffset + tidx.colIdx) >> 4;

           if (laneIdx == 0) {
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 2d42135 and 7ba41a6.

📒 Files selected for processing (3)
  • csrc/nv_internal/cpp/kernels/quantization.cu
  • csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh
  • csrc/nv_internal/tensorrt_llm/kernels/quantization_utils.cuh
🧰 Additional context used
📓 Path-based instructions (1)
csrc/**/*.cu

📄 CodeRabbit inference engine (CLAUDE.md)

Framework bindings and PyTorch tensor handling should be implemented in csrc/ via TVM-FFI, not in include/ headers

Files:

  • csrc/nv_internal/cpp/kernels/quantization.cu
🧠 Learnings (6)
📓 Common learnings
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to csrc/**/*.cu : Framework bindings and PyTorch tensor handling should be implemented in `csrc/` via TVM-FFI, not in `include/` headers
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : For performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers

Applied to files:

  • csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh
  • csrc/nv_internal/cpp/kernels/quantization.cu
  • csrc/nv_internal/tensorrt_llm/kernels/quantization_utils.cuh
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Consult the PTX ISA documentation (https://docs.nvidia.com/cuda/parallel-thread-execution/) for low-level instruction details and new GPU architecture features when writing inline PTX assembly

Applied to files:

  • csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh
  • csrc/nv_internal/cpp/kernels/quantization.cu
  • csrc/nv_internal/tensorrt_llm/kernels/quantization_utils.cuh
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to csrc/**/*.cu : Framework bindings and PyTorch tensor handling should be implemented in `csrc/` via TVM-FFI, not in `include/` headers

Applied to files:

  • csrc/nv_internal/cpp/kernels/quantization.cu
  • csrc/nv_internal/tensorrt_llm/kernels/quantization_utils.cuh
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Torch headers MUST NOT be included in files within the `include/` directory - keep framework-agnostic CUDA kernels that accept raw pointers

Applied to files:

  • csrc/nv_internal/cpp/kernels/quantization.cu
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • csrc/nv_internal/tensorrt_llm/kernels/quantization_utils.cuh
🧬 Code graph analysis (1)
csrc/nv_internal/cpp/kernels/quantization.cu (4)
csrc/trtllm_gemm_runner.cu (8)
  • m (111-126)
  • m (111-111)
  • m (128-179)
  • m (128-130)
  • m (181-236)
  • m (181-181)
  • m (238-250)
  • m (238-238)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h (1)
  • input (494-494)
csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/tile_interleaved_layout.h (1)
  • layout (29-47)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h (1)
  • enable_pdl (220-220)
🔇 Additional comments (13)
csrc/nv_internal/cpp/kernels/quantization.cu (2)

17-18: LGTM: Required includes for TMA functionality.

The added headers <cuda.h> and <cudaTypedefs.h> are necessary for the CUtensorMap type and the PFN_cuTensorMapEncodeTiled_v12000 function pointer typedef used in the TMA tensor map creation.


98-99: Acknowledge the TODO for MXFP8 TMA support.

The TODO indicates TMA quantization is not yet supported for MXFP8 (SF_VEC_SIZE=32). This is a reasonable scope limitation for this PR.

csrc/nv_internal/tensorrt_llm/kernels/quantization_utils.cuh (5)

1-27: LGTM: New quantization utilities header with proper structure.

The header is well-organized with clear sections for different functionality. The #pragma once guard and namespace organization follow the project conventions.


254-270: LGTM: PackedVec template provides type-safe vector abstraction.

The template correctly handles different element types with appropriate packed representations. The static_assert ensures size invariants are maintained at compile time.


276-361: LGTM: FP16-to-FP4 quantization with proper warp-level reduction.

The implementation correctly:

  • Computes local max across packed elements
  • Uses warp shuffles for cross-thread reduction based on CVT_NUM_THREADS_PER_SF
  • Handles both UE8M0 and UE4M3 scale factor formats
  • Guards against division by zero when vecMax == 0

863-869: LGTM: Shared memory size calculation for TMA kernel.

The helper correctly computes total shared memory requirements including barrier storage and data staging area.


133-136: The calling code is fully guarded at multiple architectural levels. All kernels and wrapper functions that use fp32_vec_to_e2m1() and fp32_vec_to_e4m3() are conditionally compiled with #if __CUDA_ARCH__ >= 1000, so these conversion functions will never be called on pre-Blackwell GPUs. The return 0 fallback in the #else branches serves as a safety measure and is an intentional design choice (evidenced by the commented-out static_assert). No data corruption risk exists.

Likely an incorrect or invalid review comment.

csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh (6)

17-31: LGTM: Required includes and type aliases for TMA and barrier support.

The new includes for cutlass/arch/barrier.h and cute/arch/copy_sm90_tma.hpp are necessary for the TMA-based quantization kernel. The Barrier type alias simplifies usage.


206-208: LGTM: Updated to use template-driven PackedVecT.

The kernel correctly uses the new PackedVecT type alias and maintains the size assertion for safety.


398-398: Verify elect_one_sync() is the intended pattern for single-producer selection.

The producer warp uses elect_one_sync() to ensure only one thread issues TMA loads. This is correct for TMA operations which require a single issuing thread.


516-519: LGTM: Expert quantization kernel updated to use new template API.

The kernel correctly uses PackedVecT and the updated function call signatures with explicit template parameters.


347-347: Good: Compile-time constraint for TMA kernel SF_VEC_SIZE.

The static_assert correctly enforces that TMA quantization only supports SF_VEC_SIZE = 16, preventing accidental misuse at compile time.


381-393: The transaction size at line 409 is correct. The arrive_and_expect_tx(SMEM_STAGE_SIZE * sizeof(SmemType)) accurately reflects the full data loaded per TMA operation. The TMA_LOAD_3D loads across all 8 consumer warps' columns (NUM_CONSUMER_WARPS * TMA_COL_TILE), which matches the entire SMEM_STAGE_SIZE allocation (SMEM_ROWS * SMEM_COLS = TMA_ROW_TILE * 8 * TMA_COL_TILE). The column loop increment and consumer warp access patterns both confirm this full-stage load per TMA transaction.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #41612397: 1/20 passed

@Shunkangz
Copy link
Contributor Author

@yzh119 I refactor the code format and add related UT. Could you please help me run the CI again? Thank you.

Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, we should also create benchmarking scripts for fp4 tma quantization with different problems.

@yzh119
Copy link
Collaborator

yzh119 commented Jan 15, 2026

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !239 has been updated with latest changes, and the CI pipeline #41787430 is currently running. I'll report back once the pipeline job completes.

@yzh119
Copy link
Collaborator

yzh119 commented Jan 15, 2026

@Shunkangz would you mind taking a look at the compilation errors:

[2026-01-14T07:57:11.856Z] /workspace/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh(998): error: cannot determine which instance of function template "tensorrt_llm::kernels::cvt_warp_fp16_to_fp4" is intended
[2026-01-14T07:57:11.856Z]                    ? &cvt_warp_fp16_to_fp4<GemmOutputType, VecSize, false>
[2026-01-14T07:57:11.856Z]                      ^
[2026-01-14T07:57:11.856Z]           detected during:
[2026-01-14T07:57:11.856Z]             instantiation of "auto tensorrt_llm::kernels::cutlass_kernels::quantizePackedFPXValue<GemmOutputType,QuantizedType,ComputeElem,VecSize>(ComputeElem &, float, int64_t, int64_t, int64_t, int64_t, int64_t, tensorrt_llm::kernels::cutlass_kernels::TmaWarpSpecializedGroupedGemmInput::ElementSF *, tensorrt_llm::kernels::cutlass_kernels::TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType) [with GemmOutputType=half, QuantizedType=__nv_fp4_e2m1, ComputeElem=cutlass::Array<float, 8, true>, VecSize=16]" at line 2190
[2026-01-14T07:57:11.856Z]             instantiation of "void tensorrt_llm::kernels::cutlass_kernels::doActivationKernel<T,GemmOutputType,ScaleBiasType,ActFn,BlockScalingType>(T *, const GemmOutputType *, const float *, const ScaleBiasType *, __nv_bool, const int64_t *, int, int64_t, const float *, __nv_bool, tensorrt_llm::kernels::cutlass_kernels::TmaWarpSpecializedGroupedGemmInput::ElementSF *, tensorrt_llm::kernels::cutlass_kernels::ActivationParams) [with T=__nv_fp4_e2m1, GemmOutputType=half, ScaleBiasType=half, ActFn=tensorrt_llm::kernels::cutlass_kernels::IdentityAdaptor<cutlass::epilogue::thread::GELU>, BlockScalingType=tensorrt_llm::kernels::cutlass_kernels::TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4]" at line 2330
[2026-01-14T07:57:11.856Z]             instantiation of "void tensorrt_llm::kernels::cutlass_kernels::doActivation(T *, const GemmOutputType *, const float *, const ScaleBiasType *, __nv_bool, const int64_t *, int, int64_t, int64_t, tensorrt_llm::kernels::cutlass_kernels::ActivationParams, const tensorrt_llm::kernels::cutlass_kernels::QuantParams &, __nv_bool, tensorrt_llm::kernels::cutlass_kernels::TmaWarpSpecializedGroupedGemmInput::ElementSF *, __nv_bool, cudaStream_t) [with T=__nv_fp4_e2m1, GemmOutputType=half, ScaleBiasType=half]" at line 3061
[2026-01-14T07:57:11.856Z]             instantiation of "void tensorrt_llm::kernels::cutlass_kernels::CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enable>::gemm1(tensorrt_llm::kernels::cutlass_kernels::MoeGemmRunner<T, WeightType, OutputType, tensorrt_llm::kernels::cutlass_kernels::CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enable>::ScaleBiasType> &, tensorrt_llm::kernels::cutlass_kernels::CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enable>::DeepSeekBlockScaleGemmRunner *, const T *, T *, void *, const int64_t *, tensorrt_llm::kernels::cutlass_kernels::TmaWarpSpecializedGroupedGemmInput, const WeightType *, const tensorrt_llm::kernels::cutlass_kernels::CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enable>::ScaleBiasType *, const int64_t *, const tensorrt_llm::kernels::cutlass_kernels::CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enable>::ScaleBiasType *, const float *, const float *, const tensorrt_llm::kernels::cutlass_kernels::TmaWarpSpecializedGroupedGemmInput::ElementSF *, tensorrt_llm::kernels::cutlass_kernels::TmaWarpSpecializedGroupedGemmInput::ElementSF *, tensorrt_llm::kernels::cutlass_kernels::QuantParams, int64_t, int64_t, int64_t, int64_t, int, tensorrt_llm::kernels::cutlass_kernels::ActivationParams, const float **, __nv_bool, cudaStream_t, tensorrt_llm::cutlass_extensions::CutlassGemmConfig, __nv_bool, int *, int *, __nv_bool) [with T=__nv_fp4_e2m1, WeightType=__nv_fp4_e2m1, OutputType=half, InputType=__nv_fp4_e2m1, BackBoneType=half, Enable=void]" at line 3719

@Shunkangz Shunkangz requested a review from jimmyzho as a code owner January 15, 2026 08:27
@Shunkangz
Copy link
Contributor Author

cvt_warp_fp16_to_fp4

Hi @yzh119 , I fixed this issue. Please help me run CI. Thanks.

@bkryu
Copy link
Collaborator

bkryu commented Jan 15, 2026

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !239 has been updated with latest changes, and the CI pipeline #41820142 is currently running. I'll report back once the pipeline job completes.

@bkryu
Copy link
Collaborator

bkryu commented Jan 16, 2026

Hi @Shunkangz, the unit tests results are showing a high rate of mismatches in tests/moe/test_trtllm_gen_fused_moe.py on SM100 & 103 devices. Can you look into this?

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #41820142: 11/20 passed

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In `@csrc/nv_internal/cpp/kernels/quantization.cu`:
- Around line 296-299: The call to cudaFuncSetAttribute(kernel_instance,
cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) is unchecked; capture
its return value (e.g., cudaError_t rc = cudaFuncSetAttribute(...)) and handle
failures before proceeding to create cudaLaunchConfig_t config or launching the
kernel: on non-success return log or propagate the error (including smem_size
and kernel_instance identifiers) and abort/return so the kernel is not launched
with insufficient dynamic shared memory.
- Around line 187-230: The template make_3d_tma_copy_desc currently silently
maps any unknown T to CU_TENSOR_MAP_DATA_TYPE_UINT8; add a compile-time check so
only the supported types (half, __nv_bfloat16, __nv_fp8_e4m3) are allowed and
all other instantiations fail to compile. Implement this by introducing a
constexpr predicate (e.g., is_supported_type<T>) or individual constexpr
booleans and then a static_assert near the top of make_3d_tma_copy_desc that
references T and emits a clear message like "Unsupported data type for
cuTensorMapDataType" if false; keep the existing explicit mapping to
CU_TENSOR_MAP_DATA_TYPE_UINT8 only for __nv_fp8_e4m3 and remove the catch-all
else branch. Ensure the static_assert and mapping touch the variables data_type
and template T so unsupported types fail at compile time.
🧹 Nitpick comments (1)
csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh (1)

419-423: Tie consumer-warp guard to Traits::NUM_CONSUMER_WARPS.
Line 419 hardcodes <= 8, which will break if Traits changes. Prefer the trait constant to keep the kernel self-consistent.

♻️ Proposed change
-  } else if (warpIdx >= 1 and warpIdx <= 8) {
+  } else if (warpIdx >= 1 and warpIdx <= NUM_CONSUMER_WARPS) {

@Shunkangz
Copy link
Contributor Author

Hi @Shunkangz, the unit tests results are showing a high rate of mismatches in tests/moe/test_trtllm_gen_fused_moe.py on SM100 & 103 devices. Can you look into this?

Hi @bkryu , thank you for pointing this out. There are two potential issue. At first, I fix the TMA loading issue with batch size larger than 1 and add the UT config as well. Another problem is that there is a padding issue when N is not divisible by 512. For this problem, I want to propose another PR to fix it. Currently, I let the code fall back to the original kernel. Does it make sense to you?

@bkryu
Copy link
Collaborator

bkryu commented Jan 20, 2026

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !239 has been updated with latest changes, and the CI pipeline #42125407 is currently running. I'll report back once the pipeline job completes.

@bkryu
Copy link
Collaborator

bkryu commented Jan 20, 2026

Hi @bkryu , thank you for pointing this out. There are two potential issue. At first, I fix the TMA loading issue with batch size larger than 1 and add the UT config as well. Another problem is that there is a padding issue when N is not divisible by 512. For this problem, I want to propose another PR to fix it. Currently, I let the code fall back to the original kernel. Does it make sense to you?

Hi @Shunkangz, thanks for looking into this. If you mean falling back to the original kernel when N % 512 != 0 so that there are no functional issues and unit tests are passing while you prepare a followup PR, I'd say it should be fine. We don't want the main branch to be failing unit tests.

I have re-triggered the unit tests to see what state we are in now.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #42125407: 3/20 passed

@Shunkangz
Copy link
Contributor Author

Hi @bkryu , thank you for pointing this out. There are two potential issue. At first, I fix the TMA loading issue with batch size larger than 1 and add the UT config as well. Another problem is that there is a padding issue when N is not divisible by 512. For this problem, I want to propose another PR to fix it. Currently, I let the code fall back to the original kernel. Does it make sense to you?

Hi @Shunkangz, thanks for looking into this. If you mean falling back to the original kernel when N % 512 != 0 so that there are no functional issues and unit tests are passing while you prepare a followup PR, I'd say it should be fine. We don't want the main branch to be failing unit tests.

I have re-triggered the unit tests to see what state we are in now.

Hi @bkryu , it seems that the 3 failed tests are not related to my change. Could you please help me double check this? Thank you.

Copy link
Collaborator

@bkryu bkryu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank @Shunkangz, the failures are indeed unrelated that should have already been fixed in the main branch. LGTM

@yzh119 yzh119 merged commit aa8af85 into flashinfer-ai:main Jan 23, 2026
20 of 23 checks passed
@coderabbitai coderabbitai bot mentioned this pull request Feb 22, 2026
5 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants