Skip to content

[feat] Trtllm-gen Per-token Nvfp4 MoE#3027

Open
IwakuraRein wants to merge 27 commits intoflashinfer-ai:mainfrom
IwakuraRein:per-token-fp4
Open

[feat] Trtllm-gen Per-token Nvfp4 MoE#3027
IwakuraRein wants to merge 27 commits intoflashinfer-ai:mainfrom
IwakuraRein:per-token-fp4

Conversation

@IwakuraRein
Copy link
Copy Markdown
Collaborator

@IwakuraRein IwakuraRein commented Apr 9, 2026

📌 Description

coauthor: @mxz297

This PR aims to enable the per-token quantization for Trtllm-gen MoE.

  • In trtllm_fp4_block_scale_moe and trtllm_fp4_block_scale_routed_moe, added a new optional argument per_token_scale.
  • Optimize fp4 quantization kernel. Use 256bit vectorized load. Add cvt_warp_fp16_to_fp4_with_vec_max to use the cached local amax.
  • Add explicit amax and quantization kernel after FC1 to generate the per-token scales for FC2.
  • Generate the new cubins.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Per‑token FP4 quantization and a new nvfp4_quant_and_per_token_scale API exposed to Python.
    • Per‑token scaling integrated across fused MoE flow and runners; FP4 quantize gains a flag to invert global scale and row‑wise/inverse scaling controls.
  • Chores

    • Updated backend artifact reference and checksum for batched GEMM.
  • Tests

    • Added end‑to‑end per‑token routed fused MoE test.

Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 9, 2026

📝 Walkthrough

Walkthrough

Adds NVFP4 per-token quantization and per-token scale outputs, extends FP4 quantization with row-wise and inverse-scale options, refactors quantization kernels/helpers and host dispatch, threads per-token scaling through fused MoE runner/launcher/GEMM, and exposes new Python API and tests.

Changes

Cohort / File(s) Summary
Quant kernels & helpers
csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh, csrc/nv_internal/cpp/kernels/quantization.cu, csrc/nv_internal/tensorrt_llm/kernels/quantization_utils.cuh
Added packed-vector load helpers and CVT_FP16_TO_FP4_ELTS_PER_THREAD; added cvt_warp_fp16_to_fp4_with_vec_max; introduced NVFP4 per-token kernels (nvfp4QuantAndPerTokenScale*); refactored FP4/MXFP8 kernels to support USE_ROW_WISE_SCALE/USE_INVERSE_SCALE, changed SF pointer name to SFOutput, and adjusted block sizing and packed-output types.
Host launchers & headers
csrc/nv_internal/tensorrt_llm/kernels/quantization.h, csrc/nv_internal/cpp/kernels/quantization.cu
Added invokeNvfp4QuantAndPerTokenScale and invokeRowWiseAmax; extended invokeFP4Quantization/invokeMxFP8Quantization signatures (renamed SFOuputSFOutput, added use_row_wise_scale/inverse_scale) and updated kernel dispatch/attributes.
FP4 Python ops & bindings
csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.cpp, csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.h, flashinfer/quantization/fp4_quantization.py, flashinfer/quantization/__init__.py, flashinfer/fp4_quantization.py, flashinfer/__init__.py
Added SM100 custom op nvfp4_quant_and_per_token_scale, registered/exposed new Python API, added is_global_scale_inversed parameter to FP4 quantize callsites, and re-exported the new function.
Fused MoE runner / launcher / headers
include/flashinfer/trtllm/fused_moe/runner.h, csrc/trtllm_fused_moe_runner.cu, csrc/trtllm_fused_moe_kernel_launcher.cu
Threaded per-token scaling through MoE: added usePerTokenScaling and useExplicitQuantization, updated PermuteGemm1/Gemm2 runner signatures and run() args to accept per-token scales, added token_scales_fc2 workspace, and added explicit-quant path invoking NVFP4 per-token kernel.
Batched GEMM runner
include/flashinfer/trtllm/batched_gemm/KernelRunner.h, csrc/trtllm_batched_gemm_runner.cu
Added usePerTokenScaling option and gating logic to skip configs incompatible with per-token scaling.
Python MoE API & tests
flashinfer/fused_moe/core.py, tests/moe/test_trtllm_gen_per_token_moe.py
Added per_token_scale field and use_per_token_scaling plumbing in MoE inputs/runner/APIs, updated public fused-MoE op signatures to accept per-token scales, and added a new routed per-token MoE test.
Build/artifacts/jit & enums
flashinfer/artifacts.py, flashinfer/tllm_enums.py, flashinfer/jit/fused_moe.py
Updated TRTLLM_GEN_BMM artifact path/checksum, adjusted dtype deduction to use element counts, and added quantization.cu to fused-MoE JIT compilation units.
Exports & re-exports
flashinfer/quantization/__init__.py, flashinfer/fp4_quantization.py, flashinfer/__init__.py
Re-exported nvfp4_quant_and_per_token_scale and updated public FP4 quantization APIs and module exports.

Sequence Diagram

sequenceDiagram
    participant Host as Host
    participant Launcher as MoE Launcher
    participant NVKernel as NVFP4 Kernel
    participant FP4Kernel as FP4 Quant Kernel
    participant Gemm2 as Gemm2 / FC2

    Host->>Launcher: call fused MoE with input + per_token_scales?
    Launcher->>Launcher: allocate workspace (token_scales_fc2) if needed
    Launcher->>NVKernel: invokeNvfp4QuantAndPerTokenScale(input, globalScaleInv, sfLayout, ...)
    activate NVKernel
    NVKernel->>NVKernel: reduce per-row amax → compute per-token scales
    NVKernel-->>Launcher: write perTokenScaleOutput, weightOutput, scaleOutput
    deactivate NVKernel

    Launcher->>FP4Kernel: invokeFP4Quantization(FC1_output, perTokenScales?, use_row_wise/inverse flags)
    activate FP4Kernel
    FP4Kernel->>FP4Kernel: apply per-token or row-wise/inverse scale, quantize to FP4, emit block scales
    FP4Kernel-->>Launcher: packed FP4 activations + block scales
    deactivate FP4Kernel

    Launcher->>Gemm2: run Gemm2(perTokenScales_fc2, quantized_activation, weights, scales)
    activate Gemm2
    Gemm2-->>Host: final MoE output
    deactivate Gemm2
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

Possibly related PRs

Suggested labels

run-ci

Suggested reviewers

  • bkryu
  • cyx-6
  • aleozlx
  • yzh119
  • djmmoss
  • jimmyzho
  • kahyunnam

Poem

🐇 I hopped through kernels, scaled each token small,
Packed FP4 whispers, answered every call,
Per-token scales in burrowed stacks,
MoE and quant sing — no regressions, no cracks,
A rabbit cheers for merges, big and small!

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 21.05% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly summarizes the main change: enabling per-token quantization for Trtllm-gen MoE kernel.
Description check ✅ Passed The description covers key changes (per-token scale API, kernel optimizations, explicit quantization after FC1) but pre-commit and test checkboxes are incomplete.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for per-token scaling in FP4 quantization for MoE models, including updates to the quantization kernels, the MoE runner, and the Python interface. The changes enable row-wise Amax calculation and quantization for FP4, and integrate these into the fused MoE pipeline. My review identified several issues: missing header includes for cuda::std::maximum, potential integer overflow in the Amax kernel, efficiency improvements for the Amax kernel, and the need to expand type support for NvFP4. Additionally, I pointed out unused variables and the use of magic numbers that should be replaced with named constants.

#include <cudaTypedefs.h>
#include <float.h>

#include <cub/cub.cuh>
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The use of cuda::maximum (or cuda::std::maximum) in the rowWiseAmaxKernel requires including <cuda/std/functional>. Without this, compilation might fail depending on the CUDA toolkit version and transitive includes.

#include <cub/cub.cuh>
#include <cuda/std/functional>

Comment on lines +754 to +755
FLASHINFER_CHECK(mGemm2.mDtypeAct == btg::Dtype::E2m1,
"Currently only support NvFP4 when using explicit quantization.");
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The check mGemm2.mDtypeAct == btg::Dtype::E2m1 is too restrictive. NvFP4 also supports MxE2m1 (vector size 32). This check should be expanded to allow MxE2m1, and the subsequent call to invokeFP4Quantization should dispatch to the correct template instantiation (16 or 32) based on the actual type of mGemm2.mDtypeAct.

Comment on lines +239 to +265
template <typename T, uint32_t BLOCK_SIZE>
__global__ void rowWiseAmaxKernel(uint32_t m, uint32_t n, T const* input, float* amaxOutput, float scale) {
uint32_t rowIdx = blockIdx.x;
if (rowIdx >= m) return;

float localMax = 0.f;
for (uint32_t colIdx = threadIdx.x; colIdx < n; colIdx += blockDim.x) {
T element = input[rowIdx * n + colIdx];
localMax = fmaxf(localMax, fabsf(static_cast<float>(element) * scale));
}

using BlockReduce = cub::BlockReduce<float, BLOCK_SIZE>;
__shared__ typename BlockReduce::TempStorage tempStorage;
float blockMax = BlockReduce(tempStorage)
.Reduce(
localMax,
#if CUDART_VERSION >= 12090
cuda::maximum<> {}
#else
cub::Max(),
#endif
);

if (threadIdx.x == 0) {
amaxOutput[rowIdx] = blockMax;
}
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There are two issues in rowWiseAmaxKernel:

  1. Efficiency: Multiplying by scale inside the loop performs n multiplications per thread. It is more efficient to compute the maximum of absolute values first and multiply by scale once at the end when writing to global memory.
  2. Correctness (Overflow): The indexing rowIdx * n + colIdx uses uint32_t. If the product exceeds $2^{32}$ (possible with large token counts and intermediate sizes), it will overflow. Using static_cast<size_t>(rowIdx) * n prevents this.
template <typename T, uint32_t BLOCK_SIZE>
__global__ void rowWiseAmaxKernel(uint32_t m, uint32_t n, T const* input, float* amaxOutput, float scale) {
  uint32_t rowIdx = blockIdx.x;
  if (rowIdx >= m) return;

  float localMax = 0.f;
  for (uint32_t colIdx = threadIdx.x; colIdx < n; colIdx += blockDim.x) {
    T element = input[static_cast<size_t>(rowIdx) * n + colIdx];
    localMax = fmaxf(localMax, fabsf(static_cast<float>(element)));
  }

  using BlockReduce = cub::BlockReduce<float, BLOCK_SIZE>;
  __shared__ typename BlockReduce::TempStorage tempStorage;
  float blockMax = BlockReduce(tempStorage)
                       .Reduce(
                           localMax,
#if CUDART_VERSION >= 12090
                           cuda::std::maximum<> {}
#else
                           cub::Max(),
#endif
                       );

  if (threadIdx.x == 0) {
    amaxOutput[rowIdx] = blockMax * scale;
  }
}

Comment thread csrc/trtllm_fused_moe_runner.cu Outdated
workspace.token_scales_fc2 != nullptr,
"workspace.token_scales_fc2 must be provided When using explicit quantization.");
const int mMultiProcessorCount = tensorrt_llm::common::getMultiProcessorCount();
int intermediate_size_factor = isGatedActivation(args.activation_type) ? 2 : 1;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The variable intermediate_size_factor is calculated but never used in the subsequent logic. It should be removed to keep the code clean.

Comment on lines +762 to +763
auto sfLayout = mGemm2.mTileTokensDim >= 128 ? QuantizationSFLayout::SWIZZLED_128x4
: QuantizationSFLayout::SWIZZLED_8x4;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The selection of sfLayout based on a hardcoded threshold (mTileTokensDim >= 128) is a heuristic. As noted in the FIXME, this should ideally be determined from the actual kernel configuration to ensure compatibility with the GEMM2 kernel's expected layout.

Comment thread csrc/trtllm_fused_moe_runner.cu Outdated
invokeRowWiseAmax<__nv_bfloat16>(workspace.total_max_padded_tokens, args.intermediate_size,
reinterpret_cast<__nv_bfloat16*>(workspace.gemm1_output),
reinterpret_cast<float*>(workspace.token_scales_fc2),
1.f / 448.f / 6.f, stream);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The scale factor 1.f / 448.f / 6.f is a magic number. It should be defined as a named constant (e.g., based on FP8 and FP4 max values) to improve code clarity and maintainability.

Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
@IwakuraRein IwakuraRein marked this pull request as ready for review April 15, 2026 21:24
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (6)
flashinfer/quantization/fp4_quantization.py (3)

773-782: ⚠️ Potential issue | 🟠 Major

backend="cute-dsl" currently drops the new inverse-scale flag.

The CUDA path forwards is_global_scale_inversed, but the CuTe-DSL path still calls _fp4_quantize_cute_dsl() with the old parameter set. The same API now has backend-dependent semantics.

🛠️ Short-term guard
     if backend == "cute-dsl":
+        if is_global_scale_inversed:
+            raise ValueError(
+                "backend='cute-dsl' does not support is_global_scale_inversed yet."
+            )
         return _fp4_quantize_cute_dsl(
             input,
             global_scale,
             sf_vec_size,
             sf_use_ue8m0,

Also applies to: 822-830

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/quantization/fp4_quantization.py` around lines 773 - 782, The
"cute-dsl" branch of the backend call to _fp4_quantize_cute_dsl is missing the
new is_global_scale_inversed argument and thus drops the inverse-scale flag;
update the cute-dsl call(s) (the one shown and the similar call around the other
backend branch) to pass the is_global_scale_inversed parameter through (matching
the CUDA path) so both backends receive the same API/semantics and preserve the
inverse-scale behavior.

176-185: ⚠️ Potential issue | 🟠 Major

Update _fake_fp4_quantize_sm100 to the new op schema.

The real op now accepts is_sf_8x4_layout, is_global_scale_inversed, and enable_pdl, but the fake op still exposes the old signature. FakeTensor/torch.compile execution will break once the new argument list is exercised.

🧩 Minimal fix
 def _fake_fp4_quantize_sm100(
     input: torch.Tensor,
     global_scale: Optional[torch.Tensor] = None,
     sf_vec_size: int = 16,
     sf_use_ue8m0: bool = False,
     is_sf_swizzled_layout: bool = True,
+    is_sf_8x4_layout: bool = False,
+    is_global_scale_inversed: bool = False,
+    enable_pdl: Optional[bool] = None,
 ) -> Tuple[torch.Tensor, torch.Tensor]:

Based on learnings, functions decorated with register_fake_op must exactly mirror the corresponding real op signatures.

Also applies to: 237-244

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/quantization/fp4_quantization.py` around lines 176 - 185, The fake
op function _fake_fp4_quantize_sm100 must mirror the real op signature: add the
new parameters is_sf_8x4_layout: bool, is_global_scale_inversed: bool, and
enable_pdl: Optional[bool] to the decorated function signature (both occurrences
around the current definitions), accept them as arguments, and propagate them
(or at least accept and ignore) in the function body so FakeTensor/torch.compile
sees the same signature as the real op; ensure parameter names match exactly and
update any internal call/return wiring so the new params are accepted without
changing behavior.

721-730: ⚠️ Potential issue | 🟠 Major

Preserve the existing positional fp4_quantize ABI.

is_global_scale_inversed was inserted before enable_pdl and backend, so older positional calls now mis-bind arguments.

💡 Backward-compatible shape
 def fp4_quantize(
     input: torch.Tensor,
     global_scale: Optional[torch.Tensor] = None,
     sf_vec_size: int = 16,
     sf_use_ue8m0: bool = False,
     is_sf_swizzled_layout: bool = True,
     is_sf_8x4_layout: bool = False,
-    is_global_scale_inversed: bool = False,
     enable_pdl: Optional[bool] = None,
     backend: str = "cuda",
+    *,
+    is_global_scale_inversed: bool = False,
 ) -> Tuple[torch.Tensor, torch.Tensor]:
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/quantization/fp4_quantization.py` around lines 721 - 730, The
function fp4_quantize signature was changed such that is_global_scale_inversed
was inserted before enable_pdl (and backend), breaking older positional calls;
restore the original positional ABI by moving is_global_scale_inversed to its
prior position (after enable_pdl and backend or at the end of the parameter list
as it was previously), keep its default value unchanged, and update any
calls/tests accordingly; verify by adding or updating a minimal unit test that
calls fp4_quantize with positional arguments to ensure arguments bind as before.
flashinfer/fused_moe/core.py (1)

2916-2949: ⚠️ Potential issue | 🟠 Major

Keep per_token_scale out of the positional ABI.

Adding per_token_scale before output changes the public call signature for both APIs. Existing positional callers that used the old signature will now bind their output tensor or tune_max_num_tokens integer to per_token_scale.

💡 Safer signature shape
 def trtllm_fp4_block_scale_moe(
     ...
     activation_type: int = ActivationType.Swiglu.value,
-    per_token_scale: Optional[torch.Tensor] = None,
     output: Optional[torch.Tensor] = None,
     tune_max_num_tokens: int = 8192,
     norm_topk_prob: bool = True,
     routing_replay_out: Optional[torch.Tensor] = None,
+    *,
+    per_token_scale: Optional[torch.Tensor] = None,
 ) -> List[torch.Tensor]:
 def trtllm_fp4_block_scale_routed_moe(
     ...
     activation_type: int = ActivationType.Swiglu.value,
-    per_token_scale: Optional[torch.Tensor] = None,
     output: Optional[torch.Tensor] = None,
     tune_max_num_tokens: int = 8192,
+    *,
+    per_token_scale: Optional[torch.Tensor] = None,
 ) -> List[torch.Tensor]:

Also applies to: 3056-3089

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/fused_moe/core.py` around lines 2916 - 2949, The new parameter
per_token_scale was inserted before output, changing the positional ABI and
breaking existing callers; move per_token_scale out of the positional argument
list (e.g., place it at the end of the parameter list or after
routing_replay_out) so existing positional bindings continue to work, and apply
the same change to the duplicate signature around lines 3056-3089 (function name
trtllm_fp4_block_scale_moe) to keep both signatures consistent.
csrc/trtllm_fused_moe_kernel_launcher.cu (2)

874-892: ⚠️ Potential issue | 🟠 Major

getValidConfigs() is using the wrong scaling mode for FP8 per-tensor.

prepare_moe_common() only turns on usePerTokenScaling for this path when routing is Llama4, but Line 891 hardcodes true for every Fp8PerTensorLauncher::getValidConfigs() call. That means the autotuner can cache tactics for a different runner capability set than the one used at execution time on Default/Renormalize/TopK routes.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/trtllm_fused_moe_kernel_launcher.cu` around lines 874 - 892,
getValidConfigs currently hardcodes usePerTokenScaling=true when constructing
tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner, which causes FP8 per-tensor
runners to be tuned with per-token scaling even when prepare_moe_common only
enables per-token scaling for Llama4 routing; change getValidConfigs
(Fp8PerTensorLauncher::getValidConfigs) to determine the usePerTokenScaling
value from the same condition used in prepare_moe_common (or add a parameter
passed from the caller indicating whether per-token scaling is enabled) and pass
that boolean into the MoE::Runner constructor instead of the hardcoded true so
autotuner config matches runtime routing modes.

1652-1684: ⚠️ Potential issue | 🔴 Critical

The explicit NvFP4 branch still lets MxE2m1 weights through.

Line 1652 sizes gemm1_output_scale from mDtypeWeights, so MxE2m1 allocates 32-value SF blocks. The per-token branch only checks mDtypeAct == btg::Dtype::E2m1, but the explicit quantization path later emits NVFP4 16-value SFs into that same buffer. That makes workspace.activation_output_scale too small/misaligned whenever callers pass NvFP4 activations with MXFP4 weights.

Suggested guard
-      TVM_FFI_ICHECK(mDtypeAct == btg::Dtype::E2m1)
-          << "NvFP4 MoE: currently only support NvFP4 x NvFP4 when using per-token scaling.";
+      TVM_FFI_ICHECK(mDtypeAct == btg::Dtype::E2m1 &&
+                     mDtypeWeights == btg::Dtype::E2m1)
+          << "NvFP4 MoE: per-token scaling currently only supports NvFP4 activations with NvFP4 weights.";
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/trtllm_fused_moe_kernel_launcher.cu` around lines 1652 - 1684, The bug
is a dtype mismatch: sf_vec_size is computed from mDtypeWeights but
gemm1_output_scale and later quant paths use mDtypeAct, allowing MxE2m1 weights
to produce a too-small SF buffer when activations are NvFP4; fix by computing
the swizzle vector size from mDtypeAct (not mDtypeWeights) or by guarding the
per-token/quant path with both mDtypeAct and mDtypeWeights checks — e.g., change
the sf_vec_size derivation or the gemm1_output_scale allocation condition so the
branch that computes sf_size and allocates gemm1_output_scale uses mDtypeAct
(and/or requires mDtypeWeights to match) to ensure the allocated SF block size
(32 vs 16) matches the actual activation/quantization format used by
alloc_tensor/gemm1_output and activation_output_scale.
🧹 Nitpick comments (1)
tests/moe/test_trtllm_gen_per_token_moe.py (1)

59-62: Use the standard GPU-support skip helper here.

The raw compute_capability[0] == 10 gate bypasses the test-suite skip helpers and makes the supported-arch contract harder to read.

As per coding guidelines, tests/**/*.py: Skip test execution on unsupported GPU architectures using flashinfer.utils check functions (is_sm90a_supported(), is_sm100a_supported(), etc.) or API methods like api_name.is_compute_capability_supported(cc).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/moe/test_trtllm_gen_per_token_moe.py` around lines 59 - 62, Replace the
manual compute_capability check around get_compute_capability/compute_capability
with the project test helper: import and call the appropriate flashinfer.utils
function (e.g., is_sm100a_supported()) or use
api_name.is_compute_capability_supported(cc) to decide skipping; specifically,
remove the direct compute_capability[0] check and instead call
is_sm100a_supported() (or the matching helper for SM100/SM103) and call
pytest.skip(...) when it returns False so the test uses the standard GPU-support
skip helper rather than a raw capability comparison.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh`:
- Around line 53-60: The preprocessor gate for CVT_FP16_TO_FP4_ELTS_PER_THREAD
is incorrectly parsed due to operator precedence, so wrap the CUDA-version check
together with the __CUDACC_VER_* defines and parentheses so the 16-element path
requires both SM100+ (__CUDA_ARCH__ >= 1000) and a supported compiler version;
update the `#if` condition that references __CUDA_ARCH__, __CUDACC_VER_MAJOR__,
and __CUDACC_VER_MINOR__ so the version clause is parenthesized with its defined
checks (e.g., group the __CUDACC_VER_MAJOR__/__CUDACC_VER_MINOR__ existence and
the (12.9+ or 13+) comparison) before ANDing with the __CUDA_ARCH__ test,
leaving the constexpr CVT_FP16_TO_FP4_ELTS_PER_THREAD branches unchanged.

In `@csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.cpp`:
- Around line 285-323: The code currently enforces
output_per_token_scale.numel() == m but still accepts
expanded_idx_to_permuted_idx which can map row indices > m-1 and cause OOB
writes; update the validation so when expanded_idx_to_permuted_idx.has_value()
you either (A) require output_per_token_scale to be sized to the permuted/padded
domain (check against expanded domain size instead of m) or (B) validate all
entries of expanded_idx_to_permuted_idx are within [0, m) and reject otherwise.
Concretely, in the function handling these variables (use symbols
output_per_token_scale, expanded_idx_to_permuted_idx, m and
expanded_idx_to_permuted_idx_ptr) add a check: if
expanded_idx_to_permuted_idx.has_value() then read the int32 mapping buffer and
assert 0 <= map[i] < m for all entries (or change the output_per_token_scale
size check to the permuted size) and TVM_FFI_LOG_AND_THROW on failure.

In `@tests/moe/test_trtllm_gen_per_token_moe.py`:
- Around line 314-318: The relative-difference mask uses
torch.reciprocal(reference) which blows up on zeros; change the computation of
mask (and thus mismatch_rate) to clamp the denominator: compute a safe
denominator from reference (e.g., denom = torch.clamp(torch.abs(reference),
min=eps) or use torch.where to replace near-zero reference with eps) and then
compute mask = torch.abs((reference - result) / denom) < rtol so rtol, mask,
mismatch_rate and the variables reference/result remain used but avoid division
by zero.

---

Outside diff comments:
In `@csrc/trtllm_fused_moe_kernel_launcher.cu`:
- Around line 874-892: getValidConfigs currently hardcodes
usePerTokenScaling=true when constructing
tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner, which causes FP8 per-tensor
runners to be tuned with per-token scaling even when prepare_moe_common only
enables per-token scaling for Llama4 routing; change getValidConfigs
(Fp8PerTensorLauncher::getValidConfigs) to determine the usePerTokenScaling
value from the same condition used in prepare_moe_common (or add a parameter
passed from the caller indicating whether per-token scaling is enabled) and pass
that boolean into the MoE::Runner constructor instead of the hardcoded true so
autotuner config matches runtime routing modes.
- Around line 1652-1684: The bug is a dtype mismatch: sf_vec_size is computed
from mDtypeWeights but gemm1_output_scale and later quant paths use mDtypeAct,
allowing MxE2m1 weights to produce a too-small SF buffer when activations are
NvFP4; fix by computing the swizzle vector size from mDtypeAct (not
mDtypeWeights) or by guarding the per-token/quant path with both mDtypeAct and
mDtypeWeights checks — e.g., change the sf_vec_size derivation or the
gemm1_output_scale allocation condition so the branch that computes sf_size and
allocates gemm1_output_scale uses mDtypeAct (and/or requires mDtypeWeights to
match) to ensure the allocated SF block size (32 vs 16) matches the actual
activation/quantization format used by alloc_tensor/gemm1_output and
activation_output_scale.

In `@flashinfer/fused_moe/core.py`:
- Around line 2916-2949: The new parameter per_token_scale was inserted before
output, changing the positional ABI and breaking existing callers; move
per_token_scale out of the positional argument list (e.g., place it at the end
of the parameter list or after routing_replay_out) so existing positional
bindings continue to work, and apply the same change to the duplicate signature
around lines 3056-3089 (function name trtllm_fp4_block_scale_moe) to keep both
signatures consistent.

In `@flashinfer/quantization/fp4_quantization.py`:
- Around line 773-782: The "cute-dsl" branch of the backend call to
_fp4_quantize_cute_dsl is missing the new is_global_scale_inversed argument and
thus drops the inverse-scale flag; update the cute-dsl call(s) (the one shown
and the similar call around the other backend branch) to pass the
is_global_scale_inversed parameter through (matching the CUDA path) so both
backends receive the same API/semantics and preserve the inverse-scale behavior.
- Around line 176-185: The fake op function _fake_fp4_quantize_sm100 must mirror
the real op signature: add the new parameters is_sf_8x4_layout: bool,
is_global_scale_inversed: bool, and enable_pdl: Optional[bool] to the decorated
function signature (both occurrences around the current definitions), accept
them as arguments, and propagate them (or at least accept and ignore) in the
function body so FakeTensor/torch.compile sees the same signature as the real
op; ensure parameter names match exactly and update any internal call/return
wiring so the new params are accepted without changing behavior.
- Around line 721-730: The function fp4_quantize signature was changed such that
is_global_scale_inversed was inserted before enable_pdl (and backend), breaking
older positional calls; restore the original positional ABI by moving
is_global_scale_inversed to its prior position (after enable_pdl and backend or
at the end of the parameter list as it was previously), keep its default value
unchanged, and update any calls/tests accordingly; verify by adding or updating
a minimal unit test that calls fp4_quantize with positional arguments to ensure
arguments bind as before.

---

Nitpick comments:
In `@tests/moe/test_trtllm_gen_per_token_moe.py`:
- Around line 59-62: Replace the manual compute_capability check around
get_compute_capability/compute_capability with the project test helper: import
and call the appropriate flashinfer.utils function (e.g., is_sm100a_supported())
or use api_name.is_compute_capability_supported(cc) to decide skipping;
specifically, remove the direct compute_capability[0] check and instead call
is_sm100a_supported() (or the matching helper for SM100/SM103) and call
pytest.skip(...) when it returns False so the test uses the standard GPU-support
skip helper rather than a raw capability comparison.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: add4842f-6d94-40f3-ada7-11b624b69fd8

📥 Commits

Reviewing files that changed from the base of the PR and between 25b324d and 884500f.

📒 Files selected for processing (20)
  • csrc/nv_internal/cpp/kernels/quantization.cu
  • csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh
  • csrc/nv_internal/tensorrt_llm/kernels/quantization.h
  • csrc/nv_internal/tensorrt_llm/kernels/quantization_utils.cuh
  • csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.cpp
  • csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.h
  • csrc/trtllm_batched_gemm_runner.cu
  • csrc/trtllm_fused_moe_kernel_launcher.cu
  • csrc/trtllm_fused_moe_runner.cu
  • flashinfer/__init__.py
  • flashinfer/artifacts.py
  • flashinfer/fp4_quantization.py
  • flashinfer/fused_moe/core.py
  • flashinfer/jit/fused_moe.py
  • flashinfer/quantization/__init__.py
  • flashinfer/quantization/fp4_quantization.py
  • flashinfer/tllm_enums.py
  • include/flashinfer/trtllm/batched_gemm/KernelRunner.h
  • include/flashinfer/trtllm/fused_moe/runner.h
  • tests/moe/test_trtllm_gen_per_token_moe.py

Comment thread csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh
Comment thread csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.cpp Outdated
Comment on lines +314 to +318
rtol = 0.2
mask = torch.abs((reference - result) * torch.reciprocal(reference)) < rtol
mismatch_rate = (~mask).float().mean().item()
print(f"Mismatch: {mismatch_rate * 100}%")
assert mismatch_rate < 0.3
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Clamp the denominator in the mismatch metric.

torch.reciprocal(reference) explodes on zero or near-zero reference values, so this can fail even when the kernel output is fine.

🔧 Safer comparison
     rtol = 0.2
-    mask = torch.abs((reference - result) * torch.reciprocal(reference)) < rtol
+    denom = reference.abs().clamp_min(1e-6)
+    mask = ((reference - result).abs() / denom) < rtol
     mismatch_rate = (~mask).float().mean().item()
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/moe/test_trtllm_gen_per_token_moe.py` around lines 314 - 318, The
relative-difference mask uses torch.reciprocal(reference) which blows up on
zeros; change the computation of mask (and thus mismatch_rate) to clamp the
denominator: compute a safe denominator from reference (e.g., denom =
torch.clamp(torch.abs(reference), min=eps) or use torch.where to replace
near-zero reference with eps) and then compute mask = torch.abs((reference -
result) / denom) < rtol so rtol, mask, mismatch_rate and the variables
reference/result remain used but avoid division by zero.

Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
@IwakuraRein
Copy link
Copy Markdown
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !555 has been created, and the CI pipeline #48633822 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (2)
csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.cpp (1)

285-287: ⚠️ Potential issue | 🟠 Major

Output buffer size validation is incomplete when expanded_idx_to_permuted_idx is provided.

When the mapping tensor is present, output_per_token_scale.numel() is not validated. If any mapped index exceeds the actual output buffer size, the kernel will write out-of-bounds. Either:

  1. Validate that output_per_token_scale.numel() covers the maximum possible mapped index, or
  2. Validate that all entries in expanded_idx_to_permuted_idx are within [0, output_per_token_scale.numel()).
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.cpp` around lines 285 - 287,
The current check only validates output_per_token_scale.numel() when
expanded_idx_to_permuted_idx is absent; when expanded_idx_to_permuted_idx is
present you must also bound-check the mapped indices to avoid OOB writes. Modify
the logic around expanded_idx_to_permuted_idx and output_per_token_scale (the
variables/functions referenced in this block) to either assert
output_per_token_scale.numel() >= m (as before) and additionally verify that
every value in expanded_idx_to_permuted_idx is < output_per_token_scale.numel(),
or compute the maximum value in expanded_idx_to_permuted_idx and ensure
output_per_token_scale.numel() > max_mapped_index; use TVM_FFI_ICHECK_EQ/ICHECK
or equivalent to fail with a clear error message if the condition is violated.
csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh (1)

696-700: ⚠️ Potential issue | 🟠 Major

Mapped rowIdx is not bounds-checked against output buffer dimensions.

When expandedIdxToPermutedIdx is provided, the kernel remaps rowIdx but only checks for rowIdx < 0. If the mapped index exceeds the output buffer's row dimension, writes to perTokenScaleOutput[rowIdx], weightOutput, and scaleOutput will be out-of-bounds. The calling code should either validate that all mapped indices are within [0, output_rows), or the kernel should add an upper-bound check here.

Suggested kernel-side guard
 int rowIdx = blockIdx.x;
 if (rowIdx >= m) return;
 if (expandedIdxToPermutedIdx != nullptr) {
   rowIdx = expandedIdxToPermutedIdx[rowIdx];
 }
-if (rowIdx < 0) return;
+if (rowIdx < 0 || rowIdx >= static_cast<int>(m)) return;

Note: This relates to the validation issue flagged in the C++ wrapper (fp4Quantize.cpp). A defense-in-depth approach would add checks in both places.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh` around lines 696 -
700, After remapping rowIdx using expandedIdxToPermutedIdx in the kernel, add an
upper-bound guard to ensure the mapped index is within the output buffer row
count before any writes; specifically, after "rowIdx =
expandedIdxToPermutedIdx[rowIdx];" check "if (rowIdx < 0 || rowIdx >=
output_rows) return;" (use the actual output row-count variable available in
this kernel) so subsequent writes to perTokenScaleOutput[rowIdx], weightOutput
and scaleOutput cannot go out-of-bounds.
🧹 Nitpick comments (1)
csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh (1)

60-67: Consider checking CUDA API return values.

cudaGetDevice and cudaDeviceGetAttribute can fail, leaving device at -1 or maxThreadsPerSM at the default 1024. While failures are unlikely in practice, unchecked errors could produce incorrect block counts on unusual setups.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh` around lines 60 - 67,
The function runtimeBlocksPerSM does not check CUDA API return values for
cudaGetDevice and cudaDeviceGetAttribute; update runtimeBlocksPerSM to check the
return status of these calls (cudaError_t), handle errors by logging or
returning a safe default (e.g., assume 1 block per SM or fall back to a known
maxThreadsPerSM) and only use device/maxThreadsPerSM when the calls succeed;
reference the function runtimeBlocksPerSM and the
cudaGetDevice/cudaDeviceGetAttribute calls so you locate and wrap their calls,
returning a safe value or propagating an error when CUDA APIs fail.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh`:
- Around line 696-700: After remapping rowIdx using expandedIdxToPermutedIdx in
the kernel, add an upper-bound guard to ensure the mapped index is within the
output buffer row count before any writes; specifically, after "rowIdx =
expandedIdxToPermutedIdx[rowIdx];" check "if (rowIdx < 0 || rowIdx >=
output_rows) return;" (use the actual output row-count variable available in
this kernel) so subsequent writes to perTokenScaleOutput[rowIdx], weightOutput
and scaleOutput cannot go out-of-bounds.

In `@csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.cpp`:
- Around line 285-287: The current check only validates
output_per_token_scale.numel() when expanded_idx_to_permuted_idx is absent; when
expanded_idx_to_permuted_idx is present you must also bound-check the mapped
indices to avoid OOB writes. Modify the logic around
expanded_idx_to_permuted_idx and output_per_token_scale (the variables/functions
referenced in this block) to either assert output_per_token_scale.numel() >= m
(as before) and additionally verify that every value in
expanded_idx_to_permuted_idx is < output_per_token_scale.numel(), or compute the
maximum value in expanded_idx_to_permuted_idx and ensure
output_per_token_scale.numel() > max_mapped_index; use TVM_FFI_ICHECK_EQ/ICHECK
or equivalent to fail with a clear error message if the condition is violated.

---

Nitpick comments:
In `@csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh`:
- Around line 60-67: The function runtimeBlocksPerSM does not check CUDA API
return values for cudaGetDevice and cudaDeviceGetAttribute; update
runtimeBlocksPerSM to check the return status of these calls (cudaError_t),
handle errors by logging or returning a safe default (e.g., assume 1 block per
SM or fall back to a known maxThreadsPerSM) and only use device/maxThreadsPerSM
when the calls succeed; reference the function runtimeBlocksPerSM and the
cudaGetDevice/cudaDeviceGetAttribute calls so you locate and wrap their calls,
returning a safe value or propagating an error when CUDA APIs fail.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: bfa762e6-3385-4f57-91b8-9b6f32efa427

📥 Commits

Reviewing files that changed from the base of the PR and between 884500f and 3ecaf6f.

📒 Files selected for processing (2)
  • csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh
  • csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.cpp

Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh (1)

589-677: ⚠️ Potential issue | 🔴 Critical

Store the SM100 16-element expert pack as 64 bits.

CVT_FP16_TO_FP4_ELTS_PER_THREAD can now be 16, but this kernel still indexes out as uint32_t*. On that path, Line 675 writes only half of each packed FP4 vector and the row stride becomes incorrect. This needs the same FP4OutT handling you added to quantize_with_block_size.

🐛 Proposed fix
-    int32_t numRows, int32_t numCols, Type const* in, float const* SFScale, uint32_t* out,
+    int32_t numRows, int32_t numCols, Type const* in, float const* SFScale, void* out,
     uint32_t* SFout, int32_t* mask, bool use_silu_and_mul, int n_experts) {
 `#if` defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
   using PackedVecT = PackedVec<Type, CVT_FP16_TO_FP4_ELTS_PER_THREAD>;
+  using FP4OutT =
+      std::conditional_t<CVT_FP16_TO_FP4_ELTS_PER_THREAD == 16, uint64_t, uint32_t>;
   static constexpr int CVT_FP4_NUM_THREADS_PER_SF =
       (CVT_FP4_SF_VEC_SIZE / CVT_FP16_TO_FP4_ELTS_PER_THREAD);
   static_assert(sizeof(PackedVecT) == sizeof(Type) * CVT_FP16_TO_FP4_ELTS_PER_THREAD,
                 "Vec size is not matched.");
...
-    auto& out_pos = out[outOffset];
+    auto& out_pos = reinterpret_cast<FP4OutT*>(out)[outOffset];
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh` around lines 589 -
677, The kernel currently writes packed FP4 outputs into out declared as
uint32_t*, which breaks when CVT_FP16_TO_FP4_ELTS_PER_THREAD == 16 (SM100 path)
because each packed vector is 64 bits; update this code to use the same FP4OutT
pattern as quantize_with_block_size: define an FP4OutT type that is uint32_t
when CVT_FP16_TO_FP4_ELTS_PER_THREAD == 8 and uint64_t when == 16, change the
out pointer to FP4OutT* (or cast to it for writing), compute outOffset in units
of FP4OutT (adjusting the divisor/multipliers accordingly), and assign out_pos
as an FP4OutT from cvt_warp_fp16_to_fp4 before writing so the full 64-bit pack
is stored for the 16-element path; keep other logic (SFScaleVal, sf_out, etc.)
unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh`:
- Around line 280-289: The row-wise scale lookup uses SFScale[rowIdx] but the
kernels flatten rows with a batch offset (batchIdx * numRows), so for
numBatches>1 scales from batch 0 are reused; update the indexing to read
SFScale[batchIdx * numRows + rowIdx] (or compute a flattenedRow = batchIdx *
numRows + rowIdx) wherever SFScale is accessed (the block guarded by
USE_ROW_WISE_SCALE and USE_INVERSE_SCALE around SFScaleVal, and the analogous
block at lines ~506-515), then keep the existing inverse-scale path
(reciprocal_approximate_ftz) and bounds checks (rowIdx < numRows && SFScale !=
nullptr) but use the flattened index to apply the correct per-batch row scale.

---

Outside diff comments:
In `@csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh`:
- Around line 589-677: The kernel currently writes packed FP4 outputs into out
declared as uint32_t*, which breaks when CVT_FP16_TO_FP4_ELTS_PER_THREAD == 16
(SM100 path) because each packed vector is 64 bits; update this code to use the
same FP4OutT pattern as quantize_with_block_size: define an FP4OutT type that is
uint32_t when CVT_FP16_TO_FP4_ELTS_PER_THREAD == 8 and uint64_t when == 16,
change the out pointer to FP4OutT* (or cast to it for writing), compute
outOffset in units of FP4OutT (adjusting the divisor/multipliers accordingly),
and assign out_pos as an FP4OutT from cvt_warp_fp16_to_fp4 before writing so the
full 64-bit pack is stored for the 16-element path; keep other logic
(SFScaleVal, sf_out, etc.) unchanged.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 8f262bb4-74f5-4a9e-9b1a-eafb399deb20

📥 Commits

Reviewing files that changed from the base of the PR and between 3ecaf6f and 1225fdd.

📒 Files selected for processing (1)
  • csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh

Comment on lines +280 to +289
if constexpr (USE_ROW_WISE_SCALE) {
if (rowIdx < numRows && SFScale != nullptr) {
SFScaleVal = SFScale[rowIdx];
if constexpr (USE_INVERSE_SCALE) {
SFScaleVal = reciprocal_approximate_ftz(SFScaleVal);
}
} else {
SFScaleVal = 1.f;
}
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Apply row-wise scales with the batch offset.

These branches still read SFScale by row only, but the surrounding input/output addressing is already flattened as batchIdx * numRows + row. For numbatches > 1, every batch after 0 will reuse batch-0 scales and quantize with the wrong alpha.

Also applies to: 506-515

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh` around lines 280 -
289, The row-wise scale lookup uses SFScale[rowIdx] but the kernels flatten rows
with a batch offset (batchIdx * numRows), so for numBatches>1 scales from batch
0 are reused; update the indexing to read SFScale[batchIdx * numRows + rowIdx]
(or compute a flattenedRow = batchIdx * numRows + rowIdx) wherever SFScale is
accessed (the block guarded by USE_ROW_WISE_SCALE and USE_INVERSE_SCALE around
SFScaleVal, and the analogous block at lines ~506-515), then keep the existing
inverse-scale path (reciprocal_approximate_ftz) and bounds checks (rowIdx <
numRows && SFScale != nullptr) but use the flattened index to apply the correct
per-batch row scale.

Signed-off-by: Siyuan Fu <siyuanf@nvidia.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (1)
csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh (1)

43-52: ⚠️ Potential issue | 🟠 Major

loadPackedVec assumes 16/32-byte alignment that the FP32 kernel does not provide.

The loadPackedVec helper (line 48) reinterprets VecT& as PackedU32x4 or PackedU32x8, both of which require 16-byte and 32-byte alignment respectively. However, line 799 defines VecType as std::array<float, ELTS_PER_THREAD> (where ELTS_PER_THREAD=8 at line 793), which when stack-allocated has natural alignment of only alignof(float) (4 bytes). Passing this under-aligned vec to loadPackedVec at lines 809, 820, and 830 violates alignment requirements and is undefined behavior.

Either make the FP32 VecType explicitly 32-byte aligned (e.g., std::array<float, ELTS_PER_THREAD> alignas(32)) or add a static_assert in loadPackedVec to reject under-aligned types and ensure all call sites use properly aligned types.

Row-wise scale indexing does not account for batch dimension. When USE_ROW_WISE_SCALE=true, lines 280 and 506 look up scales as SFScale[rowIdx] and SFScale[threadRowIdxGlobal] respectively, but the output tensor is indexed with batch-flattened offsets like batchIdx * numRows + rowIdx (lines 305, 321, 500). This asymmetry means per-batch scale factors (if they exist) are not accessed correctly.

Fold the batch index into the scale lookup: SFScale[batchIdx * numRows + rowIdx] to match the flattened output addressing.

Also applies to: 799–812, 506–515

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh` around lines 43 - 52,
loadPackedVec reinterprets VecT as PackedU32x4/PackedU32x8 which require
16/32-byte alignment but FP32 VecType (std::array<float, ELTS_PER_THREAD>) is
only 4-byte aligned, causing UB; fix by either marking the FP32 VecType as
alignas(32) where VecType is defined (ELTS_PER_THREAD=8) or add a static_assert
inside loadPackedVec that checks alignof(VecT) >= sizeof(VecT)==16?16:32 (i.e.,
ensure 16/32-byte alignment) and update all callers to pass properly aligned
types; additionally, when USE_ROW_WISE_SCALE is true change scale lookups from
SFScale[rowIdx] or SFScale[threadRowIdxGlobal] to include the batch dimension
(e.g., SFScale[batchIdx * numRows + rowIdx] or SFScale[batchIdx * numRows +
threadRowIdxGlobal]) so per-batch scales match the flattened output indexing.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh`:
- Around line 43-52: loadPackedVec reinterprets VecT as PackedU32x4/PackedU32x8
which require 16/32-byte alignment but FP32 VecType (std::array<float,
ELTS_PER_THREAD>) is only 4-byte aligned, causing UB; fix by either marking the
FP32 VecType as alignas(32) where VecType is defined (ELTS_PER_THREAD=8) or add
a static_assert inside loadPackedVec that checks alignof(VecT) >=
sizeof(VecT)==16?16:32 (i.e., ensure 16/32-byte alignment) and update all
callers to pass properly aligned types; additionally, when USE_ROW_WISE_SCALE is
true change scale lookups from SFScale[rowIdx] or SFScale[threadRowIdxGlobal] to
include the batch dimension (e.g., SFScale[batchIdx * numRows + rowIdx] or
SFScale[batchIdx * numRows + threadRowIdxGlobal]) so per-batch scales match the
flattened output indexing.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 18891fd0-cbcf-4032-bd19-eec04d39b99c

📥 Commits

Reviewing files that changed from the base of the PR and between 1225fdd and 691a6ee.

📒 Files selected for processing (2)
  • csrc/nv_internal/cpp/kernels/quantization.cu
  • csrc/nv_internal/tensorrt_llm/kernels/quantization.cuh
✅ Files skipped from review due to trivial changes (1)
  • csrc/nv_internal/cpp/kernels/quantization.cu

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants