diff --git a/benchmarks/bench_mxfp4_quantize_backend_comparison.py b/benchmarks/bench_mxfp4_quantize_backend_comparison.py index 5c51c5c18f..a14ae56811 100644 --- a/benchmarks/bench_mxfp4_quantize_backend_comparison.py +++ b/benchmarks/bench_mxfp4_quantize_backend_comparison.py @@ -16,9 +16,10 @@ Benchmark: MXFP4 Quantization Backend Comparison (CUDA vs CuTe-DSL) Compares the performance of CUDA and CuTe-DSL backends for MXFP4 quantization -across different M and K dimensions. Each configuration is verified for -correctness before timing. Generates heatmaps showing relative performance -(speedup of CuTe-DSL over CUDA). +across different M and K dimensions. Supports both swizzled 128x4 and linear +scale factor layouts. Each configuration is verified for correctness before +timing. Generates heatmaps showing relative performance (speedup of CuTe-DSL +over CUDA). Can also measure achieved memory bandwidth in TB/s for the CuTe-DSL backend. @@ -55,6 +56,7 @@ def verify_mxfp4_correctness( m: int, k: int, dtype: torch.dtype, + is_sf_swizzled_layout: bool, ) -> Tuple[bool, str, float, float]: """ Verify that both backends produce correct outputs via roundtrip test. @@ -63,19 +65,51 @@ def verify_mxfp4_correctness( Tuple of (success, message, quant_match_pct, scale_match_pct) On failure, quant_match_pct and scale_match_pct are 0.0 """ - import flashinfer + from flashinfer.quantization.fp4_quantization import ( + e2m1_and_ufp8sf_scale_to_float, + fp4_quantize, + ) torch.manual_seed(42) x = torch.randn(m, k, device="cuda", dtype=dtype) + global_sf = ((448 * 6) / x.float().abs().nan_to_num().max()).cuda() try: # Test CUDA backend - quant_cuda, scale_cuda = flashinfer.mxfp4_quantize(x, backend="cuda") - dq_cuda = flashinfer.mxfp4_dequantize(quant_cuda, scale_cuda) + quant_cuda, scale_cuda = fp4_quantize( + x, + global_sf, + sf_vec_size=32, + sf_use_ue8m0=True, + is_sf_swizzled_layout=is_sf_swizzled_layout, + backend="cuda", + ) + dq_cuda = e2m1_and_ufp8sf_scale_to_float( + quant_cuda.cpu().view(torch.uint8), + scale_cuda.cpu().view(torch.uint8).reshape(-1), + torch.tensor([1.0]), + 32, + 0, + is_sf_swizzled_layout, + ) # Test CuTe-DSL backend - quant_cute, scale_cute = flashinfer.mxfp4_quantize(x, backend="cute-dsl") - dq_cute = flashinfer.mxfp4_dequantize(quant_cute, scale_cute) + quant_cute, scale_cute = fp4_quantize( + x, + global_sf, + sf_vec_size=32, + sf_use_ue8m0=True, + is_sf_swizzled_layout=is_sf_swizzled_layout, + backend="cute-dsl", + ) + dq_cute = e2m1_and_ufp8sf_scale_to_float( + quant_cute.cpu().view(torch.uint8), + scale_cute.cpu().view(torch.uint8).reshape(-1), + torch.tensor([1.0]), + 32, + 0, + is_sf_swizzled_layout, + ) # Check shapes match if quant_cuda.shape != quant_cute.shape: @@ -131,6 +165,7 @@ def bench_mxfp4_quantize( m: int, k: int, dtype: torch.dtype, + is_sf_swizzled_layout: bool, backend: str, ) -> float: """ @@ -140,22 +175,38 @@ def bench_mxfp4_quantize( m: Number of rows k: Number of columns dtype: Input dtype (torch.float16 or torch.bfloat16) + is_sf_swizzled_layout: Whether to use swizzled scale factor layout backend: "cuda" or "cute-dsl" Returns: Median execution time in milliseconds """ - import flashinfer + from flashinfer.quantization.fp4_quantization import fp4_quantize # Create input tensor x = torch.randn(m, k, device="cuda", dtype=dtype) - - # Warmup and get output shapes - _ = flashinfer.mxfp4_quantize(x, backend=backend) + global_sf = ((448 * 6) / x.float().abs().nan_to_num().max()).cuda() + + # Warmup + _ = fp4_quantize( + x, + global_sf, + sf_vec_size=32, + sf_use_ue8m0=True, + is_sf_swizzled_layout=is_sf_swizzled_layout, + backend=backend, + ) # Benchmark def run_kernel(): - flashinfer.mxfp4_quantize(x, backend=backend) + fp4_quantize( + x, + global_sf, + sf_vec_size=32, + sf_use_ue8m0=True, + is_sf_swizzled_layout=is_sf_swizzled_layout, + backend=backend, + ) times = bench_gpu_time( fn=run_kernel, @@ -210,6 +261,7 @@ def run_bandwidth_sweep( m_values: List[int], k_values: List[int], dtype: torch.dtype, + is_sf_swizzled_layout: bool, ) -> Dict[Tuple[int, int], float]: """ Run bandwidth benchmark sweep for CuTe-DSL backend only. @@ -222,7 +274,10 @@ def run_bandwidth_sweep( total = len(m_values) * len(k_values) current = 0 - print(f"\nBenchmarking MXFP4 swizzled layout, dtype={dtype} (CuTe-DSL bandwidth)") + layout_str = "swizzled" if is_sf_swizzled_layout else "linear" + print( + f"\nBenchmarking MXFP4 {layout_str} layout, dtype={dtype} (CuTe-DSL bandwidth)" + ) print("=" * 60) for m in m_values: @@ -231,7 +286,9 @@ def run_bandwidth_sweep( print(f"[{current}/{total}] M={m:5d}, K={k:5d} ... ", end="", flush=True) # Benchmark CuTe-DSL backend only - time_ms = bench_mxfp4_quantize(m, k, dtype, backend="cute-dsl") + time_ms = bench_mxfp4_quantize( + m, k, dtype, is_sf_swizzled_layout, backend="cute-dsl" + ) # Compute bandwidth bandwidth = compute_bandwidth_tb_per_sec(m, k, dtype, time_ms) @@ -246,6 +303,7 @@ def run_benchmark_sweep( m_values: List[int], k_values: List[int], dtype: torch.dtype, + is_sf_swizzled_layout: bool, ) -> Tuple[Dict[Tuple[int, int], float], Dict[Tuple[int, int], float]]: """ Run benchmark sweep for both backends with inline correctness verification. @@ -254,6 +312,7 @@ def run_benchmark_sweep( m_values: List of M dimensions to benchmark k_values: List of K dimensions to benchmark dtype: Input dtype + is_sf_swizzled_layout: Whether to use swizzled scale factor layout Returns: Tuple of (cuda_times, cute_dsl_times) dictionaries @@ -265,7 +324,8 @@ def run_benchmark_sweep( total = len(m_values) * len(k_values) current = 0 - print(f"\nBenchmarking MXFP4 swizzled layout, dtype={dtype}") + layout_str = "swizzled" if is_sf_swizzled_layout else "linear" + print(f"\nBenchmarking MXFP4 {layout_str} layout, dtype={dtype}") print("=" * 95) print( f"{'Progress':<12} {'M':>5} {'K':>5} | " @@ -285,7 +345,7 @@ def run_benchmark_sweep( # Verify correctness first success, verify_msg, quant_match, scale_match = verify_mxfp4_correctness( - m, k, dtype + m, k, dtype, is_sf_swizzled_layout ) if not success: failures.append((m, k, verify_msg)) @@ -293,11 +353,15 @@ def run_benchmark_sweep( continue # Benchmark CUDA backend - cuda_time = bench_mxfp4_quantize(m, k, dtype, backend="cuda") + cuda_time = bench_mxfp4_quantize( + m, k, dtype, is_sf_swizzled_layout, backend="cuda" + ) cuda_times[(m, k)] = cuda_time # Benchmark CuTe-DSL backend - cute_dsl_time = bench_mxfp4_quantize(m, k, dtype, backend="cute-dsl") + cute_dsl_time = bench_mxfp4_quantize( + m, k, dtype, is_sf_swizzled_layout, backend="cute-dsl" + ) cute_dsl_times[(m, k)] = cute_dsl_time # Compute speedup @@ -497,10 +561,11 @@ def print_bandwidth_summary_table( m_values: List[int], k_values: List[int], bandwidth_results: Dict[Tuple[int, int], float], + layout_name: str = "Swizzled Layout", ): """Print a summary table of bandwidth results.""" print(f"\n{'=' * 80}") - print("Bandwidth Summary: MXFP4 Swizzled Layout (TB/s)") + print(f"Bandwidth Summary: MXFP4 {layout_name} (TB/s)") print(f"{'=' * 80}") # Header @@ -537,10 +602,11 @@ def print_summary_table( k_values: List[int], cuda_times: Dict[Tuple[int, int], float], cute_dsl_times: Dict[Tuple[int, int], float], + layout_name: str = "Swizzled Layout", ): """Print a summary table of results.""" print(f"\n{'=' * 80}") - print("Summary: MXFP4 Swizzled Layout (Speedup: CUDA time / CuTe-DSL time)") + print(f"Summary: MXFP4 {layout_name} (Speedup: CUDA time / CuTe-DSL time)") print(f"{'=' * 80}") # Header @@ -618,12 +684,20 @@ def main(): print(f"Data type: {dtype}") # Define sweep ranges (powers of 2 + common transformer hidden dimensions) - # Note: K must be a multiple of 128 for MXFP4 swizzled layout because: - # - SF vec size is 32, so K/32 gives number of SF blocks per row - # - Swizzled layout pads SF blocks to multiples of 4 - # - The CUDA backend's reshape assumes unpadded SF dimensions - # So K/32 must already be a multiple of 4, i.e., K must be multiple of 128 + # K constraints: + # - Linear layout: K must be a multiple of 32 (SF_VEC_SIZE) + # - Swizzled layout: K must be a multiple of 128 because K/32 (SF blocks + # per row) must be a multiple of 4 for the swizzled padding to work + # correctly with the CUDA backend's reshape + # We use K values that satisfy both constraints (multiples of 128) m_values = [ + 1, + 2, + 4, + 8, + 16, + 32, + 64, 128, 256, 384, @@ -667,30 +741,97 @@ def main(): print("BANDWIDTH MEASUREMENT MODE (CuTe-DSL only)") print("=" * 80) - bandwidth_results = run_bandwidth_sweep(m_values, k_values, dtype) - print_bandwidth_summary_table(m_values, k_values, bandwidth_results) + # Benchmark linear layout + print("\n" + "=" * 80) + print("BENCHMARKING LINEAR (NON-SWIZZLED) LAYOUT - BANDWIDTH") + print("=" * 80) - # Generate bandwidth heatmap + bandwidth_linear = run_bandwidth_sweep( + m_values, k_values, dtype, is_sf_swizzled_layout=False + ) + print_bandwidth_summary_table( + m_values, k_values, bandwidth_linear, "Linear Layout" + ) create_bandwidth_heatmap( m_values, k_values, - bandwidth_results, - f"MXFP4 Quantization CuTe-DSL Bandwidth ({args.dtype})", - f"{args.output_prefix}_bandwidth_{args.dtype}.png", + bandwidth_linear, + f"MXFP4 Quantization Bandwidth (CuTe-DSL) - Linear Layout - {args.dtype}", + f"{args.output_prefix}_bandwidth_linear_{args.dtype}.png", + ) + + # Benchmark swizzled layout + print("\n" + "=" * 80) + print("BENCHMARKING SWIZZLED LAYOUT - BANDWIDTH") + print("=" * 80) + + bandwidth_swizzled = run_bandwidth_sweep( + m_values, k_values, dtype, is_sf_swizzled_layout=True + ) + print_bandwidth_summary_table( + m_values, k_values, bandwidth_swizzled, "Swizzled Layout" + ) + create_bandwidth_heatmap( + m_values, + k_values, + bandwidth_swizzled, + f"MXFP4 Quantization Bandwidth (CuTe-DSL) - Swizzled Layout - {args.dtype}", + f"{args.output_prefix}_bandwidth_swizzled_{args.dtype}.png", ) else: - # Run comparison benchmark (with inline correctness verification) - cuda_times, cute_dsl_times = run_benchmark_sweep(m_values, k_values, dtype) - print_summary_table(m_values, k_values, cuda_times, cute_dsl_times) + # Speedup comparison mode: CUDA vs CuTe-DSL + # Benchmark linear layout (non-swizzled) + print("\n" + "=" * 80) + print("BENCHMARKING LINEAR (NON-SWIZZLED) LAYOUT") + print("=" * 80) + + cuda_times_linear, cute_dsl_times_linear = run_benchmark_sweep( + m_values, + k_values, + dtype, + is_sf_swizzled_layout=False, + ) + print_summary_table( + m_values, + k_values, + cuda_times_linear, + cute_dsl_times_linear, + "Linear Layout", + ) + create_heatmap( + m_values, + k_values, + cuda_times_linear, + cute_dsl_times_linear, + f"MXFP4 Quantization Speedup (CuTe-DSL vs CUDA) - Linear Layout - {args.dtype}", + f"{args.output_prefix}_comparison_linear_{args.dtype}.png", + ) + + # Benchmark swizzled layout + print("\n" + "=" * 80) + print("BENCHMARKING SWIZZLED LAYOUT") + print("=" * 80) - # Generate heatmap + cuda_times_swizzled, cute_dsl_times_swizzled = run_benchmark_sweep( + m_values, + k_values, + dtype, + is_sf_swizzled_layout=True, + ) + print_summary_table( + m_values, + k_values, + cuda_times_swizzled, + cute_dsl_times_swizzled, + "Swizzled Layout", + ) create_heatmap( m_values, k_values, - cuda_times, - cute_dsl_times, - f"MXFP4 Quantization Backend Comparison ({args.dtype})", - f"{args.output_prefix}_comparison_{args.dtype}.png", + cuda_times_swizzled, + cute_dsl_times_swizzled, + f"MXFP4 Quantization Speedup (CuTe-DSL vs CUDA) - Swizzled Layout - {args.dtype}", + f"{args.output_prefix}_comparison_swizzled_{args.dtype}.png", ) print("\n" + "=" * 80) diff --git a/benchmarks/bench_mxfp8_quantize_backend_comparison.py b/benchmarks/bench_mxfp8_quantize_backend_comparison.py index 25ae1b6abb..5abd5472d8 100644 --- a/benchmarks/bench_mxfp8_quantize_backend_comparison.py +++ b/benchmarks/bench_mxfp8_quantize_backend_comparison.py @@ -50,6 +50,88 @@ def get_cc(): return major * 10 + minor +def verify_mxfp8_correctness( + m: int, + k: int, + dtype: torch.dtype, + is_sf_swizzled_layout: bool, +) -> Tuple[bool, str, float, float]: + """ + Verify that both backends produce correct outputs. + + Returns: + Tuple of (success, message, quant_match_pct, scale_match_pct) + On failure, quant_match_pct and scale_match_pct are 0.0 + """ + import flashinfer + + torch.manual_seed(42) + x = torch.randn(m, k, device="cuda", dtype=dtype) + + try: + # Test CUDA backend + quant_cuda, scale_cuda = flashinfer.mxfp8_quantize( + x, is_sf_swizzled_layout=is_sf_swizzled_layout, backend="cuda" + ) + + # Test CuTe-DSL backend + quant_cute, scale_cute = flashinfer.mxfp8_quantize( + x, is_sf_swizzled_layout=is_sf_swizzled_layout, backend="cute-dsl" + ) + + # Check shapes match + if quant_cuda.shape != quant_cute.shape: + return ( + False, + f"Quant shape mismatch: CUDA={quant_cuda.shape}, CuTe={quant_cute.shape}", + 0.0, + 0.0, + ) + if scale_cuda.shape != scale_cute.shape: + return ( + False, + f"Scale shape mismatch: CUDA={scale_cuda.shape}, CuTe={scale_cute.shape}", + 0.0, + 0.0, + ) + + # Check backend agreement (exact byte-level match) + quant_cuda_u8 = quant_cuda.view(torch.uint8) + quant_cute_u8 = quant_cute.view(torch.uint8) + quant_match_pct = (quant_cuda_u8 == quant_cute_u8).float().mean().item() * 100 + scale_match_pct = (scale_cuda == scale_cute).float().mean().item() * 100 + + # FP8 quantization: check roundtrip quality via cosine similarity + dq_cuda = quant_cuda.to(torch.float32).view(1, -1) + dq_cute = quant_cute.to(torch.float32).view(1, -1) + x_f32 = x.cpu().to(torch.float32).view(1, -1) + dq_cuda_cpu = dq_cuda.cpu() + dq_cute_cpu = dq_cute.cpu() + + cos_sim_cuda = torch.nn.functional.cosine_similarity(x_f32, dq_cuda_cpu).item() + cos_sim_cute = torch.nn.functional.cosine_similarity(x_f32, dq_cute_cpu).item() + + if cos_sim_cuda < 0.9: + return ( + False, + f"CUDA roundtrip quality too low: cos_sim={cos_sim_cuda:.4f}", + quant_match_pct, + scale_match_pct, + ) + if cos_sim_cute < 0.9: + return ( + False, + f"CuTe-DSL roundtrip quality too low: cos_sim={cos_sim_cute:.4f}", + quant_match_pct, + scale_match_pct, + ) + + return True, "OK", quant_match_pct, scale_match_pct + + except Exception as e: + return False, f"Exception: {e}", 0.0, 0.0 + + def bench_mxfp8_quantize( m: int, k: int, @@ -186,25 +268,45 @@ def run_benchmark_sweep( is_sf_swizzled_layout: bool, ) -> Tuple[Dict[Tuple[int, int], float], Dict[Tuple[int, int], float]]: """ - Run benchmark sweep for both backends. + Run benchmark sweep for both backends with inline correctness verification. Returns: Tuple of (cuda_times, cute_dsl_times) dictionaries """ cuda_times = {} cute_dsl_times = {} + failures = [] total = len(m_values) * len(k_values) current = 0 layout_str = "swizzled" if is_sf_swizzled_layout else "linear" - print(f"\nBenchmarking {layout_str} layout, dtype={dtype}") - print("=" * 60) + print(f"\nBenchmarking MXFP8 {layout_str} layout, dtype={dtype}") + print("=" * 95) + print( + f"{'Progress':<12} {'M':>5} {'K':>5} | " + f"{'--Match--':^14} | " + f"{'-------Timing-------':^28}" + ) + print( + f"{'':12} {'':>5} {'':>5} | " + f"{'quant':>6} {'scale':>6} | " + f"{'CUDA':>8} {'CuTe':>8} {'Speedup':>10}" + ) + print("-" * 95) for m in m_values: for k in k_values: current += 1 - print(f"[{current}/{total}] M={m:5d}, K={k:5d} ... ", end="", flush=True) + + # Verify correctness first + success, verify_msg, quant_match, scale_match = verify_mxfp8_correctness( + m, k, dtype, is_sf_swizzled_layout + ) + if not success: + failures.append((m, k, verify_msg)) + print(f"[{current:3d}/{total}] {m:5d} {k:5d} | FAIL: {verify_msg}") + continue # Benchmark CUDA backend cuda_time = bench_mxfp8_quantize( @@ -224,10 +326,16 @@ def run_benchmark_sweep( f"{speedup:.2f}x" if speedup >= 1 else f"{1 / speedup:.2f}x slower" ) print( - f"CUDA={cuda_time:.3f}ms, CuTe-DSL={cute_dsl_time:.3f}ms, " - f"Speedup={speedup_str}" + f"[{current:3d}/{total}] {m:5d} {k:5d} | " + f"{quant_match:5.1f}% {scale_match:6.1f}% | " + f"{cuda_time:7.3f}ms {cute_dsl_time:7.3f}ms {speedup_str:>10}" ) + if failures: + print(f"\nWARNING: {len(failures)}/{total} configurations failed verification:") + for m, k, msg in failures: + print(f" - M={m}, K={k}: {msg}") + return cuda_times, cute_dsl_times @@ -535,6 +643,13 @@ def main(): # Define sweep ranges (powers of 2 + common transformer hidden dimensions) m_values = [ + 1, + 2, + 4, + 8, + 16, + 32, + 64, 128, 256, 384, @@ -628,7 +743,10 @@ def main(): print("=" * 80) cuda_times_linear, cute_dsl_times_linear = run_benchmark_sweep( - m_values, k_values, dtype, is_sf_swizzled_layout=False + m_values, + k_values, + dtype, + is_sf_swizzled_layout=False, ) print_summary_table( @@ -654,7 +772,10 @@ def main(): print("=" * 80) cuda_times_swizzled, cute_dsl_times_swizzled = run_benchmark_sweep( - m_values, k_values, dtype, is_sf_swizzled_layout=True + m_values, + k_values, + dtype, + is_sf_swizzled_layout=True, ) print_summary_table( diff --git a/benchmarks/bench_nvfp4_quantize_backend_comparison.py b/benchmarks/bench_nvfp4_quantize_backend_comparison.py new file mode 100644 index 0000000000..521320397e --- /dev/null +++ b/benchmarks/bench_nvfp4_quantize_backend_comparison.py @@ -0,0 +1,759 @@ +""" +Copyright (c) 2025 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +Benchmark: NVFP4 Quantization Backend Comparison (CUDA vs CuTe-DSL) + +Compares the performance of CUDA and CuTe-DSL backends for NVFP4 quantization +across different M and K dimensions. Supports both swizzled 128x4 and linear +scale factor layouts. Each configuration is verified for correctness before +timing. Generates heatmaps showing relative performance (speedup of CuTe-DSL +over CUDA). + +Can also measure achieved memory bandwidth in TB/s for the CuTe-DSL backend. + +Usage: + # Speedup comparison mode (default, includes correctness verification) + python bench_nvfp4_quantize_backend_comparison.py + + # Bandwidth measurement mode (cute-dsl only) + python bench_nvfp4_quantize_backend_comparison.py --bandwidth + +Requirements: + - Blackwell GPU (SM100+) for CuTe-DSL backend + - matplotlib for visualization +""" + +import argparse +import numpy as np +import torch +from typing import Dict, List, Tuple + +from flashinfer.testing.utils import bench_gpu_time + +# Constants for NVFP4 +NVFP4_SF_VEC_SIZE = 16 +FLOAT4_E2M1_MAX = 6.0 +FLOAT8_E4M3_MAX = float(torch.finfo(torch.float8_e4m3fn).max) + + +def get_cc(): + """Get CUDA compute capability.""" + major, minor = torch.cuda.get_device_capability() + return major * 10 + minor + + +def verify_nvfp4_correctness( + m: int, + k: int, + dtype: torch.dtype, + is_sf_swizzled_layout: bool, +) -> Tuple[bool, str, float, float]: + """ + Verify that both backends produce correct outputs via roundtrip test. + + Returns: + Tuple of (success, message, quant_match_pct, scale_match_pct) + On failure, quant_match_pct and scale_match_pct are 0.0 + """ + from flashinfer.quantization.fp4_quantization import ( + e2m1_and_ufp8sf_scale_to_float, + fp4_quantize, + ) + + torch.manual_seed(42) + x = torch.randn(m, k, device="cuda", dtype=dtype) + amax = x.abs().max().to(torch.float32) + global_sf = (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / amax).cuda() + + try: + # Test CUDA backend + quant_cuda, scale_cuda = fp4_quantize( + x, + global_sf, + sf_vec_size=16, + sf_use_ue8m0=False, + is_sf_swizzled_layout=is_sf_swizzled_layout, + backend="cuda", + ) + dq_cuda = e2m1_and_ufp8sf_scale_to_float( + quant_cuda.cpu().view(torch.uint8), + scale_cuda.cpu().view(torch.uint8).reshape(-1), + torch.tensor([1.0]), + 16, + 1, + is_sf_swizzled_layout, + ) + + # Test CuTe-DSL backend + quant_cute, scale_cute = fp4_quantize( + x, + global_sf, + sf_vec_size=16, + sf_use_ue8m0=False, + is_sf_swizzled_layout=is_sf_swizzled_layout, + backend="cute-dsl", + ) + dq_cute = e2m1_and_ufp8sf_scale_to_float( + quant_cute.cpu().view(torch.uint8), + scale_cute.cpu().view(torch.uint8).reshape(-1), + torch.tensor([1.0]), + 16, + 1, + is_sf_swizzled_layout, + ) + + # Check shapes match + if quant_cuda.shape != quant_cute.shape: + return ( + False, + f"Quant shape mismatch: CUDA={quant_cuda.shape}, CuTe={quant_cute.shape}", + 0.0, + 0.0, + ) + if scale_cuda.shape != scale_cute.shape: + return ( + False, + f"Scale shape mismatch: CUDA={scale_cuda.shape}, CuTe={scale_cute.shape}", + 0.0, + 0.0, + ) + + # Check roundtrip quality for both backends (cosine similarity) + x_f32 = x.cpu().to(torch.float32).view(1, -1) + dq_cuda_f32 = dq_cuda.cpu().to(torch.float32).view(1, -1) + dq_cute_f32 = dq_cute.cpu().to(torch.float32).view(1, -1) + + cos_sim_cuda = torch.nn.functional.cosine_similarity(x_f32, dq_cuda_f32).item() + cos_sim_cute = torch.nn.functional.cosine_similarity(x_f32, dq_cute_f32).item() + + # Check backend agreement + quant_match_pct = (quant_cuda == quant_cute).float().mean().item() * 100 + scale_match_pct = (scale_cuda == scale_cute).float().mean().item() * 100 + + # FP4 quantization should have cosine similarity > 0.9 + if cos_sim_cuda < 0.9: + return ( + False, + f"CUDA roundtrip quality too low: cos_sim={cos_sim_cuda:.4f}", + quant_match_pct, + scale_match_pct, + ) + if cos_sim_cute < 0.9: + return ( + False, + f"CuTe-DSL roundtrip quality too low: cos_sim={cos_sim_cute:.4f}", + quant_match_pct, + scale_match_pct, + ) + + return True, "OK", quant_match_pct, scale_match_pct + + except Exception as e: + return False, f"Exception: {e}", 0.0, 0.0 + + +def bench_nvfp4_quantize( + m: int, + k: int, + dtype: torch.dtype, + is_sf_swizzled_layout: bool, + backend: str, +) -> float: + """ + Benchmark NVFP4 quantization for a specific configuration. + + Returns: + Median execution time in milliseconds + """ + from flashinfer.quantization.fp4_quantization import fp4_quantize + + x = torch.randn(m, k, device="cuda", dtype=dtype) + amax = x.abs().max().to(torch.float32) + global_sf = (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / amax).cuda() + + # Warmup + _ = fp4_quantize( + x, + global_sf, + sf_vec_size=16, + sf_use_ue8m0=False, + is_sf_swizzled_layout=is_sf_swizzled_layout, + backend=backend, + ) + + def run_kernel(): + fp4_quantize( + x, + global_sf, + sf_vec_size=16, + sf_use_ue8m0=False, + is_sf_swizzled_layout=is_sf_swizzled_layout, + backend=backend, + ) + + times = bench_gpu_time( + fn=run_kernel, + enable_cupti=True, + dry_run_iters=5, + repeat_iters=30, + cold_l2_cache=True, + use_cuda_graph=False, + ) + + return np.median(times) + + +def compute_bandwidth_tb_per_sec( + m: int, k: int, dtype: torch.dtype, time_ms: float +) -> float: + """ + Compute achieved memory bandwidth in TB/s. + + Memory bandwidth calculation for nvfp4_quantize: + - Read: input tensor (2 bytes per element for fp16/bf16) + - Write: quantized tensor (0.5 bytes per element, since fp4 = 4 bits) + - Write: scale factors (1 byte per scale factor) + """ + input_dtype_bytes = 2 # fp16 or bf16 + + num_elements = m * k + num_scale_factors = num_elements // NVFP4_SF_VEC_SIZE + + problem_bytes = ( + num_elements * input_dtype_bytes # input read + + num_elements // 2 # fp4 output write + + num_scale_factors * 1 # scale factors write + ) + + tb_per_sec = problem_bytes / (1e9 * time_ms) + return tb_per_sec + + +def run_bandwidth_sweep( + m_values: List[int], + k_values: List[int], + dtype: torch.dtype, + is_sf_swizzled_layout: bool, +) -> Dict[Tuple[int, int], float]: + """Run bandwidth benchmark sweep for CuTe-DSL backend only.""" + bandwidth_results = {} + + total = len(m_values) * len(k_values) + current = 0 + + layout_str = "swizzled" if is_sf_swizzled_layout else "linear" + print( + f"\nBenchmarking NVFP4 {layout_str} layout, dtype={dtype} (CuTe-DSL bandwidth)" + ) + print("=" * 60) + + for m in m_values: + for k in k_values: + current += 1 + print(f"[{current}/{total}] M={m:5d}, K={k:5d} ... ", end="", flush=True) + + time_ms = bench_nvfp4_quantize( + m, k, dtype, is_sf_swizzled_layout, backend="cute-dsl" + ) + + bandwidth = compute_bandwidth_tb_per_sec(m, k, dtype, time_ms) + bandwidth_results[(m, k)] = bandwidth + + print(f"time={time_ms:.3f}ms, bandwidth={bandwidth:.2f} TB/s") + + return bandwidth_results + + +def run_benchmark_sweep( + m_values: List[int], + k_values: List[int], + dtype: torch.dtype, + is_sf_swizzled_layout: bool, +) -> Tuple[Dict[Tuple[int, int], float], Dict[Tuple[int, int], float]]: + """Run benchmark sweep for both backends with inline correctness verification.""" + cuda_times = {} + cute_dsl_times = {} + failures = [] + + total = len(m_values) * len(k_values) + current = 0 + + layout_str = "swizzled" if is_sf_swizzled_layout else "linear" + print(f"\nBenchmarking NVFP4 {layout_str} layout, dtype={dtype}") + print("=" * 95) + print( + f"{'Progress':<12} {'M':>5} {'K':>5} | " + f"{'--Match--':^14} | " + f"{'-------Timing-------':^28}" + ) + print( + f"{'':12} {'':>5} {'':>5} | " + f"{'quant':>6} {'scale':>6} | " + f"{'CUDA':>8} {'CuTe':>8} {'Speedup':>10}" + ) + print("-" * 95) + + for m in m_values: + for k in k_values: + current += 1 + + # Verify correctness first + success, verify_msg, quant_match, scale_match = verify_nvfp4_correctness( + m, k, dtype, is_sf_swizzled_layout + ) + if not success: + failures.append((m, k, verify_msg)) + print(f"[{current:3d}/{total}] {m:5d} {k:5d} | FAIL: {verify_msg}") + continue + + # Benchmark CUDA backend + cuda_time = bench_nvfp4_quantize( + m, k, dtype, is_sf_swizzled_layout, backend="cuda" + ) + cuda_times[(m, k)] = cuda_time + + # Benchmark CuTe-DSL backend + cute_dsl_time = bench_nvfp4_quantize( + m, k, dtype, is_sf_swizzled_layout, backend="cute-dsl" + ) + cute_dsl_times[(m, k)] = cute_dsl_time + + # Compute speedup + speedup = cuda_time / cute_dsl_time + speedup_str = ( + f"{speedup:.2f}x" if speedup >= 1 else f"{1 / speedup:.2f}x slower" + ) + print( + f"[{current:3d}/{total}] {m:5d} {k:5d} | " + f"{quant_match:5.1f}% {scale_match:6.1f}% | " + f"{cuda_time:7.3f}ms {cute_dsl_time:7.3f}ms {speedup_str:>10}" + ) + + if failures: + print(f"\nWARNING: {len(failures)}/{total} configurations failed verification:") + for m, k, msg in failures: + print(f" - M={m}, K={k}: {msg}") + + return cuda_times, cute_dsl_times + + +def create_heatmap( + m_values: List[int], + k_values: List[int], + cuda_times: Dict[Tuple[int, int], float], + cute_dsl_times: Dict[Tuple[int, int], float], + title: str, + output_file: str, +): + """Create a heatmap showing relative performance (CuTe-DSL speedup over CUDA).""" + try: + import matplotlib.pyplot as plt + import matplotlib.colors as mcolors + except ImportError: + print("matplotlib not installed, skipping heatmap generation") + return + + speedup_matrix = np.zeros((len(m_values), len(k_values))) + + for i, m in enumerate(m_values): + for j, k in enumerate(k_values): + cuda_time = cuda_times.get((m, k), float("nan")) + cute_dsl_time = cute_dsl_times.get((m, k), float("nan")) + if cute_dsl_time > 0: + speedup_matrix[i, j] = cuda_time / cute_dsl_time + else: + speedup_matrix[i, j] = float("nan") + + fig, ax = plt.subplots(figsize=(12, 10)) + + vmin = min(0.5, np.nanmin(speedup_matrix)) + vmax = max(2.0, np.nanmax(speedup_matrix)) + norm = mcolors.TwoSlopeNorm(vmin=vmin, vcenter=1.0, vmax=vmax) + + im = ax.imshow(speedup_matrix, cmap="RdYlGn", norm=norm, aspect="auto") + + cbar = ax.figure.colorbar(im, ax=ax, shrink=0.8) + cbar.ax.set_ylabel("Speedup (CUDA time / CuTe-DSL time)", rotation=-90, va="bottom") + + ax.set_xticks(np.arange(len(k_values))) + ax.set_yticks(np.arange(len(m_values))) + ax.set_xticklabels([str(k) for k in k_values]) + ax.set_yticklabels([str(m) for m in m_values]) + + plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") + + for i in range(len(m_values)): + for j in range(len(k_values)): + value = speedup_matrix[i, j] + if not np.isnan(value): + text_color = "white" if value < 0.7 or value > 1.5 else "black" + ax.text( + j, + i, + f"{value:.2f}", + ha="center", + va="center", + color=text_color, + fontsize=8, + ) + + ax.set_xlabel("K (columns)") + ax.set_ylabel("M (rows)") + ax.set_title(title + "\n(>1.0 = CuTe-DSL faster, <1.0 = CUDA faster)") + + plt.tight_layout() + plt.savefig(output_file, dpi=150, bbox_inches="tight") + print(f"Saved heatmap to {output_file}") + plt.close() + + +def create_bandwidth_heatmap( + m_values: List[int], + k_values: List[int], + bandwidth_results: Dict[Tuple[int, int], float], + title: str, + output_file: str, +): + """Create a heatmap showing achieved memory bandwidth in TB/s.""" + try: + import matplotlib.pyplot as plt + except ImportError: + print("matplotlib not installed, skipping heatmap generation") + return + + bandwidth_matrix = np.zeros((len(m_values), len(k_values))) + + for i, m in enumerate(m_values): + for j, k in enumerate(k_values): + bandwidth_matrix[i, j] = bandwidth_results.get((m, k), float("nan")) + + fig, ax = plt.subplots(figsize=(12, 10)) + + vmin = np.nanmin(bandwidth_matrix) + vmax = np.nanmax(bandwidth_matrix) + + im = ax.imshow(bandwidth_matrix, cmap="YlGn", vmin=vmin, vmax=vmax, aspect="auto") + + cbar = ax.figure.colorbar(im, ax=ax, shrink=0.8) + cbar.ax.set_ylabel("Achieved Bandwidth (TB/s)", rotation=-90, va="bottom") + + ax.set_xticks(np.arange(len(k_values))) + ax.set_yticks(np.arange(len(m_values))) + ax.set_xticklabels([str(k) for k in k_values]) + ax.set_yticklabels([str(m) for m in m_values]) + + plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") + + for i in range(len(m_values)): + for j in range(len(k_values)): + value = bandwidth_matrix[i, j] + if not np.isnan(value): + normalized = (value - vmin) / (vmax - vmin) if vmax > vmin else 0.5 + text_color = "white" if normalized > 0.6 else "black" + ax.text( + j, + i, + f"{value:.1f}", + ha="center", + va="center", + color=text_color, + fontsize=8, + ) + + ax.set_xlabel("K (columns)") + ax.set_ylabel("M (rows)") + ax.set_title(title) + + plt.tight_layout() + plt.savefig(output_file, dpi=150, bbox_inches="tight") + print(f"Saved heatmap to {output_file}") + plt.close() + + +def print_bandwidth_summary_table( + m_values: List[int], + k_values: List[int], + bandwidth_results: Dict[Tuple[int, int], float], + layout_name: str = "Swizzled Layout", +): + """Print a summary table of bandwidth results.""" + print(f"\n{'=' * 80}") + print(f"Bandwidth Summary: NVFP4 {layout_name} (TB/s)") + print(f"{'=' * 80}") + + header = "M\\K".ljust(8) + for k in k_values: + header += f"{k:>8}" + print(header) + print("-" * (8 + 8 * len(k_values))) + + for m in m_values: + row = f"{m:<8}" + for k in k_values: + bandwidth = bandwidth_results.get((m, k), float("nan")) + if not np.isnan(bandwidth): + row += f"{bandwidth:>8.1f}" + else: + row += f"{'N/A':>8}" + print(row) + + bandwidths = [b for b in bandwidth_results.values() if not np.isnan(b)] + if bandwidths: + print("\nStatistics:") + print(f" Mean bandwidth: {np.mean(bandwidths):.2f} TB/s") + print(f" Min bandwidth: {min(bandwidths):.2f} TB/s") + print(f" Max bandwidth: {max(bandwidths):.2f} TB/s") + print(f" Std deviation: {np.std(bandwidths):.2f} TB/s") + + +def print_summary_table( + m_values: List[int], + k_values: List[int], + cuda_times: Dict[Tuple[int, int], float], + cute_dsl_times: Dict[Tuple[int, int], float], + layout_name: str = "Swizzled Layout", +): + """Print a summary table of results.""" + print(f"\n{'=' * 80}") + print(f"Summary: NVFP4 {layout_name} (Speedup: CUDA time / CuTe-DSL time)") + print(f"{'=' * 80}") + + header = "M\\K".ljust(8) + for k in k_values: + header += f"{k:>8}" + print(header) + print("-" * (8 + 8 * len(k_values))) + + for m in m_values: + row = f"{m:<8}" + for k in k_values: + cuda_time = cuda_times.get((m, k), float("nan")) + cute_dsl_time = cute_dsl_times.get((m, k), float("nan")) + if cute_dsl_time > 0 and not np.isnan(cuda_time): + speedup = cuda_time / cute_dsl_time + row += f"{speedup:>8.2f}" + else: + row += f"{'N/A':>8}" + print(row) + + speedups = [] + for m in m_values: + for k in k_values: + cuda_time = cuda_times.get((m, k)) + cute_dsl_time = cute_dsl_times.get((m, k)) + if cuda_time and cute_dsl_time and cute_dsl_time > 0: + speedups.append(cuda_time / cute_dsl_time) + + if speedups: + print("\nStatistics:") + print(f" Geometric mean speedup: {np.exp(np.mean(np.log(speedups))):.2f}x") + print(f" Min speedup: {min(speedups):.2f}x") + print(f" Max speedup: {max(speedups):.2f}x") + print( + f" Cases where CuTe-DSL faster: {sum(1 for s in speedups if s > 1)}/{len(speedups)}" + ) + + +def main(): + parser = argparse.ArgumentParser( + description="Benchmark NVFP4 quantization backends" + ) + parser.add_argument( + "--bandwidth", + action="store_true", + help="Run bandwidth benchmark (CuTe-DSL only) instead of comparison", + ) + parser.add_argument( + "--dtype", + choices=["float16", "bfloat16"], + default="bfloat16", + help="Input data type (default: bfloat16)", + ) + parser.add_argument( + "--output-prefix", + type=str, + default="nvfp4_quantize_backend", + help="Output file prefix for heatmaps", + ) + args = parser.parse_args() + + # Check compute capability + cc = get_cc() + print(f"GPU Compute Capability: SM{cc}") + + if cc < 100: + print("ERROR: CuTe-DSL backend requires Blackwell GPU (SM100+)") + return + + # Get dtype + dtype = torch.float16 if args.dtype == "float16" else torch.bfloat16 + print(f"Data type: {dtype}") + + # Define sweep ranges + # K constraints: + # - Linear layout: K must be a multiple of 16 (NVFP4_SF_VEC_SIZE) + # - Swizzled layout: K must be a multiple of 64 because K/16 (SF blocks + # per row) must be a multiple of 4 for the swizzled padding + # We use K values that satisfy both constraints (multiples of 64) + m_values = [ + 1, + 2, + 4, + 8, + 16, + 32, + 64, + 128, + 256, + 384, + 512, + 768, + 1024, + 1536, + 2048, + 3072, + 4096, + 6144, + 8192, + 12288, + 16384, + 32768, + ] + k_values = [ + 128, + 256, + 384, + 512, + 768, + 1024, + 1536, + 2048, + 3072, + 4096, + 5120, + 6144, + 8192, + 12288, + 16384, + ] + + print(f"\nM values: {m_values}") + print(f"K values: {k_values}") + + if args.bandwidth: + print("\n" + "=" * 80) + print("BANDWIDTH MEASUREMENT MODE (CuTe-DSL only)") + print("=" * 80) + + # Benchmark linear layout + print("\n" + "=" * 80) + print("BENCHMARKING LINEAR (NON-SWIZZLED) LAYOUT - BANDWIDTH") + print("=" * 80) + + bandwidth_linear = run_bandwidth_sweep( + m_values, k_values, dtype, is_sf_swizzled_layout=False + ) + print_bandwidth_summary_table( + m_values, k_values, bandwidth_linear, "Linear Layout" + ) + create_bandwidth_heatmap( + m_values, + k_values, + bandwidth_linear, + f"NVFP4 Quantization Bandwidth (CuTe-DSL) - Linear Layout - {args.dtype}", + f"{args.output_prefix}_bandwidth_linear_{args.dtype}.png", + ) + + # Benchmark swizzled layout + print("\n" + "=" * 80) + print("BENCHMARKING SWIZZLED LAYOUT - BANDWIDTH") + print("=" * 80) + + bandwidth_swizzled = run_bandwidth_sweep( + m_values, k_values, dtype, is_sf_swizzled_layout=True + ) + print_bandwidth_summary_table( + m_values, k_values, bandwidth_swizzled, "Swizzled Layout" + ) + create_bandwidth_heatmap( + m_values, + k_values, + bandwidth_swizzled, + f"NVFP4 Quantization Bandwidth (CuTe-DSL) - Swizzled Layout - {args.dtype}", + f"{args.output_prefix}_bandwidth_swizzled_{args.dtype}.png", + ) + else: + # Speedup comparison mode: CUDA vs CuTe-DSL + # Benchmark linear layout + print("\n" + "=" * 80) + print("BENCHMARKING LINEAR (NON-SWIZZLED) LAYOUT") + print("=" * 80) + + cuda_times_linear, cute_dsl_times_linear = run_benchmark_sweep( + m_values, + k_values, + dtype, + is_sf_swizzled_layout=False, + ) + print_summary_table( + m_values, + k_values, + cuda_times_linear, + cute_dsl_times_linear, + "Linear Layout", + ) + create_heatmap( + m_values, + k_values, + cuda_times_linear, + cute_dsl_times_linear, + f"NVFP4 Quantization Speedup (CuTe-DSL vs CUDA) - Linear Layout - {args.dtype}", + f"{args.output_prefix}_comparison_linear_{args.dtype}.png", + ) + + # Benchmark swizzled layout + print("\n" + "=" * 80) + print("BENCHMARKING SWIZZLED LAYOUT") + print("=" * 80) + + cuda_times_swizzled, cute_dsl_times_swizzled = run_benchmark_sweep( + m_values, + k_values, + dtype, + is_sf_swizzled_layout=True, + ) + print_summary_table( + m_values, + k_values, + cuda_times_swizzled, + cute_dsl_times_swizzled, + "Swizzled Layout", + ) + create_heatmap( + m_values, + k_values, + cuda_times_swizzled, + cute_dsl_times_swizzled, + f"NVFP4 Quantization Speedup (CuTe-DSL vs CUDA) - Swizzled Layout - {args.dtype}", + f"{args.output_prefix}_comparison_swizzled_{args.dtype}.png", + ) + + print("\n" + "=" * 80) + print("BENCHMARK COMPLETE") + print("=" * 80) + + +if __name__ == "__main__": + main() diff --git a/flashinfer/cute_dsl/fp4_common.py b/flashinfer/cute_dsl/fp4_common.py index 150658822c..c381324927 100644 --- a/flashinfer/cute_dsl/fp4_common.py +++ b/flashinfer/cute_dsl/fp4_common.py @@ -196,6 +196,23 @@ def st_global_u64(base_ptr: Int64, value: Uint64, *, loc=None, ip=None): ) +@dsl_user_op +def st_global_u32(base_ptr: Int64, value: Uint32, *, loc=None, ip=None): + """Store 32 bits to global memory.""" + llvm.inline_asm( + None, + [ + Int64(base_ptr).ir_value(loc=loc, ip=ip), + Uint32(value).ir_value(loc=loc, ip=ip), + ], + "st.global.u32 [$0], $1;", + "l,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + @dsl_user_op def get_ptr_as_int64(tensor: cute.Tensor, offset: Int32, *, loc=None, ip=None) -> Int64: """Get the memory address of tensor[offset] as Int64. diff --git a/flashinfer/quantization/kernels/__init__.py b/flashinfer/quantization/kernels/__init__.py index 5df0f078d3..0f30455b43 100644 --- a/flashinfer/quantization/kernels/__init__.py +++ b/flashinfer/quantization/kernels/__init__.py @@ -27,7 +27,8 @@ """ from .mxfp4_quantize import ( - MXFP4QuantizeKernel, + MXFP4QuantizeLinearKernel, + MXFP4QuantizeSwizzledKernel, mxfp4_quantize_cute_dsl, ) from .mxfp8_quantize import ( @@ -41,7 +42,8 @@ ) __all__ = [ - "MXFP4QuantizeKernel", + "MXFP4QuantizeLinearKernel", + "MXFP4QuantizeSwizzledKernel", "mxfp4_quantize_cute_dsl", "MXFP8QuantizeLinearKernel", "MXFP8QuantizeSwizzledKernel", diff --git a/flashinfer/quantization/kernels/mxfp4_quantize.py b/flashinfer/quantization/kernels/mxfp4_quantize.py index 88fa3837c8..6b0c4e8a75 100644 --- a/flashinfer/quantization/kernels/mxfp4_quantize.py +++ b/flashinfer/quantization/kernels/mxfp4_quantize.py @@ -19,6 +19,12 @@ MXFP4 quantization kernel using CuTe-DSL. Supports multiple scale factor layouts: swizzled 128x4 and linear. +Dual-path optimization following the MXFP8 pattern: +- Linear layout: flat SF-block iteration for 100% thread utilization + - Adaptive 1T/4T per SF block dispatch based on GPU SM count: + - High-SM GPUs (>80 SMs): 1T/SF — sufficient outstanding memory requests + - Low-SM GPUs (<=80 SMs): 4T/SF — coalesced 128-bit loads for bandwidth +- Swizzled layout: row-based iteration with multi-row / column-loop paths """ import functools @@ -27,16 +33,33 @@ import cutlass import cutlass.cute as cute import torch -from cutlass import Int32, Uint8 +from cutlass import Float32, Int32, Uint8 from ...api_logging import flashinfer_api -from ...cute_dsl.fp4_common import get_ptr_as_int64, st_global_u64 +from ...cute_dsl.fp4_common import ( + get_ptr_as_int64, + ld_global_v4_u32, + rcp_approx_ftz, + st_global_u32, + st_global_u64, +) from ...cute_dsl.utils import get_num_sm from ..quantization_cute_dsl_utils import ( # MXFP4 Constants MXFP4_SF_VEC_SIZE, + WARP_SIZE, ROW_TILE_SIZE, # Low-level intrinsics + hmax_reduce_to_f32, + bfloat2_hmax_reduce_to_f32, + half2_max_abs_4, + bfloat2_max_abs_4, + reduce_max_4threads, + float_to_ue8m0_fast, + ue8m0_to_inv_scale_fast, + half2_to_float2_scaled, + bfloat2_to_float2_scaled, + cvt_e2m1x8_f32, compute_sf_index_swizzled_128x4_gpu, compute_sf_index_linear_gpu, # High-level helpers (MXFP4) @@ -54,118 +77,343 @@ # Maximum threads per block _MAX_THREADS_PER_BLOCK = 1024 -# Thread configuration bounds -_MIN_THREADS = 128 # Minimum for reasonable occupancy -_MAX_THREADS = 512 # Maximum to avoid register pressure -_DEFAULT_THREADS = 256 # Default thread count +# Thread count bounds for swizzled kernel +_MIN_THREADS = 128 +_MAX_THREADS = 512 + +# Linear kernel: fixed 16 warps (512 threads), 1 SF block per thread +_LINEAR_WARPS_PER_BLOCK = 16 +_LINEAR_SF_BLOCKS_PER_TB = _LINEAR_WARPS_PER_BLOCK * WARP_SIZE # 512 + +# 4T/SF configuration for low-SM-count GPUs +_4T_THREADS_PER_SF = 4 +_4T_SF_PER_WARP = WARP_SIZE // _4T_THREADS_PER_SF # 8 +_4T_SF_BLOCKS_PER_TB = _LINEAR_WARPS_PER_BLOCK * _4T_SF_PER_WARP # 128 +# SM count threshold: use 4T/SF when num_sm <= this value +_LOW_SM_THRESHOLD = 80 -def _compute_optimal_threads_for_k(K: int) -> int: + +def _compute_optimal_threads(K: int, threads_per_sf: int = 1) -> int: """ - Compute optimal thread count for 100% thread utilization. + Compute optimal thread count for 100% utilization in the swizzled kernel. + + For MXFP4: + threads_per_row = (K / 32) * threads_per_sf - For MXFP4, each thread processes one SF block (32 elements). - threads_per_row = K / 32 = num_sf_blocks_per_row + We want num_threads to be a multiple of threads_per_row so that + rows_per_block = num_threads / threads_per_row is an integer. - For 100% utilization when processing multiple rows: - threads_per_block % threads_per_row == 0 + We prefer LARGER thread counts (up to _MAX_THREADS) for better occupancy. - We prefer LARGER thread counts (up to _MAX_THREADS) for better occupancy, - while maintaining 100% thread utilization. + If threads_per_row > _MAX_THREADS, we use _MAX_THREADS with a column loop. Args: K: Number of columns (must be divisible by 32) + threads_per_sf: Threads per SF block (1 for 1T/SF, 4 for 4T/SF) Returns: Optimal number of threads per block """ - threads_per_row = K // MXFP4_SF_VEC_SIZE # K / 32 - - # For 100% utilization: threads_per_block % threads_per_row == 0 - # threads_per_block must be a multiple of threads_per_row + threads_per_row = (K // MXFP4_SF_VEC_SIZE) * threads_per_sf - if threads_per_row >= _MAX_THREADS: - # Large K: use max threads, will need column loop + if threads_per_row > _MAX_THREADS: + # Column loop mode: use maximum threads return _MAX_THREADS - # threads_per_block should be a multiple of threads_per_row - if threads_per_row <= _MAX_THREADS: - # Find largest multiple of threads_per_row <= _MAX_THREADS - threads = (_MAX_THREADS // threads_per_row) * threads_per_row - if threads >= _MIN_THREADS: - return threads - # If largest multiple is below _MIN_THREADS, use the smallest valid one - threads = threads_per_row - while threads < _MIN_THREADS: - threads += threads_per_row - if threads <= _MAX_THREADS: - return threads + # Find largest multiple of threads_per_row in [_MIN_THREADS, _MAX_THREADS] + largest = (_MAX_THREADS // threads_per_row) * threads_per_row + if largest >= _MIN_THREADS: + return largest - # Fallback to default - return _DEFAULT_THREADS + # If largest multiple is below _MIN_THREADS, use smallest valid one + candidate = threads_per_row + while candidate < _MIN_THREADS: + candidate += threads_per_row + if candidate <= _MAX_THREADS: + return candidate - -def _compute_swizzled_layout_sf_size( - total_row: int, total_column: int, row_size: int = 128 -) -> int: - """Compute size of swizzled scale factor buffer.""" - padded_row = (total_row + row_size - 1) // row_size * row_size - padded_column = (total_column + 3) // 4 * 4 - return padded_row * padded_column + # Fallback (shouldn't happen for reasonable K) + return _MAX_THREADS # ============================================================================= -# CuTe-DSL Kernel Class for MXFP4 Swizzled Layout +# CuTe-DSL Kernel Class for Linear Layout — Flat SF-Block Iteration # ============================================================================= -class MXFP4QuantizeKernel: +class MXFP4QuantizeLinearKernel: """ - MXFP4 quantization kernel supporting multiple scale factor layouts. + MXFP4 quantization kernel optimized for LINEAR layout. - Supported layouts: - - 128x4 (swizzled): Optimized for GEMM with large tileN - - linear: Simple row-major layout, no swizzling + Uses flat SF-block iteration for efficient memory access. Row and + column indices are derived from the flat SF index via integer division. - Key features: - - UE8M0 scale factors (unsigned 8-bit exponent-only) - - sf_vec_size=32 (each thread processes 32 elements) - - Multi-row processing when K is small, column loop when K is large - - Row-based iteration with grid-stride loop - - Padding row fast path for zeroing scale factors + No padding passes are needed since for linear layout: + - padded_m == m (no row padding) + - padded_sf_cols == num_sf_blocks_per_row (no column padding) - This kernel is M-agnostic: compiled once per (K, dtype, sf_layout, pdl) - combination. M-dependent values (M, padded_M) are passed at runtime. + Adaptive thread configuration (compile-time selected via use_4t_per_sf): + - 1T/SF (high-SM GPUs): 1 thread per SF block, 32 elements per thread, + no shuffle reduction, 512 SF blocks per TB + - 4T/SF (low-SM GPUs): 4 threads per SF block, 8 elements per thread, + 2 shuffle reductions, 128 SF blocks per TB — better memory coalescing + when fewer SMs can't generate enough outstanding memory requests + + This kernel is M-agnostic: compiled once per (K, dtype, pdl, use_4t) + combination. """ + WARPS_PER_BLOCK = _LINEAR_WARPS_PER_BLOCK + def __init__( self, dtype: cutlass.Numeric, K: int, - sf_layout: int = SF_LAYOUT_128x4, enable_pdl: bool = False, + use_4t_per_sf: bool = False, ): self.dtype = dtype self.K = K self.is_bfloat16 = dtype == cutlass.BFloat16 self.enable_pdl = enable_pdl - self.sf_layout = sf_layout - self.sf_is_128x4 = sf_layout == SF_LAYOUT_128x4 + self.use_4t_per_sf = use_4t_per_sf + + if use_4t_per_sf: + self.SF_BLOCKS_PER_TB = _4T_SF_BLOCKS_PER_TB # 128 + else: + self.SF_BLOCKS_PER_TB = _LINEAR_SF_BLOCKS_PER_TB # 512 assert K % MXFP4_SF_VEC_SIZE == 0 self.num_sf_blocks_per_row = K // MXFP4_SF_VEC_SIZE - if sf_layout == SF_LAYOUT_LINEAR: - self.padded_sf_cols = self.num_sf_blocks_per_row - self.row_tile_size = 1 + @cute.jit + def __call__( + self, + mInput: cute.Tensor, + mOutput: cute.Tensor, + mScales: cute.Tensor, + M: Int32, + total_sf_blocks: Int32, + num_blocks: Int32, + stream, + ): + threads_per_block = self.WARPS_PER_BLOCK * WARP_SIZE + + self.kernel(mInput, mOutput, mScales, M, total_sf_blocks).launch( + grid=[num_blocks, 1, 1], + block=[threads_per_block, 1, 1], + max_number_threads=[_MAX_THREADS_PER_BLOCK, 1, 1], + min_blocks_per_mp=_BLOCKS_PER_SM, + stream=stream, + use_pdl=self.enable_pdl, + ) + + @cute.kernel + def kernel( + self, + mInput: cute.Tensor, + mOutput: cute.Tensor, + mScales: cute.Tensor, + M: Int32, + total_sf_blocks: Int32, + ): + """ + MXFP4 quantization with flat SF-block iteration for linear layout. + + Compile-time branching selects 1T/SF or 4T/SF path: + - 1T/SF: Each thread handles one SF block (32 elements). + - 4T/SF: 4 threads cooperate per SF block (8 elements each), + with coalesced 128-bit loads and 32-bit stores. + Row and column indices are derived from the flat SF index. + """ + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + grid_dim_x, _, _ = cute.arch.grid_dim() + + if cutlass.const_expr(self.enable_pdl): + cute.arch.griddepcontrol_wait() + + num_sf_blocks_per_row = self.num_sf_blocks_per_row + sf_blocks_per_tb = self.SF_BLOCKS_PER_TB + + if cutlass.const_expr(self.use_4t_per_sf): + # ============================================================= + # 4T/SF path: 4 threads cooperate per SF block + # Each thread loads 8 elements (1x128-bit) -- coalesced + # 2-shuffle reduction for cross-thread max + # Each thread converts 8 elements and stores 4 bytes (u32) + # ============================================================= + warp_idx = tidx // WARP_SIZE + lane_idx = tidx % WARP_SIZE + sf_in_warp = lane_idx // Int32(_4T_THREADS_PER_SF) # 0..7 + thread_in_sf = lane_idx % Int32(_4T_THREADS_PER_SF) # 0..3 + + sf_per_warp = Int32(_4T_SF_PER_WARP) + sf_idx = bidx * sf_blocks_per_tb + warp_idx * sf_per_warp + sf_in_warp + stride = grid_dim_x * sf_blocks_per_tb + + while sf_idx < total_sf_blocks: + row_idx = sf_idx // num_sf_blocks_per_row + col_idx = sf_idx % num_sf_blocks_per_row + + # Each thread loads 8 elements (1x128-bit load) + # thread_in_sf 0: [0..7], 1: [8..15], 2: [16..23], 3: [24..31] + # Adjacent threads load adjacent 16-byte chunks → COALESCED + elem_base = col_idx * MXFP4_SF_VEC_SIZE + thread_in_sf * Int32(8) + row_input = mInput[row_idx, None] + h0, h1, h2, h3 = ld_global_v4_u32( + get_ptr_as_int64(row_input, elem_base) + ) + + # Max-abs of 8 elements (within this thread) + if cutlass.const_expr(self.is_bfloat16): + local_max_h2 = bfloat2_max_abs_4(h0, h1, h2, h3) + local_max = bfloat2_hmax_reduce_to_f32(local_max_h2) + else: + local_max_h2 = half2_max_abs_4(h0, h1, h2, h3) + local_max = hmax_reduce_to_f32(local_max_h2) + + # 4-thread reduction (2 shuffles) + global_max = reduce_max_4threads(local_max) + + # UE8M0 scale + ue = float_to_ue8m0_fast(global_max * rcp_approx_ftz(Float32(6.0))) + inv = ue8m0_to_inv_scale_fast(ue) + + # Scale and convert 8 elements to E2M1 + if cutlass.const_expr(self.is_bfloat16): + s0, s1 = bfloat2_to_float2_scaled(h0, inv) + s2, s3 = bfloat2_to_float2_scaled(h1, inv) + s4, s5 = bfloat2_to_float2_scaled(h2, inv) + s6, s7 = bfloat2_to_float2_scaled(h3, inv) + else: + s0, s1 = half2_to_float2_scaled(h0, inv) + s2, s3 = half2_to_float2_scaled(h1, inv) + s4, s5 = half2_to_float2_scaled(h2, inv) + s6, s7 = half2_to_float2_scaled(h3, inv) + packed_u32 = cvt_e2m1x8_f32(s0, s1, s2, s3, s4, s5, s6, s7) + + # Each thread stores 4 bytes at its position (coalesced) + row_output = mOutput[row_idx, None] + out_base = col_idx * (MXFP4_SF_VEC_SIZE // 2) + thread_in_sf * Int32(4) + st_global_u32(get_ptr_as_int64(row_output, out_base), packed_u32) + + # SF: only thread 0 of the 4-thread group writes the scale + if thread_in_sf == Int32(0): + sf_offset = compute_sf_index_linear_gpu( + row_idx, col_idx, num_sf_blocks_per_row + ) + mScales[sf_offset] = ue.to(Uint8) + + sf_idx = sf_idx + stride else: - self.padded_sf_cols = ((self.num_sf_blocks_per_row + 3) // 4) * 4 - self.row_tile_size = ROW_TILE_SIZE # 128 + # ============================================================= + # 1T/SF path: 1 thread per SF block (original path) + # Each thread loads 32 elements (4x128-bit) independently + # ============================================================= + stride = grid_dim_x * sf_blocks_per_tb + + # Flat SF-block iteration + sf_idx = bidx * sf_blocks_per_tb + tidx + + while sf_idx < total_sf_blocks: + row_idx = sf_idx // num_sf_blocks_per_row + col_idx = sf_idx % num_sf_blocks_per_row + + elem_base = col_idx * MXFP4_SF_VEC_SIZE + row_input = mInput[row_idx, None] + + # Process block: load, compute scale, convert to E2M1 + if cutlass.const_expr(self.is_bfloat16): + ( + _, + scale_ue8m0, + packed64_0, + packed64_1, + ) = process_mxfp4_block_bfloat(row_input, elem_base) + else: + ( + _, + scale_ue8m0, + packed64_0, + packed64_1, + ) = process_mxfp4_block_half(row_input, elem_base) + + # Write scale factor using linear indexing + sf_offset = compute_sf_index_linear_gpu( + row_idx, col_idx, num_sf_blocks_per_row + ) + mScales[sf_offset] = scale_ue8m0 + + # Store 16 bytes (32 FP4 values = 2 x st.global.u64) + row_output = mOutput[row_idx, None] + out_base = col_idx * (MXFP4_SF_VEC_SIZE // 2) + out_ptr0 = get_ptr_as_int64(row_output, out_base) + out_ptr1 = get_ptr_as_int64(row_output, out_base + Int32(8)) + st_global_u64(out_ptr0, packed64_0) + st_global_u64(out_ptr1, packed64_1) + + sf_idx = sf_idx + stride + + # PDL: Signal that dependent kernels can start early + if cutlass.const_expr(self.enable_pdl): + cute.arch.griddepcontrol_launch_dependents() + + +# ============================================================================= +# CuTe-DSL Kernel Class for Swizzled Layout — Row-Based Iteration +# ============================================================================= + + +class MXFP4QuantizeSwizzledKernel: + """ + MXFP4 quantization kernel optimized for SWIZZLED (128x4) layout. + + Key optimizations: + - Multi-row processing: threads process multiple rows per block when K is small + - Row-based iteration with grid-stride loop + - Padding row fast path - only zero out scale factors + + Thread utilization optimization: + - Dynamic thread count based on K for 100% thread utilization + - For small K: Multiple rows processed per block iteration + - For large K: Single row with column loop - self.num_threads = _compute_optimal_threads_for_k(K) + Adaptive thread configuration (compile-time selected via use_4t_per_sf): + - 1T/SF (high-SM GPUs): 1 thread per SF block, 4 loads per thread + - 4T/SF (low-SM GPUs): 4 threads per SF block, 1 load per thread, + coalesced access for better memory utilization on few-SM devices - self.threads_per_row = self.num_sf_blocks_per_row + This kernel is M-agnostic: compiled once per (K, dtype, pdl, use_4t) + combination. M-dependent values (M, padded_M) are passed at runtime. + """ + def __init__( + self, + dtype: cutlass.Numeric, + K: int, + enable_pdl: bool = False, + use_4t_per_sf: bool = False, + ): + self.dtype = dtype + self.K = K + self.is_bfloat16 = dtype == cutlass.BFloat16 + self.enable_pdl = enable_pdl + self.use_4t_per_sf = use_4t_per_sf + + assert K % MXFP4_SF_VEC_SIZE == 0 + self.num_sf_blocks_per_row = K // MXFP4_SF_VEC_SIZE + self.padded_sf_cols = ((self.num_sf_blocks_per_row + 3) // 4) * 4 + + # threads_per_sf: 4 for 4T/SF, 1 for 1T/SF + self._threads_per_sf = _4T_THREADS_PER_SF if use_4t_per_sf else 1 + + # Compute optimal thread count for 100% utilization + self.num_threads = _compute_optimal_threads(K, self._threads_per_sf) + self.threads_per_row = self.num_sf_blocks_per_row * self._threads_per_sf + + # Multi-row processing constants (compile-time) if self.threads_per_row <= self.num_threads: self.rows_per_block = self.num_threads // self.threads_per_row self.needs_col_loop = False @@ -173,16 +421,6 @@ def __init__( self.rows_per_block = 1 self.needs_col_loop = True - @cute.jit - def _compute_sf_offset( - self, row_idx: Int32, col_idx: Int32, padded_cols: Int32 - ) -> Int32: - """Compute scale factor offset based on layout (compile-time dispatch).""" - if cutlass.const_expr(self.sf_is_128x4): - return compute_sf_index_swizzled_128x4_gpu(row_idx, col_idx, padded_cols) - else: - return compute_sf_index_linear_gpu(row_idx, col_idx, padded_cols) - @cute.jit def __call__( self, @@ -194,11 +432,9 @@ def __call__( num_blocks: Int32, stream, ): - threads_per_block = self.num_threads - self.kernel(mInput, mOutput, mScales, M, padded_M).launch( grid=[num_blocks, 1, 1], - block=[threads_per_block, 1, 1], + block=[self.num_threads, 1, 1], max_number_threads=[_MAX_THREADS_PER_BLOCK, 1, 1], min_blocks_per_mp=_BLOCKS_PER_SM, stream=stream, @@ -215,23 +451,10 @@ def kernel( padded_M: Int32, ): """ - MXFP4 quantization kernel with configurable scale factor layout. - - Dual-path kernel with compile-time selection: - - Small K path: Multi-row processing for improved thread utilization - - Large K path: Single row with column loop - - Each thread processes one SF block (32 elements): - 1. Load 32 bf16/fp16 elements (4 x 128-bit loads) - 2. Compute max absolute value using SIMD reduction - 3. Compute UE8M0 scale: ceil(log2(max / 6.0)) + 127 - 4. Store scale factor using layout-specific indexing - 5. Scale elements and convert to E2M1 - 6. Store 16 bytes (32 FP4 values) - - Note: For MXFP4 (UE8M0 scale format), global scale is NOT used in - the scale computation, unlike NVFP4 (E4M3 scale format). The UE8M0 - format directly captures the per-block dynamic range. + Row-based kernel for swizzled layout. + + When K is small: each block processes multiple rows simultaneously. + When K is large: each block processes one row with column loop. """ tidx, _, _ = cute.arch.thread_idx() bidx, _, _ = cute.arch.block_idx() @@ -240,46 +463,115 @@ def kernel( if cutlass.const_expr(self.enable_pdl): cute.arch.griddepcontrol_wait() + # Compile-time constants num_sf_blocks_per_row = self.num_sf_blocks_per_row padded_sf_cols = self.padded_sf_cols - num_threads = self.num_threads - rows_per_block = self.rows_per_block threads_per_row = self.threads_per_row + rows_per_block = self.rows_per_block - if cutlass.const_expr(not self.needs_col_loop): - # ===== SMALL K PATH: Multi-row processing ===== - # Each block processes rows_per_block rows simultaneously - # Thread maps to: row_in_block = tidx // threads_per_row - # sf_idx = tidx % threads_per_row - row_in_block = tidx // threads_per_row - sf_idx_in_row = tidx % threads_per_row + _threads_per_sf = self._threads_per_sf - # Grid-stride loop over row batches - row_batch_idx = bidx - total_row_batches = cute.ceil_div(padded_M, rows_per_block) + if cutlass.const_expr(self.needs_col_loop): + # Large K path: single row per block iteration with column loop + num_threads = self.num_threads - while row_batch_idx < total_row_batches: - base_row = row_batch_idx * rows_per_block - row_idx = base_row + row_in_block + if cutlass.const_expr(self.use_4t_per_sf): + # 4T/SF: 4 threads per SF, stride over SF columns + col_unit_idx = tidx // _threads_per_sf + thread_in_sf = tidx % _threads_per_sf + col_units_per_block = num_threads // _threads_per_sf + else: + col_unit_idx = tidx + col_units_per_block = num_threads - if row_idx < padded_M: - is_padding_row = row_idx >= M + row_idx = bidx + while row_idx < padded_M: + is_padding_row = row_idx >= M - if is_padding_row: - local_sf_idx = sf_idx_in_row - while local_sf_idx < padded_sf_cols: - sf_offset = self._compute_sf_offset( - row_idx, local_sf_idx, padded_sf_cols + if is_padding_row: + # Fast path: padding row - only zero out scale factors + if cutlass.const_expr(self.use_4t_per_sf): + sf_col_idx = col_unit_idx + while sf_col_idx < padded_sf_cols: + if thread_in_sf == Int32(0): + sf_offset = compute_sf_index_swizzled_128x4_gpu( + row_idx, sf_col_idx, padded_sf_cols + ) + mScales[sf_offset] = Uint8(0) + sf_col_idx = sf_col_idx + col_units_per_block + else: + sf_col_idx = tidx + while sf_col_idx < padded_sf_cols: + sf_offset = compute_sf_index_swizzled_128x4_gpu( + row_idx, sf_col_idx, padded_sf_cols ) mScales[sf_offset] = Uint8(0) - local_sf_idx = local_sf_idx + threads_per_row + sf_col_idx = sf_col_idx + num_threads + else: + if cutlass.const_expr(self.use_4t_per_sf): + # 4T/SF: each 4-thread group processes one SF block + sf_col_idx = col_unit_idx + while sf_col_idx < num_sf_blocks_per_row: + elem_base = ( + sf_col_idx * MXFP4_SF_VEC_SIZE + thread_in_sf * Int32(8) + ) + row_input = mInput[row_idx, None] + h0, h1, h2, h3 = ld_global_v4_u32( + get_ptr_as_int64(row_input, elem_base) + ) + if cutlass.const_expr(self.is_bfloat16): + local_max = bfloat2_hmax_reduce_to_f32( + bfloat2_max_abs_4(h0, h1, h2, h3) + ) + else: + local_max = hmax_reduce_to_f32( + half2_max_abs_4(h0, h1, h2, h3) + ) + global_max = reduce_max_4threads(local_max) + ue = float_to_ue8m0_fast( + global_max * rcp_approx_ftz(Float32(6.0)) + ) + inv = ue8m0_to_inv_scale_fast(ue) + if cutlass.const_expr(self.is_bfloat16): + s0, s1 = bfloat2_to_float2_scaled(h0, inv) + s2, s3 = bfloat2_to_float2_scaled(h1, inv) + s4, s5 = bfloat2_to_float2_scaled(h2, inv) + s6, s7 = bfloat2_to_float2_scaled(h3, inv) + else: + s0, s1 = half2_to_float2_scaled(h0, inv) + s2, s3 = half2_to_float2_scaled(h1, inv) + s4, s5 = half2_to_float2_scaled(h2, inv) + s6, s7 = half2_to_float2_scaled(h3, inv) + packed_u32 = cvt_e2m1x8_f32(s0, s1, s2, s3, s4, s5, s6, s7) + row_output = mOutput[row_idx, None] + out_base = sf_col_idx * ( + MXFP4_SF_VEC_SIZE // 2 + ) + thread_in_sf * Int32(4) + st_global_u32( + get_ptr_as_int64(row_output, out_base), packed_u32 + ) + if thread_in_sf == Int32(0): + sf_offset = compute_sf_index_swizzled_128x4_gpu( + row_idx, sf_col_idx, padded_sf_cols + ) + mScales[sf_offset] = ue.to(Uint8) + sf_col_idx = sf_col_idx + col_units_per_block + + # Handle padding columns + sf_col_idx = num_sf_blocks_per_row + col_unit_idx + while sf_col_idx < padded_sf_cols: + if thread_in_sf == Int32(0): + sf_offset = compute_sf_index_swizzled_128x4_gpu( + row_idx, sf_col_idx, padded_sf_cols + ) + mScales[sf_offset] = Uint8(0) + sf_col_idx = sf_col_idx + col_units_per_block else: - # Normal path: process actual data row - if sf_idx_in_row < num_sf_blocks_per_row: - elem_base = sf_idx_in_row * MXFP4_SF_VEC_SIZE + # 1T/SF: each thread processes one full SF block + sf_col_idx = tidx + while sf_col_idx < num_sf_blocks_per_row: + elem_base = sf_col_idx * MXFP4_SF_VEC_SIZE row_input = mInput[row_idx, None] - - # Process block: load, compute scale, convert to E2M1 if cutlass.const_expr(self.is_bfloat16): ( _, @@ -294,92 +586,180 @@ def kernel( packed64_0, packed64_1, ) = process_mxfp4_block_half(row_input, elem_base) - - sf_offset = self._compute_sf_offset( - row_idx, sf_idx_in_row, padded_sf_cols + sf_offset = compute_sf_index_swizzled_128x4_gpu( + row_idx, sf_col_idx, padded_sf_cols ) mScales[sf_offset] = scale_ue8m0 - row_output = mOutput[row_idx, None] - out_base = sf_idx_in_row * (MXFP4_SF_VEC_SIZE // 2) + out_base = sf_col_idx * (MXFP4_SF_VEC_SIZE // 2) out_ptr0 = get_ptr_as_int64(row_output, out_base) out_ptr1 = get_ptr_as_int64(row_output, out_base + Int32(8)) st_global_u64(out_ptr0, packed64_0) st_global_u64(out_ptr1, packed64_1) + sf_col_idx = sf_col_idx + num_threads - padding_sf_start = num_sf_blocks_per_row + sf_idx_in_row - while padding_sf_start < padded_sf_cols: - sf_offset = self._compute_sf_offset( - row_idx, padding_sf_start, padded_sf_cols + # Handle padding columns + sf_col_idx = num_sf_blocks_per_row + tidx + while sf_col_idx < padded_sf_cols: + sf_offset = compute_sf_index_swizzled_128x4_gpu( + row_idx, sf_col_idx, padded_sf_cols ) mScales[sf_offset] = Uint8(0) - padding_sf_start = padding_sf_start + threads_per_row - - row_batch_idx = row_batch_idx + grid_dim_x + sf_col_idx = sf_col_idx + num_threads + row_idx = row_idx + grid_dim_x else: - # ===== LARGE K PATH: Single row with column loop ===== - row_idx = bidx - while row_idx < padded_M: - is_padding_row = row_idx >= M - - sf_idx = Int32(tidx) - - if is_padding_row: - while sf_idx < padded_sf_cols: - sf_offset = self._compute_sf_offset( - row_idx, sf_idx, padded_sf_cols - ) - mScales[sf_offset] = Uint8(0) - sf_idx = sf_idx + num_threads - else: - num_sf_iters = ( - num_sf_blocks_per_row + num_threads - 1 - ) // num_threads - - for sf_iter in range(num_sf_iters): - local_sf_idx = sf_iter * num_threads + tidx - - if local_sf_idx < num_sf_blocks_per_row: - elem_base = local_sf_idx * MXFP4_SF_VEC_SIZE - row_input = mInput[row_idx, None] - - if cutlass.const_expr(self.is_bfloat16): - ( - _, - scale_ue8m0, - packed64_0, - packed64_1, - ) = process_mxfp4_block_bfloat(row_input, elem_base) - else: - ( - _, - scale_ue8m0, - packed64_0, - packed64_1, - ) = process_mxfp4_block_half(row_input, elem_base) + # Small K path: multi-row processing + # Thread mapping: tidx -> (row_in_block, local_tidx) + row_in_block = tidx // threads_per_row + local_tidx = tidx % threads_per_row - sf_offset = self._compute_sf_offset( - row_idx, local_sf_idx, padded_sf_cols - ) - mScales[sf_offset] = scale_ue8m0 + if cutlass.const_expr(self.use_4t_per_sf): + sf_idx_in_row = local_tidx // _threads_per_sf + thread_in_sf = local_tidx % _threads_per_sf + else: + sf_idx_in_row = local_tidx + thread_in_sf = Int32(0) - row_output = mOutput[row_idx, None] - out_base = local_sf_idx * (MXFP4_SF_VEC_SIZE // 2) - out_ptr0 = get_ptr_as_int64(row_output, out_base) - out_ptr1 = get_ptr_as_int64(row_output, out_base + Int32(8)) - st_global_u64(out_ptr0, packed64_0) - st_global_u64(out_ptr1, packed64_1) + # Grid-stride loop over row batches + row_batch_idx = bidx + row_idx = row_batch_idx * rows_per_block + row_in_block + while row_batch_idx * rows_per_block < padded_M: + if row_idx < padded_M: + is_padding_row = row_idx >= M - padding_sf_start = num_sf_blocks_per_row + tidx - while padding_sf_start < padded_sf_cols: - sf_offset = self._compute_sf_offset( - row_idx, padding_sf_start, padded_sf_cols - ) - mScales[sf_offset] = Uint8(0) - padding_sf_start = padding_sf_start + num_threads + if is_padding_row: + # Fast path: padding row - zero ALL padded_sf_cols + # Thread-stride loop for padding + if cutlass.const_expr(self.use_4t_per_sf): + if thread_in_sf == Int32(0): + local_sf = sf_idx_in_row + while local_sf < padded_sf_cols: + sf_offset = compute_sf_index_swizzled_128x4_gpu( + row_idx, local_sf, padded_sf_cols + ) + mScales[sf_offset] = Uint8(0) + local_sf = local_sf + num_sf_blocks_per_row + else: + local_sf_idx = sf_idx_in_row + while local_sf_idx < padded_sf_cols: + sf_offset = compute_sf_index_swizzled_128x4_gpu( + row_idx, local_sf_idx, padded_sf_cols + ) + mScales[sf_offset] = Uint8(0) + local_sf_idx = local_sf_idx + threads_per_row + else: + if cutlass.const_expr(self.use_4t_per_sf): + # 4T/SF: process data + if sf_idx_in_row < num_sf_blocks_per_row: + elem_base = ( + sf_idx_in_row * MXFP4_SF_VEC_SIZE + + thread_in_sf * Int32(8) + ) + row_input = mInput[row_idx, None] + h0, h1, h2, h3 = ld_global_v4_u32( + get_ptr_as_int64(row_input, elem_base) + ) + if cutlass.const_expr(self.is_bfloat16): + local_max = bfloat2_hmax_reduce_to_f32( + bfloat2_max_abs_4(h0, h1, h2, h3) + ) + else: + local_max = hmax_reduce_to_f32( + half2_max_abs_4(h0, h1, h2, h3) + ) + global_max = reduce_max_4threads(local_max) + ue = float_to_ue8m0_fast( + global_max * rcp_approx_ftz(Float32(6.0)) + ) + inv = ue8m0_to_inv_scale_fast(ue) + if cutlass.const_expr(self.is_bfloat16): + s0, s1 = bfloat2_to_float2_scaled(h0, inv) + s2, s3 = bfloat2_to_float2_scaled(h1, inv) + s4, s5 = bfloat2_to_float2_scaled(h2, inv) + s6, s7 = bfloat2_to_float2_scaled(h3, inv) + else: + s0, s1 = half2_to_float2_scaled(h0, inv) + s2, s3 = half2_to_float2_scaled(h1, inv) + s4, s5 = half2_to_float2_scaled(h2, inv) + s6, s7 = half2_to_float2_scaled(h3, inv) + packed_u32 = cvt_e2m1x8_f32( + s0, s1, s2, s3, s4, s5, s6, s7 + ) + row_output = mOutput[row_idx, None] + out_base = sf_idx_in_row * ( + MXFP4_SF_VEC_SIZE // 2 + ) + thread_in_sf * Int32(4) + st_global_u32( + get_ptr_as_int64(row_output, out_base), + packed_u32, + ) + if thread_in_sf == Int32(0): + sf_offset = compute_sf_index_swizzled_128x4_gpu( + row_idx, + sf_idx_in_row, + padded_sf_cols, + ) + mScales[sf_offset] = ue.to(Uint8) + + # Padding columns (4T) + if cutlass.const_expr( + self.num_sf_blocks_per_row != self.padded_sf_cols + ): + if thread_in_sf == Int32(0): + pad_col = num_sf_blocks_per_row + sf_idx_in_row + while pad_col < padded_sf_cols: + sf_offset = compute_sf_index_swizzled_128x4_gpu( + row_idx, pad_col, padded_sf_cols + ) + mScales[sf_offset] = Uint8(0) + pad_col = pad_col + (num_sf_blocks_per_row) + else: + # 1T/SF: process data + if sf_idx_in_row < num_sf_blocks_per_row: + elem_base = sf_idx_in_row * MXFP4_SF_VEC_SIZE + row_input = mInput[row_idx, None] + if cutlass.const_expr(self.is_bfloat16): + ( + _, + scale_ue8m0, + packed64_0, + packed64_1, + ) = process_mxfp4_block_bfloat(row_input, elem_base) + else: + ( + _, + scale_ue8m0, + packed64_0, + packed64_1, + ) = process_mxfp4_block_half(row_input, elem_base) + sf_offset = compute_sf_index_swizzled_128x4_gpu( + row_idx, sf_idx_in_row, padded_sf_cols + ) + mScales[sf_offset] = scale_ue8m0 + row_output = mOutput[row_idx, None] + out_base = sf_idx_in_row * (MXFP4_SF_VEC_SIZE // 2) + out_ptr0 = get_ptr_as_int64(row_output, out_base) + out_ptr1 = get_ptr_as_int64( + row_output, out_base + Int32(8) + ) + st_global_u64(out_ptr0, packed64_0) + st_global_u64(out_ptr1, packed64_1) + + # Padding columns (1T) + if cutlass.const_expr( + self.num_sf_blocks_per_row != self.padded_sf_cols + ): + pad_col = num_sf_blocks_per_row + sf_idx_in_row + while pad_col < padded_sf_cols: + sf_offset = compute_sf_index_swizzled_128x4_gpu( + row_idx, pad_col, padded_sf_cols + ) + mScales[sf_offset] = Uint8(0) + pad_col = pad_col + threads_per_row - row_idx = row_idx + grid_dim_x + row_batch_idx = row_batch_idx + grid_dim_x + row_idx = row_batch_idx * rows_per_block + row_in_block # PDL: Signal that dependent kernels can start early if cutlass.const_expr(self.enable_pdl): @@ -397,49 +777,72 @@ def _get_compiled_kernel_mxfp4( K: int, sf_layout: int = SF_LAYOUT_128x4, enable_pdl: bool = False, + use_4t_per_sf: bool = False, ) -> Tuple[Callable, int]: """ Get or compile MXFP4 kernel with TVM-FFI. - Cached by (K, dtype, sf_layout, pdl) - M-agnostic, device-independent - compilation. + Cached by (K, dtype, sf_layout, pdl, use_4t_per_sf) - M-agnostic, + device-independent compilation. Returns: - Tuple of (compiled_kernel, rows_per_block) where rows_per_block - is used by the caller to compute num_blocks at runtime. + For linear layout: (compiled_kernel, sf_blocks_per_tb) + For swizzled layout: (compiled_kernel, rows_per_block) """ cutlass_dtype = cutlass.BFloat16 if is_bfloat16 else cutlass.Float16 - kernel_obj = MXFP4QuantizeKernel(cutlass_dtype, K, sf_layout, enable_pdl) # Use symbolic M for dynamic batch sizes sym_m = cute.sym_int() + sym_scale_size = cute.sym_int() - # Create fake tensors for compilation + # Common fake tensors input_fake = cute.runtime.make_fake_compact_tensor( cutlass_dtype, (sym_m, K), stride_order=(1, 0), assumed_align=16 ) output_fake = cute.runtime.make_fake_compact_tensor( cutlass.Uint8, (sym_m, K // 2), stride_order=(1, 0), assumed_align=16 ) - sym_scale_size = cute.sym_int() scales_fake = cute.runtime.make_fake_compact_tensor( cutlass.Uint8, (sym_scale_size,), assumed_align=16 ) stream_fake = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True) - compiled_kernel = cute.compile( - kernel_obj, - input_fake, - output_fake, - scales_fake, - Int32(1), # Dummy M - Int32(128), # Dummy padded_M - Int32(1), # Dummy num_blocks - stream_fake, - options="--enable-tvm-ffi", - ) + if sf_layout == SF_LAYOUT_LINEAR: + linear_obj = MXFP4QuantizeLinearKernel( + cutlass_dtype, K, enable_pdl, use_4t_per_sf + ) + + compiled_kernel = cute.compile( + linear_obj, + input_fake, + output_fake, + scales_fake, + Int32(1), # Dummy M + Int32(1), # Dummy total_sf_blocks + Int32(1), # Dummy num_blocks + stream_fake, + options="--enable-tvm-ffi", + ) + + return compiled_kernel, linear_obj.SF_BLOCKS_PER_TB + else: + swizzled_obj = MXFP4QuantizeSwizzledKernel( + cutlass_dtype, K, enable_pdl, use_4t_per_sf + ) + + compiled_kernel = cute.compile( + swizzled_obj, + input_fake, + output_fake, + scales_fake, + Int32(1), # Dummy M + Int32(128), # Dummy padded_M + Int32(1), # Dummy num_blocks + stream_fake, + options="--enable-tvm-ffi", + ) - return compiled_kernel, kernel_obj.rows_per_block + return compiled_kernel, swizzled_obj.rows_per_block @flashinfer_api @@ -451,14 +854,16 @@ def mxfp4_quantize_cute_dsl( """ Quantize input tensor to MXFP4 format using CuTe-DSL kernel. - This is a GPU implementation matching FlashInfer's mxfp4_quantize() behavior: - - Global scale computed as (448 * 6) / max(|input|) - - UE8M0 scale factors - - E2M1 output format (4-bit, 2 values per byte) - - Supports 128x4 (swizzled) and linear scale factor layouts + This is a GPU implementation with dual-path optimization: + - LINEAR layout: flat SF-block based iteration with adaptive 1T/4T per SF + block dispatch — uses 4T/SF on low-SM GPUs (<=80 SMs) for coalesced + memory access, and 1T/SF on high-SM GPUs where enough SMs generate + sufficient outstanding memory requests + - SWIZZLED layout: row-based iteration with padding fast path (optimized) - The kernel is compiled once per (K, dtype, sf_layout, pdl) combination and - handles varying M (batch size) at runtime without recompilation. + The kernel is compiled once per (K, dtype, sf_layout, pdl, use_4t) + combination and handles varying M (batch size) at runtime without + recompilation. Args: input: Input tensor of shape [M, K] with dtype fp16/bf16 @@ -483,8 +888,7 @@ def mxfp4_quantize_cute_dsl( ) assert input.is_cuda, "Input must be on CUDA device" - if enable_pdl is None: - enable_pdl = device_support_pdl(input.device) + enable_pdl = device_support_pdl(input.device) if enable_pdl is not False else False if input.dim() > 2: m = input.numel() // input.shape[-1] @@ -500,37 +904,57 @@ def mxfp4_quantize_cute_dsl( input = input.contiguous() is_bfloat16 = input.dtype == torch.bfloat16 - target_grid = get_num_sm(input.device) * _BLOCKS_PER_SM + num_sm = get_num_sm(input.device) + target_grid = num_sm * _BLOCKS_PER_SM num_sf_blocks_per_row = k // MXFP4_SF_VEC_SIZE + # Use 4T/SF on low-SM-count GPUs for better memory coalescing. + # On GPUs with many SMs (B200, RTX 5090), 1T/SF is faster because + # enough SMs generate sufficient outstanding memory requests. + # On GPUs with few SMs (DGX Spark), 4T/SF's coalesced access is needed. + use_4t = num_sm <= _LOW_SM_THRESHOLD + + # Get or compile kernel (device-independent) + kernel_fn, block_unit = _get_compiled_kernel_mxfp4( + is_bfloat16, k, sf_layout, enable_pdl, use_4t + ) + if sf_layout == SF_LAYOUT_LINEAR: - row_tile_size = 1 - # NOTE: When adding a TMA-based kernel, padded_m must be rounded up to the - # TMA tile row dimension (e.g. round_up(m, tma_tile_rows)) and scale_output - # must be trimmed to m * num_sf_blocks_per_row before returning. - # See PR f4d10d9 for the analogous CUDA fix. padded_m = m padded_sf_cols = num_sf_blocks_per_row - else: - row_tile_size = ROW_TILE_SIZE # 128 - padded_m = ((m + row_tile_size - 1) // row_tile_size) * row_tile_size - padded_sf_cols = ((num_sf_blocks_per_row + 3) // 4) * 4 + total_sf_blocks = m * num_sf_blocks_per_row + scale_output_size = total_sf_blocks - scale_output_size = padded_m * padded_sf_cols + sf_blocks_per_tb = block_unit + num_blocks = min( + (total_sf_blocks + sf_blocks_per_tb - 1) // sf_blocks_per_tb, + target_grid, + ) - kernel_fn, rows_per_block = _get_compiled_kernel_mxfp4( - is_bfloat16, k, sf_layout, enable_pdl - ) + fp4_output = torch.empty(m, k // 2, dtype=torch.uint8, device=input.device) + scale_output = torch.empty( + scale_output_size, dtype=torch.uint8, device=input.device + ) - num_blocks = min((padded_m + rows_per_block - 1) // rows_per_block, target_grid) + kernel_fn(input, fp4_output, scale_output, m, total_sf_blocks, num_blocks) + else: + padded_m = ((m + ROW_TILE_SIZE - 1) // ROW_TILE_SIZE) * ROW_TILE_SIZE + padded_sf_cols = ((num_sf_blocks_per_row + 3) // 4) * 4 + scale_output_size = padded_m * padded_sf_cols - fp4_output = torch.empty(m, k // 2, dtype=torch.uint8, device=input.device) - scale_output = torch.empty( - scale_output_size, dtype=torch.uint8, device=input.device - ) + rows_per_block = block_unit + num_blocks = min( + (padded_m + rows_per_block - 1) // rows_per_block, + target_grid, + ) + + fp4_output = torch.empty(m, k // 2, dtype=torch.uint8, device=input.device) + scale_output = torch.empty( + scale_output_size, dtype=torch.uint8, device=input.device + ) - kernel_fn(input, fp4_output, scale_output, m, padded_m, num_blocks) + kernel_fn(input, fp4_output, scale_output, m, padded_m, num_blocks) scale_output = scale_output.reshape(-1, num_sf_blocks_per_row) @@ -540,7 +964,8 @@ def mxfp4_quantize_cute_dsl( __all__ = [ "SF_LAYOUT_128x4", "SF_LAYOUT_LINEAR", - "MXFP4QuantizeKernel", + "MXFP4QuantizeLinearKernel", + "MXFP4QuantizeSwizzledKernel", "mxfp4_quantize_cute_dsl", "_get_compiled_kernel_mxfp4", ] diff --git a/flashinfer/quantization/kernels/mxfp8_quantize.py b/flashinfer/quantization/kernels/mxfp8_quantize.py index 445ee81fb7..b64127ba9b 100644 --- a/flashinfer/quantization/kernels/mxfp8_quantize.py +++ b/flashinfer/quantization/kernels/mxfp8_quantize.py @@ -50,17 +50,24 @@ ELTS_PER_THREAD, THREADS_PER_SF, SF_BLOCKS_PER_WARP, + ELTS_PER_THREAD_SMALL, + THREADS_PER_SF_SMALL, + SF_BLOCKS_PER_WARP_SMALL, + MXFP8_2T_SF_THRESHOLD, ROW_TILE_SIZE, # Low-level intrinsics hmax_reduce_to_f32, bfloat2_hmax_reduce_to_f32, float_to_ue8m0_fast, ue8m0_to_inv_scale_fast, + reduce_max_2threads, reduce_max_4threads, compute_sf_index_swizzled_128x4_gpu, # High-level helpers half2_max_abs_4, + half2_max_abs_8, bfloat2_max_abs_4, + bfloat2_max_abs_8, half2x4_to_fp8x8_packed, bfloat2x4_to_fp8x8_packed, ) @@ -79,12 +86,12 @@ _DEFAULT_WARPS = 16 # Default when no optimization needed -def _compute_optimal_warps_for_k(K: int) -> int: +def _compute_optimal_warps(K: int, sf_blocks_per_warp: int = SF_BLOCKS_PER_WARP) -> int: """ Compute optimal WARPS_PER_BLOCK for 100% thread utilization. For the swizzled kernel, we need: - (WARPS × 8) % num_sf_blocks == 0 + (WARPS x sf_blocks_per_warp) % num_sf_blocks == 0 where num_sf_blocks = K / 32. @@ -96,6 +103,8 @@ def _compute_optimal_warps_for_k(K: int) -> int: Args: K: Number of columns (must be divisible by 32) + sf_blocks_per_warp: SF blocks per warp for the selected thread + configuration (16 for 2T/SF, 8 for 4T/SF) Returns: Optimal number of warps per block @@ -104,9 +113,9 @@ def _compute_optimal_warps_for_k(K: int) -> int: num_sf_blocks = K // SF_VEC_SIZE # K / 32 - # For 100% utilization: (WARPS * 8) % num_sf_blocks == 0 - # WARPS must be a multiple of: num_sf_blocks / gcd(num_sf_blocks, 8) - gcd_val = math.gcd(num_sf_blocks, 8) + # For 100% utilization: (WARPS * sf_blocks_per_warp) % num_sf_blocks == 0 + # WARPS must be a multiple of: num_sf_blocks / gcd(num_sf_blocks, sf_blocks_per_warp) + gcd_val = math.gcd(num_sf_blocks, sf_blocks_per_warp) warp_multiple = num_sf_blocks // gcd_val # Find LARGEST valid WARPS in range [_MIN_WARPS, _MAX_WARPS] @@ -129,32 +138,54 @@ def _compute_optimal_warps_for_k(K: int) -> int: # ============================================================================= -# CuTe-DSL Kernel Class for Linear Layout +# CuTe-DSL Kernel Class for Linear Layout — Flat SF-Block Iteration # ============================================================================= class MXFP8QuantizeLinearKernel: """ MXFP8 quantization kernel optimized for LINEAR layout. - Uses SF-block based iteration for efficient memory access. - This kernel is M-agnostic: compiled once per (K, dtype, pdl) combination. - M-dependent values (total_sf_blocks, num_blocks) are passed at runtime. + Uses flat SF-block iteration for efficient memory access. Row and + column indices are derived from the flat SF index via integer division. + + No padding passes are needed since for linear layout: + - padded_m == m (no row padding) + - padded_sf_cols == num_sf_blocks_per_row (no column padding) + + Adaptive thread configuration (compile-time selected via use_2t_per_sf): + - 2T/SF (large problems): 2 threads per SF block, 16 elements per thread, + 1 shuffle reduction, 16 SF blocks per warp + - 4T/SF (small problems): 4 threads per SF block, 8 elements per thread, + 2 shuffle reductions, 8 SF blocks per warp + + This kernel is M-agnostic: compiled once per (K, dtype, pdl, use_2t) + combination. """ WARPS_PER_BLOCK = 16 # 16 warps = 512 threads per block - SF_BLOCKS_PER_TB = WARPS_PER_BLOCK * SF_BLOCKS_PER_WARP def __init__( self, dtype: cutlass.Numeric, K: int, enable_pdl: bool = False, + use_2t_per_sf: bool = True, ): - self.dtype = dtype - self.K = K self.is_bfloat16 = dtype == cutlass.BFloat16 self.enable_pdl = enable_pdl + self.use_2t_per_sf = use_2t_per_sf + + if use_2t_per_sf: + self._elts_per_thread = ELTS_PER_THREAD + self._threads_per_sf = THREADS_PER_SF + self._sf_blocks_per_warp = SF_BLOCKS_PER_WARP + else: + self._elts_per_thread = ELTS_PER_THREAD_SMALL + self._threads_per_sf = THREADS_PER_SF_SMALL + self._sf_blocks_per_warp = SF_BLOCKS_PER_WARP_SMALL + + self.SF_BLOCKS_PER_TB = self.WARPS_PER_BLOCK * self._sf_blocks_per_warp assert K % SF_VEC_SIZE == 0 self.num_sf_blocks_per_row = K // SF_VEC_SIZE @@ -198,14 +229,18 @@ def kernel( warp_idx = tidx // WARP_SIZE lane_idx = tidx % WARP_SIZE - sf_idx_in_warp = lane_idx // THREADS_PER_SF - thread_in_sf = lane_idx % THREADS_PER_SF + threads_per_sf = self._threads_per_sf + sf_blocks_per_warp = self._sf_blocks_per_warp + elts_per_thread = self._elts_per_thread + + sf_idx_in_warp = lane_idx // threads_per_sf + thread_in_sf = lane_idx % threads_per_sf - sf_blocks_per_tb = self.WARPS_PER_BLOCK * SF_BLOCKS_PER_WARP + sf_blocks_per_tb = self.WARPS_PER_BLOCK * sf_blocks_per_warp num_sf_blocks_per_row = self.num_sf_blocks_per_row sf_idx_base = ( - bidx * sf_blocks_per_tb + warp_idx * SF_BLOCKS_PER_WARP + sf_idx_in_warp + bidx * sf_blocks_per_tb + warp_idx * sf_blocks_per_warp + sf_idx_in_warp ) sf_idx = sf_idx_base @@ -214,24 +249,39 @@ def kernel( col_idx = sf_idx % num_sf_blocks_per_row base_elem = col_idx * SF_VEC_SIZE - thread_elem_offset = thread_in_sf * ELTS_PER_THREAD + thread_elem_offset = thread_in_sf * elts_per_thread elem_idx = base_elem + thread_elem_offset row_input = mInput[row_idx, None] - input_ptr_i64 = get_ptr_as_int64(row_input, elem_idx) - v0, v1, v2, v3 = ld_global_v4_u32(input_ptr_i64) + if cutlass.const_expr(self.use_2t_per_sf): + # 2T/SF path: load 16 elements (2x128-bit), 1-shuffle reduction + input_ptr_lo = get_ptr_as_int64(row_input, elem_idx) + input_ptr_hi = get_ptr_as_int64(row_input, elem_idx + Int32(8)) + v0, v1, v2, v3 = ld_global_v4_u32(input_ptr_lo) + v4, v5, v6, v7 = ld_global_v4_u32(input_ptr_hi) - # Compute max absolute value across 8 elements (4 half2/bfloat2) - if cutlass.const_expr(self.is_bfloat16): - max0123 = bfloat2_max_abs_4(v0, v1, v2, v3) - local_max = bfloat2_hmax_reduce_to_f32(max0123) + if cutlass.const_expr(self.is_bfloat16): + max_all = bfloat2_max_abs_8(v0, v1, v2, v3, v4, v5, v6, v7) + local_max = bfloat2_hmax_reduce_to_f32(max_all) + else: + max_all = half2_max_abs_8(v0, v1, v2, v3, v4, v5, v6, v7) + local_max = hmax_reduce_to_f32(max_all) + + global_max = reduce_max_2threads(local_max) else: - max0123 = half2_max_abs_4(v0, v1, v2, v3) - local_max = hmax_reduce_to_f32(max0123) + # 4T/SF path: load 8 elements (1x128-bit), 2-shuffle reduction + input_ptr_i64 = get_ptr_as_int64(row_input, elem_idx) + v0, v1, v2, v3 = ld_global_v4_u32(input_ptr_i64) + + if cutlass.const_expr(self.is_bfloat16): + max0123 = bfloat2_max_abs_4(v0, v1, v2, v3) + local_max = bfloat2_hmax_reduce_to_f32(max0123) + else: + max0123 = half2_max_abs_4(v0, v1, v2, v3) + local_max = hmax_reduce_to_f32(max0123) - # 4-thread reduction for this SF block - global_max = reduce_max_4threads(local_max) + global_max = reduce_max_4threads(local_max) # Compute UE8M0 scale factor inv_e4m3_max = Float32(INV_FLOAT8_E4M3_MAX) @@ -243,14 +293,26 @@ def kernel( inv_scale = ue8m0_to_inv_scale_fast(scale_ue8m0_u32) # Quantize to FP8 E4M3 and pack for vectorized store - if cutlass.const_expr(self.is_bfloat16): - fp8_packed = bfloat2x4_to_fp8x8_packed(v0, v1, v2, v3, inv_scale) + if cutlass.const_expr(self.use_2t_per_sf): + if cutlass.const_expr(self.is_bfloat16): + fp8_lo = bfloat2x4_to_fp8x8_packed(v0, v1, v2, v3, inv_scale) + fp8_hi = bfloat2x4_to_fp8x8_packed(v4, v5, v6, v7, inv_scale) + else: + fp8_lo = half2x4_to_fp8x8_packed(v0, v1, v2, v3, inv_scale) + fp8_hi = half2x4_to_fp8x8_packed(v4, v5, v6, v7, inv_scale) + + row_output = mOutput[row_idx, None] + st_global_u64(get_ptr_as_int64(row_output, elem_idx), fp8_lo) + st_global_u64(get_ptr_as_int64(row_output, elem_idx + Int32(8)), fp8_hi) else: - fp8_packed = half2x4_to_fp8x8_packed(v0, v1, v2, v3, inv_scale) + if cutlass.const_expr(self.is_bfloat16): + fp8_packed = bfloat2x4_to_fp8x8_packed(v0, v1, v2, v3, inv_scale) + else: + fp8_packed = half2x4_to_fp8x8_packed(v0, v1, v2, v3, inv_scale) - row_output = mOutput[row_idx, None] - output_ptr_i64 = get_ptr_as_int64(row_output, elem_idx) - st_global_u64(output_ptr_i64, fp8_packed) + row_output = mOutput[row_idx, None] + output_ptr_i64 = get_ptr_as_int64(row_output, elem_idx) + st_global_u64(output_ptr_i64, fp8_packed) if thread_in_sf == Int32(0): mScales[sf_idx] = scale_ue8m0 @@ -263,7 +325,7 @@ def kernel( # ============================================================================= -# CuTe-DSL Kernel Class for Swizzled Layout +# CuTe-DSL Kernel Class for Swizzled Layout — Row-Based Iteration # ============================================================================= @@ -281,8 +343,11 @@ class MXFP8QuantizeSwizzledKernel: - For small K: Multiple rows processed per block iteration - For large K: Single row with column loop - This kernel is M-agnostic: compiled once per (K, dtype, pdl) combination. - M-dependent values (M, padded_M) are passed at runtime. + For MXFP8, each SF block (32 elements) is processed by _threads_per_sf + threads (2 or 4), so threads_per_row = num_sf_blocks_per_row * _threads_per_sf. + + This kernel is M-agnostic: compiled once per (K, dtype, pdl, use_2t) + combination. M-dependent values (M, padded_M) are passed at runtime. """ def __init__( @@ -290,25 +355,34 @@ def __init__( dtype: cutlass.Numeric, K: int, enable_pdl: bool = False, + use_2t_per_sf: bool = True, ): - self.dtype = dtype - self.K = K self.is_bfloat16 = dtype == cutlass.BFloat16 self.enable_pdl = enable_pdl + self.use_2t_per_sf = use_2t_per_sf + + if use_2t_per_sf: + self._elts_per_thread = ELTS_PER_THREAD + self._threads_per_sf = THREADS_PER_SF + self._sf_blocks_per_warp = SF_BLOCKS_PER_WARP + else: + self._elts_per_thread = ELTS_PER_THREAD_SMALL + self._threads_per_sf = THREADS_PER_SF_SMALL + self._sf_blocks_per_warp = SF_BLOCKS_PER_WARP_SMALL assert K % SF_VEC_SIZE == 0 self.num_sf_blocks_per_row = K // SF_VEC_SIZE self.padded_sf_cols = ((self.num_sf_blocks_per_row + 3) // 4) * 4 # Compute optimal warps for 100% thread utilization - self.warps_per_block = _compute_optimal_warps_for_k(K) + self.warps_per_block = _compute_optimal_warps(K, self._sf_blocks_per_warp) # Multi-row processing constants (compile-time) threads_per_block = self.warps_per_block * WARP_SIZE - col_units_per_block = threads_per_block // THREADS_PER_SF + col_units_per_block = threads_per_block // self._threads_per_sf - # threads_per_row = num_sf_blocks_per_row * THREADS_PER_SF = K / 8 - self.threads_per_row = self.num_sf_blocks_per_row * THREADS_PER_SF + # threads_per_row = num_sf_blocks_per_row * _threads_per_sf + self.threads_per_row = self.num_sf_blocks_per_row * self._threads_per_sf # rows_per_block = col_units_per_block // num_sf_blocks_per_row # With optimal warps, this should divide evenly for small K @@ -369,13 +443,15 @@ def kernel( threads_per_row = self.threads_per_row rows_per_block = self.rows_per_block + _threads_per_sf = self._threads_per_sf + _elts_per_thread = self._elts_per_thread + if cutlass.const_expr(self.needs_col_loop): # Large K path: single row per block iteration with column loop - # This is the original algorithm for K > 4096 - col_unit_idx = tidx // THREADS_PER_SF - thread_in_unit = tidx % THREADS_PER_SF + col_unit_idx = tidx // _threads_per_sf + thread_in_unit = tidx % _threads_per_sf threads_per_block = self.warps_per_block * WARP_SIZE - col_units_per_block = threads_per_block // THREADS_PER_SF + col_units_per_block = threads_per_block // _threads_per_sf row_idx = bidx while row_idx < padded_M: @@ -396,40 +472,82 @@ def kernel( sf_col_idx = col_unit_idx while sf_col_idx < num_sf_blocks_per_row: elem_idx = ( - sf_col_idx * SF_VEC_SIZE + thread_in_unit * ELTS_PER_THREAD + sf_col_idx * SF_VEC_SIZE + thread_in_unit * _elts_per_thread ) row_input = mInput[row_idx, None] - input_ptr_i64 = get_ptr_as_int64(row_input, elem_idx) - v0, v1, v2, v3 = ld_global_v4_u32(input_ptr_i64) - if cutlass.const_expr(self.is_bfloat16): - max0123 = bfloat2_max_abs_4(v0, v1, v2, v3) - local_max = bfloat2_hmax_reduce_to_f32(max0123) + if cutlass.const_expr(self.use_2t_per_sf): + input_ptr_lo = get_ptr_as_int64(row_input, elem_idx) + input_ptr_hi = get_ptr_as_int64( + row_input, elem_idx + Int32(8) + ) + v0, v1, v2, v3 = ld_global_v4_u32(input_ptr_lo) + v4, v5, v6, v7 = ld_global_v4_u32(input_ptr_hi) + if cutlass.const_expr(self.is_bfloat16): + local_max = bfloat2_hmax_reduce_to_f32( + bfloat2_max_abs_8(v0, v1, v2, v3, v4, v5, v6, v7) + ) + else: + local_max = hmax_reduce_to_f32( + half2_max_abs_8(v0, v1, v2, v3, v4, v5, v6, v7) + ) + global_max = reduce_max_2threads(local_max) else: - max0123 = half2_max_abs_4(v0, v1, v2, v3) - local_max = hmax_reduce_to_f32(max0123) + input_ptr_i64 = get_ptr_as_int64(row_input, elem_idx) + v0, v1, v2, v3 = ld_global_v4_u32(input_ptr_i64) + if cutlass.const_expr(self.is_bfloat16): + local_max = bfloat2_hmax_reduce_to_f32( + bfloat2_max_abs_4(v0, v1, v2, v3) + ) + else: + local_max = hmax_reduce_to_f32( + half2_max_abs_4(v0, v1, v2, v3) + ) + global_max = reduce_max_4threads(local_max) - global_max = reduce_max_4threads(local_max) inv_e4m3_max = Float32(INV_FLOAT8_E4M3_MAX) normalized_max = global_max * inv_e4m3_max scale_ue8m0_u32 = float_to_ue8m0_fast(normalized_max) scale_ue8m0 = scale_ue8m0_u32.to(Uint8) inv_scale = ue8m0_to_inv_scale_fast(scale_ue8m0_u32) - if cutlass.const_expr(self.is_bfloat16): - fp8_packed = bfloat2x4_to_fp8x8_packed( - v0, v1, v2, v3, inv_scale + row_output = mOutput[row_idx, None] + if cutlass.const_expr(self.use_2t_per_sf): + if cutlass.const_expr(self.is_bfloat16): + fp8_lo = bfloat2x4_to_fp8x8_packed( + v0, v1, v2, v3, inv_scale + ) + fp8_hi = bfloat2x4_to_fp8x8_packed( + v4, v5, v6, v7, inv_scale + ) + else: + fp8_lo = half2x4_to_fp8x8_packed( + v0, v1, v2, v3, inv_scale + ) + fp8_hi = half2x4_to_fp8x8_packed( + v4, v5, v6, v7, inv_scale + ) + st_global_u64( + get_ptr_as_int64(row_output, elem_idx), fp8_lo + ) + st_global_u64( + get_ptr_as_int64(row_output, elem_idx + Int32(8)), + fp8_hi, ) else: - fp8_packed = half2x4_to_fp8x8_packed( - v0, v1, v2, v3, inv_scale + if cutlass.const_expr(self.is_bfloat16): + fp8_packed = bfloat2x4_to_fp8x8_packed( + v0, v1, v2, v3, inv_scale + ) + else: + fp8_packed = half2x4_to_fp8x8_packed( + v0, v1, v2, v3, inv_scale + ) + st_global_u64( + get_ptr_as_int64(row_output, elem_idx), fp8_packed ) - row_output = mOutput[row_idx, None] - output_ptr_i64 = get_ptr_as_int64(row_output, elem_idx) - st_global_u64(output_ptr_i64, fp8_packed) - if thread_in_unit == Int32(0): sf_offset = compute_sf_index_swizzled_128x4_gpu( row_idx, sf_col_idx, padded_sf_cols @@ -455,8 +573,8 @@ def kernel( # Thread mapping: tidx -> (row_in_block, sf_col_idx, thread_in_unit) row_in_block = tidx // threads_per_row local_tidx = tidx % threads_per_row - sf_col_idx = local_tidx // THREADS_PER_SF - thread_in_unit = local_tidx % THREADS_PER_SF + sf_col_idx = local_tidx // _threads_per_sf + thread_in_unit = local_tidx % _threads_per_sf # Grid-stride loop over row batches row_batch_idx = bidx @@ -469,69 +587,120 @@ def kernel( is_padding_row = row_idx >= M if is_padding_row: - # Fast path: padding row - zero out scale factors - # Each thread handles one SF column (no column loop needed) - if sf_col_idx < padded_sf_cols and thread_in_unit == Int32(0): - sf_offset = compute_sf_index_swizzled_128x4_gpu( - row_idx, sf_col_idx, padded_sf_cols - ) - mScales[sf_offset] = Uint8(0) + # Fast path: padding row - zero out ALL scale factors + # Thread-stride loop since padded_sf_cols may exceed + # num_sf_blocks_per_row (when K/32 is not a multiple of 4) + if thread_in_unit == Int32(0): + pad_col = sf_col_idx + while pad_col < padded_sf_cols: + sf_offset = compute_sf_index_swizzled_128x4_gpu( + row_idx, pad_col, padded_sf_cols + ) + mScales[sf_offset] = Uint8(0) + pad_col = pad_col + num_sf_blocks_per_row else: # Normal path: process actual data if sf_col_idx < num_sf_blocks_per_row: elem_idx = ( sf_col_idx * SF_VEC_SIZE - + thread_in_unit * ELTS_PER_THREAD + + thread_in_unit * _elts_per_thread ) row_input = mInput[row_idx, None] - input_ptr_i64 = get_ptr_as_int64(row_input, elem_idx) - v0, v1, v2, v3 = ld_global_v4_u32(input_ptr_i64) - if cutlass.const_expr(self.is_bfloat16): - max0123 = bfloat2_max_abs_4(v0, v1, v2, v3) - local_max = bfloat2_hmax_reduce_to_f32(max0123) + if cutlass.const_expr(self.use_2t_per_sf): + input_ptr_lo = get_ptr_as_int64(row_input, elem_idx) + input_ptr_hi = get_ptr_as_int64( + row_input, elem_idx + Int32(8) + ) + v0, v1, v2, v3 = ld_global_v4_u32(input_ptr_lo) + v4, v5, v6, v7 = ld_global_v4_u32(input_ptr_hi) + if cutlass.const_expr(self.is_bfloat16): + local_max = bfloat2_hmax_reduce_to_f32( + bfloat2_max_abs_8( + v0, v1, v2, v3, v4, v5, v6, v7 + ) + ) + else: + local_max = hmax_reduce_to_f32( + half2_max_abs_8(v0, v1, v2, v3, v4, v5, v6, v7) + ) + global_max = reduce_max_2threads(local_max) else: - max0123 = half2_max_abs_4(v0, v1, v2, v3) - local_max = hmax_reduce_to_f32(max0123) + input_ptr_i64 = get_ptr_as_int64(row_input, elem_idx) + v0, v1, v2, v3 = ld_global_v4_u32(input_ptr_i64) + if cutlass.const_expr(self.is_bfloat16): + local_max = bfloat2_hmax_reduce_to_f32( + bfloat2_max_abs_4(v0, v1, v2, v3) + ) + else: + local_max = hmax_reduce_to_f32( + half2_max_abs_4(v0, v1, v2, v3) + ) + global_max = reduce_max_4threads(local_max) - global_max = reduce_max_4threads(local_max) inv_e4m3_max = Float32(INV_FLOAT8_E4M3_MAX) - normalized_max = global_max * inv_e4m3_max - scale_ue8m0_u32 = float_to_ue8m0_fast(normalized_max) + scale_ue8m0_u32 = float_to_ue8m0_fast( + global_max * inv_e4m3_max + ) scale_ue8m0 = scale_ue8m0_u32.to(Uint8) inv_scale = ue8m0_to_inv_scale_fast(scale_ue8m0_u32) - if cutlass.const_expr(self.is_bfloat16): - fp8_packed = bfloat2x4_to_fp8x8_packed( - v0, v1, v2, v3, inv_scale + row_output = mOutput[row_idx, None] + if cutlass.const_expr(self.use_2t_per_sf): + if cutlass.const_expr(self.is_bfloat16): + fp8_lo = bfloat2x4_to_fp8x8_packed( + v0, v1, v2, v3, inv_scale + ) + fp8_hi = bfloat2x4_to_fp8x8_packed( + v4, v5, v6, v7, inv_scale + ) + else: + fp8_lo = half2x4_to_fp8x8_packed( + v0, v1, v2, v3, inv_scale + ) + fp8_hi = half2x4_to_fp8x8_packed( + v4, v5, v6, v7, inv_scale + ) + st_global_u64( + get_ptr_as_int64(row_output, elem_idx), fp8_lo + ) + st_global_u64( + get_ptr_as_int64(row_output, elem_idx + Int32(8)), + fp8_hi, ) else: - fp8_packed = half2x4_to_fp8x8_packed( - v0, v1, v2, v3, inv_scale + if cutlass.const_expr(self.is_bfloat16): + fp8_packed = bfloat2x4_to_fp8x8_packed( + v0, v1, v2, v3, inv_scale + ) + else: + fp8_packed = half2x4_to_fp8x8_packed( + v0, v1, v2, v3, inv_scale + ) + st_global_u64( + get_ptr_as_int64(row_output, elem_idx), fp8_packed ) - row_output = mOutput[row_idx, None] - output_ptr_i64 = get_ptr_as_int64(row_output, elem_idx) - st_global_u64(output_ptr_i64, fp8_packed) - if thread_in_unit == Int32(0): sf_offset = compute_sf_index_swizzled_128x4_gpu( row_idx, sf_col_idx, padded_sf_cols ) mScales[sf_offset] = scale_ue8m0 - # Handle padding SF columns (for this row) - # Threads with sf_col_idx in [num_sf_blocks_per_row, padded_sf_cols) - if ( - sf_col_idx >= num_sf_blocks_per_row - and sf_col_idx < padded_sf_cols - and thread_in_unit == Int32(0) + # Handle padding SF columns for this row + # Thread-stride loop starting from first padding column + if cutlass.const_expr( + self.num_sf_blocks_per_row != self.padded_sf_cols ): - sf_offset = compute_sf_index_swizzled_128x4_gpu( - row_idx, sf_col_idx, padded_sf_cols - ) - mScales[sf_offset] = Uint8(0) + if thread_in_unit == Int32(0): + pad_col = num_sf_blocks_per_row + sf_col_idx + while pad_col < padded_sf_cols: + sf_offset = compute_sf_index_swizzled_128x4_gpu( + row_idx, pad_col, padded_sf_cols + ) + mScales[sf_offset] = Uint8(0) + pad_col = pad_col + num_sf_blocks_per_row row_batch_idx = row_batch_idx + grid_dim_x # Update row_idx for next iteration @@ -548,22 +717,24 @@ def kernel( @functools.cache -def _get_compiled_kernel_linear( +def _get_compiled_kernel_mxfp8_linear( is_bfloat16: bool, K: int, enable_pdl: bool = False, + use_2t_per_sf: bool = True, ) -> Tuple[Callable, int]: """ Get or compile LINEAR layout kernel with TVM-FFI. - Cached by (K, dtype, pdl) - M-agnostic, device-independent compilation. + Cached by (K, dtype, pdl, use_2t) - M-agnostic, device-independent + compilation. Returns: Tuple of (compiled_kernel, sf_blocks_per_tb) where sf_blocks_per_tb is used by the caller to compute num_blocks at runtime. """ cutlass_dtype = cutlass.BFloat16 if is_bfloat16 else cutlass.Float16 - kernel_obj = MXFP8QuantizeLinearKernel(cutlass_dtype, K, enable_pdl) + kernel_obj = MXFP8QuantizeLinearKernel(cutlass_dtype, K, enable_pdl, use_2t_per_sf) # Use symbolic M for dynamic batch sizes sym_m = cute.sym_int() @@ -596,22 +767,26 @@ def _get_compiled_kernel_linear( @functools.cache -def _get_compiled_kernel_swizzled( +def _get_compiled_kernel_mxfp8_swizzled( is_bfloat16: bool, K: int, enable_pdl: bool = False, + use_2t_per_sf: bool = True, ) -> Tuple[Callable, int]: """ Get or compile SWIZZLED layout kernel with TVM-FFI. - Cached by (K, dtype, pdl) - M-agnostic, device-independent compilation. + Cached by (K, dtype, pdl, use_2t) - M-agnostic, device-independent + compilation. Returns: Tuple of (compiled_kernel, rows_per_block) where rows_per_block is used by the caller to compute num_blocks at runtime. """ cutlass_dtype = cutlass.BFloat16 if is_bfloat16 else cutlass.Float16 - kernel_obj = MXFP8QuantizeSwizzledKernel(cutlass_dtype, K, enable_pdl) + kernel_obj = MXFP8QuantizeSwizzledKernel( + cutlass_dtype, K, enable_pdl, use_2t_per_sf + ) # Use symbolic M for dynamic batch sizes sym_m = cute.sym_int() @@ -683,9 +858,9 @@ def mxfp8_quantize_cute_dsl( f"alignment must be divisible by SF_VEC_SIZE={SF_VEC_SIZE}" ) - # Auto-detect PDL support based on device capability - if enable_pdl is None: - enable_pdl = device_support_pdl(input.device) + # Auto-detect PDL support based on device capability. + # If caller passes True explicitly, still check hardware support. + enable_pdl = device_support_pdl(input.device) if enable_pdl is not False else False if input.dim() > 2: m = input.numel() // input.shape[-1] @@ -715,14 +890,20 @@ def mxfp8_quantize_cute_dsl( # Compute M-dependent values outside the cached kernel num_sf_blocks_per_row = padded_k // SF_VEC_SIZE + # Choose 2T/SF (optimized) vs 4T/SF (legacy) based on problem size. + # 2T/SF doubles memory-level parallelism per warp but halves the grid, + # so it only helps when there are enough SF blocks to fill all SMs. + total_sf_blocks_for_dispatch = m * num_sf_blocks_per_row + use_2t = total_sf_blocks_for_dispatch >= MXFP8_2T_SF_THRESHOLD + if is_sf_swizzled_layout: # Swizzled layout: compute padded_M and scale_output_size padded_m = ((m + ROW_TILE_SIZE - 1) // ROW_TILE_SIZE) * ROW_TILE_SIZE padded_sf_cols = ((num_sf_blocks_per_row + 3) // 4) * 4 scale_output_size = padded_m * padded_sf_cols - kernel_fn, rows_per_block = _get_compiled_kernel_swizzled( - is_bfloat16, padded_k, enable_pdl + kernel_fn, rows_per_block = _get_compiled_kernel_mxfp8_swizzled( + is_bfloat16, padded_k, enable_pdl, use_2t ) num_blocks = min((padded_m + rows_per_block - 1) // rows_per_block, target_grid) @@ -738,8 +919,8 @@ def mxfp8_quantize_cute_dsl( total_sf_blocks = m * num_sf_blocks_per_row scale_output_size = total_sf_blocks - kernel_fn, sf_blocks_per_tb = _get_compiled_kernel_linear( - is_bfloat16, padded_k, enable_pdl + kernel_fn, sf_blocks_per_tb = _get_compiled_kernel_mxfp8_linear( + is_bfloat16, padded_k, enable_pdl, use_2t ) num_blocks = min( @@ -762,6 +943,6 @@ def mxfp8_quantize_cute_dsl( "MXFP8QuantizeLinearKernel", "MXFP8QuantizeSwizzledKernel", "mxfp8_quantize_cute_dsl", - "_get_compiled_kernel_linear", - "_get_compiled_kernel_swizzled", + "_get_compiled_kernel_mxfp8_linear", + "_get_compiled_kernel_mxfp8_swizzled", ] diff --git a/flashinfer/quantization/kernels/nvfp4_quantize.py b/flashinfer/quantization/kernels/nvfp4_quantize.py index 52a4aa6cfa..96422df4ce 100644 --- a/flashinfer/quantization/kernels/nvfp4_quantize.py +++ b/flashinfer/quantization/kernels/nvfp4_quantize.py @@ -19,10 +19,15 @@ NVFP4 quantization kernel using CuTe-DSL. Supports multiple scale factor layouts: swizzled 128x4, swizzled 8x4, and linear. +Dual-path optimization following the MXFP4 pattern: +- Linear layout: flat SF-block iteration for 100% thread utilization +- Swizzled layout: row-based iteration with multi-row / column-loop paths + Key differences from MXFP4: - sf_vec_size=16 (vs 32 for MXFP4) - E4M3 scale factors (vs UE8M0 for MXFP4) - User-provided global_scale (vs auto-computed for MXFP4) +- 3 layouts: 128x4, 8x4, linear (vs 128x4, linear for MXFP4) """ import functools @@ -41,6 +46,7 @@ from ...cute_dsl.utils import get_num_sm from ..quantization_cute_dsl_utils import ( NVFP4_SF_VEC_SIZE, + WARP_SIZE, ROW_TILE_SIZE, compute_sf_index_swizzled_128x4_gpu, compute_sf_index_swizzled_8x4_gpu, @@ -60,74 +66,221 @@ SF_LAYOUT_8x4 = 1 SF_LAYOUT_LINEAR = 2 +# Blocks per SM for occupancy target _BLOCKS_PER_SM = 4 + +# Maximum threads per block _MAX_THREADS_PER_BLOCK = 1024 + +# Thread count bounds for swizzled kernel _MIN_THREADS = 128 _MAX_THREADS = 512 -_DEFAULT_THREADS = 256 + +# Linear kernel: fixed 16 warps (512 threads), 1 SF block per thread +_LINEAR_WARPS_PER_BLOCK = 16 +_LINEAR_SF_BLOCKS_PER_TB = _LINEAR_WARPS_PER_BLOCK * WARP_SIZE # 512 -def _compute_optimal_threads_for_k(K: int) -> int: +def _compute_optimal_threads(K: int) -> int: """ - Compute optimal thread count for 100% thread utilization. + Compute optimal thread count for 100% utilization in the swizzled kernel. + + For NVFP4, each thread processes 1 SF block (16 elements), so: + threads_per_row = K / 16 + + We want num_threads to be a multiple of threads_per_row so that + rows_per_block = num_threads / threads_per_row is an integer. + + We prefer LARGER thread counts (up to _MAX_THREADS) for better occupancy. + + If threads_per_row > _MAX_THREADS, we use _MAX_THREADS with a column loop. - For NVFP4, each thread processes one SF block (16 elements). - threads_per_row = K / 16 = num_sf_blocks_per_row + Args: + K: Number of columns (must be divisible by 16) - We prefer LARGER thread counts (up to _MAX_THREADS) for better occupancy, - while maintaining 100% thread utilization. + Returns: + Optimal number of threads per block """ - threads_per_row = K // NVFP4_SF_VEC_SIZE + threads_per_row = K // NVFP4_SF_VEC_SIZE # K / 16 - if threads_per_row >= _MAX_THREADS: + if threads_per_row > _MAX_THREADS: + # Column loop mode: use maximum threads return _MAX_THREADS - if threads_per_row <= _MAX_THREADS: - threads = (_MAX_THREADS // threads_per_row) * threads_per_row - if threads >= _MIN_THREADS: - return threads - threads = threads_per_row - while threads < _MIN_THREADS: - threads += threads_per_row - if threads <= _MAX_THREADS: - return threads + # Find largest multiple of threads_per_row in [_MIN_THREADS, _MAX_THREADS] + largest = (_MAX_THREADS // threads_per_row) * threads_per_row + if largest >= _MIN_THREADS: + return largest + + # If largest multiple is below _MIN_THREADS, use smallest valid one + candidate = threads_per_row + while candidate < _MIN_THREADS: + candidate += threads_per_row + if candidate <= _MAX_THREADS: + return candidate + + # Fallback (shouldn't happen for reasonable K) + return _MAX_THREADS + + +# ============================================================================= +# CuTe-DSL Kernel Class for Linear Layout — Flat SF-Block Iteration +# ============================================================================= + + +class NVFP4QuantizeLinearKernel: + """ + NVFP4 quantization kernel optimized for LINEAR layout. + + Uses flat SF-block iteration for efficient memory access. Row and + column indices are derived from the flat SF index via integer division. + + No padding passes are needed since for linear layout: + - padded_m == m (no row padding) + - padded_sf_cols == num_sf_blocks_per_row (no column padding) + + This kernel is M-agnostic: compiled once per (K, dtype, pdl) combination. + Each thread handles one SF block (16 elements). + """ + + WARPS_PER_BLOCK = _LINEAR_WARPS_PER_BLOCK + SF_BLOCKS_PER_TB = _LINEAR_SF_BLOCKS_PER_TB + + def __init__( + self, + dtype: cutlass.Numeric, + K: int, + enable_pdl: bool = False, + ): + self.dtype = dtype + self.K = K + self.is_bfloat16 = dtype == cutlass.BFloat16 + self.is_fp8 = dtype == cutlass.Float8E4M3FN + self.enable_pdl = enable_pdl + + assert K % NVFP4_SF_VEC_SIZE == 0 + self.num_sf_blocks_per_row = K // NVFP4_SF_VEC_SIZE + + @cute.jit + def __call__( + self, + mInput: cute.Tensor, + mOutput: cute.Tensor, + mScales: cute.Tensor, + M: Int32, + total_sf_blocks: Int32, + num_blocks: Int32, + mGlobalScale: cute.Tensor, + stream, + ): + threads_per_block = self.WARPS_PER_BLOCK * WARP_SIZE + + self.kernel(mInput, mOutput, mScales, M, total_sf_blocks, mGlobalScale).launch( + grid=[num_blocks, 1, 1], + block=[threads_per_block, 1, 1], + max_number_threads=[_MAX_THREADS_PER_BLOCK, 1, 1], + min_blocks_per_mp=_BLOCKS_PER_SM, + stream=stream, + use_pdl=self.enable_pdl, + ) + + @cute.kernel + def kernel( + self, + mInput: cute.Tensor, + mOutput: cute.Tensor, + mScales: cute.Tensor, + M: Int32, + total_sf_blocks: Int32, + mGlobalScale: cute.Tensor, + ): + """ + NVFP4 quantization with flat SF-block iteration for linear layout. + + Each thread handles one SF block (16 elements). Row and column + indices are derived from the flat SF index. + """ + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + grid_dim_x, _, _ = cute.arch.grid_dim() + + if cutlass.const_expr(self.enable_pdl): + cute.arch.griddepcontrol_wait() + + # Read global_scale from device memory (avoids CPU-GPU sync at launch) + global_scale = Float32(mGlobalScale[Int32(0)]) + + num_sf_blocks_per_row = self.num_sf_blocks_per_row + sf_blocks_per_tb = self.SF_BLOCKS_PER_TB + stride = grid_dim_x * sf_blocks_per_tb + + # Flat SF-block iteration + sf_idx = bidx * sf_blocks_per_tb + tidx + + while sf_idx < total_sf_blocks: + row_idx = sf_idx // num_sf_blocks_per_row + col_idx = sf_idx % num_sf_blocks_per_row + + elem_base = col_idx * NVFP4_SF_VEC_SIZE + row_input = mInput[row_idx, None] + + # Process block: load, compute scale, convert to E2M1 + if cutlass.const_expr(self.is_fp8): + scale_fp8, packed64 = process_nvfp4_block_fp8( + row_input, elem_base, global_scale + ) + elif cutlass.const_expr(self.is_bfloat16): + scale_fp8, packed64 = process_nvfp4_block_bfloat( + row_input, elem_base, global_scale + ) + else: + scale_fp8, packed64 = process_nvfp4_block_half( + row_input, elem_base, global_scale + ) + + # Write scale factor using linear indexing + sf_offset = compute_sf_index_linear_gpu( + row_idx, col_idx, num_sf_blocks_per_row + ) + mScales[sf_offset] = scale_fp8 - return _DEFAULT_THREADS + # Store 8 bytes (16 FP4 values = 1 x st.global.u64) + row_output = mOutput[row_idx, None] + out_base = col_idx * (NVFP4_SF_VEC_SIZE // 2) + out_ptr = get_ptr_as_int64(row_output, out_base) + st_global_u64(out_ptr, packed64) + sf_idx = sf_idx + stride -def _compute_swizzled_layout_sf_size( - total_row: int, total_column: int, row_size: int = 128 -) -> int: - """Compute size of swizzled scale factor buffer.""" - padded_row = (total_row + row_size - 1) // row_size * row_size - padded_column = (total_column + 3) // 4 * 4 - return padded_row * padded_column + # PDL: Signal that dependent kernels can start early + if cutlass.const_expr(self.enable_pdl): + cute.arch.griddepcontrol_launch_dependents() # ============================================================================= -# CuTe-DSL Kernel Class for NVFP4 Swizzled Layout +# CuTe-DSL Kernel Class for Swizzled Layout — Row-Based Iteration # ============================================================================= class NVFP4QuantizeSwizzledKernel: """ - NVFP4 quantization kernel supporting multiple scale factor layouts. - - Supported layouts: - - 128x4 (swizzled): Optimized for GEMM with large tileN - - 8x4 (swizzled): Optimized for GEMM with small tileN - - linear: Simple row-major layout, no swizzling + NVFP4 quantization kernel optimized for SWIZZLED (128x4 or 8x4) layout. - Key features: - - E4M3 scale factors (FP8 format) with user-provided global_scale - - sf_vec_size=16 (each thread processes 16 elements) - - Multi-row processing when K is small, column loop when K is large + Key optimizations: + - Multi-row processing: threads process multiple rows per block when K is small - Row-based iteration with grid-stride loop - - Padding row fast path for zeroing scale factors + - Padding row fast path - only zero out scale factors + + Thread utilization optimization: + - Dynamic thread count based on K for 100% thread utilization + - For small K: Multiple rows processed per block iteration + - For large K: Single row with column loop + + For NVFP4, each thread processes 1 SF block (16 elements) independently, + so threads_per_row = num_sf_blocks_per_row = K/16. This kernel is M-agnostic: compiled once per (K, dtype, sf_layout, pdl) - combination. M-dependent values (M, padded_M) and global_scale are passed - at runtime. + combination. M-dependent values (M, padded_M) are passed at runtime. """ def __init__( @@ -148,21 +301,13 @@ def __init__( assert K % NVFP4_SF_VEC_SIZE == 0 self.num_sf_blocks_per_row = K // NVFP4_SF_VEC_SIZE + self.padded_sf_cols = ((self.num_sf_blocks_per_row + 3) // 4) * 4 - if sf_layout == SF_LAYOUT_LINEAR: - self.padded_sf_cols = self.num_sf_blocks_per_row - self.row_tile_size = 1 - elif sf_layout == SF_LAYOUT_8x4: - self.padded_sf_cols = ((self.num_sf_blocks_per_row + 3) // 4) * 4 - self.row_tile_size = 8 - else: - self.padded_sf_cols = ((self.num_sf_blocks_per_row + 3) // 4) * 4 - self.row_tile_size = ROW_TILE_SIZE # 128 - - self.num_threads = _compute_optimal_threads_for_k(K) - - self.threads_per_row = self.num_sf_blocks_per_row + # Compute optimal thread count for 100% utilization + self.num_threads = _compute_optimal_threads(K) + self.threads_per_row = self.num_sf_blocks_per_row # 1 thread per SF block + # Multi-row processing constants (compile-time) if self.threads_per_row <= self.num_threads: self.rows_per_block = self.num_threads // self.threads_per_row self.needs_col_loop = False @@ -178,10 +323,7 @@ def _compute_sf_offset( if cutlass.const_expr(self.sf_is_128x4): return compute_sf_index_swizzled_128x4_gpu(row_idx, col_idx, padded_cols) else: - if cutlass.const_expr(self.sf_is_8x4): - return compute_sf_index_swizzled_8x4_gpu(row_idx, col_idx, padded_cols) - else: - return compute_sf_index_linear_gpu(row_idx, col_idx, padded_cols) + return compute_sf_index_swizzled_8x4_gpu(row_idx, col_idx, padded_cols) @cute.jit def __call__( @@ -195,11 +337,9 @@ def __call__( mGlobalScale: cute.Tensor, stream, ): - threads_per_block = self.num_threads - self.kernel(mInput, mOutput, mScales, M, padded_M, mGlobalScale).launch( grid=[num_blocks, 1, 1], - block=[threads_per_block, 1, 1], + block=[self.num_threads, 1, 1], max_number_threads=[_MAX_THREADS_PER_BLOCK, 1, 1], min_blocks_per_mp=_BLOCKS_PER_SM, stream=stream, @@ -217,20 +357,10 @@ def kernel( mGlobalScale: cute.Tensor, ): """ - NVFP4 quantization kernel with swizzled scale factor layout. - - Dual-path kernel with compile-time selection: - - Small K path: Multi-row processing for improved thread utilization - - Large K path: Single row with column loop - - Each thread processes one SF block (16 elements): - 1. Load 16 elements (2 x 128-bit for fp16/bf16, 1 x 128-bit for fp8) - 2. Compute max absolute value using SIMD reduction - 3. Compute E4M3 scale: cvt_f32_to_e4m3(global_scale * max / 6.0) - 4. Store scale factor using layout-specific indexing - 5. Back-convert E4M3, compute output_scale = global_scale / scale_back - 6. Scale elements and convert to E2M1 - 7. Store 8 bytes (16 FP4 values) + Row-based kernel for swizzled layout. + + When K is small: each block processes multiple rows simultaneously. + When K is large: each block processes one row with column loop. """ tidx, _, _ = cute.arch.thread_idx() bidx, _, _ = cute.arch.block_idx() @@ -242,28 +372,95 @@ def kernel( # Read global_scale from device memory (avoids CPU-GPU sync at launch) global_scale = Float32(mGlobalScale[Int32(0)]) + # Compile-time constants num_sf_blocks_per_row = self.num_sf_blocks_per_row padded_sf_cols = self.padded_sf_cols - num_threads = self.num_threads - rows_per_block = self.rows_per_block threads_per_row = self.threads_per_row + rows_per_block = self.rows_per_block + + if cutlass.const_expr(self.needs_col_loop): + # Large K path: single row per block iteration with column loop + # Each thread maps to one SF block; threads stride over columns + num_threads = self.num_threads + + row_idx = bidx + while row_idx < padded_M: + is_padding_row = row_idx >= M + + if is_padding_row: + # Fast path: padding row - only zero out scale factors + sf_col_idx = tidx + while sf_col_idx < padded_sf_cols: + sf_offset = self._compute_sf_offset( + row_idx, sf_col_idx, padded_sf_cols + ) + mScales[sf_offset] = Uint8(0) + sf_col_idx = sf_col_idx + num_threads + else: + # Normal path: process actual data row with column loop + sf_col_idx = tidx + while sf_col_idx < num_sf_blocks_per_row: + elem_base = sf_col_idx * NVFP4_SF_VEC_SIZE + row_input = mInput[row_idx, None] + + # Process block: load, compute scale, convert to E2M1 + if cutlass.const_expr(self.is_fp8): + scale_fp8, packed64 = process_nvfp4_block_fp8( + row_input, elem_base, global_scale + ) + elif cutlass.const_expr(self.is_bfloat16): + scale_fp8, packed64 = process_nvfp4_block_bfloat( + row_input, elem_base, global_scale + ) + else: + scale_fp8, packed64 = process_nvfp4_block_half( + row_input, elem_base, global_scale + ) + + # Write scale factor using swizzled indexing + sf_offset = self._compute_sf_offset( + row_idx, sf_col_idx, padded_sf_cols + ) + mScales[sf_offset] = scale_fp8 + + # Store 8 bytes (16 FP4 values = 1 x st.global.u64) + row_output = mOutput[row_idx, None] + out_base = sf_col_idx * (NVFP4_SF_VEC_SIZE // 2) + out_ptr = get_ptr_as_int64(row_output, out_base) + st_global_u64(out_ptr, packed64) - if cutlass.const_expr(not self.needs_col_loop): - # ===== SMALL K PATH: Multi-row processing ===== + sf_col_idx = sf_col_idx + num_threads + + # Handle padding columns for this row + sf_col_idx = num_sf_blocks_per_row + tidx + while sf_col_idx < padded_sf_cols: + sf_offset = self._compute_sf_offset( + row_idx, sf_col_idx, padded_sf_cols + ) + mScales[sf_offset] = Uint8(0) + sf_col_idx = sf_col_idx + num_threads + + row_idx = row_idx + grid_dim_x + else: + # Small K path: multi-row processing + # Thread mapping: tidx -> (row_in_block, sf_idx_in_row) row_in_block = tidx // threads_per_row sf_idx_in_row = tidx % threads_per_row + # Grid-stride loop over row batches row_batch_idx = bidx - total_row_batches = cute.ceil_div(padded_M, rows_per_block) - - while row_batch_idx < total_row_batches: - base_row = row_batch_idx * rows_per_block - row_idx = base_row + row_in_block - + # Initialize row_idx before while loop (CuTe DSL requires variables + # modified in while loops to be defined before the loop) + row_idx = row_batch_idx * rows_per_block + row_in_block + while row_batch_idx * rows_per_block < padded_M: if row_idx < padded_M: is_padding_row = row_idx >= M if is_padding_row: + # Fast path: padding row - zero out ALL padded_sf_cols + # Thread-stride loop since padded_sf_cols may exceed + # threads_per_row (e.g. K=32: threads_per_row=2, + # padded_sf_cols=4) local_sf_idx = sf_idx_in_row while local_sf_idx < padded_sf_cols: sf_offset = self._compute_sf_offset( @@ -272,10 +469,12 @@ def kernel( mScales[sf_offset] = Uint8(0) local_sf_idx = local_sf_idx + threads_per_row else: + # Normal path: process actual data if sf_idx_in_row < num_sf_blocks_per_row: elem_base = sf_idx_in_row * NVFP4_SF_VEC_SIZE row_input = mInput[row_idx, None] + # Process block: load, compute scale, convert to E2M1 if cutlass.const_expr(self.is_fp8): scale_fp8, packed64 = process_nvfp4_block_fp8( row_input, elem_base, global_scale @@ -289,86 +488,36 @@ def kernel( row_input, elem_base, global_scale ) + # Write scale factor using swizzled indexing sf_offset = self._compute_sf_offset( row_idx, sf_idx_in_row, padded_sf_cols ) mScales[sf_offset] = scale_fp8 + # Store 8 bytes (16 FP4 values = 1 x st.global.u64) row_output = mOutput[row_idx, None] out_base = sf_idx_in_row * (NVFP4_SF_VEC_SIZE // 2) out_ptr = get_ptr_as_int64(row_output, out_base) st_global_u64(out_ptr, packed64) - padding_sf_start = num_sf_blocks_per_row + sf_idx_in_row - while padding_sf_start < padded_sf_cols: - sf_offset = self._compute_sf_offset( - row_idx, padding_sf_start, padded_sf_cols - ) - mScales[sf_offset] = Uint8(0) - padding_sf_start = padding_sf_start + threads_per_row - - row_batch_idx = row_batch_idx + grid_dim_x - - else: - # ===== LARGE K PATH: Single row with column loop ===== - row_idx = bidx - while row_idx < padded_M: - is_padding_row = row_idx >= M - - sf_idx = Int32(tidx) - - if is_padding_row: - while sf_idx < padded_sf_cols: - sf_offset = self._compute_sf_offset( - row_idx, sf_idx, padded_sf_cols - ) - mScales[sf_offset] = Uint8(0) - sf_idx = sf_idx + num_threads - else: - num_sf_iters = ( - num_sf_blocks_per_row + num_threads - 1 - ) // num_threads - - for sf_iter in range(num_sf_iters): - local_sf_idx = sf_iter * num_threads + tidx - - if local_sf_idx < num_sf_blocks_per_row: - elem_base = local_sf_idx * NVFP4_SF_VEC_SIZE - row_input = mInput[row_idx, None] - - if cutlass.const_expr(self.is_fp8): - scale_fp8, packed64 = process_nvfp4_block_fp8( - row_input, elem_base, global_scale - ) - elif cutlass.const_expr(self.is_bfloat16): - scale_fp8, packed64 = process_nvfp4_block_bfloat( - row_input, elem_base, global_scale - ) - else: - scale_fp8, packed64 = process_nvfp4_block_half( - row_input, elem_base, global_scale + # Handle padding SF columns for this row + # Thread-stride loop starting from first padding column + if cutlass.const_expr( + self.num_sf_blocks_per_row != self.padded_sf_cols + ): + pad_col = num_sf_blocks_per_row + sf_idx_in_row + while pad_col < padded_sf_cols: + sf_offset = self._compute_sf_offset( + row_idx, pad_col, padded_sf_cols ) + mScales[sf_offset] = Uint8(0) + pad_col = pad_col + threads_per_row - sf_offset = self._compute_sf_offset( - row_idx, local_sf_idx, padded_sf_cols - ) - mScales[sf_offset] = scale_fp8 - - row_output = mOutput[row_idx, None] - out_base = local_sf_idx * (NVFP4_SF_VEC_SIZE // 2) - out_ptr = get_ptr_as_int64(row_output, out_base) - st_global_u64(out_ptr, packed64) - - padding_sf_start = num_sf_blocks_per_row + tidx - while padding_sf_start < padded_sf_cols: - sf_offset = self._compute_sf_offset( - row_idx, padding_sf_start, padded_sf_cols - ) - mScales[sf_offset] = Uint8(0) - padding_sf_start = padding_sf_start + num_threads - - row_idx = row_idx + grid_dim_x + row_batch_idx = row_batch_idx + grid_dim_x + # Update row_idx for next iteration + row_idx = row_batch_idx * rows_per_block + row_in_block + # PDL: Signal that dependent kernels can start early if cutlass.const_expr(self.enable_pdl): cute.arch.griddepcontrol_launch_dependents() @@ -431,13 +580,10 @@ def __init__( if sf_layout == SF_LAYOUT_LINEAR: self.padded_sf_cols = self.num_sf_blocks_per_row - self.row_tile_size = 1 elif sf_layout == SF_LAYOUT_8x4: self.padded_sf_cols = ((self.num_sf_blocks_per_row + 3) // 4) * 4 - self.row_tile_size = 8 else: self.padded_sf_cols = ((self.num_sf_blocks_per_row + 3) // 4) * 4 - self.row_tile_size = ROW_TILE_SIZE self.num_consumer_warps = _TMA_NUM_CONSUMER_WARPS # 8 self.num_stages = _TMA_NUM_STAGES @@ -451,7 +597,6 @@ def __init__( # Thread indexing constants (matches CUDA TmaKernelTraitsTwoBytes) self.THREADS_PER_ROW = 4 # laneIdx % 4 self.ROWS_PER_WARP = 8 # 32 / 4 - self.ROW_ITERATIONS = _TMA_ROW_TILE // self.ROWS_PER_WARP # 2 self.ELTS_PER_THREAD = NVFP4_SF_VEC_SIZE # 16 @cute.jit @@ -693,7 +838,7 @@ def kernel( pipeline_init_arrive(cluster_shape_mn=self.cluster_shape_mn, is_relaxed=True) pipeline_init_wait(cluster_shape_mn=self.cluster_shape_mn) - # ---- TMA partition (3D: rows × warps × cols_per_warp) ---- + # ---- TMA partition (3D: rows x warps x cols_per_warp) ---- gSrc_tiled = cute.local_tile( gInput_tma, (_TMA_ROW_TILE, _TMA_NUM_CONSUMER_WARPS, _TMA_COL_TILE), @@ -911,8 +1056,8 @@ def _get_compiled_kernel_nvfp4( dtype_key: One of "float16", "bfloat16", "float8_e4m3fn". Returns: - Tuple of (compiled_kernel, rows_per_block) where rows_per_block - is used by the caller to compute num_blocks at runtime. + For linear layout: (compiled_kernel, sf_blocks_per_tb) + For swizzled layout: (compiled_kernel, rows_per_block) """ _dtype_map = { "float16": cutlass.Float16, @@ -920,19 +1065,18 @@ def _get_compiled_kernel_nvfp4( "float8_e4m3fn": cutlass.Float8E4M3FN, } cutlass_dtype = _dtype_map[dtype_key] - kernel_obj = NVFP4QuantizeSwizzledKernel( - cutlass_dtype, K, sf_layout=sf_layout, enable_pdl=enable_pdl - ) + # Use symbolic M for dynamic batch sizes sym_m = cute.sym_int() + sym_scale_size = cute.sym_int() + # Common fake tensors input_fake = cute.runtime.make_fake_compact_tensor( cutlass_dtype, (sym_m, K), stride_order=(1, 0), assumed_align=16 ) output_fake = cute.runtime.make_fake_compact_tensor( cutlass.Uint8, (sym_m, K // 2), stride_order=(1, 0), assumed_align=16 ) - sym_scale_size = cute.sym_int() scales_fake = cute.runtime.make_fake_compact_tensor( cutlass.Uint8, (sym_scale_size,), assumed_align=16 ) @@ -941,20 +1085,42 @@ def _get_compiled_kernel_nvfp4( ) stream_fake = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True) - compiled_kernel = cute.compile( - kernel_obj, - input_fake, - output_fake, - scales_fake, - Int32(1), # Dummy M - Int32(128), # Dummy padded_M - Int32(1), # Dummy num_blocks - global_scale_fake, - stream_fake, - options="--enable-tvm-ffi", - ) + if sf_layout == SF_LAYOUT_LINEAR: + linear_obj = NVFP4QuantizeLinearKernel(cutlass_dtype, K, enable_pdl) + + compiled_kernel = cute.compile( + linear_obj, + input_fake, + output_fake, + scales_fake, + Int32(1), # Dummy M + Int32(1), # Dummy total_sf_blocks + Int32(1), # Dummy num_blocks + global_scale_fake, + stream_fake, + options="--enable-tvm-ffi", + ) - return compiled_kernel, kernel_obj.rows_per_block + return compiled_kernel, linear_obj.SF_BLOCKS_PER_TB + else: + swizzled_obj = NVFP4QuantizeSwizzledKernel( + cutlass_dtype, K, sf_layout=sf_layout, enable_pdl=enable_pdl + ) + + compiled_kernel = cute.compile( + swizzled_obj, + input_fake, + output_fake, + scales_fake, + Int32(1), # Dummy M + Int32(128), # Dummy padded_M + Int32(1), # Dummy num_blocks + global_scale_fake, + stream_fake, + options="--enable-tvm-ffi", + ) + + return compiled_kernel, swizzled_obj.rows_per_block _TMA_MIN_M = 1024 @@ -1077,8 +1243,7 @@ def nvfp4_quantize_cute_dsl( ) assert input.is_cuda, "Input must be on CUDA device" - if enable_pdl is None: - enable_pdl = device_support_pdl(input.device) + enable_pdl = device_support_pdl(input.device) if enable_pdl is not False else False if input.dim() > 2: m = input.numel() // input.shape[-1] @@ -1171,39 +1336,69 @@ def nvfp4_quantize_cute_dsl( return fp4_output, scale_output - # Non-TMA path + # Non-TMA path: dual-path dispatch + # Get or compile kernel (device-independent) + kernel_fn, block_unit = _get_compiled_kernel_nvfp4( + dtype_key, k, sf_layout, enable_pdl + ) + + target_grid = num_sm * _BLOCKS_PER_SM + if sf_layout == SF_LAYOUT_LINEAR: - row_tile_size = 1 padded_m = m padded_sf_cols = num_sf_blocks_per_row - elif sf_layout == SF_LAYOUT_8x4: - row_tile_size = 8 - padded_m = ((m + row_tile_size - 1) // row_tile_size) * row_tile_size - padded_sf_cols = ((num_sf_blocks_per_row + 3) // 4) * 4 + total_sf_blocks = m * num_sf_blocks_per_row + scale_output_size = total_sf_blocks + + sf_blocks_per_tb = block_unit + num_blocks = min( + (total_sf_blocks + sf_blocks_per_tb - 1) // sf_blocks_per_tb, + target_grid, + ) + + fp4_output = torch.empty(m, k // 2, dtype=torch.uint8, device=input.device) + scale_output = torch.empty( + scale_output_size, dtype=torch.uint8, device=input.device + ) + + kernel_fn( + input, + fp4_output, + scale_output, + m, + total_sf_blocks, + num_blocks, + global_scale_tensor, + ) else: - row_tile_size = ROW_TILE_SIZE # 128 + if sf_layout == SF_LAYOUT_8x4: + row_tile_size = 8 + else: + row_tile_size = ROW_TILE_SIZE # 128 padded_m = ((m + row_tile_size - 1) // row_tile_size) * row_tile_size padded_sf_cols = ((num_sf_blocks_per_row + 3) // 4) * 4 + scale_output_size = padded_m * padded_sf_cols - scale_output_size = padded_m * padded_sf_cols - - kernel_fn, rows_per_block = _get_compiled_kernel_nvfp4( - dtype_key, k, sf_layout, enable_pdl - ) - - default_target_grid = num_sm * _BLOCKS_PER_SM - num_blocks = min( - (padded_m + rows_per_block - 1) // rows_per_block, default_target_grid - ) + rows_per_block = block_unit + num_blocks = min( + (padded_m + rows_per_block - 1) // rows_per_block, + target_grid, + ) - fp4_output = torch.empty(m, k // 2, dtype=torch.uint8, device=input.device) - scale_output = torch.empty( - scale_output_size, dtype=torch.uint8, device=input.device - ) + fp4_output = torch.empty(m, k // 2, dtype=torch.uint8, device=input.device) + scale_output = torch.empty( + scale_output_size, dtype=torch.uint8, device=input.device + ) - kernel_fn( - input, fp4_output, scale_output, m, padded_m, num_blocks, global_scale_tensor - ) + kernel_fn( + input, + fp4_output, + scale_output, + m, + padded_m, + num_blocks, + global_scale_tensor, + ) # Reshape using padded_sf_cols: for swizzled layouts the buffer includes # column padding; for linear layout padded_sf_cols == num_sf_blocks_per_row. @@ -1216,6 +1411,7 @@ def nvfp4_quantize_cute_dsl( "SF_LAYOUT_128x4", "SF_LAYOUT_8x4", "SF_LAYOUT_LINEAR", + "NVFP4QuantizeLinearKernel", "NVFP4QuantizeSwizzledKernel", "NVFP4QuantizeTMAKernel", "nvfp4_quantize_cute_dsl", diff --git a/flashinfer/quantization/quantization_cute_dsl_utils.py b/flashinfer/quantization/quantization_cute_dsl_utils.py index 8de4acdcf9..e0c86972dc 100644 --- a/flashinfer/quantization/quantization_cute_dsl_utils.py +++ b/flashinfer/quantization/quantization_cute_dsl_utils.py @@ -39,9 +39,19 @@ # Thread organization constants WARP_SIZE = 32 -ELTS_PER_THREAD = 8 # Each thread handles 8 FP16 elements (128 bits) -THREADS_PER_SF = SF_VEC_SIZE // ELTS_PER_THREAD # 32 / 8 = 4 threads per SF block -SF_BLOCKS_PER_WARP = WARP_SIZE // THREADS_PER_SF # 32 / 4 = 8 SF blocks per warp + +# Default: optimized 2-thread-per-SF configuration for large problems +ELTS_PER_THREAD = 16 # Each thread handles 16 FP16 elements (2 × 128-bit loads) +THREADS_PER_SF = SF_VEC_SIZE // ELTS_PER_THREAD # 32 / 16 = 2 threads per SF block +SF_BLOCKS_PER_WARP = WARP_SIZE // THREADS_PER_SF # 32 / 2 = 16 SF blocks per warp + +# Legacy: 4-thread-per-SF configuration for small problems (better grid occupancy) +ELTS_PER_THREAD_SMALL = 8 +THREADS_PER_SF_SMALL = SF_VEC_SIZE // ELTS_PER_THREAD_SMALL # 32 / 8 = 4 +SF_BLOCKS_PER_WARP_SMALL = WARP_SIZE // THREADS_PER_SF_SMALL # 32 / 4 = 8 + +# Threshold: use 2T/SF when total_sf_blocks >= this value (M*K >= 2M elements) +MXFP8_2T_SF_THRESHOLD = 65536 # Row tiling for swizzled layout (128x4 pattern) ROW_TILE_SIZE = 128 @@ -187,9 +197,10 @@ def float_to_ue8m0_fast(value: Float32, *, loc=None, ip=None) -> Uint32: @dsl_user_op def ue8m0_to_inv_scale_fast(ue8m0_val: Uint32, *, loc=None, ip=None) -> Float32: """ - Convert UE8M0 to inverse scale using fast ex2.approx. + Convert UE8M0 to inverse scale using integer bit construction. - Inverse scale = 2^(127 - ue8m0) + Constructs a float32 with exponent = (254 - ue8m0) and zero mantissa, + which is exactly 2^(127 - ue8m0). No SFU dependency. Returns 0 for ue8m0 == 0. """ return Float32( @@ -198,15 +209,16 @@ def ue8m0_to_inv_scale_fast(ue8m0_val: Uint32, *, loc=None, ip=None) -> Float32: [Uint32(ue8m0_val).ir_value(loc=loc, ip=ip)], """ { + .reg .s32 new_exp; + .reg .b32 float_bits; .reg .pred p_zero; - .reg .s32 neg_exp; - .reg .f32 neg_exp_f, result; setp.eq.u32 p_zero, $1, 0; - sub.s32 neg_exp, 127, $1; - cvt.rn.f32.s32 neg_exp_f, neg_exp; - ex2.approx.f32 result, neg_exp_f; - selp.f32 $0, 0f00000000, result, p_zero; + sub.s32 new_exp, 254, $1; + max.s32 new_exp, new_exp, 0; + shl.b32 float_bits, new_exp, 23; + mov.b32 $0, float_bits; + @p_zero mov.b32 $0, 0; } """, "=f,r", @@ -480,9 +492,22 @@ def shuffle_xor_f32(val: Float32, offset: int) -> Float32: return cute.arch.shuffle_sync_bfly(val, offset=offset) +@cute.jit +def reduce_max_2threads(val: Float32) -> Float32: + """Reduce max across 2 consecutive threads using 1 XOR shuffle.""" + from ..cute_dsl.fp4_common import fmax_f32 + + other = shuffle_xor_f32(val, 1) + val = fmax_f32(val, other) + return val + + @cute.jit def reduce_max_4threads(val: Float32) -> Float32: - """Reduce max across 4 consecutive threads using 2 XOR shuffles.""" + """Reduce max across 4 consecutive threads using 2 XOR shuffles. + + Kept for backward compatibility with MXFP4 kernels. + """ from ..cute_dsl.fp4_common import fmax_f32 other = shuffle_xor_f32(val, 1) @@ -1388,6 +1413,10 @@ def process_nvfp4_block_fp8( "ELTS_PER_THREAD", "THREADS_PER_SF", "SF_BLOCKS_PER_WARP", + "ELTS_PER_THREAD_SMALL", + "THREADS_PER_SF_SMALL", + "SF_BLOCKS_PER_WARP_SMALL", + "MXFP8_2T_SF_THRESHOLD", "ROW_TILE_SIZE", # MXFP4 Constants "MXFP4_SF_VEC_SIZE", @@ -1399,6 +1428,7 @@ def process_nvfp4_block_fp8( "bfloat2_hmax_reduce_to_f32", "float_to_ue8m0_fast", "ue8m0_to_inv_scale_fast", + "reduce_max_2threads", "reduce_max_4threads", "compute_sf_index_swizzled_128x4_gpu", "compute_sf_index_swizzled_8x4_gpu", diff --git a/tests/utils/test_fp4_quantize.py b/tests/utils/test_fp4_quantize.py index 292cacbd44..7f04a6af3c 100644 --- a/tests/utils/test_fp4_quantize.py +++ b/tests/utils/test_fp4_quantize.py @@ -346,7 +346,24 @@ def test_e2m1_dequantization( # MXFP4 Quantization Tests (Both Backends) # ============================================================================= -MXFP4_SHAPES = [(128, 64), (256, 128), (512, 256), (128, 1024), (1024, 2048)] +MXFP4_SHAPES = [ + # K must be a multiple of 128 so K/32 is a multiple of 4 (CUDA reshape + # constraint for swizzled layout). + # Small M with swizzled layout: padded_M >> M (row padding dominance) + (1, 128), # padded_M=128, 127 padding rows + (1, 1024), # padded_M=128, large K + (3, 256), # padded_M=128, odd M + (16, 128), # padded_M=128, 112 padding rows + (64, 128), # padded_M=128, 64 padding rows + # Standard sizes + (128, 128), + (256, 128), + (512, 256), + (128, 1024), + (1024, 2048), + # Large K (column loop path in swizzled kernel) + (128, 16384), +] MXFP4_BACKENDS = ["cuda", "cute-dsl"] @@ -491,7 +508,24 @@ def test_mxfp4_quantize_backend_parity( # NVFP4 Quantization Tests (Both Backends) # ============================================================================= -NVFP4_SHAPES = [(128, 64), (256, 128), (512, 256), (128, 1024), (1024, 2048)] +NVFP4_SHAPES = [ + # K must be a multiple of 64 so K/16 is a multiple of 4 (CUDA reshape + # constraint for swizzled layout). + # Small M with swizzled layout: padded_M >> M (row padding dominance) + (1, 64), # padded_M=128, 127 padding rows + (1, 1024), # padded_M=128, large K + (3, 128), # padded_M=128 (128x4) or 8 (8x4), odd M + (16, 64), # padded_M=128, 112 padding rows + (64, 128), # padded_M=128, 64 padding rows + # Standard sizes + (128, 64), + (256, 128), + (512, 256), + (128, 1024), + (1024, 2048), + # Large K (column loop path in swizzled kernel) + (128, 16384), +] NVFP4_BACKENDS = ["cuda", "cute-dsl"] NVFP4_SF_LAYOUTS = [SfLayout.layout_128x4, SfLayout.layout_8x4, SfLayout.layout_linear] # Roundtrip test only for layouts the dequantizer supports (128x4 and linear) diff --git a/tests/utils/test_fp4_quantize_padding.py b/tests/utils/test_fp4_quantize_padding.py index bd60b031cf..edcb936ade 100644 --- a/tests/utils/test_fp4_quantize_padding.py +++ b/tests/utils/test_fp4_quantize_padding.py @@ -22,6 +22,7 @@ (1025, 1024), (1025, 6144), ] +BACKENDS = ["cuda", "cute-dsl"] SEEDS = [42] CUDA_DEVICES = ["cuda:0"] @@ -31,24 +32,41 @@ BLOCK_SIZE = 16 +def _is_fp4_supported(device: torch.device) -> bool: + return ( + is_sm100a_supported(device) + or is_sm110a_supported(device) + or is_sm12x_supported(device) + ) + + +def _is_cute_dsl_available() -> bool: + try: + from flashinfer.cute_dsl import is_cute_dsl_available + + return is_cute_dsl_available() + except ImportError: + return False + + @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("shape", UNALIGNED_M_SHAPES) +@pytest.mark.parametrize("backend", BACKENDS) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) @torch.inference_mode() def test_fp4_quantize_unaligned_m_non_swizzled( dtype: torch.dtype, shape: tuple[int, int], + backend: str, seed: int, device: str, ) -> None: """Regression test: fp4_quantize with M not a multiple of 16 for linear SF.""" - if not ( - is_sm100a_supported(torch.device(device)) - or is_sm110a_supported(torch.device(device)) - or is_sm12x_supported(torch.device(device)) - ): + if not _is_fp4_supported(torch.device(device)): pytest.skip("Nvfp4 Requires compute capability >= 10 and CUDA >= 12.8") + if backend == "cute-dsl" and not _is_cute_dsl_available(): + pytest.skip("CuTe-DSL not available") torch.set_default_device(device) torch.manual_seed(seed) @@ -60,7 +78,9 @@ def test_fp4_quantize_unaligned_m_non_swizzled( tensor_amax = torch.abs(x).max().to(torch.float32) global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax - out_val, out_sf = fp4_quantize(x, global_scale, sf_vec_size, False, False) + out_val, out_sf = fp4_quantize( + x, global_scale, sf_vec_size, False, False, backend=backend + ) assert out_val.shape == (m, n // 2), ( f"Expected val shape {(m, n // 2)}, got {out_val.shape}" @@ -76,7 +96,3 @@ def test_fp4_quantize_unaligned_m_non_swizzled( # atol=0.5 accounts for FP4 E2M1 rounding at the 0/0.5 boundary torch.testing.assert_close(out_ans, out_ref, rtol=1e0, atol=5e-1) torch.testing.assert_close(out_scale, scale_ref, rtol=1e-1, atol=1e-1) - - -if __name__ == "__main__": - pytest.main([__file__, "-v"]) diff --git a/tests/utils/test_fp8_quantize.py b/tests/utils/test_fp8_quantize.py index c5da042010..c38ccf7506 100644 --- a/tests/utils/test_fp8_quantize.py +++ b/tests/utils/test_fp8_quantize.py @@ -15,8 +15,8 @@ def is_cute_dsl_available(): return False -@pytest.mark.parametrize("m", [1, 1024]) -@pytest.mark.parametrize("k", [1024]) +@pytest.mark.parametrize("m", [1, 3, 16, 64, 1024]) +@pytest.mark.parametrize("k", [128, 1024, 8192]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("is_sf_swizzled_layout", [True, False]) @pytest.mark.parametrize("device", ["cuda", "cpu"]) @@ -111,8 +111,8 @@ def test_mxfp8_quantize_torch_host(m, k, dtype, is_sf_swizzled_layout): mxfp8_quantize_check_accuracy(a_pt, a, 8, 0, 0.999) -@pytest.mark.parametrize("m", [1, 2, 16, 1024]) -@pytest.mark.parametrize("k", [512, 1024]) +@pytest.mark.parametrize("m", [1, 2, 3, 16, 64, 1024]) +@pytest.mark.parametrize("k", [128, 512, 1024, 8192]) @pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) @pytest.mark.parametrize("is_sf_swizzled_layout", [True, False]) @pytest.mark.parametrize("backend", ["cuda", "cute-dsl"]) @@ -182,8 +182,8 @@ def test_mxfp8_quantize_alignment_torch_device( ) -@pytest.mark.parametrize("m", [1, 128, 2048]) -@pytest.mark.parametrize("k", [1024]) +@pytest.mark.parametrize("m", [1, 3, 128, 2048]) +@pytest.mark.parametrize("k", [128, 1024]) @pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) @pytest.mark.parametrize("is_sf_swizzled_layout", [True, False]) @pytest.mark.parametrize("backend", ["cuda", "cute-dsl"]) @@ -343,15 +343,15 @@ def test_cute_dsl_compilation_cache_m_agnostic(is_sf_swizzled_layout): pytest.skip("CuTe-DSL is not available") from flashinfer.quantization.kernels.mxfp8_quantize import ( - _get_compiled_kernel_linear, - _get_compiled_kernel_swizzled, + _get_compiled_kernel_mxfp8_linear, + _get_compiled_kernel_mxfp8_swizzled, ) # Get the appropriate cache based on layout if is_sf_swizzled_layout: - cache_fn = _get_compiled_kernel_swizzled + cache_fn = _get_compiled_kernel_mxfp8_swizzled else: - cache_fn = _get_compiled_kernel_linear + cache_fn = _get_compiled_kernel_mxfp8_linear # Clear the cache to start fresh cache_fn.cache_clear() @@ -409,15 +409,15 @@ def test_cute_dsl_compilation_cache_k_specific(is_sf_swizzled_layout): pytest.skip("CuTe-DSL is not available") from flashinfer.quantization.kernels.mxfp8_quantize import ( - _get_compiled_kernel_linear, - _get_compiled_kernel_swizzled, + _get_compiled_kernel_mxfp8_linear, + _get_compiled_kernel_mxfp8_swizzled, ) # Get the appropriate cache based on layout if is_sf_swizzled_layout: - cache_fn = _get_compiled_kernel_swizzled + cache_fn = _get_compiled_kernel_mxfp8_swizzled else: - cache_fn = _get_compiled_kernel_linear + cache_fn = _get_compiled_kernel_mxfp8_linear # Clear the cache to start fresh cache_fn.cache_clear() @@ -469,15 +469,15 @@ def test_cute_dsl_compilation_cache_dtype_specific(is_sf_swizzled_layout): pytest.skip("CuTe-DSL is not available") from flashinfer.quantization.kernels.mxfp8_quantize import ( - _get_compiled_kernel_linear, - _get_compiled_kernel_swizzled, + _get_compiled_kernel_mxfp8_linear, + _get_compiled_kernel_mxfp8_swizzled, ) # Get the appropriate cache based on layout if is_sf_swizzled_layout: - cache_fn = _get_compiled_kernel_swizzled + cache_fn = _get_compiled_kernel_mxfp8_swizzled else: - cache_fn = _get_compiled_kernel_linear + cache_fn = _get_compiled_kernel_mxfp8_linear # Clear the cache to start fresh cache_fn.cache_clear()