Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions benchmarks/routines/flashinfer_benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,9 +517,9 @@ def dtype_str_to_torch_dtype(dtype_str):
"8.6": [],
"8.9": [],
"9.0": [],
"10.0": ["cuda"],
"10.3": ["cuda"],
"12.0": ["cuda"],
"10.0": ["cuda", "cute-dsl"],
"10.3": ["cuda", "cute-dsl"],
"12.0": ["cuda", "cute-dsl"],
},
"nvfp4_batched_quantize": {
"7.5": [],
Expand Down
20 changes: 9 additions & 11 deletions benchmarks/routines/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,17 +628,15 @@ def testNvfp4Quantize(args):
print(f"[VVERBOSE] {enable_pdl = }")

def run_backend(backend, input_tensor, global_sf_tensor):
if backend == "cuda":
return flashinfer.nvfp4_quantize(
input_tensor,
global_sf_tensor,
sfLayout=sf_layout,
do_shuffle=do_shuffle,
sf_vec_size=sf_vec_size,
enable_pdl=enable_pdl,
)
else:
raise ValueError(f"Unsupported backend: {backend}")
return flashinfer.nvfp4_quantize(
input_tensor,
global_sf_tensor,
sfLayout=sf_layout,
do_shuffle=do_shuffle,
sf_vec_size=sf_vec_size,
enable_pdl=enable_pdl,
backend=backend,
)

# Storage for timing results and outputs
backend_times = {backend: [] for backend in backends}
Expand Down
226 changes: 225 additions & 1 deletion flashinfer/cute_dsl/fp4_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,31 @@ def ld_global_v4_u32(
return Uint32(v0), Uint32(v1), Uint32(v2), Uint32(v3)


@dsl_user_op
def ld_v4_u32(
base_ptr: Int64, *, loc=None, ip=None
) -> Tuple[Uint32, Uint32, Uint32, Uint32]:
"""Load 128 bits (4 x uint32) using generic addressing (works for GMEM and SMEM)."""
result = llvm.inline_asm(
llvm.StructType.get_literal([T.i32(), T.i32(), T.i32(), T.i32()]),
[Int64(base_ptr).ir_value(loc=loc, ip=ip)],
"ld.v4.u32 {$0, $1, $2, $3}, [$4];",
"=r,=r,=r,=r,l",
has_side_effects=False,
is_align_stack=False,
asm_dialect=llvm.AsmDialect.AD_ATT,
loc=loc,
ip=ip,
)

v0 = llvm.extractvalue(T.i32(), result, [0], loc=loc, ip=ip)
v1 = llvm.extractvalue(T.i32(), result, [1], loc=loc, ip=ip)
v2 = llvm.extractvalue(T.i32(), result, [2], loc=loc, ip=ip)
v3 = llvm.extractvalue(T.i32(), result, [3], loc=loc, ip=ip)

return Uint32(v0), Uint32(v1), Uint32(v2), Uint32(v3)


@dsl_user_op
def st_global_u64(base_ptr: Int64, value: Uint64, *, loc=None, ip=None):
"""Store 64 bits to global memory."""
Expand All @@ -173,12 +198,87 @@ def st_global_u64(base_ptr: Int64, value: Uint64, *, loc=None, ip=None):

@dsl_user_op
def get_ptr_as_int64(tensor: cute.Tensor, offset: Int32, *, loc=None, ip=None) -> Int64:
"""Get the memory address of tensor[offset] as Int64."""
"""Get the memory address of tensor[offset] as Int64.

WARNING: This uses ptrtoint which strips address space information.
For SMEM tensors, the resulting Int64 is a raw SMEM offset that does NOT
work with generic-addressing loads (ld.v4.u32). Use only with explicit
address-space loads (ld.global.*) or for global memory tensors.
"""
elem_ptr = tensor.iterator + Int32(offset)
ptr_int = llvm.ptrtoint(T.i64(), elem_ptr.llvm_ptr, loc=loc, ip=ip)
return Int64(ptr_int)


@dsl_user_op
def get_smem_ptr_as_int32(
tensor: cute.Tensor, offset: Int32, *, loc=None, ip=None
) -> Int32:
"""Get the shared-memory byte address of tensor[offset] as Int32.

Uses Pointer.toint() which preserves the SMEM address space (addrspace 3),
returning a 32-bit SMEM address suitable for ld.shared.* instructions.
"""
elem_ptr = tensor.iterator + Int32(offset)
return elem_ptr.toint(loc=loc, ip=ip)


@dsl_user_op
def ld_shared_v4_u32(
smem_addr: Int32, *, loc=None, ip=None
) -> Tuple[Uint32, Uint32, Uint32, Uint32]:
"""Load 128 bits (4 x uint32) from shared memory via ld.shared.v4.u32.

Args:
smem_addr: 32-bit shared memory address (from get_smem_ptr_as_int32).

Returns:
4 Uint32 values (16 bytes total, e.g. 8 packed fp16 elements).
"""
result = llvm.inline_asm(
llvm.StructType.get_literal([T.i32(), T.i32(), T.i32(), T.i32()]),
[Int32(smem_addr).ir_value(loc=loc, ip=ip)],
"ld.shared.v4.u32 {$0, $1, $2, $3}, [$4];",
"=r,=r,=r,=r,r",
has_side_effects=False,
is_align_stack=False,
asm_dialect=llvm.AsmDialect.AD_ATT,
loc=loc,
ip=ip,
)
v0 = llvm.extractvalue(T.i32(), result, [0], loc=loc, ip=ip)
v1 = llvm.extractvalue(T.i32(), result, [1], loc=loc, ip=ip)
v2 = llvm.extractvalue(T.i32(), result, [2], loc=loc, ip=ip)
v3 = llvm.extractvalue(T.i32(), result, [3], loc=loc, ip=ip)
return Uint32(v0), Uint32(v1), Uint32(v2), Uint32(v3)


@dsl_user_op
def pack_16bit_to_u32(lo, hi, *, loc=None, ip=None) -> Uint32:
"""Pack two 16-bit scalar values (fp16 or bf16) into one Uint32 (half2/bfloat2).

Uses PTX mov.b32 to bitwise-pack two 16-bit register values into a single
32-bit register, suitable for half2/bfloat2 SIMD operations.
"""
lo_ir = lo.ir_value(loc=loc, ip=ip)
hi_ir = hi.ir_value(loc=loc, ip=ip)
lo_i16 = llvm.bitcast(T.i16(), lo_ir, loc=loc, ip=ip)
hi_i16 = llvm.bitcast(T.i16(), hi_ir, loc=loc, ip=ip)
return Uint32(
llvm.inline_asm(
T.i32(),
[lo_i16, hi_i16],
"mov.b32 $0, {$1, $2};",
"=r,h,h",
has_side_effects=False,
is_align_stack=False,
asm_dialect=llvm.AsmDialect.AD_ATT,
loc=loc,
ip=ip,
)
)


# =============================================================================
# PTX Intrinsics - Math Operations
# =============================================================================
Expand Down Expand Up @@ -537,6 +637,81 @@ def cvt_f32_to_e4m3(a: Float32, *, loc=None, ip=None) -> Uint32:
)


@dsl_user_op
def cvt_e4m3x4_to_f32x4(
packed: Uint32, *, loc=None, ip=None
) -> tuple[Float32, Float32, Float32, Float32]:
"""Convert 4 packed E4M3 bytes (in a uint32) to 4 float32 values.

Uses e4m3x2 β†’ f16x2 β†’ f32 conversion path (SM89+/PTX ISA 7.8+).
Input: uint32 containing bytes [b0, b1, b2, b3] (low to high).
Output: (f0, f1, f2, f3) as Float32.
"""
result = llvm.inline_asm(
llvm.StructType.get_literal([T.f32(), T.f32(), T.f32(), T.f32()]),
[Uint32(packed).ir_value(loc=loc, ip=ip)],
"""
{
.reg .b16 pair_lo, pair_hi;
.reg .b32 h2_lo, h2_hi;
.reg .b16 h0, h1, h2, h3;
mov.b32 {pair_lo, pair_hi}, $4;
cvt.rn.f16x2.e4m3x2 h2_lo, pair_lo;
cvt.rn.f16x2.e4m3x2 h2_hi, pair_hi;
mov.b32 {h0, h1}, h2_lo;
mov.b32 {h2, h3}, h2_hi;
cvt.f32.f16 $0, h0;
cvt.f32.f16 $1, h1;
cvt.f32.f16 $2, h2;
cvt.f32.f16 $3, h3;
}
""",
"=f,=f,=f,=f,r",
has_side_effects=False,
is_align_stack=False,
asm_dialect=llvm.AsmDialect.AD_ATT,
loc=loc,
ip=ip,
)

f0 = llvm.extractvalue(T.f32(), result, [0], loc=loc, ip=ip)
f1 = llvm.extractvalue(T.f32(), result, [1], loc=loc, ip=ip)
f2 = llvm.extractvalue(T.f32(), result, [2], loc=loc, ip=ip)
f3 = llvm.extractvalue(T.f32(), result, [3], loc=loc, ip=ip)

return Float32(f0), Float32(f1), Float32(f2), Float32(f3)


@dsl_user_op
def cvt_f32x2_to_half2(a: Float32, b: Float32, *, loc=None, ip=None) -> Uint32:
"""Pack two float32 values into a half2 (uint32 containing two fp16 values).

Uses cvt.rn.f16.f32 for each value, then packs into a single uint32.
Matches __float22half2_rn() behavior in CUDA.
"""
return Uint32(
llvm.inline_asm(
T.i32(),
[
Float32(a).ir_value(loc=loc, ip=ip),
Float32(b).ir_value(loc=loc, ip=ip),
],
"""
{
.reg .b16 h0, h1;
cvt.rn.f16.f32 h0, $1;
cvt.rn.f16.f32 h1, $2;
mov.b32 $0, {h0, h1};
}
""",
"=r,f,f",
has_side_effects=False,
is_align_stack=False,
asm_dialect=llvm.AsmDialect.AD_ATT,
)
)


@dsl_user_op
def fp8_e4m3_to_f32_and_rcp(fp8_val: Uint32, *, loc=None, ip=None) -> Float32:
"""Convert FP8 E4M3 to float32 AND compute reciprocal."""
Expand Down Expand Up @@ -573,6 +748,55 @@ def fp8_e4m3_to_f32_and_rcp(fp8_val: Uint32, *, loc=None, ip=None) -> Float32:
)


@dsl_user_op
def nvfp4_compute_output_scale(
fp8_val: Uint32, global_scale: Float32, *, loc=None, ip=None
) -> Float32:
"""Compute NVFP4 output_scale matching the CUDA kernel exactly.

Converts E4M3 scale factor to float via hardware f16x2 path, then computes
rcp(float_scale * rcp(global_scale)). Returns 0 when scale is zero.

This matches quantization_utils.cuh:
SFValue = static_cast<float>(tmp);
outputScale = rcp_approx(SFValue * rcp_approx(SFScaleVal));
"""
return Float32(
llvm.inline_asm(
T.f32(),
[
Uint32(fp8_val).ir_value(loc=loc, ip=ip),
Float32(global_scale).ir_value(loc=loc, ip=ip),
],
"""
{
.reg .pred p_zero;
.reg .b16 fp8_pair;
.reg .b32 h2_32;
.reg .b16 h_lo, h_hi;
.reg .f32 scale_f32, rcp_gs, product, result;

cvt.u16.u32 fp8_pair, $1;
cvt.rn.f16x2.e4m3x2 h2_32, fp8_pair;
mov.b32 {h_lo, h_hi}, h2_32;
cvt.f32.f16 scale_f32, h_lo;

rcp.approx.ftz.f32 rcp_gs, $2;
mul.f32 product, scale_f32, rcp_gs;
rcp.approx.ftz.f32 result, product;

setp.eq.f32 p_zero, scale_f32, 0f00000000;
selp.f32 $0, 0f00000000, result, p_zero;
}
""",
"=f,r,f",
has_side_effects=False,
is_align_stack=False,
asm_dialect=llvm.AsmDialect.AD_ATT,
)
)


# =============================================================================
# UE8M0 Intrinsics (for MXFP4)
# =============================================================================
Expand Down
2 changes: 2 additions & 0 deletions flashinfer/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
if is_cute_dsl_available():
from .kernels.mxfp8_quantize import mxfp8_quantize_cute_dsl
from .kernels.mxfp4_quantize import mxfp4_quantize_cute_dsl
from .kernels.nvfp4_quantize import nvfp4_quantize_cute_dsl

_cute_dsl_available = True
except ImportError:
Expand Down Expand Up @@ -83,4 +84,5 @@
__all__ += [
"mxfp8_quantize_cute_dsl",
"mxfp4_quantize_cute_dsl",
"nvfp4_quantize_cute_dsl",
]
Loading
Loading