Optimize quantization function in large problem size#2343
Optimize quantization function in large problem size#2343yzh119 merged 5 commits intoflashinfer-ai:mainfrom
Conversation
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. 📝 WalkthroughWalkthroughAdds a comprehensive quantization utilities header, templatizes device quantization kernels around per-thread element packing and SF-vector sizes, and adds host-side cuTensorMap/TMA dispatch for high-throughput FP4/MXFP8 quantization while updating internal call sites and tests. Changes
Sequence Diagram(s)sequenceDiagram
participant Host as Host
participant Map as CUtensorMap/TMA
participant GPU as QuantKernel
Host->>Host: choose path (TMA vs non-TMA) based on SF_VEC_SIZE and m
Host->>Map: make_3d_tma_copy_desc(...) (build descriptor)
Host->>GPU: launchFP4QuantizationTma(...) / cuLaunchKernelEx with CUtensorMap
GPU->>Map: load input via TMA descriptor
GPU->>GPU: perform packed quantization using PackedVecT and cvt_warp_* helpers
GPU->>Host: write quantized output and SF outputs to global memory
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @Shunkangz, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request focuses on optimizing the quantization function for large problem sizes by significantly refactoring the existing CUDA kernel code. Key helper functions, data structures, and conversion routines have been extracted into a new utility header, promoting better modularity. Furthermore, the quantization kernels have been generalized to support more flexible element-per-thread configurations, which is crucial for leveraging warp-specialized programming and potentially Tensor Memory Accelerator (TMA) features to achieve performance gains on large-scale computations. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request refactors quantization helper functions into a new quantization_utils.cuh header file, which is a great move for modularity. The changes also generalize several functions and data structures using templates to support variable vector sizes, enabling optimizations like warp-specialized programming and TMA. The code is well-structured and the changes significantly improve flexibility. My review includes a few suggestions to enhance comment clarity and accuracy for better long-term maintainability.
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In @csrc/nv_internal/tensorrt_llm/kernels/quantization_utils.cuh:
- Around line 32-35: The primary template DstVec<T, NUM_ELTS> uses an ill-formed
static_assert with a string literal; replace it with a dependent-false pattern
so the assertion only fires when the template is instantiated (e.g., introduce a
template variable like dependent_false_v<T> that is constexpr false and use
static_assert(dependent_false_v<T>, "not implemented.") in the DstVec primary
template) so compilation succeeds until a specialization is required.
🧹 Nitpick comments (4)
csrc/nv_internal/tensorrt_llm/kernels/quantization_utils.cuh (4)
24-24: Avoidusing namespacedirective in header files.This
using namespacedirective will propagate to all translation units that include this header, potentially causing name collisions or unexpected symbol resolution. Consider qualifying names explicitly (e.g.,tensorrt_llm::common::cuda_clamp) or moving the directive inside function bodies where needed.
133-137: Consider adding runtime protection or clearer documentation for unsupported architectures.The fallback
return 0for__CUDA_ARCH__ < 1000could lead to silent incorrect results if these functions are inadvertently called on older GPU architectures. While the calling code likely has architecture guards, adding a comment noting this dependency or an assert would help future maintainers.Based on learnings, for performance-critical hot paths, leaving comments explaining special algorithmic choices and potential alternatives is recommended.
Also applies to: 159-163, 199-203
229-232: Add comment explaining theexp == 0edge case handling.The special case returning
1whenexp == 0(instead of computingexp2f(127)) deserves a brief comment explaining the rationale, as this deviates from the mathematical formula and may confuse future readers.💡 Suggested documentation
__device__ __forceinline__ float exp2f_rcp(uint8_t exp) { constexpr uint32_t FP32_EXPONENT_BIAS = 127; + // When exp == 0 (smallest positive scale factor), return 1.0 to avoid + // computing 2^127 which would cause overflow in subsequent operations. return (exp == 0) ? 1 : exp2f(FP32_EXPONENT_BIAS - static_cast<float>(exp)); }
469-470: Consider documenting the magic constant 448.0f.The constant
448.0frepresents the maximum representable value in E4M3 format. Adding a brief comment or named constant would improve readability.💡 Suggested improvement
+ // 448.0f is the maximum representable value in E4M3 FP8 format float SFValue = vecMax * reciprocal_approximate_ftz(448.0f);
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
csrc/nv_internal/tensorrt_llm/kernels/quantization.cuhcsrc/nv_internal/tensorrt_llm/kernels/quantization_utils.cuh
🧰 Additional context used
🧠 Learnings (3)
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : For performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers
Applied to files:
csrc/nv_internal/tensorrt_llm/kernels/quantization.cuhcsrc/nv_internal/tensorrt_llm/kernels/quantization_utils.cuh
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Consult the PTX ISA documentation (https://docs.nvidia.com/cuda/parallel-thread-execution/) for low-level instruction details and new GPU architecture features when writing inline PTX assembly
Applied to files:
csrc/nv_internal/tensorrt_llm/kernels/quantization.cuhcsrc/nv_internal/tensorrt_llm/kernels/quantization_utils.cuh
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to csrc/**/*.cu : Framework bindings and PyTorch tensor handling should be implemented in `csrc/` via TVM-FFI, not in `include/` headers
Applied to files:
csrc/nv_internal/tensorrt_llm/kernels/quantization_utils.cuh
🔇 Additional comments (11)
csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh (4)
25-25: LGTM!The new include for
quantization_utils.cuhproperly brings in the extracted helper functions and type traits.
175-204: LGTM!The
PackedVecTalias with the accompanying static_assert provides clean, type-safe access to the packed vector type while ensuring compile-time validation of size consistency.
292-304: LGTM!The template parameters
SF_VEC_SIZE,ELTS_PER_THREAD, andUE8M0_SFare correctly propagated to the quantization utility functions, maintaining consistency with the kernel's configuration.
326-329: LGTM!The
cvt_fp16_to_fp4_expertkernel correctly uses theCVT_FP4_*constants with the templated utility functions, maintaining consistent parameterization throughout the quantization pipeline.Also applies to: 380-405
csrc/nv_internal/tensorrt_llm/kernels/quantization_utils.cuh (7)
37-66: LGTM!The
DstVecspecializations correctly map source types to appropriate destination vector types, with full specializations taking precedence for known types and partial specializations providing fallbacks with size validation.
70-108: LGTM!The helper functions are cleanly extracted with appropriate template parameters. The
USE_SMEMoptimization avoids redundant clamping for data already processed through shared memory.
237-271: LGTM!The
TypeConverterandPackedVectemplates provide clean type mappings with compile-time size validation. The FP8 specialization correctly handles the different underlying type.
275-359: LGTM!The
cvt_warp_fp16_to_fp4function is well-designed with:
- Compile-time return type selection based on element count
- Proper warp-level reduction for computing the scale factor
- Clean handling of both UE8M0 and E4M3 scale factor formats
- Type-safe conversion dispatch for half vs bfloat16 inputs
361-443: LGTM!The FP8-to-FP4 conversion correctly implements the two-stage quantization (FP8→FP16→FP4) with appropriate scale factor handling. The
static_assertenforces the required 16-element constraint.
510-628: LGTM!The scale factor offset calculations correctly implement the swizzled tile layouts (128x4 and 8x4) and linear layout. The thread gating ensures only one thread per scale factor group performs the write, avoiding race conditions.
676-699: LGTM!The
siluandsilu_and_mulfunctions are correctly implemented. The in-place modification ofx_vecinsilu_and_mulis efficient for the fused gate+activation pattern common in LLM architectures.
|
/bot run |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In @csrc/nv_internal/tensorrt_llm/kernels/quantization_utils.cuh:
- Around line 32-35: The primary template DstVec<T, NUM_ELTS> uses an invalid
static_assert with a string literal; replace it with a dependent-false
static_assert so the assertion only fires for instantiations (e.g., introduce a
template helper like dependent_false_v<T> and use
static_assert(dependent_false_v<T>, "DstVec not implemented for this type/size")
in the primary template), ensuring the specializations remain usable.
🧹 Nitpick comments (5)
csrc/nv_internal/cpp/kernels/quantization.cu (3)
219-223: Redundant fallback branch for unknown types.The
elsebranch at line 221-222 defaults toCU_TENSOR_MAP_DATA_TYPE_UINT8for any type not explicitly handled. However, since this is a template function and only instantiated withhalf,__nv_bfloat16, and__nv_fp8_e4m3, the else branch duplicates line 220's FP8 case. Consider adding astatic_assertto catch unintended instantiations.Suggested improvement
} else if constexpr (std::is_same_v<T, __nv_fp8_e4m3>) { data_type = CU_TENSOR_MAP_DATA_TYPE_UINT8; } else { - data_type = CU_TENSOR_MAP_DATA_TYPE_UINT8; + static_assert(sizeof(T) == 0, "Unsupported type for TMA tensor map"); }
300-302: Missing error check aftercudaLaunchKernelEx.The return value of
cudaLaunchKernelExis not checked. While other kernel launches in this file also lack explicit error checks (relying on synchronization elsewhere), for consistency with the defensive error handling style shown forcuTensorMapEncodeTiled, consider adding a check here—especially since TMA kernel failures may be harder to diagnose.Suggested fix
- cudaLaunchKernelEx(&config, kernel_instance, b, m, n, n, input, SFScale, - reinterpret_cast<uint32_t*>(output), reinterpret_cast<uint32_t*>(SFOuput), - layout, tensor_map); + cudaError_t err = cudaLaunchKernelEx(&config, kernel_instance, b, m, n, n, input, SFScale, + reinterpret_cast<uint32_t*>(output), reinterpret_cast<uint32_t*>(SFOuput), + layout, tensor_map); + TLLM_CHECK_WITH_INFO(err == cudaSuccess, "Failed to launch TMA quantization kernel");
312-321: Document the threshold rationale for TMA path selection.The hardcoded threshold
m > 1024determines when to use the TMA-optimized path. Per the learnings, for performance-critical hot paths, leave comments explaining special algorithmic choices. Consider adding a brief comment explaining why 1024 was chosen (e.g., amortizing TMA overhead, occupancy considerations).Suggested documentation
// Use TMA kernel for large m (high throughput mode) // Use if constexpr for SF_VEC_SIZE to avoid instantiating TMA kernel for unsupported sizes if constexpr (SF_VEC_SIZE == 16) { + // TMA path provides better throughput for larger problem sizes where the setup cost + // is amortized. Threshold of 1024 rows was empirically determined. if (m > 1024) {csrc/nv_internal/tensorrt_llm/kernels/quantization_utils.cuh (1)
771-787: Consider adding a comment explaining the SWIZZLE_128B indexing math.The swizzle index calculation in
load_input_vecis non-trivial. Per learnings, performance-critical hot paths should include comments explaining algorithmic choices. A brief explanation of howcol_after_swizzlerelates to the 128-byte swizzle pattern would help future maintainers.Suggested documentation
template <typename PackedVecT> __device__ static PackedVecT load_input_vec(float4 const* base_float4, int threadRowIdxLocal, int threadColIdxLocal) { - // Compute swizzled indices for SWIZZLE_128B + // Compute swizzled indices for SWIZZLE_128B: + // - Each row is 64 elements = 128 bytes (8 float4s) + // - SWIZZLE_128B XORs the column index with the row index within the 128-byte tile + // - This provides conflict-free access for warp-wide loads int swizzled_col = threadColIdxLocal * 2; // Each thread reads 2 float4scsrc/nv_internal/tensorrt_llm/kernels/quantization.cuh (1)
487-489: Potentially redundantthreadOutOffsetcomputation.Line 489 recomputes
threadOutOffsetafteradvance_col(), but this value is immediately overwritten at lines 432-434 in the next loop iteration. Consider removing this line to avoid confusion.Suggested cleanup
// Update column offset tidx.advance_col(); - threadOutOffset = (rowOffset + tidx.colIdx) >> 4; if (laneIdx == 0) {
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
csrc/nv_internal/cpp/kernels/quantization.cucsrc/nv_internal/tensorrt_llm/kernels/quantization.cuhcsrc/nv_internal/tensorrt_llm/kernels/quantization_utils.cuh
🧰 Additional context used
📓 Path-based instructions (1)
csrc/**/*.cu
📄 CodeRabbit inference engine (CLAUDE.md)
Framework bindings and PyTorch tensor handling should be implemented in
csrc/via TVM-FFI, not ininclude/headers
Files:
csrc/nv_internal/cpp/kernels/quantization.cu
🧠 Learnings (6)
📓 Common learnings
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to csrc/**/*.cu : Framework bindings and PyTorch tensor handling should be implemented in `csrc/` via TVM-FFI, not in `include/` headers
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : For performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers
Applied to files:
csrc/nv_internal/tensorrt_llm/kernels/quantization.cuhcsrc/nv_internal/cpp/kernels/quantization.cucsrc/nv_internal/tensorrt_llm/kernels/quantization_utils.cuh
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Consult the PTX ISA documentation (https://docs.nvidia.com/cuda/parallel-thread-execution/) for low-level instruction details and new GPU architecture features when writing inline PTX assembly
Applied to files:
csrc/nv_internal/tensorrt_llm/kernels/quantization.cuhcsrc/nv_internal/cpp/kernels/quantization.cucsrc/nv_internal/tensorrt_llm/kernels/quantization_utils.cuh
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to csrc/**/*.cu : Framework bindings and PyTorch tensor handling should be implemented in `csrc/` via TVM-FFI, not in `include/` headers
Applied to files:
csrc/nv_internal/cpp/kernels/quantization.cucsrc/nv_internal/tensorrt_llm/kernels/quantization_utils.cuh
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Torch headers MUST NOT be included in files within the `include/` directory - keep framework-agnostic CUDA kernels that accept raw pointers
Applied to files:
csrc/nv_internal/cpp/kernels/quantization.cu
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.
Applied to files:
csrc/nv_internal/tensorrt_llm/kernels/quantization_utils.cuh
🧬 Code graph analysis (1)
csrc/nv_internal/cpp/kernels/quantization.cu (4)
csrc/trtllm_gemm_runner.cu (8)
m(111-126)m(111-111)m(128-179)m(128-130)m(181-236)m(181-181)m(238-250)m(238-238)csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws.h (1)
input(494-494)csrc/nv_internal/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/tile_interleaved_layout.h (1)
layout(29-47)csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h (1)
enable_pdl(220-220)
🔇 Additional comments (13)
csrc/nv_internal/cpp/kernels/quantization.cu (2)
17-18: LGTM: Required includes for TMA functionality.The added headers
<cuda.h>and<cudaTypedefs.h>are necessary for theCUtensorMaptype and thePFN_cuTensorMapEncodeTiled_v12000function pointer typedef used in the TMA tensor map creation.
98-99: Acknowledge the TODO for MXFP8 TMA support.The TODO indicates TMA quantization is not yet supported for MXFP8 (SF_VEC_SIZE=32). This is a reasonable scope limitation for this PR.
csrc/nv_internal/tensorrt_llm/kernels/quantization_utils.cuh (5)
1-27: LGTM: New quantization utilities header with proper structure.The header is well-organized with clear sections for different functionality. The
#pragma onceguard and namespace organization follow the project conventions.
254-270: LGTM: PackedVec template provides type-safe vector abstraction.The template correctly handles different element types with appropriate packed representations. The
static_assertensures size invariants are maintained at compile time.
276-361: LGTM: FP16-to-FP4 quantization with proper warp-level reduction.The implementation correctly:
- Computes local max across packed elements
- Uses warp shuffles for cross-thread reduction based on
CVT_NUM_THREADS_PER_SF- Handles both UE8M0 and UE4M3 scale factor formats
- Guards against division by zero when
vecMax == 0
863-869: LGTM: Shared memory size calculation for TMA kernel.The helper correctly computes total shared memory requirements including barrier storage and data staging area.
133-136: The calling code is fully guarded at multiple architectural levels. All kernels and wrapper functions that usefp32_vec_to_e2m1()andfp32_vec_to_e4m3()are conditionally compiled with#if __CUDA_ARCH__ >= 1000, so these conversion functions will never be called on pre-Blackwell GPUs. Thereturn 0fallback in the#elsebranches serves as a safety measure and is an intentional design choice (evidenced by the commented-outstatic_assert). No data corruption risk exists.Likely an incorrect or invalid review comment.
csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh (6)
17-31: LGTM: Required includes and type aliases for TMA and barrier support.The new includes for
cutlass/arch/barrier.handcute/arch/copy_sm90_tma.hppare necessary for the TMA-based quantization kernel. TheBarriertype alias simplifies usage.
206-208: LGTM: Updated to use template-driven PackedVecT.The kernel correctly uses the new
PackedVecTtype alias and maintains the size assertion for safety.
398-398: Verifyelect_one_sync()is the intended pattern for single-producer selection.The producer warp uses
elect_one_sync()to ensure only one thread issues TMA loads. This is correct for TMA operations which require a single issuing thread.
516-519: LGTM: Expert quantization kernel updated to use new template API.The kernel correctly uses
PackedVecTand the updated function call signatures with explicit template parameters.
347-347: Good: Compile-time constraint for TMA kernel SF_VEC_SIZE.The
static_assertcorrectly enforces that TMA quantization only supportsSF_VEC_SIZE = 16, preventing accidental misuse at compile time.
381-393: The transaction size at line 409 is correct. Thearrive_and_expect_tx(SMEM_STAGE_SIZE * sizeof(SmemType))accurately reflects the full data loaded per TMA operation. The TMA_LOAD_3D loads across all 8 consumer warps' columns (NUM_CONSUMER_WARPS * TMA_COL_TILE), which matches the entire SMEM_STAGE_SIZE allocation (SMEM_ROWS * SMEM_COLS = TMA_ROW_TILE * 8 * TMA_COL_TILE). The column loop increment and consumer warp access patterns both confirm this full-stage load per TMA transaction.
|
[FAILED] Pipeline #41612397: 1/20 passed |
|
@yzh119 I refactor the code format and add related UT. Could you please help me run the CI again? Thank you. |
yzh119
left a comment
There was a problem hiding this comment.
LGTM, we should also create benchmarking scripts for fp4 tma quantization with different problems.
|
/bot run |
|
@Shunkangz would you mind taking a look at the compilation errors: |
Hi @yzh119 , I fixed this issue. Please help me run CI. Thanks. |
|
/bot run |
|
Hi @Shunkangz, the unit tests results are showing a high rate of mismatches in |
|
[FAILED] Pipeline #41820142: 11/20 passed |
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Fix all issues with AI agents
In `@csrc/nv_internal/cpp/kernels/quantization.cu`:
- Around line 296-299: The call to cudaFuncSetAttribute(kernel_instance,
cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) is unchecked; capture
its return value (e.g., cudaError_t rc = cudaFuncSetAttribute(...)) and handle
failures before proceeding to create cudaLaunchConfig_t config or launching the
kernel: on non-success return log or propagate the error (including smem_size
and kernel_instance identifiers) and abort/return so the kernel is not launched
with insufficient dynamic shared memory.
- Around line 187-230: The template make_3d_tma_copy_desc currently silently
maps any unknown T to CU_TENSOR_MAP_DATA_TYPE_UINT8; add a compile-time check so
only the supported types (half, __nv_bfloat16, __nv_fp8_e4m3) are allowed and
all other instantiations fail to compile. Implement this by introducing a
constexpr predicate (e.g., is_supported_type<T>) or individual constexpr
booleans and then a static_assert near the top of make_3d_tma_copy_desc that
references T and emits a clear message like "Unsupported data type for
cuTensorMapDataType" if false; keep the existing explicit mapping to
CU_TENSOR_MAP_DATA_TYPE_UINT8 only for __nv_fp8_e4m3 and remove the catch-all
else branch. Ensure the static_assert and mapping touch the variables data_type
and template T so unsupported types fail at compile time.
🧹 Nitpick comments (1)
csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh (1)
419-423: Tie consumer-warp guard toTraits::NUM_CONSUMER_WARPS.
Line 419 hardcodes<= 8, which will break if Traits changes. Prefer the trait constant to keep the kernel self-consistent.♻️ Proposed change
- } else if (warpIdx >= 1 and warpIdx <= 8) { + } else if (warpIdx >= 1 and warpIdx <= NUM_CONSUMER_WARPS) {
Hi @bkryu , thank you for pointing this out. There are two potential issue. At first, I fix the TMA loading issue with batch size larger than 1 and add the UT config as well. Another problem is that there is a padding issue when N is not divisible by 512. For this problem, I want to propose another PR to fix it. Currently, I let the code fall back to the original kernel. Does it make sense to you? |
|
/bot run |
Hi @Shunkangz, thanks for looking into this. If you mean falling back to the original kernel when I have re-triggered the unit tests to see what state we are in now. |
|
[FAILED] Pipeline #42125407: 3/20 passed |
Hi @bkryu , it seems that the 3 failed tests are not related to my change. Could you please help me double check this? Thank you. |
bkryu
left a comment
There was a problem hiding this comment.
Thank @Shunkangz, the failures are indeed unrelated that should have already been fixed in the main branch. LGTM
📌 Description
Optimize quantization function in large problem size by using TMA and warp specialized programming.
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Refactor
Documentation
Tests
✏️ Tip: You can customize this high-level summary in your review settings.