Skip to content

Add cute-dsl backends to mxfp[8,4]_quantization for future refactor#2443

Merged
bkryu merged 17 commits intoflashinfer-ai:mainfrom
bkryu:cute-dsl-mxfp8
Mar 18, 2026
Merged

Add cute-dsl backends to mxfp[8,4]_quantization for future refactor#2443
bkryu merged 17 commits intoflashinfer-ai:mainfrom
bkryu:cute-dsl-mxfp8

Conversation

@bkryu
Copy link
Copy Markdown
Collaborator

@bkryu bkryu commented Jan 30, 2026

📌 Description

This PR adds CuTe-DSL backend support for MXFP8 and MXFP4 quantization kernels as alternatives to JIT-compiled CUDA backends

Key changes:

  • Add CuTe-DSL MXFP8 and MXFP4 quantization kernels
  • Reorganize quantization module structure for better maintainability
  • Add benchmarks and unit tests for backend comparison

File Structure Reorganization
Quantization files are now organized in flashinfer/quantization/:

flashinfer/quantization/
├── __init__.py                    # Package exports
├── fp4_quantization.py            # MXFP4 public API
├── fp8_quantization.py            # MXFP8 public API  
├── packbits.py                    # Utility functions
├── quantization_cute_dsl_utils.py # Shared PTX intrinsics
└── kernels/
    ├── __init__.py                # Kernel exports (EXPERIMENTAL)
    ├── mxfp4_quantize.py          # MXFP4 CuTe-DSL kernel
    └── mxfp8_quantize.py          # MXFP8 CuTe-DSL kernel

Performance
CuTe DSL kernels are strong compared to CUDA counterparts:

  • mxfp4_quantization - Geomean 12x speedup; beats cuda backend in all cases in bench_mxfp4_quantize_backend_comparison.py
  • mxfp8_quantization - Geomean ~1.3x speedup; beats cuda backend in all cases in bench_mxfp8_quantize_backend_comparison.py

Expand below for performance heatmaps:

CuTe DSL Backend outperforms CUDA backend on every single case benchmarked in bench_mxfp8_quantize_backend_comparison.py. Click to see performance comparison data

BF16 input; Swizzled cases. > 1.0 means CuTe DSL is faster
sm100_mxfp8_swizzled_bfloat16

BF16 input; Linear cases. > 1.0 means CuTe DSL is faster
sm100_mxfp8_linear_bfloat16

BF16 input; Swizzled cases. Annotated values are achieved TB/s
sm100_mxfp8_bandwidth_linear_bfloat16

BF16 input; Linear cases. Annotated values are achieved TB/s
sm100_mxfp8_bandwidth_swizzled_bfloat16

CuTe DSL Backend outperforms CUDA backend on every single case benchmarked in ‎bench_mxfp4_quantize_backend_comparison.py. Click to see performance comparison data

BF16 input; Swizzled cases. > 1.0 means CuTe DSL is faster
sm100_mxfp4_comparison_bfloat16

BF16 input; Swizzled cases. Annotated values are achieved TB/s
sm100_mxfp4_bandwidth_bfloat16

🔍 Related Issues

#2496

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • CuTe-DSL backend added for MXFP8 and MXFP4 quantization alongside CUDA.
    • Consolidated quantization package exposing unified FP4/FP8 interfaces and conditional CuTe-DSL exports.
    • New end-to-end benchmarking tools for MXFP4 and MXFP8 (correctness, performance, bandwidth, heatmaps).
  • Bug Fixes / Compatibility

    • Backwards-compatible shims preserve existing public API while delegating implementations to the new package.
  • Tests

    • Expanded tests to cover CUDA and CuTe-DSL, availability gating, compilation cache, and backend parity.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Jan 30, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Reorganizes quantization into a new flashinfer.quantization subpackage, adds CuTe‑DSL MXFP4/MXFP8 kernel implementations with compilation caching, updates top‑level shims to re-export new modules, and extends tests/benchmarks to support and compare "cuda" and "cute-dsl" backends.

Changes

Cohort / File(s) Summary
Top-level shims & exports
flashinfer/fp4_quantization.py, flashinfer/fp8_quantization.py, flashinfer/__init__.py
Replaced in-module implementations with backward-compatibility shims that re-export implementations from flashinfer.quantization.* and updated public exports.
Quantization package core
flashinfer/quantization/__init__.py, flashinfer/quantization/fp4_quantization.py, flashinfer/quantization/fp8_quantization.py, flashinfer/quantization/packbits.py
Adds consolidated quantization package, re-exports FP4/FP8 APIs, conditionally exposes CuTe‑DSL kernels, and fixes relative imports.
CuTe‑DSL kernels & utils
flashinfer/quantization/quantization_cute_dsl_utils.py, flashinfer/quantization/kernels/mxfp4_quantize.py, flashinfer/quantization/kernels/mxfp8_quantize.py
Introduces low-level CuTe‑DSL intrinsics, MXFP4/MXFP8 kernel classes, TVM‑FFI compile/cache helpers, and mxfp*_quantize_cute_dsl high-level entrypoints.
Kernels package init
flashinfer/quantization/kernels/__init__.py
New package initializer that re-exports kernel classes and CuTe‑DSL quantize entrypoints.
FP4 module shim
flashinfer/fp4_quantization.py
Converted to a compatibility shim that re-exports public FP4 symbols from flashinfer.quantization.fp4_quantization.
Benchmarks & routine mappings
benchmarks/bench_mxfp4_quantize_backend_comparison.py, benchmarks/bench_mxfp8_quantize_backend_comparison.py, benchmarks/routines/flashinfer_benchmark_utils.py
Added two backend‑comparison benchmarking scripts and updated benchmark routine backend mappings to include "cute-dsl" and record enable_pdl.
Tests (FP4/FP8)
tests/utils/test_fp4_quantize.py, tests/utils/test_fp8_quantize.py, tests/utils/test_fp8_quantize.py
Parameterized tests to exercise cuda and cute-dsl, added CuTe‑DSL availability helpers, compilation‑cache tests, and propagated backend/enable_pdl through test flows with device gating.
Misc. imports / small updates
flashinfer/activation.py, flashinfer/quantization/packbits.py
Adjusted import paths to the new package layout; minor import refactors.

Sequence Diagram(s)

sequenceDiagram
  participant Bench as Benchmark/Test
  participant API as flashinfer.quantization
  participant Kernel as Kernel (cute-dsl / cuda)
  participant Device as GPU Device/Driver

  Bench->>API: call mxfp8_quantize(..., backend)
  API->>Kernel: select backend, determine enable_pdl, compile or fetch cached kernel
  Kernel->>Device: launch compiled kernel / call CUDA kernel
  Device-->>Kernel: execution complete (writes outputs)
  Kernel-->>API: return (quantized_tensor, scales)
  API-->>Bench: return results
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

Possibly related PRs

Suggested labels

cute-dsl, benchmark, run-ci

Suggested reviewers

  • yzh119
  • nv-yunzheq
  • kaixih
  • aleozlx
  • cyx-6

Poem

🐰 I hopped from CUDA into CuTe‑land,
kernels compiled by my little pawed hand,
scales swizzled neat, bytes tucked in rows,
benchmarks hum where the fast wind blows,
carrots for speed — hop on, take a stand!

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 70.25% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title 'Add cute-dsl backends to mxfp[8,4]_quantization for future refactor' clearly and specifically describes the main change: adding CuTe-DSL backend support to MXFP quantization functions.
Description check ✅ Passed The PR description is comprehensive and well-structured, covering all required sections: a detailed description of changes, related issues (#2496), and verification of all pre-commit checks.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
📝 Coding Plan
  • Generate coding plan for human review comments

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @bkryu, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly refactors the MXFP8 quantization implementation by introducing a new, highly optimized backend based on CuTe-DSL. This change provides an alternative, potentially more performant, path for quantization operations, enhancing the flexibility and efficiency of the FlashInfer library. The integration ensures that users can seamlessly switch between CUDA and CuTe-DSL implementations, while comprehensive testing validates the correctness and caching mechanisms of the new kernels.

Highlights

  • New CuTe-DSL MXFP8 Quantization Kernels: Introduced new high-performance MXFP8 quantization kernels implemented using CuTe-DSL, supporting both linear and swizzled (128x4) scale factor layouts. These kernels feature Half2/BFloat2 SIMD for max-abs computation, 4-thread cooperation per scale factor block, vectorized 128-bit global loads/stores, and M-agnostic compilation.
  • Backend Selection for MXFP8 Quantization: The main mxfp8_quantize API now accepts a backend argument, allowing users to choose between the existing JIT-compiled CUDA kernel ('cuda') and the new CuTe-DSL kernel ('cute-dsl'). A runtime check ensures CuTe-DSL is available if selected.
  • Shared Quantization Utilities and Intrinsics: A new quantization_utils.py module was added to house common constants (e.g., SF_VEC_SIZE, FLOAT8_E4M3_MAX) and PTX intrinsics for efficient GPU operations, including Half2/BFloat2 SIMD for max reduction, fast UE8M0 conversion, FP8 conversion with scaling, and warp shuffle for 4-thread reduction.
  • Enhanced Benchmarking and Testing Infrastructure: The benchmarking and testing utilities were updated to support and validate the new 'cute-dsl' backend. This includes adding 'cute-dsl' as a choice in benchmark arguments and extending unit tests to cover the new backend, including specific tests for CuTe-DSL's M-agnostic, K-specific, and dtype-specific compilation caching behavior.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@bkryu bkryu marked this pull request as draft January 30, 2026 17:08
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new cute-dsl backend for MXFP8 quantization, refactoring the existing CUDA implementation. The changes are well-structured, adding new CuTe-DSL kernels for both linear and swizzled layouts, and updating the public API, benchmarks, and tests accordingly. The new kernels correctly use M-agnostic compilation for better performance with varying batch sizes. My review includes a couple of suggestions to improve the maintainability of the new kernel code by explaining a magic number and refactoring a duplicated logic block. The accompanying test updates are comprehensive and include valuable checks for the compilation cache behavior.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In `@flashinfer/cute_dsl/mxfp8_quantize.py`:
- Around line 585-635: The code flattens inputs when input.dim()>2 but never
restores batch dimensions or uses orig_shape; after computing
fp8_output/scale_output and before returning, reshape fp8_tensor and
scale_output back to the original batch shape using orig_shape: for fp8_tensor,
view/reshape to (*orig_shape[:-1], padded_k); for scale_output, convert the 1D
buffer into per-row blocks and then reshape to (*orig_shape[:-1],
num_sf_blocks_per_row) for the linear path (use total_sf_blocks -> view(m,
num_sf_blocks_per_row)), and for the swizzled path convert scale_output via
view(padded_m, padded_sf_cols) then take the first m rows ([:m,
:padded_sf_cols]) and reshape to (*orig_shape[:-1], padded_sf_cols); ensure you
reference orig_shape, padded_k, m, num_sf_blocks_per_row, padded_m and
padded_sf_cols when making these changes.

In `@flashinfer/cute_dsl/quantization_utils.py`:
- Around line 22-23: Remove the unused Uint8 import from the top-level imports
in quantization_utils.py: update the import line that currently reads "from
cutlass import Float32, Int32, Uint32, Uint64, Uint8" to exclude Uint8 so only
used symbols (Float32, Int32, Uint32, Uint64) remain; this will resolve the F401
lint error while leaving functions/classes that reference
Float32/Int32/Uint32/Uint64 untouched.
🧹 Nitpick comments (1)
tests/utils/test_fp8_quantize.py (1)

203-210: Silence unused a_sf warnings in denormal/zero/mixed tests.

Ruff flags a_sf as unused in several tests. Consider replacing it with _ (or _a_sf) to avoid lint noise; same pattern applies to the other occurrences in this file.

♻️ Example fix
-    a_fp8, a_sf = mxfp8_quantize(a, is_sf_swizzled_layout, backend=backend)
+    a_fp8, _ = mxfp8_quantize(a, is_sf_swizzled_layout, backend=backend)

@bkryu
Copy link
Copy Markdown
Collaborator Author

bkryu commented Jan 30, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !283 has been created, and the CI pipeline #42939528 is currently running. I'll report back once the pipeline job completes.

@bkryu
Copy link
Copy Markdown
Collaborator Author

bkryu commented Jan 31, 2026

/bot stop

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

The GitLab CI pipeline #42939528 has been cancelled.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

The GitLab CI pipeline #43311884 has been cancelled.

@bkryu
Copy link
Copy Markdown
Collaborator Author

bkryu commented Feb 5, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !283 has been updated with latest changes, and the CI pipeline #43313609 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #43313609: 14/20 passed

@bkryu bkryu marked this pull request as ready for review February 5, 2026 17:04
@bkryu bkryu requested a review from kahyunnam as a code owner February 5, 2026 17:04
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🤖 Fix all issues with AI agents
In `@flashinfer/quantization/__init__.py`:
- Around line 56-86: The __all__ list in flashinfer/quantization/__init__.py is
unsorted (RUF022); either alphabetically sort the symbols in the __all__ list
(e.g., ensure entries like "block_scale_interleave",
"e2m1_and_ufp8sf_scale_to_float", "get_fp4_quantization_module",
"mxfp4_quantize", "mxfp8_quantize_cute_dsl", etc. are in ASCII order) or if the
current grouped ordering is intentional add an explicit ruff suppression for
RUF022 (e.g., a module-level ruff noqa for RUF022) so the linter is satisfied.
Ensure the change targets the __all__ variable and preserves the conditional
addition when _cute_dsl_available is true.

In `@flashinfer/quantization/fp4_quantization.py`:
- Around line 413-441: The variable name `l` is ambiguous (ruff E741); rename it
to a descriptive identifier (e.g., `batch_size` or `num_groups`) throughout this
function and the related functions to fix the lint error and improve
readability: update the unpacking line `l, m, k_by_2 = input.shape` to
`batch_size, m, k_by_2 = input.shape` (or `num_groups`), update all subsequent
uses of `l` (construction of `output`, `output_scales`, reshapes, the call to
module.silu_and_mul_scaled_nvfp4_experts_quantize, and the final permute/view
lines), and apply the same consistent rename in the other affected blocks (the
functions around lines 449-468, 498-530, 538-557) so every reference (e.g.,
output.view(l * m, ...), output_scales.view(..., l, ...), return output,
output_scales) uses the new identifier.
- Around line 563-600: The call to module.e2m1_and_ufp8sf_scale_to_float_sm100
unconditionally calls global_scale_tensor.cpu() but global_scale_tensor is
optional; guard it and pass a CPU tensor when it's None. Update the invocation
in e2m1_and_ufp8sf_scale_to_float_sm100 so that you pass
(global_scale_tensor.cpu() if global_scale_tensor is not None else a default CPU
float32 tensor, e.g. torch.tensor([1.0], dtype=torch.float32, device='cpu')),
ensuring the default matches the expected shape/dtype the custom op requires.

In `@flashinfer/quantization/fp8_quantization.py`:
- Around line 91-101: The fake op function _fake_mxfp8_quantize_sm100 is missing
the enable_pdl parameter present on the real implementation, causing a signature
mismatch; update the _fake_mxfp8_quantize_sm100 definition to add an enable_pdl:
bool = False parameter (keeping existing defaults for is_sf_swizzled_layout and
alignment) and ensure any callers or the returned tensors behavior remain
unchanged so the fake op signature matches the real mxfp8_quantize
implementation.

In `@tests/utils/test_fp4_quantize.py`:
- Around line 20-29: The helper is_fp4_supported in
tests/utils/test_fp4_quantize.py directly calls torch.cuda.get_device_capability
and should instead use flashinfer.utils.get_compute_capability; update the
function to import get_compute_capability and replace the torch call with
get_compute_capability(device) (keep the existing CUDA version parsing and the
same support logic), ensuring the rest of is_fp4_supported still uses
cuda_version from torch.version.cuda and the same major/minor comparisons.
🧹 Nitpick comments (6)
flashinfer/quantization/kernels/__init__.py (1)

39-45: Consider sorting __all__ for consistency.

Static analysis flagged that __all__ is not sorted. While minor, sorting it alphabetically improves readability and maintainability.

🔧 Suggested fix
 __all__ = [
     "MXFP4QuantizeSwizzledKernel",
+    "MXFP8QuantizeLinearKernel",
+    "MXFP8QuantizeSwizzledKernel",
     "mxfp4_quantize_cute_dsl",
-    "MXFP8QuantizeLinearKernel",
-    "MXFP8QuantizeSwizzledKernel",
     "mxfp8_quantize_cute_dsl",
 ]
flashinfer/quantization/kernels/mxfp4_quantize.py (2)

98-116: Redundant condition in thread count optimization.

Line 103's condition if threads_per_row <= _MAX_THREADS is always true because line 98-100 already handles the case when threads_per_row >= _MAX_THREADS and returns early. The if block can be simplified.

🔧 Suggested simplification
     if threads_per_row >= _MAX_THREADS:
         # Large K: use max threads, will need column loop
         return _MAX_THREADS

-    # threads_per_block should be a multiple of threads_per_row
-    if threads_per_row <= _MAX_THREADS:
-        # Find largest multiple of threads_per_row <= _MAX_THREADS
-        threads = (_MAX_THREADS // threads_per_row) * threads_per_row
-        if threads >= _MIN_THREADS:
-            return threads
-        # If largest multiple is below _MIN_THREADS, use the smallest valid one
-        threads = threads_per_row
-        while threads < _MIN_THREADS:
-            threads += threads_per_row
-        if threads <= _MAX_THREADS:
-            return threads
+    # threads_per_block should be a multiple of threads_per_row
+    # Find largest multiple of threads_per_row <= _MAX_THREADS
+    threads = (_MAX_THREADS // threads_per_row) * threads_per_row
+    if threads >= _MIN_THREADS:
+        return threads
+    # If largest multiple is below _MIN_THREADS, use the smallest valid one
+    threads = threads_per_row
+    while threads < _MIN_THREADS:
+        threads += threads_per_row
+    if threads <= _MAX_THREADS:
+        return threads

     # Fallback to default
     return _DEFAULT_THREADS

155-168: Use explicit None union syntax for type hints.

PEP 484 prohibits implicit Optional. The target_grid parameter should use explicit union syntax for consistency with the rest of the codebase (e.g., line 467 uses bool | None).

🔧 Suggested fix
     def __init__(
         self,
         dtype: cutlass.Numeric,
         K: int,
         enable_pdl: bool = False,
-        target_grid: int = None,
+        target_grid: int | None = None,
     ):
flashinfer/quantization/kernels/mxfp8_quantize.py (2)

75-88: Consider consolidating duplicated _get_target_grid function.

This function is identical to _get_target_grid in mxfp4_quantize.py (lines 58-71). Consider moving it to quantization_cute_dsl_utils.py to avoid code duplication.

#!/bin/bash
# Verify the duplication
echo "=== mxfp4_quantize.py _get_target_grid ==="
rg -A 15 "def _get_target_grid" flashinfer/quantization/kernels/mxfp4_quantize.py

echo ""
echo "=== mxfp8_quantize.py _get_target_grid ==="
rg -A 15 "def _get_target_grid" flashinfer/quantization/kernels/mxfp8_quantize.py

162-176: Use explicit None union syntax for type hints.

For consistency with other parts of the codebase (line 678 uses bool | None), update the target_grid parameter type annotation.

🔧 Suggested fix
     def __init__(
         self,
         dtype: cutlass.Numeric,
         K: int,
         enable_pdl: bool = False,
-        target_grid: int = None,
+        target_grid: int | None = None,
     ):

Apply the same change to MXFP8QuantizeSwizzledKernel.__init__ (line 314), _get_compiled_kernel_linear (line 583), and _get_compiled_kernel_swizzled (line 630).

flashinfer/quantization/quantization_cute_dsl_utils.py (1)

964-1002: Consider sorting __all__ for maintainability.

While the current organization by category (constants, intrinsics, helpers) is logical, sorting alphabetically or at minimum keeping consistent ordering would help with maintainability as the module grows.

Comment on lines +56 to +86
__all__ = [
# Packbits
"packbits",
"segment_packbits",
# JIT module generator
"gen_quantization_module",
# FP8
"mxfp8_quantize",
"mxfp8_dequantize_host",
# FP4
"SfLayout",
"block_scale_interleave",
"nvfp4_block_scale_interleave",
"e2m1_and_ufp8sf_scale_to_float",
"fp4_quantize",
"mxfp4_dequantize_host",
"mxfp4_dequantize",
"mxfp4_quantize",
"nvfp4_quantize",
"nvfp4_batched_quantize",
"shuffle_matrix_a",
"shuffle_matrix_sf_a",
"scaled_fp4_grouped_quantize",
"get_fp4_quantization_module",
]

if _cute_dsl_available:
__all__ += [
"mxfp8_quantize_cute_dsl",
"mxfp4_quantize_cute_dsl",
]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Ruff RUF022: __all__ is not sorted.
Consider sorting to satisfy lint, or explicitly suppress if the grouped ordering is intentional.

🔧 Optional suppression to keep grouped ordering
-__all__ = [
+__all__ = [  # noqa: RUF022 - keep grouped exports
     # Packbits
     "packbits",
     "segment_packbits",
@@
-if _cute_dsl_available:
-    __all__ += [
+if _cute_dsl_available:
+    __all__ += [  # noqa: RUF022 - keep grouped exports
         "mxfp8_quantize_cute_dsl",
         "mxfp4_quantize_cute_dsl",
     ]
🧰 Tools
🪛 Ruff (0.14.14)

[warning] 56-80: __all__ is not sorted

Apply an isort-style sorting to __all__

(RUF022)


[warning] 83-86: __all__ is not sorted

Apply an isort-style sorting to __all__

(RUF022)

🤖 Prompt for AI Agents
In `@flashinfer/quantization/__init__.py` around lines 56 - 86, The __all__ list
in flashinfer/quantization/__init__.py is unsorted (RUF022); either
alphabetically sort the symbols in the __all__ list (e.g., ensure entries like
"block_scale_interleave", "e2m1_and_ufp8sf_scale_to_float",
"get_fp4_quantization_module", "mxfp4_quantize", "mxfp8_quantize_cute_dsl", etc.
are in ASCII order) or if the current grouped ordering is intentional add an
explicit ruff suppression for RUF022 (e.g., a module-level ruff noqa for RUF022)
so the linter is satisfied. Ensure the change targets the __all__ variable and
preserves the conditional addition when _cute_dsl_available is true.

@vincentzed
Copy link
Copy Markdown
Contributor

@bkryu Q:
Did you find success in making 256 bit load in CuteDSL, which I found was not possible seems like? For example, in pure

Cutlass:
https://github.com/HydraQYH/hp_rms_norm/blob/master/hp_rms_norm/csrc/cuda/hp_rms_norm.cuh

Copy link
Copy Markdown
Collaborator

@kahyunnam kahyunnam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, I just left a few questions about compute capability heuristic

"10.0": ["cuda"],
"10.3": ["cuda"],
"12.0": ["cuda"],
"10.0": ["cuda", "cute-dsl"],
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious, why is cute-dsl only enabled above 10.0?

Is it just a future to-do for more testing/benchmarking for <10.0 before enabling?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hardware accelerated MXFP8-related instructions are a feature of Blackwell generation. Hopper should be good for (non-MX-) FP8 hence should not be able to run these kernels.

As such on Hopper or prior, we do not expect users to use MXFP8 (software-emulated MXFP8 is possible but perf would likely be unsatisfcatory)

Copy link
Copy Markdown
Collaborator

@kahyunnam kahyunnam Feb 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh ... this makes a lot of sense 😅

It may change or be removed in future versions without notice.
Use at your own risk for production workloads.
"""
if backend == "cute-dsl":
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also add a compute capability check here for current compilation context (/ current device) being compute >= 10.0, since it seems from benchmarking that we're only testing cute-dsl on 10.0 and above?

bkryu added 2 commits March 11, 2026 22:40
Made-with: Cursor

# Conflicts:
#	flashinfer/fp4_quantization.py
#	tests/utils/test_fp4_quantize.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

♻️ Duplicate comments (1)
flashinfer/quantization/fp4_quantization.py (1)

608-611: ⚠️ Potential issue | 🟠 Major

Guard global_scale_tensor=None before calling .cpu().

The public API defaults this argument to None and forwards it here unchanged, so this dereference raises AttributeError before the host dequantization op runs.

🧩 Suggested fix
+        global_scale_cpu = (
+            torch.tensor([1.0], dtype=torch.float32, device="cpu")
+            if global_scale_tensor is None
+            else global_scale_tensor.cpu()
+        )
         module.e2m1_and_ufp8sf_scale_to_float_sm100(
             e2m1_tensor.cpu(),
             ufp8_scale_tensor.cpu().reshape(-1),
-            global_scale_tensor.cpu(),
+            global_scale_cpu,
             out,
             sf_vec_size,
             ufp8_type,
             is_sf_swizzled_layout,
         )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/quantization/fp4_quantization.py` around lines 608 - 611, The call
to module.e2m1_and_ufp8sf_scale_to_float_sm100 dereferences
global_scale_tensor.cpu() but the public API may pass global_scale_tensor=None;
guard that before calling .cpu() by computing a local value (e.g.
global_scale_cpu = global_scale_tensor.cpu() if global_scale_tensor is not None
else None) and pass global_scale_cpu to e2m1_and_ufp8sf_scale_to_float_sm100; do
the same pattern if any other tensor arguments may be None (keep
e2m1_tensor.cpu() and ufp8_scale_tensor.cpu().reshape(-1) unchanged).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/quantization/fp4_quantization.py`:
- Around line 931-934: The denominator used to compute a_global_sf can be zero
(for all-zero or all-NaN inputs), producing inf; fix by clamping the max to a
small positive epsilon on the same device before dividing: compute denom =
a.float().abs().nan_to_num().max(), then replace denom with
denom.clamp_min(eps_tensor) where eps_tensor = torch.tensor(1e-6,
device=a.device, dtype=denom.dtype) (or use torch.finfo(denom.dtype).eps), then
compute a_global_sf = (448 * 6) / denom_clamped and pass that into fp4_quantize
(same a.cuda(), a_global_sf.cuda(), 32, True, True).
- Around line 459-484: The fake op
_fake_silu_and_mul_scaled_nvfp4_experts_quantize_sm100 currently returns output
in (l, m, k//2) order while the eager path yields (m, k//2, l); before
returning, permute output to match the eager logical layout (e.g., output =
output.permute(1, 2, 0)). Also ensure output_scales is permuted to the exact
same final layout the eager implementation exposes (adjust or add a final
.permute(...) on output_scales to match the eager caller expectation) so both
outputs have identical shapes/order between eager and compiled/fake paths.
- Around line 301-307: The fake implementation
_fake_block_scale_interleave_sm100 currently returns a hard-coded 1-D uint8
tensor sized as if input were 2-D, which breaks shape/dtype inference for 3-D
inputs and non-uint8 dtypes; update the register_fake_op implementation to
mirror the eager function's behavior by using unswizzled_sf.dtype (not
torch.uint8) and compute the output length from unswizzled_sf.shape handling
both 2-D and 3-D (e.g., multiply leading dims then divide by 16) so the fake op
returns the same flat buffer shape and dtype used by the real operator for
compile-time inference.
- Around line 228-240: The fake op _fake_fp4_quantize_sm100 must exactly mirror
the real op signature and output types: add the missing parameters
is_sf_8x4_layout: bool = False and enable_pdl: bool = False to the function
signature (keep existing is_sf_swizzled_layout), change the first returned
tensor to dtype=torch.uint8 with shape [m, k // 2] and change the scale-factors
tensor to dtype=torch.uint8 and sized to account for padded SF vectors using
sf_count = (k + sf_vec_size - 1) // sf_vec_size so the second tensor is
input.new_empty([m * sf_count], dtype=torch.uint8); keep other argument defaults
the same so torch.compile infers the correct schema and metadata for
fp4_quantize.

In `@tests/utils/test_fp4_quantize.py`:
- Line 172: The skip guards call _is_fp4_supported(torch.device("cuda")) which
can probe the wrong GPU; change each occurrence to use the parameterized device
(i.e., _is_fp4_supported(torch.device(device))) so the checks respect the test's
device parameterization—replace all instances in this file (including the guards
around test_fp4_quantization and the other two occurrences) to use
torch.device(device) instead of torch.device("cuda").

---

Duplicate comments:
In `@flashinfer/quantization/fp4_quantization.py`:
- Around line 608-611: The call to module.e2m1_and_ufp8sf_scale_to_float_sm100
dereferences global_scale_tensor.cpu() but the public API may pass
global_scale_tensor=None; guard that before calling .cpu() by computing a local
value (e.g. global_scale_cpu = global_scale_tensor.cpu() if global_scale_tensor
is not None else None) and pass global_scale_cpu to
e2m1_and_ufp8sf_scale_to_float_sm100; do the same pattern if any other tensor
arguments may be None (keep e2m1_tensor.cpu() and
ufp8_scale_tensor.cpu().reshape(-1) unchanged).

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: b4412e6b-338d-4618-8f83-45476dee8435

📥 Commits

Reviewing files that changed from the base of the PR and between f75a24e and 6355473.

📒 Files selected for processing (5)
  • benchmarks/routines/flashinfer_benchmark_utils.py
  • flashinfer/__init__.py
  • flashinfer/fp4_quantization.py
  • flashinfer/quantization/fp4_quantization.py
  • tests/utils/test_fp4_quantize.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • flashinfer/init.py
  • benchmarks/routines/flashinfer_benchmark_utils.py

Comment on lines +228 to +240
@register_fake_op("flashinfer::fp4_quantize_sm100")
def _fake_fp4_quantize_sm100(
input: torch.Tensor,
global_scale: Optional[torch.Tensor] = None,
sf_vec_size: int = 16,
sf_use_ue8m0: bool = False,
is_sf_swizzled_layout: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
m, k = input.shape
return (
input.new_empty([m, k // 2], dtype=torch.int64), # FLOAT4_E2M1X2
input.new_empty([m * k // sf_vec_size], dtype=torch.int32), # Scale factors
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Make _fake_fp4_quantize_sm100 mirror the real op.

This fake op omits is_sf_8x4_layout and enable_pdl, returns int64/int32 instead of uint8/uint8, and always uses the unpadded SF size. fp4_quantize() calls the real op with both extra arguments on Lines 693-700 and defaults to swizzled SFs, so torch.compile will infer the wrong schema and metadata here.

🧩 Suggested fix
 `@register_fake_op`("flashinfer::fp4_quantize_sm100")
 def _fake_fp4_quantize_sm100(
     input: torch.Tensor,
     global_scale: Optional[torch.Tensor] = None,
     sf_vec_size: int = 16,
     sf_use_ue8m0: bool = False,
     is_sf_swizzled_layout: bool = True,
+    is_sf_8x4_layout: bool = False,
+    enable_pdl: Optional[bool] = None,
 ) -> Tuple[torch.Tensor, torch.Tensor]:
-    m, k = input.shape
+    m = input.numel() // input.shape[-1]
+    k = input.shape[-1]
+    if is_sf_swizzled_layout:
+        out_sf_size = _compute_swizzled_layout_sf_size(
+            m, k // sf_vec_size, 8 if is_sf_8x4_layout else 128
+        )
+    else:
+        out_sf_size = m * k // sf_vec_size
     return (
-        input.new_empty([m, k // 2], dtype=torch.int64),  # FLOAT4_E2M1X2
-        input.new_empty([m * k // sf_vec_size], dtype=torch.int32),  # Scale factors
+        input.new_empty((*input.shape[:-1], k // 2), dtype=torch.uint8),
+        input.new_empty((out_sf_size,), dtype=torch.uint8),
     )

Based on learnings, functions decorated with register_fake_op are abstract implementations for torch.compile shape/dtype inference, and their signatures must exactly mirror the corresponding real op.

🧰 Tools
🪛 Ruff (0.15.5)

[warning] 231-231: Unused function argument: global_scale

(ARG001)


[warning] 233-233: Unused function argument: sf_use_ue8m0

(ARG001)


[warning] 234-234: Unused function argument: is_sf_swizzled_layout

(ARG001)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/quantization/fp4_quantization.py` around lines 228 - 240, The fake
op _fake_fp4_quantize_sm100 must exactly mirror the real op signature and output
types: add the missing parameters is_sf_8x4_layout: bool = False and enable_pdl:
bool = False to the function signature (keep existing is_sf_swizzled_layout),
change the first returned tensor to dtype=torch.uint8 with shape [m, k // 2] and
change the scale-factors tensor to dtype=torch.uint8 and sized to account for
padded SF vectors using sf_count = (k + sf_vec_size - 1) // sf_vec_size so the
second tensor is input.new_empty([m * sf_count], dtype=torch.uint8); keep other
argument defaults the same so torch.compile infers the correct schema and
metadata for fp4_quantize.

Comment on lines +459 to +484
@register_fake_op("flashinfer::silu_and_mul_scaled_nvfp4_experts_quantize_sm100")
def _fake_silu_and_mul_scaled_nvfp4_experts_quantize_sm100(
input: torch.Tensor,
mask: torch.Tensor,
global_scale: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
device = input.device
l, m, k_by_2 = input.shape
k = k_by_2 // 2
sf_vec_size = 16
assert k % sf_vec_size == 0, f"k must be multiple of 16, but got {k}."

scale_k = k // sf_vec_size
padded_k = (scale_k + (4 - 1)) // 4 * 4
padded_k_int32 = padded_k // 4
padded_m = (m + (128 - 1)) // 128 * 128
output = torch.empty(l, m, k // 2, device=device, dtype=torch.uint8)
output_scales = torch.empty(
l, padded_m, padded_k_int32, device=device, dtype=torch.int32
)

output_scales = output_scales.view(torch.float8_e4m3fn).view(
l, padded_m // 128, padded_k // 4, 32, 4, 4
)
output_scales = output_scales.permute(3, 4, 1, 5, 2, 0)
return (output, output_scales)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Return the same logical layout from the fake expert-quantize op.

The eager path permutes output to (m, k // 2, l) on Line 452, but the fake path returns the unpermuted (l, m, k // 2) buffer. Any compiled caller will observe a different output shape than eager mode.

🧩 Suggested fix
     output = torch.empty(l, m, k // 2, device=device, dtype=torch.uint8)
     output_scales = torch.empty(
         l, padded_m, padded_k_int32, device=device, dtype=torch.int32
     )
 
+    output = output.permute(1, 2, 0)
     output_scales = output_scales.view(torch.float8_e4m3fn).view(
         l, padded_m // 128, padded_k // 4, 32, 4, 4
     )
     output_scales = output_scales.permute(3, 4, 1, 5, 2, 0)
     return (output, output_scales)

Based on learnings, functions decorated with register_fake_op are abstract implementations for torch.compile shape/dtype inference.

🧰 Tools
🪛 Ruff (0.15.5)

[warning] 462-462: Unused function argument: mask

(ARG001)


[warning] 463-463: Unused function argument: global_scale

(ARG001)


[error] 466-466: Ambiguous variable name: l

(E741)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/quantization/fp4_quantization.py` around lines 459 - 484, The fake
op _fake_silu_and_mul_scaled_nvfp4_experts_quantize_sm100 currently returns
output in (l, m, k//2) order while the eager path yields (m, k//2, l); before
returning, permute output to match the eager logical layout (e.g., output =
output.permute(1, 2, 0)). Also ensure output_scales is permuted to the exact
same final layout the eager implementation exposes (adjust or add a final
.permute(...) on output_scales to match the eager caller expectation) so both
outputs have identical shapes/order between eager and compiled/fake paths.

Comment on lines +931 to +934
elif backend == "cuda":
a_global_sf = (448 * 6) / a.float().abs().nan_to_num().max()
a_fp4, a_sf = fp4_quantize(a.cuda(), a_global_sf.cuda(), 32, True, True)
return a_fp4, a_sf
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Clamp the MXFP4 scale denominator for zero/NaN-only inputs.

If a is all zeros or all NaNs, a.float().abs().nan_to_num().max() becomes 0, so this computes an infinite a_global_sf. That bad scale then flows straight into fp4_quantize.

🧩 Suggested fix
     elif backend == "cuda":
-        a_global_sf = (448 * 6) / a.float().abs().nan_to_num().max()
+        max_abs = a.float().abs().nan_to_num().max()
+        a_global_sf = (448 * 6) / max_abs.clamp_min(torch.finfo(max_abs.dtype).tiny)
         a_fp4, a_sf = fp4_quantize(a.cuda(), a_global_sf.cuda(), 32, True, True)
         return a_fp4, a_sf
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/quantization/fp4_quantization.py` around lines 931 - 934, The
denominator used to compute a_global_sf can be zero (for all-zero or all-NaN
inputs), producing inf; fix by clamping the max to a small positive epsilon on
the same device before dividing: compute denom =
a.float().abs().nan_to_num().max(), then replace denom with
denom.clamp_min(eps_tensor) where eps_tensor = torch.tensor(1e-6,
device=a.device, dtype=denom.dtype) (or use torch.finfo(denom.dtype).eps), then
compute a_global_sf = (448 * 6) / denom_clamped and pass that into fp4_quantize
(same a.cuda(), a_global_sf.cuda(), 32, True, True).

or is_sm110a_supported(torch.device("cuda"))
or is_sm12x_supported(torch.device("cuda"))
):
if not _is_fp4_supported(torch.device("cuda")):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Use the parameterized device in these skip guards.

These checks run before torch.set_default_device(device), so torch.device("cuda") can probe the wrong GPU on multi-device or heterogeneous hosts. Use torch.device(device) consistently, as you already do in test_fp4_quantization.

Suggested fix
-    if not _is_fp4_supported(torch.device("cuda")):
+    if not _is_fp4_supported(torch.device(device)):

Also applies to: 208-208, 257-257

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/utils/test_fp4_quantize.py` at line 172, The skip guards call
_is_fp4_supported(torch.device("cuda")) which can probe the wrong GPU; change
each occurrence to use the parameterized device (i.e.,
_is_fp4_supported(torch.device(device))) so the checks respect the test's device
parameterization—replace all instances in this file (including the guards around
test_fp4_quantization and the other two occurrences) to use torch.device(device)
instead of torch.device("cuda").

Copy link
Copy Markdown
Collaborator

@nv-yunzheq nv-yunzheq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

# =============================================================================


@pytest.mark.parametrize("is_sf_swizzled_layout", [True, False])
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to have these specific tests in the unit test file?

@bkryu
Copy link
Copy Markdown
Collaborator Author

bkryu commented Mar 17, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !283 has been updated with latest changes, and the CI pipeline #46363680 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (2)
flashinfer/quantization/kernels/mxfp8_quantize.py (2)

690-695: Consider adding explicit validation for input tensor dimensions.

The reshaping logic handles dim() > 2 by flattening to 2D, and the else branch assumes exactly 2 dimensions. If a 1D tensor is passed, m, k = input.shape on line 695 will raise a ValueError with a confusing message.

Consider adding an explicit check:

🛠️ Suggested validation
+    assert input.dim() >= 2, f"Input must be at least 2D, got {input.dim()}D tensor"
+
     if input.dim() > 2:
         m = input.numel() // input.shape[-1]
         k = input.shape[-1]
         input = input.reshape(m, k)
     else:
         m, k = input.shape
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/quantization/kernels/mxfp8_quantize.py` around lines 690 - 695,
The current reshape logic assumes input.dim() == 2 or >2 and will crash with a
confusing ValueError if a 1D tensor is passed; add an explicit validation before
the reshape branch that checks input.dim() and raises a clear ValueError (or
documents/handles 1D inputs) if dim < 2, e.g., validate input.dim() >= 2 and
include the tensor shape in the error message; update the block that computes m
and k (the variables input, m, k and the reshape behavior) so callers get a
deterministic error instead of an ambiguous unpacking failure.

761-767: Consider sorting __all__ for consistency.

Static analysis suggests applying isort-style sorting. This is purely stylistic and optional.

🔧 Optional sort
 __all__ = [
     "MXFP8QuantizeLinearKernel",
     "MXFP8QuantizeSwizzledKernel",
+    "_get_compiled_kernel_linear",
+    "_get_compiled_kernel_swizzled",
     "mxfp8_quantize_cute_dsl",
-    "_get_compiled_kernel_linear",
-    "_get_compiled_kernel_swizzled",
 ]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/quantization/kernels/mxfp8_quantize.py` around lines 761 - 767,
The __all__ list is unsorted; please alphabetically sort the exported symbol
names in the __all__ list (e.g., "MXFP8QuantizeLinearKernel",
"MXFP8QuantizeSwizzledKernel", "mxfp8_quantize_cute_dsl",
"_get_compiled_kernel_linear", "_get_compiled_kernel_swizzled") so the order is
consistent with isort-style conventions and easier to maintain.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/quantization/kernels/mxfp4_quantize.py`:
- Around line 487-520: The K validation must ensure K is non-zero and that
num_sf_blocks_per_row = K // MXFP4_SF_VEC_SIZE is divisible by 4 before
compiling/using kernels; update the early assertions/checks (near the existing
assert for MXFP4_SF_VEC_SIZE) to raise a clear error when K == 0 and when
num_sf_blocks_per_row % 4 != 0 so downstream logic in _get_compiled_kernel_mxfp4
and _compute_optimal_threads_for_k won't divide by zero or produce an
unmatchable padded_sf_cols/reshape; reference MXFP4_SF_VEC_SIZE,
num_sf_blocks_per_row, padded_sf_cols, _get_compiled_kernel_mxfp4 and
_compute_optimal_threads_for_k when making the change.
- Around line 443-491: The function mxfp4_quantize_cute_dsl is shadowing
Python's built-in input by using the parameter/local variable name "input";
rename that parameter and all local references (e.g., the reshaping/contiguous
uses and device checks) to a non-conflicting name like "tensor" or "src"
throughout the function (and update the docstring parameter name) to eliminate
the Ruff A002/A001 warnings while preserving all existing logic and assertions
(retain checks for dtype, is_cuda, PDL detection, shape handling,
MXFP4_SF_VEC_SIZE assertion, and the final contiguous call).

---

Nitpick comments:
In `@flashinfer/quantization/kernels/mxfp8_quantize.py`:
- Around line 690-695: The current reshape logic assumes input.dim() == 2 or >2
and will crash with a confusing ValueError if a 1D tensor is passed; add an
explicit validation before the reshape branch that checks input.dim() and raises
a clear ValueError (or documents/handles 1D inputs) if dim < 2, e.g., validate
input.dim() >= 2 and include the tensor shape in the error message; update the
block that computes m and k (the variables input, m, k and the reshape behavior)
so callers get a deterministic error instead of an ambiguous unpacking failure.
- Around line 761-767: The __all__ list is unsorted; please alphabetically sort
the exported symbol names in the __all__ list (e.g.,
"MXFP8QuantizeLinearKernel", "MXFP8QuantizeSwizzledKernel",
"mxfp8_quantize_cute_dsl", "_get_compiled_kernel_linear",
"_get_compiled_kernel_swizzled") so the order is consistent with isort-style
conventions and easier to maintain.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 1299c992-5494-49f9-9fef-84370de5f153

📥 Commits

Reviewing files that changed from the base of the PR and between 6f6e61e and 8df3348.

📒 Files selected for processing (2)
  • flashinfer/quantization/kernels/mxfp4_quantize.py
  • flashinfer/quantization/kernels/mxfp8_quantize.py

Comment on lines +443 to +491
def mxfp4_quantize_cute_dsl(
input: torch.Tensor,
enable_pdl: bool | None = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Quantize input tensor to MXFP4 format using CuTe-DSL kernel.

This is a GPU implementation matching FlashInfer's mxfp4_quantize() behavior:
- Global scale computed as (448 * 6) / max(|input|)
- UE8M0 scale factors
- E2M1 output format (4-bit, 2 values per byte)
- Swizzled (128x4) scale factor layout

The kernel is compiled once per (K, dtype, pdl) combination and handles
varying M (batch size) at runtime without recompilation.

Args:
input: Input tensor of shape [M, K] with dtype fp16/bf16
enable_pdl: Whether to enable PDL (Programmatic Dependent Launch).
If None, automatically detects based on device capability (SM >= 9.0).

Returns:
Tuple of:
- fp4_tensor: Quantized tensor of shape [M, K/2] with dtype uint8
- scale_tensor: Scale factors as uint8 tensor (swizzled layout)
"""
from ...utils import device_support_pdl

assert input.dtype in (torch.float16, torch.bfloat16), (
f"Input dtype must be float16 or bfloat16, got {input.dtype}"
)
assert input.is_cuda, "Input must be on CUDA device"

# Auto-detect PDL support based on device capability
if enable_pdl is None:
enable_pdl = device_support_pdl(input.device)

if input.dim() > 2:
m = input.numel() // input.shape[-1]
k = input.shape[-1]
input = input.reshape(m, k)
else:
m, k = input.shape

assert k % MXFP4_SF_VEC_SIZE == 0, (
f"K ({k}) must be divisible by MXFP4_SF_VEC_SIZE={MXFP4_SF_VEC_SIZE}"
)

input = input.contiguous()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Avoid shadowing Python’s built-in input.

Line 444 (parameter) and reassignment at Lines 483/491 shadow the built-in input, which is currently flagged by Ruff (A002/A001).

Proposed fix
-def mxfp4_quantize_cute_dsl(
-    input: torch.Tensor,
+def mxfp4_quantize_cute_dsl(
+    x: torch.Tensor,
     enable_pdl: bool | None = None,
 ) -> Tuple[torch.Tensor, torch.Tensor]:
@@
-    assert input.dtype in (torch.float16, torch.bfloat16), (
-        f"Input dtype must be float16 or bfloat16, got {input.dtype}"
+    assert x.dtype in (torch.float16, torch.bfloat16), (
+        f"Input dtype must be float16 or bfloat16, got {x.dtype}"
     )
-    assert input.is_cuda, "Input must be on CUDA device"
+    assert x.is_cuda, "Input must be on CUDA device"
@@
-        enable_pdl = device_support_pdl(input.device)
+        enable_pdl = device_support_pdl(x.device)
@@
-    if input.dim() > 2:
-        m = input.numel() // input.shape[-1]
-        k = input.shape[-1]
-        input = input.reshape(m, k)
+    if x.dim() > 2:
+        m = x.numel() // x.shape[-1]
+        k = x.shape[-1]
+        x = x.reshape(m, k)
     else:
-        m, k = input.shape
+        m, k = x.shape
@@
-    input = input.contiguous()
-    is_bfloat16 = input.dtype == torch.bfloat16
+    x = x.contiguous()
+    is_bfloat16 = x.dtype == torch.bfloat16
@@
-    target_grid = get_num_sm(input.device) * _BLOCKS_PER_SM
+    target_grid = get_num_sm(x.device) * _BLOCKS_PER_SM
@@
-    fp4_output = torch.empty(m, k // 2, dtype=torch.uint8, device=input.device)
+    fp4_output = torch.empty(m, k // 2, dtype=torch.uint8, device=x.device)
     scale_output = torch.empty(
-        scale_output_size, dtype=torch.uint8, device=input.device
+        scale_output_size, dtype=torch.uint8, device=x.device
     )
@@
-    kernel_fn(input, fp4_output, scale_output, m, padded_m, num_blocks)
+    kernel_fn(x, fp4_output, scale_output, m, padded_m, num_blocks)
🧰 Tools
🪛 Ruff (0.15.6)

[error] 444-444: Function argument input is shadowing a Python builtin

(A002)


[error] 483-483: Variable input is shadowing a Python builtin

(A001)


[error] 491-491: Variable input is shadowing a Python builtin

(A001)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/quantization/kernels/mxfp4_quantize.py` around lines 443 - 491,
The function mxfp4_quantize_cute_dsl is shadowing Python's built-in input by
using the parameter/local variable name "input"; rename that parameter and all
local references (e.g., the reshaping/contiguous uses and device checks) to a
non-conflicting name like "tensor" or "src" throughout the function (and update
the docstring parameter name) to eliminate the Ruff A002/A001 warnings while
preserving all existing logic and assertions (retain checks for dtype, is_cuda,
PDL detection, shape handling, MXFP4_SF_VEC_SIZE assertion, and the final
contiguous call).

Comment on lines +487 to +520
assert k % MXFP4_SF_VEC_SIZE == 0, (
f"K ({k}) must be divisible by MXFP4_SF_VEC_SIZE={MXFP4_SF_VEC_SIZE}"
)

input = input.contiguous()
is_bfloat16 = input.dtype == torch.bfloat16

# Cached device-specific target grid for grid size computation
target_grid = get_num_sm(input.device) * _BLOCKS_PER_SM

# Compute M-dependent values
num_sf_blocks_per_row = k // MXFP4_SF_VEC_SIZE
padded_m = ((m + ROW_TILE_SIZE - 1) // ROW_TILE_SIZE) * ROW_TILE_SIZE
padded_sf_cols = ((num_sf_blocks_per_row + 3) // 4) * 4
scale_output_size = padded_m * padded_sf_cols

# Get or compile kernel (device-independent)
kernel_fn, rows_per_block = _get_compiled_kernel_mxfp4(is_bfloat16, k, enable_pdl)

# Compute grid size in Python (runtime, device-specific)
num_blocks = min((padded_m + rows_per_block - 1) // rows_per_block, target_grid)

# Allocate outputs
fp4_output = torch.empty(m, k // 2, dtype=torch.uint8, device=input.device)
scale_output = torch.empty(
scale_output_size, dtype=torch.uint8, device=input.device
)

# Launch kernel
kernel_fn(input, fp4_output, scale_output, m, padded_m, num_blocks)

# Reshape scale output to match CUDA backend format: [padded_total, num_sf_per_row]
scale_output = scale_output.reshape(-1, num_sf_blocks_per_row)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Validate K for swizzled-column compatibility before compiling.

At Line 487, the check allows any K % 32 == 0, but Line 500 pads scale columns to multiples of 4 and Line 519 reshapes using unpadded num_sf_blocks_per_row. For K/32 not divisible by 4, this can break reshape semantics; for K=0, _compute_optimal_threads_for_k hits division by zero at Line 90.

Proposed fix
-    assert k % MXFP4_SF_VEC_SIZE == 0, (
-        f"K ({k}) must be divisible by MXFP4_SF_VEC_SIZE={MXFP4_SF_VEC_SIZE}"
-    )
+    if k <= 0 or k % MXFP4_SF_VEC_SIZE != 0:
+        raise ValueError(
+            f"K ({k}) must be a positive multiple of {MXFP4_SF_VEC_SIZE}"
+        )
+    # Swizzled 128x4 layout requires 4 scale-factor blocks per swizzle group.
+    if (k // MXFP4_SF_VEC_SIZE) % 4 != 0:
+        raise ValueError(
+            "CuTe-DSL MXFP4 swizzled backend currently requires K divisible by 128."
+        )
🧰 Tools
🪛 Ruff (0.15.6)

[error] 491-491: Variable input is shadowing a Python builtin

(A001)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/quantization/kernels/mxfp4_quantize.py` around lines 487 - 520,
The K validation must ensure K is non-zero and that num_sf_blocks_per_row = K //
MXFP4_SF_VEC_SIZE is divisible by 4 before compiling/using kernels; update the
early assertions/checks (near the existing assert for MXFP4_SF_VEC_SIZE) to
raise a clear error when K == 0 and when num_sf_blocks_per_row % 4 != 0 so
downstream logic in _get_compiled_kernel_mxfp4 and
_compute_optimal_threads_for_k won't divide by zero or produce an unmatchable
padded_sf_cols/reshape; reference MXFP4_SF_VEC_SIZE, num_sf_blocks_per_row,
padded_sf_cols, _get_compiled_kernel_mxfp4 and _compute_optimal_threads_for_k
when making the change.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[SUCCESS] Pipeline #46363680: 14/20 passed

@bkryu
Copy link
Copy Markdown
Collaborator Author

bkryu commented Mar 18, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !283 has been updated with latest changes, and the CI pipeline #46401068 is currently running. I'll report back once the pipeline job completes.

@bkryu bkryu merged commit 8e53cce into flashinfer-ai:main Mar 18, 2026
28 of 33 checks passed
frankwang28 pushed a commit to frankwang28/flashinfer that referenced this pull request Mar 18, 2026
…lashinfer-ai#2443)

<!-- .github/pull_request_template.md -->

## 📌 Description
This PR adds CuTe-DSL backend support for MXFP8 and MXFP4 quantization
kernels as alternatives to JIT-compiled CUDA backends

Key changes:
- Add CuTe-DSL MXFP8 and MXFP4 quantization kernels
- Reorganize quantization module structure for better maintainability
- Add benchmarks and unit tests for backend comparison

**File Structure Reorganization**
Quantization files are now organized in `flashinfer/quantization/`:
```
flashinfer/quantization/
├── __init__.py                    # Package exports
├── fp4_quantization.py            # MXFP4 public API
├── fp8_quantization.py            # MXFP8 public API  
├── packbits.py                    # Utility functions
├── quantization_cute_dsl_utils.py # Shared PTX intrinsics
└── kernels/
    ├── __init__.py                # Kernel exports (EXPERIMENTAL)
    ├── mxfp4_quantize.py          # MXFP4 CuTe-DSL kernel
    └── mxfp8_quantize.py          # MXFP8 CuTe-DSL kernel
```

**Performance**
CuTe DSL kernels are strong compared to CUDA counterparts:
- mxfp4_quantization - Geomean 12x speedup; beats cuda backend in all
cases in `bench_mxfp4_quantize_backend_comparison.py`
- mxfp8_quantization - Geomean ~1.3x speedup; beats cuda backend in all
cases in `bench_mxfp8_quantize_backend_comparison.py`

Expand below for performance heatmaps:

<details>
<summary>CuTe DSL Backend outperforms CUDA backend on every single case
benchmarked in bench_mxfp8_quantize_backend_comparison.py. Click to see
performance comparison data</summary>


**BF16 input; Swizzled cases. > 1.0 means CuTe DSL is faster**
<img width="1644" height="1477" alt="sm100_mxfp8_swizzled_bfloat16"
src="https://github.com/user-attachments/assets/107279a6-8fc4-4aba-843d-34a83a12acb0"
/>

**BF16 input; Linear cases. > 1.0 means CuTe DSL is faster**
<img width="1644" height="1477" alt="sm100_mxfp8_linear_bfloat16"
src="https://github.com/user-attachments/assets/1317ab55-c9ac-4284-bf9a-5127070fe0ad"
/>

**BF16 input; Swizzled cases. Annotated values are achieved TB/s**
<img width="1646" height="1481"
alt="sm100_mxfp8_bandwidth_linear_bfloat16"
src="https://github.com/user-attachments/assets/033e0692-2eef-4ff7-95f6-94a1d098dbe7"
/>

**BF16 input; Linear cases. Annotated values are achieved TB/s**
<img width="1646" height="1481"
alt="sm100_mxfp8_bandwidth_swizzled_bfloat16"
src="https://github.com/user-attachments/assets/543f7cd2-0d3a-4f7b-b465-7423f1738d9c"
/>


</details>

<details>
<summary>CuTe DSL Backend outperforms CUDA backend on every single case
benchmarked in ‎bench_mxfp4_quantize_backend_comparison.py. Click to see
performance comparison data</summary>

**BF16 input; Swizzled cases. > 1.0 means CuTe DSL is faster**
<img width="1658" height="1477" alt="sm100_mxfp4_comparison_bfloat16"
src="https://github.com/user-attachments/assets/bbaae310-581a-4035-9e06-0c437263da55"
/>


**BF16 input; Swizzled cases. Annotated values are achieved TB/s**
<img width="1646" height="1481" alt="sm100_mxfp4_bandwidth_bfloat16"
src="https://github.com/user-attachments/assets/d7798935-2112-4b73-b127-4095fede8b18"
/>


</details>

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

## 🔍 Related Issues

flashinfer-ai#2496
<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* CuTe-DSL backend added for MXFP8 and MXFP4 quantization alongside
CUDA.
* Consolidated quantization package exposing unified FP4/FP8 interfaces
and conditional CuTe-DSL exports.
* New end-to-end benchmarking tools for MXFP4 and MXFP8 (correctness,
performance, bandwidth, heatmaps).

* **Bug Fixes / Compatibility**
* Backwards-compatible shims preserve existing public API while
delegating implementations to the new package.

* **Tests**
* Expanded tests to cover CUDA and CuTe-DSL, availability gating,
compilation cache, and backend parity.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
ameynaik-hub pushed a commit to ameynaik-hub/flashinfer that referenced this pull request Mar 18, 2026
…lashinfer-ai#2443)

<!-- .github/pull_request_template.md -->

## 📌 Description
This PR adds CuTe-DSL backend support for MXFP8 and MXFP4 quantization
kernels as alternatives to JIT-compiled CUDA backends

Key changes:
- Add CuTe-DSL MXFP8 and MXFP4 quantization kernels
- Reorganize quantization module structure for better maintainability
- Add benchmarks and unit tests for backend comparison

**File Structure Reorganization**
Quantization files are now organized in `flashinfer/quantization/`:
```
flashinfer/quantization/
├── __init__.py                    # Package exports
├── fp4_quantization.py            # MXFP4 public API
├── fp8_quantization.py            # MXFP8 public API
├── packbits.py                    # Utility functions
├── quantization_cute_dsl_utils.py # Shared PTX intrinsics
└── kernels/
    ├── __init__.py                # Kernel exports (EXPERIMENTAL)
    ├── mxfp4_quantize.py          # MXFP4 CuTe-DSL kernel
    └── mxfp8_quantize.py          # MXFP8 CuTe-DSL kernel
```

**Performance**
CuTe DSL kernels are strong compared to CUDA counterparts:
- mxfp4_quantization - Geomean 12x speedup; beats cuda backend in all
cases in `bench_mxfp4_quantize_backend_comparison.py`
- mxfp8_quantization - Geomean ~1.3x speedup; beats cuda backend in all
cases in `bench_mxfp8_quantize_backend_comparison.py`

Expand below for performance heatmaps:

<details>
<summary>CuTe DSL Backend outperforms CUDA backend on every single case
benchmarked in bench_mxfp8_quantize_backend_comparison.py. Click to see
performance comparison data</summary>

**BF16 input; Swizzled cases. > 1.0 means CuTe DSL is faster**
<img width="1644" height="1477" alt="sm100_mxfp8_swizzled_bfloat16"
src="https://github.com/user-attachments/assets/107279a6-8fc4-4aba-843d-34a83a12acb0"
/>

**BF16 input; Linear cases. > 1.0 means CuTe DSL is faster**
<img width="1644" height="1477" alt="sm100_mxfp8_linear_bfloat16"
src="https://github.com/user-attachments/assets/1317ab55-c9ac-4284-bf9a-5127070fe0ad"
/>

**BF16 input; Swizzled cases. Annotated values are achieved TB/s**
<img width="1646" height="1481"
alt="sm100_mxfp8_bandwidth_linear_bfloat16"
src="https://github.com/user-attachments/assets/033e0692-2eef-4ff7-95f6-94a1d098dbe7"
/>

**BF16 input; Linear cases. Annotated values are achieved TB/s**
<img width="1646" height="1481"
alt="sm100_mxfp8_bandwidth_swizzled_bfloat16"
src="https://github.com/user-attachments/assets/543f7cd2-0d3a-4f7b-b465-7423f1738d9c"
/>

</details>

<details>
<summary>CuTe DSL Backend outperforms CUDA backend on every single case
benchmarked in ‎bench_mxfp4_quantize_backend_comparison.py. Click to see
performance comparison data</summary>

**BF16 input; Swizzled cases. > 1.0 means CuTe DSL is faster**
<img width="1658" height="1477" alt="sm100_mxfp4_comparison_bfloat16"
src="https://github.com/user-attachments/assets/bbaae310-581a-4035-9e06-0c437263da55"
/>

**BF16 input; Swizzled cases. Annotated values are achieved TB/s**
<img width="1646" height="1481" alt="sm100_mxfp4_bandwidth_bfloat16"
src="https://github.com/user-attachments/assets/d7798935-2112-4b73-b127-4095fede8b18"
/>

</details>

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

## 🔍 Related Issues

flashinfer-ai#2496
<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* CuTe-DSL backend added for MXFP8 and MXFP4 quantization alongside
CUDA.
* Consolidated quantization package exposing unified FP4/FP8 interfaces
and conditional CuTe-DSL exports.
* New end-to-end benchmarking tools for MXFP4 and MXFP8 (correctness,
performance, bandwidth, heatmaps).

* **Bug Fixes / Compatibility**
* Backwards-compatible shims preserve existing public API while
delegating implementations to the new package.

* **Tests**
* Expanded tests to cover CUDA and CuTe-DSL, availability gating,
compilation cache, and backend parity.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants