diff --git a/benchmarks/bench_mxfp4_quantize_backend_comparison.py b/benchmarks/bench_mxfp4_quantize_backend_comparison.py new file mode 100644 index 0000000000..5c51c5c18f --- /dev/null +++ b/benchmarks/bench_mxfp4_quantize_backend_comparison.py @@ -0,0 +1,702 @@ +""" +Copyright (c) 2025 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +Benchmark: MXFP4 Quantization Backend Comparison (CUDA vs CuTe-DSL) + +Compares the performance of CUDA and CuTe-DSL backends for MXFP4 quantization +across different M and K dimensions. Each configuration is verified for +correctness before timing. Generates heatmaps showing relative performance +(speedup of CuTe-DSL over CUDA). + +Can also measure achieved memory bandwidth in TB/s for the CuTe-DSL backend. + +Usage: + # Speedup comparison mode (default, includes correctness verification) + python bench_mxfp4_quantize_backend_comparison.py + + # Bandwidth measurement mode (cute-dsl only) + python bench_mxfp4_quantize_backend_comparison.py --bandwidth + +Requirements: + - Blackwell GPU (SM100+) for CuTe-DSL backend + - matplotlib for visualization +""" + +import argparse +import numpy as np +import torch +from typing import Dict, List, Tuple + +from flashinfer.testing.utils import bench_gpu_time + +# Constants for bandwidth calculation +SF_VEC_SIZE = 32 # Scale factor vector size for MXFP4 + + +def get_cc(): + """Get CUDA compute capability.""" + major, minor = torch.cuda.get_device_capability() + return major * 10 + minor + + +def verify_mxfp4_correctness( + m: int, + k: int, + dtype: torch.dtype, +) -> Tuple[bool, str, float, float]: + """ + Verify that both backends produce correct outputs via roundtrip test. + + Returns: + Tuple of (success, message, quant_match_pct, scale_match_pct) + On failure, quant_match_pct and scale_match_pct are 0.0 + """ + import flashinfer + + torch.manual_seed(42) + x = torch.randn(m, k, device="cuda", dtype=dtype) + + try: + # Test CUDA backend + quant_cuda, scale_cuda = flashinfer.mxfp4_quantize(x, backend="cuda") + dq_cuda = flashinfer.mxfp4_dequantize(quant_cuda, scale_cuda) + + # Test CuTe-DSL backend + quant_cute, scale_cute = flashinfer.mxfp4_quantize(x, backend="cute-dsl") + dq_cute = flashinfer.mxfp4_dequantize(quant_cute, scale_cute) + + # Check shapes match + if quant_cuda.shape != quant_cute.shape: + return ( + False, + f"Quant shape mismatch: CUDA={quant_cuda.shape}, CuTe={quant_cute.shape}", + 0.0, + 0.0, + ) + if scale_cuda.shape != scale_cute.shape: + return ( + False, + f"Scale shape mismatch: CUDA={scale_cuda.shape}, CuTe={scale_cute.shape}", + 0.0, + 0.0, + ) + + # Check roundtrip quality for both backends (cosine similarity) + x_f32 = x.cpu().to(torch.float32).view(1, -1) + dq_cuda_f32 = dq_cuda.cpu().to(torch.float32).view(1, -1) + dq_cute_f32 = dq_cute.cpu().to(torch.float32).view(1, -1) + + cos_sim_cuda = torch.nn.functional.cosine_similarity(x_f32, dq_cuda_f32).item() + cos_sim_cute = torch.nn.functional.cosine_similarity(x_f32, dq_cute_f32).item() + + # Check backend agreement + quant_match_pct = (quant_cuda == quant_cute).float().mean().item() * 100 + scale_match_pct = (scale_cuda == scale_cute).float().mean().item() * 100 + + # FP4 quantization should have cosine similarity > 0.9 + if cos_sim_cuda < 0.9: + return ( + False, + f"CUDA roundtrip quality too low: cos_sim={cos_sim_cuda:.4f}", + quant_match_pct, + scale_match_pct, + ) + if cos_sim_cute < 0.9: + return ( + False, + f"CuTe-DSL roundtrip quality too low: cos_sim={cos_sim_cute:.4f}", + quant_match_pct, + scale_match_pct, + ) + + return True, "OK", quant_match_pct, scale_match_pct + + except Exception as e: + return False, f"Exception: {e}", 0.0, 0.0 + + +def bench_mxfp4_quantize( + m: int, + k: int, + dtype: torch.dtype, + backend: str, +) -> float: + """ + Benchmark MXFP4 quantization for a specific configuration. + + Args: + m: Number of rows + k: Number of columns + dtype: Input dtype (torch.float16 or torch.bfloat16) + backend: "cuda" or "cute-dsl" + + Returns: + Median execution time in milliseconds + """ + import flashinfer + + # Create input tensor + x = torch.randn(m, k, device="cuda", dtype=dtype) + + # Warmup and get output shapes + _ = flashinfer.mxfp4_quantize(x, backend=backend) + + # Benchmark + def run_kernel(): + flashinfer.mxfp4_quantize(x, backend=backend) + + times = bench_gpu_time( + fn=run_kernel, + enable_cupti=True, + dry_run_iters=5, + repeat_iters=30, + cold_l2_cache=True, + use_cuda_graph=False, + ) + + return np.median(times) + + +def compute_bandwidth_tb_per_sec( + m: int, k: int, dtype: torch.dtype, time_ms: float +) -> float: + """ + Compute achieved memory bandwidth in TB/s. + + Memory bandwidth calculation for mxfp4_quantize: + - Read: input tensor (2 bytes per element for fp16/bf16) + - Write: quantized tensor (0.5 bytes per element, since fp4 = 4 bits) + - Write: scale factors (1 byte per scale factor) + + Args: + m: Number of rows + k: Number of columns + dtype: Input dtype (determines bytes per element) + time_ms: Execution time in milliseconds + + Returns: + Achieved bandwidth in TB/s + """ + input_dtype_bytes = 2 # fp16 or bf16 + + num_elements = m * k + num_scale_factors = num_elements // SF_VEC_SIZE + + # Total bytes transferred + problem_bytes = ( + num_elements * input_dtype_bytes # input read + + num_elements // 2 # fp4 output write (2 fp4 values per byte) + + num_scale_factors * 1 # scale factors write + ) + + # Convert ms to seconds, bytes to TB + tb_per_sec = problem_bytes / (1e9 * time_ms) # 1e9 = 10^12 bytes/TB / 10^3 ms/s + return tb_per_sec + + +def run_bandwidth_sweep( + m_values: List[int], + k_values: List[int], + dtype: torch.dtype, +) -> Dict[Tuple[int, int], float]: + """ + Run bandwidth benchmark sweep for CuTe-DSL backend only. + + Returns: + Dictionary mapping (m, k) to achieved bandwidth in TB/s + """ + bandwidth_results = {} + + total = len(m_values) * len(k_values) + current = 0 + + print(f"\nBenchmarking MXFP4 swizzled layout, dtype={dtype} (CuTe-DSL bandwidth)") + print("=" * 60) + + for m in m_values: + for k in k_values: + current += 1 + print(f"[{current}/{total}] M={m:5d}, K={k:5d} ... ", end="", flush=True) + + # Benchmark CuTe-DSL backend only + time_ms = bench_mxfp4_quantize(m, k, dtype, backend="cute-dsl") + + # Compute bandwidth + bandwidth = compute_bandwidth_tb_per_sec(m, k, dtype, time_ms) + bandwidth_results[(m, k)] = bandwidth + + print(f"time={time_ms:.3f}ms, bandwidth={bandwidth:.2f} TB/s") + + return bandwidth_results + + +def run_benchmark_sweep( + m_values: List[int], + k_values: List[int], + dtype: torch.dtype, +) -> Tuple[Dict[Tuple[int, int], float], Dict[Tuple[int, int], float]]: + """ + Run benchmark sweep for both backends with inline correctness verification. + + Args: + m_values: List of M dimensions to benchmark + k_values: List of K dimensions to benchmark + dtype: Input dtype + + Returns: + Tuple of (cuda_times, cute_dsl_times) dictionaries + """ + cuda_times = {} + cute_dsl_times = {} + failures = [] + + total = len(m_values) * len(k_values) + current = 0 + + print(f"\nBenchmarking MXFP4 swizzled layout, dtype={dtype}") + print("=" * 95) + print( + f"{'Progress':<12} {'M':>5} {'K':>5} | " + f"{'--Match--':^14} | " + f"{'-------Timing-------':^28}" + ) + print( + f"{'':12} {'':>5} {'':>5} | " + f"{'quant':>6} {'scale':>6} | " + f"{'CUDA':>8} {'CuTe':>8} {'Speedup':>10}" + ) + print("-" * 95) + + for m in m_values: + for k in k_values: + current += 1 + + # Verify correctness first + success, verify_msg, quant_match, scale_match = verify_mxfp4_correctness( + m, k, dtype + ) + if not success: + failures.append((m, k, verify_msg)) + print(f"[{current:3d}/{total}] {m:5d} {k:5d} | FAIL: {verify_msg}") + continue + + # Benchmark CUDA backend + cuda_time = bench_mxfp4_quantize(m, k, dtype, backend="cuda") + cuda_times[(m, k)] = cuda_time + + # Benchmark CuTe-DSL backend + cute_dsl_time = bench_mxfp4_quantize(m, k, dtype, backend="cute-dsl") + cute_dsl_times[(m, k)] = cute_dsl_time + + # Compute speedup + speedup = cuda_time / cute_dsl_time + speedup_str = ( + f"{speedup:.2f}x" if speedup >= 1 else f"{1 / speedup:.2f}x slower" + ) + print( + f"[{current:3d}/{total}] {m:5d} {k:5d} | " + f"{quant_match:5.1f}% {scale_match:6.1f}% | " + f"{cuda_time:7.3f}ms {cute_dsl_time:7.3f}ms {speedup_str:>10}" + ) + + if failures: + print(f"\nWARNING: {len(failures)}/{total} configurations failed verification:") + for m, k, msg in failures: + print(f" - M={m}, K={k}: {msg}") + + return cuda_times, cute_dsl_times + + +def create_heatmap( + m_values: List[int], + k_values: List[int], + cuda_times: Dict[Tuple[int, int], float], + cute_dsl_times: Dict[Tuple[int, int], float], + title: str, + output_file: str, +): + """ + Create a heatmap showing relative performance (CuTe-DSL speedup over CUDA). + + Values > 1.0 mean CuTe-DSL is faster, < 1.0 means CUDA is faster. + """ + try: + import matplotlib.pyplot as plt + import matplotlib.colors as mcolors + except ImportError: + print("matplotlib not installed, skipping heatmap generation") + return + + # Create speedup matrix (CUDA time / CuTe-DSL time) + # > 1.0 means CuTe-DSL is faster + speedup_matrix = np.zeros((len(m_values), len(k_values))) + + for i, m in enumerate(m_values): + for j, k in enumerate(k_values): + cuda_time = cuda_times.get((m, k), float("nan")) + cute_dsl_time = cute_dsl_times.get((m, k), float("nan")) + if cute_dsl_time > 0: + speedup_matrix[i, j] = cuda_time / cute_dsl_time + else: + speedup_matrix[i, j] = float("nan") + + # Create figure + fig, ax = plt.subplots(figsize=(12, 10)) + + # Create diverging colormap centered at 1.0 + # Green = CuTe-DSL faster (>1), Red = CUDA faster (<1) + vmin = min(0.5, np.nanmin(speedup_matrix)) + vmax = max(2.0, np.nanmax(speedup_matrix)) + + # Use log scale centered at 1.0 for better visualization + norm = mcolors.TwoSlopeNorm(vmin=vmin, vcenter=1.0, vmax=vmax) + + # Create heatmap + im = ax.imshow( + speedup_matrix, + cmap="RdYlGn", # Red-Yellow-Green: red=CUDA faster, green=CuTe-DSL faster + norm=norm, + aspect="auto", + ) + + # Add colorbar + cbar = ax.figure.colorbar(im, ax=ax, shrink=0.8) + cbar.ax.set_ylabel("Speedup (CUDA time / CuTe-DSL time)", rotation=-90, va="bottom") + + # Set ticks and labels + ax.set_xticks(np.arange(len(k_values))) + ax.set_yticks(np.arange(len(m_values))) + ax.set_xticklabels([str(k) for k in k_values]) + ax.set_yticklabels([str(m) for m in m_values]) + + # Rotate x-axis labels + plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") + + # Add value annotations + for i in range(len(m_values)): + for j in range(len(k_values)): + value = speedup_matrix[i, j] + if not np.isnan(value): + # Choose text color based on background + text_color = "white" if value < 0.7 or value > 1.5 else "black" + ax.text( + j, + i, + f"{value:.2f}", + ha="center", + va="center", + color=text_color, + fontsize=8, + ) + + # Labels and title + ax.set_xlabel("K (columns)") + ax.set_ylabel("M (rows)") + ax.set_title(title + "\n(>1.0 = CuTe-DSL faster, <1.0 = CUDA faster)") + + plt.tight_layout() + plt.savefig(output_file, dpi=150, bbox_inches="tight") + print(f"Saved heatmap to {output_file}") + plt.close() + + +def create_bandwidth_heatmap( + m_values: List[int], + k_values: List[int], + bandwidth_results: Dict[Tuple[int, int], float], + title: str, + output_file: str, +): + """ + Create a heatmap showing achieved memory bandwidth in TB/s. + """ + try: + import matplotlib.pyplot as plt + except ImportError: + print("matplotlib not installed, skipping heatmap generation") + return + + # Create bandwidth matrix + bandwidth_matrix = np.zeros((len(m_values), len(k_values))) + + for i, m in enumerate(m_values): + for j, k in enumerate(k_values): + bandwidth_matrix[i, j] = bandwidth_results.get((m, k), float("nan")) + + # Create figure + fig, ax = plt.subplots(figsize=(12, 10)) + + # Use sequential colormap (higher bandwidth = better = greener) + vmin = np.nanmin(bandwidth_matrix) + vmax = np.nanmax(bandwidth_matrix) + + # Create heatmap with viridis colormap (good for sequential data) + im = ax.imshow( + bandwidth_matrix, + cmap="YlGn", # Yellow-Green: darker green = higher bandwidth + vmin=vmin, + vmax=vmax, + aspect="auto", + ) + + # Add colorbar + cbar = ax.figure.colorbar(im, ax=ax, shrink=0.8) + cbar.ax.set_ylabel("Achieved Bandwidth (TB/s)", rotation=-90, va="bottom") + + # Set ticks and labels + ax.set_xticks(np.arange(len(k_values))) + ax.set_yticks(np.arange(len(m_values))) + ax.set_xticklabels([str(k) for k in k_values]) + ax.set_yticklabels([str(m) for m in m_values]) + + # Rotate x-axis labels + plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") + + # Add value annotations + for i in range(len(m_values)): + for j in range(len(k_values)): + value = bandwidth_matrix[i, j] + if not np.isnan(value): + # Choose text color based on background brightness + normalized = (value - vmin) / (vmax - vmin) if vmax > vmin else 0.5 + text_color = "white" if normalized > 0.6 else "black" + ax.text( + j, + i, + f"{value:.1f}", + ha="center", + va="center", + color=text_color, + fontsize=8, + ) + + # Labels and title + ax.set_xlabel("K (columns)") + ax.set_ylabel("M (rows)") + ax.set_title(title) + + plt.tight_layout() + plt.savefig(output_file, dpi=150, bbox_inches="tight") + print(f"Saved heatmap to {output_file}") + plt.close() + + +def print_bandwidth_summary_table( + m_values: List[int], + k_values: List[int], + bandwidth_results: Dict[Tuple[int, int], float], +): + """Print a summary table of bandwidth results.""" + print(f"\n{'=' * 80}") + print("Bandwidth Summary: MXFP4 Swizzled Layout (TB/s)") + print(f"{'=' * 80}") + + # Header + header = "M\\K".ljust(8) + for k in k_values: + header += f"{k:>8}" + print(header) + print("-" * (8 + 8 * len(k_values))) + + # Data rows + for m in m_values: + row = f"{m:<8}" + for k in k_values: + bandwidth = bandwidth_results.get((m, k), float("nan")) + if not np.isnan(bandwidth): + row += f"{bandwidth:>8.1f}" + else: + row += f"{'N/A':>8}" + print(row) + + # Compute overall statistics + bandwidths = [b for b in bandwidth_results.values() if not np.isnan(b)] + + if bandwidths: + print("\nStatistics:") + print(f" Mean bandwidth: {np.mean(bandwidths):.2f} TB/s") + print(f" Min bandwidth: {min(bandwidths):.2f} TB/s") + print(f" Max bandwidth: {max(bandwidths):.2f} TB/s") + print(f" Std deviation: {np.std(bandwidths):.2f} TB/s") + + +def print_summary_table( + m_values: List[int], + k_values: List[int], + cuda_times: Dict[Tuple[int, int], float], + cute_dsl_times: Dict[Tuple[int, int], float], +): + """Print a summary table of results.""" + print(f"\n{'=' * 80}") + print("Summary: MXFP4 Swizzled Layout (Speedup: CUDA time / CuTe-DSL time)") + print(f"{'=' * 80}") + + # Header + header = "M\\K".ljust(8) + for k in k_values: + header += f"{k:>8}" + print(header) + print("-" * (8 + 8 * len(k_values))) + + # Data rows + for m in m_values: + row = f"{m:<8}" + for k in k_values: + cuda_time = cuda_times.get((m, k), float("nan")) + cute_dsl_time = cute_dsl_times.get((m, k), float("nan")) + if cute_dsl_time > 0 and not np.isnan(cuda_time): + speedup = cuda_time / cute_dsl_time + row += f"{speedup:>8.2f}" + else: + row += f"{'N/A':>8}" + print(row) + + # Compute overall statistics + speedups = [] + for m in m_values: + for k in k_values: + cuda_time = cuda_times.get((m, k)) + cute_dsl_time = cute_dsl_times.get((m, k)) + if cuda_time and cute_dsl_time and cute_dsl_time > 0: + speedups.append(cuda_time / cute_dsl_time) + + if speedups: + print("\nStatistics:") + print(f" Geometric mean speedup: {np.exp(np.mean(np.log(speedups))):.2f}x") + print(f" Min speedup: {min(speedups):.2f}x") + print(f" Max speedup: {max(speedups):.2f}x") + print( + f" Cases where CuTe-DSL faster: {sum(1 for s in speedups if s > 1)}/{len(speedups)}" + ) + + +def main(): + parser = argparse.ArgumentParser( + description="Benchmark MXFP4 quantization backends" + ) + parser.add_argument( + "--bandwidth", + action="store_true", + help="Run bandwidth benchmark (CuTe-DSL only) instead of comparison", + ) + parser.add_argument( + "--dtype", + choices=["float16", "bfloat16"], + default="bfloat16", + help="Input data type (default: bfloat16)", + ) + parser.add_argument( + "--output-prefix", + type=str, + default="mxfp4_quantize_backend", + help="Output file prefix for heatmaps", + ) + args = parser.parse_args() + + # Check compute capability + cc = get_cc() + print(f"GPU Compute Capability: SM{cc}") + + if cc < 100: + print("ERROR: CuTe-DSL backend requires Blackwell GPU (SM100+)") + return + + # Get dtype + dtype = torch.float16 if args.dtype == "float16" else torch.bfloat16 + print(f"Data type: {dtype}") + + # Define sweep ranges (powers of 2 + common transformer hidden dimensions) + # Note: K must be a multiple of 128 for MXFP4 swizzled layout because: + # - SF vec size is 32, so K/32 gives number of SF blocks per row + # - Swizzled layout pads SF blocks to multiples of 4 + # - The CUDA backend's reshape assumes unpadded SF dimensions + # So K/32 must already be a multiple of 4, i.e., K must be multiple of 128 + m_values = [ + 128, + 256, + 384, + 512, + 768, + 1024, + 1536, + 2048, + 3072, + 4096, + 6144, + 8192, + 12288, + 16384, + 32768, + ] + k_values = [ + 128, + 256, + 384, + 512, + 768, + 1024, + 1536, + 2048, + 3072, + 4096, + 5120, + 6144, + 8192, + 12288, + 16384, + ] + + print(f"\nM values: {m_values}") + print(f"K values: {k_values}") + + if args.bandwidth: + # Bandwidth measurement mode: CuTe-DSL only + print("\n" + "=" * 80) + print("BANDWIDTH MEASUREMENT MODE (CuTe-DSL only)") + print("=" * 80) + + bandwidth_results = run_bandwidth_sweep(m_values, k_values, dtype) + print_bandwidth_summary_table(m_values, k_values, bandwidth_results) + + # Generate bandwidth heatmap + create_bandwidth_heatmap( + m_values, + k_values, + bandwidth_results, + f"MXFP4 Quantization CuTe-DSL Bandwidth ({args.dtype})", + f"{args.output_prefix}_bandwidth_{args.dtype}.png", + ) + else: + # Run comparison benchmark (with inline correctness verification) + cuda_times, cute_dsl_times = run_benchmark_sweep(m_values, k_values, dtype) + print_summary_table(m_values, k_values, cuda_times, cute_dsl_times) + + # Generate heatmap + create_heatmap( + m_values, + k_values, + cuda_times, + cute_dsl_times, + f"MXFP4 Quantization Backend Comparison ({args.dtype})", + f"{args.output_prefix}_comparison_{args.dtype}.png", + ) + + print("\n" + "=" * 80) + print("BENCHMARK COMPLETE") + print("=" * 80) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/bench_mxfp8_quantize_backend_comparison.py b/benchmarks/bench_mxfp8_quantize_backend_comparison.py new file mode 100644 index 0000000000..25ae1b6abb --- /dev/null +++ b/benchmarks/bench_mxfp8_quantize_backend_comparison.py @@ -0,0 +1,683 @@ +""" +Copyright (c) 2025 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +Benchmark: MXFP8 Quantization Backend Comparison (CUDA vs CuTe-DSL) + +Compares the performance of CUDA and CuTe-DSL backends for MXFP8 quantization +across different M and K dimensions. Generates heatmaps showing relative +performance (speedup of CuTe-DSL over CUDA). + +Can also measure achieved memory bandwidth in TB/s for the CuTe-DSL backend. + +Usage: + # Speedup comparison mode (default) + python bench_mxfp8_quantize_backend_comparison.py + + # Bandwidth measurement mode (cute-dsl only) + python bench_mxfp8_quantize_backend_comparison.py --bandwidth + +Requirements: + - Blackwell GPU (SM100+) for CuTe-DSL backend + - matplotlib for visualization +""" + +import argparse +import numpy as np +import torch +from typing import Dict, List, Tuple + +from flashinfer.testing.utils import bench_gpu_time + +# Constants for bandwidth calculation +SF_VEC_SIZE = 32 # Scale factor vector size for MXFP8 + + +def get_cc(): + """Get CUDA compute capability.""" + major, minor = torch.cuda.get_device_capability() + return major * 10 + minor + + +def bench_mxfp8_quantize( + m: int, + k: int, + dtype: torch.dtype, + is_sf_swizzled_layout: bool, + backend: str, +) -> float: + """ + Benchmark MXFP8 quantization for a specific configuration. + + Args: + m: Number of rows + k: Number of columns + dtype: Input dtype (torch.float16 or torch.bfloat16) + is_sf_swizzled_layout: Whether to use swizzled scale factor layout + backend: "cuda" or "cute-dsl" + + Returns: + Median execution time in milliseconds + """ + import flashinfer + + # Create input tensor + x = torch.randn(m, k, device="cuda", dtype=dtype) + + # Warmup and get output shapes + _ = flashinfer.mxfp8_quantize( + x, + is_sf_swizzled_layout=is_sf_swizzled_layout, + backend=backend, + ) + + # Benchmark + def run_kernel(): + flashinfer.mxfp8_quantize( + x, + is_sf_swizzled_layout=is_sf_swizzled_layout, + backend=backend, + ) + + times = bench_gpu_time( + fn=run_kernel, + enable_cupti=True, + dry_run_iters=5, + repeat_iters=30, + cold_l2_cache=True, + use_cuda_graph=False, + ) + + return np.median(times) + + +def compute_bandwidth_tb_per_sec( + m: int, k: int, dtype: torch.dtype, time_ms: float +) -> float: + """ + Compute achieved memory bandwidth in TB/s. + + Memory bandwidth calculation for mxfp8_quantize: + - Read: input tensor (2 bytes per element for fp16/bf16) + - Write: quantized tensor (1 byte per element, fp8) + - Write: scale factors (1 byte per scale factor) + + Args: + m: Number of rows + k: Number of columns + dtype: Input dtype (determines bytes per element) + time_ms: Execution time in milliseconds + + Returns: + Achieved bandwidth in TB/s + """ + input_dtype_bytes = 2 # fp16 or bf16 + + num_elements = m * k + num_scale_factors = num_elements // SF_VEC_SIZE + + # Total bytes transferred + problem_bytes = ( + num_elements * input_dtype_bytes # input read + + num_elements * 1 # fp8 output write + + num_scale_factors * 1 # scale factors write + ) + + # Convert ms to seconds, bytes to TB + tb_per_sec = problem_bytes / (1e9 * time_ms) # 1e9 = 10^12 bytes/TB / 10^3 ms/s + return tb_per_sec + + +def run_bandwidth_sweep( + m_values: List[int], + k_values: List[int], + dtype: torch.dtype, + is_sf_swizzled_layout: bool, +) -> Dict[Tuple[int, int], float]: + """ + Run bandwidth benchmark sweep for CuTe-DSL backend only. + + Returns: + Dictionary mapping (m, k) to achieved bandwidth in TB/s + """ + bandwidth_results = {} + + total = len(m_values) * len(k_values) + current = 0 + + layout_str = "swizzled" if is_sf_swizzled_layout else "linear" + print(f"\nBenchmarking {layout_str} layout, dtype={dtype} (CuTe-DSL bandwidth)") + print("=" * 60) + + for m in m_values: + for k in k_values: + current += 1 + print(f"[{current}/{total}] M={m:5d}, K={k:5d} ... ", end="", flush=True) + + # Benchmark CuTe-DSL backend only + time_ms = bench_mxfp8_quantize( + m, k, dtype, is_sf_swizzled_layout, backend="cute-dsl" + ) + + # Compute bandwidth + bandwidth = compute_bandwidth_tb_per_sec(m, k, dtype, time_ms) + bandwidth_results[(m, k)] = bandwidth + + print(f"time={time_ms:.3f}ms, bandwidth={bandwidth:.2f} TB/s") + + return bandwidth_results + + +def run_benchmark_sweep( + m_values: List[int], + k_values: List[int], + dtype: torch.dtype, + is_sf_swizzled_layout: bool, +) -> Tuple[Dict[Tuple[int, int], float], Dict[Tuple[int, int], float]]: + """ + Run benchmark sweep for both backends. + + Returns: + Tuple of (cuda_times, cute_dsl_times) dictionaries + """ + cuda_times = {} + cute_dsl_times = {} + + total = len(m_values) * len(k_values) + current = 0 + + layout_str = "swizzled" if is_sf_swizzled_layout else "linear" + print(f"\nBenchmarking {layout_str} layout, dtype={dtype}") + print("=" * 60) + + for m in m_values: + for k in k_values: + current += 1 + print(f"[{current}/{total}] M={m:5d}, K={k:5d} ... ", end="", flush=True) + + # Benchmark CUDA backend + cuda_time = bench_mxfp8_quantize( + m, k, dtype, is_sf_swizzled_layout, backend="cuda" + ) + cuda_times[(m, k)] = cuda_time + + # Benchmark CuTe-DSL backend + cute_dsl_time = bench_mxfp8_quantize( + m, k, dtype, is_sf_swizzled_layout, backend="cute-dsl" + ) + cute_dsl_times[(m, k)] = cute_dsl_time + + # Compute speedup + speedup = cuda_time / cute_dsl_time + speedup_str = ( + f"{speedup:.2f}x" if speedup >= 1 else f"{1 / speedup:.2f}x slower" + ) + print( + f"CUDA={cuda_time:.3f}ms, CuTe-DSL={cute_dsl_time:.3f}ms, " + f"Speedup={speedup_str}" + ) + + return cuda_times, cute_dsl_times + + +def create_heatmap( + m_values: List[int], + k_values: List[int], + cuda_times: Dict[Tuple[int, int], float], + cute_dsl_times: Dict[Tuple[int, int], float], + title: str, + output_file: str, +): + """ + Create a heatmap showing relative performance (CuTe-DSL speedup over CUDA). + + Values > 1.0 mean CuTe-DSL is faster, < 1.0 means CUDA is faster. + """ + try: + import matplotlib.pyplot as plt + import matplotlib.colors as mcolors + except ImportError: + print("matplotlib not installed, skipping heatmap generation") + return + + # Create speedup matrix (CUDA time / CuTe-DSL time) + # > 1.0 means CuTe-DSL is faster + speedup_matrix = np.zeros((len(m_values), len(k_values))) + + for i, m in enumerate(m_values): + for j, k in enumerate(k_values): + cuda_time = cuda_times.get((m, k), float("nan")) + cute_dsl_time = cute_dsl_times.get((m, k), float("nan")) + if cute_dsl_time > 0: + speedup_matrix[i, j] = cuda_time / cute_dsl_time + else: + speedup_matrix[i, j] = float("nan") + + # Create figure + fig, ax = plt.subplots(figsize=(12, 10)) + + # Create diverging colormap centered at 1.0 + # Green = CuTe-DSL faster (>1), Red = CUDA faster (<1) + vmin = min(0.5, np.nanmin(speedup_matrix)) + vmax = max(2.0, np.nanmax(speedup_matrix)) + + # Use log scale centered at 1.0 for better visualization + norm = mcolors.TwoSlopeNorm(vmin=vmin, vcenter=1.0, vmax=vmax) + + # Create heatmap + im = ax.imshow( + speedup_matrix, + cmap="RdYlGn", # Red-Yellow-Green: red=CUDA faster, green=CuTe-DSL faster + norm=norm, + aspect="auto", + ) + + # Add colorbar + cbar = ax.figure.colorbar(im, ax=ax, shrink=0.8) + cbar.ax.set_ylabel("Speedup (CUDA time / CuTe-DSL time)", rotation=-90, va="bottom") + + # Set ticks and labels + ax.set_xticks(np.arange(len(k_values))) + ax.set_yticks(np.arange(len(m_values))) + ax.set_xticklabels([str(k) for k in k_values]) + ax.set_yticklabels([str(m) for m in m_values]) + + # Rotate x-axis labels + plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") + + # Add value annotations + for i in range(len(m_values)): + for j in range(len(k_values)): + value = speedup_matrix[i, j] + if not np.isnan(value): + # Choose text color based on background + text_color = "white" if value < 0.7 or value > 1.5 else "black" + ax.text( + j, + i, + f"{value:.2f}", + ha="center", + va="center", + color=text_color, + fontsize=8, + ) + + # Labels and title + ax.set_xlabel("K (columns)") + ax.set_ylabel("M (rows)") + ax.set_title(title + "\n(>1.0 = CuTe-DSL faster, <1.0 = CUDA faster)") + + plt.tight_layout() + plt.savefig(output_file, dpi=150, bbox_inches="tight") + print(f"Saved heatmap to {output_file}") + plt.close() + + +def create_bandwidth_heatmap( + m_values: List[int], + k_values: List[int], + bandwidth_results: Dict[Tuple[int, int], float], + title: str, + output_file: str, +): + """ + Create a heatmap showing achieved memory bandwidth in TB/s. + """ + try: + import matplotlib.pyplot as plt + except ImportError: + print("matplotlib not installed, skipping heatmap generation") + return + + # Create bandwidth matrix + bandwidth_matrix = np.zeros((len(m_values), len(k_values))) + + for i, m in enumerate(m_values): + for j, k in enumerate(k_values): + bandwidth_matrix[i, j] = bandwidth_results.get((m, k), float("nan")) + + # Create figure + fig, ax = plt.subplots(figsize=(12, 10)) + + # Use sequential colormap (higher bandwidth = better = greener) + vmin = np.nanmin(bandwidth_matrix) + vmax = np.nanmax(bandwidth_matrix) + + # Create heatmap with viridis colormap (good for sequential data) + im = ax.imshow( + bandwidth_matrix, + cmap="YlGn", # Yellow-Green: darker green = higher bandwidth + vmin=vmin, + vmax=vmax, + aspect="auto", + ) + + # Add colorbar + cbar = ax.figure.colorbar(im, ax=ax, shrink=0.8) + cbar.ax.set_ylabel("Achieved Bandwidth (TB/s)", rotation=-90, va="bottom") + + # Set ticks and labels + ax.set_xticks(np.arange(len(k_values))) + ax.set_yticks(np.arange(len(m_values))) + ax.set_xticklabels([str(k) for k in k_values]) + ax.set_yticklabels([str(m) for m in m_values]) + + # Rotate x-axis labels + plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") + + # Add value annotations + for i in range(len(m_values)): + for j in range(len(k_values)): + value = bandwidth_matrix[i, j] + if not np.isnan(value): + # Choose text color based on background brightness + normalized = (value - vmin) / (vmax - vmin) if vmax > vmin else 0.5 + text_color = "white" if normalized > 0.6 else "black" + ax.text( + j, + i, + f"{value:.1f}", + ha="center", + va="center", + color=text_color, + fontsize=8, + ) + + # Labels and title + ax.set_xlabel("K (columns)") + ax.set_ylabel("M (rows)") + ax.set_title(title) + + plt.tight_layout() + plt.savefig(output_file, dpi=150, bbox_inches="tight") + print(f"Saved heatmap to {output_file}") + plt.close() + + +def print_bandwidth_summary_table( + m_values: List[int], + k_values: List[int], + bandwidth_results: Dict[Tuple[int, int], float], + layout_name: str, +): + """Print a summary table of bandwidth results.""" + print(f"\n{'=' * 80}") + print(f"Bandwidth Summary: {layout_name} (TB/s)") + print(f"{'=' * 80}") + + # Header + header = "M\\K".ljust(8) + for k in k_values: + header += f"{k:>8}" + print(header) + print("-" * (8 + 8 * len(k_values))) + + # Data rows + for m in m_values: + row = f"{m:<8}" + for k in k_values: + bandwidth = bandwidth_results.get((m, k), float("nan")) + if not np.isnan(bandwidth): + row += f"{bandwidth:>8.1f}" + else: + row += f"{'N/A':>8}" + print(row) + + # Compute overall statistics + bandwidths = [b for b in bandwidth_results.values() if not np.isnan(b)] + + if bandwidths: + print("\nStatistics:") + print(f" Mean bandwidth: {np.mean(bandwidths):.2f} TB/s") + print(f" Min bandwidth: {min(bandwidths):.2f} TB/s") + print(f" Max bandwidth: {max(bandwidths):.2f} TB/s") + print(f" Std deviation: {np.std(bandwidths):.2f} TB/s") + + +def print_summary_table( + m_values: List[int], + k_values: List[int], + cuda_times: Dict[Tuple[int, int], float], + cute_dsl_times: Dict[Tuple[int, int], float], + layout_name: str, +): + """Print a summary table of results.""" + print(f"\n{'=' * 80}") + print(f"Summary: {layout_name}") + print(f"{'=' * 80}") + + # Header + header = "M\\K".ljust(8) + for k in k_values: + header += f"{k:>8}" + print(header) + print("-" * (8 + 8 * len(k_values))) + + # Data rows + for m in m_values: + row = f"{m:<8}" + for k in k_values: + cuda_time = cuda_times.get((m, k), float("nan")) + cute_dsl_time = cute_dsl_times.get((m, k), float("nan")) + if cute_dsl_time > 0: + speedup = cuda_time / cute_dsl_time + row += f"{speedup:>8.2f}" + else: + row += f"{'N/A':>8}" + print(row) + + # Compute overall statistics + speedups = [] + for m in m_values: + for k in k_values: + cuda_time = cuda_times.get((m, k)) + cute_dsl_time = cute_dsl_times.get((m, k)) + if cuda_time and cute_dsl_time and cute_dsl_time > 0: + speedups.append(cuda_time / cute_dsl_time) + + if speedups: + print("\nStatistics:") + print(f" Geometric mean speedup: {np.exp(np.mean(np.log(speedups))):.2f}x") + print(f" Min speedup: {min(speedups):.2f}x") + print(f" Max speedup: {max(speedups):.2f}x") + print( + f" Cases where CuTe-DSL faster: {sum(1 for s in speedups if s > 1)}/{len(speedups)}" + ) + + +def main(): + parser = argparse.ArgumentParser( + description="Benchmark MXFP8 Quantization: CUDA vs CuTe-DSL" + ) + parser.add_argument( + "--dtype", + type=str, + default="bfloat16", + choices=["float16", "bfloat16"], + help="Input data type", + ) + parser.add_argument( + "--output-prefix", + type=str, + default="mxfp8_backend_comparison", + help="Prefix for output files", + ) + parser.add_argument( + "--bandwidth", + action="store_true", + help="Measure achieved memory bandwidth (TB/s) for CuTe-DSL backend only, " + "instead of comparing speedup between CUDA and CuTe-DSL", + ) + args = parser.parse_args() + + # Check GPU capability + cc = get_cc() + print(f"GPU Compute Capability: SM{cc}") + + if cc < 100: + print("ERROR: CuTe-DSL backend requires Blackwell GPU (SM100+)") + return + + # Set dtype + dtype = torch.float16 if args.dtype == "float16" else torch.bfloat16 + print(f"Data type: {dtype}") + + # Define sweep ranges (powers of 2 + common transformer hidden dimensions) + m_values = [ + 128, + 256, + 384, + 512, + 768, + 1024, + 1536, + 2048, + 3072, + 4096, + 6144, + 8192, + 12288, + 16384, + 32768, + ] + k_values = [ + 128, + 256, + 384, + 512, + 768, + 1024, + 1536, + 2048, + 3072, + 4096, + 5120, + 6144, + 8192, + 12288, + 16384, + ] + + print(f"\nM values: {m_values}") + print(f"K values: {k_values}") + + if args.bandwidth: + # Bandwidth measurement mode: CuTe-DSL only + print("\n" + "=" * 80) + print("BANDWIDTH MEASUREMENT MODE (CuTe-DSL only)") + print("=" * 80) + + # Benchmark linear layout (non-swizzled) + print("\n" + "=" * 80) + print("BENCHMARKING LINEAR (NON-SWIZZLED) LAYOUT - BANDWIDTH") + print("=" * 80) + + bandwidth_linear = run_bandwidth_sweep( + m_values, k_values, dtype, is_sf_swizzled_layout=False + ) + + print_bandwidth_summary_table( + m_values, k_values, bandwidth_linear, "Linear Layout" + ) + + create_bandwidth_heatmap( + m_values, + k_values, + bandwidth_linear, + f"MXFP8 Quantization Bandwidth (CuTe-DSL) - Linear Layout - {args.dtype}", + f"{args.output_prefix}_bandwidth_linear_{args.dtype}.png", + ) + + # Benchmark swizzled layout + print("\n" + "=" * 80) + print("BENCHMARKING SWIZZLED LAYOUT - BANDWIDTH") + print("=" * 80) + + bandwidth_swizzled = run_bandwidth_sweep( + m_values, k_values, dtype, is_sf_swizzled_layout=True + ) + + print_bandwidth_summary_table( + m_values, k_values, bandwidth_swizzled, "Swizzled Layout" + ) + + create_bandwidth_heatmap( + m_values, + k_values, + bandwidth_swizzled, + f"MXFP8 Quantization Bandwidth (CuTe-DSL) - Swizzled Layout - {args.dtype}", + f"{args.output_prefix}_bandwidth_swizzled_{args.dtype}.png", + ) + + else: + # Speedup comparison mode: CUDA vs CuTe-DSL + # Benchmark linear layout (non-swizzled) + print("\n" + "=" * 80) + print("BENCHMARKING LINEAR (NON-SWIZZLED) LAYOUT") + print("=" * 80) + + cuda_times_linear, cute_dsl_times_linear = run_benchmark_sweep( + m_values, k_values, dtype, is_sf_swizzled_layout=False + ) + + print_summary_table( + m_values, + k_values, + cuda_times_linear, + cute_dsl_times_linear, + "Linear Layout", + ) + + create_heatmap( + m_values, + k_values, + cuda_times_linear, + cute_dsl_times_linear, + f"MXFP8 Quantization Speedup (CuTe-DSL vs CUDA) - Linear Layout - {args.dtype}", + f"{args.output_prefix}_linear_{args.dtype}.png", + ) + + # Benchmark swizzled layout + print("\n" + "=" * 80) + print("BENCHMARKING SWIZZLED LAYOUT") + print("=" * 80) + + cuda_times_swizzled, cute_dsl_times_swizzled = run_benchmark_sweep( + m_values, k_values, dtype, is_sf_swizzled_layout=True + ) + + print_summary_table( + m_values, + k_values, + cuda_times_swizzled, + cute_dsl_times_swizzled, + "Swizzled Layout", + ) + + create_heatmap( + m_values, + k_values, + cuda_times_swizzled, + cute_dsl_times_swizzled, + f"MXFP8 Quantization Speedup (CuTe-DSL vs CUDA) - Swizzled Layout - {args.dtype}", + f"{args.output_prefix}_swizzled_{args.dtype}.png", + ) + + print("\n" + "=" * 80) + print("BENCHMARK COMPLETE") + print("=" * 80) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/routines/flashinfer_benchmark_utils.py b/benchmarks/routines/flashinfer_benchmark_utils.py index 95fe833b3c..98082b447f 100644 --- a/benchmarks/routines/flashinfer_benchmark_utils.py +++ b/benchmarks/routines/flashinfer_benchmark_utils.py @@ -484,9 +484,9 @@ def dtype_str_to_torch_dtype(dtype_str): "8.6": [], "8.9": [], "9.0": [], - "10.0": ["cuda"], - "10.3": ["cuda"], - "12.0": ["cuda"], + "10.0": ["cuda", "cute-dsl"], + "10.3": ["cuda", "cute-dsl"], + "12.0": ["cuda", "cute-dsl"], }, "mxfp4_quantize": { "7.5": [], @@ -494,9 +494,9 @@ def dtype_str_to_torch_dtype(dtype_str): "8.6": [], "8.9": [], "9.0": [], - "10.0": ["cuda"], - "10.3": ["cuda"], - "12.0": ["cuda"], + "10.0": ["cuda", "cute-dsl"], + "10.3": ["cuda", "cute-dsl"], + "12.0": ["cuda", "cute-dsl"], }, "nvfp4_quantize": { "7.5": [], diff --git a/benchmarks/routines/quantization.py b/benchmarks/routines/quantization.py index e1f40a9220..55bfa33691 100644 --- a/benchmarks/routines/quantization.py +++ b/benchmarks/routines/quantization.py @@ -116,7 +116,7 @@ def parse_quantization_args(line, parser): required=False, nargs="+", default=["cuda"], - choices=["cuda"], + choices=["cuda", "cute-dsl"], help="Backend to test. Default: cuda", ) # FP4 quantization specific arguments @@ -231,15 +231,13 @@ def testMxfp8Quantize(args): print(f"[VVERBOSE] {enable_pdl = }") def run_backend(backend, input_tensor): - if backend == "cuda": - return flashinfer.mxfp8_quantize( - input_tensor, - is_sf_swizzled_layout=is_sf_swizzled_layout, - alignment=alignment, - enable_pdl=enable_pdl, - ) - else: - raise ValueError(f"Unsupported backend: {backend}") + return flashinfer.mxfp8_quantize( + input_tensor, + is_sf_swizzled_layout=is_sf_swizzled_layout, + alignment=alignment, + enable_pdl=enable_pdl, + backend=backend, + ) # Reference check via dequantize round-trip has_reference_output = False @@ -391,6 +389,7 @@ def testMxfp4Quantize(args): backends = args.backends[:] # Make a copy to avoid modifying the original m = args.m k = args.k + enable_pdl = args.enable_pdl is_cuda_graph_compatible = not args.no_cuda_graph run_refcheck = args.refcheck res = [] @@ -421,12 +420,14 @@ def testMxfp4Quantize(args): if args.verbose >= 2: print(f"[VVERBOSE] {input_tensor.shape = }") print(f"[VVERBOSE] {input_tensor.dtype = }") + print(f"[VVERBOSE] {enable_pdl = }") def run_backend(backend, input_tensor): - if backend == "cuda": - return flashinfer.mxfp4_quantize(input_tensor) - else: - raise ValueError(f"Unsupported backend: {backend}") + return flashinfer.mxfp4_quantize( + input_tensor, + backend=backend, + enable_pdl=enable_pdl, + ) # Reference check via dequantize round-trip has_reference_output = False @@ -529,6 +530,7 @@ def run_backend(backend, input_tensor): cur_res["m"] = m cur_res["k"] = k cur_res["input_dtype"] = str(input_dtype) + cur_res["enable_pdl"] = enable_pdl cur_res["backend"] = backend cur_res["case_tag"] = args.case_tag res.append(cur_res) diff --git a/flashinfer/__init__.py b/flashinfer/__init__.py index 6c8619580f..6ac32efe69 100644 --- a/flashinfer/__init__.py +++ b/flashinfer/__init__.py @@ -59,7 +59,7 @@ ) from .decode import cudnn_batch_decode_with_kv_cache as cudnn_batch_decode_with_kv_cache from .decode import single_decode_with_kv_cache as single_decode_with_kv_cache -from .fp4_quantization import ( +from .quantization.fp4_quantization import ( SfLayout, block_scale_interleave, nvfp4_block_scale_interleave, @@ -73,10 +73,11 @@ shuffle_matrix_a, shuffle_matrix_sf_a, scaled_fp4_grouped_quantize, + get_fp4_quantization_module, nvfp4_kv_dequantize, nvfp4_kv_quantize, ) -from .fp8_quantization import mxfp8_dequantize_host, mxfp8_quantize +from .quantization.fp8_quantization import mxfp8_dequantize_host, mxfp8_quantize from .fused_moe import ( ActivationType, RoutingMethodType, diff --git a/flashinfer/activation.py b/flashinfer/activation.py index 35abb2fdba..3bdd3df769 100644 --- a/flashinfer/activation.py +++ b/flashinfer/activation.py @@ -28,7 +28,7 @@ register_fake_op, get_compute_capability, ) -from .fp4_quantization import get_fp4_quantization_module +from .quantization.fp4_quantization import get_fp4_quantization_module @functools.cache diff --git a/flashinfer/fp4_quantization.py b/flashinfer/fp4_quantization.py index b2019f9352..d6e8f0bf1f 100644 --- a/flashinfer/fp4_quantization.py +++ b/flashinfer/fp4_quantization.py @@ -1,1170 +1,69 @@ """ -Copyright (c) 2025 by FlashInfer team. +Backwards compatibility stub for flashinfer.fp4_quantization. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at +This module re-exports all symbols from flashinfer.quantization.fp4_quantization +to maintain backwards compatibility with existing code that imports from +flashinfer.fp4_quantization. - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. +New code should import from flashinfer.quantization.fp4_quantization directly. """ -import functools -from enum import Enum -from types import SimpleNamespace -from typing import List, Optional, Tuple - -import torch - -from .api_logging import flashinfer_api -from .jit import JitSpec -from .jit import env as jit_env -from .jit import ( - gen_jit_spec, - sm121a_nvcc_flags, - sm120a_nvcc_flags, - sm120f_nvcc_flags, - sm110a_nvcc_flags, - sm103a_nvcc_flags, - sm100a_nvcc_flags, - sm90a_nvcc_flags, +# Re-export everything from the new location +from .quantization.fp4_quantization import ( + SfLayout, + block_scale_interleave, + nvfp4_block_scale_interleave, + e2m1_and_ufp8sf_scale_to_float, + fp4_quantize, + mxfp4_dequantize_host, + mxfp4_dequantize, + mxfp4_quantize, + nvfp4_quantize, + nvfp4_batched_quantize, + shuffle_matrix_a, + shuffle_matrix_sf_a, + scaled_fp4_grouped_quantize, + get_fp4_quantization_module, + gen_fp4_quantization_module, + gen_fp4_quantization_sm90_module, + gen_fp4_quantization_sm100_module, + gen_fp4_quantization_sm103_module, + gen_fp4_quantization_sm110_module, + gen_fp4_quantization_sm120_module, + gen_fp4_quantization_sm120f_module, + gen_fp4_quantization_sm121_module, + nvfp4_kv_dequantize, + nvfp4_kv_quantize, + # Private functions needed by some tests + _pad_scale_factors, + _compute_swizzled_layout_sf_size, ) -from .jit.cpp_ext import is_cuda_version_at_least -from .utils import ( - backend_requirement, - device_support_pdl, - get_compute_capability, - get_shuffle_matrix_a_row_indices, - get_shuffle_matrix_sf_a_row_indices, - register_custom_op, - register_fake_op, - supported_compute_capability, - round_up, -) - - -def _compute_swizzled_layout_sf_size(total_row, total_column, row_size=128): - padded_row = round_up(total_row, row_size) - padded_column = round_up(total_column, 4) - return padded_row * padded_column - - -def _pad_scale_factors( - unswizzled_sf: torch.Tensor, m: int, n: int, sf_vec_size: int = 16 -) -> torch.Tensor: - """Pad scale factors tensor to meet alignment requirements. - - Args: - unswizzled_sf (torch.Tensor): Input scale factors tensor with dtype uint8. - m (int): M dimension. - n (int): N dimension. - sf_vec_size (int, optional): Scale factor vector size. Defaults to 16. - - Returns: - torch.Tensor: Padded scale factors tensor. - """ - factor = sf_vec_size * 4 - padded_row = round_up(m, 128) - padded_col = round_up(n, factor) - - # Pad the input tensor to [padded_row, padded_col // scaling_vector_size] - pad_rows = padded_row - m - pad_cols = (padded_col - n) // sf_vec_size - if pad_rows == 0 and pad_cols == 0: - return unswizzled_sf - else: - return torch.nn.functional.pad( - unswizzled_sf, (0, pad_cols, 0, pad_rows), mode="constant", value=0 - ).contiguous() - - -def gen_fp4_quantization_sm100_module() -> JitSpec: - return gen_fp4_quantization_module(sm100a_nvcc_flags, "100") - - -def gen_fp4_quantization_sm103_module() -> JitSpec: - return gen_fp4_quantization_module(sm103a_nvcc_flags, "103") - - -def gen_fp4_quantization_sm90_module() -> JitSpec: - return gen_fp4_quantization_module(sm90a_nvcc_flags, "90") - - -def gen_fp4_quantization_sm110_module() -> JitSpec: - return gen_fp4_quantization_module(sm110a_nvcc_flags, "110") - - -def gen_fp4_quantization_sm120_module() -> JitSpec: - return gen_fp4_quantization_module(sm120a_nvcc_flags, "120") - - -def gen_fp4_quantization_sm120f_module() -> JitSpec: - return gen_fp4_quantization_module(sm120f_nvcc_flags, "120f") - - -def gen_fp4_quantization_sm121_module() -> JitSpec: - return gen_fp4_quantization_module(sm121a_nvcc_flags, "121") - - -def gen_fp4_quantization_module(nvcc_flags: List[str], device_arch: str) -> JitSpec: - return gen_jit_spec( - f"fp4_quantization_{device_arch}", - [ - jit_env.FLASHINFER_CSRC_DIR - / "nv_internal/tensorrt_llm/thop/fp4Quantize.cpp", - jit_env.FLASHINFER_CSRC_DIR / "nv_internal/tensorrt_llm/thop/fp4Op.cpp", - jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/kernels/quantization.cu", - jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/envUtils.cpp", - jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/logger.cpp", - jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/stringUtils.cpp", - jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/tllmException.cpp", - ], - extra_cuda_cflags=nvcc_flags - + [ - "-DENABLE_BF16", - "-DENABLE_FP8", - "-DENABLE_FP4" if is_cuda_version_at_least("12.8") else "", - ], - extra_cflags=[ - "-DENABLE_BF16", - "-DENABLE_FP8", - "-DENABLE_FP4" if is_cuda_version_at_least("12.8") else "", - ], - extra_include_paths=[ - jit_env.FLASHINFER_CSRC_DIR / "nv_internal", - jit_env.FLASHINFER_CSRC_DIR / "nv_internal" / "include", - ], - ) - - -@functools.cache -def get_fp4_quantization_module(backend: str = "100"): - backend_modules = { - "121": gen_fp4_quantization_sm121_module, - "120f": gen_fp4_quantization_sm120f_module, - "120": gen_fp4_quantization_sm120_module, - "110": gen_fp4_quantization_sm110_module, - "103": gen_fp4_quantization_sm103_module, - "100": gen_fp4_quantization_sm100_module, - "90": gen_fp4_quantization_sm90_module, - } - - # Prefer 'f' (family / feature-set) variant for SM12x when CUDA >= 12.9, - # as it enables native FP4 conversion instructions (cvt.rn.satfinite.e2m1x2.f32). - # sm_120f covers the entire SM12x family (both SM120 and SM121). - # See: https://developer.nvidia.com/blog/nvidia-blackwell-and-nvidia-cuda-12-9-introduce-family-specific-architecture-features/ - if backend in ("120", "121"): - from .utils import version_at_least - - if version_at_least(torch.version.cuda, "12.9"): - backend = "120f" - - if backend not in backend_modules: - raise ValueError(f"Invalid backend: {backend}") - - module = backend_modules[backend]().build_and_load() - - @register_custom_op( - "flashinfer::fp4_quantize_sm100", - mutates_args=(""), - ) - def fp4_quantize_sm100( - input: torch.Tensor, - global_scale: Optional[torch.Tensor] = None, - sf_vec_size: int = 16, - sf_use_ue8m0: bool = False, - is_sf_swizzled_layout: bool = True, - is_sf_8x4_layout: bool = False, - enable_pdl: Optional[bool] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Quantize input tensor to FP4 format. - - Args: - input (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16/fp8_quantized. - global_scale (torch.Tensor, optional): Global scale factor of shape [1] and dtype float32. - sf_vec_size (int, optional): Scale factor vector size. Defaults to 16. - sf_use_ue8m0 (bool, optional): Whether to use UE8M0 format for scale factors. Defaults to False. - is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True. - is_sf_8x4_layout (bool, optional): Whether to use 8x4 layout or 128x4 layout for scale factors. Defaults to False. - enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic Dependent Launch). - If None, automatically detects based on device capability. Defaults to None. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - - Quantized tensor of shape [M, K/2] with dtype FLOAT4_E2M1X2 - - Scale factors tensor with shape determined by layout and sf_vec_size - """ - if enable_pdl is None: - enable_pdl = device_support_pdl(input.device) - out_val = torch.empty( - (*input.shape[:-1], input.shape[-1] // 2), - dtype=torch.uint8, - device=input.device, - ) - m = input.numel() // input.shape[-1] - k = input.shape[-1] - if is_sf_swizzled_layout: - out_sf_size = _compute_swizzled_layout_sf_size( - m, k // sf_vec_size, 8 if is_sf_8x4_layout else 128 - ) - out_sf_size_padded = out_sf_size - else: - out_sf_size = m * k // sf_vec_size - out_sf_size_padded = round_up(m, 16) * k // sf_vec_size - out_sf = torch.empty( - (out_sf_size_padded,), dtype=torch.uint8, device=input.device - ) - module.fp4_quantize( - input, - global_scale, - out_val, - out_sf, - sf_vec_size, - sf_use_ue8m0, - is_sf_swizzled_layout, - is_sf_8x4_layout, - enable_pdl, - ) - return out_val, out_sf[:out_sf_size] - - @register_fake_op("flashinfer::fp4_quantize_sm100") - def _fake_fp4_quantize_sm100( - input: torch.Tensor, - global_scale: Optional[torch.Tensor] = None, - sf_vec_size: int = 16, - sf_use_ue8m0: bool = False, - is_sf_swizzled_layout: bool = True, - ) -> Tuple[torch.Tensor, torch.Tensor]: - m, k = input.shape - return ( - input.new_empty([m, k // 2], dtype=torch.int64), # FLOAT4_E2M1X2 - input.new_empty([m * k // sf_vec_size], dtype=torch.int32), # Scale factors - ) - - @register_custom_op( - "flashinfer::mxfp4_dequantize_host", - mutates_args=(""), - ) - def mxfp4_dequantize_host( - weight: torch.Tensor, - scale: torch.Tensor, - group_size: int = 32, - ) -> torch.Tensor: - out = torch.empty( - (weight.shape[0], weight.shape[1] * 2), - dtype=torch.float32, - device=weight.device, - ) - module.mxfp4_dequantize_host( - weight, - scale, - out, - group_size, - ) - return out - - @register_fake_op("flashinfer::mxfp4_dequantize_host") - def _fake_mxfp4_dequantize_host( - weight: torch.Tensor, - scale: torch.Tensor, - group_size: int = 32, - ) -> torch.Tensor: - return weight.new_empty( - [weight.shape[0], weight.shape[1] * 2], dtype=torch.float32 - ) - - @register_custom_op( - "flashinfer::block_scale_interleave_sm100", - mutates_args=("",), - ) - def block_scale_interleave_sm100( - unswizzled_sf: torch.Tensor, - ) -> torch.Tensor: - """Swizzle block scale tensor for FP4 format. - - Args: - unswizzled_sf (torch.Tensor): unswizzled block scale tensor with dtype uint8 or bfloat16. - - Returns: - torch.Tensor: output tensor for swizzled block scale with dtype uint8 or bfloat16. - """ - num_experts = unswizzled_sf.shape[0] if unswizzled_sf.dim() == 3 else 1 - expert_out_size = _compute_swizzled_layout_sf_size( - unswizzled_sf.shape[-2], unswizzled_sf.shape[-1], 128 - ) - out = torch.empty( - (num_experts * expert_out_size,), - dtype=unswizzled_sf.dtype, - device=unswizzled_sf.device, - ) - module.block_scale_interleave_sm100(unswizzled_sf, out) - return out - - @register_fake_op("flashinfer::block_scale_interleave_sm100") - def _fake_block_scale_interleave_sm100( - unswizzled_sf: torch.Tensor, - ) -> torch.Tensor: - return unswizzled_sf.new_empty( - [unswizzled_sf.shape[0] * unswizzled_sf.shape[1] // 16], dtype=torch.uint8 - ) - - @register_custom_op( - "flashinfer::fp4_batched_quantize_sm100", - mutates_args=("",), - ) - def fp4_batched_quantize_sm100( - input: torch.Tensor, - global_scale: Optional[torch.Tensor] = None, - sf_vec_size: int = 16, - sf_use_ue8m0: bool = False, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Quantize a batched tensor to FP4 (E2M1x2) with per-block scale factors. - - This function converts a float/bfloat16 (or FP8-quantized) input tensor into a - packed FP4 tensor using the E2M1 format (two 4-bit values per byte), along with - per-block scale factors. Scale factors are encoded as UE4M3 by default, or UE8M0 - when requested, and an optional global scale can be applied. - - Args: - input (torch.Tensor): Input tensor of shape [B, M, K] with dtype torch.float16, - torch.bfloat16, or an FP8-quantized dtype supported by the kernel. - global_scale (torch.Tensor, optional): Global scale factor of shape [1] and - dtype float32. - sf_vec_size (int, optional): Scale-factor vector size and alignment unit along K. - Supported/expected values: - - 16 (NVFP4 path; supported) - - 32 (MXFP4 path; not supported yet) - Defaults to 16. - sf_use_ue8m0 (bool, optional): Scale-factor encoding type. - False → UE4M3 (default), True → UE8M0. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: - - self_fp4 (torch.Tensor): Packed FP4 tensor in E2M1x2 format of shape - [B, M, K // 2] with dtype torch.uint8 (two FP4 lanes per byte). - - self_block_scale_factors (torch.Tensor): Block scale factors with dtype - uint8 (UE4M3 or UE8M0), laid out as a flat buffer of shape - [B, ceil(M / 128) * 128 * ceil(K / sf_vec_size / 4) * 4]. - - Notes: - - K must be even (because outputs pack two FP4 values per byte). - - For best performance, K should be a multiple of sf_vec_size; the scale-factor - buffer is aligned to sf_vec_size along K, pads M to multiples of 128, and - rounds (K / sf_vec_size) up to a multiple of 4 for storage. - - The batch dimension B is preserved for both outputs. - """ - b, m, k = input.shape - out_val = torch.empty( - (b, m, k // 2), - dtype=torch.uint8, - device=input.device, - ) - out_sf = torch.empty( - (b, _compute_swizzled_layout_sf_size(m, k // sf_vec_size, 128)), - dtype=torch.uint8, - device=input.device, - ) - module.fp4_batched_quantize( - input, - global_scale, - out_val, - out_sf, - sf_vec_size, - sf_use_ue8m0, - ) - return out_val, out_sf - - @register_fake_op("flashinfer::fp4_batched_quantize_sm100") - def _fake_fp4_batched_quantize_sm100( - input: torch.Tensor, - global_scale: Optional[torch.Tensor] = None, - sf_vec_size: int = 16, - sf_use_ue8m0: bool = False, - ) -> Tuple[torch.Tensor, torch.Tensor]: - b, m, k = input.shape - return ( - input.new_empty([b, m, k // 2], dtype=torch.uint8), # FLOAT4_E2M1X2 - input.new_empty( - [b, _compute_swizzled_layout_sf_size(m, k // sf_vec_size, 128)], - dtype=torch.uint8, - ), # swizzled SF buffer - ) - - @register_custom_op( - "flashinfer::silu_and_mul_scaled_nvfp4_experts_quantize_sm100", - mutates_args=("",), - ) - def silu_and_mul_scaled_nvfp4_experts_quantize_sm100( - input: torch.Tensor, - mask: torch.Tensor, - global_scale: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Quantize a silu and matmul with masked batched tensor to FP4 (E2M1x2) with per-block scale factors. - - This function first does silu and matmul to a float/bfloat16 input tensor then convect the result - into a packed FP4 tensor using the E2M1 format (two 4-bit values per byte), along with - per-block scale factors. Scale factors are encoded as UE4M3 by default, or UE8M0 - when requested, and an optional global scale can be applied. - - Args: - input (torch.Tensor): Input tensor of shape [B, M, K] with dtype torch.float16, - torch.bfloat16, or an FP8-quantized dtype supported by the kernel. - mask (torch.Tensor): mask tensor of shape [B] with dtype torch.int32. - global_scale (torch.Tensor, optional): Global scale factor of shape [1] and - dtype float32. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: - - self_fp4 (torch.Tensor): Packed FP4 tensor in E2M1x2 format of shape - [B, M, K // 2] with dtype torch.uint8 (two FP4 lanes per byte). - - self_block_scale_factors (torch.Tensor): Block scale factors with dtype - uint8 (UE4M3 or UE8M0), laid out as a flat buffer of shape - [B, ceil(M / 128) * 128 * ceil(K / sf_vec_size / 4) * 4]. - - Notes: - - K must be even (because outputs pack two FP4 values per byte). - - For best performance, K should be a multiple of sf_vec_size; the scale-factor - buffer is aligned to sf_vec_size along K, pads M to multiples of 128, and - rounds (K / sf_vec_size) up to a multiple of 4 for storage. - - The batch dimension B is preserved for both outputs. - """ - device = input.device - l, m, k_by_2 = input.shape - k = k_by_2 // 2 - sf_vec_size = 16 - assert k % sf_vec_size == 0, f"k must be multiple of 16, but got {k}." - - scale_k = k // sf_vec_size - padded_k = round_up(scale_k, 4) - padded_k_int32 = padded_k // 4 - padded_m = round_up(m, 128) - output = torch.empty(l, m, k // 2, device=device, dtype=torch.uint8) - output_scales = torch.empty( - l, padded_m, padded_k_int32, device=device, dtype=torch.int32 - ) - - module.silu_and_mul_scaled_nvfp4_experts_quantize( - output.view(l * m, k // 2), - output_scales.view(l * padded_m, padded_k_int32), - input.view(l * m, k_by_2), - global_scale, - mask, - True, - ) - output = output.permute(1, 2, 0) - output_scales = output_scales.view(torch.float8_e4m3fn).view( - l, padded_m // 128, padded_k // 4, 32, 4, 4 - ) - output_scales = output_scales.permute(3, 4, 1, 5, 2, 0) - return output, output_scales - - @register_fake_op("flashinfer::silu_and_mul_scaled_nvfp4_experts_quantize_sm100") - def _fake_silu_and_mul_scaled_nvfp4_experts_quantize_sm100( - input: torch.Tensor, - mask: torch.Tensor, - global_scale: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - device = input.device - l, m, k_by_2 = input.shape - k = k_by_2 // 2 - sf_vec_size = 16 - assert k % sf_vec_size == 0, f"k must be multiple of 16, but got {k}." - - scale_k = k // sf_vec_size - padded_k = round_up(scale_k, 4) - padded_k_int32 = padded_k // 4 - padded_m = round_up(m, 128) - output = torch.empty(l, m, k // 2, device=device, dtype=torch.uint8) - output_scales = torch.empty( - l, padded_m, padded_k_int32, device=device, dtype=torch.int32 - ) - - output_scales = output_scales.view(torch.float8_e4m3fn).view( - l, padded_m // 128, padded_k // 4, 32, 4, 4 - ) - output_scales = output_scales.permute(3, 4, 1, 5, 2, 0) - return (output, output_scales) - - @register_custom_op( - "flashinfer::scaled_fp4_grouped_quant_sm100", - mutates_args=("",), - ) - def scaled_fp4_grouped_quant_sm100( - input_tensor: torch.Tensor, - input_global_scale: torch.Tensor, - mask: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Quantize input tensor to FP4 and return quantized tensor and scale, for - grouped gemm inputs (e.g., grouped_gemm_nt_masked for flashinfer). - Args: - input: The input tensor to be quantized to FP4, with shape (l, m, k) - l is number of groups, m is number of tokens per group, k is number of features. - input_global_scale: A scalar scaling factor for the entire tensor, with - shape (l,). - Outputs: - output: The quantized tensor in FP4, with shape (m, k // 2, l) but the physical - layout is (l, m, k // 2). `// 2` is because two fp4 values are packed into - an uint8. - output_scales: The blockscale tensor in FP8-E4M3, with shape (32, 4, rm, 4, rk, l) - but the physical layout is (l, rm, rk, 32, 4, 4). - Note: - For the shape of output_scales, `32 * 4 * rm` is a padded m to nearest multiple of 128. - `4 * rk` is a padded `k // 16` to nearest multiple of 4. These layout constants are - required by the NVIDIA Blackwell MMA operations. - """ - device = input_tensor.device - l, m, k = input_tensor.shape - sf_vec_size = 16 - assert k % sf_vec_size == 0, f"k must be multiple of 16, but got {k}." - - scale_k = k // sf_vec_size - padded_k = round_up(scale_k, 4) - padded_k_int32 = padded_k // 4 - padded_m = round_up(m, 128) - output = torch.empty(l, m, k // 2, device=device, dtype=torch.uint8) - output_scales = torch.empty( - l, padded_m, padded_k_int32, device=device, dtype=torch.int32 - ) - - module.silu_and_mul_scaled_nvfp4_experts_quantize( - output.view(l * m, k // 2), - output_scales.view(l * padded_m, padded_k_int32), - input_tensor.view(l * m, k), - input_global_scale, - mask, - False, - ) - # The physical layout of the output is (l, m, k // 2), but we want to return a - # logical layout (m, k // 2, l) required by the flashinfer masked group gemm. - output = output.permute(1, 2, 0) - # The physical layout of the output scales is already swizzled as (l, rm, rk, 32, 4, 4), a - # requirement for the flashinfer masked group gemm, where rm=m/128 and rk=k/4. The logic - # layout is (32, 4, rm, 4, rk, l). - output_scales = output_scales.view(torch.float8_e4m3fn).view( - l, padded_m // 128, padded_k // 4, 32, 4, 4 - ) - output_scales = output_scales.permute(3, 4, 1, 5, 2, 0) - return output, output_scales - - @register_fake_op("flashinfer::scaled_fp4_grouped_quant_sm100") - def _fake_scaled_fp4_grouped_quant_sm100( - input_tensor: torch.Tensor, - input_global_scale: torch.Tensor, - mask: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - device = input_tensor.device - l, m, k = input_tensor.shape - sf_vec_size = 16 - assert k % sf_vec_size == 0, f"k must be multiple of 16, but got {k}." - - scale_k = k // sf_vec_size - padded_k = round_up(scale_k, 4) - padded_k_int32 = padded_k // 4 - padded_m = round_up(m, 128) - output = torch.empty(l, m, k // 2, device=device, dtype=torch.uint8) - output_scales = torch.empty( - l, padded_m, padded_k_int32, device=device, dtype=torch.int32 - ) - - output = output.permute(1, 2, 0) - output_scales = output_scales.view(torch.float8_e4m3fn).view( - l, padded_m // 128, padded_k // 4, 32, 4, 4 - ) - output_scales = output_scales.permute(3, 4, 1, 5, 2, 0) - return output, output_scales - - @register_custom_op( - "flashinfer::e2m1_and_ufp8sf_scale_to_float_sm100", - mutates_args=(""), - ) - def e2m1_and_ufp8sf_scale_to_float_sm100( - e2m1_tensor: torch.Tensor, - ufp8_scale_tensor: torch.Tensor, - global_scale_tensor: Optional[torch.Tensor] = None, - sf_vec_size: int = 16, - ufp8_type: int = 1, - is_sf_swizzled_layout: bool = True, - ) -> torch.Tensor: - """Convert E2M1 format tensor and UFP8 scale factors to float tensor. - - This function performs dequantization by converting a packed FP4 tensor in E2M1 format - back to float values using the associated UFP8 scale factors and global scale. - - Args: - e2m1_tensor (torch.Tensor): Packed FP4 tensor in E2M1 format of shape [M, K/2] with dtype uint8. - ufp8_scale_tensor (torch.Tensor): Scale factors tensor in UFP8 format with dtype uint8. - global_scale_tensor (torch.Tensor, optional): Global scale factor of shape [1] and dtype float32. - sf_vec_size (int, optional): Scale factor vector size. Defaults to 16. - ufp8_type (int, optional): UFP8 scale factor type (0 for UE8M0, 1 for E4M3). Defaults to 1. - is_sf_swizzled_layout (bool, optional): Whether scale factors use swizzled layout. Defaults to True. - - Returns: - torch.Tensor: Dequantized float tensor of shape [M, K] with dtype float32. - """ - out = torch.zeros( - (e2m1_tensor.shape[0], e2m1_tensor.shape[1] * 2), - dtype=torch.float32, - device="cpu", - ) - module.e2m1_and_ufp8sf_scale_to_float_sm100( - e2m1_tensor.cpu(), - ufp8_scale_tensor.cpu().reshape(-1), - global_scale_tensor.cpu(), - out, - sf_vec_size, - ufp8_type, - is_sf_swizzled_layout, - ) - return out - - @register_fake_op("flashinfer::e2m1_and_ufp8sf_scale_to_float_sm100") - def _fake_e2m1_and_ufp8sf_scale_to_float_sm100( - e2m1_tensor: torch.Tensor, - ufp8_scale_tensor: torch.Tensor, - global_scale_tensor: Optional[torch.Tensor] = None, - sf_vec_size: int = 16, - ufp8_type: int = 1, - is_sf_swizzled_layout: bool = True, - ) -> torch.Tensor: - return e2m1_tensor.new_empty( - [e2m1_tensor.shape[0], e2m1_tensor.shape[1] * 2], dtype=torch.float32 - ) - - # Register the module - return SimpleNamespace( - fp4_quantize_sm100=fp4_quantize_sm100, - block_scale_interleave_sm100=block_scale_interleave_sm100, - e2m1_and_ufp8sf_scale_to_float_sm100=e2m1_and_ufp8sf_scale_to_float_sm100, - mxfp4_dequantize_host=mxfp4_dequantize_host, - fp4_batched_quantize_sm100=fp4_batched_quantize_sm100, - silu_and_mul_scaled_nvfp4_experts_quantize_sm100=silu_and_mul_scaled_nvfp4_experts_quantize_sm100, - scaled_fp4_grouped_quant_sm100=scaled_fp4_grouped_quant_sm100, - ) - - -@flashinfer_api -def fp4_quantize( - input: torch.Tensor, - global_scale: Optional[torch.Tensor] = None, - sf_vec_size: int = 16, - sf_use_ue8m0: bool = False, - is_sf_swizzled_layout: bool = True, - is_sf_8x4_layout: bool = False, - enable_pdl: Optional[bool] = None, -) -> Tuple[torch.Tensor, torch.Tensor]: - """Quantize input tensor to FP4 format. - - This function implements FP4 quantization that converts input tensors to a compressed FP4 format - with associated scale factors. It supports various input data types and scale factor layouts. - - Args: - input (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16/fp8_quantized. - global_scale (torch.Tensor, optional): Global scale factor of shape [1] and dtype float32. - sf_vec_size (int, optional): Scale factor vector size. Defaults to 16. - sf_use_ue8m0 (bool, optional): Whether to use UE8M0 format for scale factors. Defaults to False. - is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True. - is_sf_8x4_layout (bool, optional): Whether to use 8x4 layout or 128x4 layout for scale factors. Defaults to False. - enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic Dependent Launch). - If None, automatically detects based on device capability. Defaults to None. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - - Quantized tensor of shape [M, K/2] with dtype FLOAT4_E2M1X2 - - Scale factors tensor with shape determined by layout and sf_vec_size - - Raises: - NotImplementedError: If any of the following features are requested but not implemented: - - BFloat16 input when BFloat16 is not enabled - - FP8 input when FP8 is not enabled - - sf_vec_size other than 16 or 32 - """ - if sf_vec_size != 16 and sf_vec_size != 32: - raise NotImplementedError("sf_vec_size can only be 16 or 32") - - # for column major input, we need to transpose the input - is_column_major = input.stride(-2) == 1 - if is_column_major: - input = input.transpose(-2, -1) - - assert input.shape[-1] % sf_vec_size == 0 - if enable_pdl is None: - enable_pdl = device_support_pdl(input.device) - # get input device sm version - major, minor = get_compute_capability(input.device) - x_q, sf = get_fp4_quantization_module(f"{major}{minor}").fp4_quantize_sm100( - input, - global_scale, - sf_vec_size, - sf_use_ue8m0, - is_sf_swizzled_layout, - is_sf_8x4_layout, - enable_pdl, - ) - sf = sf.reshape((-1, input.shape[-1] // sf_vec_size)) - if is_column_major: - x_q = x_q.transpose(-2, -1) - sf = sf.transpose(-2, -1) - - return x_q, sf - - -@flashinfer_api -def block_scale_interleave(unswizzled_sf: torch.Tensor) -> torch.Tensor: - """Swizzle block scale tensor for FP4 format. - - This function swizzles the block scale tensor to optimize memory access patterns - for FP4 operations. The output needs to be padded in the m dimension to be a multiple of 128. - - Args: - unswizzled_sf (torch.Tensor): Input tensor with dtype uint8 or bfloat16. - - Returns: - torch.Tensor: Swizzled tensor with the same shape as input. - - Raises: - AssertionError: If input dtype is not uint8 or bfloat16. - """ - # TODO(shuw): check input dtype is uint8 - assert ( - unswizzled_sf.dtype == torch.uint8 or unswizzled_sf.dtype == torch.bfloat16 - ), f"Input dtype must be uint8 or bfloat16, got {unswizzled_sf.dtype}" - - major, minor = get_compute_capability(unswizzled_sf.device) - device_arch = f"{major * 10 + minor}" - - return get_fp4_quantization_module(device_arch).block_scale_interleave_sm100( - unswizzled_sf, - ) - - -# Maintain compatibility with libraries using the old name -nvfp4_block_scale_interleave = block_scale_interleave - - -@flashinfer_api -def e2m1_and_ufp8sf_scale_to_float( - e2m1_tensor: torch.Tensor, - ufp8_scale_tensor: torch.Tensor, - global_scale_tensor: Optional[torch.Tensor] = None, - sf_vec_size: int = 16, - ufp8_type: int = 1, - is_sf_swizzled_layout: bool = True, -) -> torch.Tensor: - """Convert E2M1 format tensor and UFP8 scale factors to float tensor. - - This function performs dequantization by converting a packed FP4 tensor in E2M1 format - back to float values using the associated UFP8 scale factors and global scale. - - Args: - e2m1_tensor (torch.Tensor): Packed FP4 tensor in E2M1 format of shape [M, K/2] with dtype uint8. - ufp8_scale_tensor (torch.Tensor): Scale factors tensor in UFP8 format with dtype uint8. - global_scale_tensor (torch.Tensor, optional): Global scale factor of shape [1] and dtype float32. - sf_vec_size (int, optional): Scale factor vector size. Defaults to 16. - ufp8_type (int, optional): UFP8 scale factor type (0 for UE8M0, 1 for E4M3). Defaults to 1. - is_sf_swizzled_layout (bool, optional): Whether scale factors use swizzled layout. Defaults to True. - - Returns: - torch.Tensor: Dequantized float tensor of shape [M, K] with dtype float32. - - """ - # NOTE(Zihao): this is another cpu op, should decouple it from cuda ops in the future - major, minor = get_compute_capability( - torch.device("cuda:0") - ) # select any cuda device to get a compute capability - device_arch = f"{major * 10 + minor}" - return get_fp4_quantization_module( - device_arch - ).e2m1_and_ufp8sf_scale_to_float_sm100( - e2m1_tensor, - ufp8_scale_tensor, - global_scale_tensor, - sf_vec_size, - ufp8_type, - is_sf_swizzled_layout, - ) - - -@flashinfer_api -def shuffle_matrix_a(input_tensor: torch.Tensor, epilogue_tile_m: int) -> torch.Tensor: - """ - PyTorch equivalent of trtllm-gen `shuffleMatrixA` - """ - row_indices = get_shuffle_matrix_a_row_indices(input_tensor, epilogue_tile_m) - - return input_tensor[row_indices.to(input_tensor.device)] - - -@flashinfer_api -def shuffle_matrix_sf_a( - input_tensor: torch.Tensor, - epilogue_tile_m: int, - num_elts_per_sf: int = 16, -): - """ - Cuda implementation of trtllm-gen `shuffleMatrixSfA` but with a caveat. - `shuffleMatrixSfA` expects the input to be in 128x4 layout and then - apply the same shuffling in `shuffleMatrixA` and writes out in 128x4 - layout. - This function expects the input to be in linear layout. It's done this - way because the scaling factors in the NVFP4 checkpoints are quantized - and are in linear layout. - This function doesn't add padding. - """ - - row_indices = get_shuffle_matrix_sf_a_row_indices(input_tensor, epilogue_tile_m) - - w_shuffled = input_tensor[row_indices.to(input_tensor.device)] - - # 128x4 - return block_scale_interleave(w_shuffled) - - -class SfLayout(Enum): - """ - Layout of scale factors for NVFP4. - """ - - layout_128x4 = 0 - layout_8x4 = 1 - layout_linear = 2 - - -@flashinfer_api -def nvfp4_quantize( - a, - a_global_sf, - sfLayout=SfLayout.layout_128x4, - do_shuffle=False, - sf_vec_size=16, - enable_pdl=None, -): - """ - Quantize input tensor to NVFP4 format. - - Parameters: - a (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16. - a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32. - sfLayout (SfLayout, optional): Scale factor layout. Defaults to SfLayout.layout_128x4. - do_shuffle (bool, optional): Whether to shuffle the scale factors. Defaults to False. Only TRTLLM backend needs to shuffle the tensor B scale factors. - sf_vec_size (int, optional): Scale factor vector size. Defaults to 16. - enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic Dependent Launch). - If None, automatically detects based on device capability. Defaults to None. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - - Quantized tensor of shape [M, K/2] with dtype FLOAT4_E2M1X2 - - Scale factors tensor with shape determined by layout and sf_vec_size - """ - - if do_shuffle: - # Weights 128x4 + shuffle. It is done during the model load and we do not care much about the perf - assert sfLayout == SfLayout.layout_128x4 - a_fp4, a_sf = fp4_quantize( - a.cuda(), - a_global_sf.cuda(), - sf_vec_size, - sf_use_ue8m0=False, - is_sf_swizzled_layout=False, - is_sf_8x4_layout=False, - enable_pdl=enable_pdl, - ) - - epilogue_tile_m = 128 - a_fp4 = shuffle_matrix_a(a_fp4.view(torch.uint8), epilogue_tile_m) - a_sf = shuffle_matrix_sf_a(a_sf.view(torch.uint8), epilogue_tile_m).reshape( - a_sf.shape - ) - else: - # Activations with 8x4 layout for SFs (GEMM with small tileN) - # Activations with 128x4 layout for SFs (GEMM with large tileN) - a_fp4, a_sf = fp4_quantize( - a.cuda(), - a_global_sf.cuda(), - sf_vec_size, - sf_use_ue8m0=False, - is_sf_swizzled_layout=sfLayout != SfLayout.layout_linear, - is_sf_8x4_layout=sfLayout == SfLayout.layout_8x4, - enable_pdl=enable_pdl, - ) - - return a_fp4, a_sf - - -@flashinfer_api -def mxfp4_quantize(a): - """ - Quantize input tensor to MXFP4 format. - - Parameters: - a (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - - Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2) - - Scale factors tensor with shape determined by layout and sf_vec_size (uint8) - """ - a_global_sf = (448 * 6) / a.float().abs().nan_to_num().max() - a_fp4, a_sf = fp4_quantize(a.cuda(), a_global_sf.cuda(), 32, True, True) - return a_fp4, a_sf - - -@flashinfer_api -def mxfp4_dequantize(a_fp4, a_sf): - """ - Dequantize input tensor from MXFP4 format. - - Parameters: - a_fp4 (torch.Tensor): Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2) - a_sf (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size (uint8) - - Returns: - torch.Tensor: Dequantized tensor of shape [M, K] with dtype float. - """ - return e2m1_and_ufp8sf_scale_to_float( - a_fp4.cpu().view(torch.uint8), - a_sf.cpu().view(torch.uint8).reshape(-1), - torch.tensor([1.0], device=a_fp4.device), - 32, - 0, - True, - ) - - -@flashinfer_api -def mxfp4_dequantize_host( - weight: torch.Tensor, - scale: torch.Tensor, - group_size: int = 32, -) -> torch.Tensor: - """ - Dequantize input tensor from MXFP4 format on host. - - Parameters: - weight (torch.Tensor): Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2) - scale (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size (uint8) - group_size (int, optional): Group size for dequantization. Defaults to 32. - - Returns: - torch.Tensor: Dequantized tensor of shape [M, K] with dtype float. - """ - # NOTE(Zihao): the cpu op should be decouplied from cuda ops because it's device independent, should refactor this in the future - major, minor = get_compute_capability( - torch.device("cuda:0") - ) # use any cuda device to get a compute capability - device_arch = f"{major * 10 + minor}" - return get_fp4_quantization_module(device_arch).mxfp4_dequantize_host( - weight, - scale, - group_size, - ) - - -@flashinfer_api -def nvfp4_batched_quantize( - a, - a_global_sf, - sf_vec_size=16, -): - """ - Quantize batched input tensor to NVFP4 format. - - Parameters: - a (torch.Tensor): Input tensor of shape [B, M, K] with dtype fp16/bf16. - a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32. - sf_vec_size (int, optional): Scale factor vector size. Defaults to 16. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - - Quantized tensor of shape [B, M, K/2] with dtype FLOAT4_E2M1X2 - - Scale factors tensor with shape determined by layout and sf_vec_size - """ - major, minor = get_compute_capability(a.device) - device_arch = f"{major * 10 + minor}" - a_fp4, a_sf = get_fp4_quantization_module(device_arch).fp4_batched_quantize_sm100( - a, - a_global_sf, - sf_vec_size, - False, - ) - return a_fp4, a_sf - - -@flashinfer_api -def scaled_fp4_grouped_quantize( - a, - mask, - a_global_sf, -): - """ - quantize batched input tensor to NVFP4 format with mask. - Parameters: - a (torch.Tensor): Input tensor of shape [B, M, K] with dtype fp16/bf16. - a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32. - mask (torch.Tensor): Mask tensor to apply before quantization. - Returns: - Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - - Quantized tensor of shape [B, M, K/2] with dtype FLOAT4_E2M1X2 - - Scale factors tensor with shape determined by layout and sf_vec_size - """ - major, minor = get_compute_capability(a.device) - device_arch = f"{major * 10 + minor}" - a_fp4, a_sf = get_fp4_quantization_module( - device_arch - ).scaled_fp4_grouped_quant_sm100( - a, - a_global_sf, - mask, - ) - return a_fp4, a_sf - - -# --------------------------------------------------------------------------- -# NVFP4 KV cache quant/dequant with linear (non-swizzled) block scale layout -# --------------------------------------------------------------------------- - - -@functools.cache -def get_fp4_kv_dequantization_module(): - from .jit.fp4_kv_dequantization import gen_fp4_kv_dequantization_module - - module = gen_fp4_kv_dequantization_module().build_and_load() - - @register_custom_op( - "flashinfer::nvfp4_kv_dequant", - mutates_args=("output",), - ) - def nvfp4_kv_dequant( - fp4_data: torch.Tensor, - block_scales: torch.Tensor, - global_scale: torch.Tensor, - output: torch.Tensor, - ) -> None: - module.nvfp4_kv_dequant(fp4_data, block_scales, global_scale, output) - - @register_fake_op("flashinfer::nvfp4_kv_dequant") - def _fake_nvfp4_kv_dequant( - fp4_data: torch.Tensor, - block_scales: torch.Tensor, - global_scale: torch.Tensor, - output: torch.Tensor, - ) -> None: - pass - - return SimpleNamespace(nvfp4_kv_dequant=nvfp4_kv_dequant) - - -@functools.cache -def get_fp4_kv_quantization_module(): - from .jit.fp4_kv_quantization import gen_fp4_kv_quantization_module - - module = gen_fp4_kv_quantization_module().build_and_load() - - @register_custom_op( - "flashinfer::nvfp4_kv_quant", - mutates_args=("fp4_output", "block_scales"), - ) - def nvfp4_kv_quant( - input: torch.Tensor, - global_scale: torch.Tensor, - fp4_output: torch.Tensor, - block_scales: torch.Tensor, - ) -> None: - module.nvfp4_kv_quant(input, global_scale, fp4_output, block_scales) - - @register_fake_op("flashinfer::nvfp4_kv_quant") - def _fake_nvfp4_kv_quant( - input: torch.Tensor, - global_scale: torch.Tensor, - fp4_output: torch.Tensor, - block_scales: torch.Tensor, - ) -> None: - pass - - return SimpleNamespace(nvfp4_kv_quant=nvfp4_kv_quant) - - -_NVFP4_BLOCK_SIZE = 16 - - -@supported_compute_capability([80, 86, 89, 90, 100, 103, 110, 120, 121]) -def _nvfp4_kv_dequant_check(fp4_data, block_scales, global_scale, output_dtype=None): - return True - - -@backend_requirement({}, common_check=_nvfp4_kv_dequant_check) -@flashinfer_api -def nvfp4_kv_dequantize( - fp4_data: torch.Tensor, - block_scales: torch.Tensor, - global_scale: torch.Tensor, - output_dtype: torch.dtype = torch.bfloat16, -) -> torch.Tensor: - """GPU dequantization of NVFP4 KV cache data with linear block scale layout. - - Requires SM80+. - - Args: - fp4_data (torch.Tensor): Packed FP4 data of shape ``[M, K/2]`` with dtype uint8. - block_scales (torch.Tensor): Per-block FP8 E4M3 scales of shape ``[M, K/16]`` - with dtype uint8. - global_scale (torch.Tensor): Global scale factor of shape ``[1]`` with dtype float32, - on the same CUDA device as fp4_data. - output_dtype (torch.dtype): Output dtype, either ``torch.bfloat16`` or ``torch.float16``. - - Returns: - torch.Tensor: Dequantized tensor of shape ``[M, K]`` with the specified output dtype. - """ - M = fp4_data.size(0) - K = fp4_data.size(1) * 2 - if K % _NVFP4_BLOCK_SIZE != 0: - raise ValueError(f"K dimension ({K}) must be divisible by {_NVFP4_BLOCK_SIZE}") - output = torch.empty((M, K), dtype=output_dtype, device=fp4_data.device) - get_fp4_kv_dequantization_module().nvfp4_kv_dequant( - fp4_data, block_scales, global_scale, output - ) - return output - - -@supported_compute_capability([100, 103, 110, 120, 121]) -def _nvfp4_kv_quant_check(input, global_scale): - return True - - -@backend_requirement({}, common_check=_nvfp4_kv_quant_check) -@flashinfer_api -def nvfp4_kv_quantize( - input: torch.Tensor, - global_scale: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - """GPU quantization to NVFP4 KV cache format with linear block scale layout. - - Requires SM100+ (Blackwell) for the cvt.rn.satfinite.e2m1x2.f32 PTX instruction. - - Args: - input (torch.Tensor): Input tensor of shape [M, K] with dtype bf16 or fp16. - K must be divisible by 16. - global_scale (torch.Tensor): Global scale factor of shape ``[1]`` with dtype float32, - on the same CUDA device as input. - Returns: - Tuple[torch.Tensor, torch.Tensor]: - - fp4_output: Packed FP4 data of shape ``[M, K/2]`` with dtype uint8. - - block_scales: Per-block FP8 E4M3 scales of shape ``[M, K/16]`` with dtype uint8. - """ - M, K = input.shape - if K % _NVFP4_BLOCK_SIZE != 0: - raise ValueError(f"K dimension ({K}) must be divisible by {_NVFP4_BLOCK_SIZE}") - fp4_output = torch.empty((M, K // 2), dtype=torch.uint8, device=input.device) - block_scales = torch.empty( - (M, K // _NVFP4_BLOCK_SIZE), dtype=torch.uint8, device=input.device - ) - get_fp4_kv_quantization_module().nvfp4_kv_quant( - input, global_scale, fp4_output, block_scales - ) - return fp4_output, block_scales +__all__ = [ + "SfLayout", + "block_scale_interleave", + "nvfp4_block_scale_interleave", + "e2m1_and_ufp8sf_scale_to_float", + "fp4_quantize", + "mxfp4_dequantize_host", + "mxfp4_dequantize", + "mxfp4_quantize", + "nvfp4_quantize", + "nvfp4_batched_quantize", + "shuffle_matrix_a", + "shuffle_matrix_sf_a", + "scaled_fp4_grouped_quantize", + "get_fp4_quantization_module", + "gen_fp4_quantization_module", + "gen_fp4_quantization_sm90_module", + "gen_fp4_quantization_sm100_module", + "gen_fp4_quantization_sm103_module", + "gen_fp4_quantization_sm110_module", + "gen_fp4_quantization_sm120_module", + "gen_fp4_quantization_sm120f_module", + "gen_fp4_quantization_sm121_module", + "nvfp4_kv_dequantize", + "nvfp4_kv_quantize", + "_pad_scale_factors", + "_compute_swizzled_layout_sf_size", +] diff --git a/flashinfer/fp8_quantization.py b/flashinfer/fp8_quantization.py index 1d2cdeea76..421191bd1b 100644 --- a/flashinfer/fp8_quantization.py +++ b/flashinfer/fp8_quantization.py @@ -1,208 +1,25 @@ -import functools -from types import SimpleNamespace -from typing import Optional, Tuple - -import torch - -from .api_logging import flashinfer_api -from .jit.fp8_quantization import gen_mxfp8_quantization_sm100_module -from .utils import ( - device_support_pdl, - register_custom_op, - register_fake_op, +""" +Backwards compatibility stub for flashinfer.fp8_quantization. + +This module re-exports all symbols from flashinfer.quantization.fp8_quantization +to maintain backwards compatibility with existing code that imports from +flashinfer.fp8_quantization. + +New code should import from flashinfer.quantization.fp8_quantization directly. +""" + +# Re-export everything from the new location +from .quantization.fp8_quantization import ( + mxfp8_quantize, + mxfp8_dequantize_host, + get_mxfp8_quantization_sm100_module, + # Private functions for backwards compatibility + _compute_swizzled_layout_sf_size, ) - -def _compute_swizzled_layout_sf_size(total_row, total_column, row_size=128): - padded_row = (total_row + row_size - 1) // row_size * row_size - padded_column = (total_column + 3) // 4 * 4 - return padded_row * padded_column - - -@functools.cache -def get_mxfp8_quantization_sm100_module(): - module = gen_mxfp8_quantization_sm100_module().build_and_load() - - @register_custom_op( - "flashinfer::mxfp8_quantize_sm100", - mutates_args=(""), - ) - def mxfp8_quantize_sm100( - input: torch.Tensor, - is_sf_swizzled_layout: bool = True, - alignment: int = 32, - enable_pdl: Optional[bool] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Quantize input tensor to MxFP8 format. - - Args: - input (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16/fp8_quantized. - is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True. - alignment (int, optional): sfVecSize. Defaults to 32. Note that alignment is not used in the host kernel. - enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic Dependent Launch). - If None, automatically detects based on device capability. Defaults to None. - Returns: - Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - - Quantized tensor of shape [M, K] with dtype FLOAT8_E4M3 - - Scale factors tensor with shape determined by layout and sf_vec_size - """ - if input.device.type == "cpu": - out_val = torch.empty(input.shape, dtype=torch.uint8, device=input.device) - if is_sf_swizzled_layout: - out_sf_size = _compute_swizzled_layout_sf_size( - input.shape[0], input.shape[1] // 32, 128 - ) - else: - out_sf_size = input.numel() // 32 - out_sf = torch.empty((out_sf_size,), dtype=torch.uint8, device=input.device) - module.mxfp8_quantize_host( - input, - out_val, - out_sf, - is_sf_swizzled_layout, - ) - return out_val, out_sf - else: - if enable_pdl is None: - enable_pdl = device_support_pdl(input.device) - m = input.numel() // input.shape[-1] - k = input.shape[-1] - padded_k = (k + alignment - 1) // alignment * alignment - out_val = torch.empty( - (*input.shape[:-1], padded_k), - dtype=torch.float8_e4m3fn, - device=input.device, - ) - if is_sf_swizzled_layout: - out_sf_size = _compute_swizzled_layout_sf_size(m, padded_k // 32, 128) - else: - out_sf_size = m * padded_k // 32 - out_sf = torch.empty((out_sf_size,), dtype=torch.uint8, device=input.device) - module.mxfp8_quantize( - input, - out_val, - out_sf, - is_sf_swizzled_layout, - alignment, - enable_pdl, - ) - return out_val, out_sf - - @register_fake_op("flashinfer::mxfp8_quantize_sm100") - def _fake_mxfp8_quantize_sm100( - input: torch.Tensor, - is_sf_swizzled_layout: bool = True, - alignment: int = 32, - ) -> Tuple[torch.Tensor, torch.Tensor]: - m, k = input.shape - return ( - input.new_empty([m, k], dtype=torch.int64), # FLOAT8_E4M3 - input.new_empty([m * k // 32], dtype=torch.int32), # Scale factors - ) - - @register_custom_op( - "flashinfer::mxfp8_dequantize_host_sm100", - mutates_args=("",), - ) - def mxfp8_dequantize_host_sm100( - input: torch.Tensor, - scale_tensor: torch.Tensor, - is_sf_swizzled_layout: bool = True, - ) -> torch.Tensor: - """Dequantize input tensor from MxFP8 format. - - Args: - input (torch.Tensor): Input tensor of shape [M, K] with dtype FLOAT8_E4M3. - scale_tensor (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size. - is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True. - - Returns: - torch.Tensor: Dequantized float tensor of shape [M, K] with dtype float32. - """ - out = torch.empty(input.shape, dtype=torch.float32, device=input.device) - module.mxfp8_dequantize_host( - input, - scale_tensor, - out, - is_sf_swizzled_layout, - ) - return out - - @register_fake_op("flashinfer::mxfp8_dequantize_host_sm100") - def _fake_mxfp8_dequantize_host_sm100( - input: torch.Tensor, - scale_tensor: torch.Tensor, - is_sf_swizzled_layout: bool = True, - ) -> torch.Tensor: - return input.new_empty([input.shape[0], input.shape[1]], dtype=torch.float32) - - # Register the module - return SimpleNamespace( - mxfp8_quantize_sm100=mxfp8_quantize_sm100, - mxfp8_dequantize_host_sm100=mxfp8_dequantize_host_sm100, - ) - - -@flashinfer_api -def mxfp8_quantize( - input: torch.Tensor, - is_sf_swizzled_layout: bool = True, - alignment: int = 32, - enable_pdl: Optional[bool] = None, -) -> Tuple[torch.Tensor, torch.Tensor]: - """Quantize input tensor to MxFP8 format. - - This function implements MxFP8 quantization that converts input tensors to a compressed MxFP8 format - with associated scale factors. It supports various input data types and scale factor layouts. - - Args: - input (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16/fp8_quantized. - is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True. - alignment (int, optional): sfVecSize. Defaults to 32. - enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic Dependent Launch). - If None, automatically detects based on device capability. Defaults to None. - Returns: - Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - - Quantized tensor of shape [M, K] with dtype FLOAT8_E4M3 - - Scale factors tensor with shape determined by layout and sf_vec_size - """ - sf_vec_size = 32 - - assert input.shape[-1] % sf_vec_size == 0 - if enable_pdl is None: - enable_pdl = device_support_pdl(input.device) - x_q, sf = get_mxfp8_quantization_sm100_module().mxfp8_quantize_sm100( - input, - is_sf_swizzled_layout, - alignment, - enable_pdl, - ) - return x_q, sf - - -@flashinfer_api -def mxfp8_dequantize_host( - input: torch.Tensor, - scale_tensor: torch.Tensor, - is_sf_swizzled_layout: bool = True, -) -> torch.Tensor: - """Dequantize input tensor from MxFP8 format. - - This function performs dequantization by converting a packed FP8 tensor in MxFP8 format - back to float values using the associated scale factors. - - Args: - input (torch.Tensor): Packed FP8 tensor in MxFP8 format of shape [M, K] with dtype FLOAT8_E4M3. - scale_tensor (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size. - is_sf_swizzled_layout (bool, optional): Whether scale factors use swizzled layout. Defaults to True. - - Returns: - torch.Tensor: Dequantized float tensor of shape [M, K] with dtype float32. - - """ - - return get_mxfp8_quantization_sm100_module().mxfp8_dequantize_host_sm100( - input, - scale_tensor, - is_sf_swizzled_layout, - ) +__all__ = [ + "mxfp8_quantize", + "mxfp8_dequantize_host", + "get_mxfp8_quantization_sm100_module", + "_compute_swizzled_layout_sf_size", +] diff --git a/flashinfer/quantization/__init__.py b/flashinfer/quantization/__init__.py new file mode 100644 index 0000000000..55c58c343a --- /dev/null +++ b/flashinfer/quantization/__init__.py @@ -0,0 +1,86 @@ +""" +FlashInfer Quantization Module +============================== + +This module provides quantization functions for various formats: +- FP4 (NVFP4, MXFP4) +- FP8 (MXFP8) +- Packbits utilities + +Copyright (c) 2025 by FlashInfer team. +Licensed under the Apache License, Version 2.0. +""" + +# Re-export packbits functions +from .packbits import packbits, segment_packbits + +# Re-export JIT module generator (used by tests and AOT compilation) +from ..jit.quantization import gen_quantization_module + +# Re-export FP8 quantization +from .fp8_quantization import mxfp8_quantize, mxfp8_dequantize_host + +# Re-export FP4 quantization (all public symbols) +from .fp4_quantization import ( + SfLayout, + block_scale_interleave, + nvfp4_block_scale_interleave, + e2m1_and_ufp8sf_scale_to_float, + fp4_quantize, + mxfp4_dequantize_host, + mxfp4_dequantize, + mxfp4_quantize, + nvfp4_quantize, + nvfp4_batched_quantize, + shuffle_matrix_a, + shuffle_matrix_sf_a, + scaled_fp4_grouped_quantize, + get_fp4_quantization_module, # Used by activation.py +) + +# CuTe-DSL kernels (conditionally exported, EXPERIMENTAL) +# Warning: These are experimental APIs and may change without notice. +# Import is guarded to handle environments where cutlass is not installed. +_cute_dsl_available = False +try: + from ..cute_dsl import is_cute_dsl_available + + if is_cute_dsl_available(): + from .kernels.mxfp8_quantize import mxfp8_quantize_cute_dsl + from .kernels.mxfp4_quantize import mxfp4_quantize_cute_dsl + + _cute_dsl_available = True +except ImportError: + pass + +__all__ = [ + # Packbits + "packbits", + "segment_packbits", + # JIT module generator + "gen_quantization_module", + # FP8 + "mxfp8_quantize", + "mxfp8_dequantize_host", + # FP4 + "SfLayout", + "block_scale_interleave", + "nvfp4_block_scale_interleave", + "e2m1_and_ufp8sf_scale_to_float", + "fp4_quantize", + "mxfp4_dequantize_host", + "mxfp4_dequantize", + "mxfp4_quantize", + "nvfp4_quantize", + "nvfp4_batched_quantize", + "shuffle_matrix_a", + "shuffle_matrix_sf_a", + "scaled_fp4_grouped_quantize", + "get_fp4_quantization_module", +] + +if _cute_dsl_available: + __all__ += [ + "mxfp8_quantize_cute_dsl", + "mxfp4_quantize_cute_dsl", + ] diff --git a/flashinfer/quantization/fp4_quantization.py b/flashinfer/quantization/fp4_quantization.py new file mode 100644 index 0000000000..31f0c5db29 --- /dev/null +++ b/flashinfer/quantization/fp4_quantization.py @@ -0,0 +1,1199 @@ +""" +Copyright (c) 2025 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import functools +from enum import Enum +from types import SimpleNamespace +from typing import List, Optional, Tuple + +import torch + +from ..api_logging import flashinfer_api +from ..jit import JitSpec +from ..jit import env as jit_env +from ..jit import ( + gen_jit_spec, + sm121a_nvcc_flags, + sm120a_nvcc_flags, + sm120f_nvcc_flags, + sm110a_nvcc_flags, + sm103a_nvcc_flags, + sm100a_nvcc_flags, + sm90a_nvcc_flags, +) +from ..jit.cpp_ext import is_cuda_version_at_least +from ..utils import ( + backend_requirement, + device_support_pdl, + get_compute_capability, + get_shuffle_matrix_a_row_indices, + get_shuffle_matrix_sf_a_row_indices, + register_custom_op, + register_fake_op, + supported_compute_capability, + round_up, +) + + +def _compute_swizzled_layout_sf_size(total_row, total_column, row_size=128): + padded_row = round_up(total_row, row_size) + padded_column = round_up(total_column, 4) + return padded_row * padded_column + + +def _pad_scale_factors( + unswizzled_sf: torch.Tensor, m: int, n: int, sf_vec_size: int = 16 +) -> torch.Tensor: + """Pad scale factors tensor to meet alignment requirements. + + Args: + unswizzled_sf (torch.Tensor): Input scale factors tensor with dtype uint8. + m (int): M dimension. + n (int): N dimension. + sf_vec_size (int, optional): Scale factor vector size. Defaults to 16. + + Returns: + torch.Tensor: Padded scale factors tensor. + """ + factor = sf_vec_size * 4 + padded_row = round_up(m, 128) + padded_col = round_up(n, factor) + + # Pad the input tensor to [padded_row, padded_col // scaling_vector_size] + pad_rows = padded_row - m + pad_cols = (padded_col - n) // sf_vec_size + if pad_rows == 0 and pad_cols == 0: + return unswizzled_sf + else: + return torch.nn.functional.pad( + unswizzled_sf, (0, pad_cols, 0, pad_rows), mode="constant", value=0 + ).contiguous() + + +def gen_fp4_quantization_sm100_module() -> JitSpec: + return gen_fp4_quantization_module(sm100a_nvcc_flags, "100") + + +def gen_fp4_quantization_sm103_module() -> JitSpec: + return gen_fp4_quantization_module(sm103a_nvcc_flags, "103") + + +def gen_fp4_quantization_sm90_module() -> JitSpec: + return gen_fp4_quantization_module(sm90a_nvcc_flags, "90") + + +def gen_fp4_quantization_sm110_module() -> JitSpec: + return gen_fp4_quantization_module(sm110a_nvcc_flags, "110") + + +def gen_fp4_quantization_sm120_module() -> JitSpec: + return gen_fp4_quantization_module(sm120a_nvcc_flags, "120") + + +def gen_fp4_quantization_sm120f_module() -> JitSpec: + return gen_fp4_quantization_module(sm120f_nvcc_flags, "120f") + + +def gen_fp4_quantization_sm121_module() -> JitSpec: + return gen_fp4_quantization_module(sm121a_nvcc_flags, "121") + + +def gen_fp4_quantization_module(nvcc_flags: List[str], device_arch: str) -> JitSpec: + return gen_jit_spec( + f"fp4_quantization_{device_arch}", + [ + jit_env.FLASHINFER_CSRC_DIR + / "nv_internal/tensorrt_llm/thop/fp4Quantize.cpp", + jit_env.FLASHINFER_CSRC_DIR / "nv_internal/tensorrt_llm/thop/fp4Op.cpp", + jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/kernels/quantization.cu", + jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/envUtils.cpp", + jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/logger.cpp", + jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/stringUtils.cpp", + jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/tllmException.cpp", + ], + extra_cuda_cflags=nvcc_flags + + [ + "-DENABLE_BF16", + "-DENABLE_FP8", + "-DENABLE_FP4" if is_cuda_version_at_least("12.8") else "", + ], + extra_cflags=[ + "-DENABLE_BF16", + "-DENABLE_FP8", + "-DENABLE_FP4" if is_cuda_version_at_least("12.8") else "", + ], + extra_include_paths=[ + jit_env.FLASHINFER_CSRC_DIR / "nv_internal", + jit_env.FLASHINFER_CSRC_DIR / "nv_internal" / "include", + ], + ) + + +@functools.cache +def get_fp4_quantization_module(backend: str = "100"): + backend_modules = { + "121": gen_fp4_quantization_sm121_module, + "120f": gen_fp4_quantization_sm120f_module, + "120": gen_fp4_quantization_sm120_module, + "110": gen_fp4_quantization_sm110_module, + "103": gen_fp4_quantization_sm103_module, + "100": gen_fp4_quantization_sm100_module, + "90": gen_fp4_quantization_sm90_module, + } + + # Prefer 'f' (family / feature-set) variant for SM12x when CUDA >= 12.9, + # as it enables native FP4 conversion instructions (cvt.rn.satfinite.e2m1x2.f32). + # sm_120f covers the entire SM12x family (both SM120 and SM121). + # See: https://developer.nvidia.com/blog/nvidia-blackwell-and-nvidia-cuda-12-9-introduce-family-specific-architecture-features/ + if backend in ("120", "121"): + from ..utils import version_at_least + + if version_at_least(torch.version.cuda, "12.9"): + backend = "120f" + + if backend not in backend_modules: + raise ValueError(f"Invalid backend: {backend}") + + module = backend_modules[backend]().build_and_load() + + @register_custom_op( + "flashinfer::fp4_quantize_sm100", + mutates_args=(""), + ) + def fp4_quantize_sm100( + input: torch.Tensor, + global_scale: Optional[torch.Tensor] = None, + sf_vec_size: int = 16, + sf_use_ue8m0: bool = False, + is_sf_swizzled_layout: bool = True, + is_sf_8x4_layout: bool = False, + enable_pdl: Optional[bool] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Quantize input tensor to FP4 format. + + Args: + input (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16/fp8_quantized. + global_scale (torch.Tensor, optional): Global scale factor of shape [1] and dtype float32. + sf_vec_size (int, optional): Scale factor vector size. Defaults to 16. + sf_use_ue8m0 (bool, optional): Whether to use UE8M0 format for scale factors. Defaults to False. + is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True. + is_sf_8x4_layout (bool, optional): Whether to use 8x4 layout or 128x4 layout for scale factors. Defaults to False. + enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic Dependent Launch). + If None, automatically detects based on device capability. Defaults to None. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - Quantized tensor of shape [M, K/2] with dtype FLOAT4_E2M1X2 + - Scale factors tensor with shape determined by layout and sf_vec_size + """ + if enable_pdl is None: + enable_pdl = device_support_pdl(input.device) + out_val = torch.empty( + (*input.shape[:-1], input.shape[-1] // 2), + dtype=torch.uint8, + device=input.device, + ) + m = input.numel() // input.shape[-1] + k = input.shape[-1] + if is_sf_swizzled_layout: + out_sf_size = _compute_swizzled_layout_sf_size( + m, k // sf_vec_size, 8 if is_sf_8x4_layout else 128 + ) + out_sf_size_padded = out_sf_size + else: + out_sf_size = m * k // sf_vec_size + out_sf_size_padded = round_up(m, 16) * k // sf_vec_size + out_sf = torch.empty( + (out_sf_size_padded,), dtype=torch.uint8, device=input.device + ) + module.fp4_quantize( + input, + global_scale, + out_val, + out_sf, + sf_vec_size, + sf_use_ue8m0, + is_sf_swizzled_layout, + is_sf_8x4_layout, + enable_pdl, + ) + return out_val, out_sf[:out_sf_size] + + @register_fake_op("flashinfer::fp4_quantize_sm100") + def _fake_fp4_quantize_sm100( + input: torch.Tensor, + global_scale: Optional[torch.Tensor] = None, + sf_vec_size: int = 16, + sf_use_ue8m0: bool = False, + is_sf_swizzled_layout: bool = True, + ) -> Tuple[torch.Tensor, torch.Tensor]: + m, k = input.shape + return ( + input.new_empty([m, k // 2], dtype=torch.int64), # FLOAT4_E2M1X2 + input.new_empty([m * k // sf_vec_size], dtype=torch.int32), # Scale factors + ) + + @register_custom_op( + "flashinfer::mxfp4_dequantize_host", + mutates_args=(""), + ) + def mxfp4_dequantize_host( + weight: torch.Tensor, + scale: torch.Tensor, + group_size: int = 32, + ) -> torch.Tensor: + out = torch.empty( + (weight.shape[0], weight.shape[1] * 2), + dtype=torch.float32, + device=weight.device, + ) + module.mxfp4_dequantize_host( + weight, + scale, + out, + group_size, + ) + return out + + @register_fake_op("flashinfer::mxfp4_dequantize_host") + def _fake_mxfp4_dequantize_host( + weight: torch.Tensor, + scale: torch.Tensor, + group_size: int = 32, + ) -> torch.Tensor: + return weight.new_empty( + [weight.shape[0], weight.shape[1] * 2], dtype=torch.float32 + ) + + @register_custom_op( + "flashinfer::block_scale_interleave_sm100", + mutates_args=("",), + ) + def block_scale_interleave_sm100( + unswizzled_sf: torch.Tensor, + ) -> torch.Tensor: + """Swizzle block scale tensor for FP4 format. + + Args: + unswizzled_sf (torch.Tensor): unswizzled block scale tensor with dtype uint8 or bfloat16. + + Returns: + torch.Tensor: output tensor for swizzled block scale with dtype uint8 or bfloat16. + """ + num_experts = unswizzled_sf.shape[0] if unswizzled_sf.dim() == 3 else 1 + expert_out_size = _compute_swizzled_layout_sf_size( + unswizzled_sf.shape[-2], unswizzled_sf.shape[-1], 128 + ) + out = torch.empty( + (num_experts * expert_out_size,), + dtype=unswizzled_sf.dtype, + device=unswizzled_sf.device, + ) + module.block_scale_interleave_sm100(unswizzled_sf, out) + return out + + @register_fake_op("flashinfer::block_scale_interleave_sm100") + def _fake_block_scale_interleave_sm100( + unswizzled_sf: torch.Tensor, + ) -> torch.Tensor: + return unswizzled_sf.new_empty( + [unswizzled_sf.shape[0] * unswizzled_sf.shape[1] // 16], dtype=torch.uint8 + ) + + @register_custom_op( + "flashinfer::fp4_batched_quantize_sm100", + mutates_args=("",), + ) + def fp4_batched_quantize_sm100( + input: torch.Tensor, + global_scale: Optional[torch.Tensor] = None, + sf_vec_size: int = 16, + sf_use_ue8m0: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Quantize a batched tensor to FP4 (E2M1x2) with per-block scale factors. + + This function converts a float/bfloat16 (or FP8-quantized) input tensor into a + packed FP4 tensor using the E2M1 format (two 4-bit values per byte), along with + per-block scale factors. Scale factors are encoded as UE4M3 by default, or UE8M0 + when requested, and an optional global scale can be applied. + + Args: + input (torch.Tensor): Input tensor of shape [B, M, K] with dtype torch.float16, + torch.bfloat16, or an FP8-quantized dtype supported by the kernel. + global_scale (torch.Tensor, optional): Global scale factor of shape [1] and + dtype float32. + sf_vec_size (int, optional): Scale-factor vector size and alignment unit along K. + Supported/expected values: + - 16 (NVFP4 path; supported) + - 32 (MXFP4 path; not supported yet) + Defaults to 16. + sf_use_ue8m0 (bool, optional): Scale-factor encoding type. + False → UE4M3 (default), True → UE8M0. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: + - self_fp4 (torch.Tensor): Packed FP4 tensor in E2M1x2 format of shape + [B, M, K // 2] with dtype torch.uint8 (two FP4 lanes per byte). + - self_block_scale_factors (torch.Tensor): Block scale factors with dtype + uint8 (UE4M3 or UE8M0), laid out as a flat buffer of shape + [B, ceil(M / 128) * 128 * ceil(K / sf_vec_size / 4) * 4]. + + Notes: + - K must be even (because outputs pack two FP4 values per byte). + - For best performance, K should be a multiple of sf_vec_size; the scale-factor + buffer is aligned to sf_vec_size along K, pads M to multiples of 128, and + rounds (K / sf_vec_size) up to a multiple of 4 for storage. + - The batch dimension B is preserved for both outputs. + """ + b, m, k = input.shape + out_val = torch.empty( + (b, m, k // 2), + dtype=torch.uint8, + device=input.device, + ) + out_sf = torch.empty( + (b, _compute_swizzled_layout_sf_size(m, k // sf_vec_size, 128)), + dtype=torch.uint8, + device=input.device, + ) + module.fp4_batched_quantize( + input, + global_scale, + out_val, + out_sf, + sf_vec_size, + sf_use_ue8m0, + ) + return out_val, out_sf + + @register_fake_op("flashinfer::fp4_batched_quantize_sm100") + def _fake_fp4_batched_quantize_sm100( + input: torch.Tensor, + global_scale: Optional[torch.Tensor] = None, + sf_vec_size: int = 16, + sf_use_ue8m0: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + b, m, k = input.shape + return ( + input.new_empty([b, m, k // 2], dtype=torch.uint8), # FLOAT4_E2M1X2 + input.new_empty( + [b, _compute_swizzled_layout_sf_size(m, k // sf_vec_size, 128)], + dtype=torch.uint8, + ), # swizzled SF buffer + ) + + @register_custom_op( + "flashinfer::silu_and_mul_scaled_nvfp4_experts_quantize_sm100", + mutates_args=("",), + ) + def silu_and_mul_scaled_nvfp4_experts_quantize_sm100( + input: torch.Tensor, + mask: torch.Tensor, + global_scale: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Quantize a silu and matmul with masked batched tensor to FP4 (E2M1x2) with per-block scale factors. + + This function first does silu and matmul to a float/bfloat16 input tensor then convect the result + into a packed FP4 tensor using the E2M1 format (two 4-bit values per byte), along with + per-block scale factors. Scale factors are encoded as UE4M3 by default, or UE8M0 + when requested, and an optional global scale can be applied. + + Args: + input (torch.Tensor): Input tensor of shape [B, M, K] with dtype torch.float16, + torch.bfloat16, or an FP8-quantized dtype supported by the kernel. + mask (torch.Tensor): mask tensor of shape [B] with dtype torch.int32. + global_scale (torch.Tensor, optional): Global scale factor of shape [1] and + dtype float32. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: + - self_fp4 (torch.Tensor): Packed FP4 tensor in E2M1x2 format of shape + [B, M, K // 2] with dtype torch.uint8 (two FP4 lanes per byte). + - self_block_scale_factors (torch.Tensor): Block scale factors with dtype + uint8 (UE4M3 or UE8M0), laid out as a flat buffer of shape + [B, ceil(M / 128) * 128 * ceil(K / sf_vec_size / 4) * 4]. + + Notes: + - K must be even (because outputs pack two FP4 values per byte). + - For best performance, K should be a multiple of sf_vec_size; the scale-factor + buffer is aligned to sf_vec_size along K, pads M to multiples of 128, and + rounds (K / sf_vec_size) up to a multiple of 4 for storage. + - The batch dimension B is preserved for both outputs. + """ + device = input.device + l, m, k_by_2 = input.shape + k = k_by_2 // 2 + sf_vec_size = 16 + assert k % sf_vec_size == 0, f"k must be multiple of 16, but got {k}." + + scale_k = k // sf_vec_size + padded_k = round_up(scale_k, 4) + padded_k_int32 = padded_k // 4 + padded_m = round_up(m, 128) + output = torch.empty(l, m, k // 2, device=device, dtype=torch.uint8) + output_scales = torch.empty( + l, padded_m, padded_k_int32, device=device, dtype=torch.int32 + ) + + module.silu_and_mul_scaled_nvfp4_experts_quantize( + output.view(l * m, k // 2), + output_scales.view(l * padded_m, padded_k_int32), + input.view(l * m, k_by_2), + global_scale, + mask, + True, + ) + output = output.permute(1, 2, 0) + output_scales = output_scales.view(torch.float8_e4m3fn).view( + l, padded_m // 128, padded_k // 4, 32, 4, 4 + ) + output_scales = output_scales.permute(3, 4, 1, 5, 2, 0) + return output, output_scales + + @register_fake_op("flashinfer::silu_and_mul_scaled_nvfp4_experts_quantize_sm100") + def _fake_silu_and_mul_scaled_nvfp4_experts_quantize_sm100( + input: torch.Tensor, + mask: torch.Tensor, + global_scale: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + device = input.device + l, m, k_by_2 = input.shape + k = k_by_2 // 2 + sf_vec_size = 16 + assert k % sf_vec_size == 0, f"k must be multiple of 16, but got {k}." + + scale_k = k // sf_vec_size + padded_k = round_up(scale_k, 4) + padded_k_int32 = padded_k // 4 + padded_m = round_up(m, 128) + output = torch.empty(l, m, k // 2, device=device, dtype=torch.uint8) + output_scales = torch.empty( + l, padded_m, padded_k_int32, device=device, dtype=torch.int32 + ) + + output_scales = output_scales.view(torch.float8_e4m3fn).view( + l, padded_m // 128, padded_k // 4, 32, 4, 4 + ) + output_scales = output_scales.permute(3, 4, 1, 5, 2, 0) + return (output, output_scales) + + @register_custom_op( + "flashinfer::scaled_fp4_grouped_quant_sm100", + mutates_args=("",), + ) + def scaled_fp4_grouped_quant_sm100( + input_tensor: torch.Tensor, + input_global_scale: torch.Tensor, + mask: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantize input tensor to FP4 and return quantized tensor and scale, for + grouped gemm inputs (e.g., grouped_gemm_nt_masked for flashinfer). + Args: + input: The input tensor to be quantized to FP4, with shape (l, m, k) + l is number of groups, m is number of tokens per group, k is number of features. + input_global_scale: A scalar scaling factor for the entire tensor, with + shape (l,). + Outputs: + output: The quantized tensor in FP4, with shape (m, k // 2, l) but the physical + layout is (l, m, k // 2). `// 2` is because two fp4 values are packed into + an uint8. + output_scales: The blockscale tensor in FP8-E4M3, with shape (32, 4, rm, 4, rk, l) + but the physical layout is (l, rm, rk, 32, 4, 4). + Note: + For the shape of output_scales, `32 * 4 * rm` is a padded m to nearest multiple of 128. + `4 * rk` is a padded `k // 16` to nearest multiple of 4. These layout constants are + required by the NVIDIA Blackwell MMA operations. + """ + device = input_tensor.device + l, m, k = input_tensor.shape + sf_vec_size = 16 + assert k % sf_vec_size == 0, f"k must be multiple of 16, but got {k}." + + scale_k = k // sf_vec_size + padded_k = round_up(scale_k, 4) + padded_k_int32 = padded_k // 4 + padded_m = round_up(m, 128) + output = torch.empty(l, m, k // 2, device=device, dtype=torch.uint8) + output_scales = torch.empty( + l, padded_m, padded_k_int32, device=device, dtype=torch.int32 + ) + + module.silu_and_mul_scaled_nvfp4_experts_quantize( + output.view(l * m, k // 2), + output_scales.view(l * padded_m, padded_k_int32), + input_tensor.view(l * m, k), + input_global_scale, + mask, + False, + ) + # The physical layout of the output is (l, m, k // 2), but we want to return a + # logical layout (m, k // 2, l) required by the flashinfer masked group gemm. + output = output.permute(1, 2, 0) + # The physical layout of the output scales is already swizzled as (l, rm, rk, 32, 4, 4), a + # requirement for the flashinfer masked group gemm, where rm=m/128 and rk=k/4. The logic + # layout is (32, 4, rm, 4, rk, l). + output_scales = output_scales.view(torch.float8_e4m3fn).view( + l, padded_m // 128, padded_k // 4, 32, 4, 4 + ) + output_scales = output_scales.permute(3, 4, 1, 5, 2, 0) + return output, output_scales + + @register_fake_op("flashinfer::scaled_fp4_grouped_quant_sm100") + def _fake_scaled_fp4_grouped_quant_sm100( + input_tensor: torch.Tensor, + input_global_scale: torch.Tensor, + mask: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + device = input_tensor.device + l, m, k = input_tensor.shape + sf_vec_size = 16 + assert k % sf_vec_size == 0, f"k must be multiple of 16, but got {k}." + + scale_k = k // sf_vec_size + padded_k = round_up(scale_k, 4) + padded_k_int32 = padded_k // 4 + padded_m = round_up(m, 128) + output = torch.empty(l, m, k // 2, device=device, dtype=torch.uint8) + output_scales = torch.empty( + l, padded_m, padded_k_int32, device=device, dtype=torch.int32 + ) + + output = output.permute(1, 2, 0) + output_scales = output_scales.view(torch.float8_e4m3fn).view( + l, padded_m // 128, padded_k // 4, 32, 4, 4 + ) + output_scales = output_scales.permute(3, 4, 1, 5, 2, 0) + return output, output_scales + + @register_custom_op( + "flashinfer::e2m1_and_ufp8sf_scale_to_float_sm100", + mutates_args=(""), + ) + def e2m1_and_ufp8sf_scale_to_float_sm100( + e2m1_tensor: torch.Tensor, + ufp8_scale_tensor: torch.Tensor, + global_scale_tensor: Optional[torch.Tensor] = None, + sf_vec_size: int = 16, + ufp8_type: int = 1, + is_sf_swizzled_layout: bool = True, + ) -> torch.Tensor: + """Convert E2M1 format tensor and UFP8 scale factors to float tensor. + + This function performs dequantization by converting a packed FP4 tensor in E2M1 format + back to float values using the associated UFP8 scale factors and global scale. + + Args: + e2m1_tensor (torch.Tensor): Packed FP4 tensor in E2M1 format of shape [M, K/2] with dtype uint8. + ufp8_scale_tensor (torch.Tensor): Scale factors tensor in UFP8 format with dtype uint8. + global_scale_tensor (torch.Tensor, optional): Global scale factor of shape [1] and dtype float32. + sf_vec_size (int, optional): Scale factor vector size. Defaults to 16. + ufp8_type (int, optional): UFP8 scale factor type (0 for UE8M0, 1 for E4M3). Defaults to 1. + is_sf_swizzled_layout (bool, optional): Whether scale factors use swizzled layout. Defaults to True. + + Returns: + torch.Tensor: Dequantized float tensor of shape [M, K] with dtype float32. + """ + out = torch.zeros( + (e2m1_tensor.shape[0], e2m1_tensor.shape[1] * 2), + dtype=torch.float32, + device="cpu", + ) + module.e2m1_and_ufp8sf_scale_to_float_sm100( + e2m1_tensor.cpu(), + ufp8_scale_tensor.cpu().reshape(-1), + global_scale_tensor.cpu(), + out, + sf_vec_size, + ufp8_type, + is_sf_swizzled_layout, + ) + return out + + @register_fake_op("flashinfer::e2m1_and_ufp8sf_scale_to_float_sm100") + def _fake_e2m1_and_ufp8sf_scale_to_float_sm100( + e2m1_tensor: torch.Tensor, + ufp8_scale_tensor: torch.Tensor, + global_scale_tensor: Optional[torch.Tensor] = None, + sf_vec_size: int = 16, + ufp8_type: int = 1, + is_sf_swizzled_layout: bool = True, + ) -> torch.Tensor: + return e2m1_tensor.new_empty( + [e2m1_tensor.shape[0], e2m1_tensor.shape[1] * 2], dtype=torch.float32 + ) + + # Register the module + return SimpleNamespace( + fp4_quantize_sm100=fp4_quantize_sm100, + block_scale_interleave_sm100=block_scale_interleave_sm100, + e2m1_and_ufp8sf_scale_to_float_sm100=e2m1_and_ufp8sf_scale_to_float_sm100, + mxfp4_dequantize_host=mxfp4_dequantize_host, + fp4_batched_quantize_sm100=fp4_batched_quantize_sm100, + silu_and_mul_scaled_nvfp4_experts_quantize_sm100=silu_and_mul_scaled_nvfp4_experts_quantize_sm100, + scaled_fp4_grouped_quant_sm100=scaled_fp4_grouped_quant_sm100, + ) + + +@flashinfer_api +def fp4_quantize( + input: torch.Tensor, + global_scale: Optional[torch.Tensor] = None, + sf_vec_size: int = 16, + sf_use_ue8m0: bool = False, + is_sf_swizzled_layout: bool = True, + is_sf_8x4_layout: bool = False, + enable_pdl: Optional[bool] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Quantize input tensor to FP4 format. + + This function implements FP4 quantization that converts input tensors to a compressed FP4 format + with associated scale factors. It supports various input data types and scale factor layouts. + + Args: + input (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16/fp8_quantized. + global_scale (torch.Tensor, optional): Global scale factor of shape [1] and dtype float32. + sf_vec_size (int, optional): Scale factor vector size. Defaults to 16. + sf_use_ue8m0 (bool, optional): Whether to use UE8M0 format for scale factors. Defaults to False. + is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True. + is_sf_8x4_layout (bool, optional): Whether to use 8x4 layout or 128x4 layout for scale factors. Defaults to False. + enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic Dependent Launch). + If None, automatically detects based on device capability. Defaults to None. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - Quantized tensor of shape [M, K/2] with dtype FLOAT4_E2M1X2 + - Scale factors tensor with shape determined by layout and sf_vec_size + + Raises: + NotImplementedError: If any of the following features are requested but not implemented: + - BFloat16 input when BFloat16 is not enabled + - FP8 input when FP8 is not enabled + - sf_vec_size other than 16 or 32 + """ + if sf_vec_size != 16 and sf_vec_size != 32: + raise NotImplementedError("sf_vec_size can only be 16 or 32") + + # for column major input, we need to transpose the input + is_column_major = input.stride(-2) == 1 + if is_column_major: + input = input.transpose(-2, -1) + + assert input.shape[-1] % sf_vec_size == 0 + if enable_pdl is None: + enable_pdl = device_support_pdl(input.device) + # get input device sm version + major, minor = get_compute_capability(input.device) + x_q, sf = get_fp4_quantization_module(f"{major}{minor}").fp4_quantize_sm100( + input, + global_scale, + sf_vec_size, + sf_use_ue8m0, + is_sf_swizzled_layout, + is_sf_8x4_layout, + enable_pdl, + ) + sf = sf.reshape((-1, input.shape[-1] // sf_vec_size)) + if is_column_major: + x_q = x_q.transpose(-2, -1) + sf = sf.transpose(-2, -1) + + return x_q, sf + + +@flashinfer_api +def block_scale_interleave(unswizzled_sf: torch.Tensor) -> torch.Tensor: + """Swizzle block scale tensor for FP4 format. + + This function swizzles the block scale tensor to optimize memory access patterns + for FP4 operations. The output needs to be padded in the m dimension to be a multiple of 128. + + Args: + unswizzled_sf (torch.Tensor): Input tensor with dtype uint8 or bfloat16. + + Returns: + torch.Tensor: Swizzled tensor with the same shape as input. + + Raises: + AssertionError: If input dtype is not uint8 or bfloat16. + """ + # TODO(shuw): check input dtype is uint8 + assert ( + unswizzled_sf.dtype == torch.uint8 or unswizzled_sf.dtype == torch.bfloat16 + ), f"Input dtype must be uint8 or bfloat16, got {unswizzled_sf.dtype}" + + major, minor = get_compute_capability(unswizzled_sf.device) + device_arch = f"{major * 10 + minor}" + + return get_fp4_quantization_module(device_arch).block_scale_interleave_sm100( + unswizzled_sf, + ) + + +# Maintain compatibility with libraries using the old name +nvfp4_block_scale_interleave = block_scale_interleave + + +@flashinfer_api +def e2m1_and_ufp8sf_scale_to_float( + e2m1_tensor: torch.Tensor, + ufp8_scale_tensor: torch.Tensor, + global_scale_tensor: Optional[torch.Tensor] = None, + sf_vec_size: int = 16, + ufp8_type: int = 1, + is_sf_swizzled_layout: bool = True, +) -> torch.Tensor: + """Convert E2M1 format tensor and UFP8 scale factors to float tensor. + + This function performs dequantization by converting a packed FP4 tensor in E2M1 format + back to float values using the associated UFP8 scale factors and global scale. + + Args: + e2m1_tensor (torch.Tensor): Packed FP4 tensor in E2M1 format of shape [M, K/2] with dtype uint8. + ufp8_scale_tensor (torch.Tensor): Scale factors tensor in UFP8 format with dtype uint8. + global_scale_tensor (torch.Tensor, optional): Global scale factor of shape [1] and dtype float32. + sf_vec_size (int, optional): Scale factor vector size. Defaults to 16. + ufp8_type (int, optional): UFP8 scale factor type (0 for UE8M0, 1 for E4M3). Defaults to 1. + is_sf_swizzled_layout (bool, optional): Whether scale factors use swizzled layout. Defaults to True. + + Returns: + torch.Tensor: Dequantized float tensor of shape [M, K] with dtype float32. + + """ + # NOTE(Zihao): this is another cpu op, should decouple it from cuda ops in the future + major, minor = get_compute_capability( + torch.device("cuda:0") + ) # select any cuda device to get a compute capability + device_arch = f"{major * 10 + minor}" + return get_fp4_quantization_module( + device_arch + ).e2m1_and_ufp8sf_scale_to_float_sm100( + e2m1_tensor, + ufp8_scale_tensor, + global_scale_tensor, + sf_vec_size, + ufp8_type, + is_sf_swizzled_layout, + ) + + +@flashinfer_api +def shuffle_matrix_a(input_tensor: torch.Tensor, epilogue_tile_m: int) -> torch.Tensor: + """ + PyTorch equivalent of trtllm-gen `shuffleMatrixA` + """ + row_indices = get_shuffle_matrix_a_row_indices(input_tensor, epilogue_tile_m) + + return input_tensor[row_indices.to(input_tensor.device)] + + +@flashinfer_api +def shuffle_matrix_sf_a( + input_tensor: torch.Tensor, + epilogue_tile_m: int, + num_elts_per_sf: int = 16, +): + """ + Cuda implementation of trtllm-gen `shuffleMatrixSfA` but with a caveat. + `shuffleMatrixSfA` expects the input to be in 128x4 layout and then + apply the same shuffling in `shuffleMatrixA` and writes out in 128x4 + layout. + This function expects the input to be in linear layout. It's done this + way because the scaling factors in the NVFP4 checkpoints are quantized + and are in linear layout. + This function doesn't add padding. + """ + + row_indices = get_shuffle_matrix_sf_a_row_indices(input_tensor, epilogue_tile_m) + + w_shuffled = input_tensor[row_indices.to(input_tensor.device)] + + # 128x4 + return block_scale_interleave(w_shuffled) + + +class SfLayout(Enum): + """ + Layout of scale factors for NVFP4. + """ + + layout_128x4 = 0 + layout_8x4 = 1 + layout_linear = 2 + + +@flashinfer_api +def nvfp4_quantize( + a, + a_global_sf, + sfLayout=SfLayout.layout_128x4, + do_shuffle=False, + sf_vec_size=16, + enable_pdl=None, +): + """ + Quantize input tensor to NVFP4 format. + + Parameters: + a (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16. + a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32. + sfLayout (SfLayout, optional): Scale factor layout. Defaults to SfLayout.layout_128x4. + do_shuffle (bool, optional): Whether to shuffle the scale factors. Defaults to False. Only TRTLLM backend needs to shuffle the tensor B scale factors. + sf_vec_size (int, optional): Scale factor vector size. Defaults to 16. + enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic Dependent Launch). + If None, automatically detects based on device capability. Defaults to None. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - Quantized tensor of shape [M, K/2] with dtype FLOAT4_E2M1X2 + - Scale factors tensor with shape determined by layout and sf_vec_size + """ + + if do_shuffle: + # Weights 128x4 + shuffle. It is done during the model load and we do not care much about the perf + assert sfLayout == SfLayout.layout_128x4 + a_fp4, a_sf = fp4_quantize( + a.cuda(), + a_global_sf.cuda(), + sf_vec_size, + sf_use_ue8m0=False, + is_sf_swizzled_layout=False, + is_sf_8x4_layout=False, + enable_pdl=enable_pdl, + ) + + epilogue_tile_m = 128 + a_fp4 = shuffle_matrix_a(a_fp4.view(torch.uint8), epilogue_tile_m) + a_sf = shuffle_matrix_sf_a(a_sf.view(torch.uint8), epilogue_tile_m).reshape( + a_sf.shape + ) + else: + # Activations with 8x4 layout for SFs (GEMM with small tileN) + # Activations with 128x4 layout for SFs (GEMM with large tileN) + a_fp4, a_sf = fp4_quantize( + a.cuda(), + a_global_sf.cuda(), + sf_vec_size, + sf_use_ue8m0=False, + is_sf_swizzled_layout=sfLayout != SfLayout.layout_linear, + is_sf_8x4_layout=sfLayout == SfLayout.layout_8x4, + enable_pdl=enable_pdl, + ) + + return a_fp4, a_sf + + +@flashinfer_api +def mxfp4_quantize( + a: torch.Tensor, + backend: str = "cuda", + enable_pdl: Optional[bool] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantize input tensor to MXFP4 format. + + Parameters: + a (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16. + backend (str, optional): Backend to use for quantization. + - "cuda": Use CUDA kernel (default, stable) + - "cute-dsl": Use CuTe-DSL kernel (requires SM100+, **experimental**) + enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic + Dependent Launch). Only used when backend="cute-dsl". + If None, automatically detects based on device capability. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2) + - Scale factors tensor with shape determined by layout and sf_vec_size (uint8) + + Warning: + The "cute-dsl" backend is **experimental** and not part of the stable API. + It may change or be removed in future versions without notice. + Use at your own risk for production workloads. + """ + if backend == "cute-dsl": + from ..cute_dsl import is_cute_dsl_available + + if not is_cute_dsl_available(): + raise RuntimeError( + "CuTe-DSL backend requested but CuTe-DSL is not available. " + "Please install the required dependencies." + ) + from .kernels.mxfp4_quantize import mxfp4_quantize_cute_dsl + + return mxfp4_quantize_cute_dsl(a, enable_pdl=enable_pdl) + elif backend == "cuda": + a_global_sf = (448 * 6) / a.float().abs().nan_to_num().max() + a_fp4, a_sf = fp4_quantize(a.cuda(), a_global_sf.cuda(), 32, True, True) + return a_fp4, a_sf + else: + raise ValueError(f"Unknown backend: {backend}. Must be 'cuda' or 'cute-dsl'.") + + +@flashinfer_api +def mxfp4_dequantize(a_fp4, a_sf): + """ + Dequantize input tensor from MXFP4 format. + + Parameters: + a_fp4 (torch.Tensor): Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2) + a_sf (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size (uint8) + + Returns: + torch.Tensor: Dequantized tensor of shape [M, K] with dtype float. + """ + return e2m1_and_ufp8sf_scale_to_float( + a_fp4.cpu().view(torch.uint8), + a_sf.cpu().view(torch.uint8).reshape(-1), + torch.tensor([1.0], device=a_fp4.device), + 32, + 0, + True, + ) + + +@flashinfer_api +def mxfp4_dequantize_host( + weight: torch.Tensor, + scale: torch.Tensor, + group_size: int = 32, +) -> torch.Tensor: + """ + Dequantize input tensor from MXFP4 format on host. + + Parameters: + weight (torch.Tensor): Quantized tensor of shape [M, K/2] with dtype uint8 (FLOAT4_E2M1X2) + scale (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size (uint8) + group_size (int, optional): Group size for dequantization. Defaults to 32. + + Returns: + torch.Tensor: Dequantized tensor of shape [M, K] with dtype float. + """ + # NOTE(Zihao): the cpu op should be decouplied from cuda ops because it's device independent, should refactor this in the future + major, minor = get_compute_capability( + torch.device("cuda:0") + ) # use any cuda device to get a compute capability + device_arch = f"{major * 10 + minor}" + return get_fp4_quantization_module(device_arch).mxfp4_dequantize_host( + weight, + scale, + group_size, + ) + + +@flashinfer_api +def nvfp4_batched_quantize( + a, + a_global_sf, + sf_vec_size=16, +): + """ + Quantize batched input tensor to NVFP4 format. + + Parameters: + a (torch.Tensor): Input tensor of shape [B, M, K] with dtype fp16/bf16. + a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32. + sf_vec_size (int, optional): Scale factor vector size. Defaults to 16. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - Quantized tensor of shape [B, M, K/2] with dtype FLOAT4_E2M1X2 + - Scale factors tensor with shape determined by layout and sf_vec_size + """ + major, minor = get_compute_capability(a.device) + device_arch = f"{major * 10 + minor}" + a_fp4, a_sf = get_fp4_quantization_module(device_arch).fp4_batched_quantize_sm100( + a, + a_global_sf, + sf_vec_size, + False, + ) + return a_fp4, a_sf + + +@flashinfer_api +def scaled_fp4_grouped_quantize( + a, + mask, + a_global_sf, +): + """ + quantize batched input tensor to NVFP4 format with mask. + Parameters: + a (torch.Tensor): Input tensor of shape [B, M, K] with dtype fp16/bf16. + a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32. + mask (torch.Tensor): Mask tensor to apply before quantization. + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - Quantized tensor of shape [B, M, K/2] with dtype FLOAT4_E2M1X2 + - Scale factors tensor with shape determined by layout and sf_vec_size + """ + major, minor = get_compute_capability(a.device) + device_arch = f"{major * 10 + minor}" + a_fp4, a_sf = get_fp4_quantization_module( + device_arch + ).scaled_fp4_grouped_quant_sm100( + a, + a_global_sf, + mask, + ) + return a_fp4, a_sf + + +# --------------------------------------------------------------------------- +# NVFP4 KV cache quant/dequant with linear (non-swizzled) block scale layout +# --------------------------------------------------------------------------- + + +@functools.cache +def get_fp4_kv_dequantization_module(): + from ..jit.fp4_kv_dequantization import gen_fp4_kv_dequantization_module + + module = gen_fp4_kv_dequantization_module().build_and_load() + + @register_custom_op( + "flashinfer::nvfp4_kv_dequant", + mutates_args=("output",), + ) + def nvfp4_kv_dequant( + fp4_data: torch.Tensor, + block_scales: torch.Tensor, + global_scale: torch.Tensor, + output: torch.Tensor, + ) -> None: + module.nvfp4_kv_dequant(fp4_data, block_scales, global_scale, output) + + @register_fake_op("flashinfer::nvfp4_kv_dequant") + def _fake_nvfp4_kv_dequant( + fp4_data: torch.Tensor, + block_scales: torch.Tensor, + global_scale: torch.Tensor, + output: torch.Tensor, + ) -> None: + pass + + return SimpleNamespace(nvfp4_kv_dequant=nvfp4_kv_dequant) + + +@functools.cache +def get_fp4_kv_quantization_module(): + from ..jit.fp4_kv_quantization import gen_fp4_kv_quantization_module + + module = gen_fp4_kv_quantization_module().build_and_load() + + @register_custom_op( + "flashinfer::nvfp4_kv_quant", + mutates_args=("fp4_output", "block_scales"), + ) + def nvfp4_kv_quant( + input: torch.Tensor, + global_scale: torch.Tensor, + fp4_output: torch.Tensor, + block_scales: torch.Tensor, + ) -> None: + module.nvfp4_kv_quant(input, global_scale, fp4_output, block_scales) + + @register_fake_op("flashinfer::nvfp4_kv_quant") + def _fake_nvfp4_kv_quant( + input: torch.Tensor, + global_scale: torch.Tensor, + fp4_output: torch.Tensor, + block_scales: torch.Tensor, + ) -> None: + pass + + return SimpleNamespace(nvfp4_kv_quant=nvfp4_kv_quant) + + +_NVFP4_BLOCK_SIZE = 16 + + +@supported_compute_capability([80, 86, 89, 90, 100, 103, 110, 120, 121]) +def _nvfp4_kv_dequant_check(fp4_data, block_scales, global_scale, output_dtype=None): + return True + + +@backend_requirement({}, common_check=_nvfp4_kv_dequant_check) +@flashinfer_api +def nvfp4_kv_dequantize( + fp4_data: torch.Tensor, + block_scales: torch.Tensor, + global_scale: torch.Tensor, + output_dtype: torch.dtype = torch.bfloat16, +) -> torch.Tensor: + """GPU dequantization of NVFP4 KV cache data with linear block scale layout. + + Requires SM80+. + + Args: + fp4_data (torch.Tensor): Packed FP4 data of shape ``[M, K/2]`` with dtype uint8. + block_scales (torch.Tensor): Per-block FP8 E4M3 scales of shape ``[M, K/16]`` + with dtype uint8. + global_scale (torch.Tensor): Global scale factor of shape ``[1]`` with dtype float32, + on the same CUDA device as fp4_data. + output_dtype (torch.dtype): Output dtype, either ``torch.bfloat16`` or ``torch.float16``. + + Returns: + torch.Tensor: Dequantized tensor of shape ``[M, K]`` with the specified output dtype. + """ + M = fp4_data.size(0) + K = fp4_data.size(1) * 2 + if K % _NVFP4_BLOCK_SIZE != 0: + raise ValueError(f"K dimension ({K}) must be divisible by {_NVFP4_BLOCK_SIZE}") + output = torch.empty((M, K), dtype=output_dtype, device=fp4_data.device) + get_fp4_kv_dequantization_module().nvfp4_kv_dequant( + fp4_data, block_scales, global_scale, output + ) + return output + + +@supported_compute_capability([100, 103, 110, 120, 121]) +def _nvfp4_kv_quant_check(input, global_scale): + return True + + +@backend_requirement({}, common_check=_nvfp4_kv_quant_check) +@flashinfer_api +def nvfp4_kv_quantize( + input: torch.Tensor, + global_scale: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """GPU quantization to NVFP4 KV cache format with linear block scale layout. + + Requires SM100+ (Blackwell) for the cvt.rn.satfinite.e2m1x2.f32 PTX instruction. + + Args: + input (torch.Tensor): Input tensor of shape [M, K] with dtype bf16 or fp16. + K must be divisible by 16. + global_scale (torch.Tensor): Global scale factor of shape ``[1]`` with dtype float32, + on the same CUDA device as input. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: + - fp4_output: Packed FP4 data of shape ``[M, K/2]`` with dtype uint8. + - block_scales: Per-block FP8 E4M3 scales of shape ``[M, K/16]`` with dtype uint8. + """ + M, K = input.shape + if K % _NVFP4_BLOCK_SIZE != 0: + raise ValueError(f"K dimension ({K}) must be divisible by {_NVFP4_BLOCK_SIZE}") + fp4_output = torch.empty((M, K // 2), dtype=torch.uint8, device=input.device) + block_scales = torch.empty( + (M, K // _NVFP4_BLOCK_SIZE), dtype=torch.uint8, device=input.device + ) + get_fp4_kv_quantization_module().nvfp4_kv_quant( + input, global_scale, fp4_output, block_scales + ) + return fp4_output, block_scales diff --git a/flashinfer/quantization/fp8_quantization.py b/flashinfer/quantization/fp8_quantization.py new file mode 100644 index 0000000000..997d5d2b5f --- /dev/null +++ b/flashinfer/quantization/fp8_quantization.py @@ -0,0 +1,240 @@ +import functools +from types import SimpleNamespace +from typing import Optional, Tuple + +import torch + +from ..api_logging import flashinfer_api +from ..jit.fp8_quantization import gen_mxfp8_quantization_sm100_module +from ..utils import ( + device_support_pdl, + register_custom_op, + register_fake_op, +) + + +def _compute_swizzled_layout_sf_size(total_row, total_column, row_size=128): + padded_row = (total_row + row_size - 1) // row_size * row_size + padded_column = (total_column + 3) // 4 * 4 + return padded_row * padded_column + + +@functools.cache +def get_mxfp8_quantization_sm100_module(): + module = gen_mxfp8_quantization_sm100_module().build_and_load() + + @register_custom_op( + "flashinfer::mxfp8_quantize_sm100", + mutates_args=(""), + ) + def mxfp8_quantize_sm100( + input: torch.Tensor, + is_sf_swizzled_layout: bool = True, + alignment: int = 32, + enable_pdl: Optional[bool] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Quantize input tensor to MxFP8 format. + + Args: + input (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16/fp8_quantized. + is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True. + alignment (int, optional): sfVecSize. Defaults to 32. Note that alignment is not used in the host kernel. + enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic Dependent Launch). + If None, automatically detects based on device capability. Defaults to None. + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - Quantized tensor of shape [M, K] with dtype FLOAT8_E4M3 + - Scale factors tensor with shape determined by layout and sf_vec_size + """ + if input.device.type == "cpu": + out_val = torch.empty(input.shape, dtype=torch.uint8, device=input.device) + if is_sf_swizzled_layout: + out_sf_size = _compute_swizzled_layout_sf_size( + input.shape[0], input.shape[1] // 32, 128 + ) + else: + out_sf_size = input.numel() // 32 + out_sf = torch.empty((out_sf_size,), dtype=torch.uint8, device=input.device) + module.mxfp8_quantize_host( + input, + out_val, + out_sf, + is_sf_swizzled_layout, + ) + return out_val, out_sf + else: + if enable_pdl is None: + enable_pdl = device_support_pdl(input.device) + m = input.numel() // input.shape[-1] + k = input.shape[-1] + padded_k = (k + alignment - 1) // alignment * alignment + out_val = torch.empty( + (*input.shape[:-1], padded_k), + dtype=torch.float8_e4m3fn, + device=input.device, + ) + if is_sf_swizzled_layout: + out_sf_size = _compute_swizzled_layout_sf_size(m, padded_k // 32, 128) + else: + out_sf_size = m * padded_k // 32 + out_sf = torch.empty((out_sf_size,), dtype=torch.uint8, device=input.device) + module.mxfp8_quantize( + input, + out_val, + out_sf, + is_sf_swizzled_layout, + alignment, + enable_pdl, + ) + return out_val, out_sf + + @register_fake_op("flashinfer::mxfp8_quantize_sm100") + def _fake_mxfp8_quantize_sm100( + input: torch.Tensor, + is_sf_swizzled_layout: bool = True, + alignment: int = 32, + ) -> Tuple[torch.Tensor, torch.Tensor]: + m, k = input.shape + return ( + input.new_empty([m, k], dtype=torch.int64), # FLOAT8_E4M3 + input.new_empty([m * k // 32], dtype=torch.int32), # Scale factors + ) + + @register_custom_op( + "flashinfer::mxfp8_dequantize_host_sm100", + mutates_args=("",), + ) + def mxfp8_dequantize_host_sm100( + input: torch.Tensor, + scale_tensor: torch.Tensor, + is_sf_swizzled_layout: bool = True, + ) -> torch.Tensor: + """Dequantize input tensor from MxFP8 format. + + Args: + input (torch.Tensor): Input tensor of shape [M, K] with dtype FLOAT8_E4M3. + scale_tensor (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size. + is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True. + + Returns: + torch.Tensor: Dequantized float tensor of shape [M, K] with dtype float32. + """ + out = torch.empty(input.shape, dtype=torch.float32, device=input.device) + module.mxfp8_dequantize_host( + input, + scale_tensor, + out, + is_sf_swizzled_layout, + ) + return out + + @register_fake_op("flashinfer::mxfp8_dequantize_host_sm100") + def _fake_mxfp8_dequantize_host_sm100( + input: torch.Tensor, + scale_tensor: torch.Tensor, + is_sf_swizzled_layout: bool = True, + ) -> torch.Tensor: + return input.new_empty([input.shape[0], input.shape[1]], dtype=torch.float32) + + # Register the module + return SimpleNamespace( + mxfp8_quantize_sm100=mxfp8_quantize_sm100, + mxfp8_dequantize_host_sm100=mxfp8_dequantize_host_sm100, + ) + + +@flashinfer_api +def mxfp8_quantize( + input: torch.Tensor, + is_sf_swizzled_layout: bool = True, + alignment: int = 32, + enable_pdl: Optional[bool] = None, + backend: str = "cuda", +) -> Tuple[torch.Tensor, torch.Tensor]: + """Quantize input tensor to MxFP8 format. + + This function implements MxFP8 quantization that converts input tensors to a compressed MxFP8 format + with associated scale factors. It supports various input data types and scale factor layouts. + + Args: + input (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16/fp8_quantized. + is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True. + alignment (int, optional): sfVecSize. Defaults to 32. + enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic Dependent Launch). + If None, automatically detects based on device capability (SM >= 9.0). Defaults to None. + backend (str, optional): Backend to use for quantization. Options are: + - "cuda": Use JIT-compiled CUDA kernel (default, stable) + - "cute-dsl": Use CuTe-DSL kernel (requires SM100+, **experimental**) + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - Quantized tensor of shape [M, K] with dtype FLOAT8_E4M3 + - Scale factors tensor with shape determined by layout and sf_vec_size + + Warning: + The "cute-dsl" backend is **experimental** and not part of the stable API. + It may change or be removed in future versions without notice. + Use at your own risk for production workloads. + """ + sf_vec_size = 32 + + assert input.shape[-1] % sf_vec_size == 0 + assert backend in ("cuda", "cute-dsl"), ( + f"backend must be 'cuda' or 'cute-dsl', got '{backend}'" + ) + + if backend == "cute-dsl": + from ..cute_dsl import is_cute_dsl_available + + if not is_cute_dsl_available(): + raise RuntimeError( + "CuTe-DSL backend requested but CuTe-DSL is not available. " + "Please install nvidia-cutlass-dsl package." + ) + from .kernels.mxfp8_quantize import mxfp8_quantize_cute_dsl + + return mxfp8_quantize_cute_dsl( + input, + is_sf_swizzled_layout=is_sf_swizzled_layout, + alignment=alignment, + enable_pdl=enable_pdl, + ) + else: + # backend == "cuda" + if enable_pdl is None: + enable_pdl = device_support_pdl(input.device) + x_q, sf = get_mxfp8_quantization_sm100_module().mxfp8_quantize_sm100( + input, + is_sf_swizzled_layout, + alignment, + enable_pdl, + ) + return x_q, sf + + +@flashinfer_api +def mxfp8_dequantize_host( + input: torch.Tensor, + scale_tensor: torch.Tensor, + is_sf_swizzled_layout: bool = True, +) -> torch.Tensor: + """Dequantize input tensor from MxFP8 format. + + This function performs dequantization by converting a packed FP8 tensor in MxFP8 format + back to float values using the associated scale factors. + + Args: + input (torch.Tensor): Packed FP8 tensor in MxFP8 format of shape [M, K] with dtype FLOAT8_E4M3. + scale_tensor (torch.Tensor): Scale factors tensor with shape determined by layout and sf_vec_size. + is_sf_swizzled_layout (bool, optional): Whether scale factors use swizzled layout. Defaults to True. + + Returns: + torch.Tensor: Dequantized float tensor of shape [M, K] with dtype float32. + + """ + + return get_mxfp8_quantization_sm100_module().mxfp8_dequantize_host_sm100( + input, + scale_tensor, + is_sf_swizzled_layout, + ) diff --git a/flashinfer/quantization/kernels/__init__.py b/flashinfer/quantization/kernels/__init__.py new file mode 100644 index 0000000000..7e99b74a54 --- /dev/null +++ b/flashinfer/quantization/kernels/__init__.py @@ -0,0 +1,45 @@ +""" +Copyright (c) 2025 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +CuTe-DSL Quantization Kernels (EXPERIMENTAL) +============================================ + +.. warning:: + This subpackage is **experimental** and not part of the stable FlashInfer API. + The interfaces, implementations, and behaviors may change or be removed + in future versions without notice. Use at your own risk for production workloads. + +This subpackage contains high-performance CuTe-DSL implementations of +quantization kernels for MXFP4 and MXFP8 formats. These kernels require +SM100+ (Blackwell) GPUs and the nvidia-cutlass-dsl package. +""" + +from .mxfp4_quantize import ( + MXFP4QuantizeSwizzledKernel, + mxfp4_quantize_cute_dsl, +) +from .mxfp8_quantize import ( + MXFP8QuantizeLinearKernel, + MXFP8QuantizeSwizzledKernel, + mxfp8_quantize_cute_dsl, +) + +__all__ = [ + "MXFP4QuantizeSwizzledKernel", + "mxfp4_quantize_cute_dsl", + "MXFP8QuantizeLinearKernel", + "MXFP8QuantizeSwizzledKernel", + "mxfp8_quantize_cute_dsl", +] diff --git a/flashinfer/quantization/kernels/mxfp4_quantize.py b/flashinfer/quantization/kernels/mxfp4_quantize.py new file mode 100644 index 0000000000..f56dbf8eda --- /dev/null +++ b/flashinfer/quantization/kernels/mxfp4_quantize.py @@ -0,0 +1,528 @@ +""" +Copyright (c) 2025 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +MXFP4 Quantization using CuTe-DSL +================================= + +MXFP4 quantization kernel using CuTe-DSL. +Supports swizzled (128x4) scale factor layout. + +""" + +import functools +from typing import Callable, Tuple + +import cutlass +import cutlass.cute as cute +import torch +from cutlass import Int32, Uint8 + +from ...api_logging import flashinfer_api +from ...cute_dsl.fp4_common import get_ptr_as_int64, st_global_u64 +from ...cute_dsl.utils import get_num_sm +from ..quantization_cute_dsl_utils import ( + # MXFP4 Constants + MXFP4_SF_VEC_SIZE, + ROW_TILE_SIZE, + # Low-level intrinsics + compute_sf_index_swizzled_128x4_gpu, + # High-level helpers (MXFP4) + process_mxfp4_block_half, + process_mxfp4_block_bfloat, +) + + +# Blocks per SM for occupancy target +_BLOCKS_PER_SM = 4 + +# Maximum threads per block +_MAX_THREADS_PER_BLOCK = 1024 + +# Thread configuration bounds +_MIN_THREADS = 128 # Minimum for reasonable occupancy +_MAX_THREADS = 512 # Maximum to avoid register pressure +_DEFAULT_THREADS = 256 # Default thread count + + +def _compute_optimal_threads_for_k(K: int) -> int: + """ + Compute optimal thread count for 100% thread utilization. + + For MXFP4, each thread processes one SF block (32 elements). + threads_per_row = K / 32 = num_sf_blocks_per_row + + For 100% utilization when processing multiple rows: + threads_per_block % threads_per_row == 0 + + We prefer LARGER thread counts (up to _MAX_THREADS) for better occupancy, + while maintaining 100% thread utilization. + + Args: + K: Number of columns (must be divisible by 32) + + Returns: + Optimal number of threads per block + """ + threads_per_row = K // MXFP4_SF_VEC_SIZE # K / 32 + + # For 100% utilization: threads_per_block % threads_per_row == 0 + # threads_per_block must be a multiple of threads_per_row + + if threads_per_row >= _MAX_THREADS: + # Large K: use max threads, will need column loop + return _MAX_THREADS + + # threads_per_block should be a multiple of threads_per_row + if threads_per_row <= _MAX_THREADS: + # Find largest multiple of threads_per_row <= _MAX_THREADS + threads = (_MAX_THREADS // threads_per_row) * threads_per_row + if threads >= _MIN_THREADS: + return threads + # If largest multiple is below _MIN_THREADS, use the smallest valid one + threads = threads_per_row + while threads < _MIN_THREADS: + threads += threads_per_row + if threads <= _MAX_THREADS: + return threads + + # Fallback to default + return _DEFAULT_THREADS + + +def _compute_swizzled_layout_sf_size( + total_row: int, total_column: int, row_size: int = 128 +) -> int: + """Compute size of swizzled scale factor buffer.""" + padded_row = (total_row + row_size - 1) // row_size * row_size + padded_column = (total_column + 3) // 4 * 4 + return padded_row * padded_column + + +# ============================================================================= +# CuTe-DSL Kernel Class for MXFP4 Swizzled Layout +# ============================================================================= + + +class MXFP4QuantizeSwizzledKernel: + """ + MXFP4 quantization kernel optimized for SWIZZLED layout. + + Key optimizations: + - Multi-row processing: threads process multiple rows per block when K is small + - Dynamic thread count based on K for 100% thread utilization + - Row-based iteration with grid-stride loop + - Padding row fast path - only zero out scale factors + + Thread utilization optimization: + - For small K: Multiple rows processed per block iteration + - For large K: Single row with column loop + + Each thread processes one SF block (32 elements): + - UE8M0 scale factors (unsigned 8-bit exponent-only) + - E2M1 output format (4-bit, 2 values per byte) + + This kernel is M-agnostic: compiled once per (K, dtype, pdl) combination. + M-dependent values (M, padded_M) are passed at runtime. + """ + + def __init__( + self, + dtype: cutlass.Numeric, + K: int, + enable_pdl: bool = False, + ): + self.dtype = dtype + self.K = K + self.is_bfloat16 = dtype == cutlass.BFloat16 + self.enable_pdl = enable_pdl + + assert K % MXFP4_SF_VEC_SIZE == 0 + self.num_sf_blocks_per_row = K // MXFP4_SF_VEC_SIZE + self.padded_sf_cols = ((self.num_sf_blocks_per_row + 3) // 4) * 4 + + # Compute optimal thread count for 100% utilization + self.num_threads = _compute_optimal_threads_for_k(K) + + # Multi-row processing constants (compile-time) + # threads_per_row = num_sf_blocks_per_row (1 thread per SF block) + self.threads_per_row = self.num_sf_blocks_per_row + + # Determine if we can process multiple rows or need column loop + if self.threads_per_row <= self.num_threads: + # Small K: multiple rows per block + self.rows_per_block = self.num_threads // self.threads_per_row + self.needs_col_loop = False + else: + # Large K: one row per block with column loop + self.rows_per_block = 1 + self.needs_col_loop = True + + @cute.jit + def __call__( + self, + mInput: cute.Tensor, + mOutput: cute.Tensor, + mScales: cute.Tensor, + M: Int32, + padded_M: Int32, + num_blocks: Int32, + stream, + ): + threads_per_block = self.num_threads + + self.kernel(mInput, mOutput, mScales, M, padded_M).launch( + grid=[num_blocks, 1, 1], + block=[threads_per_block, 1, 1], + max_number_threads=[_MAX_THREADS_PER_BLOCK, 1, 1], + min_blocks_per_mp=_BLOCKS_PER_SM, + stream=stream, + use_pdl=self.enable_pdl, + ) + + @cute.kernel + def kernel( + self, + mInput: cute.Tensor, + mOutput: cute.Tensor, + mScales: cute.Tensor, + M: Int32, + padded_M: Int32, + ): + """ + MXFP4 quantization kernel with swizzled scale factor layout. + + Dual-path kernel with compile-time selection: + - Small K path: Multi-row processing for improved thread utilization + - Large K path: Single row with column loop + + Each thread processes one SF block (32 elements): + 1. Load 32 bf16/fp16 elements (4 x 128-bit loads) + 2. Compute max absolute value using SIMD reduction + 3. Compute UE8M0 scale: ceil(log2(max / 6.0)) + 127 + 4. Swizzle scale factor to 128x4 layout + 5. Scale elements and convert to E2M1 + 6. Store 16 bytes (32 FP4 values) + + Note: For MXFP4 (UE8M0 scale format), global scale is NOT used in + the scale computation, unlike NVFP4 (E4M3 scale format). The UE8M0 + format directly captures the per-block dynamic range. + """ + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + grid_dim_x, _, _ = cute.arch.grid_dim() + + if cutlass.const_expr(self.enable_pdl): + cute.arch.griddepcontrol_wait() + + num_sf_blocks_per_row = self.num_sf_blocks_per_row + padded_sf_cols = self.padded_sf_cols + num_threads = self.num_threads + rows_per_block = self.rows_per_block + threads_per_row = self.threads_per_row + + if cutlass.const_expr(not self.needs_col_loop): + # ===== SMALL K PATH: Multi-row processing ===== + # Each block processes rows_per_block rows simultaneously + # Thread maps to: row_in_block = tidx // threads_per_row + # sf_idx = tidx % threads_per_row + row_in_block = tidx // threads_per_row + sf_idx_in_row = tidx % threads_per_row + + # Grid-stride loop over row batches + row_batch_idx = bidx + total_row_batches = cute.ceil_div(padded_M, rows_per_block) + + while row_batch_idx < total_row_batches: + base_row = row_batch_idx * rows_per_block + row_idx = base_row + row_in_block + + if row_idx < padded_M: + is_padding_row = row_idx >= M + + if is_padding_row: + # Fast path: padding row - only zero out scale factors + # Each participating thread zeros one SF at a time + local_sf_idx = sf_idx_in_row + while local_sf_idx < padded_sf_cols: + sf_offset = compute_sf_index_swizzled_128x4_gpu( + row_idx, local_sf_idx, padded_sf_cols + ) + mScales[sf_offset] = Uint8(0) + local_sf_idx = local_sf_idx + threads_per_row + else: + # Normal path: process actual data row + if sf_idx_in_row < num_sf_blocks_per_row: + elem_base = sf_idx_in_row * MXFP4_SF_VEC_SIZE + row_input = mInput[row_idx, None] + + # Process block: load, compute scale, convert to E2M1 + if cutlass.const_expr(self.is_bfloat16): + ( + _, + scale_ue8m0, + packed64_0, + packed64_1, + ) = process_mxfp4_block_bfloat(row_input, elem_base) + else: + ( + _, + scale_ue8m0, + packed64_0, + packed64_1, + ) = process_mxfp4_block_half(row_input, elem_base) + + # Write swizzled scale factor + sf_offset = compute_sf_index_swizzled_128x4_gpu( + row_idx, sf_idx_in_row, padded_sf_cols + ) + mScales[sf_offset] = scale_ue8m0 + + # Store 16 bytes (32 FP4 values = 2 x st.global.u64) + row_output = mOutput[row_idx, None] + out_base = sf_idx_in_row * (MXFP4_SF_VEC_SIZE // 2) + out_ptr0 = get_ptr_as_int64(row_output, out_base) + out_ptr1 = get_ptr_as_int64(row_output, out_base + Int32(8)) + st_global_u64(out_ptr0, packed64_0) + st_global_u64(out_ptr1, packed64_1) + + # Handle padding SF columns (columns beyond actual K) + padding_sf_start = num_sf_blocks_per_row + sf_idx_in_row + while padding_sf_start < padded_sf_cols: + sf_offset = compute_sf_index_swizzled_128x4_gpu( + row_idx, padding_sf_start, padded_sf_cols + ) + mScales[sf_offset] = Uint8(0) + padding_sf_start = padding_sf_start + threads_per_row + + row_batch_idx = row_batch_idx + grid_dim_x + + else: + # ===== LARGE K PATH: Single row with column loop ===== + # Grid-stride loop over rows + row_idx = bidx + while row_idx < padded_M: + is_padding_row = row_idx >= M + + # Initialize sf_idx before control flow to satisfy DSL type requirements + sf_idx = Int32(tidx) + + if is_padding_row: + # Fast path: padding row - only zero out scale factors + while sf_idx < padded_sf_cols: + sf_offset = compute_sf_index_swizzled_128x4_gpu( + row_idx, sf_idx, padded_sf_cols + ) + mScales[sf_offset] = Uint8(0) + sf_idx = sf_idx + num_threads + else: + # Normal path: process actual data row with column loop + num_sf_iters = ( + num_sf_blocks_per_row + num_threads - 1 + ) // num_threads + + for sf_iter in range(num_sf_iters): + local_sf_idx = sf_iter * num_threads + tidx + + if local_sf_idx < num_sf_blocks_per_row: + elem_base = local_sf_idx * MXFP4_SF_VEC_SIZE + row_input = mInput[row_idx, None] + + # Process block: load, compute scale, convert to E2M1 + if cutlass.const_expr(self.is_bfloat16): + ( + _, + scale_ue8m0, + packed64_0, + packed64_1, + ) = process_mxfp4_block_bfloat(row_input, elem_base) + else: + ( + _, + scale_ue8m0, + packed64_0, + packed64_1, + ) = process_mxfp4_block_half(row_input, elem_base) + + # Write swizzled scale factor + sf_offset = compute_sf_index_swizzled_128x4_gpu( + row_idx, local_sf_idx, padded_sf_cols + ) + mScales[sf_offset] = scale_ue8m0 + + # Store 16 bytes (32 FP4 values = 2 x st.global.u64) + row_output = mOutput[row_idx, None] + out_base = local_sf_idx * (MXFP4_SF_VEC_SIZE // 2) + out_ptr0 = get_ptr_as_int64(row_output, out_base) + out_ptr1 = get_ptr_as_int64(row_output, out_base + Int32(8)) + st_global_u64(out_ptr0, packed64_0) + st_global_u64(out_ptr1, packed64_1) + + # Handle padding SF columns (columns beyond actual K) + padding_sf_start = num_sf_blocks_per_row + tidx + while padding_sf_start < padded_sf_cols: + sf_offset = compute_sf_index_swizzled_128x4_gpu( + row_idx, padding_sf_start, padded_sf_cols + ) + mScales[sf_offset] = Uint8(0) + padding_sf_start = padding_sf_start + num_threads + + row_idx = row_idx + grid_dim_x + + # PDL: Signal that dependent kernels can start early + if cutlass.const_expr(self.enable_pdl): + cute.arch.griddepcontrol_launch_dependents() + + +# ============================================================================= +# PyTorch Integration with TVM-FFI +# ============================================================================= + + +@functools.cache +def _get_compiled_kernel_mxfp4( + is_bfloat16: bool, + K: int, + enable_pdl: bool = False, +) -> Tuple[Callable, int]: + """ + Get or compile MXFP4 kernel with TVM-FFI. + + Cached by (K, dtype, pdl) - M-agnostic, device-independent compilation. + + Returns: + Tuple of (compiled_kernel, rows_per_block) where rows_per_block + is used by the caller to compute num_blocks at runtime. + """ + cutlass_dtype = cutlass.BFloat16 if is_bfloat16 else cutlass.Float16 + kernel_obj = MXFP4QuantizeSwizzledKernel(cutlass_dtype, K, enable_pdl) + + # Use symbolic M for dynamic batch sizes + sym_m = cute.sym_int() + + # Create fake tensors for compilation + input_fake = cute.runtime.make_fake_compact_tensor( + cutlass_dtype, (sym_m, K), stride_order=(1, 0), assumed_align=16 + ) + output_fake = cute.runtime.make_fake_compact_tensor( + cutlass.Uint8, (sym_m, K // 2), stride_order=(1, 0), assumed_align=16 + ) + sym_scale_size = cute.sym_int() + scales_fake = cute.runtime.make_fake_compact_tensor( + cutlass.Uint8, (sym_scale_size,), assumed_align=16 + ) + stream_fake = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True) + + compiled_kernel = cute.compile( + kernel_obj, + input_fake, + output_fake, + scales_fake, + Int32(1), # Dummy M + Int32(128), # Dummy padded_M + Int32(1), # Dummy num_blocks + stream_fake, + options="--enable-tvm-ffi", + ) + + return compiled_kernel, kernel_obj.rows_per_block + + +@flashinfer_api +def mxfp4_quantize_cute_dsl( + input: torch.Tensor, + enable_pdl: bool | None = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantize input tensor to MXFP4 format using CuTe-DSL kernel. + + This is a GPU implementation matching FlashInfer's mxfp4_quantize() behavior: + - Global scale computed as (448 * 6) / max(|input|) + - UE8M0 scale factors + - E2M1 output format (4-bit, 2 values per byte) + - Swizzled (128x4) scale factor layout + + The kernel is compiled once per (K, dtype, pdl) combination and handles + varying M (batch size) at runtime without recompilation. + + Args: + input: Input tensor of shape [M, K] with dtype fp16/bf16 + enable_pdl: Whether to enable PDL (Programmatic Dependent Launch). + If None, automatically detects based on device capability (SM >= 9.0). + + Returns: + Tuple of: + - fp4_tensor: Quantized tensor of shape [M, K/2] with dtype uint8 + - scale_tensor: Scale factors as uint8 tensor (swizzled layout) + """ + from ...utils import device_support_pdl + + assert input.dtype in (torch.float16, torch.bfloat16), ( + f"Input dtype must be float16 or bfloat16, got {input.dtype}" + ) + assert input.is_cuda, "Input must be on CUDA device" + + # Auto-detect PDL support based on device capability + if enable_pdl is None: + enable_pdl = device_support_pdl(input.device) + + if input.dim() > 2: + m = input.numel() // input.shape[-1] + k = input.shape[-1] + input = input.reshape(m, k) + else: + m, k = input.shape + + assert k % MXFP4_SF_VEC_SIZE == 0, ( + f"K ({k}) must be divisible by MXFP4_SF_VEC_SIZE={MXFP4_SF_VEC_SIZE}" + ) + + input = input.contiguous() + is_bfloat16 = input.dtype == torch.bfloat16 + + # Cached device-specific target grid for grid size computation + target_grid = get_num_sm(input.device) * _BLOCKS_PER_SM + + # Compute M-dependent values + num_sf_blocks_per_row = k // MXFP4_SF_VEC_SIZE + padded_m = ((m + ROW_TILE_SIZE - 1) // ROW_TILE_SIZE) * ROW_TILE_SIZE + padded_sf_cols = ((num_sf_blocks_per_row + 3) // 4) * 4 + scale_output_size = padded_m * padded_sf_cols + + # Get or compile kernel (device-independent) + kernel_fn, rows_per_block = _get_compiled_kernel_mxfp4(is_bfloat16, k, enable_pdl) + + # Compute grid size in Python (runtime, device-specific) + num_blocks = min((padded_m + rows_per_block - 1) // rows_per_block, target_grid) + + # Allocate outputs + fp4_output = torch.empty(m, k // 2, dtype=torch.uint8, device=input.device) + scale_output = torch.empty( + scale_output_size, dtype=torch.uint8, device=input.device + ) + + # Launch kernel + kernel_fn(input, fp4_output, scale_output, m, padded_m, num_blocks) + + # Reshape scale output to match CUDA backend format: [padded_total, num_sf_per_row] + scale_output = scale_output.reshape(-1, num_sf_blocks_per_row) + + return fp4_output, scale_output + + +__all__ = [ + "MXFP4QuantizeSwizzledKernel", + "mxfp4_quantize_cute_dsl", + "_get_compiled_kernel_mxfp4", +] diff --git a/flashinfer/quantization/kernels/mxfp8_quantize.py b/flashinfer/quantization/kernels/mxfp8_quantize.py new file mode 100644 index 0000000000..445ee81fb7 --- /dev/null +++ b/flashinfer/quantization/kernels/mxfp8_quantize.py @@ -0,0 +1,767 @@ +""" +Copyright (c) 2025 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +MXFP8 Quantization using CuTe-DSL +================================= + +High-performance MXFP8 quantization kernel using CuTe-DSL. +Supports both linear and swizzled (128x4) scale factor layouts. + +Key features: +- Half2/BFloat2 SIMD for max-abs computation +- 4-thread cooperation per scale factor block +- Dual-path optimization: linear layout (SF-block based) and swizzled layout (row-based) +- Vectorized 128-bit global loads/stores +- M-agnostic compilation: kernels are compiled once per K dimension +""" + +import functools +from typing import Callable, Tuple + +import cutlass +import cutlass.cute as cute +import torch +from cutlass import Float32, Int32, Uint8 + +from ...api_logging import flashinfer_api +from ...cute_dsl.fp4_common import ( + ld_global_v4_u32, + st_global_u64, + get_ptr_as_int64, +) +from ...cute_dsl.utils import get_num_sm +from ..quantization_cute_dsl_utils import ( + # Constants + SF_VEC_SIZE, + INV_FLOAT8_E4M3_MAX, + WARP_SIZE, + ELTS_PER_THREAD, + THREADS_PER_SF, + SF_BLOCKS_PER_WARP, + ROW_TILE_SIZE, + # Low-level intrinsics + hmax_reduce_to_f32, + bfloat2_hmax_reduce_to_f32, + float_to_ue8m0_fast, + ue8m0_to_inv_scale_fast, + reduce_max_4threads, + compute_sf_index_swizzled_128x4_gpu, + # High-level helpers + half2_max_abs_4, + bfloat2_max_abs_4, + half2x4_to_fp8x8_packed, + bfloat2x4_to_fp8x8_packed, +) + + +# Blocks per SM for occupancy target +_BLOCKS_PER_SM = 4 + +# Maximum threads per block (all modern NVIDIA GPUs support 1024) +_MAX_THREADS_PER_BLOCK = 1024 + + +# Warp configuration bounds +_MIN_WARPS = 4 # Minimum for reasonable occupancy (128 threads) +_MAX_WARPS = 32 # Maximum to avoid register pressure (1024 threads) +_DEFAULT_WARPS = 16 # Default when no optimization needed + + +def _compute_optimal_warps_for_k(K: int) -> int: + """ + Compute optimal WARPS_PER_BLOCK for 100% thread utilization. + + For the swizzled kernel, we need: + (WARPS × 8) % num_sf_blocks == 0 + + where num_sf_blocks = K / 32. + + This ensures that col_units_per_block is evenly divisible by + num_sf_blocks_per_row, so all threads are utilized. + + We prefer LARGER warp counts (up to _MAX_WARPS) for better occupancy, + while maintaining 100% thread utilization. + + Args: + K: Number of columns (must be divisible by 32) + + Returns: + Optimal number of warps per block + """ + import math + + num_sf_blocks = K // SF_VEC_SIZE # K / 32 + + # For 100% utilization: (WARPS * 8) % num_sf_blocks == 0 + # WARPS must be a multiple of: num_sf_blocks / gcd(num_sf_blocks, 8) + gcd_val = math.gcd(num_sf_blocks, 8) + warp_multiple = num_sf_blocks // gcd_val + + # Find LARGEST valid WARPS in range [_MIN_WARPS, _MAX_WARPS] + # that is a multiple of warp_multiple (for best occupancy) + if warp_multiple <= _MAX_WARPS: + # Find largest multiple of warp_multiple that fits in [_MIN_WARPS, _MAX_WARPS] + warps = (_MAX_WARPS // warp_multiple) * warp_multiple + if warps >= _MIN_WARPS: + return warps + # If largest multiple is below _MIN_WARPS, use the smallest valid one + warps = warp_multiple + while warps < _MIN_WARPS: + warps += warp_multiple + if warps <= _MAX_WARPS: + return warps + + # If warp_multiple is too large, fall back to default + # This shouldn't happen for reasonable K values + return _DEFAULT_WARPS + + +# ============================================================================= +# CuTe-DSL Kernel Class for Linear Layout +# ============================================================================= + + +class MXFP8QuantizeLinearKernel: + """ + MXFP8 quantization kernel optimized for LINEAR layout. + Uses SF-block based iteration for efficient memory access. + + This kernel is M-agnostic: compiled once per (K, dtype, pdl) combination. + M-dependent values (total_sf_blocks, num_blocks) are passed at runtime. + """ + + WARPS_PER_BLOCK = 16 # 16 warps = 512 threads per block + SF_BLOCKS_PER_TB = WARPS_PER_BLOCK * SF_BLOCKS_PER_WARP + + def __init__( + self, + dtype: cutlass.Numeric, + K: int, + enable_pdl: bool = False, + ): + self.dtype = dtype + self.K = K + self.is_bfloat16 = dtype == cutlass.BFloat16 + self.enable_pdl = enable_pdl + + assert K % SF_VEC_SIZE == 0 + self.num_sf_blocks_per_row = K // SF_VEC_SIZE + + @cute.jit + def __call__( + self, + mInput: cute.Tensor, + mOutput: cute.Tensor, + mScales: cute.Tensor, + total_sf_blocks: Int32, + num_blocks: Int32, + stream, + ): + threads_per_block = self.WARPS_PER_BLOCK * WARP_SIZE + + self.kernel(mInput, mOutput, mScales, total_sf_blocks).launch( + grid=[num_blocks, 1, 1], + block=[threads_per_block, 1, 1], + max_number_threads=[_MAX_THREADS_PER_BLOCK, 1, 1], + min_blocks_per_mp=_BLOCKS_PER_SM, + stream=stream, + use_pdl=self.enable_pdl, + ) + + @cute.kernel + def kernel( + self, + mInput: cute.Tensor, + mOutput: cute.Tensor, + mScales: cute.Tensor, + total_sf_blocks: Int32, + ): + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + grid_dim_x, _, _ = cute.arch.grid_dim() + + if cutlass.const_expr(self.enable_pdl): + cute.arch.griddepcontrol_wait() + + warp_idx = tidx // WARP_SIZE + lane_idx = tidx % WARP_SIZE + + sf_idx_in_warp = lane_idx // THREADS_PER_SF + thread_in_sf = lane_idx % THREADS_PER_SF + + sf_blocks_per_tb = self.WARPS_PER_BLOCK * SF_BLOCKS_PER_WARP + num_sf_blocks_per_row = self.num_sf_blocks_per_row + + sf_idx_base = ( + bidx * sf_blocks_per_tb + warp_idx * SF_BLOCKS_PER_WARP + sf_idx_in_warp + ) + + sf_idx = sf_idx_base + while sf_idx < total_sf_blocks: + row_idx = sf_idx // num_sf_blocks_per_row + col_idx = sf_idx % num_sf_blocks_per_row + + base_elem = col_idx * SF_VEC_SIZE + thread_elem_offset = thread_in_sf * ELTS_PER_THREAD + elem_idx = base_elem + thread_elem_offset + + row_input = mInput[row_idx, None] + input_ptr_i64 = get_ptr_as_int64(row_input, elem_idx) + + v0, v1, v2, v3 = ld_global_v4_u32(input_ptr_i64) + + # Compute max absolute value across 8 elements (4 half2/bfloat2) + if cutlass.const_expr(self.is_bfloat16): + max0123 = bfloat2_max_abs_4(v0, v1, v2, v3) + local_max = bfloat2_hmax_reduce_to_f32(max0123) + else: + max0123 = half2_max_abs_4(v0, v1, v2, v3) + local_max = hmax_reduce_to_f32(max0123) + + # 4-thread reduction for this SF block + global_max = reduce_max_4threads(local_max) + + # Compute UE8M0 scale factor + inv_e4m3_max = Float32(INV_FLOAT8_E4M3_MAX) + normalized_max = global_max * inv_e4m3_max + scale_ue8m0_u32 = float_to_ue8m0_fast(normalized_max) + scale_ue8m0 = scale_ue8m0_u32.to(Uint8) + + # Compute inverse scale for quantization + inv_scale = ue8m0_to_inv_scale_fast(scale_ue8m0_u32) + + # Quantize to FP8 E4M3 and pack for vectorized store + if cutlass.const_expr(self.is_bfloat16): + fp8_packed = bfloat2x4_to_fp8x8_packed(v0, v1, v2, v3, inv_scale) + else: + fp8_packed = half2x4_to_fp8x8_packed(v0, v1, v2, v3, inv_scale) + + row_output = mOutput[row_idx, None] + output_ptr_i64 = get_ptr_as_int64(row_output, elem_idx) + st_global_u64(output_ptr_i64, fp8_packed) + + if thread_in_sf == Int32(0): + mScales[sf_idx] = scale_ue8m0 + + sf_idx = sf_idx + grid_dim_x * sf_blocks_per_tb + + # PDL: Signal that dependent kernels can start early + if cutlass.const_expr(self.enable_pdl): + cute.arch.griddepcontrol_launch_dependents() + + +# ============================================================================= +# CuTe-DSL Kernel Class for Swizzled Layout +# ============================================================================= + + +class MXFP8QuantizeSwizzledKernel: + """ + MXFP8 quantization kernel optimized for SWIZZLED layout. + + Key optimizations: + - Multi-row processing: threads process multiple rows per block when K is small + - Row-based iteration with grid-stride loop + - Padding row fast path - only zero out scale factors + + Thread utilization optimization: + - Dynamic WARPS_PER_BLOCK based on K for 100% thread utilization + - For small K: Multiple rows processed per block iteration + - For large K: Single row with column loop + + This kernel is M-agnostic: compiled once per (K, dtype, pdl) combination. + M-dependent values (M, padded_M) are passed at runtime. + """ + + def __init__( + self, + dtype: cutlass.Numeric, + K: int, + enable_pdl: bool = False, + ): + self.dtype = dtype + self.K = K + self.is_bfloat16 = dtype == cutlass.BFloat16 + self.enable_pdl = enable_pdl + + assert K % SF_VEC_SIZE == 0 + self.num_sf_blocks_per_row = K // SF_VEC_SIZE + self.padded_sf_cols = ((self.num_sf_blocks_per_row + 3) // 4) * 4 + + # Compute optimal warps for 100% thread utilization + self.warps_per_block = _compute_optimal_warps_for_k(K) + + # Multi-row processing constants (compile-time) + threads_per_block = self.warps_per_block * WARP_SIZE + col_units_per_block = threads_per_block // THREADS_PER_SF + + # threads_per_row = num_sf_blocks_per_row * THREADS_PER_SF = K / 8 + self.threads_per_row = self.num_sf_blocks_per_row * THREADS_PER_SF + + # rows_per_block = col_units_per_block // num_sf_blocks_per_row + # With optimal warps, this should divide evenly for small K + if self.num_sf_blocks_per_row <= col_units_per_block: + self.rows_per_block = col_units_per_block // self.num_sf_blocks_per_row + self.needs_col_loop = False + else: + self.rows_per_block = 1 + self.needs_col_loop = True + + @cute.jit + def __call__( + self, + mInput: cute.Tensor, + mOutput: cute.Tensor, + mScales: cute.Tensor, + M: Int32, + padded_M: Int32, + num_blocks: Int32, + stream, + ): + threads_per_block = self.warps_per_block * WARP_SIZE + + self.kernel(mInput, mOutput, mScales, M, padded_M).launch( + grid=[num_blocks, 1, 1], + block=[threads_per_block, 1, 1], + max_number_threads=[_MAX_THREADS_PER_BLOCK, 1, 1], + min_blocks_per_mp=_BLOCKS_PER_SM, + stream=stream, + use_pdl=self.enable_pdl, + ) + + @cute.kernel + def kernel( + self, + mInput: cute.Tensor, + mOutput: cute.Tensor, + mScales: cute.Tensor, + M: Int32, + padded_M: Int32, + ): + """ + Multi-row kernel for swizzled layout. + + When K is small: each block processes multiple rows simultaneously. + When K is large: each block processes one row with column loop. + """ + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + grid_dim_x, _, _ = cute.arch.grid_dim() + + if cutlass.const_expr(self.enable_pdl): + cute.arch.griddepcontrol_wait() + + # Compile-time constants + num_sf_blocks_per_row = self.num_sf_blocks_per_row + padded_sf_cols = self.padded_sf_cols + threads_per_row = self.threads_per_row + rows_per_block = self.rows_per_block + + if cutlass.const_expr(self.needs_col_loop): + # Large K path: single row per block iteration with column loop + # This is the original algorithm for K > 4096 + col_unit_idx = tidx // THREADS_PER_SF + thread_in_unit = tidx % THREADS_PER_SF + threads_per_block = self.warps_per_block * WARP_SIZE + col_units_per_block = threads_per_block // THREADS_PER_SF + + row_idx = bidx + while row_idx < padded_M: + is_padding_row = row_idx >= M + + if is_padding_row: + # Fast path: padding row - only zero out scale factors + sf_col_idx = col_unit_idx + while sf_col_idx < padded_sf_cols: + if thread_in_unit == Int32(0): + sf_offset = compute_sf_index_swizzled_128x4_gpu( + row_idx, sf_col_idx, padded_sf_cols + ) + mScales[sf_offset] = Uint8(0) + sf_col_idx = sf_col_idx + col_units_per_block + else: + # Normal path: process actual data row with column loop + sf_col_idx = col_unit_idx + while sf_col_idx < num_sf_blocks_per_row: + elem_idx = ( + sf_col_idx * SF_VEC_SIZE + thread_in_unit * ELTS_PER_THREAD + ) + + row_input = mInput[row_idx, None] + input_ptr_i64 = get_ptr_as_int64(row_input, elem_idx) + v0, v1, v2, v3 = ld_global_v4_u32(input_ptr_i64) + + if cutlass.const_expr(self.is_bfloat16): + max0123 = bfloat2_max_abs_4(v0, v1, v2, v3) + local_max = bfloat2_hmax_reduce_to_f32(max0123) + else: + max0123 = half2_max_abs_4(v0, v1, v2, v3) + local_max = hmax_reduce_to_f32(max0123) + + global_max = reduce_max_4threads(local_max) + inv_e4m3_max = Float32(INV_FLOAT8_E4M3_MAX) + normalized_max = global_max * inv_e4m3_max + scale_ue8m0_u32 = float_to_ue8m0_fast(normalized_max) + scale_ue8m0 = scale_ue8m0_u32.to(Uint8) + inv_scale = ue8m0_to_inv_scale_fast(scale_ue8m0_u32) + + if cutlass.const_expr(self.is_bfloat16): + fp8_packed = bfloat2x4_to_fp8x8_packed( + v0, v1, v2, v3, inv_scale + ) + else: + fp8_packed = half2x4_to_fp8x8_packed( + v0, v1, v2, v3, inv_scale + ) + + row_output = mOutput[row_idx, None] + output_ptr_i64 = get_ptr_as_int64(row_output, elem_idx) + st_global_u64(output_ptr_i64, fp8_packed) + + if thread_in_unit == Int32(0): + sf_offset = compute_sf_index_swizzled_128x4_gpu( + row_idx, sf_col_idx, padded_sf_cols + ) + mScales[sf_offset] = scale_ue8m0 + + sf_col_idx = sf_col_idx + col_units_per_block + + # Handle padding columns + sf_col_idx = num_sf_blocks_per_row + col_unit_idx + while sf_col_idx < padded_sf_cols: + if thread_in_unit == Int32(0): + sf_offset = compute_sf_index_swizzled_128x4_gpu( + row_idx, sf_col_idx, padded_sf_cols + ) + mScales[sf_offset] = Uint8(0) + sf_col_idx = sf_col_idx + col_units_per_block + + row_idx = row_idx + grid_dim_x + else: + # Small K path: multi-row processing (K <= 4096) + # Each block processes rows_per_block rows simultaneously + # Thread mapping: tidx -> (row_in_block, sf_col_idx, thread_in_unit) + row_in_block = tidx // threads_per_row + local_tidx = tidx % threads_per_row + sf_col_idx = local_tidx // THREADS_PER_SF + thread_in_unit = local_tidx % THREADS_PER_SF + + # Grid-stride loop over row batches + row_batch_idx = bidx + # Initialize row_idx before while loop (CuTe DSL requires variables + # modified in while loops to be defined before the loop) + row_idx = row_batch_idx * rows_per_block + row_in_block + while row_batch_idx * rows_per_block < padded_M: + # Check if this thread's row is valid + if row_idx < padded_M: + is_padding_row = row_idx >= M + + if is_padding_row: + # Fast path: padding row - zero out scale factors + # Each thread handles one SF column (no column loop needed) + if sf_col_idx < padded_sf_cols and thread_in_unit == Int32(0): + sf_offset = compute_sf_index_swizzled_128x4_gpu( + row_idx, sf_col_idx, padded_sf_cols + ) + mScales[sf_offset] = Uint8(0) + else: + # Normal path: process actual data + if sf_col_idx < num_sf_blocks_per_row: + elem_idx = ( + sf_col_idx * SF_VEC_SIZE + + thread_in_unit * ELTS_PER_THREAD + ) + + row_input = mInput[row_idx, None] + input_ptr_i64 = get_ptr_as_int64(row_input, elem_idx) + v0, v1, v2, v3 = ld_global_v4_u32(input_ptr_i64) + + if cutlass.const_expr(self.is_bfloat16): + max0123 = bfloat2_max_abs_4(v0, v1, v2, v3) + local_max = bfloat2_hmax_reduce_to_f32(max0123) + else: + max0123 = half2_max_abs_4(v0, v1, v2, v3) + local_max = hmax_reduce_to_f32(max0123) + + global_max = reduce_max_4threads(local_max) + inv_e4m3_max = Float32(INV_FLOAT8_E4M3_MAX) + normalized_max = global_max * inv_e4m3_max + scale_ue8m0_u32 = float_to_ue8m0_fast(normalized_max) + scale_ue8m0 = scale_ue8m0_u32.to(Uint8) + inv_scale = ue8m0_to_inv_scale_fast(scale_ue8m0_u32) + + if cutlass.const_expr(self.is_bfloat16): + fp8_packed = bfloat2x4_to_fp8x8_packed( + v0, v1, v2, v3, inv_scale + ) + else: + fp8_packed = half2x4_to_fp8x8_packed( + v0, v1, v2, v3, inv_scale + ) + + row_output = mOutput[row_idx, None] + output_ptr_i64 = get_ptr_as_int64(row_output, elem_idx) + st_global_u64(output_ptr_i64, fp8_packed) + + if thread_in_unit == Int32(0): + sf_offset = compute_sf_index_swizzled_128x4_gpu( + row_idx, sf_col_idx, padded_sf_cols + ) + mScales[sf_offset] = scale_ue8m0 + + # Handle padding SF columns (for this row) + # Threads with sf_col_idx in [num_sf_blocks_per_row, padded_sf_cols) + if ( + sf_col_idx >= num_sf_blocks_per_row + and sf_col_idx < padded_sf_cols + and thread_in_unit == Int32(0) + ): + sf_offset = compute_sf_index_swizzled_128x4_gpu( + row_idx, sf_col_idx, padded_sf_cols + ) + mScales[sf_offset] = Uint8(0) + + row_batch_idx = row_batch_idx + grid_dim_x + # Update row_idx for next iteration + row_idx = row_batch_idx * rows_per_block + row_in_block + + # PDL: Signal that dependent kernels can start early + if cutlass.const_expr(self.enable_pdl): + cute.arch.griddepcontrol_launch_dependents() + + +# ============================================================================= +# PyTorch Integration with TVM-FFI +# ============================================================================= + + +@functools.cache +def _get_compiled_kernel_linear( + is_bfloat16: bool, + K: int, + enable_pdl: bool = False, +) -> Tuple[Callable, int]: + """ + Get or compile LINEAR layout kernel with TVM-FFI. + + Cached by (K, dtype, pdl) - M-agnostic, device-independent compilation. + + Returns: + Tuple of (compiled_kernel, sf_blocks_per_tb) where sf_blocks_per_tb + is used by the caller to compute num_blocks at runtime. + """ + cutlass_dtype = cutlass.BFloat16 if is_bfloat16 else cutlass.Float16 + kernel_obj = MXFP8QuantizeLinearKernel(cutlass_dtype, K, enable_pdl) + + # Use symbolic M for dynamic batch sizes + sym_m = cute.sym_int() + + # Create fake tensors for compilation + input_fake = cute.runtime.make_fake_compact_tensor( + cutlass_dtype, (sym_m, K), stride_order=(1, 0), assumed_align=16 + ) + output_fake = cute.runtime.make_fake_compact_tensor( + cutlass.Uint8, (sym_m, K), stride_order=(1, 0), assumed_align=16 + ) + sym_scale_size = cute.sym_int() + scales_fake = cute.runtime.make_fake_compact_tensor( + cutlass.Uint8, (sym_scale_size,), assumed_align=16 + ) + stream_fake = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True) + + compiled_kernel = cute.compile( + kernel_obj, + input_fake, + output_fake, + scales_fake, + Int32(1), # Dummy total_sf_blocks + Int32(1), # Dummy num_blocks + stream_fake, + options="--enable-tvm-ffi", + ) + + return compiled_kernel, kernel_obj.SF_BLOCKS_PER_TB + + +@functools.cache +def _get_compiled_kernel_swizzled( + is_bfloat16: bool, + K: int, + enable_pdl: bool = False, +) -> Tuple[Callable, int]: + """ + Get or compile SWIZZLED layout kernel with TVM-FFI. + + Cached by (K, dtype, pdl) - M-agnostic, device-independent compilation. + + Returns: + Tuple of (compiled_kernel, rows_per_block) where rows_per_block + is used by the caller to compute num_blocks at runtime. + """ + cutlass_dtype = cutlass.BFloat16 if is_bfloat16 else cutlass.Float16 + kernel_obj = MXFP8QuantizeSwizzledKernel(cutlass_dtype, K, enable_pdl) + + # Use symbolic M for dynamic batch sizes + sym_m = cute.sym_int() + + # Create fake tensors for compilation + input_fake = cute.runtime.make_fake_compact_tensor( + cutlass_dtype, (sym_m, K), stride_order=(1, 0), assumed_align=16 + ) + output_fake = cute.runtime.make_fake_compact_tensor( + cutlass.Uint8, (sym_m, K), stride_order=(1, 0), assumed_align=16 + ) + sym_scale_size = cute.sym_int() + scales_fake = cute.runtime.make_fake_compact_tensor( + cutlass.Uint8, (sym_scale_size,), assumed_align=16 + ) + stream_fake = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True) + + compiled_kernel = cute.compile( + kernel_obj, + input_fake, + output_fake, + scales_fake, + Int32(1), # Dummy M + Int32(128), # Dummy padded_M + Int32(1), # Dummy num_blocks + stream_fake, + options="--enable-tvm-ffi", + ) + + return compiled_kernel, kernel_obj.rows_per_block + + +@flashinfer_api +def mxfp8_quantize_cute_dsl( + input: torch.Tensor, + is_sf_swizzled_layout: bool = True, + alignment: int = 32, + enable_pdl: bool | None = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantize input tensor to MXFP8 format using CuTe-DSL kernel. + + This is a GPU implementation with dual-path optimization: + - LINEAR layout: SF-block based iteration (fast) + - SWIZZLED layout: Row-based iteration with padding fast path (optimized) + + The kernel is compiled once per (K, dtype, pdl) combination and handles + varying M (batch size) at runtime without recompilation. + + Args: + input: Input tensor of shape [M, K] with dtype fp16/bf16 + is_sf_swizzled_layout: Whether to use 128x4 swizzled layout (True) or linear (False) + alignment: Alignment for K dimension (default 32, must be multiple of SF_VEC_SIZE) + enable_pdl: Whether to enable PDL (Programmatic Dependent Launch). + If None, automatically detects based on device capability (SM >= 9.0). + + Returns: + Tuple of: + - fp8_tensor: Quantized tensor of shape [M, padded_K] with dtype float8_e4m3fn + - scale_tensor: Scale factors as uint8 tensor + """ + from ...utils import device_support_pdl + + assert input.dtype in (torch.float16, torch.bfloat16), ( + f"Input dtype must be float16 or bfloat16, got {input.dtype}" + ) + assert input.is_cuda, "Input must be on CUDA device" + assert alignment % SF_VEC_SIZE == 0, ( + f"alignment must be divisible by SF_VEC_SIZE={SF_VEC_SIZE}" + ) + + # Auto-detect PDL support based on device capability + if enable_pdl is None: + enable_pdl = device_support_pdl(input.device) + + if input.dim() > 2: + m = input.numel() // input.shape[-1] + k = input.shape[-1] + input = input.reshape(m, k) + else: + m, k = input.shape + + assert k % SF_VEC_SIZE == 0, ( + f"K ({k}) must be divisible by SF_VEC_SIZE={SF_VEC_SIZE}" + ) + + padded_k = ((k + alignment - 1) // alignment) * alignment + + if padded_k > k: + # Pad input with zeros - padding columns must be zero to produce zero FP8 output + input_padded = torch.zeros(m, padded_k, dtype=input.dtype, device=input.device) + input_padded[:, :k] = input + else: + input_padded = input.contiguous() + + is_bfloat16 = input.dtype == torch.bfloat16 + + # Cached device-specific target grid for grid size computation + target_grid = get_num_sm(input.device) * _BLOCKS_PER_SM + + # Compute M-dependent values outside the cached kernel + num_sf_blocks_per_row = padded_k // SF_VEC_SIZE + + if is_sf_swizzled_layout: + # Swizzled layout: compute padded_M and scale_output_size + padded_m = ((m + ROW_TILE_SIZE - 1) // ROW_TILE_SIZE) * ROW_TILE_SIZE + padded_sf_cols = ((num_sf_blocks_per_row + 3) // 4) * 4 + scale_output_size = padded_m * padded_sf_cols + + kernel_fn, rows_per_block = _get_compiled_kernel_swizzled( + is_bfloat16, padded_k, enable_pdl + ) + + num_blocks = min((padded_m + rows_per_block - 1) // rows_per_block, target_grid) + + fp8_output = torch.empty(m, padded_k, dtype=torch.uint8, device=input.device) + scale_output = torch.empty( + scale_output_size, dtype=torch.uint8, device=input.device + ) + + kernel_fn(input_padded, fp8_output, scale_output, m, padded_m, num_blocks) + else: + # Linear layout: compute total_sf_blocks + total_sf_blocks = m * num_sf_blocks_per_row + scale_output_size = total_sf_blocks + + kernel_fn, sf_blocks_per_tb = _get_compiled_kernel_linear( + is_bfloat16, padded_k, enable_pdl + ) + + num_blocks = min( + (total_sf_blocks + sf_blocks_per_tb - 1) // sf_blocks_per_tb, target_grid + ) + + fp8_output = torch.empty(m, padded_k, dtype=torch.uint8, device=input.device) + scale_output = torch.empty( + scale_output_size, dtype=torch.uint8, device=input.device + ) + + kernel_fn(input_padded, fp8_output, scale_output, total_sf_blocks, num_blocks) + + fp8_tensor = fp8_output.view(torch.float8_e4m3fn) + + return fp8_tensor, scale_output + + +__all__ = [ + "MXFP8QuantizeLinearKernel", + "MXFP8QuantizeSwizzledKernel", + "mxfp8_quantize_cute_dsl", + "_get_compiled_kernel_linear", + "_get_compiled_kernel_swizzled", +] diff --git a/flashinfer/quantization.py b/flashinfer/quantization/packbits.py similarity index 96% rename from flashinfer/quantization.py rename to flashinfer/quantization/packbits.py index 4e279ab5f0..762357778f 100644 --- a/flashinfer/quantization.py +++ b/flashinfer/quantization/packbits.py @@ -19,9 +19,9 @@ import torch -from .api_logging import flashinfer_api -from .jit.quantization import gen_quantization_module -from .utils import register_custom_op, register_fake_op +from ..api_logging import flashinfer_api +from ..jit.quantization import gen_quantization_module +from ..utils import register_custom_op, register_fake_op @functools.cache diff --git a/flashinfer/quantization/quantization_cute_dsl_utils.py b/flashinfer/quantization/quantization_cute_dsl_utils.py new file mode 100644 index 0000000000..b4d3ac1dbf --- /dev/null +++ b/flashinfer/quantization/quantization_cute_dsl_utils.py @@ -0,0 +1,1002 @@ +""" +Copyright (c) 2025 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +Common utilities for quantization kernels using CuTe-DSL. + +This module contains shared PTX intrinsics and helper functions for MXFP8 +and MXFP4 quantization kernels. +""" + +import cutlass.cute as cute +from cutlass import Float32, Int32, Uint32, Uint64 +from cutlass._mlir.dialects import llvm +from cutlass.cutlass_dsl import T, dsl_user_op + +from ..cute_dsl.fp4_common import habs2, hmax2, bfloat2_habs2, bfloat2_hmax2 + + +# ============================================================================= +# MXFP8 Constants +# ============================================================================= + +# Scale factor vector size: each scale factor covers 32 elements +SF_VEC_SIZE = 32 + +# Inverse of max representable value in FP8 E4M3 format (1/448) +INV_FLOAT8_E4M3_MAX = 1.0 / 448.0 + +# Thread organization constants +WARP_SIZE = 32 +ELTS_PER_THREAD = 8 # Each thread handles 8 FP16 elements (128 bits) +THREADS_PER_SF = SF_VEC_SIZE // ELTS_PER_THREAD # 32 / 8 = 4 threads per SF block +SF_BLOCKS_PER_WARP = WARP_SIZE // THREADS_PER_SF # 32 / 4 = 8 SF blocks per warp + +# Row tiling for swizzled layout (128x4 pattern) +ROW_TILE_SIZE = 128 + + +# ============================================================================= +# MXFP4 Constants +# ============================================================================= + +# Scale factor vector size for MXFP4: each scale factor covers 32 elements +MXFP4_SF_VEC_SIZE = 32 + +# Elements per thread for MXFP4: each thread handles 32 elements (one full SF block) +MXFP4_ELTS_PER_THREAD = 32 + +# Inverse of max representable value in FP4 E2M1 format (1/6) +INV_FLOAT4_E2M1_MAX = 1.0 / 6.0 + +# Global scale factor for MXFP4: 448 * 6 = 2688 +MXFP4_GLOBAL_SCALE_FACTOR = 448.0 * 6.0 + + +# ============================================================================= +# Half2 SIMD Intrinsics for Max Reduction +# ============================================================================= + + +@dsl_user_op +def hmax_reduce_to_f32(x: Uint32, *, loc=None, ip=None) -> Float32: + """Extract max of 2 FP16 values in a Half2 as Float32.""" + return Float32( + llvm.inline_asm( + T.f32(), + [Uint32(x).ir_value(loc=loc, ip=ip)], + """ + { + .reg .b16 h0, h1; + .reg .f32 f0, f1; + mov.b32 {h0, h1}, $1; + cvt.f32.f16 f0, h0; + cvt.f32.f16 f1, h1; + max.f32 $0, f0, f1; + } + """, + "=f,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + ) + + +@dsl_user_op +def bfloat2_hmax_reduce_to_f32(x: Uint32, *, loc=None, ip=None) -> Float32: + """Extract max of 2 BF16 values as Float32.""" + return Float32( + llvm.inline_asm( + T.f32(), + [Uint32(x).ir_value(loc=loc, ip=ip)], + """ + { + .reg .b32 lo, hi; + .reg .f32 f0, f1; + and.b32 lo, $1, 0xFFFF; + shr.b32 hi, $1, 16; + shl.b32 lo, lo, 16; + shl.b32 hi, hi, 16; + mov.b32 f0, lo; + mov.b32 f1, hi; + max.f32 $0, f0, f1; + } + """, + "=f,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + ) + + +# ============================================================================= +# Fast UE8M0 Conversion +# ============================================================================= + + +@dsl_user_op +def float_to_ue8m0_fast(value: Float32, *, loc=None, ip=None) -> Uint32: + """ + Convert float to UE8M0 format using fast log2 approximation. + + UE8M0 = ceil(log2(value)) + 127, clamped to [0, 255] + """ + return Uint32( + llvm.inline_asm( + T.i32(), + [Float32(value).ir_value(loc=loc, ip=ip)], + """ + { + .reg .pred p_zero, p_neg, p_ovf; + .reg .f32 log2_val; + .reg .s32 exp_int, result; + + setp.le.f32 p_zero, $1, 0f00000000; + lg2.approx.f32 log2_val, $1; + cvt.rpi.s32.f32 exp_int, log2_val; + add.s32 result, exp_int, 127; + setp.lt.s32 p_neg, result, 0; + setp.gt.s32 p_ovf, result, 255; + selp.s32 result, 0, result, p_neg; + selp.s32 result, 255, result, p_ovf; + selp.s32 $0, 0, result, p_zero; + } + """, + "=r,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + ) + + +@dsl_user_op +def ue8m0_to_inv_scale_fast(ue8m0_val: Uint32, *, loc=None, ip=None) -> Float32: + """ + Convert UE8M0 to inverse scale using fast ex2.approx. + + Inverse scale = 2^(127 - ue8m0) + Returns 0 for ue8m0 == 0. + """ + return Float32( + llvm.inline_asm( + T.f32(), + [Uint32(ue8m0_val).ir_value(loc=loc, ip=ip)], + """ + { + .reg .pred p_zero; + .reg .s32 neg_exp; + .reg .f32 neg_exp_f, result; + + setp.eq.u32 p_zero, $1, 0; + sub.s32 neg_exp, 127, $1; + cvt.rn.f32.s32 neg_exp_f, neg_exp; + ex2.approx.f32 result, neg_exp_f; + selp.f32 $0, 0f00000000, result, p_zero; + } + """, + "=f,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + ) + + +# ============================================================================= +# FP8 Conversion with Scaling +# ============================================================================= + + +@dsl_user_op +def half2_to_fp8x2_scaled( + h2: Uint32, inv_scale: Float32, *, loc=None, ip=None +) -> Uint32: + """Convert Half2 to 2 FP8 E4M3 values with scaling.""" + return Uint32( + llvm.inline_asm( + T.i32(), + [ + Uint32(h2).ir_value(loc=loc, ip=ip), + Float32(inv_scale).ir_value(loc=loc, ip=ip), + ], + """ + { + .reg .b16 h0, h1; + .reg .f32 f0, f1; + .reg .b16 fp8_pair; + + mov.b32 {h0, h1}, $1; + cvt.f32.f16 f0, h0; + cvt.f32.f16 f1, h1; + mul.f32 f0, f0, $2; + mul.f32 f1, f1, $2; + cvt.rn.satfinite.e4m3x2.f32 fp8_pair, f1, f0; + cvt.u32.u16 $0, fp8_pair; + } + """, + "=r,r,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + ) + + +@dsl_user_op +def bfloat2_to_fp8x2_scaled( + bf2: Uint32, inv_scale: Float32, *, loc=None, ip=None +) -> Uint32: + """Convert BFloat16x2 to 2 FP8 E4M3 values with scaling.""" + return Uint32( + llvm.inline_asm( + T.i32(), + [ + Uint32(bf2).ir_value(loc=loc, ip=ip), + Float32(inv_scale).ir_value(loc=loc, ip=ip), + ], + """ + { + .reg .b32 lo, hi; + .reg .f32 f0, f1; + .reg .b16 fp8_pair; + + and.b32 lo, $1, 0xFFFF; + shr.b32 hi, $1, 16; + shl.b32 lo, lo, 16; + shl.b32 hi, hi, 16; + mov.b32 f0, lo; + mov.b32 f1, hi; + mul.f32 f0, f0, $2; + mul.f32 f1, f1, $2; + cvt.rn.satfinite.e4m3x2.f32 fp8_pair, f1, f0; + cvt.u32.u16 $0, fp8_pair; + } + """, + "=r,r,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + ) + + +@dsl_user_op +def pack_fp8x8_to_u64( + fp8_01: Uint32, fp8_23: Uint32, fp8_45: Uint32, fp8_67: Uint32, *, loc=None, ip=None +) -> Uint64: + """Pack 8 FP8 values into a 64-bit value for vectorized store.""" + return Uint64( + llvm.inline_asm( + T.i64(), + [ + Uint32(fp8_01).ir_value(loc=loc, ip=ip), + Uint32(fp8_23).ir_value(loc=loc, ip=ip), + Uint32(fp8_45).ir_value(loc=loc, ip=ip), + Uint32(fp8_67).ir_value(loc=loc, ip=ip), + ], + """ + { + .reg .b32 lo, hi; + .reg .b32 t01, t23, t45, t67; + + and.b32 t01, $1, 0xFFFF; + and.b32 t23, $2, 0xFFFF; + and.b32 t45, $3, 0xFFFF; + and.b32 t67, $4, 0xFFFF; + + shl.b32 t23, t23, 16; + or.b32 lo, t01, t23; + + shl.b32 t67, t67, 16; + or.b32 hi, t45, t67; + + mov.b64 $0, {lo, hi}; + } + """, + "=l,r,r,r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + ) + + +# ============================================================================= +# E2M1 (FP4) Conversion for MXFP4 +# ============================================================================= + + +@dsl_user_op +def half2_to_float2_scaled( + h2: Uint32, scale: Float32, *, loc=None, ip=None +) -> tuple[Float32, Float32]: + """Convert Half2 to Float2 AND multiply by scale.""" + result = llvm.inline_asm( + llvm.StructType.get_literal([T.f32(), T.f32()]), + [Uint32(h2).ir_value(loc=loc, ip=ip), Float32(scale).ir_value(loc=loc, ip=ip)], + """ + { + .reg .b16 h0, h1; + .reg .f32 f0, f1; + mov.b32 {h0, h1}, $2; + cvt.f32.f16 f0, h0; + cvt.f32.f16 f1, h1; + mul.f32 $0, f0, $3; + mul.f32 $1, f1, $3; + } + """, + "=f,=f,r,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + f0 = llvm.extractvalue(T.f32(), result, [0], loc=loc, ip=ip) + f1 = llvm.extractvalue(T.f32(), result, [1], loc=loc, ip=ip) + + return Float32(f0), Float32(f1) + + +@dsl_user_op +def bfloat2_to_float2_scaled( + bf2: Uint32, scale: Float32, *, loc=None, ip=None +) -> tuple[Float32, Float32]: + """Convert BFloat16x2 to Float2 AND multiply by scale.""" + result = llvm.inline_asm( + llvm.StructType.get_literal([T.f32(), T.f32()]), + [Uint32(bf2).ir_value(loc=loc, ip=ip), Float32(scale).ir_value(loc=loc, ip=ip)], + """ + { + .reg .b32 lo, hi; + .reg .f32 f0, f1; + and.b32 lo, $2, 0xFFFF; + shr.b32 hi, $2, 16; + shl.b32 lo, lo, 16; + shl.b32 hi, hi, 16; + mov.b32 f0, lo; + mov.b32 f1, hi; + mul.f32 $0, f0, $3; + mul.f32 $1, f1, $3; + } + """, + "=f,=f,r,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + f0 = llvm.extractvalue(T.f32(), result, [0], loc=loc, ip=ip) + f1 = llvm.extractvalue(T.f32(), result, [1], loc=loc, ip=ip) + + return Float32(f0), Float32(f1) + + +@dsl_user_op +def cvt_e2m1x8_f32( + v0: Float32, + v1: Float32, + v2: Float32, + v3: Float32, + v4: Float32, + v5: Float32, + v6: Float32, + v7: Float32, + *, + loc=None, + ip=None, +) -> Uint32: + """ + Convert eight float32 values to eight E2M1 (4-bit) values packed into uint32. + + Uses cvt.rn.satfinite.e2m1x2.f32 PTX instruction to convert pairs of f32 + to pairs of 4-bit E2M1 values, then packs all 8 values into a single u32. + """ + return Uint32( + llvm.inline_asm( + T.i32(), + [ + Float32(v0).ir_value(loc=loc, ip=ip), + Float32(v1).ir_value(loc=loc, ip=ip), + Float32(v2).ir_value(loc=loc, ip=ip), + Float32(v3).ir_value(loc=loc, ip=ip), + Float32(v4).ir_value(loc=loc, ip=ip), + Float32(v5).ir_value(loc=loc, ip=ip), + Float32(v6).ir_value(loc=loc, ip=ip), + Float32(v7).ir_value(loc=loc, ip=ip), + ], + """ + { + .reg .b8 byte0, byte1, byte2, byte3; + cvt.rn.satfinite.e2m1x2.f32 byte0, $2, $1; + cvt.rn.satfinite.e2m1x2.f32 byte1, $4, $3; + cvt.rn.satfinite.e2m1x2.f32 byte2, $6, $5; + cvt.rn.satfinite.e2m1x2.f32 byte3, $8, $7; + mov.b32 $0, {byte0, byte1, byte2, byte3}; + } + """, + "=r,f,f,f,f,f,f,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +# ============================================================================= +# Warp Shuffle for 4-Thread Reduction +# ============================================================================= + + +@cute.jit +def shuffle_xor_f32(val: Float32, offset: int) -> Float32: + """XOR shuffle for float32 values.""" + return cute.arch.shuffle_sync_bfly(val, offset=offset) + + +@cute.jit +def reduce_max_4threads(val: Float32) -> Float32: + """Reduce max across 4 consecutive threads using 2 XOR shuffles.""" + from ..cute_dsl.fp4_common import fmax_f32 + + other = shuffle_xor_f32(val, 1) + val = fmax_f32(val, other) + other = shuffle_xor_f32(val, 2) + val = fmax_f32(val, other) + return val + + +# ============================================================================= +# Swizzled Index Computation (GPU-side) +# ============================================================================= + + +@cute.jit +def compute_sf_index_swizzled_128x4_gpu( + row_idx: Int32, + col_idx: Int32, + padded_cols: Int32, +) -> Int32: + """Compute swizzled 128x4 scale factor index on GPU.""" + kColumnGroup0Size = Int32(4) + kRowGroup0Size = Int32(32) + kRowGroup1Size = Int32(128) + + columnIdxInGroup0 = col_idx % kColumnGroup0Size + columnGroupIdx = col_idx // kColumnGroup0Size + columnGroupStride = Int32(512) + + rowIdxInGroup0 = row_idx % kRowGroup0Size + rowIdxInGroup1 = (row_idx % kRowGroup1Size) // kRowGroup0Size + rowGroupIdx = row_idx // kRowGroup1Size + + rowGroup1Stride = Int32(4) + rowGroup0Stride = Int32(16) + rowGroupStride = kRowGroup1Size * padded_cols + + offset = ( + columnIdxInGroup0 + + columnGroupIdx * columnGroupStride + + rowIdxInGroup0 * rowGroup0Stride + + rowIdxInGroup1 * rowGroup1Stride + + rowGroupIdx * rowGroupStride + ) + + return offset + + +# ============================================================================= +# High-Level Helper Functions for MXFP8 Quantization +# ============================================================================= + + +@cute.jit +def half2_max_abs_4(v0: Uint32, v1: Uint32, v2: Uint32, v3: Uint32) -> Uint32: + """ + Compute max absolute value across 4 half2 values (8 FP16 elements). + + Uses tree reduction: 4 -> 2 -> 1 half2 values. + Returns a half2 containing the max absolute value in both lanes. + """ + abs0 = habs2(v0) + abs1 = habs2(v1) + abs2 = habs2(v2) + abs3 = habs2(v3) + max01 = hmax2(abs0, abs1) + max23 = hmax2(abs2, abs3) + return hmax2(max01, max23) + + +@cute.jit +def bfloat2_max_abs_4(v0: Uint32, v1: Uint32, v2: Uint32, v3: Uint32) -> Uint32: + """ + Compute max absolute value across 4 bfloat2 values (8 BF16 elements). + + Uses tree reduction: 4 -> 2 -> 1 bfloat2 values. + Returns a bfloat2 containing the max absolute value in both lanes. + """ + abs0 = bfloat2_habs2(v0) + abs1 = bfloat2_habs2(v1) + abs2 = bfloat2_habs2(v2) + abs3 = bfloat2_habs2(v3) + max01 = bfloat2_hmax2(abs0, abs1) + max23 = bfloat2_hmax2(abs2, abs3) + return bfloat2_hmax2(max01, max23) + + +@cute.jit +def half2x4_to_fp8x8_packed( + v0: Uint32, v1: Uint32, v2: Uint32, v3: Uint32, inv_scale: Float32 +) -> Uint64: + """ + Convert 4 half2 values (8 FP16) to 8 FP8 E4M3 and pack into u64. + + Each half2 is converted to 2 FP8 values using the inverse scale, + then all 8 FP8 values are packed into a single 64-bit value for + efficient vectorized store. + """ + fp8_01 = half2_to_fp8x2_scaled(v0, inv_scale) + fp8_23 = half2_to_fp8x2_scaled(v1, inv_scale) + fp8_45 = half2_to_fp8x2_scaled(v2, inv_scale) + fp8_67 = half2_to_fp8x2_scaled(v3, inv_scale) + return pack_fp8x8_to_u64(fp8_01, fp8_23, fp8_45, fp8_67) + + +@cute.jit +def bfloat2x4_to_fp8x8_packed( + v0: Uint32, v1: Uint32, v2: Uint32, v3: Uint32, inv_scale: Float32 +) -> Uint64: + """ + Convert 4 bfloat2 values (8 BF16) to 8 FP8 E4M3 and pack into u64. + + Each bfloat2 is converted to 2 FP8 values using the inverse scale, + then all 8 FP8 values are packed into a single 64-bit value for + efficient vectorized store. + """ + fp8_01 = bfloat2_to_fp8x2_scaled(v0, inv_scale) + fp8_23 = bfloat2_to_fp8x2_scaled(v1, inv_scale) + fp8_45 = bfloat2_to_fp8x2_scaled(v2, inv_scale) + fp8_67 = bfloat2_to_fp8x2_scaled(v3, inv_scale) + return pack_fp8x8_to_u64(fp8_01, fp8_23, fp8_45, fp8_67) + + +# ============================================================================= +# MXFP4 High-Level Helper Functions +# ============================================================================= + + +@cute.jit +def half2_max_abs_8( + v0: Uint32, + v1: Uint32, + v2: Uint32, + v3: Uint32, + v4: Uint32, + v5: Uint32, + v6: Uint32, + v7: Uint32, +) -> Uint32: + """ + Compute max absolute value across 8 half2 values (16 FP16 elements). + + Uses tree reduction: 8 -> 4 -> 2 -> 1 half2 values. + Returns a half2 containing the max absolute value in both lanes. + """ + abs0 = habs2(v0) + abs1 = habs2(v1) + abs2 = habs2(v2) + abs3 = habs2(v3) + abs4 = habs2(v4) + abs5 = habs2(v5) + abs6 = habs2(v6) + abs7 = habs2(v7) + + max01 = hmax2(abs0, abs1) + max23 = hmax2(abs2, abs3) + max45 = hmax2(abs4, abs5) + max67 = hmax2(abs6, abs7) + + max0123 = hmax2(max01, max23) + max4567 = hmax2(max45, max67) + + return hmax2(max0123, max4567) + + +@cute.jit +def bfloat2_max_abs_8( + v0: Uint32, + v1: Uint32, + v2: Uint32, + v3: Uint32, + v4: Uint32, + v5: Uint32, + v6: Uint32, + v7: Uint32, +) -> Uint32: + """ + Compute max absolute value across 8 bfloat2 values (16 BF16 elements). + + Uses tree reduction: 8 -> 4 -> 2 -> 1 bfloat2 values. + Returns a bfloat2 containing the max absolute value in both lanes. + """ + abs0 = bfloat2_habs2(v0) + abs1 = bfloat2_habs2(v1) + abs2 = bfloat2_habs2(v2) + abs3 = bfloat2_habs2(v3) + abs4 = bfloat2_habs2(v4) + abs5 = bfloat2_habs2(v5) + abs6 = bfloat2_habs2(v6) + abs7 = bfloat2_habs2(v7) + + max01 = bfloat2_hmax2(abs0, abs1) + max23 = bfloat2_hmax2(abs2, abs3) + max45 = bfloat2_hmax2(abs4, abs5) + max67 = bfloat2_hmax2(abs6, abs7) + + max0123 = bfloat2_hmax2(max01, max23) + max4567 = bfloat2_hmax2(max45, max67) + + return bfloat2_hmax2(max0123, max4567) + + +@cute.jit +def process_mxfp4_block_half(row_tensor, elem_base: Int32) -> tuple: + """ + Process a 32-element MXFP4 block for half precision input. + + Loads 32 FP16 elements, computes the UE8M0 scale factor, converts to E2M1, + and packs the result into two u64 values. + + Args: + row_tensor: Row tensor slice (mInput[row_idx, None]) + elem_base: Starting element index + + Returns: + (scale_ue8m0_u32, scale_ue8m0_u8, packed64_0, packed64_1): + - scale_ue8m0_u32: Scale factor as Uint32 (for inv_scale computation) + - scale_ue8m0_u8: Scale factor as Uint8 (for storage) + - packed64_0, packed64_1: Two Uint64 containing 16 E2M1 values each + """ + from cutlass import Uint8 + + from ..cute_dsl.fp4_common import get_ptr_as_int64, hmax2, ld_global_v4_u32 + + # Load 32 elements (4 x 128-bit = 16 half2 values) + ptr0 = get_ptr_as_int64(row_tensor, elem_base) + ptr1 = get_ptr_as_int64(row_tensor, elem_base + Int32(8)) + ptr2 = get_ptr_as_int64(row_tensor, elem_base + Int32(16)) + ptr3 = get_ptr_as_int64(row_tensor, elem_base + Int32(24)) + + h0, h1, h2, h3 = ld_global_v4_u32(ptr0) + h4, h5, h6, h7 = ld_global_v4_u32(ptr1) + h8, h9, h10, h11 = ld_global_v4_u32(ptr2) + h12, h13, h14, h15 = ld_global_v4_u32(ptr3) + + # Compute max absolute value across 32 elements + max_first = half2_max_abs_8(h0, h1, h2, h3, h4, h5, h6, h7) + max_second = half2_max_abs_8(h8, h9, h10, h11, h12, h13, h14, h15) + block_max_h2 = hmax2(max_first, max_second) + block_max = hmax_reduce_to_f32(block_max_h2) + + # Compute UE8M0 scale factor + inv_e2m1_max = Float32(INV_FLOAT4_E2M1_MAX) + normalized_max = block_max * inv_e2m1_max + scale_ue8m0_u32 = float_to_ue8m0_fast(normalized_max) + scale_ue8m0_u8 = scale_ue8m0_u32.to(Uint8) + + # Compute inverse scale and convert to E2M1 packed format + inv_scale = ue8m0_to_inv_scale_fast(scale_ue8m0_u32) + packed64_0, packed64_1 = half2x16_to_e2m1x32_packed( + h0, + h1, + h2, + h3, + h4, + h5, + h6, + h7, + h8, + h9, + h10, + h11, + h12, + h13, + h14, + h15, + inv_scale, + ) + + return scale_ue8m0_u32, scale_ue8m0_u8, packed64_0, packed64_1 + + +@cute.jit +def process_mxfp4_block_bfloat(row_tensor, elem_base: Int32) -> tuple: + """ + Process a 32-element MXFP4 block for bfloat16 precision input. + + Loads 32 BF16 elements, computes the UE8M0 scale factor, converts to E2M1, + and packs the result into two u64 values. + + Args: + row_tensor: Row tensor slice (mInput[row_idx, None]) + elem_base: Starting element index + + Returns: + (scale_ue8m0_u32, scale_ue8m0_u8, packed64_0, packed64_1): + - scale_ue8m0_u32: Scale factor as Uint32 (for inv_scale computation) + - scale_ue8m0_u8: Scale factor as Uint8 (for storage) + - packed64_0, packed64_1: Two Uint64 containing 16 E2M1 values each + """ + from cutlass import Uint8 + + from ..cute_dsl.fp4_common import bfloat2_hmax2, get_ptr_as_int64, ld_global_v4_u32 + + # Load 32 elements (4 x 128-bit = 16 bfloat2 values) + ptr0 = get_ptr_as_int64(row_tensor, elem_base) + ptr1 = get_ptr_as_int64(row_tensor, elem_base + Int32(8)) + ptr2 = get_ptr_as_int64(row_tensor, elem_base + Int32(16)) + ptr3 = get_ptr_as_int64(row_tensor, elem_base + Int32(24)) + + h0, h1, h2, h3 = ld_global_v4_u32(ptr0) + h4, h5, h6, h7 = ld_global_v4_u32(ptr1) + h8, h9, h10, h11 = ld_global_v4_u32(ptr2) + h12, h13, h14, h15 = ld_global_v4_u32(ptr3) + + # Compute max absolute value across 32 elements + max_first = bfloat2_max_abs_8(h0, h1, h2, h3, h4, h5, h6, h7) + max_second = bfloat2_max_abs_8(h8, h9, h10, h11, h12, h13, h14, h15) + block_max_h2 = bfloat2_hmax2(max_first, max_second) + block_max = bfloat2_hmax_reduce_to_f32(block_max_h2) + + # Compute UE8M0 scale factor + inv_e2m1_max = Float32(INV_FLOAT4_E2M1_MAX) + normalized_max = block_max * inv_e2m1_max + scale_ue8m0_u32 = float_to_ue8m0_fast(normalized_max) + scale_ue8m0_u8 = scale_ue8m0_u32.to(Uint8) + + # Compute inverse scale and convert to E2M1 packed format + inv_scale = ue8m0_to_inv_scale_fast(scale_ue8m0_u32) + packed64_0, packed64_1 = bfloat2x16_to_e2m1x32_packed( + h0, + h1, + h2, + h3, + h4, + h5, + h6, + h7, + h8, + h9, + h10, + h11, + h12, + h13, + h14, + h15, + inv_scale, + ) + + return scale_ue8m0_u32, scale_ue8m0_u8, packed64_0, packed64_1 + + +@cute.jit +def ld_32_elements(row_tensor, elem_base: Int32) -> tuple: + """ + Load 32 elements (16 half2/bfloat2 values) from a row tensor. + + This loads 4 x 128-bit vectors (4 x v4_u32) starting at elem_base. + + Args: + row_tensor: Row tensor slice (mInput[row_idx, None]) + elem_base: Starting element index + + Returns: + Tuple of 16 Uint32 values (h0-h15), each containing 2 fp16/bf16 elements + """ + from ..cute_dsl.fp4_common import get_ptr_as_int64, ld_global_v4_u32 + + ptr0 = get_ptr_as_int64(row_tensor, elem_base) + ptr1 = get_ptr_as_int64(row_tensor, elem_base + Int32(8)) + ptr2 = get_ptr_as_int64(row_tensor, elem_base + Int32(16)) + ptr3 = get_ptr_as_int64(row_tensor, elem_base + Int32(24)) + + h0, h1, h2, h3 = ld_global_v4_u32(ptr0) # Elements 0-7 + h4, h5, h6, h7 = ld_global_v4_u32(ptr1) # Elements 8-15 + h8, h9, h10, h11 = ld_global_v4_u32(ptr2) # Elements 16-23 + h12, h13, h14, h15 = ld_global_v4_u32(ptr3) # Elements 24-31 + + return h0, h1, h2, h3, h4, h5, h6, h7, h8, h9, h10, h11, h12, h13, h14, h15 + + +@cute.jit +def half2x16_to_e2m1x32_packed( + h0: Uint32, + h1: Uint32, + h2: Uint32, + h3: Uint32, + h4: Uint32, + h5: Uint32, + h6: Uint32, + h7: Uint32, + h8: Uint32, + h9: Uint32, + h10: Uint32, + h11: Uint32, + h12: Uint32, + h13: Uint32, + h14: Uint32, + h15: Uint32, + inv_scale: Float32, +) -> tuple: + """ + Convert 16 half2 values (32 FP16) to 32 E2M1 and pack into two u64. + + Each half2 is converted to 2 float32 values using inv_scale, + then groups of 8 floats are converted to 8 E2M1 values packed into u32, + and finally combined into two u64 values for vectorized store. + + Returns: + (packed64_0, packed64_1): Two Uint64 containing 16 E2M1 values each + """ + # Scale and convert each half2 to 2 float32 + s0, s1 = half2_to_float2_scaled(h0, inv_scale) + s2, s3 = half2_to_float2_scaled(h1, inv_scale) + s4, s5 = half2_to_float2_scaled(h2, inv_scale) + s6, s7 = half2_to_float2_scaled(h3, inv_scale) + s8, s9 = half2_to_float2_scaled(h4, inv_scale) + s10, s11 = half2_to_float2_scaled(h5, inv_scale) + s12, s13 = half2_to_float2_scaled(h6, inv_scale) + s14, s15 = half2_to_float2_scaled(h7, inv_scale) + s16, s17 = half2_to_float2_scaled(h8, inv_scale) + s18, s19 = half2_to_float2_scaled(h9, inv_scale) + s20, s21 = half2_to_float2_scaled(h10, inv_scale) + s22, s23 = half2_to_float2_scaled(h11, inv_scale) + s24, s25 = half2_to_float2_scaled(h12, inv_scale) + s26, s27 = half2_to_float2_scaled(h13, inv_scale) + s28, s29 = half2_to_float2_scaled(h14, inv_scale) + s30, s31 = half2_to_float2_scaled(h15, inv_scale) + + # Convert to E2M1 (4 x 8 floats -> 4 x uint32) + packed0 = cvt_e2m1x8_f32(s0, s1, s2, s3, s4, s5, s6, s7) + packed1 = cvt_e2m1x8_f32(s8, s9, s10, s11, s12, s13, s14, s15) + packed2 = cvt_e2m1x8_f32(s16, s17, s18, s19, s20, s21, s22, s23) + packed3 = cvt_e2m1x8_f32(s24, s25, s26, s27, s28, s29, s30, s31) + + # Pack into 2 x 64-bit values + packed64_0 = (Uint64(packed1) << Uint64(32)) | Uint64(packed0) + packed64_1 = (Uint64(packed3) << Uint64(32)) | Uint64(packed2) + + return packed64_0, packed64_1 + + +@cute.jit +def bfloat2x16_to_e2m1x32_packed( + h0: Uint32, + h1: Uint32, + h2: Uint32, + h3: Uint32, + h4: Uint32, + h5: Uint32, + h6: Uint32, + h7: Uint32, + h8: Uint32, + h9: Uint32, + h10: Uint32, + h11: Uint32, + h12: Uint32, + h13: Uint32, + h14: Uint32, + h15: Uint32, + inv_scale: Float32, +) -> tuple: + """ + Convert 16 bfloat2 values (32 BF16) to 32 E2M1 and pack into two u64. + + Each bfloat2 is converted to 2 float32 values using inv_scale, + then groups of 8 floats are converted to 8 E2M1 values packed into u32, + and finally combined into two u64 values for vectorized store. + + Returns: + (packed64_0, packed64_1): Two Uint64 containing 16 E2M1 values each + """ + # Scale and convert each bfloat2 to 2 float32 + s0, s1 = bfloat2_to_float2_scaled(h0, inv_scale) + s2, s3 = bfloat2_to_float2_scaled(h1, inv_scale) + s4, s5 = bfloat2_to_float2_scaled(h2, inv_scale) + s6, s7 = bfloat2_to_float2_scaled(h3, inv_scale) + s8, s9 = bfloat2_to_float2_scaled(h4, inv_scale) + s10, s11 = bfloat2_to_float2_scaled(h5, inv_scale) + s12, s13 = bfloat2_to_float2_scaled(h6, inv_scale) + s14, s15 = bfloat2_to_float2_scaled(h7, inv_scale) + s16, s17 = bfloat2_to_float2_scaled(h8, inv_scale) + s18, s19 = bfloat2_to_float2_scaled(h9, inv_scale) + s20, s21 = bfloat2_to_float2_scaled(h10, inv_scale) + s22, s23 = bfloat2_to_float2_scaled(h11, inv_scale) + s24, s25 = bfloat2_to_float2_scaled(h12, inv_scale) + s26, s27 = bfloat2_to_float2_scaled(h13, inv_scale) + s28, s29 = bfloat2_to_float2_scaled(h14, inv_scale) + s30, s31 = bfloat2_to_float2_scaled(h15, inv_scale) + + # Convert to E2M1 (4 x 8 floats -> 4 x uint32) + packed0 = cvt_e2m1x8_f32(s0, s1, s2, s3, s4, s5, s6, s7) + packed1 = cvt_e2m1x8_f32(s8, s9, s10, s11, s12, s13, s14, s15) + packed2 = cvt_e2m1x8_f32(s16, s17, s18, s19, s20, s21, s22, s23) + packed3 = cvt_e2m1x8_f32(s24, s25, s26, s27, s28, s29, s30, s31) + + # Pack into 2 x 64-bit values + packed64_0 = (Uint64(packed1) << Uint64(32)) | Uint64(packed0) + packed64_1 = (Uint64(packed3) << Uint64(32)) | Uint64(packed2) + + return packed64_0, packed64_1 + + +__all__ = [ + # MXFP8 Constants + "SF_VEC_SIZE", + "INV_FLOAT8_E4M3_MAX", + "WARP_SIZE", + "ELTS_PER_THREAD", + "THREADS_PER_SF", + "SF_BLOCKS_PER_WARP", + "ROW_TILE_SIZE", + # MXFP4 Constants + "MXFP4_SF_VEC_SIZE", + "MXFP4_ELTS_PER_THREAD", + "INV_FLOAT4_E2M1_MAX", + "MXFP4_GLOBAL_SCALE_FACTOR", + # Low-level intrinsics (MXFP8) + "hmax_reduce_to_f32", + "bfloat2_hmax_reduce_to_f32", + "float_to_ue8m0_fast", + "ue8m0_to_inv_scale_fast", + "reduce_max_4threads", + "compute_sf_index_swizzled_128x4_gpu", + # Low-level intrinsics (MXFP4 - E2M1 conversion) + "half2_to_float2_scaled", + "bfloat2_to_float2_scaled", + "cvt_e2m1x8_f32", + # High-level helper functions (MXFP8) + "half2_max_abs_4", + "bfloat2_max_abs_4", + "half2x4_to_fp8x8_packed", + "bfloat2x4_to_fp8x8_packed", + # High-level helper functions (MXFP4) + "half2_max_abs_8", + "bfloat2_max_abs_8", + "process_mxfp4_block_half", + "process_mxfp4_block_bfloat", + "ld_32_elements", + "half2x16_to_e2m1x32_packed", + "bfloat2x16_to_e2m1x32_packed", +] diff --git a/tests/utils/test_fp4_quantize.py b/tests/utils/test_fp4_quantize.py index cf866d2245..b2343dd9d2 100644 --- a/tests/utils/test_fp4_quantize.py +++ b/tests/utils/test_fp4_quantize.py @@ -21,6 +21,16 @@ is_sm12x_supported, ) + +def _is_fp4_supported(device: torch.device) -> bool: + """Check if FP4 quantization is supported on this device.""" + return ( + is_sm100a_supported(device) + or is_sm110a_supported(device) + or is_sm12x_supported(device) + ) + + DTYPES = [torch.float16, torch.bfloat16] # The batch dimension doesn't need to be multiple of 128 SHAPES = [(128, 64), (256, 128), (120, 64), (200, 256), (2048, 2048)] @@ -116,11 +126,7 @@ def test_fp4_quantization( sf_use_ue8m0: bool, is_swizzled: bool, ) -> None: - if not ( - is_sm100a_supported(torch.device(device)) - or is_sm110a_supported(torch.device(device)) - or is_sm12x_supported(torch.device(device)) - ): + if not _is_fp4_supported(torch.device(device)): pytest.skip("Nvfp4 Requires compute capability >= 10 and CUDA >= 12.8") torch.set_default_device(device) torch.manual_seed(seed) @@ -163,11 +169,7 @@ def test_scale_swizzling( seed: int, device: str, ) -> None: - if not ( - is_sm100a_supported(torch.device("cuda")) - or is_sm110a_supported(torch.device("cuda")) - or is_sm12x_supported(torch.device("cuda")) - ): + if not _is_fp4_supported(torch.device("cuda")): pytest.skip("Nvfp4 Requires compute capability >= 10 and CUDA >= 12.8") torch.set_default_device(device) torch.manual_seed(seed) @@ -203,11 +205,7 @@ def test_block_scale_interleave( device: str, ) -> None: """Test the block_scale_interleave function directly.""" - if not ( - is_sm100a_supported(torch.device("cuda")) - or is_sm110a_supported(torch.device("cuda")) - or is_sm12x_supported(torch.device("cuda")) - ): + if not _is_fp4_supported(torch.device("cuda")): pytest.skip("Nvfp4 Requires compute capability >= 10 and CUDA >= 12.8") torch.set_default_device(device) torch.manual_seed(seed) @@ -256,11 +254,7 @@ def test_e2m1_dequantization( sf_use_ue8m0: bool, ) -> None: """Test roundtrip: fp4_quantize -> e2m1_and_ufp8sf_scale_to_float.""" - if not ( - is_sm100a_supported(torch.device("cuda")) - or is_sm110a_supported(torch.device("cuda")) - or is_sm12x_supported(torch.device("cuda")) - ): + if not _is_fp4_supported(torch.device("cuda")): pytest.skip("Nvfp4 Requires compute capability >= 10 and CUDA >= 12.8") torch.set_default_device(device) torch.manual_seed(seed) @@ -323,25 +317,151 @@ def test_e2m1_dequantization( ) +# ============================================================================= +# MXFP4 Quantization Tests (Both Backends) +# ============================================================================= + +MXFP4_SHAPES = [(128, 64), (256, 128), (512, 256), (128, 1024), (1024, 2048)] +MXFP4_BACKENDS = ["cuda", "cute-dsl"] + + +def _is_cute_dsl_available(): + """Check if CuTe-DSL is available.""" + try: + from flashinfer.cute_dsl import is_cute_dsl_available + + return is_cute_dsl_available() + except ImportError: + return False + + +@pytest.mark.parametrize("backend", MXFP4_BACKENDS) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("shape", MXFP4_SHAPES) @pytest.mark.parametrize("device", CUDA_DEVICES) -def test_mxfp4_quantize_roundtrip(device: str): - if not ( - is_sm100a_supported(torch.device(device)) - or is_sm110a_supported(torch.device(device)) - or is_sm12x_supported(torch.device(device)) - ): +@torch.inference_mode() +def test_mxfp4_quantize_roundtrip( + backend: str, + dtype: torch.dtype, + shape: tuple[int, int], + device: str, +) -> None: + """Test MXFP4 quantization roundtrip for both backends.""" + if not _is_fp4_supported(torch.device(device)): pytest.skip("Nvfp4 Requires compute capability >= 10 and CUDA >= 12.8") - x = torch.randn((128, 64), device="cuda", dtype=torch.bfloat16) / 10 + if backend == "cute-dsl" and not _is_cute_dsl_available(): + pytest.skip("CuTe-DSL not available") + + torch.set_default_device(device) + torch.manual_seed(42) + + m, n = shape + x = torch.randn((m, n), dtype=dtype) + + # Test specified backend + quant_out, scale_out = mxfp4_quantize(x, backend=backend) + + # Basic shape checks + assert quant_out.shape == (m, n // 2), ( + f"Expected shape ({m}, {n // 2}), got {quant_out.shape}" + ) + assert quant_out.dtype == torch.uint8, f"Expected uint8, got {quant_out.dtype}" + assert scale_out.dtype == torch.uint8, f"Expected uint8, got {scale_out.dtype}" - quant_a, sfs = mxfp4_quantize(x) - dq_a = mxfp4_dequantize(quant_a, sfs) + # Check roundtrip with mxfp4_dequantize + dq_out = mxfp4_dequantize(quant_out, scale_out) + # Verify no NaN/Inf + assert not torch.isnan(dq_out).any(), "Dequantized tensor contains NaN" + assert not torch.isinf(dq_out).any(), "Dequantized tensor contains Inf" + + # Verify roundtrip is reasonably accurate torch.testing.assert_close( - dq_a.cpu().to(torch.float32), + dq_out.cpu().to(torch.float32), x.cpu().to(torch.float32), rtol=0.3, atol=0.5, - msg="Quantize -> dequantize mxfp4 roundtrip failed", + msg=f"{backend} MXFP4 quantize -> dequantize roundtrip failed", + ) + + +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("shape", MXFP4_SHAPES) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_mxfp4_quantize_backend_parity( + dtype: torch.dtype, + shape: tuple[int, int], + device: str, +) -> None: + """Test that CUDA and CuTe-DSL backends produce matching results.""" + if not _is_fp4_supported(torch.device(device)): + pytest.skip("Nvfp4 Requires compute capability >= 10 and CUDA >= 12.8") + if not _is_cute_dsl_available(): + pytest.skip("CuTe-DSL not available") + + torch.set_default_device(device) + torch.manual_seed(42) + + m, n = shape + x = torch.randn((m, n), dtype=dtype) + + # Get results from both backends + quant_cuda, scale_cuda = mxfp4_quantize(x, backend="cuda") + quant_cute, scale_cute = mxfp4_quantize(x, backend="cute-dsl") + + # Shape should match + assert quant_cuda.shape == quant_cute.shape, "Quantized output shape mismatch" + assert scale_cuda.shape == scale_cute.shape, "Scale output shape mismatch" + + # Dequantize both and compare + dq_cuda = mxfp4_dequantize(quant_cuda, scale_cuda) + dq_cute = mxfp4_dequantize(quant_cute, scale_cute) + + # Compute detailed error statistics + dq_cuda_f32 = dq_cuda.cpu().to(torch.float32) + dq_cute_f32 = dq_cute.cpu().to(torch.float32) + abs_diff = (dq_cuda_f32 - dq_cute_f32).abs() + rel_diff = abs_diff / (dq_cuda_f32.abs() + 1e-8) + + # Print diagnostic info on failure + max_abs_diff = abs_diff.max().item() + mean_abs_diff = abs_diff.mean().item() + max_rel_diff = rel_diff.max().item() + mean_rel_diff = rel_diff.mean().item() + + # Check quantized data match + quant_match_pct = (quant_cuda == quant_cute).float().mean().item() * 100 + scale_match_pct = (scale_cuda == scale_cute).float().mean().item() * 100 + + error_msg = ( + f"CUDA and CuTe-DSL backends differ after dequantization:\n" + f" Shape: {shape}, dtype: {dtype}\n" + f" Quantized match: {quant_match_pct:.1f}%, Scale match: {scale_match_pct:.1f}%\n" + f" Abs diff - max: {max_abs_diff:.6f}, mean: {mean_abs_diff:.6f}\n" + f" Rel diff - max: {max_rel_diff:.6f}, mean: {mean_rel_diff:.6f}\n" + f" CUDA dq range: [{dq_cuda_f32.min().item():.4f}, {dq_cuda_f32.max().item():.4f}]\n" + f" CuTe dq range: [{dq_cute_f32.min().item():.4f}, {dq_cute_f32.max().item():.4f}]" + ) + + # Verify high agreement between backends + # For FP4 quantization, we expect >95% exact match due to minor rounding differences + assert quant_match_pct > 95.0, ( + f"Quantized values should match >95%, got {quant_match_pct:.1f}%" + ) + assert scale_match_pct > 95.0, ( + f"Scale factors should match >95%, got {scale_match_pct:.1f}%" + ) + + # Both should roundtrip to similar values + # Note: FP4 (E2M1) has coarse quantization steps (0.25-0.5 between adjacent values), + # so we allow atol=0.5 (one quantization step) for edge-case rounding differences. + torch.testing.assert_close( + dq_cuda_f32, + dq_cute_f32, + rtol=0.2, + atol=0.5, # Allow one FP4 quantization step difference + msg=error_msg, ) @@ -357,11 +477,7 @@ def test_nvfp4_batched_quantize( device: str, ) -> None: """Test nvfp4_batched_quantize function.""" - if not ( - is_sm100a_supported(torch.device(device)) - or is_sm110a_supported(torch.device(device)) - or is_sm12x_supported(torch.device(device)) - ): + if not _is_fp4_supported(torch.device(device)): pytest.skip("Nvfp4 Requires compute capability of 10 or above") torch.set_default_device(device) torch.manual_seed(seed) @@ -404,11 +520,7 @@ def test_scaled_fp4_grouped_quantize( device: str, ) -> None: """Test scaled_fp4_grouped_quantize function.""" - if not ( - is_sm100a_supported(torch.device(device)) - or is_sm110a_supported(torch.device(device)) - or is_sm12x_supported(torch.device(device)) - ): + if not _is_fp4_supported(torch.device(device)): pytest.skip("Nvfp4 Requires compute capability of 10 or above") torch.set_default_device(device) torch.manual_seed(seed) @@ -455,11 +567,7 @@ def test_silu_and_mul_scaled_nvfp4_experts_quantize( device: str, ) -> None: """Test silu_and_mul_nvfp4_batched_quantize function.""" - if not ( - is_sm100a_supported(torch.device(device)) - or is_sm110a_supported(torch.device(device)) - or is_sm12x_supported(torch.device(device)) - ): + if not _is_fp4_supported(torch.device(device)): pytest.skip("Nvfp4 Requires compute capability of 10 or above") torch.set_default_device(device) torch.manual_seed(seed) diff --git a/tests/utils/test_fp8_quantize.py b/tests/utils/test_fp8_quantize.py index 1e0b1ec26a..c5da042010 100644 --- a/tests/utils/test_fp8_quantize.py +++ b/tests/utils/test_fp8_quantize.py @@ -5,12 +5,23 @@ from flashinfer.utils import get_compute_capability +def is_cute_dsl_available(): + """Check if CuTe-DSL is available.""" + try: + from flashinfer.cute_dsl import is_cute_dsl_available as _is_available + + return _is_available() + except ImportError: + return False + + @pytest.mark.parametrize("m", [1, 1024]) @pytest.mark.parametrize("k", [1024]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("is_sf_swizzled_layout", [True, False]) @pytest.mark.parametrize("device", ["cuda", "cpu"]) -def test_mxfp8_quantize_torch(m, k, dtype, is_sf_swizzled_layout, device): +@pytest.mark.parametrize("backend", ["cuda", "cute-dsl"]) +def test_mxfp8_quantize_torch(m, k, dtype, is_sf_swizzled_layout, device, backend): if device == "cuda": major, _ = get_compute_capability(torch.device(device)) if major < 10: @@ -18,12 +29,19 @@ def test_mxfp8_quantize_torch(m, k, dtype, is_sf_swizzled_layout, device): "mxfp8 quantization is not supported on compute capability < 10" ) + # Skip cute-dsl backend for CPU or if not available + if backend == "cute-dsl": + if device == "cpu": + pytest.skip("cute-dsl backend only supports CUDA") + if not is_cute_dsl_available(): + pytest.skip("CuTe-DSL is not available") + a = 16 * torch.randn([m, k], dtype=dtype).to(device).contiguous() if device == "cpu": a = a.float() - a_fp8, a_sf = mxfp8_quantize(a, is_sf_swizzled_layout) + a_fp8, a_sf = mxfp8_quantize(a, is_sf_swizzled_layout, backend=backend) if device == "cuda": a_fp8 = a_fp8.cpu() @@ -97,15 +115,19 @@ def test_mxfp8_quantize_torch_host(m, k, dtype, is_sf_swizzled_layout): @pytest.mark.parametrize("k", [512, 1024]) @pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) @pytest.mark.parametrize("is_sf_swizzled_layout", [True, False]) -def test_mxfp8_quantize_torch_device(m, k, dtype, is_sf_swizzled_layout): +@pytest.mark.parametrize("backend", ["cuda", "cute-dsl"]) +def test_mxfp8_quantize_torch_device(m, k, dtype, is_sf_swizzled_layout, backend): major, _ = get_compute_capability(torch.device("cuda:0")) if major < 10: pytest.skip("mxfp8 quantization is not supported on compute capability < 10") + if backend == "cute-dsl" and not is_cute_dsl_available(): + pytest.skip("CuTe-DSL is not available") + torch.random.manual_seed(0) a = (torch.randn([m, k], dtype=torch.float) * 16).to(dtype).cuda().contiguous() - a_fp8, a_sf = mxfp8_quantize(a, is_sf_swizzled_layout, 32) + a_fp8, a_sf = mxfp8_quantize(a, is_sf_swizzled_layout, 32, backend=backend) a_pt = mxfp8_dequantize_host( a_fp8.cpu().view(torch.uint8), a_sf.cpu().view(torch.uint8), @@ -123,19 +145,23 @@ def test_mxfp8_quantize_torch_device(m, k, dtype, is_sf_swizzled_layout): @pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) @pytest.mark.parametrize("is_sf_swizzled_layout", [True, False]) @pytest.mark.parametrize("alignment", [64, 128]) +@pytest.mark.parametrize("backend", ["cuda", "cute-dsl"]) def test_mxfp8_quantize_alignment_torch_device( - m, k, dtype, is_sf_swizzled_layout, alignment + m, k, dtype, is_sf_swizzled_layout, alignment, backend ): major, _ = get_compute_capability(torch.device("cuda:0")) if major < 10: pytest.skip("mxfp8 quantization is not supported on compute capability < 10") + if backend == "cute-dsl" and not is_cute_dsl_available(): + pytest.skip("CuTe-DSL is not available") + torch.random.manual_seed(0) a = (torch.randn([m, k], dtype=torch.float) * 16).to(dtype).cuda().contiguous() padded_k = ((k + alignment - 1) // alignment) * alignment # Quantize it on device. - a_fp8, a_sf = mxfp8_quantize(a, is_sf_swizzled_layout, alignment) + a_fp8, a_sf = mxfp8_quantize(a, is_sf_swizzled_layout, alignment, backend=backend) assert a_fp8.shape[1] == padded_k # Dequantize it on host. @@ -160,7 +186,8 @@ def test_mxfp8_quantize_alignment_torch_device( @pytest.mark.parametrize("k", [1024]) @pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) @pytest.mark.parametrize("is_sf_swizzled_layout", [True, False]) -def test_mxfp8_quantize_denormal_inputs(m, k, dtype, is_sf_swizzled_layout): +@pytest.mark.parametrize("backend", ["cuda", "cute-dsl"]) +def test_mxfp8_quantize_denormal_inputs(m, k, dtype, is_sf_swizzled_layout, backend): """Test that very small denormalized inputs do not produce NaN. This test covers a bug where inputs small enough to cause E8M0 scale factor @@ -170,13 +197,16 @@ def test_mxfp8_quantize_denormal_inputs(m, k, dtype, is_sf_swizzled_layout): if major < 10: pytest.skip("mxfp8 quantization is not supported on compute capability < 10") + if backend == "cute-dsl" and not is_cute_dsl_available(): + pytest.skip("CuTe-DSL is not available") + torch.random.manual_seed(42) # Create very small denormalized values (below float32 normal range ~1.17e-38) # These values caused NaN in the original buggy implementation a = (torch.randn([m, k], dtype=torch.float32) * 1e-38).to(dtype).cuda().contiguous() - a_fp8, a_sf = mxfp8_quantize(a, is_sf_swizzled_layout) + a_fp8, a_sf = mxfp8_quantize(a, is_sf_swizzled_layout, backend=backend) # The primary check: no NaN values should be produced nan_count = torch.isnan(a_fp8.float()).sum().item() @@ -189,16 +219,20 @@ def test_mxfp8_quantize_denormal_inputs(m, k, dtype, is_sf_swizzled_layout): @pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) @pytest.mark.parametrize("is_sf_swizzled_layout", [True, False]) -def test_mxfp8_quantize_all_zeros(dtype, is_sf_swizzled_layout): +@pytest.mark.parametrize("backend", ["cuda", "cute-dsl"]) +def test_mxfp8_quantize_all_zeros(dtype, is_sf_swizzled_layout, backend): """Test that all-zero inputs produce all-zero outputs without NaN.""" major, _ = get_compute_capability(torch.device("cuda:0")) if major < 10: pytest.skip("mxfp8 quantization is not supported on compute capability < 10") + if backend == "cute-dsl" and not is_cute_dsl_available(): + pytest.skip("CuTe-DSL is not available") + m, k = 128, 1024 a = torch.zeros([m, k], dtype=dtype, device="cuda").contiguous() - a_fp8, a_sf = mxfp8_quantize(a, is_sf_swizzled_layout) + a_fp8, a_sf = mxfp8_quantize(a, is_sf_swizzled_layout, backend=backend) # No NaN values assert not torch.isnan(a_fp8.float()).any(), "NaN found in output for zero input" @@ -209,7 +243,8 @@ def test_mxfp8_quantize_all_zeros(dtype, is_sf_swizzled_layout): @pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) @pytest.mark.parametrize("is_sf_swizzled_layout", [True, False]) -def test_mxfp8_quantize_mixed_magnitude(dtype, is_sf_swizzled_layout): +@pytest.mark.parametrize("backend", ["cuda", "cute-dsl"]) +def test_mxfp8_quantize_mixed_magnitude(dtype, is_sf_swizzled_layout, backend): """Test mixed inputs: some blocks with normal values, some with denormals. This mimics real-world scenarios where different regions of a tensor @@ -219,6 +254,9 @@ def test_mxfp8_quantize_mixed_magnitude(dtype, is_sf_swizzled_layout): if major < 10: pytest.skip("mxfp8 quantization is not supported on compute capability < 10") + if backend == "cute-dsl" and not is_cute_dsl_available(): + pytest.skip("CuTe-DSL is not available") + torch.random.manual_seed(123) m, k = 256, 1024 @@ -234,7 +272,7 @@ def test_mxfp8_quantize_mixed_magnitude(dtype, is_sf_swizzled_layout): a = a.to(dtype).cuda().contiguous() - a_fp8, a_sf = mxfp8_quantize(a, is_sf_swizzled_layout) + a_fp8, a_sf = mxfp8_quantize(a, is_sf_swizzled_layout, backend=backend) # No NaN values should be produced anywhere nan_mask = torch.isnan(a_fp8.float()) @@ -250,7 +288,8 @@ def test_mxfp8_quantize_mixed_magnitude(dtype, is_sf_swizzled_layout): @pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) @pytest.mark.parametrize("is_sf_swizzled_layout", [True, False]) -def test_mxfp8_quantize_single_denormal_in_block(dtype, is_sf_swizzled_layout): +@pytest.mark.parametrize("backend", ["cuda", "cute-dsl"]) +def test_mxfp8_quantize_single_denormal_in_block(dtype, is_sf_swizzled_layout, backend): """Test a block where most values are normal but one is a tiny denormal. This specifically tests the scenario from the original bug report where @@ -261,6 +300,9 @@ def test_mxfp8_quantize_single_denormal_in_block(dtype, is_sf_swizzled_layout): if major < 10: pytest.skip("mxfp8 quantization is not supported on compute capability < 10") + if backend == "cute-dsl" and not is_cute_dsl_available(): + pytest.skip("CuTe-DSL is not available") + m, k = 64, 1024 # Start with small but normal-range values a = torch.full([m, k], 1e-36, dtype=torch.float32) @@ -273,12 +315,204 @@ def test_mxfp8_quantize_single_denormal_in_block(dtype, is_sf_swizzled_layout): a = a.to(dtype).cuda().contiguous() - a_fp8, a_sf = mxfp8_quantize(a, is_sf_swizzled_layout) + a_fp8, a_sf = mxfp8_quantize(a, is_sf_swizzled_layout, backend=backend) # Check that no NaN is produced nan_mask = torch.isnan(a_fp8.float()) assert not nan_mask.any(), f"Found NaN at positions: {torch.where(nan_mask)}" +# ============================================================================= +# CuTe-DSL Compilation Cache Tests +# ============================================================================= + + +@pytest.mark.parametrize("is_sf_swizzled_layout", [True, False]) +def test_cute_dsl_compilation_cache_m_agnostic(is_sf_swizzled_layout): + """ + Test that the CuTe-DSL compilation cache is M-agnostic. + + Different M values with the same K should reuse the cached kernel, + meaning no recompilation occurs when only M changes. + """ + major, _ = get_compute_capability(torch.device("cuda:0")) + if major < 10: + pytest.skip("mxfp8 quantization is not supported on compute capability < 10") + + if not is_cute_dsl_available(): + pytest.skip("CuTe-DSL is not available") + + from flashinfer.quantization.kernels.mxfp8_quantize import ( + _get_compiled_kernel_linear, + _get_compiled_kernel_swizzled, + ) + + # Get the appropriate cache based on layout + if is_sf_swizzled_layout: + cache_fn = _get_compiled_kernel_swizzled + else: + cache_fn = _get_compiled_kernel_linear + + # Clear the cache to start fresh + cache_fn.cache_clear() + + # Fixed parameters for this test + K = 1024 + dtype = torch.float16 + + # First call with M=1 - should compile + a1 = torch.randn([1, K], dtype=dtype, device="cuda") + mxfp8_quantize(a1, is_sf_swizzled_layout, backend="cute-dsl") + cache_info_after_m1 = cache_fn.cache_info() + assert cache_info_after_m1.misses == 1, "First call should be a cache miss" + assert cache_info_after_m1.hits == 0, "First call should have no hits" + + # Second call with M=16 (different M, same K) - should reuse cached kernel + a2 = torch.randn([16, K], dtype=dtype, device="cuda") + mxfp8_quantize(a2, is_sf_swizzled_layout, backend="cute-dsl") + cache_info_after_m16 = cache_fn.cache_info() + assert cache_info_after_m16.misses == 1, ( + "Second call with different M should still be 1 miss" + ) + assert cache_info_after_m16.hits == 1, ( + "Second call should be a cache hit (M-agnostic)" + ) + + # Third call with M=1024 (different M again, same K) - should reuse cached kernel + a3 = torch.randn([1024, K], dtype=dtype, device="cuda") + mxfp8_quantize(a3, is_sf_swizzled_layout, backend="cute-dsl") + cache_info_after_m1024 = cache_fn.cache_info() + assert cache_info_after_m1024.misses == 1, ( + "Third call with different M should still be 1 miss" + ) + assert cache_info_after_m1024.hits == 2, ( + "Third call should be a cache hit (M-agnostic)" + ) + + # Clean up + cache_fn.cache_clear() + + +@pytest.mark.parametrize("is_sf_swizzled_layout", [True, False]) +def test_cute_dsl_compilation_cache_k_specific(is_sf_swizzled_layout): + """ + Test that the CuTe-DSL compilation cache is K-specific. + + Different K values should create separate cached kernels, + meaning recompilation occurs when K changes. + """ + major, _ = get_compute_capability(torch.device("cuda:0")) + if major < 10: + pytest.skip("mxfp8 quantization is not supported on compute capability < 10") + + if not is_cute_dsl_available(): + pytest.skip("CuTe-DSL is not available") + + from flashinfer.quantization.kernels.mxfp8_quantize import ( + _get_compiled_kernel_linear, + _get_compiled_kernel_swizzled, + ) + + # Get the appropriate cache based on layout + if is_sf_swizzled_layout: + cache_fn = _get_compiled_kernel_swizzled + else: + cache_fn = _get_compiled_kernel_linear + + # Clear the cache to start fresh + cache_fn.cache_clear() + + dtype = torch.float16 + M = 16 # Fixed M + + # First call with K=1024 - should compile + a1 = torch.randn([M, 1024], dtype=dtype, device="cuda") + mxfp8_quantize(a1, is_sf_swizzled_layout, backend="cute-dsl") + cache_info_after_k1024 = cache_fn.cache_info() + assert cache_info_after_k1024.misses == 1, "First call should be a cache miss" + + # Second call with K=2048 (different K) - should compile new kernel + a2 = torch.randn([M, 2048], dtype=dtype, device="cuda") + mxfp8_quantize(a2, is_sf_swizzled_layout, backend="cute-dsl") + cache_info_after_k2048 = cache_fn.cache_info() + assert cache_info_after_k2048.misses == 2, ( + "Second call with different K should be a cache miss" + ) + + # Third call with K=1024 again - should hit cache + a3 = torch.randn([M, 1024], dtype=dtype, device="cuda") + mxfp8_quantize(a3, is_sf_swizzled_layout, backend="cute-dsl") + cache_info_after_k1024_again = cache_fn.cache_info() + assert cache_info_after_k1024_again.misses == 2, ( + "Third call with same K=1024 should not add miss" + ) + assert cache_info_after_k1024_again.hits >= 1, ( + "Third call with same K=1024 should hit cache" + ) + + # Clean up + cache_fn.cache_clear() + + +@pytest.mark.parametrize("is_sf_swizzled_layout", [True, False]) +def test_cute_dsl_compilation_cache_dtype_specific(is_sf_swizzled_layout): + """ + Test that the CuTe-DSL compilation cache is dtype-specific. + + Different dtypes (fp16 vs bf16) should create separate cached kernels. + """ + major, _ = get_compute_capability(torch.device("cuda:0")) + if major < 10: + pytest.skip("mxfp8 quantization is not supported on compute capability < 10") + + if not is_cute_dsl_available(): + pytest.skip("CuTe-DSL is not available") + + from flashinfer.quantization.kernels.mxfp8_quantize import ( + _get_compiled_kernel_linear, + _get_compiled_kernel_swizzled, + ) + + # Get the appropriate cache based on layout + if is_sf_swizzled_layout: + cache_fn = _get_compiled_kernel_swizzled + else: + cache_fn = _get_compiled_kernel_linear + + # Clear the cache to start fresh + cache_fn.cache_clear() + + K = 1024 + M = 16 + + # First call with float16 - should compile + a1 = torch.randn([M, K], dtype=torch.float16, device="cuda") + mxfp8_quantize(a1, is_sf_swizzled_layout, backend="cute-dsl") + cache_info_after_fp16 = cache_fn.cache_info() + assert cache_info_after_fp16.misses == 1, "First call (fp16) should be a cache miss" + + # Second call with bfloat16 (different dtype, same K) - should compile new kernel + a2 = torch.randn([M, K], dtype=torch.bfloat16, device="cuda") + mxfp8_quantize(a2, is_sf_swizzled_layout, backend="cute-dsl") + cache_info_after_bf16 = cache_fn.cache_info() + assert cache_info_after_bf16.misses == 2, ( + "Second call (bf16) should be a cache miss (dtype-specific)" + ) + + # Third call with float16 again - should hit cache + a3 = torch.randn([M, K], dtype=torch.float16, device="cuda") + mxfp8_quantize(a3, is_sf_swizzled_layout, backend="cute-dsl") + cache_info_after_fp16_again = cache_fn.cache_info() + assert cache_info_after_fp16_again.misses == 2, ( + "Third call (fp16 again) should not add miss" + ) + assert cache_info_after_fp16_again.hits >= 1, ( + "Third call (fp16 again) should hit cache" + ) + + # Clean up + cache_fn.cache_clear() + + if __name__ == "__main__": pytest.main([__file__, "-v"])