diff --git a/benchmarks/routines/flashinfer_benchmark_utils.py b/benchmarks/routines/flashinfer_benchmark_utils.py index 38ca0090ac..63c8f53dd5 100644 --- a/benchmarks/routines/flashinfer_benchmark_utils.py +++ b/benchmarks/routines/flashinfer_benchmark_utils.py @@ -517,9 +517,9 @@ def dtype_str_to_torch_dtype(dtype_str): "8.6": [], "8.9": [], "9.0": [], - "10.0": ["cuda"], - "10.3": ["cuda"], - "12.0": ["cuda"], + "10.0": ["cuda", "cute-dsl"], + "10.3": ["cuda", "cute-dsl"], + "12.0": ["cuda", "cute-dsl"], }, "nvfp4_batched_quantize": { "7.5": [], diff --git a/benchmarks/routines/quantization.py b/benchmarks/routines/quantization.py index 46a0ce2822..d75fd7b583 100644 --- a/benchmarks/routines/quantization.py +++ b/benchmarks/routines/quantization.py @@ -628,17 +628,15 @@ def testNvfp4Quantize(args): print(f"[VVERBOSE] {enable_pdl = }") def run_backend(backend, input_tensor, global_sf_tensor): - if backend == "cuda": - return flashinfer.nvfp4_quantize( - input_tensor, - global_sf_tensor, - sfLayout=sf_layout, - do_shuffle=do_shuffle, - sf_vec_size=sf_vec_size, - enable_pdl=enable_pdl, - ) - else: - raise ValueError(f"Unsupported backend: {backend}") + return flashinfer.nvfp4_quantize( + input_tensor, + global_sf_tensor, + sfLayout=sf_layout, + do_shuffle=do_shuffle, + sf_vec_size=sf_vec_size, + enable_pdl=enable_pdl, + backend=backend, + ) # Storage for timing results and outputs backend_times = {backend: [] for backend in backends} diff --git a/flashinfer/cute_dsl/fp4_common.py b/flashinfer/cute_dsl/fp4_common.py index 33307004b9..150658822c 100644 --- a/flashinfer/cute_dsl/fp4_common.py +++ b/flashinfer/cute_dsl/fp4_common.py @@ -154,6 +154,31 @@ def ld_global_v4_u32( return Uint32(v0), Uint32(v1), Uint32(v2), Uint32(v3) +@dsl_user_op +def ld_v4_u32( + base_ptr: Int64, *, loc=None, ip=None +) -> Tuple[Uint32, Uint32, Uint32, Uint32]: + """Load 128 bits (4 x uint32) using generic addressing (works for GMEM and SMEM).""" + result = llvm.inline_asm( + llvm.StructType.get_literal([T.i32(), T.i32(), T.i32(), T.i32()]), + [Int64(base_ptr).ir_value(loc=loc, ip=ip)], + "ld.v4.u32 {$0, $1, $2, $3}, [$4];", + "=r,=r,=r,=r,l", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + v0 = llvm.extractvalue(T.i32(), result, [0], loc=loc, ip=ip) + v1 = llvm.extractvalue(T.i32(), result, [1], loc=loc, ip=ip) + v2 = llvm.extractvalue(T.i32(), result, [2], loc=loc, ip=ip) + v3 = llvm.extractvalue(T.i32(), result, [3], loc=loc, ip=ip) + + return Uint32(v0), Uint32(v1), Uint32(v2), Uint32(v3) + + @dsl_user_op def st_global_u64(base_ptr: Int64, value: Uint64, *, loc=None, ip=None): """Store 64 bits to global memory.""" @@ -173,12 +198,87 @@ def st_global_u64(base_ptr: Int64, value: Uint64, *, loc=None, ip=None): @dsl_user_op def get_ptr_as_int64(tensor: cute.Tensor, offset: Int32, *, loc=None, ip=None) -> Int64: - """Get the memory address of tensor[offset] as Int64.""" + """Get the memory address of tensor[offset] as Int64. + + WARNING: This uses ptrtoint which strips address space information. + For SMEM tensors, the resulting Int64 is a raw SMEM offset that does NOT + work with generic-addressing loads (ld.v4.u32). Use only with explicit + address-space loads (ld.global.*) or for global memory tensors. + """ elem_ptr = tensor.iterator + Int32(offset) ptr_int = llvm.ptrtoint(T.i64(), elem_ptr.llvm_ptr, loc=loc, ip=ip) return Int64(ptr_int) +@dsl_user_op +def get_smem_ptr_as_int32( + tensor: cute.Tensor, offset: Int32, *, loc=None, ip=None +) -> Int32: + """Get the shared-memory byte address of tensor[offset] as Int32. + + Uses Pointer.toint() which preserves the SMEM address space (addrspace 3), + returning a 32-bit SMEM address suitable for ld.shared.* instructions. + """ + elem_ptr = tensor.iterator + Int32(offset) + return elem_ptr.toint(loc=loc, ip=ip) + + +@dsl_user_op +def ld_shared_v4_u32( + smem_addr: Int32, *, loc=None, ip=None +) -> Tuple[Uint32, Uint32, Uint32, Uint32]: + """Load 128 bits (4 x uint32) from shared memory via ld.shared.v4.u32. + + Args: + smem_addr: 32-bit shared memory address (from get_smem_ptr_as_int32). + + Returns: + 4 Uint32 values (16 bytes total, e.g. 8 packed fp16 elements). + """ + result = llvm.inline_asm( + llvm.StructType.get_literal([T.i32(), T.i32(), T.i32(), T.i32()]), + [Int32(smem_addr).ir_value(loc=loc, ip=ip)], + "ld.shared.v4.u32 {$0, $1, $2, $3}, [$4];", + "=r,=r,=r,=r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + v0 = llvm.extractvalue(T.i32(), result, [0], loc=loc, ip=ip) + v1 = llvm.extractvalue(T.i32(), result, [1], loc=loc, ip=ip) + v2 = llvm.extractvalue(T.i32(), result, [2], loc=loc, ip=ip) + v3 = llvm.extractvalue(T.i32(), result, [3], loc=loc, ip=ip) + return Uint32(v0), Uint32(v1), Uint32(v2), Uint32(v3) + + +@dsl_user_op +def pack_16bit_to_u32(lo, hi, *, loc=None, ip=None) -> Uint32: + """Pack two 16-bit scalar values (fp16 or bf16) into one Uint32 (half2/bfloat2). + + Uses PTX mov.b32 to bitwise-pack two 16-bit register values into a single + 32-bit register, suitable for half2/bfloat2 SIMD operations. + """ + lo_ir = lo.ir_value(loc=loc, ip=ip) + hi_ir = hi.ir_value(loc=loc, ip=ip) + lo_i16 = llvm.bitcast(T.i16(), lo_ir, loc=loc, ip=ip) + hi_i16 = llvm.bitcast(T.i16(), hi_ir, loc=loc, ip=ip) + return Uint32( + llvm.inline_asm( + T.i32(), + [lo_i16, hi_i16], + "mov.b32 $0, {$1, $2};", + "=r,h,h", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + ) + + # ============================================================================= # PTX Intrinsics - Math Operations # ============================================================================= @@ -537,6 +637,81 @@ def cvt_f32_to_e4m3(a: Float32, *, loc=None, ip=None) -> Uint32: ) +@dsl_user_op +def cvt_e4m3x4_to_f32x4( + packed: Uint32, *, loc=None, ip=None +) -> tuple[Float32, Float32, Float32, Float32]: + """Convert 4 packed E4M3 bytes (in a uint32) to 4 float32 values. + + Uses e4m3x2 → f16x2 → f32 conversion path (SM89+/PTX ISA 7.8+). + Input: uint32 containing bytes [b0, b1, b2, b3] (low to high). + Output: (f0, f1, f2, f3) as Float32. + """ + result = llvm.inline_asm( + llvm.StructType.get_literal([T.f32(), T.f32(), T.f32(), T.f32()]), + [Uint32(packed).ir_value(loc=loc, ip=ip)], + """ + { + .reg .b16 pair_lo, pair_hi; + .reg .b32 h2_lo, h2_hi; + .reg .b16 h0, h1, h2, h3; + mov.b32 {pair_lo, pair_hi}, $4; + cvt.rn.f16x2.e4m3x2 h2_lo, pair_lo; + cvt.rn.f16x2.e4m3x2 h2_hi, pair_hi; + mov.b32 {h0, h1}, h2_lo; + mov.b32 {h2, h3}, h2_hi; + cvt.f32.f16 $0, h0; + cvt.f32.f16 $1, h1; + cvt.f32.f16 $2, h2; + cvt.f32.f16 $3, h3; + } + """, + "=f,=f,=f,=f,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + + f0 = llvm.extractvalue(T.f32(), result, [0], loc=loc, ip=ip) + f1 = llvm.extractvalue(T.f32(), result, [1], loc=loc, ip=ip) + f2 = llvm.extractvalue(T.f32(), result, [2], loc=loc, ip=ip) + f3 = llvm.extractvalue(T.f32(), result, [3], loc=loc, ip=ip) + + return Float32(f0), Float32(f1), Float32(f2), Float32(f3) + + +@dsl_user_op +def cvt_f32x2_to_half2(a: Float32, b: Float32, *, loc=None, ip=None) -> Uint32: + """Pack two float32 values into a half2 (uint32 containing two fp16 values). + + Uses cvt.rn.f16.f32 for each value, then packs into a single uint32. + Matches __float22half2_rn() behavior in CUDA. + """ + return Uint32( + llvm.inline_asm( + T.i32(), + [ + Float32(a).ir_value(loc=loc, ip=ip), + Float32(b).ir_value(loc=loc, ip=ip), + ], + """ + { + .reg .b16 h0, h1; + cvt.rn.f16.f32 h0, $1; + cvt.rn.f16.f32 h1, $2; + mov.b32 $0, {h0, h1}; + } + """, + "=r,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + @dsl_user_op def fp8_e4m3_to_f32_and_rcp(fp8_val: Uint32, *, loc=None, ip=None) -> Float32: """Convert FP8 E4M3 to float32 AND compute reciprocal.""" @@ -573,6 +748,55 @@ def fp8_e4m3_to_f32_and_rcp(fp8_val: Uint32, *, loc=None, ip=None) -> Float32: ) +@dsl_user_op +def nvfp4_compute_output_scale( + fp8_val: Uint32, global_scale: Float32, *, loc=None, ip=None +) -> Float32: + """Compute NVFP4 output_scale matching the CUDA kernel exactly. + + Converts E4M3 scale factor to float via hardware f16x2 path, then computes + rcp(float_scale * rcp(global_scale)). Returns 0 when scale is zero. + + This matches quantization_utils.cuh: + SFValue = static_cast(tmp); + outputScale = rcp_approx(SFValue * rcp_approx(SFScaleVal)); + """ + return Float32( + llvm.inline_asm( + T.f32(), + [ + Uint32(fp8_val).ir_value(loc=loc, ip=ip), + Float32(global_scale).ir_value(loc=loc, ip=ip), + ], + """ + { + .reg .pred p_zero; + .reg .b16 fp8_pair; + .reg .b32 h2_32; + .reg .b16 h_lo, h_hi; + .reg .f32 scale_f32, rcp_gs, product, result; + + cvt.u16.u32 fp8_pair, $1; + cvt.rn.f16x2.e4m3x2 h2_32, fp8_pair; + mov.b32 {h_lo, h_hi}, h2_32; + cvt.f32.f16 scale_f32, h_lo; + + rcp.approx.ftz.f32 rcp_gs, $2; + mul.f32 product, scale_f32, rcp_gs; + rcp.approx.ftz.f32 result, product; + + setp.eq.f32 p_zero, scale_f32, 0f00000000; + selp.f32 $0, 0f00000000, result, p_zero; + } + """, + "=f,r,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + # ============================================================================= # UE8M0 Intrinsics (for MXFP4) # ============================================================================= diff --git a/flashinfer/quantization/__init__.py b/flashinfer/quantization/__init__.py index 55c58c343a..bebe8424bc 100644 --- a/flashinfer/quantization/__init__.py +++ b/flashinfer/quantization/__init__.py @@ -48,6 +48,7 @@ if is_cute_dsl_available(): from .kernels.mxfp8_quantize import mxfp8_quantize_cute_dsl from .kernels.mxfp4_quantize import mxfp4_quantize_cute_dsl + from .kernels.nvfp4_quantize import nvfp4_quantize_cute_dsl _cute_dsl_available = True except ImportError: @@ -83,4 +84,5 @@ __all__ += [ "mxfp8_quantize_cute_dsl", "mxfp4_quantize_cute_dsl", + "nvfp4_quantize_cute_dsl", ] diff --git a/flashinfer/quantization/fp4_quantization.py b/flashinfer/quantization/fp4_quantization.py index 627dcdc3e2..3a69751ace 100644 --- a/flashinfer/quantization/fp4_quantization.py +++ b/flashinfer/quantization/fp4_quantization.py @@ -657,6 +657,7 @@ def fp4_quantize( is_sf_swizzled_layout: bool = True, is_sf_8x4_layout: bool = False, enable_pdl: Optional[bool] = None, + backend: str = "cuda", ) -> Tuple[torch.Tensor, torch.Tensor]: """Quantize input tensor to FP4 format. @@ -672,6 +673,12 @@ def fp4_quantize( is_sf_8x4_layout (bool, optional): Whether to use 8x4 layout or 128x4 layout for scale factors. Defaults to False. enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic Dependent Launch). If None, automatically detects based on device capability. Defaults to None. + backend (str, optional): Backend to use for quantization. + - "cuda": Use CUDA kernel (default, stable). + - "cute-dsl": Use CuTe-DSL kernel (requires SM100+, **experimental**). + Supported combinations: + * sf_vec_size=16, sf_use_ue8m0=False: all layouts, fp16/bf16/fp8 (NVFP4) + * sf_vec_size=32, sf_use_ue8m0=True: 128x4 swizzled and linear, fp16/bf16 (MXFP4) Returns: Tuple[torch.Tensor, torch.Tensor]: A tuple containing: @@ -683,10 +690,28 @@ def fp4_quantize( - BFloat16 input when BFloat16 is not enabled - FP8 input when FP8 is not enabled - sf_vec_size other than 16 or 32 + ValueError: If the "cute-dsl" backend is requested for an unsupported parameter combination. + + Warning: + The "cute-dsl" backend is **experimental** and not part of the stable API. + It may change or be removed in future versions without notice. """ if sf_vec_size != 16 and sf_vec_size != 32: raise NotImplementedError("sf_vec_size can only be 16 or 32") + if backend == "cute-dsl": + return _fp4_quantize_cute_dsl( + input, + global_scale, + sf_vec_size, + sf_use_ue8m0, + is_sf_swizzled_layout, + is_sf_8x4_layout, + enable_pdl, + ) + elif backend != "cuda": + raise ValueError(f"Unknown backend: {backend}. Must be 'cuda' or 'cute-dsl'.") + # for column major input, we need to transpose the input is_column_major = input.stride(-2) == 1 if is_column_major: @@ -714,6 +739,71 @@ def fp4_quantize( return x_q, sf +def _fp4_quantize_cute_dsl( + input: torch.Tensor, + global_scale: Optional[torch.Tensor], + sf_vec_size: int, + sf_use_ue8m0: bool, + is_sf_swizzled_layout: bool, + is_sf_8x4_layout: bool, + enable_pdl: Optional[bool], +) -> Tuple[torch.Tensor, torch.Tensor]: + """CuTe-DSL dispatch for fp4_quantize. Maps parameters to the appropriate kernel.""" + from ..cute_dsl import is_cute_dsl_available + + if not is_cute_dsl_available(): + raise RuntimeError( + "CuTe-DSL backend requested but CuTe-DSL is not available. " + "Please install the required dependencies." + ) + + if sf_vec_size == 16 and not sf_use_ue8m0: + # NVFP4 path: E4M3 scale factors, sf_vec_size=16, all layouts + from .kernels.nvfp4_quantize import ( + SF_LAYOUT_128x4, + SF_LAYOUT_8x4, + SF_LAYOUT_LINEAR, + nvfp4_quantize_cute_dsl, + ) + + if not is_sf_swizzled_layout: + sf_layout = SF_LAYOUT_LINEAR + elif is_sf_8x4_layout: + sf_layout = SF_LAYOUT_8x4 + else: + sf_layout = SF_LAYOUT_128x4 + + return nvfp4_quantize_cute_dsl( + input, global_scale, sf_layout=sf_layout, enable_pdl=enable_pdl + ) + + elif sf_vec_size == 32 and sf_use_ue8m0: + # MXFP4 path: UE8M0 scale factors, sf_vec_size=32 + if is_sf_8x4_layout: + raise ValueError( + "CuTe-DSL MXFP4 kernel does not support 8x4 layout. " + "Supported: swizzled 128x4 and linear." + ) + from .kernels.mxfp4_quantize import ( + SF_LAYOUT_128x4, + SF_LAYOUT_LINEAR, + mxfp4_quantize_cute_dsl, + ) + + sf_layout = SF_LAYOUT_128x4 if is_sf_swizzled_layout else SF_LAYOUT_LINEAR + return mxfp4_quantize_cute_dsl( + input, sf_layout=sf_layout, enable_pdl=enable_pdl + ) + + else: + raise ValueError( + f"CuTe-DSL backend does not support sf_vec_size={sf_vec_size} with " + f"sf_use_ue8m0={sf_use_ue8m0}. Supported: " + f"(sf_vec_size=16, sf_use_ue8m0=False) for NVFP4, " + f"(sf_vec_size=32, sf_use_ue8m0=True) for MXFP4." + ) + + @flashinfer_api def block_scale_interleave(unswizzled_sf: torch.Tensor) -> torch.Tensor: """Swizzle block scale tensor for FP4 format. @@ -833,55 +923,95 @@ def nvfp4_quantize( do_shuffle=False, sf_vec_size=16, enable_pdl=None, + backend: str = "cuda", ): """ Quantize input tensor to NVFP4 format. Parameters: - a (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16. + a (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16/float8_e4m3fn. a_global_sf (torch.Tensor): Global scale factor of shape [1] with dtype float32. sfLayout (SfLayout, optional): Scale factor layout. Defaults to SfLayout.layout_128x4. do_shuffle (bool, optional): Whether to shuffle the scale factors. Defaults to False. Only TRTLLM backend needs to shuffle the tensor B scale factors. sf_vec_size (int, optional): Scale factor vector size. Defaults to 16. enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic Dependent Launch). If None, automatically detects based on device capability. Defaults to None. + backend (str, optional): Backend to use for quantization. + - "cuda": Use CUDA kernel (default, stable) + - "cute-dsl": Use CuTe-DSL kernel (requires SM100+, **experimental**). + Supports all sfLayout values (layout_128x4, layout_8x4, layout_linear). + Supports input dtypes: fp16, bf16, float8_e4m3fn. + Only supports sf_vec_size=16. Returns: Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - Quantized tensor of shape [M, K/2] with dtype FLOAT4_E2M1X2 - Scale factors tensor with shape determined by layout and sf_vec_size + + Warning: + The "cute-dsl" backend is **experimental** and not part of the stable API. + It may change or be removed in future versions without notice. """ + if backend == "cuda": + if do_shuffle: + assert sfLayout == SfLayout.layout_128x4 + is_sf_swizzled_layout = False + is_sf_8x4_layout = False + else: + is_sf_swizzled_layout = sfLayout != SfLayout.layout_linear + is_sf_8x4_layout = sfLayout == SfLayout.layout_8x4 - if do_shuffle: - # Weights 128x4 + shuffle. It is done during the model load and we do not care much about the perf - assert sfLayout == SfLayout.layout_128x4 a_fp4, a_sf = fp4_quantize( a.cuda(), a_global_sf.cuda(), sf_vec_size, sf_use_ue8m0=False, - is_sf_swizzled_layout=False, - is_sf_8x4_layout=False, + is_sf_swizzled_layout=is_sf_swizzled_layout, + is_sf_8x4_layout=is_sf_8x4_layout, enable_pdl=enable_pdl, ) + elif backend == "cute-dsl": + from ..cute_dsl import is_cute_dsl_available + if not is_cute_dsl_available(): + raise RuntimeError( + "CuTe-DSL backend requested but CuTe-DSL is not available. " + "Please install the required dependencies." + ) + if sf_vec_size != 16: + raise ValueError( + f"CuTe-DSL backend only supports sf_vec_size=16, got {sf_vec_size}" + ) + from .kernels.nvfp4_quantize import ( + SF_LAYOUT_128x4, + SF_LAYOUT_8x4, + SF_LAYOUT_LINEAR, + nvfp4_quantize_cute_dsl, + ) + + _sf_layout_map = { + SfLayout.layout_128x4: SF_LAYOUT_128x4, + SfLayout.layout_8x4: SF_LAYOUT_8x4, + SfLayout.layout_linear: SF_LAYOUT_LINEAR, + } + if do_shuffle: + assert sfLayout == SfLayout.layout_128x4 + sf_layout_int = SF_LAYOUT_LINEAR + else: + sf_layout_int = _sf_layout_map[sfLayout] + + a_fp4, a_sf = nvfp4_quantize_cute_dsl( + a.cuda(), a_global_sf.cuda(), sf_layout=sf_layout_int, enable_pdl=enable_pdl + ) + else: + raise ValueError(f"Unknown backend: {backend}. Must be 'cuda' or 'cute-dsl'.") + + if do_shuffle: epilogue_tile_m = 128 a_fp4 = shuffle_matrix_a(a_fp4.view(torch.uint8), epilogue_tile_m) a_sf = shuffle_matrix_sf_a(a_sf.view(torch.uint8), epilogue_tile_m).reshape( a_sf.shape ) - else: - # Activations with 8x4 layout for SFs (GEMM with small tileN) - # Activations with 128x4 layout for SFs (GEMM with large tileN) - a_fp4, a_sf = fp4_quantize( - a.cuda(), - a_global_sf.cuda(), - sf_vec_size, - sf_use_ue8m0=False, - is_sf_swizzled_layout=sfLayout != SfLayout.layout_linear, - is_sf_8x4_layout=sfLayout == SfLayout.layout_8x4, - enable_pdl=enable_pdl, - ) return a_fp4, a_sf diff --git a/flashinfer/quantization/kernels/__init__.py b/flashinfer/quantization/kernels/__init__.py index 7e99b74a54..5df0f078d3 100644 --- a/flashinfer/quantization/kernels/__init__.py +++ b/flashinfer/quantization/kernels/__init__.py @@ -27,7 +27,7 @@ """ from .mxfp4_quantize import ( - MXFP4QuantizeSwizzledKernel, + MXFP4QuantizeKernel, mxfp4_quantize_cute_dsl, ) from .mxfp8_quantize import ( @@ -35,11 +35,17 @@ MXFP8QuantizeSwizzledKernel, mxfp8_quantize_cute_dsl, ) +from .nvfp4_quantize import ( + NVFP4QuantizeSwizzledKernel, + nvfp4_quantize_cute_dsl, +) __all__ = [ - "MXFP4QuantizeSwizzledKernel", + "MXFP4QuantizeKernel", "mxfp4_quantize_cute_dsl", "MXFP8QuantizeLinearKernel", "MXFP8QuantizeSwizzledKernel", "mxfp8_quantize_cute_dsl", + "NVFP4QuantizeSwizzledKernel", + "nvfp4_quantize_cute_dsl", ] diff --git a/flashinfer/quantization/kernels/mxfp4_quantize.py b/flashinfer/quantization/kernels/mxfp4_quantize.py index f56dbf8eda..88fa3837c8 100644 --- a/flashinfer/quantization/kernels/mxfp4_quantize.py +++ b/flashinfer/quantization/kernels/mxfp4_quantize.py @@ -17,7 +17,7 @@ ================================= MXFP4 quantization kernel using CuTe-DSL. -Supports swizzled (128x4) scale factor layout. +Supports multiple scale factor layouts: swizzled 128x4 and linear. """ @@ -38,11 +38,15 @@ ROW_TILE_SIZE, # Low-level intrinsics compute_sf_index_swizzled_128x4_gpu, + compute_sf_index_linear_gpu, # High-level helpers (MXFP4) process_mxfp4_block_half, process_mxfp4_block_bfloat, ) +SF_LAYOUT_128x4 = 0 +SF_LAYOUT_LINEAR = 2 + # Blocks per SM for occupancy target _BLOCKS_PER_SM = 4 @@ -115,60 +119,70 @@ def _compute_swizzled_layout_sf_size( # ============================================================================= -class MXFP4QuantizeSwizzledKernel: +class MXFP4QuantizeKernel: """ - MXFP4 quantization kernel optimized for SWIZZLED layout. - - Key optimizations: - - Multi-row processing: threads process multiple rows per block when K is small - - Dynamic thread count based on K for 100% thread utilization - - Row-based iteration with grid-stride loop - - Padding row fast path - only zero out scale factors + MXFP4 quantization kernel supporting multiple scale factor layouts. - Thread utilization optimization: - - For small K: Multiple rows processed per block iteration - - For large K: Single row with column loop + Supported layouts: + - 128x4 (swizzled): Optimized for GEMM with large tileN + - linear: Simple row-major layout, no swizzling - Each thread processes one SF block (32 elements): + Key features: - UE8M0 scale factors (unsigned 8-bit exponent-only) - - E2M1 output format (4-bit, 2 values per byte) + - sf_vec_size=32 (each thread processes 32 elements) + - Multi-row processing when K is small, column loop when K is large + - Row-based iteration with grid-stride loop + - Padding row fast path for zeroing scale factors - This kernel is M-agnostic: compiled once per (K, dtype, pdl) combination. - M-dependent values (M, padded_M) are passed at runtime. + This kernel is M-agnostic: compiled once per (K, dtype, sf_layout, pdl) + combination. M-dependent values (M, padded_M) are passed at runtime. """ def __init__( self, dtype: cutlass.Numeric, K: int, + sf_layout: int = SF_LAYOUT_128x4, enable_pdl: bool = False, ): self.dtype = dtype self.K = K self.is_bfloat16 = dtype == cutlass.BFloat16 self.enable_pdl = enable_pdl + self.sf_layout = sf_layout + self.sf_is_128x4 = sf_layout == SF_LAYOUT_128x4 assert K % MXFP4_SF_VEC_SIZE == 0 self.num_sf_blocks_per_row = K // MXFP4_SF_VEC_SIZE - self.padded_sf_cols = ((self.num_sf_blocks_per_row + 3) // 4) * 4 - # Compute optimal thread count for 100% utilization + if sf_layout == SF_LAYOUT_LINEAR: + self.padded_sf_cols = self.num_sf_blocks_per_row + self.row_tile_size = 1 + else: + self.padded_sf_cols = ((self.num_sf_blocks_per_row + 3) // 4) * 4 + self.row_tile_size = ROW_TILE_SIZE # 128 + self.num_threads = _compute_optimal_threads_for_k(K) - # Multi-row processing constants (compile-time) - # threads_per_row = num_sf_blocks_per_row (1 thread per SF block) self.threads_per_row = self.num_sf_blocks_per_row - # Determine if we can process multiple rows or need column loop if self.threads_per_row <= self.num_threads: - # Small K: multiple rows per block self.rows_per_block = self.num_threads // self.threads_per_row self.needs_col_loop = False else: - # Large K: one row per block with column loop self.rows_per_block = 1 self.needs_col_loop = True + @cute.jit + def _compute_sf_offset( + self, row_idx: Int32, col_idx: Int32, padded_cols: Int32 + ) -> Int32: + """Compute scale factor offset based on layout (compile-time dispatch).""" + if cutlass.const_expr(self.sf_is_128x4): + return compute_sf_index_swizzled_128x4_gpu(row_idx, col_idx, padded_cols) + else: + return compute_sf_index_linear_gpu(row_idx, col_idx, padded_cols) + @cute.jit def __call__( self, @@ -201,7 +215,7 @@ def kernel( padded_M: Int32, ): """ - MXFP4 quantization kernel with swizzled scale factor layout. + MXFP4 quantization kernel with configurable scale factor layout. Dual-path kernel with compile-time selection: - Small K path: Multi-row processing for improved thread utilization @@ -211,7 +225,7 @@ def kernel( 1. Load 32 bf16/fp16 elements (4 x 128-bit loads) 2. Compute max absolute value using SIMD reduction 3. Compute UE8M0 scale: ceil(log2(max / 6.0)) + 127 - 4. Swizzle scale factor to 128x4 layout + 4. Store scale factor using layout-specific indexing 5. Scale elements and convert to E2M1 6. Store 16 bytes (32 FP4 values) @@ -252,11 +266,9 @@ def kernel( is_padding_row = row_idx >= M if is_padding_row: - # Fast path: padding row - only zero out scale factors - # Each participating thread zeros one SF at a time local_sf_idx = sf_idx_in_row while local_sf_idx < padded_sf_cols: - sf_offset = compute_sf_index_swizzled_128x4_gpu( + sf_offset = self._compute_sf_offset( row_idx, local_sf_idx, padded_sf_cols ) mScales[sf_offset] = Uint8(0) @@ -283,13 +295,11 @@ def kernel( packed64_1, ) = process_mxfp4_block_half(row_input, elem_base) - # Write swizzled scale factor - sf_offset = compute_sf_index_swizzled_128x4_gpu( + sf_offset = self._compute_sf_offset( row_idx, sf_idx_in_row, padded_sf_cols ) mScales[sf_offset] = scale_ue8m0 - # Store 16 bytes (32 FP4 values = 2 x st.global.u64) row_output = mOutput[row_idx, None] out_base = sf_idx_in_row * (MXFP4_SF_VEC_SIZE // 2) out_ptr0 = get_ptr_as_int64(row_output, out_base) @@ -297,10 +307,9 @@ def kernel( st_global_u64(out_ptr0, packed64_0) st_global_u64(out_ptr1, packed64_1) - # Handle padding SF columns (columns beyond actual K) padding_sf_start = num_sf_blocks_per_row + sf_idx_in_row while padding_sf_start < padded_sf_cols: - sf_offset = compute_sf_index_swizzled_128x4_gpu( + sf_offset = self._compute_sf_offset( row_idx, padding_sf_start, padded_sf_cols ) mScales[sf_offset] = Uint8(0) @@ -310,24 +319,20 @@ def kernel( else: # ===== LARGE K PATH: Single row with column loop ===== - # Grid-stride loop over rows row_idx = bidx while row_idx < padded_M: is_padding_row = row_idx >= M - # Initialize sf_idx before control flow to satisfy DSL type requirements sf_idx = Int32(tidx) if is_padding_row: - # Fast path: padding row - only zero out scale factors while sf_idx < padded_sf_cols: - sf_offset = compute_sf_index_swizzled_128x4_gpu( + sf_offset = self._compute_sf_offset( row_idx, sf_idx, padded_sf_cols ) mScales[sf_offset] = Uint8(0) sf_idx = sf_idx + num_threads else: - # Normal path: process actual data row with column loop num_sf_iters = ( num_sf_blocks_per_row + num_threads - 1 ) // num_threads @@ -339,7 +344,6 @@ def kernel( elem_base = local_sf_idx * MXFP4_SF_VEC_SIZE row_input = mInput[row_idx, None] - # Process block: load, compute scale, convert to E2M1 if cutlass.const_expr(self.is_bfloat16): ( _, @@ -355,13 +359,11 @@ def kernel( packed64_1, ) = process_mxfp4_block_half(row_input, elem_base) - # Write swizzled scale factor - sf_offset = compute_sf_index_swizzled_128x4_gpu( + sf_offset = self._compute_sf_offset( row_idx, local_sf_idx, padded_sf_cols ) mScales[sf_offset] = scale_ue8m0 - # Store 16 bytes (32 FP4 values = 2 x st.global.u64) row_output = mOutput[row_idx, None] out_base = local_sf_idx * (MXFP4_SF_VEC_SIZE // 2) out_ptr0 = get_ptr_as_int64(row_output, out_base) @@ -369,10 +371,9 @@ def kernel( st_global_u64(out_ptr0, packed64_0) st_global_u64(out_ptr1, packed64_1) - # Handle padding SF columns (columns beyond actual K) padding_sf_start = num_sf_blocks_per_row + tidx while padding_sf_start < padded_sf_cols: - sf_offset = compute_sf_index_swizzled_128x4_gpu( + sf_offset = self._compute_sf_offset( row_idx, padding_sf_start, padded_sf_cols ) mScales[sf_offset] = Uint8(0) @@ -394,19 +395,21 @@ def kernel( def _get_compiled_kernel_mxfp4( is_bfloat16: bool, K: int, + sf_layout: int = SF_LAYOUT_128x4, enable_pdl: bool = False, ) -> Tuple[Callable, int]: """ Get or compile MXFP4 kernel with TVM-FFI. - Cached by (K, dtype, pdl) - M-agnostic, device-independent compilation. + Cached by (K, dtype, sf_layout, pdl) - M-agnostic, device-independent + compilation. Returns: Tuple of (compiled_kernel, rows_per_block) where rows_per_block is used by the caller to compute num_blocks at runtime. """ cutlass_dtype = cutlass.BFloat16 if is_bfloat16 else cutlass.Float16 - kernel_obj = MXFP4QuantizeSwizzledKernel(cutlass_dtype, K, enable_pdl) + kernel_obj = MXFP4QuantizeKernel(cutlass_dtype, K, sf_layout, enable_pdl) # Use symbolic M for dynamic batch sizes sym_m = cute.sym_int() @@ -442,6 +445,7 @@ def _get_compiled_kernel_mxfp4( @flashinfer_api def mxfp4_quantize_cute_dsl( input: torch.Tensor, + sf_layout: int = SF_LAYOUT_128x4, enable_pdl: bool | None = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ @@ -451,29 +455,34 @@ def mxfp4_quantize_cute_dsl( - Global scale computed as (448 * 6) / max(|input|) - UE8M0 scale factors - E2M1 output format (4-bit, 2 values per byte) - - Swizzled (128x4) scale factor layout + - Supports 128x4 (swizzled) and linear scale factor layouts - The kernel is compiled once per (K, dtype, pdl) combination and handles - varying M (batch size) at runtime without recompilation. + The kernel is compiled once per (K, dtype, sf_layout, pdl) combination and + handles varying M (batch size) at runtime without recompilation. Args: input: Input tensor of shape [M, K] with dtype fp16/bf16 + sf_layout: Scale factor layout (0=128x4, 2=linear). enable_pdl: Whether to enable PDL (Programmatic Dependent Launch). If None, automatically detects based on device capability (SM >= 9.0). Returns: Tuple of: - fp4_tensor: Quantized tensor of shape [M, K/2] with dtype uint8 - - scale_tensor: Scale factors as uint8 tensor (swizzled layout) + - scale_tensor: Scale factors as uint8 tensor + reshaped to [padded_rows, K/32] """ from ...utils import device_support_pdl + _valid_sf_layouts = (SF_LAYOUT_128x4, SF_LAYOUT_LINEAR) + assert sf_layout in _valid_sf_layouts, ( + f"sf_layout must be one of {_valid_sf_layouts}, got {sf_layout}" + ) assert input.dtype in (torch.float16, torch.bfloat16), ( f"Input dtype must be float16 or bfloat16, got {input.dtype}" ) assert input.is_cuda, "Input must be on CUDA device" - # Auto-detect PDL support based on device capability if enable_pdl is None: enable_pdl = device_support_pdl(input.device) @@ -491,38 +500,47 @@ def mxfp4_quantize_cute_dsl( input = input.contiguous() is_bfloat16 = input.dtype == torch.bfloat16 - # Cached device-specific target grid for grid size computation target_grid = get_num_sm(input.device) * _BLOCKS_PER_SM - # Compute M-dependent values num_sf_blocks_per_row = k // MXFP4_SF_VEC_SIZE - padded_m = ((m + ROW_TILE_SIZE - 1) // ROW_TILE_SIZE) * ROW_TILE_SIZE - padded_sf_cols = ((num_sf_blocks_per_row + 3) // 4) * 4 + + if sf_layout == SF_LAYOUT_LINEAR: + row_tile_size = 1 + # NOTE: When adding a TMA-based kernel, padded_m must be rounded up to the + # TMA tile row dimension (e.g. round_up(m, tma_tile_rows)) and scale_output + # must be trimmed to m * num_sf_blocks_per_row before returning. + # See PR f4d10d9 for the analogous CUDA fix. + padded_m = m + padded_sf_cols = num_sf_blocks_per_row + else: + row_tile_size = ROW_TILE_SIZE # 128 + padded_m = ((m + row_tile_size - 1) // row_tile_size) * row_tile_size + padded_sf_cols = ((num_sf_blocks_per_row + 3) // 4) * 4 + scale_output_size = padded_m * padded_sf_cols - # Get or compile kernel (device-independent) - kernel_fn, rows_per_block = _get_compiled_kernel_mxfp4(is_bfloat16, k, enable_pdl) + kernel_fn, rows_per_block = _get_compiled_kernel_mxfp4( + is_bfloat16, k, sf_layout, enable_pdl + ) - # Compute grid size in Python (runtime, device-specific) num_blocks = min((padded_m + rows_per_block - 1) // rows_per_block, target_grid) - # Allocate outputs fp4_output = torch.empty(m, k // 2, dtype=torch.uint8, device=input.device) scale_output = torch.empty( scale_output_size, dtype=torch.uint8, device=input.device ) - # Launch kernel kernel_fn(input, fp4_output, scale_output, m, padded_m, num_blocks) - # Reshape scale output to match CUDA backend format: [padded_total, num_sf_per_row] scale_output = scale_output.reshape(-1, num_sf_blocks_per_row) return fp4_output, scale_output __all__ = [ - "MXFP4QuantizeSwizzledKernel", + "SF_LAYOUT_128x4", + "SF_LAYOUT_LINEAR", + "MXFP4QuantizeKernel", "mxfp4_quantize_cute_dsl", "_get_compiled_kernel_mxfp4", ] diff --git a/flashinfer/quantization/kernels/nvfp4_quantize.py b/flashinfer/quantization/kernels/nvfp4_quantize.py new file mode 100644 index 0000000000..52a4aa6cfa --- /dev/null +++ b/flashinfer/quantization/kernels/nvfp4_quantize.py @@ -0,0 +1,1224 @@ +""" +Copyright (c) 2025 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +NVFP4 Quantization using CuTe-DSL +================================= + +NVFP4 quantization kernel using CuTe-DSL. +Supports multiple scale factor layouts: swizzled 128x4, swizzled 8x4, and linear. + +Key differences from MXFP4: +- sf_vec_size=16 (vs 32 for MXFP4) +- E4M3 scale factors (vs UE8M0 for MXFP4) +- User-provided global_scale (vs auto-computed for MXFP4) +""" + +import functools +from typing import Callable, Tuple + +import cutlass +import cutlass.cute as cute +import cutlass.cute.nvgpu.cpasync as cpasync +import cutlass.pipeline as pipeline +from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait +import torch +from cutlass import Float32, Int32, Uint8 + +from ...api_logging import flashinfer_api +from ...cute_dsl.fp4_common import get_ptr_as_int64, st_global_u64 +from ...cute_dsl.utils import get_num_sm +from ..quantization_cute_dsl_utils import ( + NVFP4_SF_VEC_SIZE, + ROW_TILE_SIZE, + compute_sf_index_swizzled_128x4_gpu, + compute_sf_index_swizzled_8x4_gpu, + compute_sf_index_linear_gpu, + half2_max_abs_8 as half2_max_abs_8_fn, + bfloat2_max_abs_8 as bfloat2_max_abs_8_fn, + hmax_reduce_to_f32, + bfloat2_hmax_reduce_to_f32, + half2x8_to_e2m1x16_packed, + bfloat2x8_to_e2m1x16_packed, + process_nvfp4_block_half, + process_nvfp4_block_bfloat, + process_nvfp4_block_fp8, +) + +SF_LAYOUT_128x4 = 0 +SF_LAYOUT_8x4 = 1 +SF_LAYOUT_LINEAR = 2 + +_BLOCKS_PER_SM = 4 +_MAX_THREADS_PER_BLOCK = 1024 +_MIN_THREADS = 128 +_MAX_THREADS = 512 +_DEFAULT_THREADS = 256 + + +def _compute_optimal_threads_for_k(K: int) -> int: + """ + Compute optimal thread count for 100% thread utilization. + + For NVFP4, each thread processes one SF block (16 elements). + threads_per_row = K / 16 = num_sf_blocks_per_row + + We prefer LARGER thread counts (up to _MAX_THREADS) for better occupancy, + while maintaining 100% thread utilization. + """ + threads_per_row = K // NVFP4_SF_VEC_SIZE + + if threads_per_row >= _MAX_THREADS: + return _MAX_THREADS + + if threads_per_row <= _MAX_THREADS: + threads = (_MAX_THREADS // threads_per_row) * threads_per_row + if threads >= _MIN_THREADS: + return threads + threads = threads_per_row + while threads < _MIN_THREADS: + threads += threads_per_row + if threads <= _MAX_THREADS: + return threads + + return _DEFAULT_THREADS + + +def _compute_swizzled_layout_sf_size( + total_row: int, total_column: int, row_size: int = 128 +) -> int: + """Compute size of swizzled scale factor buffer.""" + padded_row = (total_row + row_size - 1) // row_size * row_size + padded_column = (total_column + 3) // 4 * 4 + return padded_row * padded_column + + +# ============================================================================= +# CuTe-DSL Kernel Class for NVFP4 Swizzled Layout +# ============================================================================= + + +class NVFP4QuantizeSwizzledKernel: + """ + NVFP4 quantization kernel supporting multiple scale factor layouts. + + Supported layouts: + - 128x4 (swizzled): Optimized for GEMM with large tileN + - 8x4 (swizzled): Optimized for GEMM with small tileN + - linear: Simple row-major layout, no swizzling + + Key features: + - E4M3 scale factors (FP8 format) with user-provided global_scale + - sf_vec_size=16 (each thread processes 16 elements) + - Multi-row processing when K is small, column loop when K is large + - Row-based iteration with grid-stride loop + - Padding row fast path for zeroing scale factors + + This kernel is M-agnostic: compiled once per (K, dtype, sf_layout, pdl) + combination. M-dependent values (M, padded_M) and global_scale are passed + at runtime. + """ + + def __init__( + self, + dtype: cutlass.Numeric, + K: int, + sf_layout: int = SF_LAYOUT_128x4, + enable_pdl: bool = False, + ): + self.dtype = dtype + self.K = K + self.is_bfloat16 = dtype == cutlass.BFloat16 + self.is_fp8 = dtype == cutlass.Float8E4M3FN + self.enable_pdl = enable_pdl + self.sf_layout = sf_layout + self.sf_is_128x4 = sf_layout == SF_LAYOUT_128x4 + self.sf_is_8x4 = sf_layout == SF_LAYOUT_8x4 + + assert K % NVFP4_SF_VEC_SIZE == 0 + self.num_sf_blocks_per_row = K // NVFP4_SF_VEC_SIZE + + if sf_layout == SF_LAYOUT_LINEAR: + self.padded_sf_cols = self.num_sf_blocks_per_row + self.row_tile_size = 1 + elif sf_layout == SF_LAYOUT_8x4: + self.padded_sf_cols = ((self.num_sf_blocks_per_row + 3) // 4) * 4 + self.row_tile_size = 8 + else: + self.padded_sf_cols = ((self.num_sf_blocks_per_row + 3) // 4) * 4 + self.row_tile_size = ROW_TILE_SIZE # 128 + + self.num_threads = _compute_optimal_threads_for_k(K) + + self.threads_per_row = self.num_sf_blocks_per_row + + if self.threads_per_row <= self.num_threads: + self.rows_per_block = self.num_threads // self.threads_per_row + self.needs_col_loop = False + else: + self.rows_per_block = 1 + self.needs_col_loop = True + + @cute.jit + def _compute_sf_offset( + self, row_idx: Int32, col_idx: Int32, padded_cols: Int32 + ) -> Int32: + """Compute scale factor offset based on layout (compile-time dispatch).""" + if cutlass.const_expr(self.sf_is_128x4): + return compute_sf_index_swizzled_128x4_gpu(row_idx, col_idx, padded_cols) + else: + if cutlass.const_expr(self.sf_is_8x4): + return compute_sf_index_swizzled_8x4_gpu(row_idx, col_idx, padded_cols) + else: + return compute_sf_index_linear_gpu(row_idx, col_idx, padded_cols) + + @cute.jit + def __call__( + self, + mInput: cute.Tensor, + mOutput: cute.Tensor, + mScales: cute.Tensor, + M: Int32, + padded_M: Int32, + num_blocks: Int32, + mGlobalScale: cute.Tensor, + stream, + ): + threads_per_block = self.num_threads + + self.kernel(mInput, mOutput, mScales, M, padded_M, mGlobalScale).launch( + grid=[num_blocks, 1, 1], + block=[threads_per_block, 1, 1], + max_number_threads=[_MAX_THREADS_PER_BLOCK, 1, 1], + min_blocks_per_mp=_BLOCKS_PER_SM, + stream=stream, + use_pdl=self.enable_pdl, + ) + + @cute.kernel + def kernel( + self, + mInput: cute.Tensor, + mOutput: cute.Tensor, + mScales: cute.Tensor, + M: Int32, + padded_M: Int32, + mGlobalScale: cute.Tensor, + ): + """ + NVFP4 quantization kernel with swizzled scale factor layout. + + Dual-path kernel with compile-time selection: + - Small K path: Multi-row processing for improved thread utilization + - Large K path: Single row with column loop + + Each thread processes one SF block (16 elements): + 1. Load 16 elements (2 x 128-bit for fp16/bf16, 1 x 128-bit for fp8) + 2. Compute max absolute value using SIMD reduction + 3. Compute E4M3 scale: cvt_f32_to_e4m3(global_scale * max / 6.0) + 4. Store scale factor using layout-specific indexing + 5. Back-convert E4M3, compute output_scale = global_scale / scale_back + 6. Scale elements and convert to E2M1 + 7. Store 8 bytes (16 FP4 values) + """ + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + grid_dim_x, _, _ = cute.arch.grid_dim() + + if cutlass.const_expr(self.enable_pdl): + cute.arch.griddepcontrol_wait() + + # Read global_scale from device memory (avoids CPU-GPU sync at launch) + global_scale = Float32(mGlobalScale[Int32(0)]) + + num_sf_blocks_per_row = self.num_sf_blocks_per_row + padded_sf_cols = self.padded_sf_cols + num_threads = self.num_threads + rows_per_block = self.rows_per_block + threads_per_row = self.threads_per_row + + if cutlass.const_expr(not self.needs_col_loop): + # ===== SMALL K PATH: Multi-row processing ===== + row_in_block = tidx // threads_per_row + sf_idx_in_row = tidx % threads_per_row + + row_batch_idx = bidx + total_row_batches = cute.ceil_div(padded_M, rows_per_block) + + while row_batch_idx < total_row_batches: + base_row = row_batch_idx * rows_per_block + row_idx = base_row + row_in_block + + if row_idx < padded_M: + is_padding_row = row_idx >= M + + if is_padding_row: + local_sf_idx = sf_idx_in_row + while local_sf_idx < padded_sf_cols: + sf_offset = self._compute_sf_offset( + row_idx, local_sf_idx, padded_sf_cols + ) + mScales[sf_offset] = Uint8(0) + local_sf_idx = local_sf_idx + threads_per_row + else: + if sf_idx_in_row < num_sf_blocks_per_row: + elem_base = sf_idx_in_row * NVFP4_SF_VEC_SIZE + row_input = mInput[row_idx, None] + + if cutlass.const_expr(self.is_fp8): + scale_fp8, packed64 = process_nvfp4_block_fp8( + row_input, elem_base, global_scale + ) + elif cutlass.const_expr(self.is_bfloat16): + scale_fp8, packed64 = process_nvfp4_block_bfloat( + row_input, elem_base, global_scale + ) + else: + scale_fp8, packed64 = process_nvfp4_block_half( + row_input, elem_base, global_scale + ) + + sf_offset = self._compute_sf_offset( + row_idx, sf_idx_in_row, padded_sf_cols + ) + mScales[sf_offset] = scale_fp8 + + row_output = mOutput[row_idx, None] + out_base = sf_idx_in_row * (NVFP4_SF_VEC_SIZE // 2) + out_ptr = get_ptr_as_int64(row_output, out_base) + st_global_u64(out_ptr, packed64) + + padding_sf_start = num_sf_blocks_per_row + sf_idx_in_row + while padding_sf_start < padded_sf_cols: + sf_offset = self._compute_sf_offset( + row_idx, padding_sf_start, padded_sf_cols + ) + mScales[sf_offset] = Uint8(0) + padding_sf_start = padding_sf_start + threads_per_row + + row_batch_idx = row_batch_idx + grid_dim_x + + else: + # ===== LARGE K PATH: Single row with column loop ===== + row_idx = bidx + while row_idx < padded_M: + is_padding_row = row_idx >= M + + sf_idx = Int32(tidx) + + if is_padding_row: + while sf_idx < padded_sf_cols: + sf_offset = self._compute_sf_offset( + row_idx, sf_idx, padded_sf_cols + ) + mScales[sf_offset] = Uint8(0) + sf_idx = sf_idx + num_threads + else: + num_sf_iters = ( + num_sf_blocks_per_row + num_threads - 1 + ) // num_threads + + for sf_iter in range(num_sf_iters): + local_sf_idx = sf_iter * num_threads + tidx + + if local_sf_idx < num_sf_blocks_per_row: + elem_base = local_sf_idx * NVFP4_SF_VEC_SIZE + row_input = mInput[row_idx, None] + + if cutlass.const_expr(self.is_fp8): + scale_fp8, packed64 = process_nvfp4_block_fp8( + row_input, elem_base, global_scale + ) + elif cutlass.const_expr(self.is_bfloat16): + scale_fp8, packed64 = process_nvfp4_block_bfloat( + row_input, elem_base, global_scale + ) + else: + scale_fp8, packed64 = process_nvfp4_block_half( + row_input, elem_base, global_scale + ) + + sf_offset = self._compute_sf_offset( + row_idx, local_sf_idx, padded_sf_cols + ) + mScales[sf_offset] = scale_fp8 + + row_output = mOutput[row_idx, None] + out_base = local_sf_idx * (NVFP4_SF_VEC_SIZE // 2) + out_ptr = get_ptr_as_int64(row_output, out_base) + st_global_u64(out_ptr, packed64) + + padding_sf_start = num_sf_blocks_per_row + tidx + while padding_sf_start < padded_sf_cols: + sf_offset = self._compute_sf_offset( + row_idx, padding_sf_start, padded_sf_cols + ) + mScales[sf_offset] = Uint8(0) + padding_sf_start = padding_sf_start + num_threads + + row_idx = row_idx + grid_dim_x + + if cutlass.const_expr(self.enable_pdl): + cute.arch.griddepcontrol_launch_dependents() + + +# ============================================================================= +# CuTe-DSL TMA Kernel Class for NVFP4 +# ============================================================================= + +_TMA_ROW_TILE = 16 +_TMA_COL_TILE = 64 # Per-warp column tile +_TMA_NUM_CONSUMER_WARPS = 8 +_TMA_NUM_STAGES = 4 +_TMA_COLS_PER_STAGE = _TMA_NUM_CONSUMER_WARPS * _TMA_COL_TILE # 512 + + +def _round_up(x: int, d: int) -> int: + return ((x + d - 1) // d) * d + + +class NVFP4QuantizeTMAKernel: + """ + TMA-based NVFP4 quantization kernel with pipelined producer-consumer + warp specialization, matching the CUDA TMA kernel architecture. + + Architecture (matches csrc/nv_internal/.../quantization.cuh): + - 1 producer warp (warp 0) issues TMA G2S loads into staged SMEM buffers + - 8 consumer warps (warps 1-8) read from SMEM, quantize, write to GMEM + - PipelineTmaAsync manages multi-stage buffering (4 stages) + - Each TMA tile: [16, 512] = 16 rows x 8 warps x 64 cols per warp + - Each consumer warp: 4 threads/row x 8 rows/warp, 2 row iterations + - Each thread: 16 elements (1 SF block) via 2 x ld.shared.v4.u32 + - Grid-stride loop over row tiles, inner loop over K/512 col chunks + + Effective when M >= 1024 and K is a multiple of 512. + """ + + def __init__( + self, + dtype: cutlass.Numeric, + K: int, + sf_layout: int = SF_LAYOUT_128x4, + enable_pdl: bool = False, + ): + self.dtype = dtype + self.K = K + self.is_bfloat16 = dtype == cutlass.BFloat16 + self.is_fp8 = dtype == cutlass.Float8E4M3FN + self.enable_pdl = enable_pdl + self.sf_layout = sf_layout + self.sf_is_128x4 = sf_layout == SF_LAYOUT_128x4 + self.sf_is_8x4 = sf_layout == SF_LAYOUT_8x4 + + assert not self.is_fp8, "FP8 input not yet supported for TMA kernel" + assert K % _TMA_COLS_PER_STAGE == 0, ( + f"K ({K}) must be a multiple of {_TMA_COLS_PER_STAGE} for TMA kernel" + ) + + self.num_sf_blocks_per_row = K // NVFP4_SF_VEC_SIZE + self.num_col_chunks = K // _TMA_COLS_PER_STAGE + + if sf_layout == SF_LAYOUT_LINEAR: + self.padded_sf_cols = self.num_sf_blocks_per_row + self.row_tile_size = 1 + elif sf_layout == SF_LAYOUT_8x4: + self.padded_sf_cols = ((self.num_sf_blocks_per_row + 3) // 4) * 4 + self.row_tile_size = 8 + else: + self.padded_sf_cols = ((self.num_sf_blocks_per_row + 3) // 4) * 4 + self.row_tile_size = ROW_TILE_SIZE + + self.num_consumer_warps = _TMA_NUM_CONSUMER_WARPS # 8 + self.num_stages = _TMA_NUM_STAGES + self.producer_warp_id = 0 # Warp 0 is producer (matches CUDA kernel) + self.threads_per_cta = 32 * (self.num_consumer_warps + 1) # 288 + self.rows_per_block = _TMA_ROW_TILE + self.buffer_align_bytes = 1024 + self.cluster_shape_mn = (1, 1) + self.elems_per_stage = _TMA_ROW_TILE * _TMA_COLS_PER_STAGE # 8192 + + # Thread indexing constants (matches CUDA TmaKernelTraitsTwoBytes) + self.THREADS_PER_ROW = 4 # laneIdx % 4 + self.ROWS_PER_WARP = 8 # 32 / 4 + self.ROW_ITERATIONS = _TMA_ROW_TILE // self.ROWS_PER_WARP # 2 + self.ELTS_PER_THREAD = NVFP4_SF_VEC_SIZE # 16 + + @cute.jit + def _compute_sf_offset( + self, row_idx: Int32, col_idx: Int32, padded_cols: Int32 + ) -> Int32: + if cutlass.const_expr(self.sf_is_128x4): + return compute_sf_index_swizzled_128x4_gpu(row_idx, col_idx, padded_cols) + else: + if cutlass.const_expr(self.sf_is_8x4): + return compute_sf_index_swizzled_8x4_gpu(row_idx, col_idx, padded_cols) + else: + return compute_sf_index_linear_gpu(row_idx, col_idx, padded_cols) + + @cute.jit + def _quantize_sf_block( + self, + h0: cutlass.Uint32, + h1: cutlass.Uint32, + h2: cutlass.Uint32, + h3: cutlass.Uint32, + h4: cutlass.Uint32, + h5: cutlass.Uint32, + h6: cutlass.Uint32, + h7: cutlass.Uint32, + global_row: Int32, + sf_col: Int32, + global_scale: Float32, + M: Int32, + padded_M: Int32, + padded_sf_cols: Int32, + mOutput: cute.Tensor, + mScales: cute.Tensor, + ): + """Quantize one 16-element SF block and write results to GMEM.""" + from ...cute_dsl.fp4_common import ( + cvt_f32_to_e4m3, + nvfp4_compute_output_scale, + rcp_approx_ftz, + ) + + if global_row < padded_M: + is_padding_row = global_row >= M + + if is_padding_row: + sf_offset = self._compute_sf_offset(global_row, sf_col, padded_sf_cols) + mScales[sf_offset] = Uint8(0) + else: + if cutlass.const_expr(self.is_bfloat16): + block_max_h2 = bfloat2_max_abs_8_fn(h0, h1, h2, h3, h4, h5, h6, h7) + block_max = bfloat2_hmax_reduce_to_f32(block_max_h2) + else: + block_max_h2 = half2_max_abs_8_fn(h0, h1, h2, h3, h4, h5, h6, h7) + block_max = hmax_reduce_to_f32(block_max_h2) + + fp4_max_rcp = rcp_approx_ftz(Float32(6.0)) + scale_float = global_scale * (block_max * fp4_max_rcp) + scale_fp8_u32 = cvt_f32_to_e4m3(scale_float) + scale_fp8 = Uint8(scale_fp8_u32 & cutlass.Uint32(0xFF)) + + output_scale = nvfp4_compute_output_scale(scale_fp8_u32, global_scale) + + if cutlass.const_expr(self.is_bfloat16): + packed64 = bfloat2x8_to_e2m1x16_packed( + h0, h1, h2, h3, h4, h5, h6, h7, output_scale + ) + else: + packed64 = half2x8_to_e2m1x16_packed( + h0, h1, h2, h3, h4, h5, h6, h7, output_scale + ) + + sf_offset = self._compute_sf_offset(global_row, sf_col, padded_sf_cols) + mScales[sf_offset] = scale_fp8 + + row_output = mOutput[global_row, None] + out_base = sf_col * Int32(NVFP4_SF_VEC_SIZE // 2) + out_ptr = get_ptr_as_int64(row_output, out_base) + st_global_u64(out_ptr, packed64) + + @cute.jit + def __call__( + self, + mInput: cute.Tensor, + mOutput: cute.Tensor, + mScales: cute.Tensor, + M: Int32, + padded_M: Int32, + num_blocks: Int32, + mGlobalScale: cute.Tensor, + stream, + ): + # 3D global tensor: [padded_M, K/64, 64] so each warp's 64-col + # stripe is the contiguous innermost dimension, matching the CUDA + # TMA kernel's 3D tensor map. + gInput = cute.make_tensor( + mInput.iterator, + cute.make_layout( + (padded_M, self.K // _TMA_COL_TILE, _TMA_COL_TILE), + stride=(self.K, _TMA_COL_TILE, 1), + ), + ) + + # SMEM layout per stage: [rows=16, warps=8, cols_per_warp=64] + # with SWIZZLE_128B applied. Within each warp's [16, 64] tile the + # row stride is 64 elems = 128 bytes, putting row bits in the S=3 + # range of the swizzle so different rows map to different banks. + smem_swizzle = cute.make_swizzle(3, 4, 3) # SWIZZLE_128B for 2B types + smem_outer_single = cute.make_layout( + (_TMA_ROW_TILE, _TMA_NUM_CONSUMER_WARPS, _TMA_COL_TILE), + stride=(_TMA_COL_TILE, _TMA_ROW_TILE * _TMA_COL_TILE, 1), + ) + smem_single_composed = cute.make_composed_layout( + smem_swizzle, 0, smem_outer_single + ) + + cta_tiler = (_TMA_ROW_TILE, _TMA_NUM_CONSUMER_WARPS, _TMA_COL_TILE) + tma_atom, tma_tensor = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileG2SOp(), + gInput, + smem_single_composed, + cta_tiler, + ) + + total_smem_elems = self.elems_per_stage * self.num_stages + # Staged outer layout (no swizzle — swizzle passed separately) + smem_outer_staged = cute.make_layout( + (_TMA_ROW_TILE, _TMA_NUM_CONSUMER_WARPS, _TMA_COL_TILE, self.num_stages), + stride=( + _TMA_COL_TILE, + _TMA_ROW_TILE * _TMA_COL_TILE, + 1, + self.elems_per_stage, + ), + ) + # Flat layout for manual-swizzle consumer reads + smem_layout_flat = cute.make_layout((total_smem_elems,)) + + self.num_tma_load_bytes = cute.size_in_bytes(self.dtype, smem_outer_single) + + @cute.struct + class SharedStorage: + load_full_mbar: cute.struct.MemRange[cutlass.Int64, self.num_stages] + load_empty_mbar: cute.struct.MemRange[cutlass.Int64, self.num_stages] + smem_data: cute.struct.Align[ + cute.struct.MemRange[self.dtype, total_smem_elems], + self.buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + + self.kernel( + tma_atom, + tma_tensor, + mOutput, + mScales, + M, + padded_M, + mGlobalScale, + smem_outer_staged, + smem_swizzle, + smem_layout_flat, + ).launch( + grid=[num_blocks, 1, 1], + block=[self.threads_per_cta, 1, 1], + max_number_threads=[ + self.threads_per_cta, + 1, + 1, + ], # __launch_bounds__(288, 2) + min_blocks_per_mp=2, + cluster=(*self.cluster_shape_mn, 1), + stream=stream, + use_pdl=self.enable_pdl, + ) + + @cute.kernel + def kernel( + self, + tma_atom: cute.CopyAtom, + gInput_tma: cute.Tensor, + mOutput: cute.Tensor, + mScales: cute.Tensor, + M: Int32, + padded_M: Int32, + mGlobalScale: cute.Tensor, + smem_outer_staged: cute.Layout, + smem_swizzle: cute.Swizzle, + smem_layout_flat: cute.Layout, + ): + from ...cute_dsl.fp4_common import ( + get_smem_ptr_as_int32, + ld_shared_v4_u32, + ) + + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + grid_dim_x, _, _ = cute.arch.grid_dim() + + if cutlass.const_expr(self.enable_pdl): + cute.arch.griddepcontrol_wait() + + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + lane_idx = tidx % 32 + + global_scale = Float32(mGlobalScale[Int32(0)]) + padded_sf_cols = self.padded_sf_cols + num_sf_blocks_per_row = self.num_sf_blocks_per_row + num_col_chunks = self.num_col_chunks + elems_per_stage = self.elems_per_stage + + # ---- SMEM allocation ---- + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + load_mbar_ptr = storage.load_full_mbar.data_ptr() + # Swizzled tensor for TMA partition (address-space-correct writes) + sData_staged = storage.smem_data.get_tensor( + smem_outer_staged, swizzle=smem_swizzle + ) + # Flat tensor for manual-swizzle consumer reads + sData_flat = storage.smem_data.get_tensor(smem_layout_flat) + + # ---- Pipeline setup ---- + load_pipeline = pipeline.PipelineTmaAsync.create( + barrier_storage=load_mbar_ptr, + num_stages=self.num_stages, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1), + consumer_group=pipeline.CooperativeGroup( + pipeline.Agent.Thread, self.num_consumer_warps + ), + tx_count=self.num_tma_load_bytes, + cta_layout_vmnk=cute.tiled_divide( + cute.make_layout((*self.cluster_shape_mn, 1)), (1,) + ), + defer_sync=True, + ) + + pipeline_init_arrive(cluster_shape_mn=self.cluster_shape_mn, is_relaxed=True) + pipeline_init_wait(cluster_shape_mn=self.cluster_shape_mn) + + # ---- TMA partition (3D: rows × warps × cols_per_warp) ---- + gSrc_tiled = cute.local_tile( + gInput_tma, + (_TMA_ROW_TILE, _TMA_NUM_CONSUMER_WARPS, _TMA_COL_TILE), + (None, None, None), + ) + tAsA, tAgA = cpasync.tma_partition( + tma_atom, + 0, + cute.make_layout(1), + cute.group_modes(sData_staged, 0, 3), # Group 3 tile modes + cute.group_modes(gSrc_tiled, 0, 3), + ) + + num_row_tiles = cute.ceil_div(padded_M, _TMA_ROW_TILE) + + # ---- Consumer thread indexing (matches CUDA TmaKernelTraitsTwoBytes) ---- + # 4 threads per row, 8 rows per warp, 2 row iterations per stage + col_idx_local = lane_idx % Int32(self.THREADS_PER_ROW) + row_idx_local = lane_idx // Int32(self.THREADS_PER_ROW) + + # ======== Producer warp (warp 0) ======== + if warp_idx == self.producer_warp_id: + producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_stages + ) + + row_tile_idx = bidx + while row_tile_idx < num_row_tiles: + col_chunk = Int32(0) + while col_chunk < num_col_chunks: + load_pipeline.producer_acquire(producer_state) + + cute.copy( + tma_atom, + tAgA[(None, row_tile_idx, col_chunk, 0)], + tAsA[(None, producer_state.index)], + tma_bar_ptr=load_pipeline.producer_get_barrier(producer_state), + ) + + producer_state.advance() + col_chunk = col_chunk + Int32(1) + + row_tile_idx = row_tile_idx + grid_dim_x + + load_pipeline.producer_tail(producer_state) + + # ======== Consumer warps (warps 1-8) ======== + if warp_idx > 0: + consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_stages + ) + + # 0-indexed consumer warp id + consumer_warp_idx = warp_idx - Int32(1) + + # Pre-compute warp tile base offset (constant per warp) + # SMEM 3D layout: [rows=16, warps=8, cols=64] per stage + # stride: [64, 1024, 1] → warp tile base = warp * 1024 + warp_tile_elems = _TMA_ROW_TILE * _TMA_COL_TILE # 1024 + warp_tile_base = consumer_warp_idx * Int32(warp_tile_elems) + + # Float4 base position for this thread (0, 2, 4, 6) + f4_base = col_idx_local * Int32(2) + + # Global column offset for SF index: warp's column within K + base_col_in_stage = consumer_warp_idx * Int32( + _TMA_COL_TILE + ) + col_idx_local * Int32(self.ELTS_PER_THREAD) + + row_tile_idx = bidx + while row_tile_idx < num_row_tiles: + base_row = row_tile_idx * Int32(_TMA_ROW_TILE) + + col_chunk = Int32(0) + while col_chunk < num_col_chunks: + load_pipeline.consumer_wait(consumer_state) + stage = consumer_state.index + + # ---- Read ALL SMEM data with SWIZZLE_128B addressing ---- + # Within each warp's [16,64] tile, the XOR pattern matches + # CUDA's load_input_vec: float4_idx ^= row & 7 + # Physical elem offset in warp tile for (row, float4 f): + # row * 64 + (f ^ (row & 7)) * 8 + stage_base = stage * Int32(elems_per_stage) + + # Row iteration 0 (row_idx_local = 0..7) + r0_xor = row_idx_local & Int32(7) + r0_f4_0 = f4_base ^ r0_xor + r0_f4_1 = (f4_base + Int32(1)) ^ r0_xor + r0_row_base = ( + stage_base + + warp_tile_base + + row_idx_local * Int32(_TMA_COL_TILE) + ) + r0_addr_0 = get_smem_ptr_as_int32( + sData_flat, r0_row_base + r0_f4_0 * Int32(8) + ) + r0_addr_1 = get_smem_ptr_as_int32( + sData_flat, r0_row_base + r0_f4_1 * Int32(8) + ) + r0_h0, r0_h1, r0_h2, r0_h3 = ld_shared_v4_u32(r0_addr_0) + r0_h4, r0_h5, r0_h6, r0_h7 = ld_shared_v4_u32(r0_addr_1) + + # Row iteration 1 (row = row_idx_local + 8) + r1_row = row_idx_local + Int32(self.ROWS_PER_WARP) + r1_xor = r1_row & Int32(7) + r1_f4_0 = f4_base ^ r1_xor + r1_f4_1 = (f4_base + Int32(1)) ^ r1_xor + r1_row_base = ( + stage_base + warp_tile_base + r1_row * Int32(_TMA_COL_TILE) + ) + r1_addr_0 = get_smem_ptr_as_int32( + sData_flat, r1_row_base + r1_f4_0 * Int32(8) + ) + r1_addr_1 = get_smem_ptr_as_int32( + sData_flat, r1_row_base + r1_f4_1 * Int32(8) + ) + r1_h0, r1_h1, r1_h2, r1_h3 = ld_shared_v4_u32(r1_addr_0) + r1_h4, r1_h5, r1_h6, r1_h7 = ld_shared_v4_u32(r1_addr_1) + + # ---- Quantize and write: both row iterations ---- + # Global column base for SF index computation + global_col_base = col_chunk * Int32(_TMA_COLS_PER_STAGE) + sf_col = (global_col_base + base_col_in_stage) // Int32( + NVFP4_SF_VEC_SIZE + ) + + # Row iteration 0 + global_row_0 = base_row + row_idx_local + self._quantize_sf_block( + r0_h0, + r0_h1, + r0_h2, + r0_h3, + r0_h4, + r0_h5, + r0_h6, + r0_h7, + global_row_0, + sf_col, + global_scale, + M, + padded_M, + padded_sf_cols, + mOutput, + mScales, + ) + + # Row iteration 1 + global_row_1 = base_row + row_idx_local + Int32(self.ROWS_PER_WARP) + self._quantize_sf_block( + r1_h0, + r1_h1, + r1_h2, + r1_h3, + r1_h4, + r1_h5, + r1_h6, + r1_h7, + global_row_1, + sf_col, + global_scale, + M, + padded_M, + padded_sf_cols, + mOutput, + mScales, + ) + + # ---- Release pipeline after all work (matches CUDA pattern) ---- + load_pipeline.consumer_release(consumer_state) + consumer_state.advance() + + col_chunk = col_chunk + Int32(1) + + # Zero padding SF columns for swizzled layouts + if cutlass.const_expr(self.sf_layout != SF_LAYOUT_LINEAR): + consumer_tid = (warp_idx - Int32(1)) * Int32(32) + lane_idx + if consumer_tid < _TMA_ROW_TILE: + pad_row_idx = base_row + consumer_tid + if pad_row_idx < padded_M: + padding_sf = Int32(num_sf_blocks_per_row) + while padding_sf < padded_sf_cols: + sf_offset = self._compute_sf_offset( + pad_row_idx, padding_sf, padded_sf_cols + ) + mScales[sf_offset] = Uint8(0) + padding_sf = padding_sf + Int32(1) + + row_tile_idx = row_tile_idx + grid_dim_x + + if cutlass.const_expr(self.enable_pdl): + cute.arch.griddepcontrol_launch_dependents() + + +# ============================================================================= +# PyTorch Integration with TVM-FFI +# ============================================================================= + + +@functools.cache +def _get_compiled_kernel_nvfp4( + dtype_key: str, + K: int, + sf_layout: int = SF_LAYOUT_128x4, + enable_pdl: bool = False, +) -> Tuple[Callable, int]: + """ + Get or compile NVFP4 kernel with TVM-FFI. + + Cached by (K, dtype_key, sf_layout, pdl) - M-agnostic, device-independent + compilation. + + Args: + dtype_key: One of "float16", "bfloat16", "float8_e4m3fn". + + Returns: + Tuple of (compiled_kernel, rows_per_block) where rows_per_block + is used by the caller to compute num_blocks at runtime. + """ + _dtype_map = { + "float16": cutlass.Float16, + "bfloat16": cutlass.BFloat16, + "float8_e4m3fn": cutlass.Float8E4M3FN, + } + cutlass_dtype = _dtype_map[dtype_key] + kernel_obj = NVFP4QuantizeSwizzledKernel( + cutlass_dtype, K, sf_layout=sf_layout, enable_pdl=enable_pdl + ) + + sym_m = cute.sym_int() + + input_fake = cute.runtime.make_fake_compact_tensor( + cutlass_dtype, (sym_m, K), stride_order=(1, 0), assumed_align=16 + ) + output_fake = cute.runtime.make_fake_compact_tensor( + cutlass.Uint8, (sym_m, K // 2), stride_order=(1, 0), assumed_align=16 + ) + sym_scale_size = cute.sym_int() + scales_fake = cute.runtime.make_fake_compact_tensor( + cutlass.Uint8, (sym_scale_size,), assumed_align=16 + ) + global_scale_fake = cute.runtime.make_fake_compact_tensor( + cutlass.Float32, (1,), assumed_align=4 + ) + stream_fake = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True) + + compiled_kernel = cute.compile( + kernel_obj, + input_fake, + output_fake, + scales_fake, + Int32(1), # Dummy M + Int32(128), # Dummy padded_M + Int32(1), # Dummy num_blocks + global_scale_fake, + stream_fake, + options="--enable-tvm-ffi", + ) + + return compiled_kernel, kernel_obj.rows_per_block + + +_TMA_MIN_M = 1024 +# TMA wins when the total problem is large enough to amortize pipeline overhead. +# Empirically, floor(log2(M)) + floor(log2(K)) >= 25 is the crossover where TMA +# outperforms the default vectorized-load kernel, validated on B200 and SM120. +# We use bit_length()-1 (i.e., floor(log2)) rather than m*k to keep the boundary +# aligned with the power-of-2 grid it was tuned on. +_TMA_LOG2_MK_THRESHOLD = 25 + + +def _should_use_tma(m: int, k: int, dtype: torch.dtype) -> bool: + """Determine if TMA kernel should be used based on problem dimensions.""" + if dtype == torch.float8_e4m3fn: + return False + if k % _TMA_COLS_PER_STAGE != 0: + return False + if m < _TMA_MIN_M: + return False + # Use log2(M) + log2(K) threshold for the crossover point + return m.bit_length() - 1 + k.bit_length() - 1 >= _TMA_LOG2_MK_THRESHOLD + + +@functools.cache +def _get_compiled_kernel_nvfp4_tma( + dtype_key: str, + K: int, + sf_layout: int = SF_LAYOUT_128x4, + enable_pdl: bool = False, +) -> Tuple[Callable, int]: + """ + Get or compile TMA-based NVFP4 kernel with TVM-FFI. + + Cached by (K, dtype_key, sf_layout, pdl). + """ + _dtype_map = { + "float16": cutlass.Float16, + "bfloat16": cutlass.BFloat16, + } + cutlass_dtype = _dtype_map[dtype_key] + kernel_obj = NVFP4QuantizeTMAKernel( + cutlass_dtype, K, sf_layout=sf_layout, enable_pdl=enable_pdl + ) + + sym_m = cute.sym_int() + sym_padded_m = cute.sym_int() + + input_fake = cute.runtime.make_fake_compact_tensor( + cutlass_dtype, (sym_padded_m, K), stride_order=(1, 0), assumed_align=16 + ) + output_fake = cute.runtime.make_fake_compact_tensor( + cutlass.Uint8, (sym_m, K // 2), stride_order=(1, 0), assumed_align=16 + ) + sym_scale_size = cute.sym_int() + scales_fake = cute.runtime.make_fake_compact_tensor( + cutlass.Uint8, (sym_scale_size,), assumed_align=16 + ) + global_scale_fake = cute.runtime.make_fake_compact_tensor( + cutlass.Float32, (1,), assumed_align=4 + ) + stream_fake = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True) + + compiled_kernel = cute.compile( + kernel_obj, + input_fake, + output_fake, + scales_fake, + Int32(1), # Dummy M + Int32(1024), # Dummy padded_M + Int32(1), # Dummy num_blocks + global_scale_fake, + stream_fake, + options="--enable-tvm-ffi", + ) + + return compiled_kernel, kernel_obj.rows_per_block + + +@flashinfer_api +def nvfp4_quantize_cute_dsl( + input: torch.Tensor, + global_scale: torch.Tensor, + sf_layout: int = SF_LAYOUT_128x4, + enable_pdl: bool | None = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantize input tensor to NVFP4 format using CuTe-DSL kernel. + + This is a GPU implementation matching FlashInfer's nvfp4_quantize() behavior: + - E4M3 scale factors (FP8) + - E2M1 output format (4-bit, 2 values per byte) + - Supports 128x4, 8x4, and linear scale factor layouts + - sf_vec_size=16 + + The kernel is compiled once per (K, dtype, sf_layout, pdl) combination and + handles varying M (batch size) at runtime without recompilation. + + Args: + input: Input tensor of shape [M, K] with dtype fp16/bf16/float8_e4m3fn + global_scale: Scalar tensor (float32) for NVFP4 global scale factor + sf_layout: Scale factor layout (0=128x4, 1=8x4, 2=linear). + enable_pdl: Whether to enable PDL (Programmatic Dependent Launch). + If None, automatically detects based on device capability (SM >= 9.0). + + Returns: + Tuple of: + - fp4_tensor: Quantized tensor of shape [M, K/2] with dtype uint8 + - scale_tensor: E4M3 scale factors as uint8 tensor + reshaped to [padded_rows, K/16] + """ + from ...utils import device_support_pdl + + _valid_sf_layouts = (SF_LAYOUT_128x4, SF_LAYOUT_8x4, SF_LAYOUT_LINEAR) + assert sf_layout in _valid_sf_layouts, ( + f"sf_layout must be one of {_valid_sf_layouts}, got {sf_layout}" + ) + _supported_dtypes = (torch.float16, torch.bfloat16, torch.float8_e4m3fn) + assert input.dtype in _supported_dtypes, ( + f"Input dtype must be one of {_supported_dtypes}, got {input.dtype}" + ) + assert input.is_cuda, "Input must be on CUDA device" + + if enable_pdl is None: + enable_pdl = device_support_pdl(input.device) + + if input.dim() > 2: + m = input.numel() // input.shape[-1] + k = input.shape[-1] + input = input.reshape(m, k) + else: + m, k = input.shape + + assert k % NVFP4_SF_VEC_SIZE == 0, ( + f"K ({k}) must be divisible by NVFP4_SF_VEC_SIZE={NVFP4_SF_VEC_SIZE}" + ) + + input = input.contiguous() + + _torch_to_dtype_key = { + torch.float16: "float16", + torch.bfloat16: "bfloat16", + torch.float8_e4m3fn: "float8_e4m3fn", + } + dtype_key = _torch_to_dtype_key[input.dtype] + + if isinstance(global_scale, torch.Tensor): + global_scale_tensor = ( + global_scale.float().reshape(1).contiguous().to(input.device) + ) + else: + global_scale_tensor = torch.tensor( + [float(global_scale)], dtype=torch.float32, device=input.device + ) + + num_sm = get_num_sm(input.device) + + num_sf_blocks_per_row = k // NVFP4_SF_VEC_SIZE + + use_tma = _should_use_tma(m, k, input.dtype) + + if use_tma: + tma_row_tile = _TMA_ROW_TILE + if sf_layout == SF_LAYOUT_LINEAR: + padded_m = _round_up(m, tma_row_tile) + padded_sf_cols = num_sf_blocks_per_row + elif sf_layout == SF_LAYOUT_8x4: + padded_m = _round_up(m, max(tma_row_tile, 8)) + padded_sf_cols = ((num_sf_blocks_per_row + 3) // 4) * 4 + else: + padded_m = _round_up(m, max(tma_row_tile, ROW_TILE_SIZE)) + padded_sf_cols = ((num_sf_blocks_per_row + 3) // 4) * 4 + + scale_output_size = padded_m * padded_sf_cols + + kernel_fn, rows_per_block = _get_compiled_kernel_nvfp4_tma( + dtype_key, k, sf_layout, enable_pdl + ) + + # Match CUDA TMA kernel: grid = min(row_tiles, SM_count * 2) + tma_target_grid = num_sm * 2 + num_blocks = min( + (padded_m + rows_per_block - 1) // rows_per_block, tma_target_grid + ) + + input_padded = input + if padded_m > m: + input_padded = torch.zeros( + padded_m, k, dtype=input.dtype, device=input.device + ) + input_padded[:m, :] = input + + fp4_output = torch.empty(m, k // 2, dtype=torch.uint8, device=input.device) + scale_output = torch.empty( + scale_output_size, dtype=torch.uint8, device=input.device + ) + + kernel_fn( + input_padded, + fp4_output, + scale_output, + m, + padded_m, + num_blocks, + global_scale_tensor, + ) + + if sf_layout == SF_LAYOUT_LINEAR: + scale_output = scale_output[: m * num_sf_blocks_per_row] + + # Reshape using padded_sf_cols for swizzled layouts (the buffer is + # physically padded and stores data in swizzled order). For linear + # layout the padding is already trimmed above. + scale_output = scale_output.reshape(-1, padded_sf_cols) + + return fp4_output, scale_output + + # Non-TMA path + if sf_layout == SF_LAYOUT_LINEAR: + row_tile_size = 1 + padded_m = m + padded_sf_cols = num_sf_blocks_per_row + elif sf_layout == SF_LAYOUT_8x4: + row_tile_size = 8 + padded_m = ((m + row_tile_size - 1) // row_tile_size) * row_tile_size + padded_sf_cols = ((num_sf_blocks_per_row + 3) // 4) * 4 + else: + row_tile_size = ROW_TILE_SIZE # 128 + padded_m = ((m + row_tile_size - 1) // row_tile_size) * row_tile_size + padded_sf_cols = ((num_sf_blocks_per_row + 3) // 4) * 4 + + scale_output_size = padded_m * padded_sf_cols + + kernel_fn, rows_per_block = _get_compiled_kernel_nvfp4( + dtype_key, k, sf_layout, enable_pdl + ) + + default_target_grid = num_sm * _BLOCKS_PER_SM + num_blocks = min( + (padded_m + rows_per_block - 1) // rows_per_block, default_target_grid + ) + + fp4_output = torch.empty(m, k // 2, dtype=torch.uint8, device=input.device) + scale_output = torch.empty( + scale_output_size, dtype=torch.uint8, device=input.device + ) + + kernel_fn( + input, fp4_output, scale_output, m, padded_m, num_blocks, global_scale_tensor + ) + + # Reshape using padded_sf_cols: for swizzled layouts the buffer includes + # column padding; for linear layout padded_sf_cols == num_sf_blocks_per_row. + scale_output = scale_output.reshape(-1, padded_sf_cols) + + return fp4_output, scale_output + + +__all__ = [ + "SF_LAYOUT_128x4", + "SF_LAYOUT_8x4", + "SF_LAYOUT_LINEAR", + "NVFP4QuantizeSwizzledKernel", + "NVFP4QuantizeTMAKernel", + "nvfp4_quantize_cute_dsl", + "_get_compiled_kernel_nvfp4", + "_get_compiled_kernel_nvfp4_tma", +] diff --git a/flashinfer/quantization/quantization_cute_dsl_utils.py b/flashinfer/quantization/quantization_cute_dsl_utils.py index b4d3ac1dbf..8de4acdcf9 100644 --- a/flashinfer/quantization/quantization_cute_dsl_utils.py +++ b/flashinfer/quantization/quantization_cute_dsl_utils.py @@ -47,6 +47,13 @@ ROW_TILE_SIZE = 128 +# ============================================================================= +# NVFP4 Constants +# ============================================================================= + +# Scale factor vector size for NVFP4: each scale factor covers 16 elements +NVFP4_SF_VEC_SIZE = 16 + # ============================================================================= # MXFP4 Constants # ============================================================================= @@ -134,9 +141,13 @@ def bfloat2_hmax_reduce_to_f32(x: Uint32, *, loc=None, ip=None) -> Float32: @dsl_user_op def float_to_ue8m0_fast(value: Float32, *, loc=None, ip=None) -> Uint32: """ - Convert float to UE8M0 format using fast log2 approximation. + Convert float to UE8M0 format using exact IEEE 754 bit manipulation. - UE8M0 = ceil(log2(value)) + 127, clamped to [0, 255] + Matches the hardware __nv_cvt_float_to_e8m0(value, __NV_SATFINITE, cudaRoundPosInf): + - Extract biased exponent from IEEE 754 float + - If mantissa is nonzero, add 1 (round towards +inf / ceil behavior) + - Clamp to [0, 254] (255 = NaN in E8M0) + - Return 0 for zero/negative input """ return Uint32( llvm.inline_asm( @@ -144,19 +155,23 @@ def float_to_ue8m0_fast(value: Float32, *, loc=None, ip=None) -> Uint32: [Float32(value).ir_value(loc=loc, ip=ip)], """ { - .reg .pred p_zero, p_neg, p_ovf; - .reg .f32 log2_val; - .reg .s32 exp_int, result; + .reg .pred p_zero, p_has_mant, p_ovf; + .reg .u32 bits, exp_biased, mantissa, bump, result; setp.le.f32 p_zero, $1, 0f00000000; - lg2.approx.f32 log2_val, $1; - cvt.rpi.s32.f32 exp_int, log2_val; - add.s32 result, exp_int, 127; - setp.lt.s32 p_neg, result, 0; - setp.gt.s32 p_ovf, result, 255; - selp.s32 result, 0, result, p_neg; - selp.s32 result, 255, result, p_ovf; - selp.s32 $0, 0, result, p_zero; + + mov.b32 bits, $1; + shr.b32 exp_biased, bits, 23; + and.b32 exp_biased, exp_biased, 255; + and.b32 mantissa, bits, 0x7FFFFF; + + setp.ne.u32 p_has_mant, mantissa, 0; + selp.u32 bump, 1, 0, p_has_mant; + add.u32 result, exp_biased, bump; + + setp.gt.u32 p_ovf, result, 254; + selp.u32 result, 254, result, p_ovf; + selp.u32 $0, 0, result, p_zero; } """, "=r,f", @@ -516,6 +531,48 @@ def compute_sf_index_swizzled_128x4_gpu( return offset +@cute.jit +def compute_sf_index_swizzled_8x4_gpu( + row_idx: Int32, + col_idx: Int32, + padded_cols: Int32, +) -> Int32: + """Compute swizzled 8x4 scale factor index on GPU. + + Layout: [numMTiles, numKTiles, 8 (mTile), 4 (kTile)] + Tile size: 32 elements (8 rows x 4 cols). + """ + kMTileSize = Int32(8) + kKTileSize = Int32(4) + kTileElements = Int32(32) + + innerKIdx = col_idx % kKTileSize + innerMIdx = row_idx % kMTileSize + kTileIdx = col_idx // kKTileSize + mTileIdx = row_idx // kMTileSize + + numKTiles = (padded_cols + kKTileSize - Int32(1)) // kKTileSize + + offset = ( + mTileIdx * (numKTiles * kTileElements) + + kTileIdx * kTileElements + + innerMIdx * kKTileSize + + innerKIdx + ) + + return offset + + +@cute.jit +def compute_sf_index_linear_gpu( + row_idx: Int32, + col_idx: Int32, + num_cols: Int32, +) -> Int32: + """Compute linear (row-major) scale factor index on GPU.""" + return row_idx * num_cols + col_idx + + # ============================================================================= # High-Level Helper Functions for MXFP8 Quantization # ============================================================================= @@ -690,7 +747,12 @@ def process_mxfp4_block_half(row_tensor, elem_base: Int32) -> tuple: """ from cutlass import Uint8 - from ..cute_dsl.fp4_common import get_ptr_as_int64, hmax2, ld_global_v4_u32 + from ..cute_dsl.fp4_common import ( + get_ptr_as_int64, + hmax2, + ld_global_v4_u32, + rcp_approx_ftz, + ) # Load 32 elements (4 x 128-bit = 16 half2 values) ptr0 = get_ptr_as_int64(row_tensor, elem_base) @@ -709,9 +771,8 @@ def process_mxfp4_block_half(row_tensor, elem_base: Int32) -> tuple: block_max_h2 = hmax2(max_first, max_second) block_max = hmax_reduce_to_f32(block_max_h2) - # Compute UE8M0 scale factor - inv_e2m1_max = Float32(INV_FLOAT4_E2M1_MAX) - normalized_max = block_max * inv_e2m1_max + # Compute UE8M0 scale factor (rcp_approx matches CUDA's rcp.approx.ftz(6.0f)) + normalized_max = block_max * rcp_approx_ftz(Float32(6.0)) scale_ue8m0_u32 = float_to_ue8m0_fast(normalized_max) scale_ue8m0_u8 = scale_ue8m0_u32.to(Uint8) @@ -760,7 +821,12 @@ def process_mxfp4_block_bfloat(row_tensor, elem_base: Int32) -> tuple: """ from cutlass import Uint8 - from ..cute_dsl.fp4_common import bfloat2_hmax2, get_ptr_as_int64, ld_global_v4_u32 + from ..cute_dsl.fp4_common import ( + bfloat2_hmax2, + get_ptr_as_int64, + ld_global_v4_u32, + rcp_approx_ftz, + ) # Load 32 elements (4 x 128-bit = 16 bfloat2 values) ptr0 = get_ptr_as_int64(row_tensor, elem_base) @@ -779,9 +845,8 @@ def process_mxfp4_block_bfloat(row_tensor, elem_base: Int32) -> tuple: block_max_h2 = bfloat2_hmax2(max_first, max_second) block_max = bfloat2_hmax_reduce_to_f32(block_max_h2) - # Compute UE8M0 scale factor - inv_e2m1_max = Float32(INV_FLOAT4_E2M1_MAX) - normalized_max = block_max * inv_e2m1_max + # Compute UE8M0 scale factor (rcp_approx matches CUDA's rcp.approx.ftz(6.0f)) + normalized_max = block_max * rcp_approx_ftz(Float32(6.0)) scale_ue8m0_u32 = float_to_ue8m0_fast(normalized_max) scale_ue8m0_u8 = scale_ue8m0_u32.to(Uint8) @@ -961,7 +1026,361 @@ def bfloat2x16_to_e2m1x32_packed( return packed64_0, packed64_1 +# ============================================================================= +# NVFP4 High-Level Helper Functions (sf_vec_size=16, E4M3 scale factors) +# ============================================================================= + + +@cute.jit +def half2x8_to_e2m1x16_packed( + h0: Uint32, + h1: Uint32, + h2: Uint32, + h3: Uint32, + h4: Uint32, + h5: Uint32, + h6: Uint32, + h7: Uint32, + inv_scale: Float32, +) -> Uint64: + """ + Convert 8 half2 values (16 FP16) to 16 E2M1 and pack into u64. + + Returns: + Uint64 containing 16 E2M1 values (8 bytes) + """ + s0, s1 = half2_to_float2_scaled(h0, inv_scale) + s2, s3 = half2_to_float2_scaled(h1, inv_scale) + s4, s5 = half2_to_float2_scaled(h2, inv_scale) + s6, s7 = half2_to_float2_scaled(h3, inv_scale) + s8, s9 = half2_to_float2_scaled(h4, inv_scale) + s10, s11 = half2_to_float2_scaled(h5, inv_scale) + s12, s13 = half2_to_float2_scaled(h6, inv_scale) + s14, s15 = half2_to_float2_scaled(h7, inv_scale) + + packed_lo = cvt_e2m1x8_f32(s0, s1, s2, s3, s4, s5, s6, s7) + packed_hi = cvt_e2m1x8_f32(s8, s9, s10, s11, s12, s13, s14, s15) + + return (Uint64(packed_hi) << Uint64(32)) | Uint64(packed_lo) + + +@cute.jit +def bfloat2x8_to_e2m1x16_packed( + h0: Uint32, + h1: Uint32, + h2: Uint32, + h3: Uint32, + h4: Uint32, + h5: Uint32, + h6: Uint32, + h7: Uint32, + inv_scale: Float32, +) -> Uint64: + """ + Convert 8 bfloat2 values (16 BF16) to 16 E2M1 and pack into u64. + + Returns: + Uint64 containing 16 E2M1 values (8 bytes) + """ + s0, s1 = bfloat2_to_float2_scaled(h0, inv_scale) + s2, s3 = bfloat2_to_float2_scaled(h1, inv_scale) + s4, s5 = bfloat2_to_float2_scaled(h2, inv_scale) + s6, s7 = bfloat2_to_float2_scaled(h3, inv_scale) + s8, s9 = bfloat2_to_float2_scaled(h4, inv_scale) + s10, s11 = bfloat2_to_float2_scaled(h5, inv_scale) + s12, s13 = bfloat2_to_float2_scaled(h6, inv_scale) + s14, s15 = bfloat2_to_float2_scaled(h7, inv_scale) + + packed_lo = cvt_e2m1x8_f32(s0, s1, s2, s3, s4, s5, s6, s7) + packed_hi = cvt_e2m1x8_f32(s8, s9, s10, s11, s12, s13, s14, s15) + + return (Uint64(packed_hi) << Uint64(32)) | Uint64(packed_lo) + + +@cute.jit +def process_nvfp4_block_half( + row_tensor, elem_base: Int32, global_scale: Float32 +) -> tuple: + """ + Process a 16-element NVFP4 block for half precision input. + + Loads 16 FP16 elements, computes the E4M3 scale factor using global_scale, + converts to E2M1, and packs the result into a u64 value. + + Args: + row_tensor: Row tensor slice (mInput[row_idx, None]) + elem_base: Starting element index + global_scale: User-provided global scale factor + + Returns: + (scale_e4m3_u8, packed64): + - scale_e4m3_u8: E4M3 scale factor as Uint8 + - packed64: Uint64 containing 16 E2M1 values + """ + from cutlass import Uint8 + + from ..cute_dsl.fp4_common import ( + cvt_f32_to_e4m3, + get_ptr_as_int64, + ld_global_v4_u32, + nvfp4_compute_output_scale, + rcp_approx_ftz, + ) + + # Load 16 elements (2 x 128-bit = 8 half2 values) + ptr0 = get_ptr_as_int64(row_tensor, elem_base) + ptr1 = get_ptr_as_int64(row_tensor, elem_base + Int32(8)) + + h0, h1, h2, h3 = ld_global_v4_u32(ptr0) + h4, h5, h6, h7 = ld_global_v4_u32(ptr1) + + # Compute max absolute value across 16 elements + block_max_h2 = half2_max_abs_8(h0, h1, h2, h3, h4, h5, h6, h7) + block_max = hmax_reduce_to_f32(block_max_h2) + + # E4M3 scale factor computation + fp4_max_rcp = rcp_approx_ftz(Float32(6.0)) + scale_float = global_scale * (block_max * fp4_max_rcp) + scale_fp8_u32 = cvt_f32_to_e4m3(scale_float) + scale_fp8 = Uint8(scale_fp8_u32 & Uint32(0xFF)) + + # output_scale = rcp(float(E4M3(scale)) * rcp(global_scale)), matching CUDA + output_scale = nvfp4_compute_output_scale(scale_fp8_u32, global_scale) + + # Convert to E2M1 and pack + packed64 = half2x8_to_e2m1x16_packed(h0, h1, h2, h3, h4, h5, h6, h7, output_scale) + + return scale_fp8, packed64 + + +@cute.jit +def process_nvfp4_block_bfloat( + row_tensor, elem_base: Int32, global_scale: Float32 +) -> tuple: + """ + Process a 16-element NVFP4 block for bfloat16 precision input. + + Loads 16 BF16 elements, computes the E4M3 scale factor using global_scale, + converts to E2M1, and packs the result into a u64 value. + + Args: + row_tensor: Row tensor slice (mInput[row_idx, None]) + elem_base: Starting element index + global_scale: User-provided global scale factor + + Returns: + (scale_e4m3_u8, packed64): + - scale_e4m3_u8: E4M3 scale factor as Uint8 + - packed64: Uint64 containing 16 E2M1 values + """ + from cutlass import Uint8 + + from ..cute_dsl.fp4_common import ( + cvt_f32_to_e4m3, + get_ptr_as_int64, + ld_global_v4_u32, + nvfp4_compute_output_scale, + rcp_approx_ftz, + ) + + # Load 16 elements (2 x 128-bit = 8 bfloat2 values) + ptr0 = get_ptr_as_int64(row_tensor, elem_base) + ptr1 = get_ptr_as_int64(row_tensor, elem_base + Int32(8)) + + h0, h1, h2, h3 = ld_global_v4_u32(ptr0) + h4, h5, h6, h7 = ld_global_v4_u32(ptr1) + + # Compute max absolute value across 16 elements + block_max_h2 = bfloat2_max_abs_8(h0, h1, h2, h3, h4, h5, h6, h7) + block_max = bfloat2_hmax_reduce_to_f32(block_max_h2) + + # E4M3 scale factor computation + fp4_max_rcp = rcp_approx_ftz(Float32(6.0)) + scale_float = global_scale * (block_max * fp4_max_rcp) + scale_fp8_u32 = cvt_f32_to_e4m3(scale_float) + scale_fp8 = Uint8(scale_fp8_u32 & Uint32(0xFF)) + + # output_scale = rcp(float(E4M3(scale)) * rcp(global_scale)), matching CUDA + output_scale = nvfp4_compute_output_scale(scale_fp8_u32, global_scale) + + # Convert to E2M1 and pack + packed64 = bfloat2x8_to_e2m1x16_packed(h0, h1, h2, h3, h4, h5, h6, h7, output_scale) + + return scale_fp8, packed64 + + +@cute.jit +def fp8x16_to_e2m1x16_packed( + w0: Uint32, + w1: Uint32, + w2: Uint32, + w3: Uint32, + output_scale: Float32, +) -> Uint64: + """Convert 16 packed FP8 E4M3 values (4 x uint32) to 16 E2M1 values packed as Uint64. + + Each uint32 contains 4 E4M3 bytes. Output is 16 E2M1 nibbles packed into 8 bytes. + """ + from ..cute_dsl.fp4_common import cvt_e4m3x4_to_f32x4 + + f0, f1, f2, f3 = cvt_e4m3x4_to_f32x4(w0) + f4, f5, f6, f7 = cvt_e4m3x4_to_f32x4(w1) + f8, f9, f10, f11 = cvt_e4m3x4_to_f32x4(w2) + f12, f13, f14, f15 = cvt_e4m3x4_to_f32x4(w3) + + s0 = f0 * output_scale + s1 = f1 * output_scale + s2 = f2 * output_scale + s3 = f3 * output_scale + s4 = f4 * output_scale + s5 = f5 * output_scale + s6 = f6 * output_scale + s7 = f7 * output_scale + s8 = f8 * output_scale + s9 = f9 * output_scale + s10 = f10 * output_scale + s11 = f11 * output_scale + s12 = f12 * output_scale + s13 = f13 * output_scale + s14 = f14 * output_scale + s15 = f15 * output_scale + + packed_lo = cvt_e2m1x8_f32(s0, s1, s2, s3, s4, s5, s6, s7) + packed_hi = cvt_e2m1x8_f32(s8, s9, s10, s11, s12, s13, s14, s15) + + return (Uint64(packed_hi) << Uint64(32)) | Uint64(packed_lo) + + +@cute.jit +def fp8_max_abs_16(w0: Uint32, w1: Uint32, w2: Uint32, w3: Uint32) -> Float32: + """Compute max absolute value across 16 FP8 E4M3 values (4 x uint32). + + Converts all 16 values to float32, takes abs, and reduces to a single max. + """ + from ..cute_dsl.fp4_common import cvt_e4m3x4_to_f32x4 + + f0, f1, f2, f3 = cvt_e4m3x4_to_f32x4(w0) + f4, f5, f6, f7 = cvt_e4m3x4_to_f32x4(w1) + f8, f9, f10, f11 = cvt_e4m3x4_to_f32x4(w2) + f12, f13, f14, f15 = cvt_e4m3x4_to_f32x4(w3) + + from ..cute_dsl.fp4_common import fabs_f32, fmax_f32 + + a0 = fabs_f32(f0) + a1 = fabs_f32(f1) + a2 = fabs_f32(f2) + a3 = fabs_f32(f3) + a4 = fabs_f32(f4) + a5 = fabs_f32(f5) + a6 = fabs_f32(f6) + a7 = fabs_f32(f7) + a8 = fabs_f32(f8) + a9 = fabs_f32(f9) + a10 = fabs_f32(f10) + a11 = fabs_f32(f11) + a12 = fabs_f32(f12) + a13 = fabs_f32(f13) + a14 = fabs_f32(f14) + a15 = fabs_f32(f15) + + m01 = fmax_f32(a0, a1) + m23 = fmax_f32(a2, a3) + m45 = fmax_f32(a4, a5) + m67 = fmax_f32(a6, a7) + m89 = fmax_f32(a8, a9) + m1011 = fmax_f32(a10, a11) + m1213 = fmax_f32(a12, a13) + m1415 = fmax_f32(a14, a15) + + m0123 = fmax_f32(m01, m23) + m4567 = fmax_f32(m45, m67) + m891011 = fmax_f32(m89, m1011) + m12131415 = fmax_f32(m1213, m1415) + + m_lo = fmax_f32(m0123, m4567) + m_hi = fmax_f32(m891011, m12131415) + + return fmax_f32(m_lo, m_hi) + + +@cute.jit +def process_nvfp4_block_fp8( + row_tensor, elem_base: Int32, global_scale: Float32 +) -> tuple: + """ + Process a 16-element NVFP4 block for FP8 E4M3 input. + + Matches the CUDA cvt_warp_fp8_to_fp4 behavior: FP8 values are first converted + to float32, pre-scaled by 6/global_scale, and converted to half2. From there, + the standard half2 pipeline is used for max-abs reduction, scale factor + computation, and E2M1 conversion. + + Args: + row_tensor: Row tensor slice (mInput[row_idx, None]) + elem_base: Starting element index + global_scale: User-provided global scale factor + + Returns: + (scale_e4m3_u8, packed64): + - scale_e4m3_u8: E4M3 scale factor as Uint8 + - packed64: Uint64 containing 16 E2M1 values + """ + from cutlass import Uint8 + + from ..cute_dsl.fp4_common import ( + cvt_e4m3x4_to_f32x4, + cvt_f32_to_e4m3, + cvt_f32x2_to_half2, + get_ptr_as_int64, + ld_global_v4_u32, + nvfp4_compute_output_scale, + rcp_approx_ftz, + ) + + # Load 16 FP8 elements (1 x 128-bit = 4 x uint32 = 16 bytes) + ptr = get_ptr_as_int64(row_tensor, elem_base) + w0, w1, w2, w3 = ld_global_v4_u32(ptr) + + # Convert FP8 to float32 and pre-scale by 6/global_scale (matching CUDA) + prescale = Float32(6.0) * rcp_approx_ftz(global_scale) + + f0, f1, f2, f3 = cvt_e4m3x4_to_f32x4(w0) + f4, f5, f6, f7 = cvt_e4m3x4_to_f32x4(w1) + f8, f9, f10, f11 = cvt_e4m3x4_to_f32x4(w2) + f12, f13, f14, f15 = cvt_e4m3x4_to_f32x4(w3) + + # Pack pre-scaled float pairs into half2 (matching __float22half2_rn in CUDA) + h0 = cvt_f32x2_to_half2(f0 * prescale, f1 * prescale) + h1 = cvt_f32x2_to_half2(f2 * prescale, f3 * prescale) + h2 = cvt_f32x2_to_half2(f4 * prescale, f5 * prescale) + h3 = cvt_f32x2_to_half2(f6 * prescale, f7 * prescale) + h4 = cvt_f32x2_to_half2(f8 * prescale, f9 * prescale) + h5 = cvt_f32x2_to_half2(f10 * prescale, f11 * prescale) + h6 = cvt_f32x2_to_half2(f12 * prescale, f13 * prescale) + h7 = cvt_f32x2_to_half2(f14 * prescale, f15 * prescale) + + # From here, use the same half2 pipeline as process_nvfp4_block_half + block_max_h2 = half2_max_abs_8(h0, h1, h2, h3, h4, h5, h6, h7) + block_max = hmax_reduce_to_f32(block_max_h2) + + # E4M3 scale factor computation + fp4_max_rcp = rcp_approx_ftz(Float32(6.0)) + scale_float = global_scale * (block_max * fp4_max_rcp) + scale_fp8_u32 = cvt_f32_to_e4m3(scale_float) + scale_fp8 = Uint8(scale_fp8_u32 & Uint32(0xFF)) + + # output_scale = rcp(float(E4M3(scale)) * rcp(global_scale)), matching CUDA + output_scale = nvfp4_compute_output_scale(scale_fp8_u32, global_scale) + + # Convert pre-scaled half2 values to E2M1 and pack + packed64 = half2x8_to_e2m1x16_packed(h0, h1, h2, h3, h4, h5, h6, h7, output_scale) + + return scale_fp8, packed64 + + __all__ = [ + # NVFP4 Constants + "NVFP4_SF_VEC_SIZE", # MXFP8 Constants "SF_VEC_SIZE", "INV_FLOAT8_E4M3_MAX", @@ -982,6 +1401,8 @@ def bfloat2x16_to_e2m1x32_packed( "ue8m0_to_inv_scale_fast", "reduce_max_4threads", "compute_sf_index_swizzled_128x4_gpu", + "compute_sf_index_swizzled_8x4_gpu", + "compute_sf_index_linear_gpu", # Low-level intrinsics (MXFP4 - E2M1 conversion) "half2_to_float2_scaled", "bfloat2_to_float2_scaled", @@ -999,4 +1420,13 @@ def bfloat2x16_to_e2m1x32_packed( "ld_32_elements", "half2x16_to_e2m1x32_packed", "bfloat2x16_to_e2m1x32_packed", + # High-level helper functions (NVFP4) + "half2x8_to_e2m1x16_packed", + "bfloat2x8_to_e2m1x16_packed", + "process_nvfp4_block_half", + "process_nvfp4_block_bfloat", + # High-level helper functions (NVFP4 - FP8 input) + "fp8x16_to_e2m1x16_packed", + "fp8_max_abs_16", + "process_nvfp4_block_fp8", ] diff --git a/tests/utils/test_fp4_quantize.py b/tests/utils/test_fp4_quantize.py index b2343dd9d2..292cacbd44 100644 --- a/tests/utils/test_fp4_quantize.py +++ b/tests/utils/test_fp4_quantize.py @@ -10,10 +10,12 @@ fp4_quantize, mxfp4_quantize, mxfp4_dequantize, + nvfp4_quantize, nvfp4_batched_quantize, scaled_fp4_grouped_quantize, silu_and_mul_scaled_nvfp4_experts_quantize, silu_and_mul, + SfLayout, ) from flashinfer.utils import ( is_sm100a_supported, @@ -48,6 +50,7 @@ def _is_fp4_supported(device: torch.device) -> bool: FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max BLOCK_SIZE = 16 +FP4_BACKENDS = ["cuda", "cute-dsl"] def swizzle_sf( @@ -111,6 +114,7 @@ def unswizzle_sf( return sf_unswizzle_sliced.contiguous() +@pytest.mark.parametrize("backend", FP4_BACKENDS) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("shape", SHAPES) @pytest.mark.parametrize("seed", SEEDS) @@ -119,6 +123,7 @@ def unswizzle_sf( @pytest.mark.parametrize("is_swizzled", [False, True]) @torch.inference_mode() def test_fp4_quantization( + backend: str, dtype: torch.dtype, shape: tuple[int, int], seed: int, @@ -128,6 +133,9 @@ def test_fp4_quantization( ) -> None: if not _is_fp4_supported(torch.device(device)): pytest.skip("Nvfp4 Requires compute capability >= 10 and CUDA >= 12.8") + if backend == "cute-dsl": + if not _is_cute_dsl_available(): + pytest.skip("CuTe-DSL not available") torch.set_default_device(device) torch.manual_seed(seed) m, n = shape @@ -140,7 +148,7 @@ def test_fp4_quantization( global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax out_ref, scale_ref = ref_fp4_quant(x, global_scale, sf_vec_size, sf_use_ue8m0) out, out_scale = fp4_quantize( - x, global_scale, sf_vec_size, sf_use_ue8m0, is_swizzled + x, global_scale, sf_vec_size, sf_use_ue8m0, is_swizzled, backend=backend ) assert n % sf_vec_size == 0, f"cols needs to be {sf_vec_size} divisible" if sf_use_ue8m0: @@ -158,12 +166,14 @@ def test_fp4_quantization( torch.testing.assert_close(scale_ans, scale_ref, rtol=1e-1, atol=1e-1) +@pytest.mark.parametrize("backend", FP4_BACKENDS) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("shape", SHAPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) @torch.inference_mode() def test_scale_swizzling( + backend: str, dtype: torch.dtype, shape: tuple[int, int], seed: int, @@ -171,6 +181,8 @@ def test_scale_swizzling( ) -> None: if not _is_fp4_supported(torch.device("cuda")): pytest.skip("Nvfp4 Requires compute capability >= 10 and CUDA >= 12.8") + if backend == "cute-dsl" and not _is_cute_dsl_available(): + pytest.skip("CuTe-DSL not available") torch.set_default_device(device) torch.manual_seed(seed) m, n = shape @@ -178,8 +190,12 @@ def test_scale_swizzling( tensor_amax = torch.abs(x).max().to(torch.float32) global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax - _, unswizzled_scale = fp4_quantize(x, global_scale, BLOCK_SIZE, False, False) - _, swizzled_scale = fp4_quantize(x, global_scale, BLOCK_SIZE, False, True) + _, unswizzled_scale = fp4_quantize( + x, global_scale, BLOCK_SIZE, False, False, backend=backend + ) + _, swizzled_scale = fp4_quantize( + x, global_scale, BLOCK_SIZE, False, True, backend=backend + ) assert n % BLOCK_SIZE == 0, f"cols needs to be {BLOCK_SIZE} divisible" recovered_unswizzled_scale = unswizzle_sf( swizzle_sf(unswizzled_scale, m, n), @@ -242,12 +258,14 @@ def test_block_scale_interleave( assert_equal(swizzled_sf.reshape(expected_shape), ref_swizzled_sf) +@pytest.mark.parametrize("backend", FP4_BACKENDS) @pytest.mark.parametrize("shape", SHAPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("sf_use_ue8m0", [True, False]) @torch.inference_mode() def test_e2m1_dequantization( + backend: str, shape: tuple[int, int], seed: int, device: str, @@ -256,6 +274,8 @@ def test_e2m1_dequantization( """Test roundtrip: fp4_quantize -> e2m1_and_ufp8sf_scale_to_float.""" if not _is_fp4_supported(torch.device("cuda")): pytest.skip("Nvfp4 Requires compute capability >= 10 and CUDA >= 12.8") + if backend == "cute-dsl" and not _is_cute_dsl_available(): + pytest.skip("CuTe-DSL not available") torch.set_default_device(device) torch.manual_seed(seed) @@ -273,7 +293,12 @@ def test_e2m1_dequantization( # Step 1: Quantize with fp4_quantize quantized_tensor, scale_factors = fp4_quantize( - x, global_scale, block_size, sf_use_ue8m0, is_sf_swizzled_layout + x, + global_scale, + block_size, + sf_use_ue8m0, + is_sf_swizzled_layout, + backend=backend, ) # Step 2: Dequantize with e2m1_and_ufp8sf_scale_to_float @@ -453,18 +478,393 @@ def test_mxfp4_quantize_backend_parity( f"Scale factors should match >95%, got {scale_match_pct:.1f}%" ) - # Both should roundtrip to similar values - # Note: FP4 (E2M1) has coarse quantization steps (0.25-0.5 between adjacent values), - # so we allow atol=0.5 (one quantization step) for edge-case rounding differences. torch.testing.assert_close( dq_cuda_f32, dq_cute_f32, - rtol=0.2, - atol=0.5, # Allow one FP4 quantization step difference + rtol=0, + atol=0, msg=error_msg, ) +# ============================================================================= +# NVFP4 Quantization Tests (Both Backends) +# ============================================================================= + +NVFP4_SHAPES = [(128, 64), (256, 128), (512, 256), (128, 1024), (1024, 2048)] +NVFP4_BACKENDS = ["cuda", "cute-dsl"] +NVFP4_SF_LAYOUTS = [SfLayout.layout_128x4, SfLayout.layout_8x4, SfLayout.layout_linear] +# Roundtrip test only for layouts the dequantizer supports (128x4 and linear) +NVFP4_ROUNDTRIP_SF_LAYOUTS = [SfLayout.layout_128x4, SfLayout.layout_linear] + + +@pytest.mark.parametrize("backend", NVFP4_BACKENDS) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("shape", NVFP4_SHAPES) +@pytest.mark.parametrize("sf_layout", NVFP4_ROUNDTRIP_SF_LAYOUTS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_nvfp4_quantize_roundtrip( + backend: str, + dtype: torch.dtype, + shape: tuple[int, int], + sf_layout: SfLayout, + device: str, +) -> None: + """Test NVFP4 quantization roundtrip for both backends and layouts.""" + if not _is_fp4_supported(torch.device(device)): + pytest.skip("Nvfp4 Requires compute capability >= 10 and CUDA >= 12.8") + if backend == "cute-dsl" and not _is_cute_dsl_available(): + pytest.skip("CuTe-DSL not available") + + torch.set_default_device(device) + torch.manual_seed(42) + + m, n = shape + x = torch.randn((m, n), dtype=dtype) + + tensor_amax = torch.abs(x).max().to(torch.float32) + global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax + + quant_out, scale_out = nvfp4_quantize( + x, global_scale, sfLayout=sf_layout, backend=backend + ) + + # Basic shape checks + assert quant_out.shape == (m, n // 2), ( + f"Expected shape ({m}, {n // 2}), got {quant_out.shape}" + ) + assert quant_out.dtype == torch.uint8, f"Expected uint8, got {quant_out.dtype}" + assert scale_out.dtype == torch.uint8, f"Expected uint8, got {scale_out.dtype}" + + is_swizzled = sf_layout != SfLayout.layout_linear + + # Dequantize round-trip + dq_out = e2m1_and_ufp8sf_scale_to_float( + quant_out, + scale_out, + 1 / global_scale, + sf_vec_size=16, + ufp8_type=1, + is_sf_swizzled_layout=is_swizzled, + ) + dq_out = dq_out.to(device) + + # Verify no NaN/Inf + assert not torch.isnan(dq_out).any(), "Dequantized tensor contains NaN" + assert not torch.isinf(dq_out).any(), "Dequantized tensor contains Inf" + + # Verify roundtrip is reasonably accurate + torch.testing.assert_close( + dq_out.to(torch.float32), + x.to(torch.float32), + rtol=0.3, + atol=0.5, + msg=f"{backend} {sf_layout.name} NVFP4 quantize -> dequantize roundtrip failed", + ) + + +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("shape", NVFP4_SHAPES) +@pytest.mark.parametrize("sf_layout", NVFP4_SF_LAYOUTS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_nvfp4_quantize_backend_parity( + dtype: torch.dtype, + shape: tuple[int, int], + sf_layout: SfLayout, + device: str, +) -> None: + """Test that CUDA and CuTe-DSL backends produce matching results for NVFP4.""" + if not _is_fp4_supported(torch.device(device)): + pytest.skip("Nvfp4 Requires compute capability >= 10 and CUDA >= 12.8") + if not _is_cute_dsl_available(): + pytest.skip("CuTe-DSL not available") + + torch.set_default_device(device) + torch.manual_seed(42) + + m, n = shape + x = torch.randn((m, n), dtype=dtype) + + tensor_amax = torch.abs(x).max().to(torch.float32) + global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax + + # Get results from both backends + quant_cuda, scale_cuda = nvfp4_quantize( + x, global_scale, sfLayout=sf_layout, backend="cuda" + ) + quant_cute, scale_cute = nvfp4_quantize( + x, global_scale, sfLayout=sf_layout, backend="cute-dsl" + ) + + # Shape should match + assert quant_cuda.shape == quant_cute.shape, ( + f"Quantized output shape mismatch for {sf_layout.name}" + ) + assert scale_cuda.shape == scale_cute.shape, ( + f"Scale output shape mismatch for {sf_layout.name}" + ) + + # Quantized FP4 values should match exactly (layout-independent) + quant_match_pct = (quant_cuda == quant_cute).float().mean().item() * 100 + assert quant_match_pct > 95.0, ( + f"Quantized values should match >95%, got {quant_match_pct:.1f}% " + f"(layout={sf_layout.name})" + ) + + # Scale factors should match exactly (layout-specific indexing) + scale_match_pct = (scale_cuda == scale_cute).float().mean().item() * 100 + assert scale_match_pct > 95.0, ( + f"Scale factors should match >95%, got {scale_match_pct:.1f}% " + f"(layout={sf_layout.name})" + ) + + # For layouts that support dequantization, also compare dequantized values + is_swizzled = sf_layout != SfLayout.layout_linear + can_dequantize = sf_layout in (SfLayout.layout_128x4, SfLayout.layout_linear) + + if can_dequantize: + dq_cuda = ( + e2m1_and_ufp8sf_scale_to_float( + quant_cuda, + scale_cuda, + 1 / global_scale, + sf_vec_size=16, + ufp8_type=1, + is_sf_swizzled_layout=is_swizzled, + ) + .to(device) + .to(torch.float32) + ) + dq_cute = ( + e2m1_and_ufp8sf_scale_to_float( + quant_cute, + scale_cute, + 1 / global_scale, + sf_vec_size=16, + ufp8_type=1, + is_sf_swizzled_layout=is_swizzled, + ) + .to(device) + .to(torch.float32) + ) + + abs_diff = (dq_cuda - dq_cute).abs() + rel_diff = abs_diff / (dq_cuda.abs() + 1e-8) + + error_msg = ( + f"CUDA and CuTe-DSL backends differ after dequantization:\n" + f" Shape: {shape}, dtype: {dtype}, layout: {sf_layout.name}\n" + f" Quantized match: {quant_match_pct:.1f}%, Scale match: {scale_match_pct:.1f}%\n" + f" Abs diff - max: {abs_diff.max().item():.6f}, mean: {abs_diff.mean().item():.6f}\n" + f" Rel diff - max: {rel_diff.max().item():.6f}, mean: {rel_diff.mean().item():.6f}\n" + f" CUDA dq range: [{dq_cuda.min().item():.4f}, {dq_cuda.max().item():.4f}]\n" + f" CuTe dq range: [{dq_cute.min().item():.4f}, {dq_cute.max().item():.4f}]" + ) + + torch.testing.assert_close( + dq_cuda, + dq_cute, + rtol=0, + atol=0, + msg=error_msg, + ) + + +NVFP4_FP8_SHAPES = [(128, 64), (256, 128), (512, 256), (128, 1024)] + + +@pytest.mark.parametrize("shape", NVFP4_FP8_SHAPES) +@pytest.mark.parametrize("sf_layout", NVFP4_SF_LAYOUTS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_nvfp4_quantize_fp8_input_cute_dsl( + shape: tuple[int, int], + sf_layout: SfLayout, + device: str, +) -> None: + """Test CuTe-DSL NVFP4 quantization with FP8 E4M3 input.""" + if not _is_fp4_supported(torch.device(device)): + pytest.skip("Nvfp4 Requires compute capability >= 10 and CUDA >= 12.8") + if not _is_cute_dsl_available(): + pytest.skip("CuTe-DSL not available") + + torch.set_default_device(device) + torch.manual_seed(42) + + m, n = shape + x_fp32 = torch.randn((m, n), dtype=torch.float32) + x_fp8 = x_fp32.to(torch.float8_e4m3fn) + + tensor_amax = torch.abs(x_fp8.float()).max().to(torch.float32) + global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax + + quant_out, scale_out = nvfp4_quantize( + x_fp8, global_scale, sfLayout=sf_layout, backend="cute-dsl" + ) + + assert quant_out.shape == (m, n // 2), ( + f"Expected shape ({m}, {n // 2}), got {quant_out.shape}" + ) + assert quant_out.dtype == torch.uint8, f"Expected uint8, got {quant_out.dtype}" + assert scale_out.dtype == torch.uint8, f"Expected uint8, got {scale_out.dtype}" + + assert not torch.all(quant_out == 0), "All quantized values are zero" + assert not torch.all(scale_out == 0), "All scale factors are zero" + + is_swizzled = sf_layout != SfLayout.layout_linear + can_dequantize = sf_layout in (SfLayout.layout_128x4, SfLayout.layout_linear) + + if can_dequantize: + dq_out = ( + e2m1_and_ufp8sf_scale_to_float( + quant_out, + scale_out, + 1 / global_scale, + sf_vec_size=16, + ufp8_type=1, + is_sf_swizzled_layout=is_swizzled, + ) + .to(device) + .to(torch.float32) + ) + assert not torch.isnan(dq_out).any(), "Dequantized tensor contains NaN" + assert not torch.isinf(dq_out).any(), "Dequantized tensor contains Inf" + + # The FP8→FP4 path (matching CUDA cvt_warp_fp8_to_fp4) pre-scales input + # by 6/global_scale before quantization. Standard dequant (e2m1 * sf / gs) + # therefore reconstructs x_fp8 * 6/gs, not x_fp8. + expected = x_fp8.float() * (6.0 / global_scale.item()) + torch.testing.assert_close( + dq_out, + expected, + rtol=0.3, + atol=0.5, + msg=f"CuTe-DSL FP8 input NVFP4 roundtrip failed (layout={sf_layout.name})", + ) + + +@pytest.mark.parametrize("shape", NVFP4_FP8_SHAPES) +@pytest.mark.parametrize("sf_layout", NVFP4_SF_LAYOUTS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_nvfp4_quantize_fp8_backend_parity( + shape: tuple[int, int], + sf_layout: SfLayout, + device: str, +) -> None: + """Test CUDA and CuTe-DSL backends produce matching results for FP8 input.""" + if not _is_fp4_supported(torch.device(device)): + pytest.skip("Nvfp4 Requires compute capability >= 10 and CUDA >= 12.8") + if not _is_cute_dsl_available(): + pytest.skip("CuTe-DSL not available") + + torch.set_default_device(device) + torch.manual_seed(42) + + m, n = shape + x_fp32 = torch.randn((m, n), dtype=torch.float32) + x_fp8 = x_fp32.to(torch.float8_e4m3fn) + + tensor_amax = torch.abs(x_fp8.float()).max().to(torch.float32) + global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax + + quant_cuda, scale_cuda = nvfp4_quantize( + x_fp8, global_scale, sfLayout=sf_layout, backend="cuda" + ) + quant_cute, scale_cute = nvfp4_quantize( + x_fp8, global_scale, sfLayout=sf_layout, backend="cute-dsl" + ) + + assert quant_cuda.shape == quant_cute.shape, ( + f"Quantized output shape mismatch for FP8 input, {sf_layout.name}" + ) + assert scale_cuda.shape == scale_cute.shape, ( + f"Scale output shape mismatch for FP8 input, {sf_layout.name}" + ) + + quant_match_pct = (quant_cuda == quant_cute).float().mean().item() * 100 + assert quant_match_pct > 95.0, ( + f"FP8 quantized values should match >95%, got {quant_match_pct:.1f}% " + f"(layout={sf_layout.name})" + ) + + scale_match_pct = (scale_cuda == scale_cute).float().mean().item() * 100 + assert scale_match_pct > 95.0, ( + f"FP8 scale factors should match >95%, got {scale_match_pct:.1f}% " + f"(layout={sf_layout.name})" + ) + + +# ============================================================================= +# NVFP4 TMA Kernel Tests +# ============================================================================= + +NVFP4_TMA_SHAPES = [ + # Shapes that trigger TMA: log2(M)+log2(K) >= 25 and K % 512 == 0 + (4096, 8192), # log2sum=25, smallest TMA case + (8192, 4096), # log2sum=25 + (16384, 2048), # log2sum=25 + (32768, 1024), # log2sum=25 + (16384, 4096), # log2sum=26 + (8192, 8192), # log2sum=26 +] + + +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("shape", NVFP4_TMA_SHAPES) +@pytest.mark.parametrize("sf_layout", NVFP4_SF_LAYOUTS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_nvfp4_quantize_tma_backend_parity( + dtype: torch.dtype, + shape: tuple[int, int], + sf_layout: SfLayout, + device: str, +) -> None: + """Test that TMA-based CuTe-DSL kernel matches the CUDA backend for large problems.""" + if not _is_fp4_supported(torch.device(device)): + pytest.skip("Nvfp4 Requires compute capability >= 10 and CUDA >= 12.8") + if not _is_cute_dsl_available(): + pytest.skip("CuTe-DSL not available") + + torch.set_default_device(device) + torch.manual_seed(42) + + m, n = shape + x = torch.randn((m, n), dtype=dtype) + + tensor_amax = torch.abs(x).max().to(torch.float32) + global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax + + quant_cuda, scale_cuda = nvfp4_quantize( + x, global_scale, sfLayout=sf_layout, backend="cuda" + ) + quant_cute, scale_cute = nvfp4_quantize( + x, global_scale, sfLayout=sf_layout, backend="cute-dsl" + ) + + assert quant_cuda.shape == quant_cute.shape, ( + f"TMA quantized output shape mismatch for {sf_layout.name}" + ) + assert scale_cuda.shape == scale_cute.shape, ( + f"TMA scale output shape mismatch for {sf_layout.name}" + ) + + quant_match_pct = (quant_cuda == quant_cute).float().mean().item() * 100 + assert quant_match_pct > 95.0, ( + f"TMA quantized values should match >95%, got {quant_match_pct:.1f}% " + f"(shape={shape}, layout={sf_layout.name})" + ) + + scale_match_pct = (scale_cuda == scale_cute).float().mean().item() * 100 + assert scale_match_pct > 95.0, ( + f"TMA scale factors should match >95%, got {scale_match_pct:.1f}% " + f"(shape={shape}, layout={sf_layout.name})" + ) + + @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("batch_shape", BATCH_SHAPES) @pytest.mark.parametrize("seed", SEEDS)