Add cute-dsl backends to mxfp[8,4]_quantization for future refactor#2443
Add cute-dsl backends to mxfp[8,4]_quantization for future refactor#2443bkryu merged 17 commits intoflashinfer-ai:mainfrom
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughReorganizes quantization into a new Changes
Sequence Diagram(s)sequenceDiagram
participant Bench as Benchmark/Test
participant API as flashinfer.quantization
participant Kernel as Kernel (cute-dsl / cuda)
participant Device as GPU Device/Driver
Bench->>API: call mxfp8_quantize(..., backend)
API->>Kernel: select backend, determine enable_pdl, compile or fetch cached kernel
Kernel->>Device: launch compiled kernel / call CUDA kernel
Device-->>Kernel: execution complete (writes outputs)
Kernel-->>API: return (quantized_tensor, scales)
API-->>Bench: return results
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
📝 Coding Plan
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @bkryu, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly refactors the MXFP8 quantization implementation by introducing a new, highly optimized backend based on CuTe-DSL. This change provides an alternative, potentially more performant, path for quantization operations, enhancing the flexibility and efficiency of the FlashInfer library. The integration ensures that users can seamlessly switch between CUDA and CuTe-DSL implementations, while comprehensive testing validates the correctness and caching mechanisms of the new kernels. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a new cute-dsl backend for MXFP8 quantization, refactoring the existing CUDA implementation. The changes are well-structured, adding new CuTe-DSL kernels for both linear and swizzled layouts, and updating the public API, benchmarks, and tests accordingly. The new kernels correctly use M-agnostic compilation for better performance with varying batch sizes. My review includes a couple of suggestions to improve the maintainability of the new kernel code by explaining a magic number and refactoring a duplicated logic block. The accompanying test updates are comprehensive and include valuable checks for the compilation cache behavior.
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Fix all issues with AI agents
In `@flashinfer/cute_dsl/mxfp8_quantize.py`:
- Around line 585-635: The code flattens inputs when input.dim()>2 but never
restores batch dimensions or uses orig_shape; after computing
fp8_output/scale_output and before returning, reshape fp8_tensor and
scale_output back to the original batch shape using orig_shape: for fp8_tensor,
view/reshape to (*orig_shape[:-1], padded_k); for scale_output, convert the 1D
buffer into per-row blocks and then reshape to (*orig_shape[:-1],
num_sf_blocks_per_row) for the linear path (use total_sf_blocks -> view(m,
num_sf_blocks_per_row)), and for the swizzled path convert scale_output via
view(padded_m, padded_sf_cols) then take the first m rows ([:m,
:padded_sf_cols]) and reshape to (*orig_shape[:-1], padded_sf_cols); ensure you
reference orig_shape, padded_k, m, num_sf_blocks_per_row, padded_m and
padded_sf_cols when making these changes.
In `@flashinfer/cute_dsl/quantization_utils.py`:
- Around line 22-23: Remove the unused Uint8 import from the top-level imports
in quantization_utils.py: update the import line that currently reads "from
cutlass import Float32, Int32, Uint32, Uint64, Uint8" to exclude Uint8 so only
used symbols (Float32, Int32, Uint32, Uint64) remain; this will resolve the F401
lint error while leaving functions/classes that reference
Float32/Int32/Uint32/Uint64 untouched.
🧹 Nitpick comments (1)
tests/utils/test_fp8_quantize.py (1)
203-210: Silence unuseda_sfwarnings in denormal/zero/mixed tests.Ruff flags
a_sfas unused in several tests. Consider replacing it with_(or_a_sf) to avoid lint noise; same pattern applies to the other occurrences in this file.♻️ Example fix
- a_fp8, a_sf = mxfp8_quantize(a, is_sf_swizzled_layout, backend=backend) + a_fp8, _ = mxfp8_quantize(a, is_sf_swizzled_layout, backend=backend)
|
/bot run |
|
/bot stop |
|
The GitLab CI pipeline #42939528 has been cancelled. |
|
The GitLab CI pipeline #43311884 has been cancelled. |
|
/bot run |
|
[FAILED] Pipeline #43313609: 14/20 passed |
There was a problem hiding this comment.
Actionable comments posted: 5
🤖 Fix all issues with AI agents
In `@flashinfer/quantization/__init__.py`:
- Around line 56-86: The __all__ list in flashinfer/quantization/__init__.py is
unsorted (RUF022); either alphabetically sort the symbols in the __all__ list
(e.g., ensure entries like "block_scale_interleave",
"e2m1_and_ufp8sf_scale_to_float", "get_fp4_quantization_module",
"mxfp4_quantize", "mxfp8_quantize_cute_dsl", etc. are in ASCII order) or if the
current grouped ordering is intentional add an explicit ruff suppression for
RUF022 (e.g., a module-level ruff noqa for RUF022) so the linter is satisfied.
Ensure the change targets the __all__ variable and preserves the conditional
addition when _cute_dsl_available is true.
In `@flashinfer/quantization/fp4_quantization.py`:
- Around line 413-441: The variable name `l` is ambiguous (ruff E741); rename it
to a descriptive identifier (e.g., `batch_size` or `num_groups`) throughout this
function and the related functions to fix the lint error and improve
readability: update the unpacking line `l, m, k_by_2 = input.shape` to
`batch_size, m, k_by_2 = input.shape` (or `num_groups`), update all subsequent
uses of `l` (construction of `output`, `output_scales`, reshapes, the call to
module.silu_and_mul_scaled_nvfp4_experts_quantize, and the final permute/view
lines), and apply the same consistent rename in the other affected blocks (the
functions around lines 449-468, 498-530, 538-557) so every reference (e.g.,
output.view(l * m, ...), output_scales.view(..., l, ...), return output,
output_scales) uses the new identifier.
- Around line 563-600: The call to module.e2m1_and_ufp8sf_scale_to_float_sm100
unconditionally calls global_scale_tensor.cpu() but global_scale_tensor is
optional; guard it and pass a CPU tensor when it's None. Update the invocation
in e2m1_and_ufp8sf_scale_to_float_sm100 so that you pass
(global_scale_tensor.cpu() if global_scale_tensor is not None else a default CPU
float32 tensor, e.g. torch.tensor([1.0], dtype=torch.float32, device='cpu')),
ensuring the default matches the expected shape/dtype the custom op requires.
In `@flashinfer/quantization/fp8_quantization.py`:
- Around line 91-101: The fake op function _fake_mxfp8_quantize_sm100 is missing
the enable_pdl parameter present on the real implementation, causing a signature
mismatch; update the _fake_mxfp8_quantize_sm100 definition to add an enable_pdl:
bool = False parameter (keeping existing defaults for is_sf_swizzled_layout and
alignment) and ensure any callers or the returned tensors behavior remain
unchanged so the fake op signature matches the real mxfp8_quantize
implementation.
In `@tests/utils/test_fp4_quantize.py`:
- Around line 20-29: The helper is_fp4_supported in
tests/utils/test_fp4_quantize.py directly calls torch.cuda.get_device_capability
and should instead use flashinfer.utils.get_compute_capability; update the
function to import get_compute_capability and replace the torch call with
get_compute_capability(device) (keep the existing CUDA version parsing and the
same support logic), ensuring the rest of is_fp4_supported still uses
cuda_version from torch.version.cuda and the same major/minor comparisons.
🧹 Nitpick comments (6)
flashinfer/quantization/kernels/__init__.py (1)
39-45: Consider sorting__all__for consistency.Static analysis flagged that
__all__is not sorted. While minor, sorting it alphabetically improves readability and maintainability.🔧 Suggested fix
__all__ = [ "MXFP4QuantizeSwizzledKernel", + "MXFP8QuantizeLinearKernel", + "MXFP8QuantizeSwizzledKernel", "mxfp4_quantize_cute_dsl", - "MXFP8QuantizeLinearKernel", - "MXFP8QuantizeSwizzledKernel", "mxfp8_quantize_cute_dsl", ]flashinfer/quantization/kernels/mxfp4_quantize.py (2)
98-116: Redundant condition in thread count optimization.Line 103's condition
if threads_per_row <= _MAX_THREADSis always true because line 98-100 already handles the case whenthreads_per_row >= _MAX_THREADSand returns early. Theifblock can be simplified.🔧 Suggested simplification
if threads_per_row >= _MAX_THREADS: # Large K: use max threads, will need column loop return _MAX_THREADS - # threads_per_block should be a multiple of threads_per_row - if threads_per_row <= _MAX_THREADS: - # Find largest multiple of threads_per_row <= _MAX_THREADS - threads = (_MAX_THREADS // threads_per_row) * threads_per_row - if threads >= _MIN_THREADS: - return threads - # If largest multiple is below _MIN_THREADS, use the smallest valid one - threads = threads_per_row - while threads < _MIN_THREADS: - threads += threads_per_row - if threads <= _MAX_THREADS: - return threads + # threads_per_block should be a multiple of threads_per_row + # Find largest multiple of threads_per_row <= _MAX_THREADS + threads = (_MAX_THREADS // threads_per_row) * threads_per_row + if threads >= _MIN_THREADS: + return threads + # If largest multiple is below _MIN_THREADS, use the smallest valid one + threads = threads_per_row + while threads < _MIN_THREADS: + threads += threads_per_row + if threads <= _MAX_THREADS: + return threads # Fallback to default return _DEFAULT_THREADS
155-168: Use explicitNoneunion syntax for type hints.PEP 484 prohibits implicit
Optional. Thetarget_gridparameter should use explicit union syntax for consistency with the rest of the codebase (e.g., line 467 usesbool | None).🔧 Suggested fix
def __init__( self, dtype: cutlass.Numeric, K: int, enable_pdl: bool = False, - target_grid: int = None, + target_grid: int | None = None, ):flashinfer/quantization/kernels/mxfp8_quantize.py (2)
75-88: Consider consolidating duplicated_get_target_gridfunction.This function is identical to
_get_target_gridinmxfp4_quantize.py(lines 58-71). Consider moving it toquantization_cute_dsl_utils.pyto avoid code duplication.#!/bin/bash # Verify the duplication echo "=== mxfp4_quantize.py _get_target_grid ===" rg -A 15 "def _get_target_grid" flashinfer/quantization/kernels/mxfp4_quantize.py echo "" echo "=== mxfp8_quantize.py _get_target_grid ===" rg -A 15 "def _get_target_grid" flashinfer/quantization/kernels/mxfp8_quantize.py
162-176: Use explicitNoneunion syntax for type hints.For consistency with other parts of the codebase (line 678 uses
bool | None), update thetarget_gridparameter type annotation.🔧 Suggested fix
def __init__( self, dtype: cutlass.Numeric, K: int, enable_pdl: bool = False, - target_grid: int = None, + target_grid: int | None = None, ):Apply the same change to
MXFP8QuantizeSwizzledKernel.__init__(line 314),_get_compiled_kernel_linear(line 583), and_get_compiled_kernel_swizzled(line 630).flashinfer/quantization/quantization_cute_dsl_utils.py (1)
964-1002: Consider sorting__all__for maintainability.While the current organization by category (constants, intrinsics, helpers) is logical, sorting alphabetically or at minimum keeping consistent ordering would help with maintainability as the module grows.
| __all__ = [ | ||
| # Packbits | ||
| "packbits", | ||
| "segment_packbits", | ||
| # JIT module generator | ||
| "gen_quantization_module", | ||
| # FP8 | ||
| "mxfp8_quantize", | ||
| "mxfp8_dequantize_host", | ||
| # FP4 | ||
| "SfLayout", | ||
| "block_scale_interleave", | ||
| "nvfp4_block_scale_interleave", | ||
| "e2m1_and_ufp8sf_scale_to_float", | ||
| "fp4_quantize", | ||
| "mxfp4_dequantize_host", | ||
| "mxfp4_dequantize", | ||
| "mxfp4_quantize", | ||
| "nvfp4_quantize", | ||
| "nvfp4_batched_quantize", | ||
| "shuffle_matrix_a", | ||
| "shuffle_matrix_sf_a", | ||
| "scaled_fp4_grouped_quantize", | ||
| "get_fp4_quantization_module", | ||
| ] | ||
|
|
||
| if _cute_dsl_available: | ||
| __all__ += [ | ||
| "mxfp8_quantize_cute_dsl", | ||
| "mxfp4_quantize_cute_dsl", | ||
| ] |
There was a problem hiding this comment.
Ruff RUF022: __all__ is not sorted.
Consider sorting to satisfy lint, or explicitly suppress if the grouped ordering is intentional.
🔧 Optional suppression to keep grouped ordering
-__all__ = [
+__all__ = [ # noqa: RUF022 - keep grouped exports
# Packbits
"packbits",
"segment_packbits",
@@
-if _cute_dsl_available:
- __all__ += [
+if _cute_dsl_available:
+ __all__ += [ # noqa: RUF022 - keep grouped exports
"mxfp8_quantize_cute_dsl",
"mxfp4_quantize_cute_dsl",
]🧰 Tools
🪛 Ruff (0.14.14)
[warning] 56-80: __all__ is not sorted
Apply an isort-style sorting to __all__
(RUF022)
[warning] 83-86: __all__ is not sorted
Apply an isort-style sorting to __all__
(RUF022)
🤖 Prompt for AI Agents
In `@flashinfer/quantization/__init__.py` around lines 56 - 86, The __all__ list
in flashinfer/quantization/__init__.py is unsorted (RUF022); either
alphabetically sort the symbols in the __all__ list (e.g., ensure entries like
"block_scale_interleave", "e2m1_and_ufp8sf_scale_to_float",
"get_fp4_quantization_module", "mxfp4_quantize", "mxfp8_quantize_cute_dsl", etc.
are in ASCII order) or if the current grouped ordering is intentional add an
explicit ruff suppression for RUF022 (e.g., a module-level ruff noqa for RUF022)
so the linter is satisfied. Ensure the change targets the __all__ variable and
preserves the conditional addition when _cute_dsl_available is true.
|
@bkryu Q: |
kahyunnam
left a comment
There was a problem hiding this comment.
LGTM, I just left a few questions about compute capability heuristic
| "10.0": ["cuda"], | ||
| "10.3": ["cuda"], | ||
| "12.0": ["cuda"], | ||
| "10.0": ["cuda", "cute-dsl"], |
There was a problem hiding this comment.
Just curious, why is cute-dsl only enabled above 10.0?
Is it just a future to-do for more testing/benchmarking for <10.0 before enabling?
There was a problem hiding this comment.
Hardware accelerated MXFP8-related instructions are a feature of Blackwell generation. Hopper should be good for (non-MX-) FP8 hence should not be able to run these kernels.
As such on Hopper or prior, we do not expect users to use MXFP8 (software-emulated MXFP8 is possible but perf would likely be unsatisfcatory)
There was a problem hiding this comment.
Oh ... this makes a lot of sense 😅
| It may change or be removed in future versions without notice. | ||
| Use at your own risk for production workloads. | ||
| """ | ||
| if backend == "cute-dsl": |
There was a problem hiding this comment.
Should we also add a compute capability check here for current compilation context (/ current device) being compute >= 10.0, since it seems from benchmarking that we're only testing cute-dsl on 10.0 and above?
Made-with: Cursor # Conflicts: # flashinfer/fp4_quantization.py # tests/utils/test_fp4_quantize.py
There was a problem hiding this comment.
Actionable comments posted: 5
♻️ Duplicate comments (1)
flashinfer/quantization/fp4_quantization.py (1)
608-611:⚠️ Potential issue | 🟠 MajorGuard
global_scale_tensor=Nonebefore calling.cpu().The public API defaults this argument to
Noneand forwards it here unchanged, so this dereference raisesAttributeErrorbefore the host dequantization op runs.🧩 Suggested fix
+ global_scale_cpu = ( + torch.tensor([1.0], dtype=torch.float32, device="cpu") + if global_scale_tensor is None + else global_scale_tensor.cpu() + ) module.e2m1_and_ufp8sf_scale_to_float_sm100( e2m1_tensor.cpu(), ufp8_scale_tensor.cpu().reshape(-1), - global_scale_tensor.cpu(), + global_scale_cpu, out, sf_vec_size, ufp8_type, is_sf_swizzled_layout, )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/quantization/fp4_quantization.py` around lines 608 - 611, The call to module.e2m1_and_ufp8sf_scale_to_float_sm100 dereferences global_scale_tensor.cpu() but the public API may pass global_scale_tensor=None; guard that before calling .cpu() by computing a local value (e.g. global_scale_cpu = global_scale_tensor.cpu() if global_scale_tensor is not None else None) and pass global_scale_cpu to e2m1_and_ufp8sf_scale_to_float_sm100; do the same pattern if any other tensor arguments may be None (keep e2m1_tensor.cpu() and ufp8_scale_tensor.cpu().reshape(-1) unchanged).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/quantization/fp4_quantization.py`:
- Around line 931-934: The denominator used to compute a_global_sf can be zero
(for all-zero or all-NaN inputs), producing inf; fix by clamping the max to a
small positive epsilon on the same device before dividing: compute denom =
a.float().abs().nan_to_num().max(), then replace denom with
denom.clamp_min(eps_tensor) where eps_tensor = torch.tensor(1e-6,
device=a.device, dtype=denom.dtype) (or use torch.finfo(denom.dtype).eps), then
compute a_global_sf = (448 * 6) / denom_clamped and pass that into fp4_quantize
(same a.cuda(), a_global_sf.cuda(), 32, True, True).
- Around line 459-484: The fake op
_fake_silu_and_mul_scaled_nvfp4_experts_quantize_sm100 currently returns output
in (l, m, k//2) order while the eager path yields (m, k//2, l); before
returning, permute output to match the eager logical layout (e.g., output =
output.permute(1, 2, 0)). Also ensure output_scales is permuted to the exact
same final layout the eager implementation exposes (adjust or add a final
.permute(...) on output_scales to match the eager caller expectation) so both
outputs have identical shapes/order between eager and compiled/fake paths.
- Around line 301-307: The fake implementation
_fake_block_scale_interleave_sm100 currently returns a hard-coded 1-D uint8
tensor sized as if input were 2-D, which breaks shape/dtype inference for 3-D
inputs and non-uint8 dtypes; update the register_fake_op implementation to
mirror the eager function's behavior by using unswizzled_sf.dtype (not
torch.uint8) and compute the output length from unswizzled_sf.shape handling
both 2-D and 3-D (e.g., multiply leading dims then divide by 16) so the fake op
returns the same flat buffer shape and dtype used by the real operator for
compile-time inference.
- Around line 228-240: The fake op _fake_fp4_quantize_sm100 must exactly mirror
the real op signature and output types: add the missing parameters
is_sf_8x4_layout: bool = False and enable_pdl: bool = False to the function
signature (keep existing is_sf_swizzled_layout), change the first returned
tensor to dtype=torch.uint8 with shape [m, k // 2] and change the scale-factors
tensor to dtype=torch.uint8 and sized to account for padded SF vectors using
sf_count = (k + sf_vec_size - 1) // sf_vec_size so the second tensor is
input.new_empty([m * sf_count], dtype=torch.uint8); keep other argument defaults
the same so torch.compile infers the correct schema and metadata for
fp4_quantize.
In `@tests/utils/test_fp4_quantize.py`:
- Line 172: The skip guards call _is_fp4_supported(torch.device("cuda")) which
can probe the wrong GPU; change each occurrence to use the parameterized device
(i.e., _is_fp4_supported(torch.device(device))) so the checks respect the test's
device parameterization—replace all instances in this file (including the guards
around test_fp4_quantization and the other two occurrences) to use
torch.device(device) instead of torch.device("cuda").
---
Duplicate comments:
In `@flashinfer/quantization/fp4_quantization.py`:
- Around line 608-611: The call to module.e2m1_and_ufp8sf_scale_to_float_sm100
dereferences global_scale_tensor.cpu() but the public API may pass
global_scale_tensor=None; guard that before calling .cpu() by computing a local
value (e.g. global_scale_cpu = global_scale_tensor.cpu() if global_scale_tensor
is not None else None) and pass global_scale_cpu to
e2m1_and_ufp8sf_scale_to_float_sm100; do the same pattern if any other tensor
arguments may be None (keep e2m1_tensor.cpu() and
ufp8_scale_tensor.cpu().reshape(-1) unchanged).
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: b4412e6b-338d-4618-8f83-45476dee8435
📒 Files selected for processing (5)
benchmarks/routines/flashinfer_benchmark_utils.pyflashinfer/__init__.pyflashinfer/fp4_quantization.pyflashinfer/quantization/fp4_quantization.pytests/utils/test_fp4_quantize.py
🚧 Files skipped from review as they are similar to previous changes (2)
- flashinfer/init.py
- benchmarks/routines/flashinfer_benchmark_utils.py
| @register_fake_op("flashinfer::fp4_quantize_sm100") | ||
| def _fake_fp4_quantize_sm100( | ||
| input: torch.Tensor, | ||
| global_scale: Optional[torch.Tensor] = None, | ||
| sf_vec_size: int = 16, | ||
| sf_use_ue8m0: bool = False, | ||
| is_sf_swizzled_layout: bool = True, | ||
| ) -> Tuple[torch.Tensor, torch.Tensor]: | ||
| m, k = input.shape | ||
| return ( | ||
| input.new_empty([m, k // 2], dtype=torch.int64), # FLOAT4_E2M1X2 | ||
| input.new_empty([m * k // sf_vec_size], dtype=torch.int32), # Scale factors | ||
| ) |
There was a problem hiding this comment.
Make _fake_fp4_quantize_sm100 mirror the real op.
This fake op omits is_sf_8x4_layout and enable_pdl, returns int64/int32 instead of uint8/uint8, and always uses the unpadded SF size. fp4_quantize() calls the real op with both extra arguments on Lines 693-700 and defaults to swizzled SFs, so torch.compile will infer the wrong schema and metadata here.
🧩 Suggested fix
`@register_fake_op`("flashinfer::fp4_quantize_sm100")
def _fake_fp4_quantize_sm100(
input: torch.Tensor,
global_scale: Optional[torch.Tensor] = None,
sf_vec_size: int = 16,
sf_use_ue8m0: bool = False,
is_sf_swizzled_layout: bool = True,
+ is_sf_8x4_layout: bool = False,
+ enable_pdl: Optional[bool] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
- m, k = input.shape
+ m = input.numel() // input.shape[-1]
+ k = input.shape[-1]
+ if is_sf_swizzled_layout:
+ out_sf_size = _compute_swizzled_layout_sf_size(
+ m, k // sf_vec_size, 8 if is_sf_8x4_layout else 128
+ )
+ else:
+ out_sf_size = m * k // sf_vec_size
return (
- input.new_empty([m, k // 2], dtype=torch.int64), # FLOAT4_E2M1X2
- input.new_empty([m * k // sf_vec_size], dtype=torch.int32), # Scale factors
+ input.new_empty((*input.shape[:-1], k // 2), dtype=torch.uint8),
+ input.new_empty((out_sf_size,), dtype=torch.uint8),
)Based on learnings, functions decorated with register_fake_op are abstract implementations for torch.compile shape/dtype inference, and their signatures must exactly mirror the corresponding real op.
🧰 Tools
🪛 Ruff (0.15.5)
[warning] 231-231: Unused function argument: global_scale
(ARG001)
[warning] 233-233: Unused function argument: sf_use_ue8m0
(ARG001)
[warning] 234-234: Unused function argument: is_sf_swizzled_layout
(ARG001)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/quantization/fp4_quantization.py` around lines 228 - 240, The fake
op _fake_fp4_quantize_sm100 must exactly mirror the real op signature and output
types: add the missing parameters is_sf_8x4_layout: bool = False and enable_pdl:
bool = False to the function signature (keep existing is_sf_swizzled_layout),
change the first returned tensor to dtype=torch.uint8 with shape [m, k // 2] and
change the scale-factors tensor to dtype=torch.uint8 and sized to account for
padded SF vectors using sf_count = (k + sf_vec_size - 1) // sf_vec_size so the
second tensor is input.new_empty([m * sf_count], dtype=torch.uint8); keep other
argument defaults the same so torch.compile infers the correct schema and
metadata for fp4_quantize.
| @register_fake_op("flashinfer::silu_and_mul_scaled_nvfp4_experts_quantize_sm100") | ||
| def _fake_silu_and_mul_scaled_nvfp4_experts_quantize_sm100( | ||
| input: torch.Tensor, | ||
| mask: torch.Tensor, | ||
| global_scale: Optional[torch.Tensor] = None, | ||
| ) -> Tuple[torch.Tensor, torch.Tensor]: | ||
| device = input.device | ||
| l, m, k_by_2 = input.shape | ||
| k = k_by_2 // 2 | ||
| sf_vec_size = 16 | ||
| assert k % sf_vec_size == 0, f"k must be multiple of 16, but got {k}." | ||
|
|
||
| scale_k = k // sf_vec_size | ||
| padded_k = (scale_k + (4 - 1)) // 4 * 4 | ||
| padded_k_int32 = padded_k // 4 | ||
| padded_m = (m + (128 - 1)) // 128 * 128 | ||
| output = torch.empty(l, m, k // 2, device=device, dtype=torch.uint8) | ||
| output_scales = torch.empty( | ||
| l, padded_m, padded_k_int32, device=device, dtype=torch.int32 | ||
| ) | ||
|
|
||
| output_scales = output_scales.view(torch.float8_e4m3fn).view( | ||
| l, padded_m // 128, padded_k // 4, 32, 4, 4 | ||
| ) | ||
| output_scales = output_scales.permute(3, 4, 1, 5, 2, 0) | ||
| return (output, output_scales) |
There was a problem hiding this comment.
Return the same logical layout from the fake expert-quantize op.
The eager path permutes output to (m, k // 2, l) on Line 452, but the fake path returns the unpermuted (l, m, k // 2) buffer. Any compiled caller will observe a different output shape than eager mode.
🧩 Suggested fix
output = torch.empty(l, m, k // 2, device=device, dtype=torch.uint8)
output_scales = torch.empty(
l, padded_m, padded_k_int32, device=device, dtype=torch.int32
)
+ output = output.permute(1, 2, 0)
output_scales = output_scales.view(torch.float8_e4m3fn).view(
l, padded_m // 128, padded_k // 4, 32, 4, 4
)
output_scales = output_scales.permute(3, 4, 1, 5, 2, 0)
return (output, output_scales)Based on learnings, functions decorated with register_fake_op are abstract implementations for torch.compile shape/dtype inference.
🧰 Tools
🪛 Ruff (0.15.5)
[warning] 462-462: Unused function argument: mask
(ARG001)
[warning] 463-463: Unused function argument: global_scale
(ARG001)
[error] 466-466: Ambiguous variable name: l
(E741)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/quantization/fp4_quantization.py` around lines 459 - 484, The fake
op _fake_silu_and_mul_scaled_nvfp4_experts_quantize_sm100 currently returns
output in (l, m, k//2) order while the eager path yields (m, k//2, l); before
returning, permute output to match the eager logical layout (e.g., output =
output.permute(1, 2, 0)). Also ensure output_scales is permuted to the exact
same final layout the eager implementation exposes (adjust or add a final
.permute(...) on output_scales to match the eager caller expectation) so both
outputs have identical shapes/order between eager and compiled/fake paths.
| elif backend == "cuda": | ||
| a_global_sf = (448 * 6) / a.float().abs().nan_to_num().max() | ||
| a_fp4, a_sf = fp4_quantize(a.cuda(), a_global_sf.cuda(), 32, True, True) | ||
| return a_fp4, a_sf |
There was a problem hiding this comment.
Clamp the MXFP4 scale denominator for zero/NaN-only inputs.
If a is all zeros or all NaNs, a.float().abs().nan_to_num().max() becomes 0, so this computes an infinite a_global_sf. That bad scale then flows straight into fp4_quantize.
🧩 Suggested fix
elif backend == "cuda":
- a_global_sf = (448 * 6) / a.float().abs().nan_to_num().max()
+ max_abs = a.float().abs().nan_to_num().max()
+ a_global_sf = (448 * 6) / max_abs.clamp_min(torch.finfo(max_abs.dtype).tiny)
a_fp4, a_sf = fp4_quantize(a.cuda(), a_global_sf.cuda(), 32, True, True)
return a_fp4, a_sf🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/quantization/fp4_quantization.py` around lines 931 - 934, The
denominator used to compute a_global_sf can be zero (for all-zero or all-NaN
inputs), producing inf; fix by clamping the max to a small positive epsilon on
the same device before dividing: compute denom =
a.float().abs().nan_to_num().max(), then replace denom with
denom.clamp_min(eps_tensor) where eps_tensor = torch.tensor(1e-6,
device=a.device, dtype=denom.dtype) (or use torch.finfo(denom.dtype).eps), then
compute a_global_sf = (448 * 6) / denom_clamped and pass that into fp4_quantize
(same a.cuda(), a_global_sf.cuda(), 32, True, True).
| or is_sm110a_supported(torch.device("cuda")) | ||
| or is_sm12x_supported(torch.device("cuda")) | ||
| ): | ||
| if not _is_fp4_supported(torch.device("cuda")): |
There was a problem hiding this comment.
Use the parameterized device in these skip guards.
These checks run before torch.set_default_device(device), so torch.device("cuda") can probe the wrong GPU on multi-device or heterogeneous hosts. Use torch.device(device) consistently, as you already do in test_fp4_quantization.
Suggested fix
- if not _is_fp4_supported(torch.device("cuda")):
+ if not _is_fp4_supported(torch.device(device)):Also applies to: 208-208, 257-257
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/utils/test_fp4_quantize.py` at line 172, The skip guards call
_is_fp4_supported(torch.device("cuda")) which can probe the wrong GPU; change
each occurrence to use the parameterized device (i.e.,
_is_fp4_supported(torch.device(device))) so the checks respect the test's device
parameterization—replace all instances in this file (including the guards around
test_fp4_quantization and the other two occurrences) to use torch.device(device)
instead of torch.device("cuda").
| # ============================================================================= | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("is_sf_swizzled_layout", [True, False]) |
There was a problem hiding this comment.
Do we need to have these specific tests in the unit test file?
|
/bot run |
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (2)
flashinfer/quantization/kernels/mxfp8_quantize.py (2)
690-695: Consider adding explicit validation for input tensor dimensions.The reshaping logic handles
dim() > 2by flattening to 2D, and the else branch assumes exactly 2 dimensions. If a 1D tensor is passed,m, k = input.shapeon line 695 will raise aValueErrorwith a confusing message.Consider adding an explicit check:
🛠️ Suggested validation
+ assert input.dim() >= 2, f"Input must be at least 2D, got {input.dim()}D tensor" + if input.dim() > 2: m = input.numel() // input.shape[-1] k = input.shape[-1] input = input.reshape(m, k) else: m, k = input.shape🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/quantization/kernels/mxfp8_quantize.py` around lines 690 - 695, The current reshape logic assumes input.dim() == 2 or >2 and will crash with a confusing ValueError if a 1D tensor is passed; add an explicit validation before the reshape branch that checks input.dim() and raises a clear ValueError (or documents/handles 1D inputs) if dim < 2, e.g., validate input.dim() >= 2 and include the tensor shape in the error message; update the block that computes m and k (the variables input, m, k and the reshape behavior) so callers get a deterministic error instead of an ambiguous unpacking failure.
761-767: Consider sorting__all__for consistency.Static analysis suggests applying isort-style sorting. This is purely stylistic and optional.
🔧 Optional sort
__all__ = [ "MXFP8QuantizeLinearKernel", "MXFP8QuantizeSwizzledKernel", + "_get_compiled_kernel_linear", + "_get_compiled_kernel_swizzled", "mxfp8_quantize_cute_dsl", - "_get_compiled_kernel_linear", - "_get_compiled_kernel_swizzled", ]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/quantization/kernels/mxfp8_quantize.py` around lines 761 - 767, The __all__ list is unsorted; please alphabetically sort the exported symbol names in the __all__ list (e.g., "MXFP8QuantizeLinearKernel", "MXFP8QuantizeSwizzledKernel", "mxfp8_quantize_cute_dsl", "_get_compiled_kernel_linear", "_get_compiled_kernel_swizzled") so the order is consistent with isort-style conventions and easier to maintain.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/quantization/kernels/mxfp4_quantize.py`:
- Around line 487-520: The K validation must ensure K is non-zero and that
num_sf_blocks_per_row = K // MXFP4_SF_VEC_SIZE is divisible by 4 before
compiling/using kernels; update the early assertions/checks (near the existing
assert for MXFP4_SF_VEC_SIZE) to raise a clear error when K == 0 and when
num_sf_blocks_per_row % 4 != 0 so downstream logic in _get_compiled_kernel_mxfp4
and _compute_optimal_threads_for_k won't divide by zero or produce an
unmatchable padded_sf_cols/reshape; reference MXFP4_SF_VEC_SIZE,
num_sf_blocks_per_row, padded_sf_cols, _get_compiled_kernel_mxfp4 and
_compute_optimal_threads_for_k when making the change.
- Around line 443-491: The function mxfp4_quantize_cute_dsl is shadowing
Python's built-in input by using the parameter/local variable name "input";
rename that parameter and all local references (e.g., the reshaping/contiguous
uses and device checks) to a non-conflicting name like "tensor" or "src"
throughout the function (and update the docstring parameter name) to eliminate
the Ruff A002/A001 warnings while preserving all existing logic and assertions
(retain checks for dtype, is_cuda, PDL detection, shape handling,
MXFP4_SF_VEC_SIZE assertion, and the final contiguous call).
---
Nitpick comments:
In `@flashinfer/quantization/kernels/mxfp8_quantize.py`:
- Around line 690-695: The current reshape logic assumes input.dim() == 2 or >2
and will crash with a confusing ValueError if a 1D tensor is passed; add an
explicit validation before the reshape branch that checks input.dim() and raises
a clear ValueError (or documents/handles 1D inputs) if dim < 2, e.g., validate
input.dim() >= 2 and include the tensor shape in the error message; update the
block that computes m and k (the variables input, m, k and the reshape behavior)
so callers get a deterministic error instead of an ambiguous unpacking failure.
- Around line 761-767: The __all__ list is unsorted; please alphabetically sort
the exported symbol names in the __all__ list (e.g.,
"MXFP8QuantizeLinearKernel", "MXFP8QuantizeSwizzledKernel",
"mxfp8_quantize_cute_dsl", "_get_compiled_kernel_linear",
"_get_compiled_kernel_swizzled") so the order is consistent with isort-style
conventions and easier to maintain.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 1299c992-5494-49f9-9fef-84370de5f153
📒 Files selected for processing (2)
flashinfer/quantization/kernels/mxfp4_quantize.pyflashinfer/quantization/kernels/mxfp8_quantize.py
| def mxfp4_quantize_cute_dsl( | ||
| input: torch.Tensor, | ||
| enable_pdl: bool | None = None, | ||
| ) -> Tuple[torch.Tensor, torch.Tensor]: | ||
| """ | ||
| Quantize input tensor to MXFP4 format using CuTe-DSL kernel. | ||
|
|
||
| This is a GPU implementation matching FlashInfer's mxfp4_quantize() behavior: | ||
| - Global scale computed as (448 * 6) / max(|input|) | ||
| - UE8M0 scale factors | ||
| - E2M1 output format (4-bit, 2 values per byte) | ||
| - Swizzled (128x4) scale factor layout | ||
|
|
||
| The kernel is compiled once per (K, dtype, pdl) combination and handles | ||
| varying M (batch size) at runtime without recompilation. | ||
|
|
||
| Args: | ||
| input: Input tensor of shape [M, K] with dtype fp16/bf16 | ||
| enable_pdl: Whether to enable PDL (Programmatic Dependent Launch). | ||
| If None, automatically detects based on device capability (SM >= 9.0). | ||
|
|
||
| Returns: | ||
| Tuple of: | ||
| - fp4_tensor: Quantized tensor of shape [M, K/2] with dtype uint8 | ||
| - scale_tensor: Scale factors as uint8 tensor (swizzled layout) | ||
| """ | ||
| from ...utils import device_support_pdl | ||
|
|
||
| assert input.dtype in (torch.float16, torch.bfloat16), ( | ||
| f"Input dtype must be float16 or bfloat16, got {input.dtype}" | ||
| ) | ||
| assert input.is_cuda, "Input must be on CUDA device" | ||
|
|
||
| # Auto-detect PDL support based on device capability | ||
| if enable_pdl is None: | ||
| enable_pdl = device_support_pdl(input.device) | ||
|
|
||
| if input.dim() > 2: | ||
| m = input.numel() // input.shape[-1] | ||
| k = input.shape[-1] | ||
| input = input.reshape(m, k) | ||
| else: | ||
| m, k = input.shape | ||
|
|
||
| assert k % MXFP4_SF_VEC_SIZE == 0, ( | ||
| f"K ({k}) must be divisible by MXFP4_SF_VEC_SIZE={MXFP4_SF_VEC_SIZE}" | ||
| ) | ||
|
|
||
| input = input.contiguous() |
There was a problem hiding this comment.
Avoid shadowing Python’s built-in input.
Line 444 (parameter) and reassignment at Lines 483/491 shadow the built-in input, which is currently flagged by Ruff (A002/A001).
Proposed fix
-def mxfp4_quantize_cute_dsl(
- input: torch.Tensor,
+def mxfp4_quantize_cute_dsl(
+ x: torch.Tensor,
enable_pdl: bool | None = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
@@
- assert input.dtype in (torch.float16, torch.bfloat16), (
- f"Input dtype must be float16 or bfloat16, got {input.dtype}"
+ assert x.dtype in (torch.float16, torch.bfloat16), (
+ f"Input dtype must be float16 or bfloat16, got {x.dtype}"
)
- assert input.is_cuda, "Input must be on CUDA device"
+ assert x.is_cuda, "Input must be on CUDA device"
@@
- enable_pdl = device_support_pdl(input.device)
+ enable_pdl = device_support_pdl(x.device)
@@
- if input.dim() > 2:
- m = input.numel() // input.shape[-1]
- k = input.shape[-1]
- input = input.reshape(m, k)
+ if x.dim() > 2:
+ m = x.numel() // x.shape[-1]
+ k = x.shape[-1]
+ x = x.reshape(m, k)
else:
- m, k = input.shape
+ m, k = x.shape
@@
- input = input.contiguous()
- is_bfloat16 = input.dtype == torch.bfloat16
+ x = x.contiguous()
+ is_bfloat16 = x.dtype == torch.bfloat16
@@
- target_grid = get_num_sm(input.device) * _BLOCKS_PER_SM
+ target_grid = get_num_sm(x.device) * _BLOCKS_PER_SM
@@
- fp4_output = torch.empty(m, k // 2, dtype=torch.uint8, device=input.device)
+ fp4_output = torch.empty(m, k // 2, dtype=torch.uint8, device=x.device)
scale_output = torch.empty(
- scale_output_size, dtype=torch.uint8, device=input.device
+ scale_output_size, dtype=torch.uint8, device=x.device
)
@@
- kernel_fn(input, fp4_output, scale_output, m, padded_m, num_blocks)
+ kernel_fn(x, fp4_output, scale_output, m, padded_m, num_blocks)🧰 Tools
🪛 Ruff (0.15.6)
[error] 444-444: Function argument input is shadowing a Python builtin
(A002)
[error] 483-483: Variable input is shadowing a Python builtin
(A001)
[error] 491-491: Variable input is shadowing a Python builtin
(A001)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/quantization/kernels/mxfp4_quantize.py` around lines 443 - 491,
The function mxfp4_quantize_cute_dsl is shadowing Python's built-in input by
using the parameter/local variable name "input"; rename that parameter and all
local references (e.g., the reshaping/contiguous uses and device checks) to a
non-conflicting name like "tensor" or "src" throughout the function (and update
the docstring parameter name) to eliminate the Ruff A002/A001 warnings while
preserving all existing logic and assertions (retain checks for dtype, is_cuda,
PDL detection, shape handling, MXFP4_SF_VEC_SIZE assertion, and the final
contiguous call).
| assert k % MXFP4_SF_VEC_SIZE == 0, ( | ||
| f"K ({k}) must be divisible by MXFP4_SF_VEC_SIZE={MXFP4_SF_VEC_SIZE}" | ||
| ) | ||
|
|
||
| input = input.contiguous() | ||
| is_bfloat16 = input.dtype == torch.bfloat16 | ||
|
|
||
| # Cached device-specific target grid for grid size computation | ||
| target_grid = get_num_sm(input.device) * _BLOCKS_PER_SM | ||
|
|
||
| # Compute M-dependent values | ||
| num_sf_blocks_per_row = k // MXFP4_SF_VEC_SIZE | ||
| padded_m = ((m + ROW_TILE_SIZE - 1) // ROW_TILE_SIZE) * ROW_TILE_SIZE | ||
| padded_sf_cols = ((num_sf_blocks_per_row + 3) // 4) * 4 | ||
| scale_output_size = padded_m * padded_sf_cols | ||
|
|
||
| # Get or compile kernel (device-independent) | ||
| kernel_fn, rows_per_block = _get_compiled_kernel_mxfp4(is_bfloat16, k, enable_pdl) | ||
|
|
||
| # Compute grid size in Python (runtime, device-specific) | ||
| num_blocks = min((padded_m + rows_per_block - 1) // rows_per_block, target_grid) | ||
|
|
||
| # Allocate outputs | ||
| fp4_output = torch.empty(m, k // 2, dtype=torch.uint8, device=input.device) | ||
| scale_output = torch.empty( | ||
| scale_output_size, dtype=torch.uint8, device=input.device | ||
| ) | ||
|
|
||
| # Launch kernel | ||
| kernel_fn(input, fp4_output, scale_output, m, padded_m, num_blocks) | ||
|
|
||
| # Reshape scale output to match CUDA backend format: [padded_total, num_sf_per_row] | ||
| scale_output = scale_output.reshape(-1, num_sf_blocks_per_row) | ||
|
|
There was a problem hiding this comment.
Validate K for swizzled-column compatibility before compiling.
At Line 487, the check allows any K % 32 == 0, but Line 500 pads scale columns to multiples of 4 and Line 519 reshapes using unpadded num_sf_blocks_per_row. For K/32 not divisible by 4, this can break reshape semantics; for K=0, _compute_optimal_threads_for_k hits division by zero at Line 90.
Proposed fix
- assert k % MXFP4_SF_VEC_SIZE == 0, (
- f"K ({k}) must be divisible by MXFP4_SF_VEC_SIZE={MXFP4_SF_VEC_SIZE}"
- )
+ if k <= 0 or k % MXFP4_SF_VEC_SIZE != 0:
+ raise ValueError(
+ f"K ({k}) must be a positive multiple of {MXFP4_SF_VEC_SIZE}"
+ )
+ # Swizzled 128x4 layout requires 4 scale-factor blocks per swizzle group.
+ if (k // MXFP4_SF_VEC_SIZE) % 4 != 0:
+ raise ValueError(
+ "CuTe-DSL MXFP4 swizzled backend currently requires K divisible by 128."
+ )🧰 Tools
🪛 Ruff (0.15.6)
[error] 491-491: Variable input is shadowing a Python builtin
(A001)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/quantization/kernels/mxfp4_quantize.py` around lines 487 - 520,
The K validation must ensure K is non-zero and that num_sf_blocks_per_row = K //
MXFP4_SF_VEC_SIZE is divisible by 4 before compiling/using kernels; update the
early assertions/checks (near the existing assert for MXFP4_SF_VEC_SIZE) to
raise a clear error when K == 0 and when num_sf_blocks_per_row % 4 != 0 so
downstream logic in _get_compiled_kernel_mxfp4 and
_compute_optimal_threads_for_k won't divide by zero or produce an unmatchable
padded_sf_cols/reshape; reference MXFP4_SF_VEC_SIZE, num_sf_blocks_per_row,
padded_sf_cols, _get_compiled_kernel_mxfp4 and _compute_optimal_threads_for_k
when making the change.
|
[SUCCESS] Pipeline #46363680: 14/20 passed |
|
/bot run |
…lashinfer-ai#2443) <!-- .github/pull_request_template.md --> ## 📌 Description This PR adds CuTe-DSL backend support for MXFP8 and MXFP4 quantization kernels as alternatives to JIT-compiled CUDA backends Key changes: - Add CuTe-DSL MXFP8 and MXFP4 quantization kernels - Reorganize quantization module structure for better maintainability - Add benchmarks and unit tests for backend comparison **File Structure Reorganization** Quantization files are now organized in `flashinfer/quantization/`: ``` flashinfer/quantization/ ├── __init__.py # Package exports ├── fp4_quantization.py # MXFP4 public API ├── fp8_quantization.py # MXFP8 public API ├── packbits.py # Utility functions ├── quantization_cute_dsl_utils.py # Shared PTX intrinsics └── kernels/ ├── __init__.py # Kernel exports (EXPERIMENTAL) ├── mxfp4_quantize.py # MXFP4 CuTe-DSL kernel └── mxfp8_quantize.py # MXFP8 CuTe-DSL kernel ``` **Performance** CuTe DSL kernels are strong compared to CUDA counterparts: - mxfp4_quantization - Geomean 12x speedup; beats cuda backend in all cases in `bench_mxfp4_quantize_backend_comparison.py` - mxfp8_quantization - Geomean ~1.3x speedup; beats cuda backend in all cases in `bench_mxfp8_quantize_backend_comparison.py` Expand below for performance heatmaps: <details> <summary>CuTe DSL Backend outperforms CUDA backend on every single case benchmarked in bench_mxfp8_quantize_backend_comparison.py. Click to see performance comparison data</summary> **BF16 input; Swizzled cases. > 1.0 means CuTe DSL is faster** <img width="1644" height="1477" alt="sm100_mxfp8_swizzled_bfloat16" src="https://github.com/user-attachments/assets/107279a6-8fc4-4aba-843d-34a83a12acb0" /> **BF16 input; Linear cases. > 1.0 means CuTe DSL is faster** <img width="1644" height="1477" alt="sm100_mxfp8_linear_bfloat16" src="https://github.com/user-attachments/assets/1317ab55-c9ac-4284-bf9a-5127070fe0ad" /> **BF16 input; Swizzled cases. Annotated values are achieved TB/s** <img width="1646" height="1481" alt="sm100_mxfp8_bandwidth_linear_bfloat16" src="https://github.com/user-attachments/assets/033e0692-2eef-4ff7-95f6-94a1d098dbe7" /> **BF16 input; Linear cases. Annotated values are achieved TB/s** <img width="1646" height="1481" alt="sm100_mxfp8_bandwidth_swizzled_bfloat16" src="https://github.com/user-attachments/assets/543f7cd2-0d3a-4f7b-b465-7423f1738d9c" /> </details> <details> <summary>CuTe DSL Backend outperforms CUDA backend on every single case benchmarked in bench_mxfp4_quantize_backend_comparison.py. Click to see performance comparison data</summary> **BF16 input; Swizzled cases. > 1.0 means CuTe DSL is faster** <img width="1658" height="1477" alt="sm100_mxfp4_comparison_bfloat16" src="https://github.com/user-attachments/assets/bbaae310-581a-4035-9e06-0c437263da55" /> **BF16 input; Swizzled cases. Annotated values are achieved TB/s** <img width="1646" height="1481" alt="sm100_mxfp4_bandwidth_bfloat16" src="https://github.com/user-attachments/assets/d7798935-2112-4b73-b127-4095fede8b18" /> </details> <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues flashinfer-ai#2496 <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * CuTe-DSL backend added for MXFP8 and MXFP4 quantization alongside CUDA. * Consolidated quantization package exposing unified FP4/FP8 interfaces and conditional CuTe-DSL exports. * New end-to-end benchmarking tools for MXFP4 and MXFP8 (correctness, performance, bandwidth, heatmaps). * **Bug Fixes / Compatibility** * Backwards-compatible shims preserve existing public API while delegating implementations to the new package. * **Tests** * Expanded tests to cover CUDA and CuTe-DSL, availability gating, compilation cache, and backend parity. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
…lashinfer-ai#2443) <!-- .github/pull_request_template.md --> ## 📌 Description This PR adds CuTe-DSL backend support for MXFP8 and MXFP4 quantization kernels as alternatives to JIT-compiled CUDA backends Key changes: - Add CuTe-DSL MXFP8 and MXFP4 quantization kernels - Reorganize quantization module structure for better maintainability - Add benchmarks and unit tests for backend comparison **File Structure Reorganization** Quantization files are now organized in `flashinfer/quantization/`: ``` flashinfer/quantization/ ├── __init__.py # Package exports ├── fp4_quantization.py # MXFP4 public API ├── fp8_quantization.py # MXFP8 public API ├── packbits.py # Utility functions ├── quantization_cute_dsl_utils.py # Shared PTX intrinsics └── kernels/ ├── __init__.py # Kernel exports (EXPERIMENTAL) ├── mxfp4_quantize.py # MXFP4 CuTe-DSL kernel └── mxfp8_quantize.py # MXFP8 CuTe-DSL kernel ``` **Performance** CuTe DSL kernels are strong compared to CUDA counterparts: - mxfp4_quantization - Geomean 12x speedup; beats cuda backend in all cases in `bench_mxfp4_quantize_backend_comparison.py` - mxfp8_quantization - Geomean ~1.3x speedup; beats cuda backend in all cases in `bench_mxfp8_quantize_backend_comparison.py` Expand below for performance heatmaps: <details> <summary>CuTe DSL Backend outperforms CUDA backend on every single case benchmarked in bench_mxfp8_quantize_backend_comparison.py. Click to see performance comparison data</summary> **BF16 input; Swizzled cases. > 1.0 means CuTe DSL is faster** <img width="1644" height="1477" alt="sm100_mxfp8_swizzled_bfloat16" src="https://github.com/user-attachments/assets/107279a6-8fc4-4aba-843d-34a83a12acb0" /> **BF16 input; Linear cases. > 1.0 means CuTe DSL is faster** <img width="1644" height="1477" alt="sm100_mxfp8_linear_bfloat16" src="https://github.com/user-attachments/assets/1317ab55-c9ac-4284-bf9a-5127070fe0ad" /> **BF16 input; Swizzled cases. Annotated values are achieved TB/s** <img width="1646" height="1481" alt="sm100_mxfp8_bandwidth_linear_bfloat16" src="https://github.com/user-attachments/assets/033e0692-2eef-4ff7-95f6-94a1d098dbe7" /> **BF16 input; Linear cases. Annotated values are achieved TB/s** <img width="1646" height="1481" alt="sm100_mxfp8_bandwidth_swizzled_bfloat16" src="https://github.com/user-attachments/assets/543f7cd2-0d3a-4f7b-b465-7423f1738d9c" /> </details> <details> <summary>CuTe DSL Backend outperforms CUDA backend on every single case benchmarked in bench_mxfp4_quantize_backend_comparison.py. Click to see performance comparison data</summary> **BF16 input; Swizzled cases. > 1.0 means CuTe DSL is faster** <img width="1658" height="1477" alt="sm100_mxfp4_comparison_bfloat16" src="https://github.com/user-attachments/assets/bbaae310-581a-4035-9e06-0c437263da55" /> **BF16 input; Swizzled cases. Annotated values are achieved TB/s** <img width="1646" height="1481" alt="sm100_mxfp4_bandwidth_bfloat16" src="https://github.com/user-attachments/assets/d7798935-2112-4b73-b127-4095fede8b18" /> </details> <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues flashinfer-ai#2496 <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * CuTe-DSL backend added for MXFP8 and MXFP4 quantization alongside CUDA. * Consolidated quantization package exposing unified FP4/FP8 interfaces and conditional CuTe-DSL exports. * New end-to-end benchmarking tools for MXFP4 and MXFP8 (correctness, performance, bandwidth, heatmaps). * **Bug Fixes / Compatibility** * Backwards-compatible shims preserve existing public API while delegating implementations to the new package. * **Tests** * Expanded tests to cover CUDA and CuTe-DSL, availability gating, compilation cache, and backend parity. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Amey Naik <212485788+ameynaik-hub@users.noreply.github.com>
📌 Description
This PR adds CuTe-DSL backend support for MXFP8 and MXFP4 quantization kernels as alternatives to JIT-compiled CUDA backends
Key changes:
File Structure Reorganization
Quantization files are now organized in
flashinfer/quantization/:Performance
CuTe DSL kernels are strong compared to CUDA counterparts:
bench_mxfp4_quantize_backend_comparison.pybench_mxfp8_quantize_backend_comparison.pyExpand below for performance heatmaps:
CuTe DSL Backend outperforms CUDA backend on every single case benchmarked in bench_mxfp8_quantize_backend_comparison.py. Click to see performance comparison data
BF16 input; Swizzled cases. > 1.0 means CuTe DSL is faster

BF16 input; Linear cases. > 1.0 means CuTe DSL is faster

BF16 input; Swizzled cases. Annotated values are achieved TB/s

BF16 input; Linear cases. Annotated values are achieved TB/s

CuTe DSL Backend outperforms CUDA backend on every single case benchmarked in bench_mxfp4_quantize_backend_comparison.py. Click to see performance comparison data
BF16 input; Swizzled cases. > 1.0 means CuTe DSL is faster

BF16 input; Swizzled cases. Annotated values are achieved TB/s

🔍 Related Issues
#2496
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Bug Fixes / Compatibility
Tests