Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
702 changes: 702 additions & 0 deletions benchmarks/bench_mxfp4_quantize_backend_comparison.py

Large diffs are not rendered by default.

683 changes: 683 additions & 0 deletions benchmarks/bench_mxfp8_quantize_backend_comparison.py

Large diffs are not rendered by default.

12 changes: 6 additions & 6 deletions benchmarks/routines/flashinfer_benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,19 +484,19 @@ def dtype_str_to_torch_dtype(dtype_str):
"8.6": [],
"8.9": [],
"9.0": [],
"10.0": ["cuda"],
"10.3": ["cuda"],
"12.0": ["cuda"],
"10.0": ["cuda", "cute-dsl"],
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious, why is cute-dsl only enabled above 10.0?

Is it just a future to-do for more testing/benchmarking for <10.0 before enabling?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hardware accelerated MXFP8-related instructions are a feature of Blackwell generation. Hopper should be good for (non-MX-) FP8 hence should not be able to run these kernels.

As such on Hopper or prior, we do not expect users to use MXFP8 (software-emulated MXFP8 is possible but perf would likely be unsatisfcatory)

Copy link
Copy Markdown
Collaborator

@kahyunnam kahyunnam Feb 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh ... this makes a lot of sense πŸ˜…

"10.3": ["cuda", "cute-dsl"],
"12.0": ["cuda", "cute-dsl"],
},
"mxfp4_quantize": {
"7.5": [],
"8.0": [],
"8.6": [],
"8.9": [],
"9.0": [],
"10.0": ["cuda"],
"10.3": ["cuda"],
"12.0": ["cuda"],
"10.0": ["cuda", "cute-dsl"],
"10.3": ["cuda", "cute-dsl"],
"12.0": ["cuda", "cute-dsl"],
},
"nvfp4_quantize": {
"7.5": [],
Expand Down
30 changes: 16 additions & 14 deletions benchmarks/routines/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def parse_quantization_args(line, parser):
required=False,
nargs="+",
default=["cuda"],
choices=["cuda"],
choices=["cuda", "cute-dsl"],
help="Backend to test. Default: cuda",
)
# FP4 quantization specific arguments
Expand Down Expand Up @@ -231,15 +231,13 @@ def testMxfp8Quantize(args):
print(f"[VVERBOSE] {enable_pdl = }")

def run_backend(backend, input_tensor):
if backend == "cuda":
return flashinfer.mxfp8_quantize(
input_tensor,
is_sf_swizzled_layout=is_sf_swizzled_layout,
alignment=alignment,
enable_pdl=enable_pdl,
)
else:
raise ValueError(f"Unsupported backend: {backend}")
return flashinfer.mxfp8_quantize(
input_tensor,
is_sf_swizzled_layout=is_sf_swizzled_layout,
alignment=alignment,
enable_pdl=enable_pdl,
backend=backend,
)

# Reference check via dequantize round-trip
has_reference_output = False
Expand Down Expand Up @@ -391,6 +389,7 @@ def testMxfp4Quantize(args):
backends = args.backends[:] # Make a copy to avoid modifying the original
m = args.m
k = args.k
enable_pdl = args.enable_pdl
is_cuda_graph_compatible = not args.no_cuda_graph
run_refcheck = args.refcheck
res = []
Expand Down Expand Up @@ -421,12 +420,14 @@ def testMxfp4Quantize(args):
if args.verbose >= 2:
print(f"[VVERBOSE] {input_tensor.shape = }")
print(f"[VVERBOSE] {input_tensor.dtype = }")
print(f"[VVERBOSE] {enable_pdl = }")

def run_backend(backend, input_tensor):
if backend == "cuda":
return flashinfer.mxfp4_quantize(input_tensor)
else:
raise ValueError(f"Unsupported backend: {backend}")
return flashinfer.mxfp4_quantize(
input_tensor,
backend=backend,
enable_pdl=enable_pdl,
)

# Reference check via dequantize round-trip
has_reference_output = False
Expand Down Expand Up @@ -529,6 +530,7 @@ def run_backend(backend, input_tensor):
cur_res["m"] = m
cur_res["k"] = k
cur_res["input_dtype"] = str(input_dtype)
cur_res["enable_pdl"] = enable_pdl
cur_res["backend"] = backend
cur_res["case_tag"] = args.case_tag
res.append(cur_res)
Expand Down
5 changes: 3 additions & 2 deletions flashinfer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
)
from .decode import cudnn_batch_decode_with_kv_cache as cudnn_batch_decode_with_kv_cache
from .decode import single_decode_with_kv_cache as single_decode_with_kv_cache
from .fp4_quantization import (
from .quantization.fp4_quantization import (
SfLayout,
block_scale_interleave,
nvfp4_block_scale_interleave,
Expand All @@ -73,10 +73,11 @@
shuffle_matrix_a,
shuffle_matrix_sf_a,
scaled_fp4_grouped_quantize,
get_fp4_quantization_module,
nvfp4_kv_dequantize,
nvfp4_kv_quantize,
)
from .fp8_quantization import mxfp8_dequantize_host, mxfp8_quantize
from .quantization.fp8_quantization import mxfp8_dequantize_host, mxfp8_quantize
from .fused_moe import (
ActivationType,
RoutingMethodType,
Expand Down
2 changes: 1 addition & 1 deletion flashinfer/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
register_fake_op,
get_compute_capability,
)
from .fp4_quantization import get_fp4_quantization_module
from .quantization.fp4_quantization import get_fp4_quantization_module


@functools.cache
Expand Down
Loading
Loading