diff --git a/benchmarks/bench_cute_dsl_add_rmsnorm_fp4quant.py b/benchmarks/bench_cute_dsl_add_rmsnorm_fp4quant.py index 6c12f707fb..5325eb6a30 100644 --- a/benchmarks/bench_cute_dsl_add_rmsnorm_fp4quant.py +++ b/benchmarks/bench_cute_dsl_add_rmsnorm_fp4quant.py @@ -73,7 +73,9 @@ def compute_bandwidth_gb_s( return total_bytes / time_s / 1e9 -def bench_fused_cute_dsl(batch_size, hidden_size, dtype, block_size=16): +def bench_fused_cute_dsl( + batch_size, hidden_size, dtype, block_size=16, global_scale=None +): """Benchmark fused CuTe-DSL kernel.""" from flashinfer.cute_dsl.add_rmsnorm_fp4quant import add_rmsnorm_fp4quant @@ -82,7 +84,9 @@ def bench_fused_cute_dsl(batch_size, hidden_size, dtype, block_size=16): x = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype) r = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype) weight = torch.randn(hidden_size, device="cuda", dtype=dtype) - y_fp4 = torch.empty(batch_size, hidden_size // 2, device="cuda", dtype=torch.uint8) + y_fp4 = torch.empty( + batch_size, hidden_size // 2, device="cuda", dtype=torch.float4_e2m1fn_x2 + ) if block_size == 32: block_scale = torch.empty( @@ -105,6 +109,7 @@ def bench_fused_cute_dsl(batch_size, hidden_size, dtype, block_size=16): weight, y_fp4, block_scale, + global_scale=global_scale, eps=eps, block_size=block_size, scale_format=scale_format, @@ -119,10 +124,10 @@ def bench_fused_cute_dsl(batch_size, hidden_size, dtype, block_size=16): return np.median(times) -def bench_fully_separate(batch_size, hidden_size, dtype, block_size=16): - """Benchmark fully separate operations: torch.add + rmsnorm + fp4_quantize. +def bench_unfused(batch_size, hidden_size, dtype, block_size=16, global_scale=None): + """Benchmark unfused operations: torch.add + rmsnorm + fp4_quantize. - Returns tuple of (add_time_ms, rmsnorm_time_ms, fp4_quant_time_ms, total_time_ms) + Returns total time in ms for the combined unfused operation. """ from flashinfer.norm import rmsnorm from flashinfer.fp4_quantization import fp4_quantize @@ -137,125 +142,34 @@ def bench_fully_separate(batch_size, hidden_size, dtype, block_size=16): h = torch.empty_like(x) y_normed = torch.empty_like(x) - # Compute global_scale for fp4_quantize (required when sf_use_ue8m0 is false) - global_scale = torch.tensor([1.0], device="cuda", dtype=torch.float32) - - # Benchmark torch.add alone - times_add = bench_gpu_time( - lambda: torch.add(x, r, out=h), - cold_l2_cache=True, - enable_cupti=True, - use_cuda_graph=False, - dry_run_iters=10, - repeat_iters=100, - ) - t_add = np.median(times_add) - - # Run add once to get h for rmsnorm - torch.add(x, r, out=h) - - # Benchmark rmsnorm alone - times_rmsnorm = bench_gpu_time( - lambda: rmsnorm(h, weight, eps=eps, out=y_normed), - cold_l2_cache=True, - enable_cupti=True, - use_cuda_graph=False, - dry_run_iters=10, - repeat_iters=100, - ) - t_rmsnorm = np.median(times_rmsnorm) - - # Run rmsnorm once to get y_normed for fp4_quantize - rmsnorm(h, weight, eps=eps, out=y_normed) - - # Benchmark fp4_quantize alone - times_fp4 = bench_gpu_time( - lambda: fp4_quantize( + def unfused_operation(): + # Step 1: Add + torch.add(x, r, out=h) + # Step 2: RMSNorm + rmsnorm(h, weight, eps=eps, out=y_normed) + # Step 3: FP4 quantize (with global_scale for NVFP4) + fp4_quantize( y_normed, - global_scale=None if block_size == 32 else global_scale, + global_scale=global_scale, sf_vec_size=block_size, sf_use_ue8m0=(block_size == 32), is_sf_swizzled_layout=False, - ), - cold_l2_cache=True, - enable_cupti=True, - use_cuda_graph=False, - dry_run_iters=10, - repeat_iters=100, - ) - t_fp4 = np.median(times_fp4) - - return t_add, t_rmsnorm, t_fp4, t_add + t_rmsnorm + t_fp4 - - -def bench_partial_separate(batch_size, hidden_size, dtype, block_size=16): - """Benchmark partial separate: torch.add + fused rmsnorm_fp4quant. - - Returns tuple of (add_time_ms, rmsnorm_fp4quant_time_ms, total_time_ms) - """ - from flashinfer.cute_dsl.rmsnorm_fp4quant import rmsnorm_fp4quant_cute_dsl - - eps = 1e-6 - - x = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype) - r = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype) - weight = torch.randn(hidden_size, device="cuda", dtype=dtype) - - # Pre-allocate tensors - h = torch.empty_like(x) - y_fp4 = torch.empty(batch_size, hidden_size // 2, device="cuda", dtype=torch.uint8) - - if block_size == 32: - block_scale = torch.empty( - batch_size, hidden_size // block_size, device="cuda", dtype=torch.uint8 - ) - scale_format = "ue8m0" - else: - block_scale = torch.empty( - batch_size, - hidden_size // block_size, - device="cuda", - dtype=torch.float8_e4m3fn, ) - scale_format = "e4m3" - - # Benchmark torch.add alone - times_add = bench_gpu_time( - lambda: torch.add(x, r, out=h), - cold_l2_cache=True, - enable_cupti=True, - use_cuda_graph=False, - dry_run_iters=10, - repeat_iters=100, - ) - t_add = np.median(times_add) - # Run add once to get h for rmsnorm_fp4quant - torch.add(x, r, out=h) - - # Benchmark fused rmsnorm_fp4quant - times_rmsnorm_fp4 = bench_gpu_time( - lambda: rmsnorm_fp4quant_cute_dsl( - h, - weight, - y_fp4, - block_scale, - eps=eps, - block_size=block_size, - scale_format=scale_format, - ), + # Benchmark combined unfused operation + times = bench_gpu_time( + unfused_operation, cold_l2_cache=True, enable_cupti=True, use_cuda_graph=False, dry_run_iters=10, repeat_iters=100, ) - t_rmsnorm_fp4 = np.median(times_rmsnorm_fp4) - return t_add, t_rmsnorm_fp4, t_add + t_rmsnorm_fp4 + return np.median(times) -def sanity_check_outputs(dtype=torch.float16, block_size=16): +def sanity_check_outputs(dtype=torch.float16, block_size=16, global_scale=None): """Verify CuTe-DSL output matches separate torch.add + RMSNorm + fp4_quantize.""" from flashinfer.cute_dsl.add_rmsnorm_fp4quant import add_rmsnorm_fp4quant from flashinfer.norm import rmsnorm @@ -279,7 +193,7 @@ def sanity_check_outputs(dtype=torch.float16, block_size=16): # CuTe-DSL fused path y_fp4_fused = torch.empty( - batch_size, hidden_size // 2, device="cuda", dtype=torch.uint8 + batch_size, hidden_size // 2, device="cuda", dtype=torch.float4_e2m1fn_x2 ) if block_size == 32: block_scale_fused = torch.empty( @@ -301,6 +215,7 @@ def sanity_check_outputs(dtype=torch.float16, block_size=16): weight, y_fp4_fused, block_scale_fused, + global_scale=global_scale, eps=eps, block_size=block_size, scale_format=scale_format, @@ -311,12 +226,9 @@ def sanity_check_outputs(dtype=torch.float16, block_size=16): y_normed = torch.empty_like(x) rmsnorm(h, weight, eps=eps, out=y_normed) - global_scale = torch.tensor( - [y_normed.abs().max().item() / 6.0], device="cuda", dtype=torch.float32 - ) y_fp4_sep, block_scale_sep = fp4_quantize( y_normed, - global_scale=None if block_size == 32 else global_scale, + global_scale=global_scale, sf_vec_size=block_size, sf_use_ue8m0=(block_size == 32), is_sf_swizzled_layout=False, @@ -326,7 +238,10 @@ def sanity_check_outputs(dtype=torch.float16, block_size=16): # 1. FP4 is very low precision (4 bits), small float differences can flip values # 2. Different scale factor computation between fused and separate paths # 3. Different floating-point operation ordering - match_count = (y_fp4_fused == y_fp4_sep).sum().item() + # View as uint8 for comparison (float4_e2m1fn_x2 doesn't support == operator) + match_count = ( + (y_fp4_fused.view(torch.uint8) == y_fp4_sep.view(torch.uint8)).sum().item() + ) total_count = y_fp4_fused.numel() match_pct = match_count / total_count * 100 @@ -338,16 +253,16 @@ def sanity_check_outputs(dtype=torch.float16, block_size=16): f"FP4 match: {match_pct:.1f}% (expected >= 70%)" ) else: - print(f" OK: ({batch_size}, {hidden_size}) - FP4 match") + print(f" OK: ({batch_size}, {hidden_size}) - FP4 match {match_pct:.1f}%") return all_passed def run_benchmark(): """Run full benchmark suite.""" - print("=" * 120) + print("=" * 80) print("Fused Add + RMSNorm + FP4 Quantization Benchmark") - print("=" * 120) + print("=" * 80) cc = get_cc() print(f"GPU Compute Capability: SM{cc}") @@ -356,15 +271,25 @@ def run_benchmark(): raise RuntimeError("Blackwell GPU (SM100+) required for FP4 quantization") dtype = torch.float16 - block_size = 16 + block_size = 16 # NVFP4 + + # For benchmarking, use a fixed global_scale value + # In production, this would be computed from model calibration + FLOAT4_E2M1_MAX = 6.0 + FLOAT8_E4M3_MAX = float(torch.finfo(torch.float8_e4m3fn).max) + global_scale = torch.tensor( + [FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / 3.0], # Assume typical amax ~3.0 + device="cuda", + dtype=torch.float32, + ) # Sanity check: verify CuTe-DSL output matches separate operations print() print("Running sanity check...") - if sanity_check_outputs(dtype, block_size): + if sanity_check_outputs(dtype, block_size, global_scale): print( "✓ Confirmed: CuTe-DSL output is equivalent to " - "torch.add + RMSNorm + fp4_quantization" + "torch.add + RMSNorm + fp4_quantize" ) else: print("✗ Warning: Some outputs did not match closely") @@ -385,18 +310,12 @@ def run_benchmark(): for hidden_size in hidden_sizes ] - print() - print("Legend:") - print(" Fully Sep = torch.add + RMSNorm + FP4 Quantization (3 kernels)") - print(" Partial Sep = torch.add + fused RMSNorm-FP4Quant (2 kernels)") print() header = ( f"{'Batch':<8} {'Hidden':<8} " - f"{'Fused (µs)':<11} {'BW (GB/s)':<10} " - f"{'Add (µs)':<10} {'RMSNorm (µs)':<13} {'FP4Q (µs)':<10} " - f"{'RN+FP4 (µs)':<12} " - f"{'Full Sep':<10} {'Part Sep':<10} " - f"{'vs Full':<9} {'vs Part':<9}" + f"{'Fused (µs)':<12} {'BW (GB/s)':<10} " + f"{'Unfused (µs)':<14} " + f"{'Speedup':<10}" ) print(header) print("-" * len(header)) @@ -406,7 +325,9 @@ def run_benchmark(): for batch_size, hidden_size in configs: # Fused CuTe-DSL kernel timing (add + rmsnorm + fp4quant all in one) try: - t_fused = bench_fused_cute_dsl(batch_size, hidden_size, dtype, block_size) + t_fused = bench_fused_cute_dsl( + batch_size, hidden_size, dtype, block_size, global_scale + ) t_fused_us = t_fused * 1e3 # ms to µs bw_fused = compute_bandwidth_gb_s( batch_size, hidden_size, block_size, t_fused @@ -415,61 +336,27 @@ def run_benchmark(): print(f"{batch_size:<8} {hidden_size:<8} FUSED ERROR: {e}") continue - # Fully separate: torch.add + rmsnorm + fp4_quantize - try: - t_add, t_rmsnorm, t_fp4, t_full_sep = bench_fully_separate( - batch_size, hidden_size, dtype, block_size - ) - t_add_us = t_add * 1e3 # ms to µs - t_rmsnorm_us = t_rmsnorm * 1e3 # ms to µs - t_fp4_us = t_fp4 * 1e3 # ms to µs - t_full_sep_us = t_full_sep * 1e3 # ms to µs - speedup_full = t_full_sep / t_fused if t_fused > 0 else 0 - add_str = f"{t_add_us:.1f}" - rmsnorm_str = f"{t_rmsnorm_us:.1f}" - fp4_str = f"{t_fp4_us:.1f}" - full_sep_str = f"{t_full_sep_us:.1f}" - speedup_full_str = f"{speedup_full:.2f}x" - except Exception as e: - print(f"{batch_size:<8} {hidden_size:<8} FULLY SEPARATE ERROR: {e}") - t_add_us = None - t_rmsnorm_us = None - t_fp4_us = None - t_full_sep_us = None - add_str = "N/A" - rmsnorm_str = "N/A" - fp4_str = "N/A" - full_sep_str = "N/A" - speedup_full_str = "N/A" - speedup_full = None - - # Partial separate: torch.add + fused rmsnorm_fp4quant + # Unfused: torch.add + rmsnorm + fp4_quantize try: - t_add_p, t_rn_fp4, t_part_sep = bench_partial_separate( - batch_size, hidden_size, dtype, block_size + t_unfused = bench_unfused( + batch_size, hidden_size, dtype, block_size, global_scale ) - t_rn_fp4_us = t_rn_fp4 * 1e3 # ms to µs - t_part_sep_us = t_part_sep * 1e3 # ms to µs - speedup_part = t_part_sep / t_fused if t_fused > 0 else 0 - rn_fp4_str = f"{t_rn_fp4_us:.1f}" - part_sep_str = f"{t_part_sep_us:.1f}" - speedup_part_str = f"{speedup_part:.2f}x" + t_unfused_us = t_unfused * 1e3 # ms to µs + speedup = t_unfused / t_fused if t_fused > 0 else 0 + unfused_str = f"{t_unfused_us:.1f}" + speedup_str = f"{speedup:.2f}x" except Exception as e: - print(f"{batch_size:<8} {hidden_size:<8} PARTIAL SEPARATE ERROR: {e}") - t_rn_fp4_us = None - t_part_sep_us = None - rn_fp4_str = "N/A" - part_sep_str = "N/A" - speedup_part_str = "N/A" - speedup_part = None + print(f"{batch_size:<8} {hidden_size:<8} UNFUSED ERROR: {e}") + t_unfused_us = None + unfused_str = "N/A" + speedup_str = "N/A" + speedup = None print( f"{batch_size:<8} {hidden_size:<8} " - f"{t_fused_us:<11.1f} {bw_fused:<10.1f} " - f"{add_str:<10} {rmsnorm_str:<13} {fp4_str:<10} " - f"{rn_fp4_str:<12} " - f"{full_sep_str:<10} {part_sep_str:<10} " - f"{speedup_full_str:<9} {speedup_part_str:<9}" + f"{t_fused_us:<12.1f} {bw_fused:<10.1f} " + f"{unfused_str:<14} " + f"{speedup_str:<10}" ) result = { @@ -477,42 +364,26 @@ def run_benchmark(): "hidden_size": hidden_size, "fused_us": t_fused_us, "fused_bw_gb_s": bw_fused, - "add_us": t_add_us, - "rmsnorm_us": t_rmsnorm_us, - "fp4_quant_us": t_fp4_us, - "rmsnorm_fp4quant_us": t_rn_fp4_us, - "fully_separate_us": t_full_sep_us, - "partial_separate_us": t_part_sep_us, - "speedup_vs_fully_separate": speedup_full, - "speedup_vs_partial_separate": speedup_part, + "unfused_us": t_unfused_us, + "speedup": speedup, } results.append(result) print() - print("=" * 120) + print("=" * 80) - # Calculate and print geomean speedups - speedups_full = [ - r["speedup_vs_fully_separate"] - for r in results - if r["speedup_vs_fully_separate"] is not None - ] - speedups_part = [ - r["speedup_vs_partial_separate"] - for r in results - if r["speedup_vs_partial_separate"] is not None - ] + # Calculate and print geomean speedup + speedups = [r["speedup"] for r in results if r["speedup"] is not None] - if speedups_full: - geomean_full = gmean(speedups_full) - print(f"Geomean speedup vs Fully Separate (3 kernels): {geomean_full:.2f}x") - if speedups_part: - geomean_part = gmean(speedups_part) - print(f"Geomean speedup vs Partial Separate (2 kernels): {geomean_part:.2f}x") + if speedups: + geomean_speedup = gmean(speedups) + print( + f"Geomean speedup vs Unfused (add + rmsnorm + fp4_quantize): {geomean_speedup:.2f}x" + ) - print("=" * 120) + print("=" * 80) print("Benchmark Complete") - print("=" * 120) + print("=" * 80) if __name__ == "__main__": diff --git a/benchmarks/bench_cute_dsl_rmsnorm_fp4quant.py b/benchmarks/bench_cute_dsl_rmsnorm_fp4quant.py index 5adb2c20f9..d40cd8b9ba 100644 --- a/benchmarks/bench_cute_dsl_rmsnorm_fp4quant.py +++ b/benchmarks/bench_cute_dsl_rmsnorm_fp4quant.py @@ -84,7 +84,7 @@ def compute_bandwidth_gb_s( return total_bytes / time_s / 1e9 -def bench_cute_dsl(batch_size, hidden_size, dtype, block_size=16): +def bench_cute_dsl(batch_size, hidden_size, dtype, block_size=16, global_scale=None): """Benchmark CuTe-DSL backend.""" from flashinfer.cute_dsl.rmsnorm_fp4quant import rmsnorm_fp4quant @@ -92,7 +92,9 @@ def bench_cute_dsl(batch_size, hidden_size, dtype, block_size=16): x = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype) weight = torch.randn(hidden_size, device="cuda", dtype=dtype) - y_fp4 = torch.empty(batch_size, hidden_size // 2, device="cuda", dtype=torch.uint8) + y_fp4 = torch.empty( + batch_size, hidden_size // 2, device="cuda", dtype=torch.float4_e2m1fn_x2 + ) # Scale factor dtype depends on format if block_size == 32: @@ -116,6 +118,7 @@ def bench_cute_dsl(batch_size, hidden_size, dtype, block_size=16): weight, y_fp4, block_scale, + global_scale=global_scale, eps=eps, block_size=block_size, scale_format=scale_format, @@ -131,10 +134,15 @@ def bench_cute_dsl(batch_size, hidden_size, dtype, block_size=16): return np.median(times) -def bench_separate_flashinfer(batch_size, hidden_size, dtype, block_size=16): +def bench_separate_flashinfer( + batch_size, hidden_size, dtype, block_size=16, global_scale=None +): """Benchmark separate FlashInfer operations: rmsnorm + fp4_quantize. - Returns tuple of (rmsnorm_time_ms, fp4_quant_time_ms, total_time_ms) + For NVFP4 (block_size=16), if global_scale is provided, we include it in the timing. + For MXFP4 (block_size=32), global_scale is not used. + + Returns the total unfused time in ms. """ from flashinfer.norm import rmsnorm from flashinfer.fp4_quantization import fp4_quantize @@ -145,45 +153,32 @@ def bench_separate_flashinfer(batch_size, hidden_size, dtype, block_size=16): weight = torch.randn(hidden_size, device="cuda", dtype=dtype) y_normed = torch.empty_like(x) - # Compute global_scale for fp4_quantize (required when sf_use_ue8m0 is false) - # Use a fixed scale for benchmarking consistency - global_scale = torch.tensor([1.0], device="cuda", dtype=torch.float32) - - # Benchmark rmsnorm alone - times_rmsnorm = bench_gpu_time( - lambda: rmsnorm(x, weight, eps=eps, out=y_normed), - cold_l2_cache=True, - enable_cupti=True, - use_cuda_graph=False, - dry_run_iters=10, - repeat_iters=100, - ) - t_rmsnorm = np.median(times_rmsnorm) - - # Run rmsnorm once to get y_normed for fp4_quantize - rmsnorm(x, weight, eps=eps, out=y_normed) - - # Benchmark fp4_quantize alone - times_fp4 = bench_gpu_time( - lambda: fp4_quantize( + def unfused_operation(): + # Step 1: RMSNorm + rmsnorm(x, weight, eps=eps, out=y_normed) + # Step 2: FP4 quantize (with global_scale for NVFP4) + fp4_quantize( y_normed, - global_scale=None if block_size == 32 else global_scale, + global_scale=global_scale, sf_vec_size=block_size, sf_use_ue8m0=(block_size == 32), is_sf_swizzled_layout=False, - ), + ) + + # Benchmark combined unfused operation + times = bench_gpu_time( + unfused_operation, cold_l2_cache=True, enable_cupti=True, use_cuda_graph=False, dry_run_iters=10, repeat_iters=100, ) - t_fp4 = np.median(times_fp4) - return t_rmsnorm, t_fp4, t_rmsnorm + t_fp4 + return np.median(times) -def sanity_check_outputs(dtype=torch.float16, block_size=16): +def sanity_check_outputs(dtype=torch.float16, block_size=16, global_scale=None): """Verify CuTe-DSL output matches separate RMSNorm + fp4_quantize operations.""" from flashinfer.cute_dsl.rmsnorm_fp4quant import rmsnorm_fp4quant from flashinfer.norm import rmsnorm @@ -206,7 +201,7 @@ def sanity_check_outputs(dtype=torch.float16, block_size=16): # CuTe-DSL path y_fp4_cute = torch.empty( - batch_size, hidden_size // 2, device="cuda", dtype=torch.uint8 + batch_size, hidden_size // 2, device="cuda", dtype=torch.float4_e2m1fn_x2 ) if block_size == 32: block_scale_cute = torch.empty( @@ -227,6 +222,7 @@ def sanity_check_outputs(dtype=torch.float16, block_size=16): weight, y_fp4_cute, block_scale_cute, + global_scale=global_scale, eps=eps, block_size=block_size, scale_format=scale_format, @@ -236,12 +232,9 @@ def sanity_check_outputs(dtype=torch.float16, block_size=16): y_normed = torch.empty_like(x) rmsnorm(x, weight, eps=eps, out=y_normed) - global_scale = torch.tensor( - [y_normed.abs().max().item() / 6.0], device="cuda", dtype=torch.float32 - ) y_fp4_sep, block_scale_sep = fp4_quantize( y_normed, - global_scale=None if block_size == 32 else global_scale, + global_scale=global_scale, sf_vec_size=block_size, sf_use_ue8m0=(block_size == 32), is_sf_swizzled_layout=False, @@ -251,7 +244,10 @@ def sanity_check_outputs(dtype=torch.float16, block_size=16): # 1. FP4 is very low precision (4 bits), small float differences can flip values # 2. Different scale factor computation between fused and separate paths # 3. Different floating-point operation ordering - match_count = (y_fp4_cute == y_fp4_sep).sum().item() + # View as uint8 for comparison (float4_e2m1fn_x2 doesn't support == operator) + match_count = ( + (y_fp4_cute.view(torch.uint8) == y_fp4_sep.view(torch.uint8)).sum().item() + ) total_count = y_fp4_cute.numel() match_pct = match_count / total_count * 100 @@ -264,7 +260,7 @@ def sanity_check_outputs(dtype=torch.float16, block_size=16): f"FP4 match: {match_pct:.1f}% (expected >= 70%)" ) else: - print(f" OK: ({batch_size}, {hidden_size}) - FP4 match") + print(f" OK: ({batch_size}, {hidden_size}) - FP4 match {match_pct:.1f}%") return all_passed @@ -282,15 +278,23 @@ def run_benchmark(): raise RuntimeError("Blackwell GPU (SM100+) required for FP4 quantization") dtype = torch.float16 - block_size = 16 + block_size = 16 # NVFP4 + + # For benchmarking, use a fixed global_scale value + # In production, this would be computed from model calibration + FLOAT4_E2M1_MAX = 6.0 + FLOAT8_E4M3_MAX = float(torch.finfo(torch.float8_e4m3fn).max) + global_scale = torch.tensor( + [FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / 3.0], # Assume typical amax ~3.0 + device="cuda", + dtype=torch.float32, + ) # Sanity check: verify CuTe-DSL output matches separate operations print() print("Running sanity check...") - if sanity_check_outputs(dtype, block_size): - print( - "✓ Confirmed: CuTe-DSL output is equivalent to RMSNorm + fp4_quantization" - ) + if sanity_check_outputs(dtype, block_size, global_scale): + print("✓ Confirmed: CuTe-DSL output is equivalent to RMSNorm + fp4_quantize") else: print("✗ Warning: Some outputs did not match closely") print() @@ -316,9 +320,9 @@ def run_benchmark(): print() header = ( f"{'Batch':<8} {'Hidden':<8} " - f"{'CuTe-DSL (µs)':<14} {'BW (GB/s)':<10} " - f"{'RMSNorm (µs)':<13} {'FP4Q (µs)':<11} {'Separate (µs)':<14} " - f"{'vs Separate':<12}" + f"{'Fused (µs)':<12} {'BW (GB/s)':<10} " + f"{'Unfused (µs)':<14} " + f"{'Speedup':<10}" ) print(header) print("-" * len(header)) @@ -326,56 +330,48 @@ def run_benchmark(): results = [] for batch_size, hidden_size in configs: - # CuTe-DSL timing + # CuTe-DSL fused timing try: - t_cute = bench_cute_dsl(batch_size, hidden_size, dtype, block_size) - t_cute_us = t_cute * 1e3 # ms to µs - bw_cute = compute_bandwidth_gb_s( - batch_size, hidden_size, block_size, t_cute + t_fused = bench_cute_dsl( + batch_size, hidden_size, dtype, block_size, global_scale + ) + t_fused_us = t_fused * 1e3 # ms to µs + bw_fused = compute_bandwidth_gb_s( + batch_size, hidden_size, block_size, t_fused ) except Exception as e: - print(f"{batch_size:<8} {hidden_size:<8} ERROR: {e}") + print(f"{batch_size:<8} {hidden_size:<8} FUSED ERROR: {e}") continue - # Separate FlashInfer timing + # Unfused (separate) FlashInfer timing try: - t_rmsnorm, t_fp4, t_separate = bench_separate_flashinfer( - batch_size, hidden_size, dtype, block_size + t_unfused = bench_separate_flashinfer( + batch_size, hidden_size, dtype, block_size, global_scale ) - t_rmsnorm_us = t_rmsnorm * 1e3 # ms to µs - t_fp4_us = t_fp4 * 1e3 # ms to µs - t_separate_us = t_separate * 1e3 # ms to µs - speedup_sep = t_separate / t_cute if t_cute > 0 else 0 - rmsnorm_str = f"{t_rmsnorm_us:.1f}" - fp4_str = f"{t_fp4_us:.1f}" - separate_str = f"{t_separate_us:.1f}" - speedup_sep_str = f"{speedup_sep:.2f}x" + t_unfused_us = t_unfused * 1e3 # ms to µs + speedup = t_unfused / t_fused if t_fused > 0 else 0 + unfused_str = f"{t_unfused_us:.1f}" + speedup_str = f"{speedup:.2f}x" except Exception: - t_rmsnorm_us = None - t_fp4_us = None - t_separate_us = None - rmsnorm_str = "N/A" - fp4_str = "N/A" - separate_str = "N/A" - speedup_sep_str = "N/A" - speedup_sep = None + t_unfused_us = None + unfused_str = "N/A" + speedup_str = "N/A" + speedup = None print( f"{batch_size:<8} {hidden_size:<8} " - f"{t_cute_us:<14.1f} {bw_cute:<10.1f} " - f"{rmsnorm_str:<13} {fp4_str:<11} {separate_str:<14} " - f"{speedup_sep_str:<12}" + f"{t_fused_us:<12.1f} {bw_fused:<10.1f} " + f"{unfused_str:<14} " + f"{speedup_str:<10}" ) result = { "batch_size": batch_size, "hidden_size": hidden_size, - "cute_dsl_us": t_cute_us, - "cute_dsl_bw_gb_s": bw_cute, - "rmsnorm_us": t_rmsnorm_us, - "fp4_quant_us": t_fp4_us, - "separate_us": t_separate_us, - "speedup_vs_separate": speedup_sep, + "fused_us": t_fused_us, + "fused_bw_gb_s": bw_fused, + "unfused_us": t_unfused_us, + "speedup": speedup, } results.append(result) @@ -383,15 +379,13 @@ def run_benchmark(): print("=" * 80) # Calculate and print geomean speedup - speedups = [ - r["speedup_vs_separate"] - for r in results - if r["speedup_vs_separate"] is not None - ] + speedups = [r["speedup"] for r in results if r["speedup"] is not None] if speedups: geomean_speedup = gmean(speedups) - print(f"Geomean speedup vs Separate (2 kernels): {geomean_speedup:.2f}x") + print( + f"Geomean speedup vs Unfused (rmsnorm + fp4_quantize): {geomean_speedup:.2f}x" + ) print("=" * 80) print("Benchmark Complete") diff --git a/flashinfer/cute_dsl/add_rmsnorm_fp4quant.py b/flashinfer/cute_dsl/add_rmsnorm_fp4quant.py index 369406a67c..b56b87d95c 100644 --- a/flashinfer/cute_dsl/add_rmsnorm_fp4quant.py +++ b/flashinfer/cute_dsl/add_rmsnorm_fp4quant.py @@ -1013,11 +1013,18 @@ def __call__( w_ptr: cute.Pointer, y_ptr: cute.Pointer, s_ptr: cute.Pointer, + global_scale_ptr: cute.Pointer, M: Int32, eps: Float32, stream: cuda.CUstream, ): - """Host function to launch the kernel.""" + """Host function to launch the kernel. + + Args: + global_scale_ptr: Pointer to global scale tensor (shape [1], float32). + The kernel reads this value and computes 1/global_scale to apply: + y = rmsnorm(h) / global_scale. Use tensor with value 1.0 for no scaling. + """ H = self.H mX = cute.make_tensor( @@ -1058,6 +1065,12 @@ def __call__( ), ) + # Create global scale tensor (scalar) + mGlobalScale = cute.make_tensor( + global_scale_ptr, + layout=cute.make_layout((1,)), + ) + tv_shape, tv_stride = self._make_tv_layout( self.threads_per_row, self.rows_per_block, @@ -1067,7 +1080,9 @@ def __call__( tv_layout = cute.make_layout(tv_shape, stride=tv_stride) tiler_mn = (self.rows_per_block, self.cols_per_tile) - self.kernel(mX, mR, mW, mY, mS, M, eps, tv_layout, tiler_mn).launch( + self.kernel( + mX, mR, mW, mY, mS, mGlobalScale, M, eps, tv_layout, tiler_mn + ).launch( grid=[cute.ceil_div(M, self.rows_per_block), self.cluster_n, 1], block=[self.num_threads, 1, 1], cluster=[1, self.cluster_n, 1] @@ -1085,12 +1100,18 @@ def kernel( mW: cute.Tensor, mY: cute.Tensor, mS: cute.Tensor, + mGlobalScale: cute.Tensor, M: Int32, eps: Float32, tv_layout: cute.Layout, tiler_mn: cute.Shape, ): - """Device kernel with cluster sync and Half2 SIMD.""" + """Device kernel with cluster sync and Half2 SIMD. + + mGlobalScale contains the global scale value. The kernel reads it and + computes 1/global_scale, which is multiplied with rstd to apply: + y = h * rstd * w / global_scale = rmsnorm(h, w) / global_scale + """ tidx, _, _ = cute.arch.thread_idx() bidx, _, _ = cute.arch.block_idx() @@ -1250,6 +1271,10 @@ def kernel( mean_sq = sum_sq / H rstd = cute.math.rsqrt(mean_sq + eps, fastmath=True) + # Read global_scale from device memory (CUDA graph compatible) + # Note: global_scale is incorporated into the block scale, NOT applied to input + global_scale_val = mGlobalScale[0] + if cutlass.const_expr(cluster_n > 1): cute.arch.cluster_arrive_relaxed() cute.arch.cluster_wait() @@ -1341,13 +1366,19 @@ def kernel( max_abs = fmax_f32(max_abs, fabs_f32(y14)) max_abs = fmax_f32(max_abs, fabs_f32(y15)) - scale_float = max_abs * fp4_max_rcp + # E4M3: global_scale is incorporated into block scale + # Formula: scale = global_scale * max_abs / FP4_MAX + scale_float = global_scale_val * max_abs * fp4_max_rcp scale_float = fmin_f32( scale_float, Float32(FLOAT8_E4M3_MAX) ) scale_fp8_u32 = cvt_f32_to_e4m3(scale_float) scale_fp8 = Uint8(scale_fp8_u32 & Uint32(0xFF)) - inv_scale = fp8_e4m3_to_f32_and_rcp(scale_fp8_u32) + # inv_scale = global_scale / scale_float to cancel global_scale + inv_scale = ( + fp8_e4m3_to_f32_and_rcp(scale_fp8_u32) + * global_scale_val + ) if cutlass.const_expr(self.output_swizzled): inner_k_idx = sf_idx % Int32(4) @@ -1517,13 +1548,19 @@ def kernel( y12, y13 = bfloat2_to_float2_scaled(hw6, rstd) y14, y15 = bfloat2_to_float2_scaled(hw7, rstd) - scale_float = max_abs * fp4_max_rcp + # E4M3: global_scale is incorporated into block scale + # Formula: scale = global_scale * max_abs / FP4_MAX + scale_float = global_scale_val * max_abs * fp4_max_rcp scale_float = fmin_f32( scale_float, Float32(FLOAT8_E4M3_MAX) ) scale_fp8_u32 = cvt_f32_to_e4m3(scale_float) scale_fp8 = Uint8(scale_fp8_u32 & Uint32(0xFF)) - inv_scale = fp8_e4m3_to_f32_and_rcp(scale_fp8_u32) + # inv_scale = global_scale / scale_float to cancel global_scale + inv_scale = ( + fp8_e4m3_to_f32_and_rcp(scale_fp8_u32) + * global_scale_val + ) if cutlass.const_expr(self.output_swizzled): inner_k_idx = sf_idx % Int32(4) @@ -1712,19 +1749,27 @@ def kernel( max_abs = fmax_f32(max_abs, fabs_f32(y30)) max_abs = fmax_f32(max_abs, fabs_f32(y31)) + # Compute scale factor (E4M3 or UE8M0 based on scale_format) + # For E4M3: global_scale is incorporated into block scale + # For UE8M0 (MXFP4): global_scale is not used if cutlass.const_expr(self.scale_format == "ue8m0"): scale_float = max_abs * fp4_max_rcp scale_ue8m0 = cvt_f32_to_ue8m0(scale_float) scale_u8 = Uint8(scale_ue8m0 & Uint32(0xFF)) inv_scale = ue8m0_to_output_scale(scale_ue8m0) else: - scale_float = max_abs * fp4_max_rcp + # E4M3: scale = global_scale * max_abs / FP4_MAX + scale_float = global_scale_val * max_abs * fp4_max_rcp scale_float = fmin_f32( scale_float, Float32(FLOAT8_E4M3_MAX) ) scale_fp8_u32 = cvt_f32_to_e4m3(scale_float) scale_u8 = Uint8(scale_fp8_u32 & Uint32(0xFF)) - inv_scale = fp8_e4m3_to_f32_and_rcp(scale_fp8_u32) + # inv_scale = global_scale / scale_float to cancel global_scale + inv_scale = ( + fp8_e4m3_to_f32_and_rcp(scale_fp8_u32) + * global_scale_val + ) if cutlass.const_expr(self.output_swizzled): inner_k_idx = sf_idx % Int32(4) @@ -2028,19 +2073,27 @@ def kernel( y28, y29 = bfloat2_to_float2_scaled(hw14, rstd) y30, y31 = bfloat2_to_float2_scaled(hw15, rstd) + # Compute scale factor (E4M3 or UE8M0 based on scale_format) + # For E4M3: global_scale is incorporated into block scale + # For UE8M0 (MXFP4): global_scale is not used if cutlass.const_expr(self.scale_format == "ue8m0"): scale_float = max_abs * fp4_max_rcp scale_ue8m0 = cvt_f32_to_ue8m0(scale_float) scale_u8 = Uint8(scale_ue8m0 & Uint32(0xFF)) inv_scale = ue8m0_to_output_scale(scale_ue8m0) else: - scale_float = max_abs * fp4_max_rcp + # E4M3: scale = global_scale * max_abs / FP4_MAX + scale_float = global_scale_val * max_abs * fp4_max_rcp scale_float = fmin_f32( scale_float, Float32(FLOAT8_E4M3_MAX) ) scale_fp8_u32 = cvt_f32_to_e4m3(scale_float) scale_u8 = Uint8(scale_fp8_u32 & Uint32(0xFF)) - inv_scale = fp8_e4m3_to_f32_and_rcp(scale_fp8_u32) + # inv_scale = global_scale / scale_float to cancel global_scale + inv_scale = ( + fp8_e4m3_to_f32_and_rcp(scale_fp8_u32) + * global_scale_val + ) if cutlass.const_expr(self.output_swizzled): inner_k_idx = sf_idx % Int32(4) @@ -2156,8 +2209,11 @@ def get_cute_pointers(tensors): make_ptr( cutlass.Uint8, 16, cute.AddressSpace.gmem, assumed_align=16 ), # s + make_ptr( + cutlass.Float32, 16, cute.AddressSpace.gmem, assumed_align=4 + ), # global_scale ] - x, r, w, y, s = tensors + x, r, w, y, s, global_scale = tensors return [ make_ptr( cutlass_dtype, x.data_ptr(), cute.AddressSpace.gmem, assumed_align=16 @@ -2174,6 +2230,12 @@ def get_cute_pointers(tensors): make_ptr( cutlass.Uint8, s.data_ptr(), cute.AddressSpace.gmem, assumed_align=16 ), + make_ptr( + cutlass.Float32, + global_scale.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=4, + ), ] kernel_obj = AddRMSNormFP4QuantKernel( @@ -2200,13 +2262,14 @@ def tensor_api( w: torch.Tensor, y: torch.Tensor, s: torch.Tensor, + global_scale: torch.Tensor, M: int, eps: float, ) -> None: """Runtime API that converts tensors to pointers and calls the kernel.""" nonlocal compiled_kernel compiled_kernel( - *get_cute_pointers([x, r, w, y, s]), + *get_cute_pointers([x, r, w, y, s, global_scale]), Int32(M), Float32(eps), cutlass_torch.current_stream(), @@ -2220,17 +2283,19 @@ def add_rmsnorm_fp4quant( input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, - y_fp4: torch.Tensor, - block_scale: torch.Tensor, + y_fp4: torch.Tensor | None = None, + block_scale: torch.Tensor | None = None, + global_scale: torch.Tensor | None = None, eps: float = 1e-6, block_size: int = 16, scale_format: str | None = None, is_sf_swizzled_layout: bool = False, -) -> None: +) -> Tuple[torch.Tensor, torch.Tensor]: """ Fused Add + RMS normalization + FP4 quantization using CuTe-DSL. Computes: ``h = input + residual``, then ``y = RMSNorm(h) * weight``, + optionally applies global scaling (``y = y / global_scale``), and finally quantizes ``y`` to FP4. Parameters @@ -2243,11 +2308,12 @@ def add_rmsnorm_fp4quant( weight : torch.Tensor Weight tensor for RMSNorm, shape ``(hidden_size,)``. Must have the same dtype as input. - y_fp4 : torch.Tensor - Output tensor for quantized values in FP4_E2M1 format, packed as uint8. - Two FP4 values are packed into each uint8 byte. + y_fp4 : torch.Tensor, optional + Output tensor for quantized values in FP4_E2M1 format with dtype + ``torch.float4_e2m1fn_x2``. Shape must be ``(batch_size, hidden_size // 2)`` or matching 3D input. - block_scale : torch.Tensor + If ``None``, will be allocated automatically. + block_scale : torch.Tensor, optional Output tensor for per-block scale factors. - If ``is_sf_swizzled_layout=False`` (default): row-major layout with shape @@ -2258,7 +2324,14 @@ def add_rmsnorm_fp4quant( ``[m_tile][k_tile][outer_m (32)][inner_m (4)][inner_k (4)]``. Dtype should be ``torch.float8_e4m3fn`` for E4M3 format or ``torch.uint8`` - for UE8M0 format. + for UE8M0 format. If ``None``, will be allocated automatically. + global_scale : torch.Tensor, optional + Global scale factor tensor of shape ``(1,)`` with dtype ``torch.float32``. + If provided, the RMSNorm output is divided by this value before quantization: + ``y = rmsnorm(h, w) / global_scale`` where ``h = input + residual``. + This is used for NVFP4 format where a pre-computed global scale lifts + per-block scales into optimal dynamic range. + If ``None``, no global scaling is applied (equivalent to global_scale=1.0). eps : float Epsilon for numerical stability in RMSNorm. Default is ``1e-6``. block_size : int @@ -2277,6 +2350,14 @@ def add_rmsnorm_fp4quant( where ``outer_m = row % 32``, ``inner_m = (row % 128) // 32``, etc. Default is ``False`` (row-major layout). + Returns + ------- + Tuple[torch.Tensor, torch.Tensor] + A tuple of ``(y_fp4, block_scale)``: + + - ``y_fp4``: Quantized FP4 values packed as uint8. + - ``block_scale``: Per-block scale factors. + Notes ----- - Requires SM100+ (Blackwell) for FP4 quantization PTX intrinsics. @@ -2287,15 +2368,13 @@ def add_rmsnorm_fp4quant( is_3d = input.dim() == 3 if is_3d: B, S, H = input.shape - input = input.view(B * S, H).contiguous() - residual = residual.view(B * S, H).contiguous() - y_fp4_2d = y_fp4.view(B * S, -1) - block_scale_2d = block_scale.view(B * S, -1) + input_2d = input.view(B * S, H).contiguous() + residual_2d = residual.view(B * S, H).contiguous() else: - y_fp4_2d = y_fp4 - block_scale_2d = block_scale + input_2d = input + residual_2d = residual - batch_size, hidden_size = input.shape + batch_size, hidden_size = input_2d.shape dtype = input.dtype assert hidden_size % block_size == 0, "hidden_size must be divisible by block_size" @@ -2308,6 +2387,65 @@ def add_rmsnorm_fp4quant( ) sm_version = get_sm_version(input.device) + # Allocate output tensors if not provided + if y_fp4 is None: + if is_3d: + y_fp4 = torch.empty( + (B, S, hidden_size // 2), + dtype=torch.float4_e2m1fn_x2, + device=input.device, + ) + else: + y_fp4 = torch.empty( + (batch_size, hidden_size // 2), + dtype=torch.float4_e2m1fn_x2, + device=input.device, + ) + + if block_scale is None: + # Determine scale dtype based on format + scale_dtype = ( + torch.uint8 if actual_scale_format == "ue8m0" else torch.float8_e4m3fn + ) + num_sf_blocks_per_row = hidden_size // block_size + + if is_sf_swizzled_layout: + # Swizzled layout: flattened with 128x4 tile pattern + num_m_tiles = (batch_size + 127) // 128 + num_k_tiles = (num_sf_blocks_per_row + 3) // 4 + k_tile_stride = 512 + swizzled_size = num_m_tiles * num_k_tiles * k_tile_stride + block_scale = torch.empty( + (swizzled_size,), dtype=scale_dtype, device=input.device + ) + else: + if is_3d: + block_scale = torch.empty( + (B, S, num_sf_blocks_per_row), + dtype=scale_dtype, + device=input.device, + ) + else: + block_scale = torch.empty( + (batch_size, num_sf_blocks_per_row), + dtype=scale_dtype, + device=input.device, + ) + + # Get 2D views for kernel + if is_3d: + y_fp4_2d = y_fp4.view(B * S, -1) + block_scale_2d = ( + block_scale.view(B * S, -1) if not is_sf_swizzled_layout else block_scale + ) + else: + y_fp4_2d = y_fp4 + block_scale_2d = block_scale + + # Create global_scale tensor if not provided (1.0 = no scaling) + if global_scale is None: + global_scale = torch.ones(1, dtype=torch.float32, device=input.device) + tensor_api = _get_compiled_kernel( hidden_size, block_size, @@ -2317,15 +2455,18 @@ def add_rmsnorm_fp4quant( is_sf_swizzled_layout, ) tensor_api( - input.contiguous(), - residual.contiguous(), + input_2d.contiguous(), + residual_2d.contiguous(), weight.contiguous(), y_fp4_2d, block_scale_2d.view(torch.uint8), + global_scale.contiguous(), batch_size, eps, ) + return y_fp4, block_scale + __all__ = [ "AddRMSNormFP4QuantKernel", diff --git a/flashinfer/cute_dsl/rmsnorm_fp4quant.py b/flashinfer/cute_dsl/rmsnorm_fp4quant.py index 71bd527278..bbb55b7651 100644 --- a/flashinfer/cute_dsl/rmsnorm_fp4quant.py +++ b/flashinfer/cute_dsl/rmsnorm_fp4quant.py @@ -1007,6 +1007,7 @@ def __call__( w_ptr: cute.Pointer, y_ptr: cute.Pointer, s_ptr: cute.Pointer, + global_scale_ptr: cute.Pointer, M: Int32, eps: Float32, stream: cuda.CUstream, @@ -1015,6 +1016,11 @@ def __call__( Takes raw pointers and batch size M, creates tensors internally. This avoids the overhead of from_dlpack() at runtime. + + Args: + global_scale_ptr: Pointer to global scale tensor (shape [1], float32). + The kernel reads this value and computes 1/global_scale to apply: + y = rmsnorm(x) / global_scale. Use tensor with value 1.0 for no scaling. """ H = self.H @@ -1054,6 +1060,12 @@ def __call__( ), ) + # Create global scale tensor (scalar) + mGlobalScale = cute.make_tensor( + global_scale_ptr, + layout=cute.make_layout((1,)), + ) + # Create TV layout tv_shape, tv_stride = self._make_tv_layout( self.threads_per_row, @@ -1065,7 +1077,7 @@ def __call__( tiler_mn = (self.rows_per_block, self.cols_per_tile) # Launch with cluster support - self.kernel(mX, mW, mY, mS, M, eps, tv_layout, tiler_mn).launch( + self.kernel(mX, mW, mY, mS, mGlobalScale, M, eps, tv_layout, tiler_mn).launch( grid=[cute.ceil_div(M, self.rows_per_block), self.cluster_n, 1], block=[self.num_threads, 1, 1], cluster=[1, self.cluster_n, 1] @@ -1082,12 +1094,18 @@ def kernel( mW: cute.Tensor, mY: cute.Tensor, mS: cute.Tensor, + mGlobalScale: cute.Tensor, M: Int32, eps: Float32, tv_layout: cute.Layout, tiler_mn: cute.Shape, ): - """Device kernel with cluster synchronization for large H.""" + """Device kernel with cluster synchronization for large H. + + mGlobalScale contains the global scale value. The kernel reads it and + computes 1/global_scale, which is multiplied with rstd to apply: + y = x * rstd * w / global_scale = rmsnorm(x, w) / global_scale + """ tidx, _, _ = cute.arch.thread_idx() bidx, _, _ = cute.arch.block_idx() @@ -1231,6 +1249,10 @@ def kernel( mean_sq = sum_sq / H # Use full H, not H_per_cta rstd = cute.math.rsqrt(mean_sq + eps, fastmath=True) + # Read global_scale from device memory (CUDA graph compatible) + # Note: global_scale is incorporated into the block scale, NOT applied to input + global_scale_val = mGlobalScale[0] + # Sync after reduction if cutlass.const_expr(cluster_n > 1): cute.arch.cluster_arrive_relaxed() @@ -1362,14 +1384,20 @@ def kernel( # ======================================================= # Compute scale factor (FP8 E4M3) - Branchless clamping + # global_scale is incorporated into block scale (not input) + # Formula: scale = global_scale * max_abs / FP4_MAX # ======================================================= - scale_float = max_abs * fp4_max_rcp + scale_float = global_scale_val * max_abs * fp4_max_rcp scale_float = fmin_f32(scale_float, Float32(FLOAT8_E4M3_MAX)) scale_fp8_u32 = cvt_f32_to_e4m3(scale_float) scale_fp8 = Uint8(scale_fp8_u32 & Uint32(0xFF)) - inv_scale = fp8_e4m3_to_f32_and_rcp(scale_fp8_u32) + # inv_scale = global_scale / scale_float + # This cancels the global_scale in scale_float, giving ~6/max_abs + inv_scale = ( + fp8_e4m3_to_f32_and_rcp(scale_fp8_u32) * global_scale_val + ) # ======================================================= # Store scale factor @@ -1604,19 +1632,27 @@ def kernel( y14_c1, y15_c1 = bfloat2_to_float2_scaled(xw7_c1, rstd) # Compute scale factor (E4M3 or UE8M0 based on scale_format) + # For E4M3: global_scale is incorporated into block scale + # For UE8M0 (MXFP4): global_scale is not used if cutlass.const_expr(self.scale_format == "ue8m0"): scale_float = max_abs * fp4_max_rcp scale_ue8m0 = cvt_f32_to_ue8m0(scale_float) scale_u8 = Uint8(scale_ue8m0 & Uint32(0xFF)) inv_scale = ue8m0_to_output_scale(scale_ue8m0) else: - scale_float = max_abs * fp4_max_rcp + # E4M3: scale = global_scale * max_abs / FP4_MAX + scale_float = global_scale_val * max_abs * fp4_max_rcp scale_float = fmin_f32( scale_float, Float32(FLOAT8_E4M3_MAX) ) scale_fp8_u32 = cvt_f32_to_e4m3(scale_float) scale_u8 = Uint8(scale_fp8_u32 & Uint32(0xFF)) - inv_scale = fp8_e4m3_to_f32_and_rcp(scale_fp8_u32) + # inv_scale = global_scale / scale_float + # This cancels the global_scale in scale_float + inv_scale = ( + fp8_e4m3_to_f32_and_rcp(scale_fp8_u32) + * global_scale_val + ) if cutlass.const_expr(self.output_swizzled): inner_k_idx = sf_idx % Int32(4) @@ -1729,8 +1765,11 @@ def get_cute_pointers(tensors): make_ptr( cutlass.Uint8, 16, cute.AddressSpace.gmem, assumed_align=16 ), # s + make_ptr( + cutlass.Float32, 16, cute.AddressSpace.gmem, assumed_align=4 + ), # global_scale ] - x, w, y, s = tensors + x, w, y, s, global_scale = tensors return [ make_ptr( cutlass_dtype, x.data_ptr(), cute.AddressSpace.gmem, assumed_align=16 @@ -1744,6 +1783,12 @@ def get_cute_pointers(tensors): make_ptr( cutlass.Uint8, s.data_ptr(), cute.AddressSpace.gmem, assumed_align=16 ), + make_ptr( + cutlass.Float32, + global_scale.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=4, + ), ] # Create kernel instance @@ -1771,13 +1816,14 @@ def tensor_api( w: torch.Tensor, y: torch.Tensor, s: torch.Tensor, + global_scale: torch.Tensor, M: int, eps: float, ) -> None: """Runtime API that converts tensors to pointers and calls the kernel.""" nonlocal compiled_kernel compiled_kernel( - *get_cute_pointers([x, w, y, s]), + *get_cute_pointers([x, w, y, s, global_scale]), Int32(M), Float32(eps), cutlass_torch.current_stream(), @@ -1790,17 +1836,19 @@ def tensor_api( def rmsnorm_fp4quant( input: torch.Tensor, weight: torch.Tensor, - y_fp4: torch.Tensor, - block_scale: torch.Tensor, + y_fp4: torch.Tensor | None = None, + block_scale: torch.Tensor | None = None, + global_scale: torch.Tensor | None = None, eps: float = 1e-6, block_size: int = 16, scale_format: str | None = None, is_sf_swizzled_layout: bool = False, -) -> None: +) -> Tuple[torch.Tensor, torch.Tensor]: """ Fused RMS normalization with FP4 quantization using CuTe-DSL. - Computes: ``y = RMSNorm(input) * weight``, then quantizes ``y`` to FP4. + Computes: ``y = RMSNorm(input) * weight``, optionally applies global scaling + (``y = y / global_scale``), then quantizes ``y`` to FP4. Parameters ---------- @@ -1810,11 +1858,12 @@ def rmsnorm_fp4quant( weight : torch.Tensor Weight tensor for RMSNorm, shape ``(hidden_size,)``. Must have the same dtype as input. - y_fp4 : torch.Tensor - Output tensor for quantized values in FP4_E2M1 format, packed as uint8. - Two FP4 values are packed into each uint8 byte. + y_fp4 : torch.Tensor, optional + Output tensor for quantized values in FP4_E2M1 format with dtype + ``torch.float4_e2m1fn_x2``. Shape must be ``(batch_size, hidden_size // 2)`` or matching 3D input. - block_scale : torch.Tensor + If ``None``, will be allocated automatically. + block_scale : torch.Tensor, optional Output tensor for per-block scale factors. - If ``is_sf_swizzled_layout=False`` (default): row-major layout with shape @@ -1825,7 +1874,13 @@ def rmsnorm_fp4quant( ``[m_tile][k_tile][outer_m (32)][inner_m (4)][inner_k (4)]``. Dtype should be ``torch.float8_e4m3fn`` for E4M3 format or ``torch.uint8`` - for UE8M0 format. + for UE8M0 format. If ``None``, will be allocated automatically. + global_scale : torch.Tensor, optional + Global scale factor tensor of shape ``(1,)`` with dtype ``torch.float32``. + If provided, the RMSNorm output is divided by this value before quantization: + ``y = rmsnorm(x, w) / global_scale``. This is used for NVFP4 format where + a pre-computed global scale lifts per-block scales into optimal dynamic range. + If ``None``, no global scaling is applied (equivalent to global_scale=1.0). eps : float Epsilon for numerical stability in RMSNorm. Default is ``1e-6``. block_size : int @@ -1844,6 +1899,14 @@ def rmsnorm_fp4quant( where ``outer_m = row % 32``, ``inner_m = (row % 128) // 32``, etc. Default is ``False`` (row-major layout). + Returns + ------- + Tuple[torch.Tensor, torch.Tensor] + A tuple of ``(y_fp4, block_scale)``: + + - ``y_fp4``: Quantized FP4 values packed as uint8. + - ``block_scale``: Per-block scale factors. + Notes ----- - Requires SM100+ (Blackwell) for FP4 quantization PTX intrinsics. @@ -1855,14 +1918,11 @@ def rmsnorm_fp4quant( is_3d = input.dim() == 3 if is_3d: B, S, H = input.shape - input = input.view(B * S, H).contiguous() - y_fp4_2d = y_fp4.view(B * S, -1) - block_scale_2d = block_scale.view(B * S, -1) + input_2d = input.view(B * S, H).contiguous() else: - y_fp4_2d = y_fp4 - block_scale_2d = block_scale + input_2d = input - batch_size, hidden_size = input.shape + batch_size, hidden_size = input_2d.shape dtype = input.dtype assert hidden_size % block_size == 0, "hidden_size must be divisible by block_size" @@ -1876,6 +1936,65 @@ def rmsnorm_fp4quant( ) sm_version = get_sm_version(input.device) + # Allocate output tensors if not provided + if y_fp4 is None: + if is_3d: + y_fp4 = torch.empty( + (B, S, hidden_size // 2), + dtype=torch.float4_e2m1fn_x2, + device=input.device, + ) + else: + y_fp4 = torch.empty( + (batch_size, hidden_size // 2), + dtype=torch.float4_e2m1fn_x2, + device=input.device, + ) + + if block_scale is None: + # Determine scale dtype based on format + scale_dtype = ( + torch.uint8 if actual_scale_format == "ue8m0" else torch.float8_e4m3fn + ) + num_sf_blocks_per_row = hidden_size // block_size + + if is_sf_swizzled_layout: + # Swizzled layout: flattened with 128x4 tile pattern + num_m_tiles = (batch_size + 127) // 128 + num_k_tiles = (num_sf_blocks_per_row + 3) // 4 + k_tile_stride = 512 + swizzled_size = num_m_tiles * num_k_tiles * k_tile_stride + block_scale = torch.empty( + (swizzled_size,), dtype=scale_dtype, device=input.device + ) + else: + if is_3d: + block_scale = torch.empty( + (B, S, num_sf_blocks_per_row), + dtype=scale_dtype, + device=input.device, + ) + else: + block_scale = torch.empty( + (batch_size, num_sf_blocks_per_row), + dtype=scale_dtype, + device=input.device, + ) + + # Get 2D views for kernel + if is_3d: + y_fp4_2d = y_fp4.view(B * S, -1) + block_scale_2d = ( + block_scale.view(B * S, -1) if not is_sf_swizzled_layout else block_scale + ) + else: + y_fp4_2d = y_fp4 + block_scale_2d = block_scale + + # Create global_scale tensor if not provided (1.0 = no scaling) + if global_scale is None: + global_scale = torch.ones(1, dtype=torch.float32, device=input.device) + # Get cached tensor_api and call it directly tensor_api = _get_compiled_kernel( hidden_size, @@ -1886,14 +2005,17 @@ def rmsnorm_fp4quant( is_sf_swizzled_layout, ) tensor_api( - input.contiguous(), + input_2d.contiguous(), weight.contiguous(), y_fp4_2d, block_scale_2d.view(torch.uint8), + global_scale.contiguous(), batch_size, eps, ) + return y_fp4, block_scale + __all__ = [ "RMSNormFP4QuantKernel", diff --git a/tests/norm/test_add_rmsnorm_fp4_quant_cute_dsl.py b/tests/norm/test_add_rmsnorm_fp4_quant_cute_dsl.py index 7e97e6fd86..579a8ef0e0 100644 --- a/tests/norm/test_add_rmsnorm_fp4_quant_cute_dsl.py +++ b/tests/norm/test_add_rmsnorm_fp4_quant_cute_dsl.py @@ -40,10 +40,20 @@ def llama_rms_norm(x, w, eps=1e-6): def dequantize_fp4_output( - y_fp4: torch.Tensor, block_scale: torch.Tensor, block_size: int + y_fp4: torch.Tensor, + block_scale: torch.Tensor, + block_size: int, + global_scale: torch.Tensor | None = None, ): - """Dequantize packed FP4 tensor using the associated block scales.""" - y_fp4_float = cast_from_fp4(y_fp4) + """ + Dequantize packed FP4 tensor using the associated block scales. + + If global_scale is provided, the dequantized values are divided by global_scale + to reverse the scaling applied during quantization. + """ + # View as uint8 for bitwise operations in cast_from_fp4 + # (float4_e2m1fn_x2 and uint8 have the same memory layout) + y_fp4_float = cast_from_fp4(y_fp4.view(torch.uint8)) if y_fp4_float.dim() == 2: b, hidden_size = y_fp4_float.shape assert hidden_size % block_size == 0 @@ -52,7 +62,7 @@ def dequantize_fp4_output( scales = torch.pow(2.0, block_scale.int() - 127).unsqueeze(-1) else: scales = block_scale.float().unsqueeze(-1) - return (y_fp4_float * scales).reshape(b, hidden_size) + result = (y_fp4_float * scales).reshape(b, hidden_size) elif y_fp4_float.dim() == 3: b, s, hidden_size = y_fp4_float.shape assert hidden_size % block_size == 0 @@ -61,10 +71,82 @@ def dequantize_fp4_output( scales = torch.pow(2.0, block_scale.int() - 127).unsqueeze(-1) else: scales = block_scale.float().unsqueeze(-1) - return (y_fp4_float * scales).reshape(b, s, hidden_size) + result = (y_fp4_float * scales).reshape(b, s, hidden_size) else: raise ValueError(f"Unsupported FP4 output rank: {y_fp4_float.dim()}") + # Reverse global scale if it was applied during quantization + if global_scale is not None: + result = result / global_scale.item() + + return result + + +def compute_global_scale( + x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6 +) -> torch.Tensor: + """ + Compute global scale for NVFP4 quantization of add+rmsnorm output. + + global_scale = (FP8_E4M3_MAX * FP4_E2M1_MAX) / max_abs(rmsnorm(x + residual, weight)) + + This ensures the dynamic range of the output fits within the FP4 range. + """ + FLOAT4_E2M1_MAX = 6.0 + FLOAT8_E4M3_MAX = float(torch.finfo(torch.float8_e4m3fn).max) + + # Compute reference add+RMSNorm output + h = x + residual + ref_output = llama_rms_norm(h, weight, eps=eps) + tensor_amax = torch.abs(ref_output).max().to(torch.float32) + global_scale = torch.tensor( + [FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax.item()], + dtype=torch.float32, + device=x.device, + ) + return global_scale + + +def assert_close_with_tiered_tolerance( + actual: torch.Tensor, + expected: torch.Tensor, + tight_rtol: float = 0.1, + tight_atol: float = 0.1, + loose_rtol: float = 0.5, + loose_atol: float = 2.0, + tight_pct: float = 0.99, + msg: str = "", +): + """ + Two-tiered tolerance check for quantized outputs. + + - tight_pct (e.g., 99%) of elements must be within tight tolerance + - 100% of elements must be within loose tolerance + + This handles the expected quantization noise where most elements match closely + but a few outliers may differ more due to rounding boundary effects. + """ + diff = (actual - expected).abs() + rel_diff = diff / (expected.abs() + 1e-8) + + # Check 1: tight_pct of elements within tight tolerance + within_tight = (diff <= tight_atol) | (rel_diff <= tight_rtol) + tight_pct_actual = within_tight.float().mean().item() + assert tight_pct_actual >= tight_pct, ( + f"{msg}: Only {tight_pct_actual * 100:.1f}% of elements within tight tolerance " + f"(rtol={tight_rtol}, atol={tight_atol}), expected {tight_pct * 100:.0f}%" + ) + + # Check 2: 100% of elements within loose tolerance + within_loose = (diff <= loose_atol) | (rel_diff <= loose_rtol) + if not within_loose.all(): + max_diff = diff.max().item() + max_rel = rel_diff.max().item() + raise AssertionError( + f"{msg}: Max diff {max_diff:.4f} (rel: {max_rel:.4f}) exceeds loose tolerance " + f"(rtol={loose_rtol}, atol={loose_atol})" + ) + def requires_cute_dsl(): """Check if CuTe-DSL is available.""" @@ -113,7 +195,7 @@ def test_add_rmsnorm_fp4quant_2d(self, batch_size, hidden_size, dtype, eps): weight = torch.randn(hidden_size, device="cuda", dtype=dtype) y_fp4 = torch.empty( - batch_size, hidden_size // 2, device="cuda", dtype=torch.uint8 + batch_size, hidden_size // 2, device="cuda", dtype=torch.float4_e2m1fn_x2 ) block_scale = torch.empty( batch_size, @@ -129,7 +211,7 @@ def test_add_rmsnorm_fp4quant_2d(self, batch_size, hidden_size, dtype, eps): # Verify output shapes assert y_fp4.shape == (batch_size, hidden_size // 2) assert block_scale.shape == (batch_size, hidden_size // block_size) - assert y_fp4.dtype == torch.uint8 + assert y_fp4.dtype == torch.float4_e2m1fn_x2 assert block_scale.dtype == torch.float8_e4m3fn # Reference computation: h = x + r, then RMSNorm(h) @@ -139,11 +221,14 @@ def test_add_rmsnorm_fp4quant_2d(self, batch_size, hidden_size, dtype, eps): # Dequantize FP4 output for value-level comparison # Tolerance based on separate FP4 roundtrip test (rtol=0.3, atol=0.5) y_dequant = dequantize_fp4_output(y_fp4, block_scale, block_size) - torch.testing.assert_close( + assert_close_with_tiered_tolerance( y_dequant, ref_rmsnorm.float(), - rtol=0.3, - atol=0.5, + tight_rtol=0.3, + tight_atol=0.5, + loose_rtol=0.5, + loose_atol=2.0, + tight_pct=0.99, ) @pytest.mark.parametrize("batch_size", [1, 4, 3, 7, 128]) @@ -165,7 +250,11 @@ def test_add_rmsnorm_fp4quant_3d(self, batch_size, seq_len, hidden_size, dtype): weight = torch.randn(hidden_size, device="cuda", dtype=dtype) y_fp4 = torch.empty( - batch_size, seq_len, hidden_size // 2, device="cuda", dtype=torch.uint8 + batch_size, + seq_len, + hidden_size // 2, + device="cuda", + dtype=torch.float4_e2m1fn_x2, ) block_scale = torch.empty( batch_size, @@ -182,7 +271,7 @@ def test_add_rmsnorm_fp4quant_3d(self, batch_size, seq_len, hidden_size, dtype): # Verify output shapes assert y_fp4.shape == (batch_size, seq_len, hidden_size // 2) assert block_scale.shape == (batch_size, seq_len, hidden_size // block_size) - assert y_fp4.dtype == torch.uint8 + assert y_fp4.dtype == torch.float4_e2m1fn_x2 assert block_scale.dtype == torch.float8_e4m3fn # Reference computation @@ -191,11 +280,14 @@ def test_add_rmsnorm_fp4quant_3d(self, batch_size, seq_len, hidden_size, dtype): # Tolerance based on separate FP4 roundtrip test (rtol=0.3, atol=0.5) y_dequant = dequantize_fp4_output(y_fp4, block_scale, block_size) - torch.testing.assert_close( + assert_close_with_tiered_tolerance( y_dequant, ref_rmsnorm.float(), - rtol=0.3, - atol=0.5, + tight_rtol=0.3, + tight_atol=0.5, + loose_rtol=0.5, + loose_atol=2.0, + tight_pct=0.99, ) @pytest.mark.parametrize( @@ -221,7 +313,7 @@ def test_large_batch(self, batch_size, hidden_size, dtype): weight = torch.randn(hidden_size, device="cuda", dtype=dtype) y_fp4 = torch.empty( - batch_size, hidden_size // 2, device="cuda", dtype=torch.uint8 + batch_size, hidden_size // 2, device="cuda", dtype=torch.float4_e2m1fn_x2 ) block_scale = torch.empty( batch_size, @@ -270,7 +362,7 @@ def test_mxfp4_basic(self, batch_size, hidden_size, dtype): weight = torch.randn(hidden_size, device="cuda", dtype=dtype) y_fp4 = torch.empty( - batch_size, hidden_size // 2, device="cuda", dtype=torch.uint8 + batch_size, hidden_size // 2, device="cuda", dtype=torch.float4_e2m1fn_x2 ) # UE8M0 scale factors are returned as uint8 block_scale = torch.empty( @@ -291,7 +383,7 @@ def test_mxfp4_basic(self, batch_size, hidden_size, dtype): # Verify output shapes assert y_fp4.shape == (batch_size, hidden_size // 2) assert block_scale.shape == (batch_size, hidden_size // block_size) - assert y_fp4.dtype == torch.uint8 + assert y_fp4.dtype == torch.float4_e2m1fn_x2 assert block_scale.dtype == torch.uint8 # Reference computation @@ -339,7 +431,7 @@ def test_fused_vs_separate(self, batch_size, hidden_size, dtype): # Fused kernel y_fp4_fused = torch.empty( - batch_size, hidden_size // 2, device="cuda", dtype=torch.uint8 + batch_size, hidden_size // 2, device="cuda", dtype=torch.float4_e2m1fn_x2 ) block_scale_fused = torch.empty( batch_size, @@ -366,14 +458,293 @@ def test_fused_vs_separate(self, batch_size, hidden_size, dtype): ) # Value-level comparison against reference RMSNorm output - torch.testing.assert_close( + assert_close_with_tiered_tolerance( y_fused_dequant, y_ref.float(), - rtol=0.3, - atol=0.5, + tight_rtol=0.3, + tight_atol=0.5, + loose_rtol=0.5, + loose_atol=2.0, + tight_pct=0.99, ) +@cute_dsl_available +@blackwell_required +class TestFusedVsSeparateFP4Quantize: + """ + Tests comparing fused Add+RMSNorm+FP4Quant against separate add + RMSNorm + fp4_quantize. + + This validates that the fused kernel applies global_scale identically to the + standalone fp4_quantize function. + """ + + @pytest.mark.parametrize("batch_size", [1, 4, 16, 128]) + @pytest.mark.parametrize("hidden_size", [64, 256, 512, 1024, 2048, 4096]) + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + def test_nvfp4_fused_matches_separate(self, batch_size, hidden_size, dtype): + """ + Compare fused kernel against separate add + RMSNorm + fp4_quantize for NVFP4. + + This test verifies that the fused kernel applies global_scale identically + to the standalone fp4_quantize function, by comparing: + 1. The packed FP4 output bytes + 2. The block scale factors + """ + from flashinfer.cute_dsl.add_rmsnorm_fp4quant import add_rmsnorm_fp4quant + from flashinfer import fp4_quantize + + torch.manual_seed(42) + block_size = 16 # NVFP4 + eps = 1e-6 + + x = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype) + r = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype) + weight = torch.randn(hidden_size, device="cuda", dtype=dtype) + + # Compute global_scale for NVFP4 + global_scale = compute_global_scale(x, r, weight, eps=eps) + + # === Fused kernel path === + y_fp4_fused = torch.empty( + batch_size, hidden_size // 2, device="cuda", dtype=torch.float4_e2m1fn_x2 + ) + block_scale_fused = torch.empty( + batch_size, + hidden_size // block_size, + device="cuda", + dtype=torch.float8_e4m3fn, + ) + add_rmsnorm_fp4quant( + x, + r, + weight, + y_fp4_fused, + block_scale_fused, + global_scale=global_scale, + eps=eps, + block_size=block_size, + is_sf_swizzled_layout=False, # Use unswizzled for easier comparison + ) + + # === Separate path: add + RMSNorm + fp4_quantize === + h = x + r + y_rmsnorm = llama_rms_norm(h, weight, eps=eps) + y_fp4_separate, block_scale_separate = fp4_quantize( + y_rmsnorm, + global_scale, + sf_vec_size=block_size, + sf_use_ue8m0=False, # E4M3 for NVFP4 + is_sf_swizzled_layout=False, + ) + + # === Compare FP4 packed outputs === + # View as uint8 for comparison (float4_e2m1fn_x2 doesn't support == operator) + fp4_match = ( + (y_fp4_fused.view(torch.uint8) == y_fp4_separate.view(torch.uint8)) + .float() + .mean() + .item() + ) + assert fp4_match > 0.95, ( + f"FP4 output mismatch: only {fp4_match * 100:.1f}% of bytes match" + ) + + # === Compare block scales === + scale_fused = block_scale_fused.to(torch.float32) + scale_separate = ( + block_scale_separate.view(torch.float8_e4m3fn) + .view(batch_size, -1) + .to(torch.float32) + ) + + scale_match = (scale_fused == scale_separate).float().mean().item() + assert scale_match > 0.95, ( + f"Block scale mismatch: only {scale_match * 100:.1f}% of scales match" + ) + + # === Also verify dequantized values are close === + y_fused_dequant = dequantize_fp4_output( + y_fp4_fused, block_scale_fused, block_size, global_scale + ) + y_separate_dequant = dequantize_fp4_output( + y_fp4_separate, + block_scale_separate.view(torch.float8_e4m3fn).view(batch_size, -1), + block_size, + global_scale, + ) + + # Two-tiered tolerance: 99% within tight tolerance, 100% within loose tolerance + assert_close_with_tiered_tolerance( + y_fused_dequant, + y_separate_dequant, + tight_rtol=0.3, + tight_atol=0.5, + loose_rtol=0.5, + loose_atol=2.0, + tight_pct=0.99, + msg="Dequantized outputs from fused and separate paths should match closely", + ) + + @pytest.mark.parametrize("batch_size", [1, 4, 16, 128]) + @pytest.mark.parametrize("hidden_size", [128, 256, 512, 1024, 2048, 4096]) + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + def test_mxfp4_fused_matches_separate(self, batch_size, hidden_size, dtype): + """ + Compare fused kernel against separate add + RMSNorm + fp4_quantize for MXFP4. + + MXFP4 uses block_size=32, UE8M0 scales, and no global_scale (global_scale=1.0). + """ + from flashinfer.cute_dsl.add_rmsnorm_fp4quant import add_rmsnorm_fp4quant + from flashinfer import fp4_quantize + + torch.manual_seed(42) + block_size = 32 # MXFP4 + eps = 1e-6 + + x = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype) + r = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype) + weight = torch.randn(hidden_size, device="cuda", dtype=dtype) + + # MXFP4 uses global_scale=1.0 + global_scale_val = torch.tensor(1.0, dtype=torch.float32, device="cuda") + + # === Fused kernel path === + y_fp4_fused = torch.empty( + batch_size, hidden_size // 2, device="cuda", dtype=torch.float4_e2m1fn_x2 + ) + block_scale_fused = torch.empty( + batch_size, hidden_size // block_size, device="cuda", dtype=torch.uint8 + ) + add_rmsnorm_fp4quant( + x, + r, + weight, + y_fp4_fused, + block_scale_fused, + eps=eps, + block_size=block_size, + scale_format="ue8m0", + is_sf_swizzled_layout=False, + ) + + # === Separate path: add + RMSNorm + fp4_quantize === + h = x + r + y_rmsnorm = llama_rms_norm(h, weight, eps=eps) + y_fp4_separate, block_scale_separate = fp4_quantize( + y_rmsnorm, + global_scale_val, + sf_vec_size=block_size, + sf_use_ue8m0=True, # UE8M0 for MXFP4 + is_sf_swizzled_layout=False, + ) + + # === Compare FP4 packed outputs === + # View as uint8 for comparison (float4_e2m1fn_x2 doesn't support == operator) + fp4_match = ( + (y_fp4_fused.view(torch.uint8) == y_fp4_separate.view(torch.uint8)) + .float() + .mean() + .item() + ) + assert fp4_match > 0.95, ( + f"FP4 output mismatch: only {fp4_match * 100:.1f}% of bytes match" + ) + + # === Compare block scales === + scale_fused = block_scale_fused + scale_separate = block_scale_separate.view(batch_size, -1) + + scale_match = (scale_fused == scale_separate).float().mean().item() + assert scale_match > 0.95, ( + f"Block scale mismatch: only {scale_match * 100:.1f}% of scales match" + ) + + # === Also verify dequantized values are close === + # MXFP4 has larger errors due to power-of-2 scale constraints + y_fused_dequant = dequantize_fp4_output( + y_fp4_fused, block_scale_fused, block_size + ) + y_separate_dequant = dequantize_fp4_output( + y_fp4_separate, scale_separate, block_size + ) + + # Two-tiered tolerance: 99% within tight tolerance, 100% within loose tolerance + assert_close_with_tiered_tolerance( + y_fused_dequant, + y_separate_dequant, + tight_rtol=0.3, + tight_atol=0.5, + loose_rtol=0.5, + loose_atol=2.0, + tight_pct=0.99, + msg="Dequantized outputs from fused and separate paths should match closely", + ) + + @pytest.mark.parametrize("batch_size", [1, 16, 64]) + @pytest.mark.parametrize("hidden_size", [256, 1024, 4096]) + def test_global_scale_value_consistency(self, batch_size, hidden_size): + """ + Verify that the global_scale value correctly scales the block scales. + + When global_scale is applied: + - block_scale_with_gs = global_scale * max_abs / FP4_MAX + - This should be approximately global_scale times larger than without global_scale + """ + from flashinfer.cute_dsl.add_rmsnorm_fp4quant import add_rmsnorm_fp4quant + + torch.manual_seed(42) + block_size = 16 # NVFP4 + eps = 1e-6 + dtype = torch.float16 + + x = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype) + r = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype) + weight = torch.randn(hidden_size, device="cuda", dtype=dtype) + + # Run with computed global_scale + global_scale = compute_global_scale(x, r, weight, eps=eps) + + y_fp4_gs, block_scale_gs = add_rmsnorm_fp4quant( + x, + r, + weight, + global_scale=global_scale, + eps=eps, + block_size=block_size, + is_sf_swizzled_layout=False, + ) + + # Run without global_scale (global_scale=1.0) + global_scale_one = torch.tensor([1.0], dtype=torch.float32, device="cuda") + + y_fp4_no_gs, block_scale_no_gs = add_rmsnorm_fp4quant( + x, + r, + weight, + global_scale=global_scale_one, + eps=eps, + block_size=block_size, + is_sf_swizzled_layout=False, + ) + + # The block scales with global_scale should be approximately global_scale times + # larger than without (since block_scale = global_scale * max_abs / FP4_MAX) + scale_gs = block_scale_gs.to(torch.float32) + scale_no_gs = block_scale_no_gs.to(torch.float32) + + # Compute ratio where both are non-zero + non_zero_mask = (scale_no_gs > 0) & (scale_gs > 0) + if non_zero_mask.sum() > 0: + ratio = (scale_gs[non_zero_mask] / scale_no_gs[non_zero_mask]).mean().item() + expected_ratio = global_scale.item() + + # Allow some tolerance due to FP8 quantization + assert abs(ratio - expected_ratio) / expected_ratio < 0.2, ( + f"Block scale ratio {ratio:.2f} doesn't match expected global_scale {expected_ratio:.2f}" + ) + + @cute_dsl_available @blackwell_required class TestLargeHiddenSize: @@ -402,7 +773,7 @@ def test_large_hidden_nvfp4(self, batch_size, hidden_size, dtype): weight = torch.randn(hidden_size, device="cuda", dtype=dtype) y_fp4 = torch.empty( - batch_size, hidden_size // 2, device="cuda", dtype=torch.uint8 + batch_size, hidden_size // 2, device="cuda", dtype=torch.float4_e2m1fn_x2 ) block_scale = torch.empty( batch_size, @@ -419,7 +790,7 @@ def test_large_hidden_nvfp4(self, batch_size, hidden_size, dtype): # Verify output shapes assert y_fp4.shape == (batch_size, hidden_size // 2) assert block_scale.shape == (batch_size, hidden_size // block_size) - assert y_fp4.dtype == torch.uint8 + assert y_fp4.dtype == torch.float4_e2m1fn_x2 assert block_scale.dtype == torch.float8_e4m3fn # Sample first few rows for value comparison (full dequant is slow) @@ -455,7 +826,7 @@ def test_large_hidden_mxfp4(self, batch_size, hidden_size, dtype): weight = torch.randn(hidden_size, device="cuda", dtype=dtype) y_fp4 = torch.empty( - batch_size, hidden_size // 2, device="cuda", dtype=torch.uint8 + batch_size, hidden_size // 2, device="cuda", dtype=torch.float4_e2m1fn_x2 ) block_scale = torch.empty( batch_size, hidden_size // block_size, device="cuda", dtype=torch.uint8 @@ -476,7 +847,7 @@ def test_large_hidden_mxfp4(self, batch_size, hidden_size, dtype): # Verify output shapes assert y_fp4.shape == (batch_size, hidden_size // 2) assert block_scale.shape == (batch_size, hidden_size // block_size) - assert y_fp4.dtype == torch.uint8 + assert y_fp4.dtype == torch.float4_e2m1fn_x2 assert block_scale.dtype == torch.uint8 # Sample first few rows for value comparison (full dequant is slow) @@ -555,7 +926,7 @@ def test_nvfp4_swizzled_vs_unswizzled(self, batch_size, hidden_size, dtype): # Non-swizzled output y_fp4_ref = torch.empty( - batch_size, hidden_size // 2, device="cuda", dtype=torch.uint8 + batch_size, hidden_size // 2, device="cuda", dtype=torch.float4_e2m1fn_x2 ) block_scale_ref = torch.empty( batch_size, @@ -570,7 +941,7 @@ def test_nvfp4_swizzled_vs_unswizzled(self, batch_size, hidden_size, dtype): num_k_tiles = (hidden_size + factor - 1) // factor swizzled_size = num_m_tiles * num_k_tiles * 32 * 4 * 4 # 128x4 tile pattern y_fp4_swizzled = torch.empty( - batch_size, hidden_size // 2, device="cuda", dtype=torch.uint8 + batch_size, hidden_size // 2, device="cuda", dtype=torch.float4_e2m1fn_x2 ) block_scale_swizzled = torch.empty( swizzled_size, device="cuda", dtype=torch.float8_e4m3fn @@ -601,8 +972,10 @@ def test_nvfp4_swizzled_vs_unswizzled(self, batch_size, hidden_size, dtype): block_scale_swizzled.view(torch.uint8), batch_size, hidden_size, block_size ).view(torch.float8_e4m3fn) - # FP4 values should be identical - torch.testing.assert_close(y_fp4_swizzled, y_fp4_ref) + # FP4 values should be identical (view as uint8 for comparison) + torch.testing.assert_close( + y_fp4_swizzled.view(torch.uint8), y_fp4_ref.view(torch.uint8) + ) # Scale factors should match after unswizzling torch.testing.assert_close( @@ -628,7 +1001,7 @@ def test_mxfp4_swizzled_vs_unswizzled(self, batch_size, hidden_size, dtype): # Non-swizzled output y_fp4_ref = torch.empty( - batch_size, hidden_size // 2, device="cuda", dtype=torch.uint8 + batch_size, hidden_size // 2, device="cuda", dtype=torch.float4_e2m1fn_x2 ) block_scale_ref = torch.empty( batch_size, hidden_size // block_size, device="cuda", dtype=torch.uint8 @@ -640,7 +1013,7 @@ def test_mxfp4_swizzled_vs_unswizzled(self, batch_size, hidden_size, dtype): num_k_tiles = (hidden_size + factor - 1) // factor swizzled_size = num_m_tiles * num_k_tiles * 32 * 4 * 4 # 128x4 tile pattern y_fp4_swizzled = torch.empty( - batch_size, hidden_size // 2, device="cuda", dtype=torch.uint8 + batch_size, hidden_size // 2, device="cuda", dtype=torch.float4_e2m1fn_x2 ) block_scale_swizzled = torch.empty( swizzled_size, device="cuda", dtype=torch.uint8 @@ -671,12 +1044,244 @@ def test_mxfp4_swizzled_vs_unswizzled(self, batch_size, hidden_size, dtype): block_scale_swizzled, batch_size, hidden_size, block_size ) - # FP4 values should be identical - torch.testing.assert_close(y_fp4_swizzled, y_fp4_ref) + # FP4 values should be identical (view as uint8 for comparison) + torch.testing.assert_close( + y_fp4_swizzled.view(torch.uint8), y_fp4_ref.view(torch.uint8) + ) # Scale factors should match after unswizzling torch.testing.assert_close(block_scale_unswizzled, block_scale_ref) +@cute_dsl_available +@blackwell_required +class TestAutoAllocation: + """Tests for automatic output tensor allocation when y_fp4 and block_scale are None.""" + + @pytest.mark.parametrize("batch_size", [1, 16, 128]) + @pytest.mark.parametrize("hidden_size", [256, 1024, 4096]) + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + def test_auto_allocation_2d_nvfp4(self, batch_size, hidden_size, dtype): + """Test auto-allocation with 2D input and NVFP4 format.""" + from flashinfer.cute_dsl.add_rmsnorm_fp4quant import add_rmsnorm_fp4quant + + torch.manual_seed(42) + block_size = 16 + eps = 1e-6 + + x = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype) + r = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype) + weight = torch.randn(hidden_size, device="cuda", dtype=dtype) + + # Call without providing y_fp4 and block_scale + y_fp4, block_scale = add_rmsnorm_fp4quant( + x, r, weight, eps=eps, block_size=block_size + ) + + # Verify output shapes + assert y_fp4.shape == (batch_size, hidden_size // 2) + assert block_scale.shape == (batch_size, hidden_size // block_size) + + # Verify output dtypes + assert y_fp4.dtype == torch.float4_e2m1fn_x2 + assert block_scale.dtype == torch.float8_e4m3fn + + # Reference computation + h = x + r + ref_rmsnorm = llama_rms_norm(h, weight, eps=eps) + + # Dequantize and verify values + y_dequant = dequantize_fp4_output(y_fp4, block_scale, block_size) + torch.testing.assert_close( + y_dequant, + ref_rmsnorm.float(), + rtol=0.3, + atol=0.5, + ) + + @pytest.mark.parametrize("batch_size", [1, 4, 16]) + @pytest.mark.parametrize("seq_len", [16, 64]) + @pytest.mark.parametrize("hidden_size", [256, 1024]) + @pytest.mark.parametrize("dtype", [torch.float16]) + def test_auto_allocation_3d_nvfp4(self, batch_size, seq_len, hidden_size, dtype): + """Test auto-allocation with 3D input and NVFP4 format.""" + from flashinfer.cute_dsl.add_rmsnorm_fp4quant import add_rmsnorm_fp4quant + + torch.manual_seed(42) + block_size = 16 + eps = 1e-6 + + x = torch.randn(batch_size, seq_len, hidden_size, device="cuda", dtype=dtype) + r = torch.randn(batch_size, seq_len, hidden_size, device="cuda", dtype=dtype) + weight = torch.randn(hidden_size, device="cuda", dtype=dtype) + + # Call without providing y_fp4 and block_scale + y_fp4, block_scale = add_rmsnorm_fp4quant( + x, r, weight, eps=eps, block_size=block_size + ) + + # Verify output shapes + assert y_fp4.shape == (batch_size, seq_len, hidden_size // 2) + assert block_scale.shape == (batch_size, seq_len, hidden_size // block_size) + + # Verify output dtypes + assert y_fp4.dtype == torch.float4_e2m1fn_x2 + assert block_scale.dtype == torch.float8_e4m3fn + + # Reference computation + h = x + r + ref_rmsnorm = llama_rms_norm(h, weight, eps=eps) + + # Dequantize and verify values + y_dequant = dequantize_fp4_output(y_fp4, block_scale, block_size) + torch.testing.assert_close( + y_dequant, + ref_rmsnorm.float(), + rtol=0.3, + atol=0.5, + ) + + @pytest.mark.parametrize("batch_size", [1, 16, 128]) + @pytest.mark.parametrize("hidden_size", [256, 1024]) + @pytest.mark.parametrize("dtype", [torch.float16]) + def test_auto_allocation_mxfp4(self, batch_size, hidden_size, dtype): + """Test auto-allocation with MXFP4 format (block_size=32, UE8M0 scales).""" + from flashinfer.cute_dsl.add_rmsnorm_fp4quant import add_rmsnorm_fp4quant + + torch.manual_seed(42) + block_size = 32 + eps = 1e-6 + + x = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype) + r = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype) + weight = torch.randn(hidden_size, device="cuda", dtype=dtype) + + # Call without providing y_fp4 and block_scale + y_fp4, block_scale = add_rmsnorm_fp4quant( + x, r, weight, eps=eps, block_size=block_size, scale_format="ue8m0" + ) + + # Verify output shapes + assert y_fp4.shape == (batch_size, hidden_size // 2) + assert block_scale.shape == (batch_size, hidden_size // block_size) + + # Verify output dtypes + assert y_fp4.dtype == torch.float4_e2m1fn_x2 + assert block_scale.dtype == torch.uint8 # UE8M0 uses uint8 + + # Reference computation + h = x + r + ref_rmsnorm = llama_rms_norm(h, weight, eps=eps) + + # Dequantize and verify values + y_dequant = dequantize_fp4_output(y_fp4, block_scale, block_size) + torch.testing.assert_close( + y_dequant, + ref_rmsnorm.float(), + rtol=0.3, + atol=0.7, + ) + + @pytest.mark.parametrize("batch_size", [16, 128]) + @pytest.mark.parametrize("hidden_size", [512, 1024]) + @pytest.mark.parametrize("dtype", [torch.float16]) + def test_auto_allocation_swizzled(self, batch_size, hidden_size, dtype): + """Test auto-allocation with swizzled scale factor layout.""" + from flashinfer.cute_dsl.add_rmsnorm_fp4quant import add_rmsnorm_fp4quant + + torch.manual_seed(42) + block_size = 16 + eps = 1e-6 + + x = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype) + r = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype) + weight = torch.randn(hidden_size, device="cuda", dtype=dtype) + + # Call without providing y_fp4 and block_scale, with swizzled layout + y_fp4, block_scale = add_rmsnorm_fp4quant( + x, r, weight, eps=eps, block_size=block_size, is_sf_swizzled_layout=True + ) + + # Verify output shapes + assert y_fp4.shape == (batch_size, hidden_size // 2) + # Swizzled layout has different shape + factor = block_size * 4 + num_m_tiles = (batch_size + 127) // 128 + num_k_tiles = (hidden_size + factor - 1) // factor + expected_swizzled_size = num_m_tiles * num_k_tiles * 32 * 4 * 4 + assert block_scale.shape == (expected_swizzled_size,) + + # Verify output dtypes + assert y_fp4.dtype == torch.float4_e2m1fn_x2 + assert block_scale.dtype == torch.float8_e4m3fn + + # Unswizzle and compare with non-swizzled version + y_fp4_ref = torch.empty( + batch_size, hidden_size // 2, device="cuda", dtype=torch.float4_e2m1fn_x2 + ) + block_scale_ref = torch.empty( + batch_size, + hidden_size // block_size, + device="cuda", + dtype=torch.float8_e4m3fn, + ) + add_rmsnorm_fp4quant( + x, r, weight, y_fp4_ref, block_scale_ref, eps=eps, block_size=block_size + ) + + # FP4 values should be identical (view as uint8 for comparison) + torch.testing.assert_close(y_fp4.view(torch.uint8), y_fp4_ref.view(torch.uint8)) + + # Unswizzle and compare scales + block_scale_unswizzled = unswizzle_sf( + block_scale.view(torch.uint8), batch_size, hidden_size, block_size + ).view(torch.float8_e4m3fn) + torch.testing.assert_close( + block_scale_unswizzled.view(torch.uint8), block_scale_ref.view(torch.uint8) + ) + + @pytest.mark.parametrize("batch_size", [16, 128]) + @pytest.mark.parametrize("hidden_size", [512, 1024]) + def test_auto_allocation_matches_preallocated(self, batch_size, hidden_size): + """Test that auto-allocation produces same results as pre-allocated tensors.""" + from flashinfer.cute_dsl.add_rmsnorm_fp4quant import add_rmsnorm_fp4quant + + torch.manual_seed(42) + block_size = 16 + eps = 1e-6 + dtype = torch.float16 + + x = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype) + r = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype) + weight = torch.randn(hidden_size, device="cuda", dtype=dtype) + + # Pre-allocated version + y_fp4_pre = torch.empty( + batch_size, hidden_size // 2, device="cuda", dtype=torch.float4_e2m1fn_x2 + ) + block_scale_pre = torch.empty( + batch_size, + hidden_size // block_size, + device="cuda", + dtype=torch.float8_e4m3fn, + ) + add_rmsnorm_fp4quant( + x, r, weight, y_fp4_pre, block_scale_pre, eps=eps, block_size=block_size + ) + + # Auto-allocated version + y_fp4_auto, block_scale_auto = add_rmsnorm_fp4quant( + x, r, weight, eps=eps, block_size=block_size + ) + + # Results should be identical (view as uint8 for comparison) + torch.testing.assert_close( + y_fp4_auto.view(torch.uint8), y_fp4_pre.view(torch.uint8) + ) + torch.testing.assert_close( + block_scale_auto.view(torch.uint8), block_scale_pre.view(torch.uint8) + ) + + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/tests/norm/test_rmsnorm_fp4_quant_cute_dsl.py b/tests/norm/test_rmsnorm_fp4_quant_cute_dsl.py index e196098f56..4a8b439d49 100644 --- a/tests/norm/test_rmsnorm_fp4_quant_cute_dsl.py +++ b/tests/norm/test_rmsnorm_fp4_quant_cute_dsl.py @@ -40,14 +40,22 @@ def llama_rms_norm(x, w, eps=1e-6): def dequantize_fp4_output( - y_fp4: torch.Tensor, block_scale: torch.Tensor, block_size: int + y_fp4: torch.Tensor, + block_scale: torch.Tensor, + block_size: int, + global_scale: torch.Tensor | None = None, ): """ Dequantize packed FP4 tensor using the associated block scales. Handles both 2D inputs shaped [B, H/2] and 3D inputs shaped [B, S, H/2]. + + If global_scale is provided, the dequantized values are divided by global_scale + to reverse the scaling applied during quantization. """ - y_fp4_float = cast_from_fp4(y_fp4) + # View as uint8 for bitwise operations in cast_from_fp4 + # (float4_e2m1fn_x2 and uint8 have the same memory layout) + y_fp4_float = cast_from_fp4(y_fp4.view(torch.uint8)) if y_fp4_float.dim() == 2: b, hidden_size = y_fp4_float.shape assert hidden_size % block_size == 0 @@ -58,7 +66,7 @@ def dequantize_fp4_output( scales = torch.pow(2.0, block_scale.int() - 127).unsqueeze(-1) else: scales = block_scale.float().unsqueeze(-1) - return (y_fp4_float * scales).reshape(b, hidden_size) + result = (y_fp4_float * scales).reshape(b, hidden_size) elif y_fp4_float.dim() == 3: b, s, hidden_size = y_fp4_float.shape assert hidden_size % block_size == 0 @@ -67,10 +75,84 @@ def dequantize_fp4_output( scales = torch.pow(2.0, block_scale.int() - 127).unsqueeze(-1) else: scales = block_scale.float().unsqueeze(-1) - return (y_fp4_float * scales).reshape(b, s, hidden_size) + result = (y_fp4_float * scales).reshape(b, s, hidden_size) else: raise ValueError(f"Unsupported FP4 output rank: {y_fp4_float.dim()}") + # Reverse global scale if it was applied during quantization + # During quantization: block_scale includes global_scale + # block_scale = global_scale * max_abs / FP4_MAX + # During dequantization: y = (fp4_value * block_scale) / global_scale + if global_scale is not None: + result = result / global_scale.item() + + return result + + +def compute_global_scale( + x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6 +) -> torch.Tensor: + """ + Compute global scale for NVFP4 quantization. + + global_scale = (FP8_E4M3_MAX * FP4_E2M1_MAX) / max_abs(rmsnorm_output) + + This ensures the dynamic range of the RMSNorm output fits within the FP4 range. + """ + FLOAT4_E2M1_MAX = 6.0 + FLOAT8_E4M3_MAX = float(torch.finfo(torch.float8_e4m3fn).max) + + # Compute reference RMSNorm output + ref_rmsnorm = llama_rms_norm(x, weight, eps=eps) + tensor_amax = torch.abs(ref_rmsnorm).max().to(torch.float32) + global_scale = torch.tensor( + [FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax.item()], + dtype=torch.float32, + device=x.device, + ) + return global_scale + + +def assert_close_with_tiered_tolerance( + actual: torch.Tensor, + expected: torch.Tensor, + tight_rtol: float = 0.1, + tight_atol: float = 0.1, + loose_rtol: float = 0.5, + loose_atol: float = 2.0, + tight_pct: float = 0.99, + msg: str = "", +): + """ + Two-tiered tolerance check for quantized outputs. + + - tight_pct (e.g., 99%) of elements must be within tight tolerance + - 100% of elements must be within loose tolerance + + This handles the expected quantization noise where most elements match closely + but a few outliers may differ more due to rounding boundary effects. + """ + diff = (actual - expected).abs() + rel_diff = diff / (expected.abs() + 1e-8) + + # Check 1: tight_pct of elements within tight tolerance + within_tight = (diff <= tight_atol) | (rel_diff <= tight_rtol) + tight_pct_actual = within_tight.float().mean().item() + assert tight_pct_actual >= tight_pct, ( + f"{msg}: Only {tight_pct_actual * 100:.1f}% of elements within tight tolerance " + f"(rtol={tight_rtol}, atol={tight_atol}), expected {tight_pct * 100:.0f}%" + ) + + # Check 2: 100% of elements within loose tolerance + within_loose = (diff <= loose_atol) | (rel_diff <= loose_rtol) + if not within_loose.all(): + max_diff = diff.max().item() + max_rel = rel_diff.max().item() + raise AssertionError( + f"{msg}: Max diff {max_diff:.4f} (rel: {max_rel:.4f}) exceeds loose tolerance " + f"(rtol={loose_rtol}, atol={loose_atol})" + ) + def requires_cute_dsl(): """Check if CuTe-DSL is available.""" @@ -117,19 +199,22 @@ class TestRMSNormFP4QuantCuteDSL: @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("eps", [1e-5, 1e-6]) def test_rmsnorm_fp4quant_2d(self, batch_size, hidden_size, dtype, eps): - """Test fused RMSNorm + FP4 quantization with 2D input.""" + """Test fused RMSNorm + FP4 quantization with 2D input (NVFP4 with global_scale).""" from flashinfer.cute_dsl.rmsnorm_fp4quant import rmsnorm_fp4quant torch.manual_seed(42) - block_size = 16 + block_size = 16 # NVFP4 # Create input tensors x = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype) weight = torch.randn(hidden_size, device="cuda", dtype=dtype) + # Compute global_scale for NVFP4 + global_scale = compute_global_scale(x, weight, eps=eps) + # Allocate output tensors y_fp4 = torch.empty( - batch_size, hidden_size // 2, device="cuda", dtype=torch.uint8 + batch_size, hidden_size // 2, device="cuda", dtype=torch.float4_e2m1fn_x2 ) block_scale = torch.empty( batch_size, @@ -138,28 +223,39 @@ def test_rmsnorm_fp4quant_2d(self, batch_size, hidden_size, dtype, eps): dtype=torch.float8_e4m3fn, ) - # Run fused kernel - rmsnorm_fp4quant(x, weight, y_fp4, block_scale, eps=eps, block_size=block_size) + # Run fused kernel with global_scale (NVFP4) + rmsnorm_fp4quant( + x, + weight, + y_fp4, + block_scale, + global_scale=global_scale, + eps=eps, + block_size=block_size, + ) # Verify output shapes assert y_fp4.shape == (batch_size, hidden_size // 2) assert block_scale.shape == (batch_size, hidden_size // block_size) # Verify output dtypes - assert y_fp4.dtype == torch.uint8 + assert y_fp4.dtype == torch.float4_e2m1fn_x2 assert block_scale.dtype == torch.float8_e4m3fn # Reference computation ref_rmsnorm = llama_rms_norm(x, weight, eps=eps) # Dequantize FP4 output for value-level comparison - # Tolerance based on separate FP4 roundtrip test (rtol=0.3, atol=0.5) - y_dequant = dequantize_fp4_output(y_fp4, block_scale, block_size) - torch.testing.assert_close( + # Pass global_scale to reverse the scaling applied during quantization + y_dequant = dequantize_fp4_output(y_fp4, block_scale, block_size, global_scale) + assert_close_with_tiered_tolerance( y_dequant, ref_rmsnorm.float(), - rtol=0.3, - atol=0.5, + tight_rtol=0.3, + tight_atol=0.5, + loose_rtol=0.5, + loose_atol=2.0, + tight_pct=0.99, ) @pytest.mark.parametrize("batch_size", [1, 4, 3, 7, 128]) @@ -167,20 +263,27 @@ def test_rmsnorm_fp4quant_2d(self, batch_size, hidden_size, dtype, eps): @pytest.mark.parametrize("hidden_size", [128, 256, 1536, 4096, 8192]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) def test_rmsnorm_fp4quant_3d(self, batch_size, seq_len, hidden_size, dtype): - """Test fused RMSNorm + FP4 quantization with 3D input.""" + """Test fused RMSNorm + FP4 quantization with 3D input (NVFP4 with global_scale).""" from flashinfer.cute_dsl.rmsnorm_fp4quant import rmsnorm_fp4quant torch.manual_seed(42) - block_size = 16 + block_size = 16 # NVFP4 eps = 1e-5 # Create input tensors x = torch.randn(batch_size, seq_len, hidden_size, device="cuda", dtype=dtype) weight = torch.randn(hidden_size, device="cuda", dtype=dtype) + # Compute global_scale for NVFP4 + global_scale = compute_global_scale(x, weight, eps=eps) + # Allocate output tensors y_fp4 = torch.empty( - batch_size, seq_len, hidden_size // 2, device="cuda", dtype=torch.uint8 + batch_size, + seq_len, + hidden_size // 2, + device="cuda", + dtype=torch.float4_e2m1fn_x2, ) block_scale = torch.empty( batch_size, @@ -190,28 +293,39 @@ def test_rmsnorm_fp4quant_3d(self, batch_size, seq_len, hidden_size, dtype): dtype=torch.float8_e4m3fn, ) - # Run fused kernel - rmsnorm_fp4quant(x, weight, y_fp4, block_scale, eps=eps, block_size=block_size) + # Run fused kernel with global_scale (NVFP4) + rmsnorm_fp4quant( + x, + weight, + y_fp4, + block_scale, + global_scale=global_scale, + eps=eps, + block_size=block_size, + ) # Verify output shapes assert y_fp4.shape == (batch_size, seq_len, hidden_size // 2) assert block_scale.shape == (batch_size, seq_len, hidden_size // block_size) # Verify output dtypes - assert y_fp4.dtype == torch.uint8 + assert y_fp4.dtype == torch.float4_e2m1fn_x2 assert block_scale.dtype == torch.float8_e4m3fn # Reference computation ref_rmsnorm = llama_rms_norm(x, weight, eps=eps) # Dequantize FP4 output for value-level comparison - # Tolerance based on separate FP4 roundtrip test (rtol=0.3, atol=0.5) - y_dequant = dequantize_fp4_output(y_fp4, block_scale, block_size) - torch.testing.assert_close( + # Pass global_scale to reverse the scaling applied during quantization + y_dequant = dequantize_fp4_output(y_fp4, block_scale, block_size, global_scale) + assert_close_with_tiered_tolerance( y_dequant, ref_rmsnorm.float(), - rtol=0.3, - atol=0.5, + tight_rtol=0.3, + tight_atol=0.5, + loose_rtol=0.5, + loose_atol=2.0, + tight_pct=0.99, ) @pytest.mark.parametrize( @@ -223,18 +337,21 @@ def test_rmsnorm_fp4quant_3d(self, batch_size, seq_len, hidden_size, dtype): ) @pytest.mark.parametrize("dtype", [torch.float16]) def test_large_batch(self, batch_size, hidden_size, dtype): - """Test with large batch sizes.""" + """Test with large batch sizes (NVFP4 with global_scale).""" from flashinfer.cute_dsl.rmsnorm_fp4quant import rmsnorm_fp4quant torch.manual_seed(42) - block_size = 16 + block_size = 16 # NVFP4 eps = 1e-6 x = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype) weight = torch.randn(hidden_size, device="cuda", dtype=dtype) + # Compute global_scale for NVFP4 + global_scale = compute_global_scale(x, weight, eps=eps) + y_fp4 = torch.empty( - batch_size, hidden_size // 2, device="cuda", dtype=torch.uint8 + batch_size, hidden_size // 2, device="cuda", dtype=torch.float4_e2m1fn_x2 ) block_scale = torch.empty( batch_size, @@ -244,11 +361,21 @@ def test_large_batch(self, batch_size, hidden_size, dtype): ) # Should complete without error - rmsnorm_fp4quant(x, weight, y_fp4, block_scale, eps=eps, block_size=block_size) + rmsnorm_fp4quant( + x, + weight, + y_fp4, + block_scale, + global_scale=global_scale, + eps=eps, + block_size=block_size, + ) # Reference computation (sample first 10 rows for speed) ref_rmsnorm = llama_rms_norm(x[:10], weight, eps=eps) - y_dequant = dequantize_fp4_output(y_fp4[:10], block_scale[:10], block_size) + y_dequant = dequantize_fp4_output( + y_fp4[:10], block_scale[:10], block_size, global_scale + ) torch.testing.assert_close( y_dequant, @@ -278,7 +405,7 @@ def test_mxfp4_basic(self, batch_size, hidden_size, dtype): weight = torch.randn(hidden_size, device="cuda", dtype=dtype) y_fp4 = torch.empty( - batch_size, hidden_size // 2, device="cuda", dtype=torch.uint8 + batch_size, hidden_size // 2, device="cuda", dtype=torch.float4_e2m1fn_x2 ) # UE8M0 scale factors are returned as uint8 block_scale = torch.empty( @@ -298,7 +425,7 @@ def test_mxfp4_basic(self, batch_size, hidden_size, dtype): # Verify output shapes assert y_fp4.shape == (batch_size, hidden_size // 2) assert block_scale.shape == (batch_size, hidden_size // block_size) - assert y_fp4.dtype == torch.uint8 + assert y_fp4.dtype == torch.float4_e2m1fn_x2 assert block_scale.dtype == torch.uint8 # Reference computation @@ -325,7 +452,7 @@ class TestSeparateFlashInferComparison: @pytest.mark.parametrize("dtype", [torch.float16]) def test_fused_vs_separate(self, batch_size, hidden_size, dtype): """ - Compare CuTe-DSL fused output with reference RMSNorm. + Compare CuTe-DSL fused output with reference RMSNorm (NVFP4 with global_scale). We compare the dequantized output against the reference RMSNorm, rather than comparing bitwise with separate fp4_quantize (which uses @@ -335,15 +462,18 @@ def test_fused_vs_separate(self, batch_size, hidden_size, dtype): from flashinfer.norm import rmsnorm torch.manual_seed(42) - block_size = 16 + block_size = 16 # NVFP4 eps = 1e-6 x = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype) weight = torch.randn(hidden_size, device="cuda", dtype=dtype) + # Compute global_scale for NVFP4 + global_scale = compute_global_scale(x, weight, eps=eps) + # Fused CuTe-DSL kernel y_fp4_fused = torch.empty( - batch_size, hidden_size // 2, device="cuda", dtype=torch.uint8 + batch_size, hidden_size // 2, device="cuda", dtype=torch.float4_e2m1fn_x2 ) block_scale_fused = torch.empty( batch_size, @@ -352,7 +482,13 @@ def test_fused_vs_separate(self, batch_size, hidden_size, dtype): dtype=torch.float8_e4m3fn, ) rmsnorm_fp4quant( - x, weight, y_fp4_fused, block_scale_fused, eps=eps, block_size=block_size + x, + weight, + y_fp4_fused, + block_scale_fused, + global_scale=global_scale, + eps=eps, + block_size=block_size, ) # Reference: RMSNorm only @@ -364,17 +500,291 @@ def test_fused_vs_separate(self, batch_size, hidden_size, dtype): # Dequantize fused output and compare to reference y_fused_dequant = dequantize_fp4_output( - y_fp4_fused, block_scale_fused, block_size + y_fp4_fused, block_scale_fused, block_size, global_scale ) # Value-level comparison against reference RMSNorm output - torch.testing.assert_close( + assert_close_with_tiered_tolerance( y_fused_dequant, y_ref.float(), - rtol=0.3, - atol=0.5, + tight_rtol=0.3, + tight_atol=0.5, + loose_rtol=0.5, + loose_atol=2.0, + tight_pct=0.99, + ) + + +@cute_dsl_available +@blackwell_required +class TestFusedVsSeparateFP4Quantize: + """ + Tests comparing fused RMSNorm+FP4Quant against separate RMSNorm + fp4_quantize. + + This validates that the fused kernel applies global_scale identically to the + standalone fp4_quantize function. + """ + + @pytest.mark.parametrize("batch_size", [1, 4, 16, 128]) + @pytest.mark.parametrize("hidden_size", [64, 256, 512, 1024, 2048, 4096]) + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + def test_nvfp4_fused_matches_separate(self, batch_size, hidden_size, dtype): + """ + Compare fused kernel against separate RMSNorm + fp4_quantize for NVFP4. + + This test verifies that the fused kernel applies global_scale identically + to the standalone fp4_quantize function, by comparing: + 1. The packed FP4 output bytes + 2. The block scale factors + """ + from flashinfer.cute_dsl.rmsnorm_fp4quant import rmsnorm_fp4quant + from flashinfer import fp4_quantize + + torch.manual_seed(42) + block_size = 16 # NVFP4 + eps = 1e-6 + + x = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype) + weight = torch.randn(hidden_size, device="cuda", dtype=dtype) + + # Compute global_scale for NVFP4 + global_scale = compute_global_scale(x, weight, eps=eps) + + # === Fused kernel path === + y_fp4_fused = torch.empty( + batch_size, hidden_size // 2, device="cuda", dtype=torch.float4_e2m1fn_x2 + ) + block_scale_fused = torch.empty( + batch_size, + hidden_size // block_size, + device="cuda", + dtype=torch.float8_e4m3fn, + ) + rmsnorm_fp4quant( + x, + weight, + y_fp4_fused, + block_scale_fused, + global_scale=global_scale, + eps=eps, + block_size=block_size, + is_sf_swizzled_layout=False, # Use unswizzled for easier comparison + ) + + # === Separate path: RMSNorm + fp4_quantize === + y_rmsnorm = llama_rms_norm(x, weight, eps=eps) + y_fp4_separate, block_scale_separate = fp4_quantize( + y_rmsnorm, + global_scale, + sf_vec_size=block_size, + sf_use_ue8m0=False, # E4M3 for NVFP4 + is_sf_swizzled_layout=False, + ) + + # === Compare FP4 packed outputs === + # They should match exactly since the quantization logic is the same + # View as uint8 for comparison (float4_e2m1fn_x2 doesn't support == operator) + fp4_match = ( + (y_fp4_fused.view(torch.uint8) == y_fp4_separate.view(torch.uint8)) + .float() + .mean() + .item() + ) + assert fp4_match > 0.95, ( + f"FP4 output mismatch: only {fp4_match * 100:.1f}% of bytes match" + ) + + # === Compare block scales === + # Cast to float for comparison + scale_fused = block_scale_fused.to(torch.float32) + scale_separate = ( + block_scale_separate.view(torch.float8_e4m3fn) + .view(batch_size, -1) + .to(torch.float32) + ) + + scale_match = (scale_fused == scale_separate).float().mean().item() + assert scale_match > 0.90, ( + f"Block scale mismatch: only {scale_match * 100:.1f}% of scales match" + ) + + # === Also verify dequantized values are close === + y_fused_dequant = dequantize_fp4_output( + y_fp4_fused, block_scale_fused, block_size, global_scale + ) + y_separate_dequant = dequantize_fp4_output( + y_fp4_separate, + block_scale_separate.view(torch.float8_e4m3fn).view(batch_size, -1), + block_size, + global_scale, + ) + + # Two-tiered tolerance: 99% within tight tolerance, 100% within loose tolerance + assert_close_with_tiered_tolerance( + y_fused_dequant, + y_separate_dequant, + tight_rtol=0.3, + tight_atol=0.5, + loose_rtol=0.5, + loose_atol=2.0, + tight_pct=0.99, + msg="Dequantized outputs from fused and separate paths should match closely", + ) + + @pytest.mark.parametrize("batch_size", [1, 4, 16, 128]) + @pytest.mark.parametrize("hidden_size", [128, 256, 512, 1024, 2048, 4096]) + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + def test_mxfp4_fused_matches_separate(self, batch_size, hidden_size, dtype): + """ + Compare fused kernel against separate RMSNorm + fp4_quantize for MXFP4. + + MXFP4 uses block_size=32, UE8M0 scales, and no global_scale (global_scale=1.0). + """ + from flashinfer.cute_dsl.rmsnorm_fp4quant import rmsnorm_fp4quant + from flashinfer import fp4_quantize + + torch.manual_seed(42) + block_size = 32 # MXFP4 + eps = 1e-6 + + x = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype) + weight = torch.randn(hidden_size, device="cuda", dtype=dtype) + + # MXFP4 uses global_scale=1.0 + global_scale_val = torch.tensor(1.0, dtype=torch.float32, device="cuda") + + # === Fused kernel path === + y_fp4_fused = torch.empty( + batch_size, hidden_size // 2, device="cuda", dtype=torch.float4_e2m1fn_x2 + ) + block_scale_fused = torch.empty( + batch_size, hidden_size // block_size, device="cuda", dtype=torch.uint8 + ) + rmsnorm_fp4quant( + x, + weight, + y_fp4_fused, + block_scale_fused, + eps=eps, + block_size=block_size, + scale_format="ue8m0", + is_sf_swizzled_layout=False, + ) + + # === Separate path: RMSNorm + fp4_quantize === + y_rmsnorm = llama_rms_norm(x, weight, eps=eps) + y_fp4_separate, block_scale_separate = fp4_quantize( + y_rmsnorm, + global_scale_val, + sf_vec_size=block_size, + sf_use_ue8m0=True, # UE8M0 for MXFP4 + is_sf_swizzled_layout=False, + ) + + # === Compare FP4 packed outputs === + # View as uint8 for comparison (float4_e2m1fn_x2 doesn't support == operator) + fp4_match = ( + (y_fp4_fused.view(torch.uint8) == y_fp4_separate.view(torch.uint8)) + .float() + .mean() + .item() + ) + assert fp4_match > 0.95, ( + f"FP4 output mismatch: only {fp4_match * 100:.1f}% of bytes match" + ) + + # === Compare block scales === + scale_fused = block_scale_fused + scale_separate = block_scale_separate.view(batch_size, -1) + + scale_match = (scale_fused == scale_separate).float().mean().item() + assert scale_match > 0.90, ( + f"Block scale mismatch: only {scale_match * 100:.1f}% of scales match" + ) + + # === Also verify dequantized values are close === + # MXFP4 has larger errors due to power-of-2 scale constraints + # A few outlier values can differ by up to 2.0 due to quantization noise + y_fused_dequant = dequantize_fp4_output( + y_fp4_fused, block_scale_fused, block_size + ) + y_separate_dequant = dequantize_fp4_output( + y_fp4_separate, scale_separate, block_size + ) + + # Two-tiered tolerance: 99% within tight tolerance, 100% within loose tolerance + assert_close_with_tiered_tolerance( + y_fused_dequant, + y_separate_dequant, + tight_rtol=0.3, + tight_atol=0.5, + loose_rtol=0.5, + loose_atol=2.0, + tight_pct=0.99, + msg="Dequantized outputs from fused and separate paths should match closely", + ) + + @pytest.mark.parametrize("batch_size", [1, 16, 64]) + @pytest.mark.parametrize("hidden_size", [256, 1024, 4096]) + def test_global_scale_value_consistency(self, batch_size, hidden_size): + """ + Verify that the global_scale value correctly scales the block scales. + + When global_scale is applied: + - block_scale_with_gs = global_scale * max_abs / FP4_MAX + - This should be approximately global_scale times larger than without global_scale + (when using the same input data) + """ + from flashinfer.cute_dsl.rmsnorm_fp4quant import rmsnorm_fp4quant + + torch.manual_seed(42) + block_size = 16 # NVFP4 + eps = 1e-6 + dtype = torch.float16 + + x = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype) + weight = torch.randn(hidden_size, device="cuda", dtype=dtype) + + # Run with computed global_scale + global_scale = compute_global_scale(x, weight, eps=eps) + + y_fp4_gs, block_scale_gs = rmsnorm_fp4quant( + x, + weight, + global_scale=global_scale, + eps=eps, + block_size=block_size, + is_sf_swizzled_layout=False, + ) + + # Run without global_scale (global_scale=1.0) + global_scale_one = torch.tensor([1.0], dtype=torch.float32, device="cuda") + + y_fp4_no_gs, block_scale_no_gs = rmsnorm_fp4quant( + x, + weight, + global_scale=global_scale_one, + eps=eps, + block_size=block_size, + is_sf_swizzled_layout=False, ) + # The block scales with global_scale should be approximately global_scale times + # larger than without (since block_scale = global_scale * max_abs / FP4_MAX) + scale_gs = block_scale_gs.to(torch.float32) + scale_no_gs = block_scale_no_gs.to(torch.float32) + + # Compute ratio where both are non-zero + non_zero_mask = (scale_no_gs > 0) & (scale_gs > 0) + if non_zero_mask.sum() > 0: + ratio = (scale_gs[non_zero_mask] / scale_no_gs[non_zero_mask]).mean().item() + expected_ratio = global_scale.item() + + # Allow some tolerance due to FP8 quantization + assert abs(ratio - expected_ratio) / expected_ratio < 0.2, ( + f"Block scale ratio {ratio:.2f} doesn't match expected global_scale {expected_ratio:.2f}" + ) + @cute_dsl_available @blackwell_required @@ -390,18 +800,21 @@ class TestLargeHiddenSize: @pytest.mark.parametrize("hidden_size", [16384, 32768]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) def test_large_hidden_nvfp4(self, batch_size, hidden_size, dtype): - """Test NVFP4 format with large hidden sizes (cluster sync path).""" + """Test NVFP4 format with large hidden sizes (cluster sync path, with global_scale).""" from flashinfer.cute_dsl.rmsnorm_fp4quant import rmsnorm_fp4quant torch.manual_seed(42) - block_size = 16 + block_size = 16 # NVFP4 eps = 1e-6 x = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype) weight = torch.randn(hidden_size, device="cuda", dtype=dtype) + # Compute global_scale for NVFP4 + global_scale = compute_global_scale(x, weight, eps=eps) + y_fp4 = torch.empty( - batch_size, hidden_size // 2, device="cuda", dtype=torch.uint8 + batch_size, hidden_size // 2, device="cuda", dtype=torch.float4_e2m1fn_x2 ) block_scale = torch.empty( batch_size, @@ -410,20 +823,28 @@ def test_large_hidden_nvfp4(self, batch_size, hidden_size, dtype): dtype=torch.float8_e4m3fn, ) - # Run kernel - rmsnorm_fp4quant(x, weight, y_fp4, block_scale, eps=eps, block_size=block_size) + # Run kernel with global_scale + rmsnorm_fp4quant( + x, + weight, + y_fp4, + block_scale, + global_scale=global_scale, + eps=eps, + block_size=block_size, + ) # Verify output shapes assert y_fp4.shape == (batch_size, hidden_size // 2) assert block_scale.shape == (batch_size, hidden_size // block_size) - assert y_fp4.dtype == torch.uint8 + assert y_fp4.dtype == torch.float4_e2m1fn_x2 assert block_scale.dtype == torch.float8_e4m3fn # Sample first few rows for value comparison (full dequant is slow) num_check = min(10, batch_size) ref_rmsnorm = llama_rms_norm(x[:num_check], weight, eps=eps) y_dequant = dequantize_fp4_output( - y_fp4[:num_check], block_scale[:num_check], block_size + y_fp4[:num_check], block_scale[:num_check], block_size, global_scale ) torch.testing.assert_close( @@ -448,7 +869,7 @@ def test_large_hidden_mxfp4(self, batch_size, hidden_size, dtype): weight = torch.randn(hidden_size, device="cuda", dtype=dtype) y_fp4 = torch.empty( - batch_size, hidden_size // 2, device="cuda", dtype=torch.uint8 + batch_size, hidden_size // 2, device="cuda", dtype=torch.float4_e2m1fn_x2 ) block_scale = torch.empty( batch_size, hidden_size // block_size, device="cuda", dtype=torch.uint8 @@ -468,7 +889,7 @@ def test_large_hidden_mxfp4(self, batch_size, hidden_size, dtype): # Verify output shapes assert y_fp4.shape == (batch_size, hidden_size // 2) assert block_scale.shape == (batch_size, hidden_size // block_size) - assert y_fp4.dtype == torch.uint8 + assert y_fp4.dtype == torch.float4_e2m1fn_x2 assert block_scale.dtype == torch.uint8 # Sample first few rows for value comparison (full dequant is slow) @@ -533,19 +954,23 @@ class TestSwizzledScaleFactors: def test_nvfp4_swizzled_vs_unswizzled(self, batch_size, hidden_size, dtype): """ Test that swizzled output, when unswizzled, matches the non-swizzled output. - Uses NVFP4 format (block_size=16, E4M3 scales). + Uses NVFP4 format (block_size=16, E4M3 scales) with global_scale. """ from flashinfer.cute_dsl import rmsnorm_fp4quant - block_size = 16 + block_size = 16 # NVFP4 + eps = 1e-6 torch.manual_seed(42) x = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype) weight = torch.randn(hidden_size, device="cuda", dtype=dtype) + # Compute global_scale for NVFP4 + global_scale = compute_global_scale(x, weight, eps=eps) + # Non-swizzled output y_fp4_ref = torch.empty( - batch_size, hidden_size // 2, device="cuda", dtype=torch.uint8 + batch_size, hidden_size // 2, device="cuda", dtype=torch.float4_e2m1fn_x2 ) block_scale_ref = torch.empty( batch_size, @@ -560,18 +985,19 @@ def test_nvfp4_swizzled_vs_unswizzled(self, batch_size, hidden_size, dtype): num_k_tiles = (hidden_size + factor - 1) // factor swizzled_size = num_m_tiles * num_k_tiles * 32 * 4 * 4 # 128x4 tile pattern y_fp4_swizzled = torch.empty( - batch_size, hidden_size // 2, device="cuda", dtype=torch.uint8 + batch_size, hidden_size // 2, device="cuda", dtype=torch.float4_e2m1fn_x2 ) block_scale_swizzled = torch.empty( swizzled_size, device="cuda", dtype=torch.float8_e4m3fn ) - # Run kernels + # Run kernels with global_scale rmsnorm_fp4quant( x, weight, y_fp4_ref, block_scale_ref, + global_scale=global_scale, block_size=block_size, is_sf_swizzled_layout=False, ) @@ -580,6 +1006,7 @@ def test_nvfp4_swizzled_vs_unswizzled(self, batch_size, hidden_size, dtype): weight, y_fp4_swizzled, block_scale_swizzled, + global_scale=global_scale, block_size=block_size, is_sf_swizzled_layout=True, ) @@ -589,8 +1016,10 @@ def test_nvfp4_swizzled_vs_unswizzled(self, batch_size, hidden_size, dtype): block_scale_swizzled.view(torch.uint8), batch_size, hidden_size, block_size ).view(torch.float8_e4m3fn) - # FP4 values should be identical - torch.testing.assert_close(y_fp4_swizzled, y_fp4_ref) + # FP4 values should be identical (view as uint8 for comparison) + torch.testing.assert_close( + y_fp4_swizzled.view(torch.uint8), y_fp4_ref.view(torch.uint8) + ) # Scale factors should match after unswizzling torch.testing.assert_close( @@ -615,7 +1044,7 @@ def test_mxfp4_swizzled_vs_unswizzled(self, batch_size, hidden_size, dtype): # Non-swizzled output y_fp4_ref = torch.empty( - batch_size, hidden_size // 2, device="cuda", dtype=torch.uint8 + batch_size, hidden_size // 2, device="cuda", dtype=torch.float4_e2m1fn_x2 ) block_scale_ref = torch.empty( batch_size, hidden_size // block_size, device="cuda", dtype=torch.uint8 @@ -627,7 +1056,7 @@ def test_mxfp4_swizzled_vs_unswizzled(self, batch_size, hidden_size, dtype): num_k_tiles = (hidden_size + factor - 1) // factor swizzled_size = num_m_tiles * num_k_tiles * 32 * 4 * 4 # 128x4 tile pattern y_fp4_swizzled = torch.empty( - batch_size, hidden_size // 2, device="cuda", dtype=torch.uint8 + batch_size, hidden_size // 2, device="cuda", dtype=torch.float4_e2m1fn_x2 ) block_scale_swizzled = torch.empty( swizzled_size, device="cuda", dtype=torch.uint8 @@ -656,12 +1085,265 @@ def test_mxfp4_swizzled_vs_unswizzled(self, batch_size, hidden_size, dtype): block_scale_swizzled, batch_size, hidden_size, block_size ) - # FP4 values should be identical - torch.testing.assert_close(y_fp4_swizzled, y_fp4_ref) + # FP4 values should be identical (view as uint8 for comparison) + torch.testing.assert_close( + y_fp4_swizzled.view(torch.uint8), y_fp4_ref.view(torch.uint8) + ) # Scale factors should match after unswizzling torch.testing.assert_close(block_scale_unswizzled, block_scale_ref) +@cute_dsl_available +@blackwell_required +class TestAutoAllocation: + """Tests for automatic output tensor allocation when y_fp4 and block_scale are None.""" + + @pytest.mark.parametrize("batch_size", [1, 16, 128]) + @pytest.mark.parametrize("hidden_size", [256, 1024, 4096]) + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + def test_auto_allocation_2d_nvfp4(self, batch_size, hidden_size, dtype): + """Test auto-allocation with 2D input and NVFP4 format (with global_scale).""" + from flashinfer.cute_dsl.rmsnorm_fp4quant import rmsnorm_fp4quant + + torch.manual_seed(42) + block_size = 16 # NVFP4 + eps = 1e-6 + + x = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype) + weight = torch.randn(hidden_size, device="cuda", dtype=dtype) + + # Compute global_scale for NVFP4 + global_scale = compute_global_scale(x, weight, eps=eps) + + # Call without providing y_fp4 and block_scale, but with global_scale + y_fp4, block_scale = rmsnorm_fp4quant( + x, weight, global_scale=global_scale, eps=eps, block_size=block_size + ) + + # Verify output shapes + assert y_fp4.shape == (batch_size, hidden_size // 2) + assert block_scale.shape == (batch_size, hidden_size // block_size) + + # Verify output dtypes + assert y_fp4.dtype == torch.float4_e2m1fn_x2 + assert block_scale.dtype == torch.float8_e4m3fn + + # Reference computation + ref_rmsnorm = llama_rms_norm(x, weight, eps=eps) + + # Dequantize and verify values + y_dequant = dequantize_fp4_output(y_fp4, block_scale, block_size, global_scale) + torch.testing.assert_close( + y_dequant, + ref_rmsnorm.float(), + rtol=0.3, + atol=0.5, + ) + + @pytest.mark.parametrize("batch_size", [1, 4, 16]) + @pytest.mark.parametrize("seq_len", [16, 64]) + @pytest.mark.parametrize("hidden_size", [256, 1024]) + @pytest.mark.parametrize("dtype", [torch.float16]) + def test_auto_allocation_3d_nvfp4(self, batch_size, seq_len, hidden_size, dtype): + """Test auto-allocation with 3D input and NVFP4 format (with global_scale).""" + from flashinfer.cute_dsl.rmsnorm_fp4quant import rmsnorm_fp4quant + + torch.manual_seed(42) + block_size = 16 # NVFP4 + eps = 1e-6 + + x = torch.randn(batch_size, seq_len, hidden_size, device="cuda", dtype=dtype) + weight = torch.randn(hidden_size, device="cuda", dtype=dtype) + + # Compute global_scale for NVFP4 + global_scale = compute_global_scale(x, weight, eps=eps) + + # Call without providing y_fp4 and block_scale, but with global_scale + y_fp4, block_scale = rmsnorm_fp4quant( + x, weight, global_scale=global_scale, eps=eps, block_size=block_size + ) + + # Verify output shapes + assert y_fp4.shape == (batch_size, seq_len, hidden_size // 2) + assert block_scale.shape == (batch_size, seq_len, hidden_size // block_size) + + # Verify output dtypes + assert y_fp4.dtype == torch.float4_e2m1fn_x2 + assert block_scale.dtype == torch.float8_e4m3fn + + # Reference computation + ref_rmsnorm = llama_rms_norm(x, weight, eps=eps) + + # Dequantize and verify values + y_dequant = dequantize_fp4_output(y_fp4, block_scale, block_size, global_scale) + torch.testing.assert_close( + y_dequant, + ref_rmsnorm.float(), + rtol=0.3, + atol=0.5, + ) + + @pytest.mark.parametrize("batch_size", [1, 16, 128]) + @pytest.mark.parametrize("hidden_size", [256, 1024]) + @pytest.mark.parametrize("dtype", [torch.float16]) + def test_auto_allocation_mxfp4(self, batch_size, hidden_size, dtype): + """Test auto-allocation with MXFP4 format (block_size=32, UE8M0 scales).""" + from flashinfer.cute_dsl.rmsnorm_fp4quant import rmsnorm_fp4quant + + torch.manual_seed(42) + block_size = 32 + eps = 1e-6 + + x = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype) + weight = torch.randn(hidden_size, device="cuda", dtype=dtype) + + # Call without providing y_fp4 and block_scale + y_fp4, block_scale = rmsnorm_fp4quant( + x, weight, eps=eps, block_size=block_size, scale_format="ue8m0" + ) + + # Verify output shapes + assert y_fp4.shape == (batch_size, hidden_size // 2) + assert block_scale.shape == (batch_size, hidden_size // block_size) + + # Verify output dtypes + assert y_fp4.dtype == torch.float4_e2m1fn_x2 + assert block_scale.dtype == torch.uint8 # UE8M0 uses uint8 + + # Reference computation + ref_rmsnorm = llama_rms_norm(x, weight, eps=eps) + + # Dequantize and verify values + y_dequant = dequantize_fp4_output(y_fp4, block_scale, block_size) + torch.testing.assert_close( + y_dequant, + ref_rmsnorm.float(), + rtol=0.3, + atol=0.7, + ) + + @pytest.mark.parametrize("batch_size", [16, 128]) + @pytest.mark.parametrize("hidden_size", [512, 1024]) + @pytest.mark.parametrize("dtype", [torch.float16]) + def test_auto_allocation_swizzled(self, batch_size, hidden_size, dtype): + """Test auto-allocation with swizzled scale factor layout (NVFP4 with global_scale).""" + from flashinfer.cute_dsl.rmsnorm_fp4quant import rmsnorm_fp4quant + + torch.manual_seed(42) + block_size = 16 # NVFP4 + eps = 1e-6 + + x = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype) + weight = torch.randn(hidden_size, device="cuda", dtype=dtype) + + # Compute global_scale for NVFP4 + global_scale = compute_global_scale(x, weight, eps=eps) + + # Call without providing y_fp4 and block_scale, with swizzled layout + y_fp4, block_scale = rmsnorm_fp4quant( + x, + weight, + global_scale=global_scale, + eps=eps, + block_size=block_size, + is_sf_swizzled_layout=True, + ) + + # Verify output shapes + assert y_fp4.shape == (batch_size, hidden_size // 2) + # Swizzled layout has different shape + factor = block_size * 4 + num_m_tiles = (batch_size + 127) // 128 + num_k_tiles = (hidden_size + factor - 1) // factor + expected_swizzled_size = num_m_tiles * num_k_tiles * 32 * 4 * 4 + assert block_scale.shape == (expected_swizzled_size,) + + # Verify output dtypes + assert y_fp4.dtype == torch.float4_e2m1fn_x2 + assert block_scale.dtype == torch.float8_e4m3fn + + # Unswizzle and compare with non-swizzled version + y_fp4_ref = torch.empty( + batch_size, hidden_size // 2, device="cuda", dtype=torch.float4_e2m1fn_x2 + ) + block_scale_ref = torch.empty( + batch_size, + hidden_size // block_size, + device="cuda", + dtype=torch.float8_e4m3fn, + ) + rmsnorm_fp4quant( + x, + weight, + y_fp4_ref, + block_scale_ref, + global_scale=global_scale, + eps=eps, + block_size=block_size, + ) + + # FP4 values should be identical (view as uint8 for comparison) + torch.testing.assert_close(y_fp4.view(torch.uint8), y_fp4_ref.view(torch.uint8)) + + # Unswizzle and compare scales + block_scale_unswizzled = unswizzle_sf( + block_scale.view(torch.uint8), batch_size, hidden_size, block_size + ).view(torch.float8_e4m3fn) + torch.testing.assert_close( + block_scale_unswizzled.view(torch.uint8), block_scale_ref.view(torch.uint8) + ) + + @pytest.mark.parametrize("batch_size", [16, 128]) + @pytest.mark.parametrize("hidden_size", [512, 1024]) + def test_auto_allocation_matches_preallocated(self, batch_size, hidden_size): + """Test that auto-allocation produces same results as pre-allocated tensors (NVFP4).""" + from flashinfer.cute_dsl.rmsnorm_fp4quant import rmsnorm_fp4quant + + torch.manual_seed(42) + block_size = 16 # NVFP4 + eps = 1e-6 + dtype = torch.float16 + + x = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype) + weight = torch.randn(hidden_size, device="cuda", dtype=dtype) + + # Compute global_scale for NVFP4 + global_scale = compute_global_scale(x, weight, eps=eps) + + # Pre-allocated version + y_fp4_pre = torch.empty( + batch_size, hidden_size // 2, device="cuda", dtype=torch.float4_e2m1fn_x2 + ) + block_scale_pre = torch.empty( + batch_size, + hidden_size // block_size, + device="cuda", + dtype=torch.float8_e4m3fn, + ) + rmsnorm_fp4quant( + x, + weight, + y_fp4_pre, + block_scale_pre, + global_scale=global_scale, + eps=eps, + block_size=block_size, + ) + + # Auto-allocated version + y_fp4_auto, block_scale_auto = rmsnorm_fp4quant( + x, weight, global_scale=global_scale, eps=eps, block_size=block_size + ) + + # Results should be identical (view as uint8 for comparison) + torch.testing.assert_close( + y_fp4_auto.view(torch.uint8), y_fp4_pre.view(torch.uint8) + ) + torch.testing.assert_close( + block_scale_auto.view(torch.uint8), block_scale_pre.view(torch.uint8) + ) + + if __name__ == "__main__": pytest.main([__file__, "-v"])