Skip to content

feat: Add FP8/NVFP4 quant fusion for MNNVL Allreduce#2263

Open
timlee0212 wants to merge 15 commits intoflashinfer-ai:mainfrom
timlee0212:mnnvlar_quant_fusion
Open

feat: Add FP8/NVFP4 quant fusion for MNNVL Allreduce#2263
timlee0212 wants to merge 15 commits intoflashinfer-ai:mainfrom
timlee0212:mnnvlar_quant_fusion

Conversation

@timlee0212
Copy link
Contributor

@timlee0212 timlee0212 commented Dec 24, 2025

📌 Description

  • Add FP8/NVFP4 quant fusion to MNNVL Allreduce
  • Support all 5 fusion patterns defined in the unified allreduce interface.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • FP8 and FP4 quantized outputs for fused AllReduce + RMSNorm, with configurable scale-factor layouts and optional output tensors.
  • Refactor

    • Pattern-driven dispatch, consolidated trait/configuration types, and updated kernel/grid sizing to support quantized fusion paths.
  • Public API

    • Added quantization/format enums, expanded fusion parameter fields and function signatures, and consolidated re-exports.
  • Tests

    • Expanded tests to cover RMSNorm+quantized patterns and added quantize/dequantize helpers.
  • Stability

    • Added runtime validation, layout/shape checks and CUDA-version guards for quantized paths.

✏️ Tip: You can customize this high-level summary in your review settings.

@timlee0212 timlee0212 marked this pull request as draft December 24, 2025 05:09
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Dec 24, 2025

📝 Walkthrough

Walkthrough

Adds FP8/FP4 quantization to the TRT-LLM MNNVL AllReduce ± RMSNorm fusion: new QuantType/QuantizationSFLayout enums, extended AllReduceFusionParams and host signatures, quantization kernels/utilities, templated kernel dispatch macros, trait-driven Python APIs, and updated tests covering quantized patterns.

Changes

Cohort / File(s) Summary
Core CUDA headers & kernels
include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh
Added QuantType enum and QuantizationSFLayout; extended AllReduceFusionParams with quant fields and config constants; added quant:: helpers (FP8/FP4/e2m1 conversions); templated fusion kernels (oneshot/twoshot/rmsNormLamport_fusion) and LAUNCH/DISPATCH macros routing by RMSNorm and QuantType.
CUDA implementation / Host dispatch
csrc/trtllm_mnnvl_allreduce.cu
Expanded trtllm_mnnvl_allreduce_fusion to accept optional output and quant params (quant_type, quant_out, sf_out, output_scale, layout_code); added runtime validations for quant/RMSNorm combos; populate new AllReduceFusionParams pointers and dispatch quant-aware kernels.
Python shared types
flashinfer/comm/_types.py
New canonical enums/types: AllReduceFusionPattern, QuantFusionType, QuantizationSFLayout, FusionPatternTraits, and get_pattern_traits mapping pattern capabilities and quant traits.
Python public API & re-exports
flashinfer/comm/__init__.py, flashinfer/comm/allreduce.py, flashinfer/comm/trtllm_ar.py, flashinfer/comm/trtllm_mnnvl_ar.py
Re-exported canonical types; removed duplicate local definitions; switched to trait-driven routing/validation; added quantized fused path (trtllm_mnnvl_fused_allreduce_add_rmsnorm_quant); updated public function signatures to accept optional outputs and quant parameters.
Tests
tests/comm/test_trtllm_mnnvl_allreduce.py
Refactored tests to use AllReduceFusionPattern (including FP8/NVFP4 variants), added FP8/NVFP4 helpers and dequant checks, expanded parametrization and assertions to cover quantized and trait-driven expectations.

Sequence Diagram(s)

sequenceDiagram
  actor Python
  participant Frontend as Python API
  participant Launcher as C++ Host Dispatcher
  participant GPU as Kernel
  participant NCCL as Comm

  Python->>Frontend: call fused_allreduce(..., quant_type?, quant_out?, sf_out?, output_scale?, layout_code?)
  Frontend->>Launcher: build AllReduceFusionParams (rmsnorm, quantType, sfLayout, pointers)
  Launcher->>GPU: launch oneshot/twoshot kernel templated by QuantType
  GPU->>NCCL: perform AllReduce across ranks
  NCCL-->>GPU: reduced data
  GPU->>GPU: optional RMSNorm fusion -> optional quantize (FP8/FP4) -> write quant_out & sf_out & outputs
  GPU-->>Launcher: completion / output pointers
  Launcher-->>Frontend: wrap results into tensors and return to Python
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested reviewers

  • yzh119
  • djmmoss
  • nvmbreughe
  • cyx-6
  • jimmyzho
  • wenscarl

Poem

🐰 I hop through kernels, scales in paw,

I pack FP8 and FP4 without a flaw.
RMSNorm hums, AllReduce takes flight,
I swizzle scales by moonlit byte.
Carrots, bits, and speedy light!

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 57.14% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title 'feat: Add FP8/NVFP4 quant fusion for MNNVL Allreduce' is concise and clearly describes the main feature addition to the MNNVL Allreduce implementation.
Description check ✅ Passed The PR description covers the key objectives (adding FP8/NVFP4 quant fusion, supporting all 5 fusion patterns) and includes completed checklist items for pre-commit checks and tests, though it lacks detailed explanation of changes and related issues are not linked.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @timlee0212, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the MNNVL Allreduce functionality by integrating FP8 and NVFP4 quantization fusion. This allows for more memory-efficient and potentially faster distributed computations, especially when combined with RMSNorm. The changes involve updating core data structures, introducing a new set of quantization utility functions, and modifying existing CUDA kernels and dispatch mechanisms to support these new quantization types across all five unified allreduce interface fusion patterns.

Highlights

  • FP8/NVFP4 Quantization Fusion: Introduced support for FP8 and NVFP4 quantization directly into the MNNVL Allreduce operation, enabling more efficient data handling.
  • Unified Interface Support: The implementation now supports all five fusion patterns defined in the unified allreduce interface, enhancing versatility.
  • Quantization Parameters Extension: Extended the AllReduceFusionParams struct to include new fields for managing quantization outputs, scaling factors, and layout configurations.
  • Dedicated Quantization Functions: Added a new quant namespace with specialized CUDA device functions for performing FP8 and NVFP4 conversions, scaling, and other related mathematical operations.
  • Kernel Integration: Modified existing Allreduce and RMSNorm kernels (oneshotAllreduceFusionKernel, rmsNormLamport_fusion) to seamlessly incorporate the new quantization logic and parameters.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds FP8/NVFP4 quantization fusion for MNNVL Allreduce. The changes are extensive and introduce new quantization logic. My review focuses on correctness and maintainability. I've identified a critical bug in the kernel dispatch logic that needs to be fixed. Additionally, there is significant code duplication that should be addressed by refactoring the shared code into a common header file to improve maintainability. The author has already noted these duplications with TODO comments, and I've formalized them as review comments.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (2)
include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh (2)

40-45: Address the TODO: Consolidate duplicated enum definition.

The TODO comment indicates this QuantType enum is duplicated elsewhere. Consider moving it to a shared header to maintain a single source of truth and avoid potential divergence.

Would you like me to help identify the other location(s) where this enum is defined and suggest a consolidation approach?


1745-1745: Minor: Inconsistent naming kELTS_SIZE vs kELT_SIZE used elsewhere.

This constant is named kELTS_SIZE but similar constants in other kernels (e.g., line 1178 in oneshotAllreduceFusionKernel) use kELT_SIZE. Consider using consistent naming across kernels for maintainability.

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 25de38e and 4a4d184.

📒 Files selected for processing (1)
  • include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (5)
include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh (5)

47-72: LGTM!

The struct expansion is well-organized with sensible defaults. The explicit initialization of pointer fields to nullptr and enum fields to their "off" states is good practice for optional fusion parameters.


922-959: LGTM!

The SF offset calculation handles both batched and non-batched scenarios correctly. The architecture guard for SM100+ and graceful nullptr return for unsupported layouts are appropriate.


1162-1176: LGTM!

The kernel template extension is well-designed. The static_assert correctly enforces that quantization requires RMSNorm fusion, preventing invalid usage patterns at compile time.


1857-1927: LGTM!

The two-shot dispatch macros correctly extend the existing pattern to support quantization. The switch-based dispatch on QuantType and the macro cleanup with #undef are well-structured.


1140-1146: The reinterpret_cast at line 1141 is safe and correctly casts between compatible memory layouts. PackedVec<float4, half> (16 bytes: a union of float4 and half[8]) and vec_t<half, 8> (16 bytes: containing int4 data[1]) both occupy identical 128-bit memory with no padding, making direct reinterpretation valid.

@yzh119 yzh119 marked this pull request as ready for review December 24, 2025 06:24
@yzh119
Copy link
Collaborator

yzh119 commented Dec 24, 2025

Hi @timlee0212 is this PR ready? I noticed that you marked it as draft.

@timlee0212
Copy link
Contributor Author

Hi @timlee0212 is this PR ready? I noticed that you marked it as draft.

No it's still WIP. Convert it to draft.

@timlee0212 timlee0212 marked this pull request as draft December 31, 2025 06:07
@timlee0212 timlee0212 force-pushed the mnnvlar_quant_fusion branch from c66b72b to 4bf89f4 Compare January 5, 2026 07:43
@timlee0212 timlee0212 marked this pull request as ready for review January 9, 2026 09:38
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🤖 Fix all issues with AI agents
In @csrc/trtllm_mnnvl_allreduce.cu:
- Around line 112-123: The code lacks FP4 output shape validation and doesn't
validate the FP8 scale-factor output; add TVM_FFI_ICHECK checks similar to the
FP8 quant_out check: for QuantType::kFP4 ensure quant_out.has_value() and
quant_out.value().size(0) == num_tokens && quant_out.value().size(1) ==
token_dim, and for QuantType::kFP8 also validate sf_out when provided
(sf_out.has_value()) using TVM_FFI_ICHECK to ensure sf_out.value().size(0) ==
num_tokens && sf_out.value().size(1) == token_dim (or the correct expected
second dimension for scale factors in your design), producing an informative
error message that includes expected vs actual shapes; place these checks inside
the switch branches for QuantType::kFP4 and QuantType::kFP8 respectively,
alongside the existing quant_out validation.
- Around line 95-96: The error message in the TVM_FFI_ICHECK call inside
trtllm_mnnvl_allreduce (the check using quant_type_enum, rmsnorm_fusion and
sizeof(c_type)) contains a typo "Qaunt fusion"; update that string to "Quant
fusion" (e.g., change "Qaunt fusion is only supported with RMSNorm fusion and
FP16/BF16 dtype." to "Quant fusion is only supported with RMSNorm fusion and
FP16/BF16 dtype.") so the diagnostic text is correct.

In @flashinfer/comm/allreduce.py:
- Around line 696-706: The call to trtllm_mnnvl_fused_allreduce_add_rmsnorm
returns (norm_result, residual_result) but residual_result is unused, causing a
warning; fix this by either renaming the unused value to _residual_result (and
similarly _quant_result if the quantized path returns an unused value) to signal
intentional discard, or if the residual must be propagated, include
residual_result in the function's return so the non-quantized path matches the
quantized path (ensure callers handle the extra return); update the references
to norm_result, residual_result, and quant_result accordingly.
🧹 Nitpick comments (6)
csrc/trtllm_mnnvl_allreduce.cu (1)

78-85: Consider validating quantization output tensors when quantization is enabled.

Currently, the code validates that output is provided when quantization is disabled (lines 78-79). However, when quantization fusion is enabled (quant_type_enum != QuantType::kNone), the code should also validate that quant_out is provided, since that becomes the required output. This validation is performed later (lines 114-118) but only within the RMSNorm fusion block, which could allow invalid configurations to pass initial validation.

🔍 Suggested validation enhancement

Consider adding an explicit check after line 79:

 TVM_FFI_ICHECK(quant_type_enum != QuantType::kNone || output.has_value())
     << "Output tensor must be provided when quantization fusion is disabled";
+TVM_FFI_ICHECK(quant_type_enum == QuantType::kNone || quant_out.has_value())
+    << "Quantized output tensor must be provided when quantization fusion is enabled";
include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh (2)

50-55: Address or track the TODO comment about enum duplication.

The QuantType enum is marked with a TODO indicating it's duplicated. Duplicated enum definitions can lead to maintenance issues and inconsistencies across the codebase. Based on the PR context, this enum appears to be defined in multiple places (Python types in _types.py, potentially in other C++ headers).

Consider either:

  1. Moving this to a shared header (e.g., a common types header)
  2. Creating an issue to track the consolidation work if immediate refactoring is out of scope
  3. Removing the TODO if the duplication is intentional and necessary for build/dependency reasons

Based on learnings, for performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers.


549-1216: Add high-level documentation for the quantization utilities namespace.

The quant namespace introduces a substantial amount of new quantization logic (FP8/FP4 conversion, scale factor computation, layout handling). While individual functions have some inline comments, the namespace lacks a high-level overview explaining:

  • The quantization workflow and how these utilities fit together
  • Key algorithmic choices (e.g., why reciprocal_6 for e2m1, the SF layout computation strategy)
  • Performance considerations and optimizations applied

Based on learnings, for performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers.

💡 Suggested namespace-level documentation

Consider adding documentation at the namespace level (around line 557):

/**
 * @brief Quantization utilities for FP8 and FP4 output quantization.
 *
 * Workflow:
 * 1. quant_fp8/quant_nvfp4 are called from kernel fusion paths to quantize normalized outputs
 * 2. For FP4: cvt_warp_fp16_to_fp4 performs per-warp quantization with dynamic scaling
 *    - Computes local max across elements
 *    - Derives scale factor using e2m1 max value (6.0)
 *    - Applies scale and converts to e2m1 format via fp32_vec_to_e2m1
 * 3. SF layout helpers (cvt_quant_to_fp4_get_sf_out_offset) compute swizzled memory offsets
 *
 * Performance notes:
 * - Inline PTX for e2m1 conversion reduces register pressure
 * - Pipelined float2 processing minimizes intermediate storage
 * - Swizzled layouts optimize for downstream CUTLASS kernels
 */
namespace quant {
flashinfer/comm/allreduce.py (1)

642-647: Improve layout validation to handle None case.

The validation at line 644 checks if layout_code == QuantizationSFLayout.SWIZZLED_8x4, but layout_code is Optional[int] and could be None. When layout_code is None, the condition will be False and no error will be raised, which is correct. However, the logic would be clearer if this intent was explicit.

♻️ Optional clarification
-        if layout_code == QuantizationSFLayout.SWIZZLED_8x4:
+        if layout_code is not None and layout_code == QuantizationSFLayout.SWIZZLED_8x4:
             raise ValueError(
                 "MNNVL AllReduce does not support 8x4 swizzled sf layout. Please use 128x4 or linear layout instead."
             )

This makes the intent clearer that None is acceptable and only the specific SWIZZLED_8x4 value is problematic.

flashinfer/comm/_types.py (1)

171-187: Consider adding input validation for robustness.

The function currently raises a KeyError if an invalid pattern is passed. For a public API function, explicit validation with a descriptive error message would improve the developer experience.

🛡️ Suggested error handling
 def get_pattern_traits(pattern: int) -> FusionPatternTraits:
     """
     Get traits for an AllReduceFusionPattern.
 
     Args:
         pattern: AllReduceFusionPattern constant (0-5)
 
     Returns:
         FusionPatternTraits with all trait flags for the pattern
 
     Example:
         >>> traits = get_pattern_traits(AllReduceFusionPattern.kARResidualRMSNormFP8Quant)
         >>> traits.has_quant  # True
         >>> traits.has_rmsnorm  # True
         >>> traits.quant_type  # QuantFusionType.FP8
     """
+    if pattern not in _PATTERN_TRAITS:
+        raise ValueError(
+            f"Invalid fusion pattern: {pattern}. "
+            f"Expected one of {list(_PATTERN_TRAITS.keys())}"
+        )
     return _PATTERN_TRAITS[pattern]
flashinfer/comm/trtllm_mnnvl_ar.py (1)

613-656: Clarify sf_out usage in docstring.

The buffer allocation logic shows that sf_out is only allocated for NVFP4 quantization, not for FP8. However, the docstring at line 569 doesn't clarify this distinction. Consider updating the docstring to make this explicit.

📝 Suggested documentation improvement

Update the docstring around line 569:

         quant_out: Quantized output tensor [num_tokens, hidden_dim], empty tensor will be created if not provided.
-        sf_out: Scaling factor output tensor [num_tokens, hidden_dim], empty tensor will be created if not provided.
+        sf_out: Scaling factor output tensor (FP4 only; None for FP8), empty tensor will be created if not provided.
         output_scale: The global scale applied to quant output.

And in the Returns section around line 578:

     Returns:
         quant_out: Quantized output tensor [num_tokens, hidden_dim]
-        sf_out: Scaling factor output tensor [num_tokens, hidden_dim]
+        sf_out: Scaling factor output tensor (FP4 only; None for FP8)
         residual_out: Add-residual tensor [num_tokens, hidden_dim]
         output: Add-residual and normalized tensor [num_tokens, hidden_dim]
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 4a4d184 and 809b222.

📒 Files selected for processing (8)
  • csrc/trtllm_mnnvl_allreduce.cu
  • flashinfer/comm/__init__.py
  • flashinfer/comm/_types.py
  • flashinfer/comm/allreduce.py
  • flashinfer/comm/trtllm_ar.py
  • flashinfer/comm/trtllm_mnnvl_ar.py
  • include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh
  • tests/comm/test_trtllm_mnnvl_allreduce.py
🧰 Additional context used
📓 Path-based instructions (4)
flashinfer/**/*.py

📄 CodeRabbit inference engine (CLAUDE.md)

flashinfer/**/*.py: Use @functools.cache decorator on Python API functions to implement module-level caching and avoid recompilation
Use @flashinfer_api decorator for debugging API calls, enable via FLASHINFER_LOGLEVEL environment variable (0=off, 1=basic, 3=detailed, 5=with stats)

Files:

  • flashinfer/comm/__init__.py
  • flashinfer/comm/_types.py
  • flashinfer/comm/trtllm_mnnvl_ar.py
  • flashinfer/comm/allreduce.py
  • flashinfer/comm/trtllm_ar.py
tests/**/*.py

📄 CodeRabbit inference engine (CLAUDE.md)

tests/**/*.py: Test implementations should use flashinfer.utils functions (get_compute_capability, is_sm90a_supported, is_sm100a_supported, etc.) to skip tests on unsupported GPU architectures
For testing with mpirun on multi-GPU systems, use the pattern: mpirun -np <num_gpus> pytest tests/path/to/test.py::test_function
Avoid OOM (out-of-memory) errors in tests by using appropriate problem sizes - tests/conftest.py provides auto-skipping for OOM tests as a safety net but should not be relied upon

Files:

  • tests/comm/test_trtllm_mnnvl_allreduce.py
csrc/**/*.cu

📄 CodeRabbit inference engine (CLAUDE.md)

Framework bindings and PyTorch tensor handling should be implemented in csrc/ via TVM-FFI, not in include/ headers

Files:

  • csrc/trtllm_mnnvl_allreduce.cu
include/**/*.cuh

📄 CodeRabbit inference engine (CLAUDE.md)

include/**/*.cuh: Torch headers MUST NOT be included in files within the include/ directory - keep framework-agnostic CUDA kernels that accept raw pointers
Kernel code in include/flashinfer/ is automatically picked up by JIT compilation on changes - no pip reinstall needed
For performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers
Consult the PTX ISA documentation (https://docs.nvidia.com/cuda/parallel-thread-execution/) for low-level instruction details and new GPU architecture features when writing inline PTX assembly

Files:

  • include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh
🧠 Learnings (4)
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/__init__.py : Export new operations in `flashinfer/__init__.py` to make them available as public API

Applied to files:

  • flashinfer/comm/__init__.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to csrc/**/*.cu : Framework bindings and PyTorch tensor handling should be implemented in `csrc/` via TVM-FFI, not in `include/` headers

Applied to files:

  • flashinfer/comm/__init__.py
  • include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : For performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers

Applied to files:

  • include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh
🧬 Code graph analysis (6)
flashinfer/comm/__init__.py (1)
flashinfer/comm/_types.py (4)
  • AllReduceFusionPattern (37-55)
  • QuantizationSFLayout (75-88)
  • get_pattern_traits (171-187)
  • QuantFusionType (63-72)
tests/comm/test_trtllm_mnnvl_allreduce.py (1)
flashinfer/comm/_types.py (4)
  • AllReduceFusionPattern (37-55)
  • QuantFusionType (63-72)
  • get_pattern_traits (171-187)
  • has_quant (113-115)
csrc/trtllm_mnnvl_allreduce.cu (1)
flashinfer/comm/_types.py (1)
  • QuantizationSFLayout (75-88)
flashinfer/comm/trtllm_mnnvl_ar.py (5)
flashinfer/fp4_quantization.py (1)
  • _compute_swizzled_layout_sf_size (47-50)
flashinfer/fp8_quantization.py (1)
  • _compute_swizzled_layout_sf_size (16-19)
flashinfer/comm/_types.py (2)
  • QuantFusionType (63-72)
  • QuantizationSFLayout (75-88)
flashinfer/comm/mapping.py (2)
  • rank (311-312)
  • rank (315-322)
flashinfer/comm/allreduce.py (1)
  • is_buffer_size_sufficient (159-174)
flashinfer/comm/allreduce.py (1)
flashinfer/comm/_types.py (4)
  • AllReduceFusionPattern (37-55)
  • QuantizationSFLayout (75-88)
  • get_pattern_traits (171-187)
  • has_quant (113-115)
flashinfer/comm/trtllm_ar.py (1)
flashinfer/comm/_types.py (2)
  • AllReduceFusionPattern (37-55)
  • QuantizationSFLayout (75-88)
🪛 Ruff (0.14.10)
flashinfer/comm/trtllm_mnnvl_ar.py

587-589: Avoid specifying long messages outside the exception class

(TRY003)


591-593: Avoid specifying long messages outside the exception class

(TRY003)


595-597: Avoid specifying long messages outside the exception class

(TRY003)


602-604: Avoid specifying long messages outside the exception class

(TRY003)


619-621: Avoid specifying long messages outside the exception class

(TRY003)


630-632: Avoid specifying long messages outside the exception class

(TRY003)


653-655: Avoid specifying long messages outside the exception class

(TRY003)


666-668: Avoid specifying long messages outside the exception class

(TRY003)

flashinfer/comm/allreduce.py

696-696: Unpacked variable residual_result is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

🔇 Additional comments (23)
include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh (1)

1371-1385: LGTM! Proper handling of quantization paths with appropriate guards.

The conditional quantization calls are correctly guarded by:

  1. Compile-time template parameter (QuantType QType)
  2. CUDA version check for FP4 (CUDA_VERSION >= 12080)
  3. Runtime null checks for output pointers

This ensures the quantization paths are only executed when properly supported and configured.

flashinfer/comm/trtllm_ar.py (1)

65-67: LGTM! Type definitions properly consolidated.

The refactoring correctly imports AllReduceFusionPattern and QuantizationSFLayout from the canonical location (_types.py) and re-exports them for backwards compatibility. This eliminates duplication and provides a single source of truth for these type definitions.

flashinfer/comm/__init__.py (1)

5-9: LGTM! Public API properly updated with consolidated types.

The changes correctly update the public API exports to use the centralized type definitions from _types.py (via allreduce.py). The explicit as aliasing ensures backwards compatibility while moving to a single source of truth for these types.

Key improvements:

  • Removes duplicate type definitions previously imported from trtllm_ar
  • Adds QuantFusionType to the public API surface
  • Adds get_pattern_traits utility for pattern-based trait queries

This aligns with the PR's goal of consolidating type definitions across the codebase.

Based on learnings, export new operations in flashinfer/__init__.py to make them available as public API.

flashinfer/comm/allreduce.py (3)

60-62: LGTM! Proper type consolidation and public re-exports.

The changes correctly import the canonical type definitions from _types.py and re-export them for public API convenience. This provides a clean, unified interface for consumers of the allreduce module.


662-706: Pattern trait definitions correctly implement RMSNorm fusion requirements.

All patterns with RMSNorm fusion properly set the has_rmsnorm trait in _types.py. The validation logic (lines 665-668) correctly validates residual_in and rms_gamma for all five RMSNorm fusion patterns. The pattern traits appropriately distinguish between patterns with norm_out (non-quantized and quantized-with-norm variants) and those without (quantized-without-norm variants), and the code properly allocates norm_out based on pattern_traits.has_norm_out.


676-693: Remove the FIXME comment—buffer allocation is properly handled by the function.

The function trtllm_mnnvl_fused_allreduce_add_rmsnorm_quant() explicitly handles buffer allocation internally. Its documentation states "empty tensor will be created if not provided" for quant_out and sf_out, and the implementation confirms this with explicit allocation when these parameters are None. The calling code correctly omits these parameters to rely on this internal allocation. The FIXME comment's uncertainty is unfounded—the function design is correct.

No additional validation is needed for scale_factor (has default value 1.0 and internal tensor conversion) or layout_code (has default QuantizationSFLayout.SWIZZLED_128x4 and built-in validation for FP4 paths).

Likely an incorrect or invalid review comment.

flashinfer/comm/_types.py (5)

37-56: LGTM: Well-structured fusion pattern enumeration.

The fusion pattern constants are clearly documented and logically organized from basic AllReduce through progressively more complex fusion patterns with quantization variants.


63-72: LGTM: Clear quantization type enumeration.


75-88: LGTM: Scale factor layout enumeration with helpful documentation.

The detailed comment about swizzled layout memory organization (512-byte blocks, 128x4 FP8 values) will be valuable for future maintainers.


96-115: LGTM: Well-designed immutable traits dataclass.

Using frozen=True ensures traits remain constant, and the has_quant property provides a convenient query method.


119-168: LGTM: Comprehensive and accurate pattern trait definitions.

The trait mappings correctly distinguish between fusion patterns, especially the subtle difference between quantization patterns with and without normalized output (has_norm_out).

flashinfer/comm/trtllm_mnnvl_ar.py (7)

10-10: LGTM: Necessary imports for quantization support.

The imports bring in required types and utilities for the new quantization fusion functionality.

Also applies to: 20-20, 34-37


280-353: LGTM: Well-organized expansion of the fusion API.

The signature is logically grouped (Primary I/O, Communication infrastructure, Distributed configuration, Kernel control flags, RMSNorm fusion, Quantization) with clear comments. The default values are sensible and backward-compatible.


417-429: LGTM: Clear use of named arguments.

Explicitly passing rmsnorm_fusion=False makes the intent clear that this function is for basic AllReduce without fusion.


509-526: LGTM: Consistent named argument usage.

The explicit rmsnorm_fusion=True clarifies this path performs fused RMSNorm operations.


583-611: LGTM: Thorough input validation.

The shape validation and output_scale normalization to float32 tensor are handled correctly.


657-693: LGTM: Proper kernel invocation and return handling.

The strategy selection, workspace validation, and kernel call are consistent with the existing codebase patterns.


813-825: LGTM: Legacy functions updated consistently.

The deprecated functions are appropriately updated with named arguments while maintaining backward compatibility.

Also applies to: 902-918

tests/comm/test_trtllm_mnnvl_allreduce.py (5)

2-4: LGTM: Well-structured test helpers and imports.

The path manipulation enables direct script execution, and the quantization helpers (fp8_quant, dequant) provide clear reference implementations for validation.

Also applies to: 12-12, 19-23, 29-40


42-228: LGTM: Comprehensive pattern-based test validation.

The test logic correctly branches on fusion patterns and thoroughly validates quantized outputs with appropriate tolerance thresholds (0.2% for FP8, 1% for FP4). The mismatch ratio calculation is a good approach for handling quantization errors.

Note: The FIXME at line 187 about global_scale appears to be a pre-existing issue with the quantization API, not introduced by this PR.


360-400: LGTM: Clean test data preparation.

The function correctly prepares reference outputs based on the fusion pattern, using rmsnorm from flashinfer for ground truth.


403-569: LGTM: Consistent pattern parameter usage.

The function is updated consistently to use the pattern enum instead of boolean fusion flag, improving test clarity.


577-625: LGTM: Excellent test coverage.

The parametrization covers all six fusion patterns (including FP8/FP4 variants) in the refactored API test, while appropriately limiting the legacy test to the two patterns it supports. The variety of sequence lengths and hidden sizes provides thorough validation.

Comment on lines +696 to 706
norm_result, residual_result = trtllm_mnnvl_fused_allreduce_add_rmsnorm(
input=input,
residual_in=residual_in,
gamma=rms_gamma,
workspace=workspace,
epsilon=rms_eps,
output=norm_out,
residual_out=residual_out,
launch_with_pdl=launch_with_pdl,
)
return norm_result
Copy link
Contributor

@coderabbitai coderabbitai bot Jan 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix unused variable warning.

Static analysis reports that residual_result (returned from line 696) is unused. This suggests one of two issues:

  1. The variable should be used or returned somewhere
  2. The function returns an extra value that isn't needed

Looking at the non-quantized path (lines 696-706), it returns norm_result. The quantized path (line 693) returns quant_result. This asymmetry might indicate missing logic.

🔧 Suggested fix

If residual_result and norm_result from the quantized path are not needed, use underscore prefixes:

-                quant_result, scale_out, residual_out, norm_result = (
+                quant_result, scale_out, _residual_out, _norm_result = (
                     trtllm_mnnvl_fused_allreduce_add_rmsnorm_quant(

Alternatively, if these outputs should be validated or used, add appropriate logic.

Committable suggestion skipped: line range outside the PR's diff.

🧰 Tools
🪛 Ruff (0.14.10)

696-696: Unpacked variable residual_result is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

🤖 Prompt for AI Agents
In @flashinfer/comm/allreduce.py around lines 696 - 706, The call to
trtllm_mnnvl_fused_allreduce_add_rmsnorm returns (norm_result, residual_result)
but residual_result is unused, causing a warning; fix this by either renaming
the unused value to _residual_result (and similarly _quant_result if the
quantized path returns an unused value) to signal intentional discard, or if the
residual must be propagated, include residual_result in the function's return so
the non-quantized path matches the quantized path (ensure callers handle the
extra return); update the references to norm_result, residual_result, and
quant_result accordingly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is out of the scope of this PR as it is related to the unified allreduce API.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@timlee0212, understood! I'll leave this for the unified allreduce API work. Thanks for clarifying the scope.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like the humans are having a chat. I'll hop back into my burrow for now. If you need me again, just tag @coderabbitai in a new comment, and I'll come hopping out!

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (3)
include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh (3)

1647-1831: RMSNorm fusion kernel: make kELTS_PER_LOAD constexpr + validate required outputs for each fusion pattern.

Minor:

  • static int const kELTS_PER_LOAD = ... should be static constexpr int ... in device code.

Correctness/API:

  • outputPreNorm is written unconditionally, outputNorm is optional; if patterns allow “norm-only” or “quant-only”, make sure host validates pointer expectations (especially now that quantType can request extra outputs).

498-542: Enforce warp-multiple blockSize alignment for FP4 quantization path.

The adjustGridConfig() function can return a blockSize that is not a multiple of 32 (e.g., ceil_div(50, 2) = 25). When the FP4 quantization path is enabled (QuantType::kFP4), the kernel uses __shfl_xor_sync(uint32_t(-1), ...) for warp-level reductions in cvt_warp_fp16_to_fp4(). With a partial warp (fewer than 32 active lanes), the full-warp mask causes undefined behavior and corrupts quantization results.

Fix: Add alignment constraint in adjustGridConfig() to ensure blockSize % 32 == 0 when the kernel will be instantiated with QuantType::kFP4, or update the shuffle operations to use __activemask() to handle partial warps correctly.


1224-1386: Add precondition checks to dispatch: outputScale null-safety and tokenDim divisibility.

Three real safety gaps in oneshotAllreduceFusionDispatch:

  • FP8/FP4 branches dereference outputScale (*outputScale) without null validation; must check it's non-null when quantType != QuantType::kNone.
  • FP4 conversion uses __shfl_xor_sync(uint32_t(-1), ...) in cvt_warp_fp16_to_fp4 with no guard for partial warps; add __activemask() or verify blockSize guarantees 32-thread occupancy.
  • tokenDim / ELTS_PER_THREAD integer division in quant_nvfp4 (line 1211) requires tokenDim % ELTS_PER_THREAD == 0; add FLASHINFER_CHECK matching the requirement in twoshotAllreduceFusionDispatch.
🤖 Fix all issues with AI agents
In @include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh:
- Around line 63-100: The host must validate quantization-related fields in
AllReduceFusionParams before dispatch: add checks where the params are consumed
(e.g., the allreduce launch/dispatch function that accepts
AllReduceFusionParams) to enforce that if params.quantType != QuantType::kNone
then params.rmsNormFusion == true, and for each quant type require non-null
pointers: if quantType == QuantType::kFP8 ensure params.quantOut != nullptr &&
params.outputScale != nullptr, and if quantType == QuantType::kFP4 ensure
params.quantOut != nullptr && params.scalingFactorOut != nullptr &&
params.outputScale != nullptr; return an error or abort launch with a clear
message when any check fails to avoid silent ignores or kernel dereferences.
- Around line 571-715: The explicit specializations for cuda_cast are malformed
(they omit the template-id) and will not compile; change each to the form
cuda_cast<OUT,IN>(IN) (e.g., cuda_cast<__nv_bfloat16,int32_t>(int32_t) and
cuda_cast<__nv_bfloat16,int8_t>(int8_t),
cuda_cast<int8_t,__nv_bfloat16>(__nv_bfloat16), etc.) matching the pattern used
elsewhere (refer to existing cuda_cast<float2,int2>, cuda_cast<half2,float2>,
etc.), and ensure bf1622float2 remains unchanged. Replace any runtime
assert(false) fallback in cuda_abs with a compile-time failure using a dependent
static_assert (e.g., use a dependent_false<T> helper and
static_assert(dependent_false<T>::value, "Unsupported type for cuda_abs")) so
unsupported types fail at compile time rather than relying on assert/include
availability.
- Around line 30-35: The device functions get_sf_out_offset_128x4 and
cvt_quant_to_fp4_get_sf_out_offset use std::optional which is not safe in CUDA
device code; replace with cuda::std::optional by including the CUDA-compatible
header and updating parameter types and any usages of std::optional (e.g.,
change std::optional<int> to cuda::std::optional<int> and keep .value_or() calls
as-is). Specifically, add or switch to the CUDA optional header (e.g., include
<cuda/std/optional> or ensure cuda::std::optional is available), and update the
function signatures and any local variable declarations from std::optional to
cuda::std::optional to match the pattern used in trtllm_allreduce_fusion.cuh and
trtllm_moe_allreduce_fusion.cuh.
- Around line 1435-1465: The dispatch currently ignores requested quantization
when params.rmsNormFusion is false; in oneshotAllreduceFusionDispatch add
runtime validation before the DISPATCH_ALLREDUCE_KERNEL call: if
(!params.rmsNormFusion && params.quantType != QuantType::kNone) emit a
FLASHINFER_ERROR and return cudaErrorInvalidValue; if (params.quantType !=
QuantType::kNone) check that params.quantOut and params.outputScale are
non-null/valid and error+return if not; and if (params.quantType ==
QuantType::kFP4) check CUDA_VERSION >= 12080 and error+return otherwise. Use the
same error reporting pattern (FLASHINFER_ERROR(...) followed by return
cudaErrorInvalidValue) and reference oneshotAllreduceFusionDispatch,
params.quantType, params.rmsNormFusion, params.quantOut, params.outputScale, and
the DISPATCH_ALLREDUCE_KERNEL macro.
- Around line 50-55: QuantType enum is duplicated; create a single shared header
(e.g., trtllm_common.cuh) that defines enum class QuantType : int { kNone=0,
kFP8=1, kFP4=2 }; remove the duplicate enum definitions from
trtllm_allreduce_fusion.cuh and trtllm_mnnvl_allreduce.cuh and instead include
the new trtllm_common.cuh in both files; ensure the new header has proper
include guards/pragma once and update any includes or forward declarations so
all translation units use the single QuantType definition.
🧹 Nitpick comments (1)
include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh (1)

549-557: Add brief rationale comments for the inline PTX + scaling math choices (future-proofing).

Given these are hot-path kernels and you’re using inline PTX (cvt.*e2m1x2, rcp.approx.ftz) and a non-obvious SF recipe (SFScaleVal / quantized_sf), a short “why this is safe / alternative considered” comment will help reviewers maintain it across CUDA/SM changes. Based on learnings, this is especially useful in include/**/*.cuh hot paths.

Also applies to: 1005-1168

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 809b222 and 5b1d53b.

📒 Files selected for processing (1)
  • include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh
🧰 Additional context used
📓 Path-based instructions (1)
include/**/*.cuh

📄 CodeRabbit inference engine (CLAUDE.md)

include/**/*.cuh: Torch headers MUST NOT be included in files within the include/ directory - keep framework-agnostic CUDA kernels that accept raw pointers
Kernel code in include/flashinfer/ is automatically picked up by JIT compilation on changes - no pip reinstall needed
For performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers
Consult the PTX ISA documentation (https://docs.nvidia.com/cuda/parallel-thread-execution/) for low-level instruction details and new GPU architecture features when writing inline PTX assembly

Files:

  • include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh
🧠 Learnings (3)
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : For performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers

Applied to files:

  • include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to csrc/**/*.cu : Framework bindings and PyTorch tensor handling should be implemented in `csrc/` via TVM-FFI, not in `include/` headers

Applied to files:

  • include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (2)
include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh (2)

1510-1512: Twoshot block sizing change looks consistent with launch config.

twoshotAllreduceKernel now has __launch_bounds__(config::kPreferredBlockSize) and dispatch always launches with config::kPreferredBlockSize, so the compiler hint matches runtime.

Also applies to: 1850-1861


1934-2010: Add CUDA_VERSION >= 12080 guard for FP4 at host dispatch.

FP4 quantization requires CUDA 12.8+, but the dispatch code only checks the output type at compile-time with if constexpr (std::is_same_v<T, half> || std::is_same_v<T, __nv_bfloat16>). Add a runtime CUDA_VERSION guard (CUDA_VERSION >= 12080) to the FP4 case in the dispatch to prevent compiling/launching unsupported kernels on earlier CUDA versions.

Note: Param validation for quantType != kNone requiring params.quantOut, params.outputScale, and params.scalingFactorOut (for FP4) already exists at lines 1870-1932.

Likely an incorrect or invalid review comment.

Comment on lines +50 to +55
// TODO: This enum definition is duplicated
enum class QuantType : int {
kNone = 0,
kFP8 = 1,
kFP4 = 2,
};
Copy link
Contributor

@coderabbitai coderabbitai bot Jan 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

rg -n --hidden --no-ignore-vcs 'enum class QuantType\b|QuantType::kFP4|QuantType::kFP8'

Repository: flashinfer-ai/flashinfer

Length of output: 3229


🏁 Script executed:

sed -n '729,735p' include/flashinfer/comm/trtllm_allreduce_fusion.cuh

Repository: flashinfer-ai/flashinfer

Length of output: 179


🏁 Script executed:

sed -n '51,57p' include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh

Repository: flashinfer-ai/flashinfer

Length of output: 141


🏁 Script executed:

rg -l 'trtllm_allreduce_fusion\.cuh|trtllm_mnnvl_allreduce\.cuh' --include='*.cu' --include='*.cuh'

Repository: flashinfer-ai/flashinfer

Length of output: 509


🏁 Script executed:

rg -l 'include.*trtllm_allreduce_fusion\.cuh|include.*trtllm_mnnvl_allreduce\.cuh'

Repository: flashinfer-ai/flashinfer

Length of output: 130


🏁 Script executed:

rg -n '#include.*trtllm_allreduce|#include.*trtllm_mnnvl' csrc/

Repository: flashinfer-ai/flashinfer

Length of output: 319


🏁 Script executed:

head -20 include/flashinfer/comm/trtllm_allreduce_fusion.cuh && echo "---" && head -20 include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh

Repository: flashinfer-ai/flashinfer

Length of output: 1155


🏁 Script executed:

rg -l 'trtllm_allreduce_fusion\|trtllm_mnnvl_allreduce' include/flashinfer/

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

ls -la include/flashinfer/comm/

Repository: flashinfer-ai/flashinfer

Length of output: 715


🏁 Script executed:

rg -n '#include.*trtllm_allreduce_fusion|#include.*trtllm_mnnvl_allreduce' include/

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

rg -n 'include.*trtllm_allreduce_fusion|include.*trtllm_mnnvl_allreduce' .

Repository: flashinfer-ai/flashinfer

Length of output: 247


🏁 Script executed:

python3 << 'EOF'
import re
import os

# Find all files that include both headers
trtllm_fusion_includers = set()
trtllm_mnnvl_includers = set()

for root, dirs, files in os.walk('.'):
    # Skip build directories
    dirs[:] = [d for d in dirs if d not in ['build', 'dist', 'node_modules', '.git']]
    
    for file in files:
        if file.endswith(('.cuh', '.cu', '.h', '.cpp', '.hpp')):
            filepath = os.path.join(root, file)
            try:
                with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
                    content = f.read()
                    if 'trtllm_allreduce_fusion.cuh' in content:
                        trtllm_fusion_includers.add(filepath)
                    if 'trtllm_mnnvl_allreduce.cuh' in content:
                        trtllm_mnnvl_includers.add(filepath)
            except:
                pass

print("Files including trtllm_allreduce_fusion.cuh:")
for f in sorted(trtllm_fusion_includers):
    print(f"  {f}")

print("\nFiles including trtllm_mnnvl_allreduce.cuh:")
for f in sorted(trtllm_mnnvl_includers):
    print(f"  {f}")

print("\nFiles including BOTH:")
both = trtllm_fusion_includers & trtllm_mnnvl_includers
if both:
    for f in sorted(both):
        print(f"  {f}")
else:
    print("  None")
EOF

Repository: flashinfer-ai/flashinfer

Length of output: 154


🏁 Script executed:

rg 'QuantType' include/flashinfer/ -l

Repository: flashinfer-ai/flashinfer

Length of output: 229


🏁 Script executed:

rg -n 'enum class QuantType|QuantType::' include/flashinfer/pos_enc.cuh include/flashinfer/norm.cuh

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

rg -n 'include.*trtllm_allreduce_fusion' include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh include/flashinfer/comm/trtllm_allreduce_fusion.cuh

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

rg -n 'include.*trtllm' include/flashinfer/pos_enc.cuh include/flashinfer/norm.cuh

Repository: flashinfer-ai/flashinfer

Length of output: 320


🏁 Script executed:

grep -n 'QuantType' include/flashinfer/pos_enc.cuh include/flashinfer/norm.cuh

Repository: flashinfer-ai/flashinfer

Length of output: 4179


Consolidate duplicated QuantType enum to a single shared header.

The enum is defined identically in both trtllm_allreduce_fusion.cuh (line 729) and trtllm_mnnvl_allreduce.cuh (line 51). While currently these headers are included in separate translation units, duplicating the definition creates maintenance burden and risk of divergence if values or order change in the future. Move this to a shared header (e.g., a new trtllm_common.cuh) and reuse it across both files.

🤖 Prompt for AI Agents
In @include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh around lines 50 - 55,
QuantType enum is duplicated; create a single shared header (e.g.,
trtllm_common.cuh) that defines enum class QuantType : int { kNone=0, kFP8=1,
kFP4=2 }; remove the duplicate enum definitions from trtllm_allreduce_fusion.cuh
and trtllm_mnnvl_allreduce.cuh and instead include the new trtllm_common.cuh in
both files; ensure the new header has proper include guards/pragma once and
update any includes or forward declarations so all translation units use the
single QuantType definition.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be left as a TODO for this framework and is out of the scope of this PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@timlee0212, understood! I'll defer this refactoring to a future PR. The existing TODO comment in the code already documents this for follow-up work.

If you'd like, I can create a GitHub issue to track the consolidation of the QuantType enum definition into a shared header once this PR is merged. Just let me know!


🧠 Learnings used
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : For performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers

Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to csrc/**/*.cu : Framework bindings and PyTorch tensor handling should be implemented in `csrc/` via TVM-FFI, not in `include/` headers

Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Torch headers MUST NOT be included in files within the `include/` directory - keep framework-agnostic CUDA kernels that accept raw pointers

Comment on lines +1025 to +1215
#if CUDA_VERSION >= 12080
// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).
// NOTE: bypass sm_100 requirement by __nv_cvt_float2_to_fp4x2
inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
uint32_t val;
asm volatile(
"{\n"
".reg .b8 byte0;\n"
".reg .b8 byte1;\n"
".reg .b8 byte2;\n"
".reg .b8 byte3;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n"
"mov.b32 %0, {byte0, byte1, byte2, byte3};\n"
"}"
: "=r"(val)
: "f"(array[0]), "f"(array[1]), "f"(array[2]), "f"(array[3]), "f"(array[4]), "f"(array[5]),
"f"(array[6]), "f"(array[7]));
return val;
#else
uint32_t val;
__nv_fp4x2_storage_t vals[4];
#pragma unroll
for (int i = 0; i < 4; i++) {
vals[i] = __nv_cvt_float2_to_fp4x2(*(((float2*)array) + i), __NV_E2M1, cudaRoundNearest);
}
val = pack_bytes(vals[0], vals[1], vals[2], vals[3]);
return val;
#endif
}

// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t).
inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
uint32_t val;
asm volatile(
"{\n"
".reg .b8 byte0;\n"
".reg .b8 byte1;\n"
".reg .b8 byte2;\n"
".reg .b8 byte3;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n"
"mov.b32 %0, {byte0, byte1, byte2, byte3};\n"
"}"
: "=r"(val)
: "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y), "f"(array[2].x),
"f"(array[2].y), "f"(array[3].x), "f"(array[3].y));
return val;
#else
uint32_t val;
__nv_fp4x2_storage_t vals[4];
#pragma unroll
for (int i = 0; i < 4; i++) {
vals[i] = __nv_cvt_float2_to_fp4x2(array[i], __NV_E2M1, cudaRoundNearest);
}
val = pack_bytes(vals[0], vals[1], vals[2], vals[3]);
return val;
#endif
}

// Quantizes the provided PackedVec into the uint32_t output
template <typename T, uint32_t VEC_SIZE, bool UE8M0_SF = false>
__device__ uint32_t cvt_warp_fp16_to_fp4(vec_t<T, VEC_SIZE>& vec, float SFScaleVal,
uint8_t* SFout) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
// Pre-compute constant: reciprocal of 6.0 (maximum value of e2m1)
static constexpr float RECIPROCAL_6 = 1.0f / 6.0f;
// Get absolute maximum values among the local 8 values.
auto localMax = maths::cuda_abs(get_vec2_element(vec, 0));

#pragma unroll
for (int i = 1; i < details::CVT_FP4_ELTS_PER_THREAD / 2; i++) {
localMax = maths::cuda_max(localMax, maths::cuda_abs(get_vec2_element(vec, i)));
}

// Get the absolute maximum among all 16 values (two threads).
localMax = maths::cuda_max(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax);
// Get the final absolute maximum values.
float vecMax = float(maths::cuda_max(localMax.x, localMax.y));

// Get the SF (max value of the vector / max value of e2m1).
// maximum value of e2m1 = 6.0.
// Optimization: compute quantized SF directly, avoid storing intermediate SFValue
uint8_t fp8SFVal;
float quantized_sf;

if constexpr (UE8M0_SF) {
#if (__CUDACC_VER_MAJOR__ * 1000 + __CUDACC_VER_MINOR__ * 10 >= 12080)
__nv_fp8_e8m0 tmp;
float sf_value = SFScaleVal * (vecMax * RECIPROCAL_6);
tmp.__x = __nv_cvt_float_to_e8m0(sf_value, __NV_SATFINITE, cudaRoundPosInf);
quantized_sf = static_cast<float>(tmp);
fp8SFVal = tmp.__x;
#else
#error "FP8 E8M0 support requires CUDA 12.8 or newer."
#endif
} else {
// Here SFValue is always positive, so E4M3 is the same as UE4M3.
__nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFScaleVal * (vecMax * RECIPROCAL_6));
fp8SFVal = tmp.__x;
quantized_sf = static_cast<float>(tmp);
}
// Get the output scale directly (optimization: avoid storing intermediate SFValue)
// Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) * reciprocal(SFScaleVal))
// Optimization: mathematically equivalent to SFScaleVal / quantized_sf, but more efficient
// (reduces 1 reciprocal call and 1 multiply operation)
float outputScale = quantized_sf != 0 ? SFScaleVal / quantized_sf : 0.0f;

if (SFout) {
// Write the SF to global memory (STG.8).
*SFout = fp8SFVal;
}

// Convert the input to float and quantize (pipelined to reduce register usage).
// Optimization: use single float2 instead of array to reduce register pressure from 32 bytes to 8
// bytes
uint32_t e2m1Vec = 0;

#pragma unroll
for (int i = 0; i < details::CVT_FP4_ELTS_PER_THREAD / 2; i++) {
// Reuse single float2 register instead of array
float2 fp2Val;
if constexpr (std::is_same_v<T, half>) {
fp2Val = __half22float2(get_vec2_element(vec, i));
} else {
fp2Val = __bfloat1622float2(get_vec2_element(vec, i));
}
fp2Val.x *= outputScale;
fp2Val.y *= outputScale;

// Convert pair immediately and pack into result
uint8_t e2m1Pair = fp32_pair_to_e2m1(fp2Val);
e2m1Vec |= (static_cast<uint32_t>(e2m1Pair) << (i * 8));
}

// Write the e2m1 values to global memory.
return e2m1Vec;
#else
return 0;
#endif
}

#endif

// ============================== Quant Device Function ==============================
template <typename T, typename PackedType, int ELTS_PER_THREAD>
inline __device__ void quant_fp8(PackedVec<PackedType, T> packedAccum, void* quantOutPtr,
float invOutputScale, uint32_t threadOffset) {
static_assert(ELTS_PER_THREAD == 8 || ELTS_PER_THREAD == 4, "ELTS_PER_THREAD must be 8 or 4");
using QuantizedPackedType = std::conditional_t<ELTS_PER_THREAD == 8, float2, float>;

auto quantOut = reinterpret_cast<__nv_fp8_e4m3*>(quantOutPtr);
PackedVec<QuantizedPackedType, __nv_fp8_e4m3> quantizedAccum;
#pragma unroll
for (int i = 0; i < ELTS_PER_THREAD; i++) {
quantizedAccum.elements[i] =
__nv_fp8_e4m3(toFloat<T>(packedAccum.elements[i]) * invOutputScale);
}
reinterpret_cast<QuantizedPackedType*>(&quantOut[threadOffset])[0] = quantizedAccum.packed;
}

template <typename T, typename PackedType, int ELTS_PER_THREAD>
inline __device__ void quant_nvfp4(PackedVec<PackedType, T> packedAccum, void* quantOutPtr,
void* sfOutPtr, float* outputScale, uint32_t tokenIdx,
uint32_t tokenDim, uint32_t packedIdx,
QuantizationSFLayout sfLayout) {
static_assert(
ELTS_PER_THREAD == 8 && (std::is_same_v<T, half> || std::is_same_v<T, __nv_bfloat16>),
"NVFP4 quantization fusion is only supported for FP16/BF16!");

// Cast the packed accumulator
auto packedAccum_ = *reinterpret_cast<vec_t<T, ELTS_PER_THREAD>*>(&packedAccum);
// SFType is only the pointer type; It does not affect the internal logic of offset calculation.
// Get the target pointer to the SF output.
auto sfOut =
cvt_quant_to_fp4_get_sf_out_offset<uint32_t, details::CVT_FP4_SF_VEC_SIZE / ELTS_PER_THREAD>(
std::nullopt, tokenIdx, packedIdx, std::nullopt, tokenDim /* numCols, don't divide*/,
reinterpret_cast<uint32_t*>(sfOutPtr), sfLayout);

// Calculate the offset in packed item granularity for the quant output
uint32_t quantOutOffset = tokenIdx * tokenDim / ELTS_PER_THREAD + packedIdx;
// Each packedvec has 8 elements -> 1 float4 in input -> 1 uint32_t in output
reinterpret_cast<uint32_t*>(quantOutPtr)[quantOutOffset] =
cvt_warp_fp16_to_fp4<T, ELTS_PER_THREAD, false>(packedAccum_, *outputScale, sfOut);
}
Copy link
Contributor

@coderabbitai coderabbitai bot Jan 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

# First, let's find the exact location and guard status of quant_nvfp4 and check for numKTiles
rg -A 20 "template.*quant_nvfp4" include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Search for numKTiles and LINEAR layout code
rg -B 5 -A 5 "numKTiles.*CVT_FP4_SF_VEC_SIZE" include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh

Repository: flashinfer-ai/flashinfer

Length of output: 566


🏁 Script executed:

# Check if quant_nvfp4 is inside or outside #if CUDA_VERSION guards
rg -B 30 "inline __device__ void quant_nvfp4" include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh | head -50

Repository: flashinfer-ai/flashinfer

Length of output: 1295


🏁 Script executed:

# Get the full quant_nvfp4 function definition to see if it calls cvt_warp_fp16_to_fp4
rg -A 30 "inline __device__ void quant_nvfp4" include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh

Repository: flashinfer-ai/flashinfer

Length of output: 1653


🏁 Script executed:

# Check what guards surround quant_nvfp4 - look for broader context
rg -B 50 "inline __device__ void quant_nvfp4" include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh | tail -60

Repository: flashinfer-ai/flashinfer

Length of output: 2014


🏁 Script executed:

# Check the broader context around quant_nvfp4 to see if there are any guards
cat -n include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh | sed -n '1200,1250p'

Repository: flashinfer-ai/flashinfer

Length of output: 2878


🏁 Script executed:

# Check CVT_FP4_SF_VEC_SIZE definition and context to understand divisibility guarantees
rg "CVT_FP4_SF_VEC_SIZE" include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh

Repository: flashinfer-ai/flashinfer

Length of output: 329


🏁 Script executed:

# Look for ceil_div or similar utility functions in the file
rg "ceil_div|ceildiv" include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh

Repository: flashinfer-ai/flashinfer

Length of output: 716


🏁 Script executed:

# Find the opening of the file section where quant_nvfp4 is defined to check for guards
cat -n include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh | sed -n '1150,1200p'

Repository: flashinfer-ai/flashinfer

Length of output: 2547


🏁 Script executed:

# Search for where quant_nvfp4 is actually called/instantiated to understand if the guard issue manifests
rg "quant_nvfp4" include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh

Repository: flashinfer-ai/flashinfer

Length of output: 342


🏁 Script executed:

# Check the context around the numKTiles calculation to see if there are comments about divisibility guarantees
rg -B 10 -A 10 "numKTiles = numCols" include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh

Repository: flashinfer-ai/flashinfer

Length of output: 888


🏁 Script executed:

# Check if tokenDim is guaranteed to be divisible by CVT_FP4_SF_VEC_SIZE
rg -B 5 -A 5 "tokenDim" include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh | head -40

Repository: flashinfer-ai/flashinfer

Length of output: 2029


🏁 Script executed:

# Search for any comments or assertions about divisibility requirements
rg "divisib|multiple of|guarantee" include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh

Repository: flashinfer-ai/flashinfer

Length of output: 501


🏁 Script executed:

# Check if quant_nvfp4 is called anywhere and what context it's called in
rg -B 5 "quant::quant_nvfp4" include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh

Repository: flashinfer-ai/flashinfer

Length of output: 757


🏁 Script executed:

# Check the context around that FIXME comment to understand the padding guarantees
rg -B 10 -A 10 "numCols.*multiple of 64" include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh

Repository: flashinfer-ai/flashinfer

Length of output: 852


🏁 Script executed:

# Check where numCols is padded or adjusted before being used in LINEAR layout
rg "numCols.*pad|round.*numCols" include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh

Repository: flashinfer-ai/flashinfer

Length of output: 181


🏁 Script executed:

# Find the exact line numbers for the #if CUDA_VERSION >= 12080 guard and #endif
grep -n "#if CUDA_VERSION >= 12080\|#endif" include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 331


🏁 Script executed:

# Search for any divisibility requirements or assertions for LINEAR layout or tokenDim
rg -B 5 -A 5 "LINEAR.*divisib|tokenDim.*divisib|CVT_FP4_SF_VEC_SIZE.*divisib" include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check if there are initialization or validation functions that ensure divisibility
rg "tokenDim|numCols" include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh | grep -i "check\|valid\|assert\|require"

Repository: flashinfer-ai/flashinfer

Length of output: 120


🏁 Script executed:

# Check what numEltsPerThread is and if there's any relationship to CVT_FP4_SF_VEC_SIZE
rg "numEltsPerThread|CVT_FP4_SF_VEC_SIZE" include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 828


🏁 Script executed:

# Search for LINEAR layout usage and if there are any preconditions documented
rg -B 10 "QuantizationSFLayout::LINEAR" include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh | head -40

Repository: flashinfer-ai/flashinfer

Length of output: 632


🏁 Script executed:

# Check the complete context around where numKTiles is calculated in LINEAR
cat -n include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh | sed -n '1062,1074p'

Repository: flashinfer-ai/flashinfer

Length of output: 588


🏁 Script executed:

# Check if LINEAR layout is actually instantiated or used anywhere
rg "QuantizationSFLayout::LINEAR" include/flashinfer/ -r

Repository: flashinfer-ai/flashinfer

Length of output: 481


🏁 Script executed:

# Check numEltsPerThread relationship to CVT_FP4_SF_VEC_SIZE (16)
rg -A 5 "numEltsPerThread.*sizeof" include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 432


🏁 Script executed:

# Check if there's padding or rounding applied to tokenDim before it's passed to cvt_quant_to_fp4_get_sf_out_offset
rg -B 5 "dimPadded\|round_up.*tokenDim" include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Check if LINEAR layout enum is defined and used elsewhere
rg "QuantizationSFLayout" include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh | grep -i "linear\|enum"

Repository: flashinfer-ai/flashinfer

Length of output: 124


🏁 Script executed:

# Check the line numbers more carefully: where exactly does #if CUDA_VERSION >= 12080 end?
cat -n include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh | sed -n '1025,1030p'

Repository: flashinfer-ai/flashinfer

Length of output: 411


🏁 Script executed:

# Look at the closing of the CUDA version guard that contains cvt_warp_fp16_to_fp4
cat -n include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh | sed -n '1080,1095p'

Repository: flashinfer-ai/flashinfer

Length of output: 759


🏁 Script executed:

# Verify the line numbers of the CUDA_VERSION guard closure
cat -n include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh | sed -n '1170,1195p'

Repository: flashinfer-ai/flashinfer

Length of output: 1494


🏁 Script executed:

# Check the full call stack to see if quant_nvfp4 instantiation is guarded
cat -n include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh | sed -n '1285,1310p'

Repository: flashinfer-ai/flashinfer

Length of output: 1045


🏁 Script executed:

# Find the actual call to quant_nvfp4 with surrounding guards
rg -B 5 -A 5 "quant::quant_nvfp4" include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh | head -50

Repository: flashinfer-ai/flashinfer

Length of output: 1157


🏁 Script executed:

# Check the full context around line 1293 where quant_nvfp4 is called
cat -n include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh | sed -n '1320,1350p'

Repository: flashinfer-ai/flashinfer

Length of output: 1693


Guard quant_nvfp4 template definition with #if CUDA_VERSION >= 12080 to prevent compile failures on CUDA < 12.8.

quant_nvfp4 (line 1193) calls cvt_warp_fp16_to_fp4, which is only defined inside #if CUDA_VERSION >= 12080 (lines 1025–1173). Although instantiation sites are guarded, the unguarded template definition will attempt to resolve cvt_warp_fp16_to_fp4 during parsing on CUDA < 12.8, causing a compile-time error.

Additionally, in LINEAR SF layout, numKTiles = numCols / details::CVT_FP4_SF_VEC_SIZE truncates instead of rounding up; the SWIZZLED_128x4 layout correctly uses ceil_div for the same calculation. Without a guarantee that numCols is divisible by 16, use ceil_div here as well.

Proposed fixes
-template <typename T, typename PackedType, int ELTS_PER_THREAD>
-inline __device__ void quant_nvfp4(PackedVec<PackedType, T> packedAccum, void* quantOutPtr,
+#if CUDA_VERSION >= 12080
+template <typename T, typename PackedType, int ELTS_PER_THREAD>
+inline __device__ void quant_nvfp4(PackedVec<PackedType, T> packedAccum, void* quantOutPtr,
   void* sfOutPtr, float* outputScale, uint32_t tokenIdx,
   uint32_t tokenDim, uint32_t packedIdx,
   QuantizationSFLayout sfLayout) {
   // ...
 }
+#endif
     } else if (layout == QuantizationSFLayout::LINEAR) {
       // Linear row-major layout, no padding required.
       int32_t KTileIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF;
-      int32_t numKTiles = numCols / details::CVT_FP4_SF_VEC_SIZE;
+      int32_t numKTiles = (numCols + details::CVT_FP4_SF_VEC_SIZE - 1) / details::CVT_FP4_SF_VEC_SIZE;
       int64_t mTileStride = numKTiles;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

The SF layout function is from trtllm_allreduce_fusion. Need to check if we need cell_div here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Check if trtllm_allreduce_fusion.cuh exists and how it handles LINEAR layout numKTiles calculation
fd "trtllm_allreduce_fusion.cuh" --exec cat {} | grep -A 10 -B 5 "LINEAR.*numKTiles\|numKTiles.*LINEAR"

Length of output: 0


🏁 Script executed:

#!/bin/bash
# Search for the get_sf_out_offset function in trtllm_allreduce_fusion.cuh to compare implementations
fd "trtllm_allreduce_fusion.cuh" --exec rg -A 15 "get_sf_out_offset.*LINEAR|LINEAR.*layout.*numKTiles" {}

Length of output: 0


🏁 Script executed:

#!/bin/bash
# Check if there's any divisibility validation or comments about numCols requirements
fd "trtllm_allreduce_fusion.cuh" --exec rg -B 5 -A 5 "CVT_FP4_SF_VEC_SIZE.*divisib|numCols.*divisib|LINEAR.*require" {}

Length of output: 0


🏁 Script executed:

#!/bin/bash
# Look for the complete LINEAR layout handling in trtllm_allreduce_fusion.cuh
fd "trtllm_allreduce_fusion.cuh" --exec rg -B 10 -A 10 "QuantizationSFLayout::LINEAR" {}

Length of output: 1006


🏁 Script executed:

#!/bin/bash
# Search for any validation or assertions about tokenDim/numCols divisibility in the codebase
rg "tokenDim.*16|tokenDim.*CVT_FP4|numCols.*16|numCols.*divisib" include/flashinfer/comm/ --type cpp

Length of output: 0


🏁 Script executed:

#!/bin/bash
# Check if there are FLASHINFER_CHECK calls that validate dimensions before quantization
rg "FLASHINFER_CHECK.*tokenDim|FLASHINFER_CHECK.*numCols|FLASHINFER_CHECK.*dim.*16" include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh

Length of output: 52


🏁 Script executed:

#!/bin/bash
# Look for how LINEAR layout is validated or documented in Python bindings
rg "LINEAR.*layout|QuantizationSFLayout.*LINEAR" python/flashinfer/comm/ --type py -A 5 -B 5

Length of output: 63


🏁 Script executed:

#!/bin/bash
# Check the dispatch code to see if there are dimension checks before calling quant_nvfp4
rg "quant_nvfp4|QType.*kFP4" include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh -B 10 | grep -A 10 "CHECK\|assert\|divisib"

Length of output: 0


🏁 Script executed:

#!/bin/bash
# Find Python comm files in the correct location
fd "\.py$" flashinfer/comm --exec basename {}

Length of output: 217


🏁 Script executed:

#!/bin/bash
# Check the test files to see what dimensions are used with LINEAR layout
rg "LINEAR|QuantizationSFLayout" tests/comm/ -A 5 -B 5

Length of output: 6176


🏁 Script executed:

#!/bin/bash
# Look at the full context of the tokenDim divisibility check
rg "FLASHINFER_CHECK.*tokenDim" include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh -B 3 -A 3

Length of output: 359


🏁 Script executed:

#!/bin/bash
# Check if CVT_FP4_SF_VEC_SIZE (16) has any relationship to the validated 8-element constraint
rg "CVT_FP4_SF_VEC_SIZE|numEltsPerThread" include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh | head -20

Length of output: 760


@timlee0212 Thanks for confirming the guard fix is done!

Regarding ceil_div for the LINEAR layout: I checked the codebase and found:

Current situation:

  • Python layer enforces hidden_dim % 16 == 0 (SF_VEC_SIZE) in tests
  • CUDA layer only validates tokenDim % 8 == 0 (numEltsPerThread for FP16/BF16)
  • trtllm_allreduce_fusion.cuh also uses truncating division for LINEAR: numKTiles = numCols / CVT_FP4_SF_VEC_SIZE
  • SWIZZLED_128x4 correctly uses: numKTiles = (numCols + factor - 1) / factor (ceil_div)

Recommendation:
Yes, use ceil_div for LINEAR layout for consistency and safety:

int32_t numKTiles = (numCols + details::CVT_FP4_SF_VEC_SIZE - 1) / details::CVT_FP4_SF_VEC_SIZE;

Reasons:

  1. Consistency with SWIZZLED_128x4 which handles the same calculation with ceiling division
  2. Defensive coding - doesn't rely on Python-layer validation; protects against direct CUDA API calls
  3. Should fix both files - trtllm_allreduce_fusion.cuh has the same issue and likely needs the same fix

While current usage happens to satisfy the divisibility requirement (Python tests enforce it), using ceil_div makes the code more robust and aligns with the SWIZZLED layout's approach.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In @csrc/trtllm_mnnvl_allreduce.cu:
- Around line 34-36: Fix the typo in the parameter comment for
Optional<TensorView> output in trtllm_mnnvl_allreduce.cu: change "thr normed" to
"the normed" so the comment reads "... This tensor can be empty if quant fusion
is enabled and the normed result is not needed." Reference the
Optional<TensorView> output parameter to locate the comment.
- Around line 162-164: The code currently casts output_scale to float* with
reinterpret_cast<float*> when setting params.outputScale, but there is no dtype
validation; add a self-defensive check to ensure output_scale (the TVM/FFI
tensor) is float32 before casting — e.g., use
TVM_FFI_ICHECK(output_scale.has_value()) and a dtype check on
output_scale.value().dtype() (or equivalent) to assert it's kDLFloat &&
bits==32, and only then set params.outputScale =
reinterpret_cast<float*>(output_scale.value().data_ptr()); alternatively
document this float32 requirement on the function comment referencing
params.outputScale and output_scale if you prefer documentation over a runtime
check.
🧹 Nitpick comments (1)
csrc/trtllm_mnnvl_allreduce.cu (1)

112-134: Consider adding explicit handling for QuantType::kNone.

The switch statement handles kFP8 and kFP4 but lacks an explicit case QuantType::kNone: or default: clause. While the current behavior (no validation for kNone) is correct, adding an explicit case would make the intent clearer and avoid compiler warnings about unhandled enum values.

Suggested fix
       switch (quant_type_enum) {
         case QuantType::kFP8:
           TVM_FFI_ICHECK(quant_out.has_value() && quant_out.value().size(0) == num_tokens &&
                          quant_out.value().size(1) == token_dim)
               << "quant_out shape mismatch: expected (" << num_tokens << ", " << token_dim
               << ") but got (" << quant_out.value().size(0) << ", " << quant_out.value().size(1)
               << ")";
           break;
         case QuantType::kFP4:
           // FP4 packs 2 elements per byte, assuming input tensor is of uint8 or FP4X2 dtype, so
           // quant_out has half the token_dim
           TVM_FFI_ICHECK(quant_out.has_value() && quant_out.value().size(0) == num_tokens &&
                          quant_out.value().size(1) == token_dim / 2)
               << "quant_out shape mismatch for FP4: expected (" << num_tokens << ", "
               << token_dim / 2 << ") but got (" << quant_out.value().size(0) << ", "
               << quant_out.value().size(1) << ")";
           // We only check the sf out size to be large enough.
           TVM_FFI_ICHECK(sf_out.has_value() &&
                          sf_out.value().numel() >= (num_tokens * token_dim / 16))
               << "sf_out size mismatch for FP4: expected at least " << num_tokens * token_dim / 16
               << " elements but got " << sf_out.value().numel();
           break;
+        case QuantType::kNone:
+          // No quantization validation needed
+          break;
       }
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 5b1d53b and 6cec217.

📒 Files selected for processing (1)
  • csrc/trtllm_mnnvl_allreduce.cu
🧰 Additional context used
📓 Path-based instructions (1)
csrc/**/*.cu

📄 CodeRabbit inference engine (CLAUDE.md)

Framework bindings and PyTorch tensor handling should be implemented in csrc/ via TVM-FFI, not in include/ headers

Files:

  • csrc/trtllm_mnnvl_allreduce.cu
🧬 Code graph analysis (1)
csrc/trtllm_mnnvl_allreduce.cu (1)
flashinfer/comm/_types.py (1)
  • QuantizationSFLayout (75-88)
🔇 Additional comments (2)
csrc/trtllm_mnnvl_allreduce.cu (2)

70-96: Quantization validation logic looks correct.

The conversion from Optional<int64_t> to QuantType via static_cast relies on the Python layer passing valid enum values. The validation that quant fusion requires RMSNorm fusion and 16-bit dtype is appropriately enforced.


178-189: LGTM!

The dispatch logic correctly routes to oneshot/twoshot variants based on the flag, with proper CUDA error handling and TVM-FFI export.

@timlee0212 timlee0212 force-pushed the mnnvlar_quant_fusion branch from 6cec217 to 1afe388 Compare January 12, 2026 11:14
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tests/comm/test_trtllm_mnnvl_allreduce.py (1)

577-604: Add GPU architecture guards for quantization patterns.

Per coding guidelines, tests should use flashinfer.utils functions to skip on unsupported architectures. FP8 quantization typically requires SM90+ (Hopper), and NVFP4 may require SM100+ (Blackwell). Without these guards, tests will fail on older GPUs rather than being skipped gracefully.

💡 Suggested implementation
from flashinfer.utils import get_compute_capability, is_sm90a_supported

# Add skip markers or inline checks:
@pytest.mark.parametrize(...)
def test_mnnvl_allreduce_refactored(...):
    # Skip FP8/FP4 patterns on unsupported architectures
    traits = get_pattern_traits(pattern)
    if traits.has_quant:
        compute_cap = get_compute_capability()
        if compute_cap < (9, 0):
            pytest.skip("FP8/FP4 quantization requires SM90+ (Hopper or newer)")
        if traits.quant_type == QuantFusionType.NVFP4 and compute_cap < (10, 0):
            pytest.skip("NVFP4 quantization requires SM100+ (Blackwell)")
    ...
🤖 Fix all issues with AI agents
In @flashinfer/comm/trtllm_mnnvl_ar.py:
- Around line 622-627: When handling QuantFusionType.NVFP4 in the branch where
quant_out is None, the created FP4 buffer is missing the device spec and will
default to CPU; update the torch.empty call that constructs quant_out (the one
using input.shape and dtype=torch.uint8) to allocate on the same device as input
by passing device=input.device (or input.device.type if needed) so quant_out
lives on the correct CUDA device before it's used by the CUDA kernel.

In @include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh:
- Around line 1166-1171: The function cvt_warp_fp16_to_fp4 currently returns 0
on unsupported architectures which can silently corrupt quantized data; change
this to fail fast by adding a compile-time static_assert (or an architecture
check using #error) to prevent building for unsupported targets, or replace the
fallback with a runtime error/abort (e.g., assert or device-side trap) so the
host cannot use FP4 quantization on unsupported GPUs; update any dispatch logic
that calls cvt_warp_fp16_to_fp4 to guard usage (so callers never select FP4
quantization when the architecture check fails) and reference
cvt_warp_fp16_to_fp4 and the e2m1Vec return path when making the change.
🧹 Nitpick comments (10)
include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh (5)

50-55: Track the duplicated QuantType enum definition.

The TODO comment indicates this enum is duplicated elsewhere. Consider consolidating into a shared header to avoid maintenance issues and potential divergence.

Would you like me to help identify the duplicate location and create a shared header for this enum?


819-823: Consider using static_assert instead of runtime assert(false).

The base template uses assert(false) which only triggers at runtime. A static_assert would catch unsupported types at compile time, similar to other templates in this codebase (e.g., negZero at line 154).

♻️ Suggested fix
 template <typename T>
 __device__ inline T cuda_abs(T val) {
-  assert(false);
+  static_assert(sizeof(T) == 0, "cuda_abs not specialized for this type");
   return {};
 }

1435-1465: Dispatch macros are well-structured but DISPATCH_ALLREDUCE_KERNEL is not undefined.

The LAUNCH_ALLREDUCE_KERNEL macro is correctly undefined at line 1500, but DISPATCH_ALLREDUCE_KERNEL is not. This could cause naming conflicts if this header is included in files that define a macro with the same name.

♻️ Suggested fix
 #undef LAUNCH_ALLREDUCE_KERNEL
+#undef DISPATCH_ALLREDUCE_KERNEL
   return cudaSuccess;
 }

1798-1798: Variable gamma shadows the kernel parameter of the same name.

The local variable gamma at line 1798 shadows the kernel input parameter gamma (line 1652). While functionally correct since the parameter is loaded into smemGamma, this can cause confusion during maintenance.

♻️ Suggested rename
-      PackedVec<float4, T> gamma = {.packed = loadPacked<float4>(&smemGamma[threadLoadOffset])};
+      PackedVec<float4, T> gammaVec = {.packed = loadPacked<float4>(&smemGamma[threadLoadOffset])};
 
 #pragma unroll
       for (uint32_t j = 0; j < kELTS_PER_LOAD; j++) {
         rOut.elements[j] =
-            fromFloat<T>(toFloat<T>(gamma.elements[j]) * rInput[i * kELTS_PER_LOAD + j] * rcpRms);
+            fromFloat<T>(toFloat<T>(gammaVec.elements[j]) * rInput[i * kELTS_PER_LOAD + j] * rcpRms);
       }

549-556: Track code duplication with related allreduce files.

The TODO indicates these quantization utilities are shared with trtllm_allreduce_fusion.cuh and moe_allreduce_fusion. Consider extracting the common quant namespace to a shared header to reduce maintenance burden and ensure consistency.

Would you like me to help identify the common code and propose a shared header structure?

csrc/trtllm_mnnvl_allreduce.cu (1)

112-134: Missing default case in switch statement could cause issues with QuantType::kNone.

The switch statement handles kFP8 and kFP4 but doesn't have a default or kNone case. While kNone doesn't require validation, adding a default case improves defensive coding and prevents unhandled future enum values.

♻️ Suggested improvement
       switch (quant_type_enum) {
         case QuantType::kFP8:
           TVM_FFI_ICHECK(quant_out.has_value() && quant_out.value().size(0) == num_tokens &&
                          quant_out.value().size(1) == token_dim)
               << "quant_out shape mismatch: expected (" << num_tokens << ", " << token_dim
               << ") but got (" << quant_out.value().size(0) << ", " << quant_out.value().size(1)
               << ")";
           break;
         case QuantType::kFP4:
           // FP4 packs 2 elements per byte, assuming input tensor is of uint8 or FP4X2 dtype, so
           // quant_out has half the token_dim
           TVM_FFI_ICHECK(quant_out.has_value() && quant_out.value().size(0) == num_tokens &&
                          quant_out.value().size(1) == token_dim / 2)
               << "quant_out shape mismatch for FP4: expected (" << num_tokens << ", "
               << token_dim / 2 << ") but got (" << quant_out.value().size(0) << ", "
               << quant_out.value().size(1) << ")";
           // We only check the sf out size to be large enough.
           TVM_FFI_ICHECK(sf_out.has_value() &&
                          sf_out.value().numel() >= (num_tokens * token_dim / 16))
               << "sf_out size mismatch for FP4: expected at least " << num_tokens * token_dim / 16
               << " elements but got " << sf_out.value().numel();
           break;
+        case QuantType::kNone:
+        default:
+          // No quantization validation needed
+          break;
       }
flashinfer/comm/allreduce.py (1)

697-709: Unused variable residual_result should be prefixed with underscore.

The residual_result variable is unpacked but never used. Use underscore prefix to indicate it's intentionally unused.

♻️ Proposed fix
-                norm_result, residual_result = trtllm_mnnvl_fused_allreduce_add_rmsnorm(
+                norm_result, _residual_result = trtllm_mnnvl_fused_allreduce_add_rmsnorm(
flashinfer/comm/trtllm_mnnvl_ar.py (1)

634-655: Missing explicit handling for SWIZZLED_8x4 layout.

The code handles SWIZZLED_128x4 and LINEAR layouts but SWIZZLED_8x4 (value 1 per _types.py) falls through to the generic error. While allreduce.py blocks this layout at a higher level (line 647-650), adding an explicit check here provides better defense-in-depth.

♻️ Suggested improvement
             if layout_code == QuantizationSFLayout.SWIZZLED_128x4:
                 sf_out = torch.empty(
                     (
                         _compute_swizzled_layout_sf_size(
                             input.shape[0], input.shape[1] // 16
                         )
                     ),
                     device=input.device,
                     dtype=torch.float8_e4m3fn,
                 )
             elif layout_code == QuantizationSFLayout.LINEAR:
                 # linear layout
                 sf_out = torch.empty(
                     (input.shape[0], input.shape[1] // 16),
                     device=input.device,
                     dtype=torch.float8_e4m3fn,
                 )
+            elif layout_code == QuantizationSFLayout.SWIZZLED_8x4:
+                raise ValueError(
+                    "SWIZZLED_8x4 layout is not supported for MNNVL AllReduce quantization. "
+                    "Use SWIZZLED_128x4 or LINEAR layout instead."
+                )
             else:
                 raise ValueError(
                     f"Unsupported scaling factor layout code: {layout_code}."
                 )
tests/comm/test_trtllm_mnnvl_allreduce.py (2)

403-411: Unused monkeypatch parameter.

The monkeypatch parameter is declared in the function signature (line 404) and docstring (line 415) but is never used within the function body. Consider removing it unless it's intended for future use.

♻️ Suggested fix
 def run_mnnvl_ar_full(
-    monkeypatch,
     seq_lens: list[int],
     pattern: AllReduceFusionPattern,
     dtype: torch.dtype,
     hidden_size: int,
     legacy_explicit_workspace_bytes: Optional[int] = None,
     legacy_api: bool = False,
 ):
     """Core test logic for MNNVL AllReduce operations.

     Args:
-        monkeypatch: pytest monkeypatch fixture
         seq_lens: List of sequence lengths to test

And update the callers at lines 602-604, 623-625, and 629.


628-635: Direct execution block uses hardcoded pattern.

The __main__ block only tests kAllReduce pattern. For debugging quantized patterns during development, consider parameterizing this or adding commented examples for other patterns.

Also, monkeypatch is passed as None which works because the parameter is unused (as noted earlier), but this confirms the parameter should be removed.

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 6cec217 and 1afe388.

📒 Files selected for processing (8)
  • csrc/trtllm_mnnvl_allreduce.cu
  • flashinfer/comm/__init__.py
  • flashinfer/comm/_types.py
  • flashinfer/comm/allreduce.py
  • flashinfer/comm/trtllm_ar.py
  • flashinfer/comm/trtllm_mnnvl_ar.py
  • include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh
  • tests/comm/test_trtllm_mnnvl_allreduce.py
✅ Files skipped from review due to trivial changes (1)
  • flashinfer/comm/_types.py
🧰 Additional context used
📓 Path-based instructions (4)
csrc/**/*.cu

📄 CodeRabbit inference engine (CLAUDE.md)

Framework bindings and PyTorch tensor handling should be implemented in csrc/ via TVM-FFI, not in include/ headers

Files:

  • csrc/trtllm_mnnvl_allreduce.cu
flashinfer/**/*.py

📄 CodeRabbit inference engine (CLAUDE.md)

flashinfer/**/*.py: Use @functools.cache decorator on Python API functions to implement module-level caching and avoid recompilation
Use @flashinfer_api decorator for debugging API calls, enable via FLASHINFER_LOGLEVEL environment variable (0=off, 1=basic, 3=detailed, 5=with stats)

Files:

  • flashinfer/comm/__init__.py
  • flashinfer/comm/trtllm_ar.py
  • flashinfer/comm/trtllm_mnnvl_ar.py
  • flashinfer/comm/allreduce.py
tests/**/*.py

📄 CodeRabbit inference engine (CLAUDE.md)

tests/**/*.py: Test implementations should use flashinfer.utils functions (get_compute_capability, is_sm90a_supported, is_sm100a_supported, etc.) to skip tests on unsupported GPU architectures
For testing with mpirun on multi-GPU systems, use the pattern: mpirun -np <num_gpus> pytest tests/path/to/test.py::test_function
Avoid OOM (out-of-memory) errors in tests by using appropriate problem sizes - tests/conftest.py provides auto-skipping for OOM tests as a safety net but should not be relied upon

Files:

  • tests/comm/test_trtllm_mnnvl_allreduce.py
include/**/*.cuh

📄 CodeRabbit inference engine (CLAUDE.md)

include/**/*.cuh: Torch headers MUST NOT be included in files within the include/ directory - keep framework-agnostic CUDA kernels that accept raw pointers
Kernel code in include/flashinfer/ is automatically picked up by JIT compilation on changes - no pip reinstall needed
For performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers
Consult the PTX ISA documentation (https://docs.nvidia.com/cuda/parallel-thread-execution/) for low-level instruction details and new GPU architecture features when writing inline PTX assembly

Files:

  • include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh
🧠 Learnings (6)
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to csrc/**/*.cu : Framework bindings and PyTorch tensor handling should be implemented in `csrc/` via TVM-FFI, not in `include/` headers

Applied to files:

  • csrc/trtllm_mnnvl_allreduce.cu
  • flashinfer/comm/__init__.py
  • include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/__init__.py : Export new operations in `flashinfer/__init__.py` to make them available as public API

Applied to files:

  • flashinfer/comm/__init__.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : For performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers

Applied to files:

  • include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Torch headers MUST NOT be included in files within the `include/` directory - keep framework-agnostic CUDA kernels that accept raw pointers

Applied to files:

  • include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Consult the PTX ISA documentation (https://docs.nvidia.com/cuda/parallel-thread-execution/) for low-level instruction details and new GPU architecture features when writing inline PTX assembly

Applied to files:

  • include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh
🧬 Code graph analysis (5)
csrc/trtllm_mnnvl_allreduce.cu (1)
flashinfer/comm/_types.py (1)
  • QuantizationSFLayout (75-88)
flashinfer/comm/__init__.py (1)
flashinfer/comm/_types.py (4)
  • AllReduceFusionPattern (37-55)
  • QuantizationSFLayout (75-88)
  • get_pattern_traits (171-187)
  • QuantFusionType (63-72)
flashinfer/comm/trtllm_ar.py (1)
flashinfer/comm/_types.py (2)
  • AllReduceFusionPattern (37-55)
  • QuantizationSFLayout (75-88)
flashinfer/comm/trtllm_mnnvl_ar.py (4)
flashinfer/fp4_quantization.py (1)
  • _compute_swizzled_layout_sf_size (47-50)
flashinfer/fp8_quantization.py (1)
  • _compute_swizzled_layout_sf_size (16-19)
flashinfer/comm/workspace_base.py (1)
  • is_buffer_size_sufficient (53-61)
flashinfer/comm/_types.py (2)
  • QuantFusionType (63-72)
  • QuantizationSFLayout (75-88)
flashinfer/comm/allreduce.py (3)
flashinfer/comm/_types.py (4)
  • AllReduceFusionPattern (37-55)
  • QuantizationSFLayout (75-88)
  • get_pattern_traits (171-187)
  • has_quant (113-115)
flashinfer/comm/workspace_base.py (1)
  • AllReduceFusionWorkspace (23-89)
flashinfer/comm/trtllm_mnnvl_ar.py (5)
  • MNNVLAllReduceFusionWorkspace (59-248)
  • MNNVLAllreduceFusionStrategy (39-52)
  • trtllm_mnnvl_allreduce (360-431)
  • trtllm_mnnvl_fused_allreduce_add_rmsnorm (434-526)
  • trtllm_mnnvl_fused_allreduce_add_rmsnorm_quant (529-692)
🪛 Ruff (0.14.10)
flashinfer/comm/trtllm_mnnvl_ar.py

587-589: Avoid specifying long messages outside the exception class

(TRY003)


591-593: Avoid specifying long messages outside the exception class

(TRY003)


595-597: Avoid specifying long messages outside the exception class

(TRY003)


602-604: Avoid specifying long messages outside the exception class

(TRY003)


619-621: Avoid specifying long messages outside the exception class

(TRY003)


630-632: Avoid specifying long messages outside the exception class

(TRY003)


653-655: Avoid specifying long messages outside the exception class

(TRY003)


666-668: Avoid specifying long messages outside the exception class

(TRY003)

flashinfer/comm/allreduce.py

699-699: Unpacked variable residual_result is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (21)
include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh (8)

17-38: LGTM!

Headers are appropriately selected for framework-agnostic CUDA kernels - no Torch headers are included as required by the coding guidelines. The conditional inclusion of cuda_fp4.h for CUDA 12.8+ is correctly guarded.


57-100: LGTM!

The AllReduceFusionParams struct is well-documented and logically organized. The parameter grouping and default values are appropriate for the fusion operations.


498-542: LGTM!

The adjustGridConfig function correctly uses the new configuration constants from the config namespace. The grid configuration tuning logic for handling different SM versions and cluster sizes is sound.


1224-1238: LGTM!

The kernel template correctly enforces that quantization-only patterns require RMSNorm fusion via the static assertion. The template structure allows for efficient compile-time dispatch of the different fusion patterns.


1371-1386: LGTM!

The quantization paths are correctly conditional:

  • Output write is guarded by null check
  • FP8 quantization uses inverse scale correctly
  • FP4 quantization is guarded by CUDA_VERSION >= 12080

The control flow correctly handles all fusion patterns.


1510-1512: LGTM!

The two-shot allreduce kernel correctly uses the configuration constants and implements the scatter-broadcast pattern with proper Lamport synchronization. The separation of quantization to the fusion kernel is a clean design.


1805-1819: LGTM!

The quantization paths in rmsNormLamport_fusion correctly mirror the pattern from the oneshot kernel, with proper null checks for output and appropriate guards for FP4 on CUDA 12.8+.


1934-1966: LGTM!

The dispatch macros for the two-shot RMSNorm fusion are well-structured with proper #undef cleanup. The quantization type dispatch correctly handles FP8, FP4, and None cases with appropriate type constraints.

csrc/trtllm_mnnvl_allreduce.cu (2)

162-164: Verify output_scale tensor dtype before casting to float*.

The code uses reinterpret_cast<float*> without verifying that output_scale is actually a float32 tensor. If a different dtype is passed, this could cause undefined behavior.

Consider adding a dtype check:

if (output_scale.has_value()) {
  TVM_FFI_ICHECK(output_scale.value().dtype() == float32_code)
      << "output_scale must be float32";
}

31-61: Well-structured function signature with clear parameter groupings.

The expanded signature follows good practices with logical groupings (Primary I/O, Communication, Config, RMSNorm fusion, Quantization) and helpful inline documentation. This aligns with the coding guideline that framework bindings should be implemented in csrc/ via TVM-FFI.

flashinfer/comm/trtllm_ar.py (1)

65-67: Clean consolidation of type imports from canonical location.

Importing from ._types and re-exporting maintains backward compatibility while centralizing type definitions. This is good API hygiene.

flashinfer/comm/__init__.py (1)

5-9: Good public API consolidation with canonical type re-exports.

The explicit X as X import pattern and comment clarifying the canonical source (_types.py) provide clear documentation. This follows the guideline to export operations in __init__.py for public API. Based on learnings.

flashinfer/comm/allreduce.py (1)

645-665: Good trait-based routing for fusion pattern dispatch.

Using get_pattern_traits(pattern) to derive capabilities (has_rmsnorm, has_norm_out, has_quant) is cleaner than hard-coded pattern checks. This makes the code more maintainable when new patterns are added.

flashinfer/comm/trtllm_mnnvl_ar.py (2)

417-429: Good use of named arguments for the module call.

Using named arguments (e.g., input=input, output=output) in the module call makes the code more readable and less error-prone when the function signature is complex.


529-581: Comprehensive documentation for the new quantization function.

The docstring thoroughly documents:

  • Behavior differences from the non-quant version
  • Supported quantization types and their constraints
  • Scale factor layout options with details
  • All parameters and return values

This is excellent for a complex fusion API.

tests/comm/test_trtllm_mnnvl_allreduce.py (6)

29-40: LGTM!

The FP8 quantization and dequantization helper functions are correctly implemented with proper clamping to FP8 representable range.


98-127: Verify output tuple ordering consistency.

The quantized path returns (quant_out, residual_out, sf_out, output) at lines 122-127, while the comment on line 121 says "We alter the order here to be compatible with the non-quant case." However:

  • kARResidualRMSNorm returns (output, residual_out) - 2 elements
  • kAllReduce returns (output,) - 1 element
  • Quantized patterns return (quant_out, residual_out, sf_out, output) - 4 elements

The checking logic at lines 152-219 accesses output[0], output[1], output[2], output[3] for quantized paths, which matches this 4-element tuple. The logic appears correct, but the comment about "compatibility" is somewhat misleading since the tuple structures are quite different.


231-232: LGTM - Legacy function preserved for backward compatibility.

The TODO comment appropriately marks this for removal when the deprecated API is cleaned up.


387-399: LGTM!

The condition correctly identifies that all patterns except kAllReduce require the fused reference output (allreduce + residual + rmsnorm). This correctly handles the new quantized patterns which also need the RMSNorm output as the reference for comparison.


176-198: CPU-GPU transfers are necessary by design and can remain as-is.

The e2m1_and_ufp8sf_scale_to_float function is a CPU-only operation (as noted in the code comment: "this is another cpu op, should decouple it from cuda ops in the future"). The .cpu() and .cuda() transfers are intentional and required—the underlying implementation explicitly creates CPU tensors and operates on CPU data. The test code's redundant .cpu() calls before invoking the function are a minor inefficiency but pose no correctness issue and don't warrant refactoring at this stage.


19-24: Add skip decorator for MNNVL availability check.

This test file should use the @pytest.mark.skipif(not mnnvl_available(), ...) decorator on test functions, matching the pattern established in other MNNVL tests (test_mnnvl_memory.py, test_mnnvl_moe_alltoall.py). The mnnvl_available() function verifies both NVLink and container permissions (pidfd), which are required for MNNVL operations.

Import mnnvl_available from tests.comm.conftest and apply the decorator to test_mnnvl_allreduce_refactored() and test_mnnvl_allreduce_legacy().

Comment on lines +679 to +696
if pattern_traits.has_quant:
# FIXME: I have the function below to handle the creation of the buffer, is this correct?
quant_result, scale_out, residual_out, norm_result = (
trtllm_mnnvl_fused_allreduce_add_rmsnorm_quant(
input=input,
residual_in=residual_in,
gamma=rms_gamma,
workspace=workspace,
epsilon=rms_eps,
output=norm_out,
residual_out=residual_out,
launch_with_pdl=launch_with_pdl,
output_scale=scale_factor,
layout_code=layout_code,
quant_type=pattern_traits.quant_type,
)
)
return quant_result
Copy link
Contributor

@coderabbitai coderabbitai bot Jan 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

cat -n flashinfer/comm/allreduce.py | sed -n '670,710p'

Repository: flashinfer-ai/flashinfer

Length of output: 2156


🏁 Script executed:

rg "def trtllm_mnnvl_fused_allreduce_add_rmsnorm_quant" --type py -A 20

Repository: flashinfer-ai/flashinfer

Length of output: 1861


🏁 Script executed:

rg "residual_result" flashinfer/comm/allreduce.py -A 5 -B 5

Repository: flashinfer-ai/flashinfer

Length of output: 492


🏁 Script executed:

rg "trtllm_mnnvl_fused_allreduce_add_rmsnorm_quant|trtllm_mnnvl_fused_allreduce_add_rmsnorm" flashinfer/comm/trtllm_mnnvl_ar.py -B 3 | head -50

Repository: flashinfer-ai/flashinfer

Length of output: 874


🏁 Script executed:

cat -n flashinfer/comm/trtllm_mnnvl_ar.py | sed -n '1,150p' | grep -E "@|def trtllm_mnnvl"

Repository: flashinfer-ai/flashinfer

Length of output: 92


Resolve the FIXME about buffer creation and fix the unused variable.

The FIXME at line 680-681 indicates uncertainty about whether pre-allocated buffers are being managed correctly. The code allocates norm_out and residual_out at lines 674-677, then passes them to the function which reassigns residual_out from the return value. Clarify whether the function creates buffers internally or expects pre-allocated ones, and update the code accordingly.

Additionally, at line 699, residual_result is unpacked but never used. Use _residual_result instead to indicate intentional non-use.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like the humans are having a chat. I'll hop back into my burrow for now. If you need me again, just tag @coderabbitai in a new comment, and I'll come hopping out!

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In @tests/comm/test_trtllm_mnnvl_allreduce.py:
- Around line 24-25: The test file mixes absolute and relative imports which
breaks running as a script (ImportError on attempted relative import); change
the relative import "from .conftest import mnnvl_available" to an absolute
import (e.g., "from tests.conftest import mnnvl_available") so both imports use
the same absolute style and the __main__ execution path that calls the module
(around the script entry at line 636) can import successfully.
🧹 Nitpick comments (2)
tests/comm/test_trtllm_mnnvl_allreduce.py (2)

187-192: Track the FIXME about optional global_scale parameter.

The comment indicates that the global_scale parameter is documented as optional but providing None leads to an error. This should be tracked in an issue to ensure it gets addressed.

Would you like me to help create an issue to track this bug in the e2m1_and_ufp8sf_scale_to_float API?


598-608: The monkeypatch parameter is unused.

The monkeypatch fixture is passed through from test_mnnvl_allreduce_refactored to run_mnnvl_ar_full but is never actually used in the test logic. Consider removing it if not needed, or add a comment explaining future intended use.

Suggested fix if not needed
 def test_mnnvl_allreduce_refactored(
-    monkeypatch,
     seq_lens: list[int],
     pattern: AllReduceFusionPattern,
     dtype: torch.dtype,
     hidden_size: int,
 ):
     """Test MNNVL AllReduce with refactored API."""
     run_mnnvl_ar_full(
-        monkeypatch, seq_lens, pattern, dtype, hidden_size, legacy_api=False
+        seq_lens, pattern, dtype, hidden_size, legacy_api=False
     )

And update run_mnnvl_ar_full signature accordingly.

Also applies to: 636-643

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 1afe388 and 84ae387.

📒 Files selected for processing (1)
  • tests/comm/test_trtllm_mnnvl_allreduce.py
🧰 Additional context used
📓 Path-based instructions (1)
tests/**/*.py

📄 CodeRabbit inference engine (CLAUDE.md)

tests/**/*.py: Test implementations should use flashinfer.utils functions (get_compute_capability, is_sm90a_supported, is_sm100a_supported, etc.) to skip tests on unsupported GPU architectures
For testing with mpirun on multi-GPU systems, use the pattern: mpirun -np <num_gpus> pytest tests/path/to/test.py::test_function
Avoid OOM (out-of-memory) errors in tests by using appropriate problem sizes - tests/conftest.py provides auto-skipping for OOM tests as a safety net but should not be relied upon

Files:

  • tests/comm/test_trtllm_mnnvl_allreduce.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (5)
tests/comm/test_trtllm_mnnvl_allreduce.py (5)

29-40: LGTM!

The FP8 quantization and dequantization helpers are correctly implemented with proper clamping to the float8_e4m3fn bounds and appropriate type conversions.


98-127: LGTM!

The pattern-based branching using get_pattern_traits is well-structured. The conditional output tensor allocation based on traits.has_norm_out and the tuple reordering for API compatibility are handled correctly.


360-401: LGTM!

The pattern-based reference output computation correctly handles both AllReduce-only and fused patterns. The reference calculation for the fused case (AllReduce → Residual Add → RMSNorm) aligns with the expected kernel behavior.


611-633: LGTM!

The legacy test is appropriately scoped to only the patterns supported by the deprecated API. The TODO comment clearly marks this for removal, and the pattern-to-boolean conversion in run_mnnvl_ar_full (line 515) correctly maintains backward compatibility.


577-608: Remove FP4 quantization patterns from test parametrization—MNNVL only supports patterns 0 and 1.

The test parametrizes with kARResidualRMSNormFP4Quant and kARResidualRMSNormOutFP4Quant, but per the documentation in allreduce.py:501, "MNNVL only supports patterns 0 and 1" (kAllReduce and kARResidualRMSNorm). All quantization patterns (including FP4/FP8) are explicitly unsupported and will raise ValueError at runtime. The test should either skip these patterns for MNNVL tests or remove them from the parametrization.

Likely an incorrect or invalid review comment.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In @flashinfer/comm/trtllm_mnnvl_ar.py:
- Around line 614-635: Update the function docstring to state that sf_out will
be None for FP8 quantization: when quant_type == QuantFusionType.FP8 the code
path does not allocate or return scaling-factor output (sf_out) so document that
sf_out is None in this case, and keep existing behavior for other quant_types;
mention sf_out, quant_out, quant_type and QuantFusionType.FP8 in the docstring
so callers know the conditional return semantics.
🧹 Nitpick comments (8)
flashinfer/comm/trtllm_mnnvl_ar.py (2)

529-544: Missing @functools.cache decorator per coding guidelines.

As per the coding guidelines for flashinfer/**/*.py, API functions should use @functools.cache decorator for module-level caching to avoid recompilation. This new public function trtllm_mnnvl_fused_allreduce_add_rmsnorm_quant should follow the same pattern as other public functions in this module.

♻️ Suggested fix
+@functools.cache
 def trtllm_mnnvl_fused_allreduce_add_rmsnorm_quant(
     input: torch.Tensor,
     residual_in: torch.Tensor,

Note: If caching is inappropriate due to the tensor parameters (which are not hashable), consider whether this function truly needs caching or if the internal get_trtllm_mnnvl_comm_module() call already handles the caching concern.


576-581: Consider documenting the different return order compared to non-quant function.

The return order (quant_out, sf_out, residual_out, output) differs from trtllm_mnnvl_fused_allreduce_add_rmsnorm which returns (output, residual_out). While documented, adding a note about this difference could help users transitioning between the two APIs.

Also applies to: 695-695

tests/comm/test_trtllm_mnnvl_allreduce.py (1)

171-174: Local import inside test function is acceptable but consider top-level with guard.

The import of nvfp4_quantize and e2m1_and_ufp8sf_scale_to_float inside the function body is reasonable to avoid import errors on systems without CUDA 12.8+. However, if this pattern is used frequently, consider a top-level import with a try/except guard for clarity.

include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh (4)

50-55: Duplicated QuantType enum definition needs resolution.

The TODO comment acknowledges this duplication. Consider consolidating the enum definition in a shared header (e.g., include/flashinfer/fp4_layout.cuh or a dedicated types header) to ensure consistency and avoid maintenance burden.

#!/bin/bash
# Find other QuantType definitions in the codebase
rg -n "enum.*QuantType|class QuantType" --type cpp --type cuh

1471-1503: Runtime validation placed after macro definitions but before dispatch - consider reordering for clarity.

The runtime validation for quantization parameters (null checks, CUDA version check) is placed between macro definitions and the actual kernel dispatch. While functionally correct, consider moving these checks to the top of the function for better readability and early failure.


1495-1503: FP4 CUDA version check may be unreachable dead code.

The runtime check for CUDA_VERSION < 12080 at lines 1496-1502 is inside a compile-time #if CUDA_VERSION < 12080 block. If CUDA version is >= 12080, this block doesn't compile. If CUDA version is < 12080, the check will always trigger. The check is defense-in-depth but the structure is slightly confusing.

Consider restructuring:

if (params.quantType == QuantType::kFP4) {
#if CUDA_VERSION < 12080
    FLASHINFER_ERROR(...);
    return cudaErrorInvalidValue;
#endif
    // FP4-specific validation that requires CUDA 12080+
}

1093-1171: Consider adding algorithm overview comment for cvt_warp_fp16_to_fp4.

Per coding guidelines for performance-critical hot paths, leave comments explaining special algorithmic choices. This function implements warp-level FP4 quantization with scale factor computation. A brief overview comment explaining the algorithm (warp reduction for max, scale factor quantization, pipelined conversion) would help future reviewers.

csrc/trtllm_mnnvl_allreduce.cu (1)

116-138: Consider adding default case for completeness.

The switch statement handles kFP8 and kFP4 but has no default case. While kNone correctly falls through without validation (quant outputs not needed), adding an explicit default: break; improves code clarity and prevents compiler warnings in some configurations.

♻️ Suggested fix
       case QuantType::kFP4:
         // ... validation ...
         break;
+      default:
+        break;  // kNone: no quant output validation needed
     }
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 84ae387 and 43a904d.

📒 Files selected for processing (4)
  • csrc/trtllm_mnnvl_allreduce.cu
  • flashinfer/comm/trtllm_mnnvl_ar.py
  • include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh
  • tests/comm/test_trtllm_mnnvl_allreduce.py
🧰 Additional context used
📓 Path-based instructions (4)
flashinfer/**/*.py

📄 CodeRabbit inference engine (CLAUDE.md)

flashinfer/**/*.py: Use @functools.cache decorator on Python API functions to implement module-level caching and avoid recompilation
Use @flashinfer_api decorator for debugging API calls, enable via FLASHINFER_LOGLEVEL environment variable (0=off, 1=basic, 3=detailed, 5=with stats)

Files:

  • flashinfer/comm/trtllm_mnnvl_ar.py
tests/**/*.py

📄 CodeRabbit inference engine (CLAUDE.md)

tests/**/*.py: Test implementations should use flashinfer.utils functions (get_compute_capability, is_sm90a_supported, is_sm100a_supported, etc.) to skip tests on unsupported GPU architectures
For testing with mpirun on multi-GPU systems, use the pattern: mpirun -np <num_gpus> pytest tests/path/to/test.py::test_function
Avoid OOM (out-of-memory) errors in tests by using appropriate problem sizes - tests/conftest.py provides auto-skipping for OOM tests as a safety net but should not be relied upon

Files:

  • tests/comm/test_trtllm_mnnvl_allreduce.py
csrc/**/*.cu

📄 CodeRabbit inference engine (CLAUDE.md)

Framework bindings and PyTorch tensor handling should be implemented in csrc/ via TVM-FFI, not in include/ headers

Files:

  • csrc/trtllm_mnnvl_allreduce.cu
include/**/*.cuh

📄 CodeRabbit inference engine (CLAUDE.md)

include/**/*.cuh: Torch headers MUST NOT be included in files within the include/ directory - keep framework-agnostic CUDA kernels that accept raw pointers
Kernel code in include/flashinfer/ is automatically picked up by JIT compilation on changes - no pip reinstall needed
For performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers
Consult the PTX ISA documentation (https://docs.nvidia.com/cuda/parallel-thread-execution/) for low-level instruction details and new GPU architecture features when writing inline PTX assembly

Files:

  • include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh
🧠 Learnings (7)
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to tests/**/*.py : For testing with `mpirun` on multi-GPU systems, use the pattern: `mpirun -np <num_gpus> pytest tests/path/to/test.py::test_function`

Applied to files:

  • tests/comm/test_trtllm_mnnvl_allreduce.py
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to csrc/**/*.cu : Framework bindings and PyTorch tensor handling should be implemented in `csrc/` via TVM-FFI, not in `include/` headers

Applied to files:

  • csrc/trtllm_mnnvl_allreduce.cu
  • include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : For performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers

Applied to files:

  • include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh
📚 Learning: 2025-11-12T03:35:17.583Z
Learnt from: raayandhar
Repo: flashinfer-ai/flashinfer PR: 2070
File: include/flashinfer/gemm/bf16_gemm_cutlass_template.h:145-160
Timestamp: 2025-11-12T03:35:17.583Z
Learning: In flashinfer GEMM implementations (e.g., include/flashinfer/gemm/bf16_gemm_cutlass_template.h, fp8_gemm_cutlass_template.h), it is acceptable to catch and silently ignore std::runtime_error exceptions in getWorkspaceSizeImpl when probing multiple GEMM configurations, as some configurations may legitimately fail due to SMEM constraints. This pattern should include a comment like "// Swallow errors when SMEM exceeds maximum allowed" to document the rationale.

Applied to files:

  • include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Torch headers MUST NOT be included in files within the `include/` directory - keep framework-agnostic CUDA kernels that accept raw pointers

Applied to files:

  • include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : Consult the PTX ISA documentation (https://docs.nvidia.com/cuda/parallel-thread-execution/) for low-level instruction details and new GPU architecture features when writing inline PTX assembly

Applied to files:

  • include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Match coding style, efficiency, complexity, verbosity, and defensiveness by learning from existing code - document intentional departures with rationale and mention 'AI-assisted' in git commit messages

Applied to files:

  • include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh
🧬 Code graph analysis (2)
flashinfer/comm/trtllm_mnnvl_ar.py (3)
flashinfer/fp4_quantization.py (1)
  • _compute_swizzled_layout_sf_size (47-50)
flashinfer/fp8_quantization.py (1)
  • _compute_swizzled_layout_sf_size (16-19)
flashinfer/comm/_types.py (2)
  • QuantFusionType (63-72)
  • QuantizationSFLayout (75-88)
csrc/trtllm_mnnvl_allreduce.cu (2)
flashinfer/comm/_types.py (1)
  • QuantizationSFLayout (75-88)
csrc/tvm_ffi_utils.h (1)
  • encode_dlpack_dtype (30-32)
🪛 Ruff (0.14.10)
flashinfer/comm/trtllm_mnnvl_ar.py

587-589: Avoid specifying long messages outside the exception class

(TRY003)


591-593: Avoid specifying long messages outside the exception class

(TRY003)


595-597: Avoid specifying long messages outside the exception class

(TRY003)


602-604: Avoid specifying long messages outside the exception class

(TRY003)


619-621: Avoid specifying long messages outside the exception class

(TRY003)


633-635: Avoid specifying long messages outside the exception class

(TRY003)


656-658: Avoid specifying long messages outside the exception class

(TRY003)


669-671: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (11)
flashinfer/comm/trtllm_mnnvl_ar.py (1)

606-611: LGTM: output_scale conversion handles both tensor and scalar inputs.

The logic correctly converts a scalar float to a 1-element float32 tensor on the input device, ensuring compatibility with the kernel's float* expectation.

tests/comm/test_trtllm_mnnvl_allreduce.py (4)

22-31: LGTM: Helper functions for FP8 quantization testing.

The fp8_quant and dequant helper functions are correctly implemented for testing quantization accuracy. The use of torch.finfo for clamping and proper dtype conversions follows best practices.


600-622: LGTM: Legacy test correctly limited to non-quantization patterns.

The legacy test parameterization only includes kAllReduce and kARResidualRMSNorm patterns, which is correct since the deprecated API doesn't support quantization fusion.


165-166: Verify FP8 tolerance is appropriate.

FP8 pct_tol = 0.002 (0.2% allowed mismatch) seems tight. The NVFP4 tolerance of 1% is documented as matching TRT-LLM's test case. Consider whether the FP8 tolerance should also be aligned with TRT-LLM's reference or if this value has been empirically validated.

Also applies to: 190-190


625-632: LGTM: Direct test invocation for debugging.

The __main__ block allows direct test execution for debugging. The monkeypatch=None is safe since monkeypatch is not used within run_mnnvl_ar_full.

include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh (3)

1240-1241: LGTM: Static assertion enforces quant-only without RMSNorm is unsupported.

The static_assert correctly enforces at compile-time that quantization patterns must be combined with RMSNorm fusion, matching the documented API constraints.


1720-1723: LGTM: Shared memory layout for RMSNorm fusion kernel.

The shared memory calculation correctly allocates three buffers (input, residual, gamma) each of size blockSize * elemsPerThread * sizeof(T).


1375-1377: LGTM: Optional norm output write with nullptr check.

The nullptr check correctly allows skipping the norm output write when only quantized output is needed, supporting the optional output parameter in the quantization fusion path.

csrc/trtllm_mnnvl_allreduce.cu (3)

6-7: LGTM: Using declarations expose quantization types.

The using declarations correctly bring QuantizationSFLayout and QuantType into scope for the TVM-FFI bindings.


95-96: LGTM: FP4 dtype constraint correctly enforced.

The validation correctly enforces that FP4 quantization requires both RMSNorm fusion and FP16/BF16 dtype (sizeof(c_type) == 2). Float32 inputs will fail this check as expected.


155-179: LGTM: Parameter wiring is complete and consistent.

All quantization-related parameters are correctly wired from the TVM-FFI interface to the AllReduceFusionParams struct. The default sfLayout of SWIZZLED_128x4 matches the Python interface default.

Comment on lines 257 to 278
mutates_args=[
"input",
"output",
"multicast_buffer_ptr",
"buffer_ptrs_dev",
"buffer_ptr_local",
"buffer_flags_mnnvl",
"nranks",
"rank",
"rmsnorm_fusion",
"launch_with_pdl",
"use_oneshot",
"output",
"residual_out",
"launch_with_pdl",
"rmsnorm_fusion",
"residual_in",
"residual_out",
"gamma",
"epsilon",
"quant_type",
"quant_out",
"sf_out",
"output_scale",
"layout_code",
],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Search for mutates_args usage in the codebase to understand patterns
rg "mutates_args" flashinfer/ -B2 -A10 | head -100

Repository: flashinfer-ai/flashinfer

Length of output: 192


🏁 Script executed:

#!/bin/bash
# Look at the specific file and function context
grep -n "register_custom_op\|mutates_args\|def " flashinfer/comm/trtllm_mnnvl_ar.py | head -50

Repository: flashinfer-ai/flashinfer

Length of output: 727


🏁 Script executed:

#!/bin/bash
# Get more context around lines 257-278 in the target file
sed -n '240,290p' flashinfer/comm/trtllm_mnnvl_ar.py

Repository: flashinfer-ai/flashinfer

Length of output: 1573


🏁 Script executed:

#!/bin/bash
# Get the complete function signature to see all parameter types
sed -n '280,330p' flashinfer/comm/trtllm_mnnvl_ar.py

Repository: flashinfer-ai/flashinfer

Length of output: 2492


🏁 Script executed:

#!/bin/bash
# Check PyTorch's register_custom_op documentation and patterns
# Look for other usages of register_custom_op in flashinfer
find flashinfer -name "*.py" -type f -exec grep -l "register_custom_op" {} \;

Repository: flashinfer-ai/flashinfer

Length of output: 710


🏁 Script executed:

#!/bin/bash
# Check the utils to understand what register_custom_op expects
grep -n "register_custom_op\|def register_custom_op" flashinfer/utils.py -A30 | head -80

Repository: flashinfer-ai/flashinfer

Length of output: 1741


🏁 Script executed:

#!/bin/bash
# Check if there are other register_custom_op usages to see the pattern
rg "register_custom_op" flashinfer/comm/trtllm_ar.py -A20 | head -60

Repository: flashinfer-ai/flashinfer

Length of output: 1760


🏁 Script executed:

#!/bin/bash
# Verify the parameter types more clearly by looking at the type hints
sed -n '280,310p' flashinfer/comm/trtllm_mnnvl_ar.py | grep -E ":\s*(int|bool|float|torch.Tensor)" | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 611


🏁 Script executed:

#!/bin/bash
# Check the exact status of torch.library usage in utils.py
sed -n '315,370p' flashinfer/utils.py

Repository: flashinfer-ai/flashinfer

Length of output: 1672


🏁 Script executed:

#!/bin/bash
# Search for any comments or TODOs about enabling custom_op in the future
rg "custom_op|torch.library" flashinfer/utils.py -B2 -A2

Repository: flashinfer-ai/flashinfer

Length of output: 815


🏁 Script executed:

#!/bin/bash
# Check if trtllm_ar.py's register_custom_op also has non-tensor params in mutates_args
sed -n '180,230p' flashinfer/comm/trtllm_ar.py | grep -A30 "register_custom_op"

Repository: flashinfer-ai/flashinfer

Length of output: 50


🌐 Web query:

PyTorch torch.library.custom_op mutates_args semantic requirements tensor vs scalar parameters

💡 Result:

Short answer — mutates_args only describes which Tensor arguments your custom_op will mutate; it must list the argument names (or be "unknown") and be exactly correct for mutated Tensors, otherwise behavior is undefined. Scalar/primitive arguments (ints/floats/tuples) are not Tensors and therefore should not be listed as mutated; only Tensor inputs that are actually written to (e.g., an out tensor or an in-place target) belong in mutates_args. If you can’t precisely enumerate mutated Tensor args you may pass "unknown" (pessimistically assumes all inputs may be mutated). [1][2][3]

Key points:

  • mutates_args: Iterable[str] naming Tensor parameters that the op mutates, or "unknown". MUST be accurate. [1]
  • If your op mutates an out Tensor, include that Tensor’s argument name (e.g., "out"). If it returns None and writes into an output argument, that output must be listed. [2]
  • Non-Tensor parameters (scalars, sequences of ints, floats, etc.) are not Tensors and should not be listed as mutated. Only Tensors count. [2]
  • If you pass "unknown", PyTorch assumes all inputs may be mutated (pessimistic). [1]
  • Use torch.library.opcheck(...) to validate your registration against sample inputs. [3]

Sources: PyTorch torch.library.custom_op docs; Python custom-ops tutorial; custom C++/mutable-ops guidance. [1][2][3]

References
[1] torch.library.custom_op docs.
[2] PyTorch tutorial — Custom Python Operators (mutable operators examples).
[3] PyTorch tutorial — Custom C++ and CUDA Operators (mutable operator schema).


mutates_args lists non-Tensor scalar parameters that should not be included.

Per PyTorch's torch.library.custom_op semantics, mutates_args must only name Tensor arguments that are actually mutated in-place. Scalar parameters like nranks, rank, use_oneshot, launch_with_pdl, rmsnorm_fusion, epsilon, quant_type, and layout_code are non-Tensor types (int, bool, float) and cannot be mutated, so they should be removed from this list. While register_custom_op is currently disabled in the codebase (the torch.library.custom_op call is commented out), this will cause failures or undefined behavior once the decorator is re-enabled.

Comment on lines +614 to +635
if quant_type == QuantFusionType.FP8:
if quant_out is None:
quant_out = torch.empty_like(input, dtype=torch.float8_e4m3fn)
# TODO: Do we need further check on the shape?
elif len(quant_out.shape) != 2:
raise ValueError(
f"The quantized output tensor must be 2D, got {len(quant_out.shape)}D. The shape is {quant_out.shape}."
)
elif quant_type == QuantFusionType.NVFP4:
if quant_out is None:
# TODO: PyTorch supports fp4x2 dtype, do we want to use that?
quant_out = torch.empty(
input.shape[0],
input.shape[1] // 2,
dtype=torch.uint8,
device=input.device,
)
# TODO: Do we need further check on the shape?
elif len(quant_out.shape) != 2:
raise ValueError(
f"The quantized output tensor must be 2D, got {len(quant_out.shape)}D. The shape is {quant_out.shape}."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Check the kernel file for sf_out/scalingFactorOut handling
rg -n "sf_out|scalingFactorOut" include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh -A2 -B2

Repository: flashinfer-ai/flashinfer

Length of output: 3930


🏁 Script executed:

#!/bin/bash
# Get context around line 695 to see the return statement
cat -n flashinfer/comm/trtllm_mnnvl_ar.py | sed -n '690,700p'

Repository: flashinfer-ai/flashinfer

Length of output: 530


🏁 Script executed:

#!/bin/bash
# Get the full function signature and context to understand the flow
cat -n flashinfer/comm/trtllm_mnnvl_ar.py | sed -n '550,650p'

Repository: flashinfer-ai/flashinfer

Length of output: 5562


🏁 Script executed:

#!/bin/bash
# Check where this function is called and how sf_out is used
rg -n "trtllm_mnnvl_fused_allreduce_add_rmsnorm_quant" flashinfer/ -A3 -B1 | head -50

Repository: flashinfer-ai/flashinfer

Length of output: 1220


🏁 Script executed:

#!/bin/bash
# See what happens after line 650 in the function
cat -n flashinfer/comm/trtllm_mnnvl_ar.py | sed -n '650,695p'

Repository: flashinfer-ai/flashinfer

Length of output: 2248


Clarify in docstring that sf_out is None for FP8 quantization.

The FP8 quantization path does not allocate sf_out (it remains None by design, since FP8 does not use scaling factors). However, the function docstring lists sf_out as a return value without clarifying this conditional behavior. Update the docstring to explicitly document that sf_out is None when quant_type == QuantFusionType.FP8.

🧰 Tools
🪛 Ruff (0.14.10)

619-621: Avoid specifying long messages outside the exception class

(TRY003)


633-635: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
In @flashinfer/comm/trtllm_mnnvl_ar.py around lines 614 - 635, Update the
function docstring to state that sf_out will be None for FP8 quantization: when
quant_type == QuantFusionType.FP8 the code path does not allocate or return
scaling-factor output (sf_out) so document that sf_out is None in this case, and
keep existing behavior for other quant_types; mention sf_out, quant_out,
quant_type and QuantFusionType.FP8 in the docstring so callers know the
conditional return semantics.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants