From 4fe1a6f351d2b88a05a130fd73a656059a9b057b Mon Sep 17 00:00:00 2001 From: Yanqin Zhai Date: Fri, 13 Mar 2026 21:24:44 -0700 Subject: [PATCH 01/11] add-cudnn-override-shape-support --- flashinfer/gemm/__init__.py | 12 + flashinfer/gemm/gemm_base.py | 894 +++++++++++++++++++++++++++++++++++ 2 files changed, 906 insertions(+) diff --git a/flashinfer/gemm/__init__.py b/flashinfer/gemm/__init__.py index d0fd71476f..2d99f12e6a 100644 --- a/flashinfer/gemm/__init__.py +++ b/flashinfer/gemm/__init__.py @@ -18,6 +18,18 @@ from .gemm_base import gemm_fp8_nt_groupwise as gemm_fp8_nt_groupwise from .gemm_base import group_gemm_fp8_nt_groupwise as group_gemm_fp8_nt_groupwise from .gemm_base import fp8_blockscale_gemm_sm90 as fp8_blockscale_gemm_sm90 +from .gemm_base import ( + is_cudnn_override_shape_available as is_cudnn_override_shape_available, + CUDNN_MIN_VERSION_OVERRIDE_SHAPE as CUDNN_MIN_VERSION_OVERRIDE_SHAPE, + build_cudnn_gemm_bf16_graph_override_shape as build_cudnn_gemm_bf16_graph_override_shape, + execute_cudnn_gemm_bf16_graph_override_shape as execute_cudnn_gemm_bf16_graph_override_shape, + build_cudnn_fp4_gemm_graph_override_shape as build_cudnn_fp4_gemm_graph_override_shape, + execute_cudnn_fp4_gemm_graph_override_shape as execute_cudnn_fp4_gemm_graph_override_shape, + build_cudnn_mxfp8_gemm_graph_override_shape as build_cudnn_mxfp8_gemm_graph_override_shape, + execute_cudnn_mxfp8_gemm_graph_override_shape as execute_cudnn_mxfp8_gemm_graph_override_shape, + build_cudnn_gemm_with_per_tensor_q_graph_override_shape as build_cudnn_gemm_with_per_tensor_q_graph_override_shape, + execute_cudnn_gemm_with_per_tensor_q_graph_override_shape as execute_cudnn_gemm_with_per_tensor_q_graph_override_shape, +) from .routergemm import ( mm_M1_16_K7168_N128 as mm_M1_16_K7168_N128, diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index 60bc5eb76f..243849d769 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -1720,6 +1720,33 @@ def _is_cublas_fp4_available_in_cudnn(): ) +# Minimum cuDNN backend version required for is_override_shape_enabled support. +CUDNN_MIN_VERSION_OVERRIDE_SHAPE = 92100 + + +def _check_cudnn_override_shape_availability(): + """Raise if the installed cuDNN backend does not support is_override_shape_enabled.""" + _check_cudnn_availability() + backend_version = cudnn.backend_version() + if backend_version < CUDNN_MIN_VERSION_OVERRIDE_SHAPE: + raise RuntimeError( + f"cuDNN override-shape GEMM requires backend version >= " + f"{CUDNN_MIN_VERSION_OVERRIDE_SHAPE} (9.21.0), " + f"found {backend_version}. " + f"Please upgrade cuDNN: pip install --upgrade nvidia-cudnn-cu12 nvidia-cudnn-frontend" + ) + + +def is_cudnn_override_shape_available() -> bool: + """Return True if the installed cuDNN backend supports is_override_shape_enabled.""" + if not CUDNN_AVAILABLE: + return False + try: + return cudnn.backend_version() >= CUDNN_MIN_VERSION_OVERRIDE_SHAPE + except Exception: + return False + + # Global cudnn handle. need to make it per device in future _cudnn_handle = None @@ -1934,6 +1961,331 @@ def execute_cudnn_gemm_fp4_graph( ) +# --------------------------------------------------------------------------- +# override_shape shared constant +# --------------------------------------------------------------------------- + +# Sentinel value used as "cache M" when building override-shape graphs. +# Any sufficiently large M will work; 8192 covers typical LLM inference shapes. +_OVERRIDE_SHAPE_CACHE_M = 8192 + +# --------------------------------------------------------------------------- +# FP4 GEMM with override_shape (dynamic M dimension) +# --------------------------------------------------------------------------- + + +@functools.cache +def build_cudnn_fp4_gemm_graph_override_shape( + batch, + n, + k, + a_descale_n_dim, + a_descale_k_dim, + b_descale_k_dim, + b_descale_n_dim, + ab_type, + o_type, + block_size, + device, + alpha_is_not_none, + use_nvfp4, + cache_m: int = _OVERRIDE_SHAPE_CACHE_M, +): + """Build a cuDNN FP4 GEMM graph with override-shape support. + + The graph is compiled once using ``cache_m`` as the M dimension. Block + scale dimensions are derived from ``cache_m`` at compile time; at execution + time the real M (and corresponding scale dims) are passed via + ``override_shapes`` / ``override_strides``. + + Caching key contains ``(batch, n, k, ...)`` but **not** M. + """ + _check_cudnn_override_shape_availability() + + scale_type = cudnn.data_type.FP8_E4M3 if use_nvfp4 else cudnn.data_type.FP8_E8M0 + + # Build shapes / strides using cache_m + block_scale_dim_m, _, block_scale_dim_k = _calculate_block_scale_dims( + cache_m, n, k, block_size + ) + + a_shape = [batch, cache_m, k * 2] # FP4 packed: K dimension stores k*2 uint8 values + a_stride = [cache_m * k * 2, k * 2, 1] + # cuDNN expects the real FP4 shape (unpacked); use *2 only when viewed as fp4 dtype + # For the cudnn tensor we use the actual fp4 element count + a_fp4_shape = [batch, cache_m, k] + a_fp4_stride = [cache_m * k, k, 1] + + b_fp4_shape = [batch, k, n] + b_fp4_stride = [k * n, 1, k] + + a_descale_shape = [batch, block_scale_dim_m, a_descale_k_dim] + a_descale_stride = [block_scale_dim_m * a_descale_k_dim, a_descale_k_dim, 1] + b_descale_shape = [batch, b_descale_k_dim, b_descale_n_dim] + b_descale_stride = [b_descale_n_dim * b_descale_k_dim, 1, b_descale_k_dim] + c_shape = [batch, cache_m, n] + c_stride = [cache_m * n, n, 1] + + stream = torch.cuda.current_stream(device) + graph = cudnn.pygraph( + io_data_type=cudnn.data_type.FLOAT, + intermediate_data_type=cudnn.data_type.FLOAT, + compute_data_type=cudnn.data_type.FLOAT, + handle=_get_cudnn_handle(stream), + is_override_shape_enabled=True, + ) + + a_cudnn_tensor = graph.tensor( + name="a", dim=a_fp4_shape, stride=a_fp4_stride, data_type=ab_type + ) + b_cudnn_tensor = graph.tensor( + name="b", dim=b_fp4_shape, stride=b_fp4_stride, data_type=ab_type + ) + block_descale_a_cudnn_tensor = graph.tensor( + name="block_descale_a", + dim=a_descale_shape, + stride=a_descale_stride, + data_type=scale_type, + reordering_type=cudnn.tensor_reordering.F8_128x4, + ) + block_descale_b_cudnn_tensor = graph.tensor( + name="block_descale_b", + dim=b_descale_shape, + stride=b_descale_stride, + data_type=scale_type, + reordering_type=cudnn.tensor_reordering.F8_128x4, + ) + + dequant_a_tensor = graph.block_scale_dequantize( + a_cudnn_tensor, + block_descale_a_cudnn_tensor, + block_size=[1, block_size], + name="dequant_a", + ) + dequant_a_tensor.set_data_type(cudnn.data_type.FLOAT) + dequant_b_tensor = graph.block_scale_dequantize( + b_cudnn_tensor, + block_descale_b_cudnn_tensor, + block_size=[block_size, 1], + name="dequant_b", + ) + dequant_b_tensor.set_data_type(cudnn.data_type.FLOAT) + c_tensor = graph.matmul( + dequant_a_tensor, + dequant_b_tensor, + compute_data_type=cudnn.data_type.FLOAT, + name="gemm", + ) + c_tensor.set_data_type(cudnn.data_type.FLOAT) + + c_final_cudnn_tensor = c_tensor + + if alpha_is_not_none: + global_scale_cudnn_tensor = graph.tensor( + name="global_scale", + dim=[1, 1, 1], + stride=[1, 1, 1], + data_type=cudnn.data_type.FLOAT, + ) + c_final_cudnn_tensor = graph.mul( + name="scale_mul", + a=c_tensor, + b=global_scale_cudnn_tensor, + compute_data_type=cudnn.data_type.FLOAT, + ) + global_scale_cudnn_tensor.set_uid(UIDs.ALPHA_UID.value) + + c_final_cudnn_tensor.set_name("c_final").set_output(True).set_data_type(o_type) + + a_cudnn_tensor.set_uid(UIDs.A_UID.value) + b_cudnn_tensor.set_uid(UIDs.B_UID.value) + block_descale_a_cudnn_tensor.set_uid(UIDs.BLOCK_DESCALE_A_UID.value) + block_descale_b_cudnn_tensor.set_uid(UIDs.BLOCK_DESCALE_B_UID.value) + c_final_cudnn_tensor.set_uid(UIDs.O_UID.value) + + graph.validate() + graph.build_operation_graph() + graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.B]) + + if alpha_is_not_none and not _is_cublas_fp4_available_in_cudnn(): + graph.deselect_engines(["eng0"]) + + graph.check_support() + graph.build_plans() + + return graph + + +def execute_cudnn_fp4_gemm_graph_override_shape( + graph, + a, + b, + a_descale, + b_descale, + alpha, + c_final, + workspace_buffer, + tactic: int = 0, +): + """Execute FP4 GEMM cuDNN graph with dynamic-shape overrides.""" + + assert a.stride()[2] == 1 and b.stride()[1] == 1 + + variant_pack = { + UIDs.A_UID.value: a, + UIDs.B_UID.value: b, + UIDs.BLOCK_DESCALE_A_UID.value: a_descale, + UIDs.BLOCK_DESCALE_B_UID.value: b_descale, + UIDs.O_UID.value: c_final, + } + + if alpha is not None: + variant_pack[UIDs.ALPHA_UID.value] = alpha.view(torch.float) + + override_uids = [ + UIDs.A_UID.value, + UIDs.B_UID.value, + UIDs.BLOCK_DESCALE_A_UID.value, + UIDs.BLOCK_DESCALE_B_UID.value, + UIDs.O_UID.value, + ] + override_shapes = [ + [a.shape[0], a.shape[1], a.shape[2] * 2], + [b.shape[0], b.shape[1] * 2, b.shape[2]], + a_descale.shape, + b_descale.shape, + c_final.shape, + ] + override_strides = [ + [a.stride()[0], a.stride()[1] * 2, a.stride()[2]], + [b.stride()[0], b.stride()[1], b.stride()[2] * 2], + a_descale.stride(), + b_descale.stride(), + c_final.stride(), + ] + + if workspace_buffer.numel() < graph.get_workspace_size(): + workspace_buffer = torch.empty( + graph.get_workspace_size(), device=a.device, dtype=torch.uint8 + ) + + stream = torch.cuda.current_stream(a.device) + + graph.execute_plan_at_index( + variant_pack, + workspace_buffer, + tactic, + handle=_get_cudnn_handle(stream), + override_uids=override_uids, + override_shapes=override_shapes, + override_strides=override_strides, + ) + + +def _get_cudnn_fp4_gemm_graph_override_shape( + a: torch.Tensor, + b: torch.Tensor, + a_descale: torch.Tensor, + b_descale: torch.Tensor, + alpha: Optional[torch.Tensor] = None, + out_dtype: torch.dtype = torch.bfloat16, + out: Optional[torch.Tensor] = None, + block_size: int = 16, + use_nvfp4: bool = True, +): + """Get (or build) a cached FP4 GEMM graph with override-shape support. + + The cache key excludes M so the same compiled plan is reused for all M + values. + """ + real_a_shape, _ = _get_real_fp4_shape_from_packed_uint8(a) + real_b_shape, _ = _get_real_fp4_shape_from_packed_uint8(b) + batch = real_a_shape[0] + n = real_b_shape[2] + k = real_a_shape[2] + + expanded_a_descale_shape, _ = _expand_block_scale_tensor_shape(a_descale, batch) + expanded_b_descale_shape, _ = _expand_block_scale_tensor_shape(b_descale, batch) + + # Scale dimension sizes that are independent of M + a_descale_k_dim = expanded_a_descale_shape[2] + b_descale_k_dim = expanded_b_descale_shape[1] + b_descale_n_dim = expanded_b_descale_shape[2] + # a_descale N-dimension (dim[1]) depends on M, so we pass it separately + a_descale_n_dim = expanded_a_descale_shape[1] + + return build_cudnn_fp4_gemm_graph_override_shape( + batch=batch, + n=n, + k=k, + a_descale_n_dim=a_descale_n_dim, + a_descale_k_dim=a_descale_k_dim, + b_descale_k_dim=b_descale_k_dim, + b_descale_n_dim=b_descale_n_dim, + ab_type=cudnn.data_type.FP4_E2M1, + o_type=_torch_data_type_to_cudnn_data_type(out_dtype), + block_size=block_size, + device=a.device, + alpha_is_not_none=alpha is not None, + use_nvfp4=use_nvfp4, + ) + + +def _cudnn_gemm_fp4_override_shape( + a: torch.Tensor, + b: torch.Tensor, + a_descale: torch.Tensor, + b_descale: torch.Tensor, + alpha: Optional[torch.Tensor] = None, + out_dtype: torch.dtype = torch.bfloat16, + out: Optional[torch.Tensor] = None, + block_size: int = 16, + use_nvfp4: bool = True, + workspace_buffer: torch.Tensor = None, + tactic: int = 0, +): + """FP4 GEMM via cuDNN using override-shape for dynamic M dimension.""" + graph = _get_cudnn_fp4_gemm_graph_override_shape( + a=a, + b=b, + a_descale=a_descale, + b_descale=b_descale, + alpha=alpha, + out_dtype=out_dtype, + out=out, + block_size=block_size, + use_nvfp4=use_nvfp4, + ) + + real_a_shape, real_a_stride = _get_real_fp4_shape_from_packed_uint8(a) + real_b_shape, real_b_stride = _get_real_fp4_shape_from_packed_uint8(b) + batch = real_a_shape[0] + expanded_a_descale_shape, expanded_a_descale_stride = ( + _expand_block_scale_tensor_shape(a_descale, batch) + ) + expanded_b_descale_shape, expanded_b_descale_stride = ( + _expand_block_scale_tensor_shape(b_descale, batch) + ) + + a_3d = a.view(real_a_shape) if a.ndim == 2 else a + b_3d = b.view(real_b_shape) if b.ndim == 2 else b + a_descale_3d = a_descale.view(expanded_a_descale_shape) if a_descale.ndim == 2 else a_descale + b_descale_3d = b_descale.view(expanded_b_descale_shape) if b_descale.ndim == 2 else b_descale + out_3d = out.unsqueeze(0) if out.ndim == 2 else out + + execute_cudnn_fp4_gemm_graph_override_shape( + graph, + a_3d, + b_3d, + a_descale_3d, + b_descale_3d, + alpha, + out_3d, + workspace_buffer, + tactic=tactic, + ) + + def execute_cudnn_gemm_mxfp8_graph( graph, a, @@ -1969,6 +2321,238 @@ def execute_cudnn_gemm_mxfp8_graph( ) +# --------------------------------------------------------------------------- +# MXFP8 GEMM with override_shape (dynamic M dimension) +# --------------------------------------------------------------------------- + + +@functools.cache +def build_cudnn_mxfp8_gemm_graph_override_shape( + batch, + n, + k, + a_type, + b_type, + o_type, + block_size, + device, + cache_m: int = _OVERRIDE_SHAPE_CACHE_M, +): + """Build a cuDNN MXFP8 GEMM graph with override-shape support. + + Compiled once using ``cache_m`` as M; at execution time the actual M is + provided through ``override_shapes`` / ``override_strides``. + """ + _check_cudnn_override_shape_availability() + + if a_type not in [cudnn.data_type.FP8_E4M3, cudnn.data_type.FP8_E5M2]: + raise ValueError(f"A type must be FP8_E4M3 or FP8_E5M2, got {a_type}") + if b_type not in [cudnn.data_type.FP8_E4M3, cudnn.data_type.FP8_E5M2]: + raise ValueError(f"B type must be FP8_E4M3 or FP8_E5M2, got {b_type}") + if o_type not in [cudnn.data_type.BFLOAT16, cudnn.data_type.HALF]: + raise ValueError(f"Output type must be BF16 or FP16, got {o_type}") + + block_scale_dim_m, block_scale_dim_n, block_scale_dim_k = _calculate_block_scale_dims( + cache_m, n, k, block_size + ) + + scale_type = cudnn.data_type.FP8_E8M0 + + a_shape = [batch, cache_m, k] + a_stride = [cache_m * k, k, 1] + b_shape = [batch, k, n] + b_stride = [k * n, 1, k] + a_descale_shape = [batch, block_scale_dim_m, block_scale_dim_k] + a_descale_stride = [block_scale_dim_m * block_scale_dim_k, block_scale_dim_k, 1] + b_descale_shape = [batch, block_scale_dim_k, block_scale_dim_n] + b_descale_stride = [block_scale_dim_n * block_scale_dim_k, 1, block_scale_dim_k] + + stream = torch.cuda.current_stream(device) + graph = cudnn.pygraph( + io_data_type=cudnn.data_type.FLOAT, + intermediate_data_type=cudnn.data_type.FLOAT, + compute_data_type=cudnn.data_type.FLOAT, + handle=_get_cudnn_handle(stream), + is_override_shape_enabled=True, + ) + + a_cudnn_tensor = graph.tensor( + name="a", dim=a_shape, stride=a_stride, data_type=a_type + ) + b_cudnn_tensor = graph.tensor( + name="b", dim=b_shape, stride=b_stride, data_type=b_type + ) + block_descale_a_cudnn_tensor = graph.tensor( + name="block_descale_a", + dim=a_descale_shape, + stride=a_descale_stride, + data_type=scale_type, + reordering_type=cudnn.tensor_reordering.F8_128x4, + ) + block_descale_b_cudnn_tensor = graph.tensor( + name="block_descale_b", + dim=b_descale_shape, + stride=b_descale_stride, + data_type=scale_type, + reordering_type=cudnn.tensor_reordering.F8_128x4, + ) + + dequant_a_tensor = graph.block_scale_dequantize( + a_cudnn_tensor, + block_descale_a_cudnn_tensor, + block_size=[1, block_size], + name="dequant_a", + ) + dequant_a_tensor.set_data_type(cudnn.data_type.FLOAT) + dequant_b_tensor = graph.block_scale_dequantize( + b_cudnn_tensor, + block_descale_b_cudnn_tensor, + block_size=[block_size, 1], + name="dequant_b", + ) + dequant_b_tensor.set_data_type(cudnn.data_type.FLOAT) + + c_tensor = graph.matmul( + dequant_a_tensor, + dequant_b_tensor, + compute_data_type=cudnn.data_type.FLOAT, + name="gemm", + ) + c_tensor.set_data_type(cudnn.data_type.FLOAT) + c_tensor.set_output(True).set_data_type(o_type) + + a_cudnn_tensor.set_uid(UIDs.A_UID.value) + b_cudnn_tensor.set_uid(UIDs.B_UID.value) + block_descale_a_cudnn_tensor.set_uid(UIDs.BLOCK_DESCALE_A_UID.value) + block_descale_b_cudnn_tensor.set_uid(UIDs.BLOCK_DESCALE_B_UID.value) + c_tensor.set_uid(UIDs.O_UID.value) + + graph.validate() + graph.build_operation_graph() + graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.B]) + graph.check_support() + graph.build_plans() + + return graph + + +def execute_cudnn_mxfp8_gemm_graph_override_shape( + graph, + a, + b, + a_descale, + b_descale, + c_final, + workspace_buffer, + tactic: int = 0, +): + """Execute MXFP8 GEMM cuDNN graph with dynamic-shape overrides.""" + variant_pack = { + UIDs.A_UID.value: a, + UIDs.B_UID.value: b, + UIDs.BLOCK_DESCALE_A_UID.value: a_descale, + UIDs.BLOCK_DESCALE_B_UID.value: b_descale, + UIDs.O_UID.value: c_final, + } + + override_uids = [ + UIDs.A_UID.value, + UIDs.B_UID.value, + UIDs.BLOCK_DESCALE_A_UID.value, + UIDs.BLOCK_DESCALE_B_UID.value, + UIDs.O_UID.value, + ] + override_shapes = [ + list(a.shape), + list(b.shape), + list(a_descale.shape), + list(b_descale.shape), + list(c_final.shape), + ] + override_strides = [ + list(a.stride()), + list(b.stride()), + list(a_descale.stride()), + list(b_descale.stride()), + list(c_final.stride()), + ] + + workspace_size = graph.get_workspace_size() + if workspace_buffer.numel() < workspace_size: + workspace_buffer = torch.empty( + workspace_size, device=a.device, dtype=torch.uint8 + ) + + stream = torch.cuda.current_stream(a.device) + graph.execute_plan_at_index( + variant_pack, + workspace_buffer, + tactic, + handle=_get_cudnn_handle(stream), + override_uids=override_uids, + override_shapes=override_shapes, + override_strides=override_strides, + ) + + +def _cudnn_gemm_mxfp8_override_shape( + a: torch.Tensor, + b: torch.Tensor, + a_descale: torch.Tensor, + b_descale: torch.Tensor, + out_dtype: torch.dtype = torch.bfloat16, + out: Optional[torch.Tensor] = None, + workspace_buffer: torch.Tensor = None, + tactic: int = 0, +): + """MXFP8 GEMM via cuDNN using override-shape for dynamic M dimension.""" + block_size = 32 + + a_3d = a if a.ndim == 3 else a.unsqueeze(0) + b_3d = b if b.ndim == 3 else b.unsqueeze(0) + out_3d = out if out.ndim == 3 else out.unsqueeze(0) + + batch = a_3d.shape[0] + k = a_3d.shape[2] + n = b_3d.shape[2] + + block_scale_dim_m, block_scale_dim_n, block_scale_dim_k = _calculate_block_scale_dims( + a_3d.shape[1], n, k, block_size + ) + + if a_descale.ndim == 2: + a_descale_3d = a_descale.view(batch, block_scale_dim_m, block_scale_dim_k) + else: + a_descale_3d = a_descale + + if b_descale.ndim == 2: + b_descale_3d = b_descale.view(batch, block_scale_dim_k, block_scale_dim_n) + else: + b_descale_3d = b_descale + + graph = build_cudnn_mxfp8_gemm_graph_override_shape( + batch=batch, + n=n, + k=k, + a_type=_torch_data_type_to_cudnn_data_type(a.dtype), + b_type=_torch_data_type_to_cudnn_data_type(b.dtype), + o_type=_torch_data_type_to_cudnn_data_type(out_dtype), + block_size=block_size, + device=a.device, + ) + + execute_cudnn_mxfp8_gemm_graph_override_shape( + graph=graph, + a=a_3d, + b=b_3d, + a_descale=a_descale_3d, + b_descale=b_descale_3d, + c_final=out_3d, + workspace_buffer=workspace_buffer, + tactic=tactic, + ) + + @functools.cache def build_cudnn_gemm_with_per_tensor_q_graph( a_shape, a_stride, b_shape, b_stride, a_type, b_type, o_type, device @@ -2072,6 +2656,170 @@ def execute_cudnn_gemm_with_per_tensor_q_graph( graph.execute(variant_pack, workspace, handle=cudnn_handle) +# --------------------------------------------------------------------------- +# FP8 per-tensor GEMM with override_shape (dynamic M dimension) +# --------------------------------------------------------------------------- + + +@functools.cache +def build_cudnn_gemm_with_per_tensor_q_graph_override_shape( + batch, n, k, a_type, b_type, o_type, device, + cache_m: int = _OVERRIDE_SHAPE_CACHE_M +): + """Build an FP8 per-tensor-quantized GEMM cuDNN graph with override-shape. + + Compiled once with ``cache_m`` as M; at execution time the actual M is + supplied through ``override_shapes`` / ``override_strides``. + """ + _check_cudnn_override_shape_availability() + + a_shape = [batch, cache_m, k] + a_stride = [cache_m * k, k, 1] + b_shape = [batch, k, n] + b_stride = [k * n, 1, k] + c_shape = [batch, cache_m, n] + c_stride = [cache_m * n, n, 1] + + stream = torch.cuda.current_stream(device) + graph = cudnn.pygraph( + io_data_type=cudnn.data_type.FLOAT, + intermediate_data_type=cudnn.data_type.FLOAT, + compute_data_type=cudnn.data_type.FLOAT, + handle=_get_cudnn_handle(stream), + is_override_shape_enabled=True, + ) + + a_cudnn_tensor = graph.tensor( + name="a", dim=a_shape, stride=a_stride, data_type=a_type + ) + b_cudnn_tensor = graph.tensor( + name="b", dim=b_shape, stride=b_stride, data_type=b_type + ) + a_scale_cudnn_tensor = graph.tensor( + name="a_scale", + dim=[1, 1, 1], + stride=[1, 1, 1], + data_type=cudnn.data_type.FLOAT, + ) + b_scale_cudnn_tensor = graph.tensor( + name="b_scale", + dim=[1, 1, 1], + stride=[1, 1, 1], + data_type=cudnn.data_type.FLOAT, + ) + c_cudnn_tensor = graph.matmul( + name="matmul", + A=a_cudnn_tensor, + B=b_cudnn_tensor, + compute_data_type=cudnn.data_type.FLOAT, + ) + c_cudnn_tensor.set_name("c").set_data_type(cudnn.data_type.FLOAT) + c_after_scale_a = graph.mul( + name="scale_mul_a", + a=c_cudnn_tensor, + b=a_scale_cudnn_tensor, + compute_data_type=cudnn.data_type.FLOAT, + ) + c_after_scale_b = graph.mul( + name="scale_mul_b", + a=c_after_scale_a, + b=b_scale_cudnn_tensor, + compute_data_type=cudnn.data_type.FLOAT, + ) + c_after_scale_b.set_name("c_final").set_output(True).set_data_type(o_type) + + a_cudnn_tensor.set_uid(UIDs.A_UID.value) + b_cudnn_tensor.set_uid(UIDs.B_UID.value) + a_scale_cudnn_tensor.set_uid(UIDs.A_SCALE_UID.value) + b_scale_cudnn_tensor.set_uid(UIDs.B_SCALE_UID.value) + c_after_scale_b.set_uid(UIDs.O_UID.value) + + graph.validate() + graph.build_operation_graph() + graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) + graph.check_support() + graph.build_plans() + + return graph + + +def execute_cudnn_gemm_with_per_tensor_q_graph_override_shape( + graph, a, b, a_scale, b_scale, c_final, workspace, tactic: int = 0 +): + """Execute FP8 per-tensor GEMM graph with dynamic-shape overrides.""" + variant_pack = { + UIDs.A_UID.value: a, + UIDs.B_UID.value: b, + UIDs.A_SCALE_UID.value: a_scale, + UIDs.B_SCALE_UID.value: b_scale, + UIDs.O_UID.value: c_final, + } + + override_uids = [UIDs.A_UID.value, UIDs.B_UID.value, UIDs.O_UID.value] + override_shapes = [list(a.shape), list(b.shape), list(c_final.shape)] + override_strides = [list(a.stride()), list(b.stride()), list(c_final.stride())] + + stream = torch.cuda.current_stream(a.device) + cudnn_handle = _get_cudnn_handle(stream) + + if workspace.numel() < graph.get_workspace_size(): + workspace = torch.empty( + graph.get_workspace_size(), device=a.device, dtype=torch.uint8 + ) + + graph.execute_plan_at_index( + variant_pack, + workspace, + tactic, + handle=cudnn_handle, + override_uids=override_uids, + override_shapes=override_shapes, + override_strides=override_strides, + ) + + +def _cudnn_gemm_fp8_override_shape( + workspace: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + a_scale: torch.Tensor, + b_scale: torch.Tensor, + out: Optional[torch.Tensor], + torch_out_dtype: torch.dtype, + tactic: int = 0, +): + """FP8 per-tensor GEMM via cuDNN using override-shape for dynamic M.""" + _check_cudnn_availability() + + # Expand 2-D tensors to 3-D for cuDNN + a_3d_shape, _ = _get_bf16_3d_shape_stride(a) + b_3d_shape, _ = _get_bf16_3d_shape_stride(b) + out_3d_shape, _ = _get_bf16_3d_shape_stride(out) + + batch = a_3d_shape[0] + n = b_3d_shape[2] + k = a_3d_shape[2] + + a_3d = a.view(a_3d_shape) if a.ndim == 2 else a + b_3d = b.view(b_3d_shape) if b.ndim == 2 else b + out_3d = out.view(out_3d_shape) if out.ndim == 2 else out + + graph = build_cudnn_gemm_with_per_tensor_q_graph_override_shape( + batch, + n, + k, + _torch_data_type_to_cudnn_data_type(a.dtype), + _torch_data_type_to_cudnn_data_type(b.dtype), + _torch_data_type_to_cudnn_data_type(torch_out_dtype), + a.device, + ) + + execute_cudnn_gemm_with_per_tensor_q_graph_override_shape( + graph, a_3d, b_3d, a_scale, b_scale, out_3d, workspace, tactic=tactic + ) + return out + + def _torch_data_type_to_cudnn_data_type(dtype: torch.dtype): if dtype == torch.bfloat16: return cudnn.data_type.BFLOAT16 @@ -2205,6 +2953,152 @@ def execute_cudnn_gemm_bf16_graph(graph, a, b, c_final, workspace, tactic: int = ) +# --------------------------------------------------------------------------- +# BF16 GEMM with override_shape (dynamic M dimension) +# --------------------------------------------------------------------------- + + +@functools.cache +def build_cudnn_gemm_bf16_graph_override_shape( + batch, n, k, o_type, device, cache_m: int = _OVERRIDE_SHAPE_CACHE_M +): + """Build a cuDNN BF16 GEMM graph with override-shape support. + + The graph is compiled once with ``cache_m`` as the M dimension. At + execution time the caller supplies the *actual* M via + ``execute_cudnn_gemm_bf16_graph_override_shape``, which calls + ``execute_plan_at_index`` with ``override_shapes`` / ``override_strides`` + so no rebuild is needed for different M values. + + Caching key is ``(batch, n, k, o_type, device, cache_m)`` — M is **not** + part of the key. + """ + _check_cudnn_override_shape_availability() + + a_shape = (batch, cache_m, k) + a_stride = (cache_m * k, k, 1) + b_shape = (batch, k, n) + b_stride = (k * n, 1, k) + + stream = torch.cuda.current_stream(device) + graph = cudnn.pygraph( + io_data_type=cudnn.data_type.BFLOAT16, + intermediate_data_type=cudnn.data_type.FLOAT, + compute_data_type=cudnn.data_type.FLOAT, + handle=_get_cudnn_handle(stream), + is_override_shape_enabled=True, + ) + + a_cudnn_tensor = graph.tensor( + name="a", dim=list(a_shape), stride=list(a_stride), + data_type=cudnn.data_type.BFLOAT16, + ) + b_cudnn_tensor = graph.tensor( + name="b", dim=list(b_shape), stride=list(b_stride), + data_type=cudnn.data_type.BFLOAT16, + ) + c_cudnn_tensor = graph.matmul( + name="matmul", + A=a_cudnn_tensor, + B=b_cudnn_tensor, + compute_data_type=cudnn.data_type.FLOAT, + ) + c_cudnn_tensor.set_name("c").set_output(True).set_data_type(o_type) + + a_cudnn_tensor.set_uid(UIDs.A_UID.value) + b_cudnn_tensor.set_uid(UIDs.B_UID.value) + c_cudnn_tensor.set_uid(UIDs.O_UID.value) + + graph.validate() + graph.build_operation_graph() + graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) + graph.check_support() + graph.build_plans() + + return graph + + +def execute_cudnn_gemm_bf16_graph_override_shape( + graph, a, b, c_final, workspace, tactic: int = 0 +): + """Execute a BF16 GEMM cuDNN graph built with override-shape enabled. + + Passes the actual shapes/strides of *a*, *b*, and *c_final* as + ``override_shapes`` / ``override_strides`` so a single compiled plan + handles any M dimension without rebuilding. + """ + variant_pack = { + UIDs.A_UID.value: a, + UIDs.B_UID.value: b, + UIDs.O_UID.value: c_final, + } + + override_uids = [UIDs.A_UID.value, UIDs.B_UID.value, UIDs.O_UID.value] + override_shapes = [list(a.shape), list(b.shape), list(c_final.shape)] + override_strides = [list(a.stride()), list(b.stride()), list(c_final.stride())] + + stream = torch.cuda.current_stream(a.device) + cudnn_handle = _get_cudnn_handle(stream) + + if workspace.numel() < graph.get_workspace_size(): + workspace = torch.empty( + graph.get_workspace_size(), device=a.device, dtype=torch.uint8 + ) + + graph.execute_plan_at_index( + variant_pack, + workspace, + tactic, + handle=cudnn_handle, + override_uids=override_uids, + override_shapes=override_shapes, + override_strides=override_strides, + ) + + +def _cudnn_gemm_bf16_override_shape( + workspace: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + out: torch.Tensor, + tactic: int = 0, +): + """BF16 GEMM via cuDNN using override-shape for dynamic M dimension. + + A single plan compiled with ``_OVERRIDE_SHAPE_CACHE_M`` as M is reused + for all M values without triggering a graph rebuild. + """ + _check_cudnn_availability() + + # Both mm (2-D) and bmm (3-D) are supported via the existing + # _get_bf16_3d_shape_stride helper which pads 2-D inputs to 3-D. + a_3d_shape, _ = _get_bf16_3d_shape_stride(a) + b_3d_shape, _ = _get_bf16_3d_shape_stride(b) + out_3d_shape, _ = _get_bf16_3d_shape_stride(out) + + batch = a_3d_shape[0] + n = b_3d_shape[2] + k = a_3d_shape[2] + + # Ensure 3-D contiguous views for the cuDNN call + a_3d = a.view(a_3d_shape) if a.ndim == 2 else a + b_3d = b.view(b_3d_shape) if b.ndim == 2 else b + out_3d = out.view(out_3d_shape) if out.ndim == 2 else out + + graph = build_cudnn_gemm_bf16_graph_override_shape( + batch, + n, + k, + _torch_data_type_to_cudnn_data_type(out.dtype), + a.device, + ) + + execute_cudnn_gemm_bf16_graph_override_shape( + graph, a_3d, b_3d, out_3d, workspace, tactic=tactic + ) + return out + + def _cudnn_gemm_bf16( workspace: torch.Tensor, a: torch.Tensor, From 07a5cf916cdf6879c311be9f93f8a023ea4f6f82 Mon Sep 17 00:00:00 2001 From: Yanqin Zhai Date: Fri, 13 Mar 2026 21:25:19 -0700 Subject: [PATCH 02/11] add_dynamic_shape_sample --- tests/gemm/test_cudnn_override_shape.py | 279 ++++++++++++++++++++++++ 1 file changed, 279 insertions(+) create mode 100644 tests/gemm/test_cudnn_override_shape.py diff --git a/tests/gemm/test_cudnn_override_shape.py b/tests/gemm/test_cudnn_override_shape.py new file mode 100644 index 0000000000..0a2cd4dcc4 --- /dev/null +++ b/tests/gemm/test_cudnn_override_shape.py @@ -0,0 +1,279 @@ +""" +Tests for cuDNN GEMM operations using is_override_shape_enabled API. + +A single cuDNN graph is compiled once with a "cache shape" (large M). +At execution time, the actual M is passed via override_shapes / override_strides, +so no graph rebuild is triggered for varying M values. + +Requires: + - CUDA compute capability == SM100 / SM103 + - cuDNN frontend >= 1.20 / backend_version >= 92100 +""" + +import pytest +import torch + +from flashinfer.gemm.gemm_base import ( + CUDNN_AVAILABLE, + _OVERRIDE_SHAPE_CACHE_M, + _get_bf16_3d_shape_stride, + _cudnn_gemm_bf16_override_shape, + build_cudnn_gemm_bf16_graph_override_shape, + execute_cudnn_gemm_bf16_graph_override_shape, + build_cudnn_fp4_gemm_graph_override_shape, + execute_cudnn_fp4_gemm_graph_override_shape, + build_cudnn_mxfp8_gemm_graph_override_shape, + execute_cudnn_mxfp8_gemm_graph_override_shape, + is_cudnn_override_shape_available, + CUDNN_MIN_VERSION_OVERRIDE_SHAPE, + _calculate_block_scale_dims, + _get_real_fp4_shape_from_packed_uint8, + _expand_block_scale_tensor_shape, +) +from flashinfer.utils import get_compute_capability + + +def _skip_if_no_cudnn(): + if not CUDNN_AVAILABLE: + pytest.skip("cuDNN not available") + + +def _skip_if_override_shape_not_supported(): + if not CUDNN_AVAILABLE: + pytest.skip("cuDNN not available") + if not is_cudnn_override_shape_available(): + import cudnn + pytest.skip( + f"cuDNN override-shape requires backend >= {CUDNN_MIN_VERSION_OVERRIDE_SHAPE} " + f"(9.21.0), found {cudnn.backend_version()}" + ) + + +def _skip_if_not_sm100(): + major, minor = get_compute_capability(torch.device("cuda")) + if major * 10 + minor < 100: + pytest.skip("override-shape GEMM requires SM100+ (Blackwell)") + + +# ============================================================================ +# BF16 GEMM with override_shape +# ============================================================================ + + +class TestCudnnBf16OverrideShape: + """Single compiled plan handles multiple M dimensions for BF16 GEMM.""" + + @pytest.mark.parametrize( + "cache_m,dynamic_ms", + [ + (2048, [1, 4, 16, 32, 64, 128, 512, 1024, 2048]), + (4096, [1, 8, 64, 256, 1024, 4096]), + ], + ) + @pytest.mark.parametrize("n", [1024, 2048]) + @pytest.mark.parametrize("k", [1024, 2048]) + def test_bf16_override_shape_dynamic_m(self, cache_m, dynamic_ms, n, k): + _skip_if_no_cudnn() + _skip_if_override_shape_not_supported() + _skip_if_not_sm100() + + import cudnn + from flashinfer.gemm.gemm_base import _torch_data_type_to_cudnn_data_type + + device = torch.device("cuda") + in_dtype = torch.bfloat16 + out_dtype = torch.bfloat16 + + # Build graph once with cache_m + graph = build_cudnn_gemm_bf16_graph_override_shape( + batch=1, + n=n, + k=k, + o_type=_torch_data_type_to_cudnn_data_type(out_dtype), + device=device, + cache_m=cache_m, + ) + + workspace = torch.empty( + graph.get_workspace_size(), dtype=torch.uint8, device=device + ) + + b = torch.randn(1, k, n, dtype=in_dtype, device=device) + + for m in dynamic_ms: + a = torch.randn(1, m, k, dtype=in_dtype, device=device) + out = torch.empty(1, m, n, dtype=out_dtype, device=device) + ref = torch.bmm(a.float(), b.float()).to(out_dtype) + + execute_cudnn_gemm_bf16_graph_override_shape( + graph, a, b, out, workspace, tactic=0 + ) + torch.cuda.synchronize() + + assert torch.allclose(ref, out, rtol=5e-2, atol=5e-2), ( + f"BF16 override_shape failed for m={m}, n={n}, k={k}: " + f"max_abs_err={( ref - out).abs().max().item():.4f}, " + f"max_rel_err={(( ref - out).abs() / (ref.abs() + 1e-8)).max().item():.4f}" + ) + + +# ============================================================================ +# NVFP4 GEMM with override_shape +# ============================================================================ + + +class TestCudnnNVFp4OverrideShape: + """Single compiled plan handles multiple M dimensions for NVFP4 GEMM.""" + + @pytest.mark.skipif( + not CUDNN_AVAILABLE, + reason="cuDNN not available", + ) + @pytest.mark.parametrize( + "cache_m,dynamic_ms", + [ + (2048, [1, 4, 16, 32, 64, 128, 512, 1024, 2048]), + (4096, [1, 8, 64, 256, 1024, 4096]), + ], + ) + @pytest.mark.parametrize("n", [1024, 2048]) + @pytest.mark.parametrize("k", [1024, 2048]) + def test_nvfp4_override_shape_dynamic_m(self, cache_m, dynamic_ms, n, k): + _skip_if_no_cudnn() + _skip_if_override_shape_not_supported() + _skip_if_not_sm100() + + import cudnn + if cudnn.backend_version() < 91002: + pytest.skip("FP4 requires cuDNN backend >= 91002") + from flashinfer.gemm.gemm_base import _torch_data_type_to_cudnn_data_type + + device = torch.device("cuda") + block_size = 16 + out_dtype = torch.bfloat16 + + # Compute block scale dims using cache_m + block_scale_dim_m_cache, block_scale_dim_n, block_scale_dim_k = ( + _calculate_block_scale_dims(cache_m, n, k, block_size) + ) + + # Build graph once with cache_m + graph = build_cudnn_fp4_gemm_graph_override_shape( + batch=1, + n=n, + k=k, + a_descale_n_dim=block_scale_dim_m_cache, + a_descale_k_dim=block_scale_dim_k, + b_descale_k_dim=block_scale_dim_k, + b_descale_n_dim=block_scale_dim_n, + ab_type=cudnn.data_type.FP4_E2M1, + o_type=_torch_data_type_to_cudnn_data_type(out_dtype), + block_size=block_size, + device=device, + alpha_is_not_none=False, + use_nvfp4=True, + cache_m=cache_m, + ) + + workspace = torch.empty( + graph.get_workspace_size(), dtype=torch.uint8, device=device, + ) + + # B is fixed across all dynamic_ms + b_packed = torch.randint(0, 256, (1, n, k // 2), dtype=torch.uint8, device=device).transpose(1, 2) + b_descale = torch.ones( + 1, block_scale_dim_n, block_scale_dim_k, dtype=torch.float8_e4m3fn, device=device + ).transpose(1, 2) + + for m in dynamic_ms: + block_scale_dim_m, _, _ = _calculate_block_scale_dims(m, n, k, block_size) + + a_packed = torch.randint(0, 256, (1, m, k // 2), dtype=torch.uint8, device=device) + a_descale = torch.ones( + 1, block_scale_dim_m, block_scale_dim_k, dtype=torch.float8_e4m3fn, device=device + ) + + # Execute with cached graph (override_shape) + out = torch.empty(1, m, n, dtype=out_dtype, device=device) + execute_cudnn_fp4_gemm_graph_override_shape( + graph, a_packed, b_packed, a_descale, b_descale, + alpha=None, c_final=out, workspace_buffer=workspace, tactic=0, + ) + torch.cuda.synchronize() + + +# ============================================================================ +# MXFP8 GEMM with override_shape +# ============================================================================ + + +class TestCudnnMXFp8OverrideShape: + """Single compiled plan handles multiple M dimensions for MXFP8 GEMM.""" + + @pytest.mark.skipif( + not CUDNN_AVAILABLE, + reason="cuDNN not available", + ) + @pytest.mark.parametrize( + "cache_m,dynamic_ms", + [ + (2048, [1, 4, 16, 32, 64, 128, 512, 1024, 2048]), + (4096, [1, 8, 64, 256, 1024, 4096]), + ], + ) + @pytest.mark.parametrize("n", [1024, 2048]) + @pytest.mark.parametrize("k", [1024, 2048]) + def test_mxfp8_override_shape_dynamic_m(self, cache_m, dynamic_ms, n, k): + _skip_if_no_cudnn() + _skip_if_override_shape_not_supported() + _skip_if_not_sm100() + + import cudnn + from flashinfer.gemm.gemm_base import _torch_data_type_to_cudnn_data_type + + device = torch.device("cuda") + block_size = 32 + out_dtype = torch.bfloat16 + + # Compute block scale dims using cache_m + block_scale_dim_m_cache, block_scale_dim_n, block_scale_dim_k = ( + _calculate_block_scale_dims(cache_m, n, k, block_size) + ) + + # Build graph once with cache_m + graph = build_cudnn_mxfp8_gemm_graph_override_shape( + batch=1, + n=n, + k=k, + a_type=cudnn.data_type.FP8_E4M3, + b_type=cudnn.data_type.FP8_E4M3, + o_type=_torch_data_type_to_cudnn_data_type(out_dtype), + block_size=block_size, + device=device, + cache_m=cache_m, + ) + + workspace = torch.empty( + graph.get_workspace_size(), dtype=torch.uint8, device=device, + ) + + b = torch.randint(0, 256, (1, n, k), dtype=torch.uint8, device=device).transpose(1, 2) + b_descale = torch.ones( + 1, block_scale_dim_n, block_scale_dim_k, dtype=torch.float8_e8m0fnu, device=device + ).transpose(1, 2) + + for m in dynamic_ms: + block_scale_dim_m, _, _ = _calculate_block_scale_dims(m, n, k, block_size) + + a = torch.randint(0, 256, (1, m, k), dtype=torch.uint8, device=device) + a_descale = torch.ones( + 1, block_scale_dim_m, block_scale_dim_k, dtype=torch.float8_e8m0fnu, device=device + ) + + # Execute with cached graph (override_shape) + out = torch.empty(1, m, n, dtype=out_dtype, device=device) + execute_cudnn_mxfp8_gemm_graph_override_shape( + graph, a, b, a_descale, b_descale, + c_final=out, workspace_buffer=workspace, tactic=0, + ) + torch.cuda.synchronize() From daac1dc2ee97f89f99abaab6b5e87817d80a17f8 Mon Sep 17 00:00:00 2001 From: Yanqin Zhai Date: Fri, 13 Mar 2026 21:40:07 -0700 Subject: [PATCH 03/11] bug-fix --- tests/gemm/test_cudnn_override_shape.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/gemm/test_cudnn_override_shape.py b/tests/gemm/test_cudnn_override_shape.py index 0a2cd4dcc4..c3676625b2 100644 --- a/tests/gemm/test_cudnn_override_shape.py +++ b/tests/gemm/test_cudnn_override_shape.py @@ -98,7 +98,7 @@ def test_bf16_override_shape_dynamic_m(self, cache_m, dynamic_ms, n, k): graph.get_workspace_size(), dtype=torch.uint8, device=device ) - b = torch.randn(1, k, n, dtype=in_dtype, device=device) + b = torch.randn(1, n, k, dtype=in_dtype, device=device).transpose(1, 2) for m in dynamic_ms: a = torch.randn(1, m, k, dtype=in_dtype, device=device) From d3f08d5b09403b55d3c758d701558a11a5e1b2df Mon Sep 17 00:00:00 2001 From: Yanqin Zhai Date: Sat, 14 Mar 2026 09:33:17 -0700 Subject: [PATCH 04/11] add-non-k-major-layout-support --- flashinfer/gemm/gemm_base.py | 15 ++++++++++++--- tests/gemm/test_cudnn_override_shape.py | 2 ++ 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index 243849d769..da45021224 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -2960,7 +2960,8 @@ def execute_cudnn_gemm_bf16_graph(graph, a, b, c_final, workspace, tactic: int = @functools.cache def build_cudnn_gemm_bf16_graph_override_shape( - batch, n, k, o_type, device, cache_m: int = _OVERRIDE_SHAPE_CACHE_M + batch, n, k, o_type, device, cache_m: int = _OVERRIDE_SHAPE_CACHE_M, + is_a_k_major: bool = True, is_b_k_major: bool = True, ): """Build a cuDNN BF16 GEMM graph with override-shape support. @@ -2972,13 +2973,21 @@ def build_cudnn_gemm_bf16_graph_override_shape( Caching key is ``(batch, n, k, o_type, device, cache_m)`` — M is **not** part of the key. + + Args: + is_a_k_major: If True, A has shape (batch, M, K) with row-major strides + (K is the contiguous dimension). If False, A has shape (batch, M, K) + with column-major strides (M is the contiguous dimension). + is_b_k_major: If True, B has shape (batch, K, N) where K is the leading + dimension (stride along N is 1, i.e. N-contiguous within each K row). + If False, B is row-major with K-contiguous layout (stride along K is 1). """ _check_cudnn_override_shape_availability() a_shape = (batch, cache_m, k) - a_stride = (cache_m * k, k, 1) + a_stride = (cache_m * k, k, 1) if is_a_k_major else (cache_m * k, 1, cache_m) b_shape = (batch, k, n) - b_stride = (k * n, 1, k) + b_stride = (k * n, 1, k) if is_b_k_major else (k * n, n, 1) stream = torch.cuda.current_stream(device) graph = cudnn.pygraph( diff --git a/tests/gemm/test_cudnn_override_shape.py b/tests/gemm/test_cudnn_override_shape.py index c3676625b2..71d8c90151 100644 --- a/tests/gemm/test_cudnn_override_shape.py +++ b/tests/gemm/test_cudnn_override_shape.py @@ -92,6 +92,8 @@ def test_bf16_override_shape_dynamic_m(self, cache_m, dynamic_ms, n, k): o_type=_torch_data_type_to_cudnn_data_type(out_dtype), device=device, cache_m=cache_m, + is_a_k_major=True, + is_b_k_major=True, ) workspace = torch.empty( From 1803c4e2613199e544f93f443d3948d762249187 Mon Sep 17 00:00:00 2001 From: Yanqin Zhai Date: Sun, 15 Mar 2026 13:33:51 -0700 Subject: [PATCH 05/11] pre-commit-fixes --- flashinfer/gemm/gemm_base.py | 53 +++++++++-------- tests/gemm/test_cudnn_override_shape.py | 77 ++++++++++++++++++------- 2 files changed, 85 insertions(+), 45 deletions(-) diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index da45021224..7a0f4bc7ed 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -2011,20 +2011,14 @@ def build_cudnn_fp4_gemm_graph_override_shape( a_shape = [batch, cache_m, k * 2] # FP4 packed: K dimension stores k*2 uint8 values a_stride = [cache_m * k * 2, k * 2, 1] - # cuDNN expects the real FP4 shape (unpacked); use *2 only when viewed as fp4 dtype - # For the cudnn tensor we use the actual fp4 element count - a_fp4_shape = [batch, cache_m, k] - a_fp4_stride = [cache_m * k, k, 1] - b_fp4_shape = [batch, k, n] - b_fp4_stride = [k * n, 1, k] + b_shape = [batch, k * 2, n] + b_stride = [k * n * 2, 1, k * 2] a_descale_shape = [batch, block_scale_dim_m, a_descale_k_dim] a_descale_stride = [block_scale_dim_m * a_descale_k_dim, a_descale_k_dim, 1] b_descale_shape = [batch, b_descale_k_dim, b_descale_n_dim] b_descale_stride = [b_descale_n_dim * b_descale_k_dim, 1, b_descale_k_dim] - c_shape = [batch, cache_m, n] - c_stride = [cache_m * n, n, 1] stream = torch.cuda.current_stream(device) graph = cudnn.pygraph( @@ -2036,10 +2030,10 @@ def build_cudnn_fp4_gemm_graph_override_shape( ) a_cudnn_tensor = graph.tensor( - name="a", dim=a_fp4_shape, stride=a_fp4_stride, data_type=ab_type + name="a", dim=a_shape, stride=a_stride, data_type=ab_type ) b_cudnn_tensor = graph.tensor( - name="b", dim=b_fp4_shape, stride=b_fp4_stride, data_type=ab_type + name="b", dim=b_shape, stride=b_stride, data_type=ab_type ) block_descale_a_cudnn_tensor = graph.tensor( name="block_descale_a", @@ -2269,8 +2263,12 @@ def _cudnn_gemm_fp4_override_shape( a_3d = a.view(real_a_shape) if a.ndim == 2 else a b_3d = b.view(real_b_shape) if b.ndim == 2 else b - a_descale_3d = a_descale.view(expanded_a_descale_shape) if a_descale.ndim == 2 else a_descale - b_descale_3d = b_descale.view(expanded_b_descale_shape) if b_descale.ndim == 2 else b_descale + a_descale_3d = ( + a_descale.view(expanded_a_descale_shape) if a_descale.ndim == 2 else a_descale + ) + b_descale_3d = ( + b_descale.view(expanded_b_descale_shape) if b_descale.ndim == 2 else b_descale + ) out_3d = out.unsqueeze(0) if out.ndim == 2 else out execute_cudnn_fp4_gemm_graph_override_shape( @@ -2352,8 +2350,8 @@ def build_cudnn_mxfp8_gemm_graph_override_shape( if o_type not in [cudnn.data_type.BFLOAT16, cudnn.data_type.HALF]: raise ValueError(f"Output type must be BF16 or FP16, got {o_type}") - block_scale_dim_m, block_scale_dim_n, block_scale_dim_k = _calculate_block_scale_dims( - cache_m, n, k, block_size + block_scale_dim_m, block_scale_dim_n, block_scale_dim_k = ( + _calculate_block_scale_dims(cache_m, n, k, block_size) ) scale_type = cudnn.data_type.FP8_E8M0 @@ -2516,8 +2514,8 @@ def _cudnn_gemm_mxfp8_override_shape( k = a_3d.shape[2] n = b_3d.shape[2] - block_scale_dim_m, block_scale_dim_n, block_scale_dim_k = _calculate_block_scale_dims( - a_3d.shape[1], n, k, block_size + block_scale_dim_m, block_scale_dim_n, block_scale_dim_k = ( + _calculate_block_scale_dims(a_3d.shape[1], n, k, block_size) ) if a_descale.ndim == 2: @@ -2663,8 +2661,7 @@ def execute_cudnn_gemm_with_per_tensor_q_graph( @functools.cache def build_cudnn_gemm_with_per_tensor_q_graph_override_shape( - batch, n, k, a_type, b_type, o_type, device, - cache_m: int = _OVERRIDE_SHAPE_CACHE_M + batch, n, k, a_type, b_type, o_type, device, cache_m: int = _OVERRIDE_SHAPE_CACHE_M ): """Build an FP8 per-tensor-quantized GEMM cuDNN graph with override-shape. @@ -2677,8 +2674,6 @@ def build_cudnn_gemm_with_per_tensor_q_graph_override_shape( a_stride = [cache_m * k, k, 1] b_shape = [batch, k, n] b_stride = [k * n, 1, k] - c_shape = [batch, cache_m, n] - c_stride = [cache_m * n, n, 1] stream = torch.cuda.current_stream(device) graph = cudnn.pygraph( @@ -2960,8 +2955,14 @@ def execute_cudnn_gemm_bf16_graph(graph, a, b, c_final, workspace, tactic: int = @functools.cache def build_cudnn_gemm_bf16_graph_override_shape( - batch, n, k, o_type, device, cache_m: int = _OVERRIDE_SHAPE_CACHE_M, - is_a_k_major: bool = True, is_b_k_major: bool = True, + batch, + n, + k, + o_type, + device, + cache_m: int = _OVERRIDE_SHAPE_CACHE_M, + is_a_k_major: bool = True, + is_b_k_major: bool = True, ): """Build a cuDNN BF16 GEMM graph with override-shape support. @@ -2999,11 +3000,15 @@ def build_cudnn_gemm_bf16_graph_override_shape( ) a_cudnn_tensor = graph.tensor( - name="a", dim=list(a_shape), stride=list(a_stride), + name="a", + dim=list(a_shape), + stride=list(a_stride), data_type=cudnn.data_type.BFLOAT16, ) b_cudnn_tensor = graph.tensor( - name="b", dim=list(b_shape), stride=list(b_stride), + name="b", + dim=list(b_shape), + stride=list(b_stride), data_type=cudnn.data_type.BFLOAT16, ) c_cudnn_tensor = graph.matmul( diff --git a/tests/gemm/test_cudnn_override_shape.py b/tests/gemm/test_cudnn_override_shape.py index 71d8c90151..2f313c9c3e 100644 --- a/tests/gemm/test_cudnn_override_shape.py +++ b/tests/gemm/test_cudnn_override_shape.py @@ -15,9 +15,6 @@ from flashinfer.gemm.gemm_base import ( CUDNN_AVAILABLE, - _OVERRIDE_SHAPE_CACHE_M, - _get_bf16_3d_shape_stride, - _cudnn_gemm_bf16_override_shape, build_cudnn_gemm_bf16_graph_override_shape, execute_cudnn_gemm_bf16_graph_override_shape, build_cudnn_fp4_gemm_graph_override_shape, @@ -27,8 +24,6 @@ is_cudnn_override_shape_available, CUDNN_MIN_VERSION_OVERRIDE_SHAPE, _calculate_block_scale_dims, - _get_real_fp4_shape_from_packed_uint8, - _expand_block_scale_tensor_shape, ) from flashinfer.utils import get_compute_capability @@ -43,6 +38,7 @@ def _skip_if_override_shape_not_supported(): pytest.skip("cuDNN not available") if not is_cudnn_override_shape_available(): import cudnn + pytest.skip( f"cuDNN override-shape requires backend >= {CUDNN_MIN_VERSION_OVERRIDE_SHAPE} " f"(9.21.0), found {cudnn.backend_version()}" @@ -77,7 +73,6 @@ def test_bf16_override_shape_dynamic_m(self, cache_m, dynamic_ms, n, k): _skip_if_override_shape_not_supported() _skip_if_not_sm100() - import cudnn from flashinfer.gemm.gemm_base import _torch_data_type_to_cudnn_data_type device = torch.device("cuda") @@ -114,8 +109,8 @@ def test_bf16_override_shape_dynamic_m(self, cache_m, dynamic_ms, n, k): assert torch.allclose(ref, out, rtol=5e-2, atol=5e-2), ( f"BF16 override_shape failed for m={m}, n={n}, k={k}: " - f"max_abs_err={( ref - out).abs().max().item():.4f}, " - f"max_rel_err={(( ref - out).abs() / (ref.abs() + 1e-8)).max().item():.4f}" + f"max_abs_err={(ref - out).abs().max().item():.4f}, " + f"max_rel_err={((ref - out).abs() / (ref.abs() + 1e-8)).max().item():.4f}" ) @@ -146,6 +141,7 @@ def test_nvfp4_override_shape_dynamic_m(self, cache_m, dynamic_ms, n, k): _skip_if_not_sm100() import cudnn + if cudnn.backend_version() < 91002: pytest.skip("FP4 requires cuDNN backend >= 91002") from flashinfer.gemm.gemm_base import _torch_data_type_to_cudnn_data_type @@ -178,28 +174,49 @@ def test_nvfp4_override_shape_dynamic_m(self, cache_m, dynamic_ms, n, k): ) workspace = torch.empty( - graph.get_workspace_size(), dtype=torch.uint8, device=device, + graph.get_workspace_size(), + dtype=torch.uint8, + device=device, ) # B is fixed across all dynamic_ms - b_packed = torch.randint(0, 256, (1, n, k // 2), dtype=torch.uint8, device=device).transpose(1, 2) + b_packed = torch.randint( + 0, 256, (1, n, k // 2), dtype=torch.uint8, device=device + ).transpose(1, 2) b_descale = torch.ones( - 1, block_scale_dim_n, block_scale_dim_k, dtype=torch.float8_e4m3fn, device=device + 1, + block_scale_dim_n, + block_scale_dim_k, + dtype=torch.float8_e4m3fn, + device=device, ).transpose(1, 2) for m in dynamic_ms: block_scale_dim_m, _, _ = _calculate_block_scale_dims(m, n, k, block_size) - a_packed = torch.randint(0, 256, (1, m, k // 2), dtype=torch.uint8, device=device) + a_packed = torch.randint( + 0, 256, (1, m, k // 2), dtype=torch.uint8, device=device + ) a_descale = torch.ones( - 1, block_scale_dim_m, block_scale_dim_k, dtype=torch.float8_e4m3fn, device=device + 1, + block_scale_dim_m, + block_scale_dim_k, + dtype=torch.float8_e4m3fn, + device=device, ) # Execute with cached graph (override_shape) out = torch.empty(1, m, n, dtype=out_dtype, device=device) execute_cudnn_fp4_gemm_graph_override_shape( - graph, a_packed, b_packed, a_descale, b_descale, - alpha=None, c_final=out, workspace_buffer=workspace, tactic=0, + graph, + a_packed, + b_packed, + a_descale, + b_descale, + alpha=None, + c_final=out, + workspace_buffer=workspace, + tactic=0, ) torch.cuda.synchronize() @@ -256,12 +273,20 @@ def test_mxfp8_override_shape_dynamic_m(self, cache_m, dynamic_ms, n, k): ) workspace = torch.empty( - graph.get_workspace_size(), dtype=torch.uint8, device=device, + graph.get_workspace_size(), + dtype=torch.uint8, + device=device, ) - b = torch.randint(0, 256, (1, n, k), dtype=torch.uint8, device=device).transpose(1, 2) + b = torch.randint( + 0, 256, (1, n, k), dtype=torch.uint8, device=device + ).transpose(1, 2) b_descale = torch.ones( - 1, block_scale_dim_n, block_scale_dim_k, dtype=torch.float8_e8m0fnu, device=device + 1, + block_scale_dim_n, + block_scale_dim_k, + dtype=torch.float8_e8m0fnu, + device=device, ).transpose(1, 2) for m in dynamic_ms: @@ -269,13 +294,23 @@ def test_mxfp8_override_shape_dynamic_m(self, cache_m, dynamic_ms, n, k): a = torch.randint(0, 256, (1, m, k), dtype=torch.uint8, device=device) a_descale = torch.ones( - 1, block_scale_dim_m, block_scale_dim_k, dtype=torch.float8_e8m0fnu, device=device + 1, + block_scale_dim_m, + block_scale_dim_k, + dtype=torch.float8_e8m0fnu, + device=device, ) # Execute with cached graph (override_shape) out = torch.empty(1, m, n, dtype=out_dtype, device=device) execute_cudnn_mxfp8_gemm_graph_override_shape( - graph, a, b, a_descale, b_descale, - c_final=out, workspace_buffer=workspace, tactic=0, + graph, + a, + b, + a_descale, + b_descale, + c_final=out, + workspace_buffer=workspace, + tactic=0, ) torch.cuda.synchronize() From 9c21ddf835351fc542ea07893e147bf13dfdd496 Mon Sep 17 00:00:00 2001 From: Yanqin Zhai Date: Mon, 16 Mar 2026 09:57:07 -0700 Subject: [PATCH 06/11] fix-some-review-comments --- flashinfer/gemm/__init__.py | 10 ++++++++++ flashinfer/gemm/gemm_base.py | 19 ++++--------------- 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/flashinfer/gemm/__init__.py b/flashinfer/gemm/__init__.py index 2d99f12e6a..f1348520d9 100644 --- a/flashinfer/gemm/__init__.py +++ b/flashinfer/gemm/__init__.py @@ -77,4 +77,14 @@ "mm_M1_16_K7168_N128", "mm_M1_16_K7168_N256", "tinygemm_bf16", + "is_cudnn_override_shape_available", + "CUDNN_MIN_VERSION_OVERRIDE_SHAPE", + "build_cudnn_gemm_bf16_graph_override_shape", + "execute_cudnn_gemm_bf16_graph_override_shape", + "build_cudnn_fp4_gemm_graph_override_shape", + "execute_cudnn_fp4_gemm_graph_override_shape", + "build_cudnn_mxfp8_gemm_graph_override_shape", + "execute_cudnn_mxfp8_gemm_graph_override_shape", + "build_cudnn_gemm_with_per_tensor_q_graph_override_shape", + "execute_cudnn_gemm_with_per_tensor_q_graph_override_shape", ] + _cute_dsl_kernels diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index 7a0f4bc7ed..52b794221a 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -2009,11 +2009,11 @@ def build_cudnn_fp4_gemm_graph_override_shape( cache_m, n, k, block_size ) - a_shape = [batch, cache_m, k * 2] # FP4 packed: K dimension stores k*2 uint8 values - a_stride = [cache_m * k * 2, k * 2, 1] + a_shape = [batch, cache_m, k] # FP4 packed: K dimension stores k*2 uint8 values + a_stride = [cache_m * k, k, 1] - b_shape = [batch, k * 2, n] - b_stride = [k * n * 2, 1, k * 2] + b_shape = [batch, k, n] + b_stride = [k * n, 1, k] a_descale_shape = [batch, block_scale_dim_m, a_descale_k_dim] a_descale_stride = [block_scale_dim_m * a_descale_k_dim, a_descale_k_dim, 1] @@ -2158,11 +2158,6 @@ def execute_cudnn_fp4_gemm_graph_override_shape( c_final.stride(), ] - if workspace_buffer.numel() < graph.get_workspace_size(): - workspace_buffer = torch.empty( - graph.get_workspace_size(), device=a.device, dtype=torch.uint8 - ) - stream = torch.cuda.current_stream(a.device) graph.execute_plan_at_index( @@ -2475,12 +2470,6 @@ def execute_cudnn_mxfp8_gemm_graph_override_shape( list(c_final.stride()), ] - workspace_size = graph.get_workspace_size() - if workspace_buffer.numel() < workspace_size: - workspace_buffer = torch.empty( - workspace_size, device=a.device, dtype=torch.uint8 - ) - stream = torch.cuda.current_stream(a.device) graph.execute_plan_at_index( variant_pack, From 5ff1bebd8618dca7599d55c8a990e7b75b77fbe6 Mon Sep 17 00:00:00 2001 From: Yanqin Zhai Date: Mon, 16 Mar 2026 11:21:16 -0700 Subject: [PATCH 07/11] optimize-the-api-address-comments --- flashinfer/gemm/__init__.py | 2 -- flashinfer/gemm/gemm_base.py | 46 +++++++++++++++---------- tests/gemm/test_cudnn_override_shape.py | 14 ++------ 3 files changed, 30 insertions(+), 32 deletions(-) diff --git a/flashinfer/gemm/__init__.py b/flashinfer/gemm/__init__.py index f1348520d9..3d4555c28e 100644 --- a/flashinfer/gemm/__init__.py +++ b/flashinfer/gemm/__init__.py @@ -20,7 +20,6 @@ from .gemm_base import fp8_blockscale_gemm_sm90 as fp8_blockscale_gemm_sm90 from .gemm_base import ( is_cudnn_override_shape_available as is_cudnn_override_shape_available, - CUDNN_MIN_VERSION_OVERRIDE_SHAPE as CUDNN_MIN_VERSION_OVERRIDE_SHAPE, build_cudnn_gemm_bf16_graph_override_shape as build_cudnn_gemm_bf16_graph_override_shape, execute_cudnn_gemm_bf16_graph_override_shape as execute_cudnn_gemm_bf16_graph_override_shape, build_cudnn_fp4_gemm_graph_override_shape as build_cudnn_fp4_gemm_graph_override_shape, @@ -78,7 +77,6 @@ "mm_M1_16_K7168_N256", "tinygemm_bf16", "is_cudnn_override_shape_available", - "CUDNN_MIN_VERSION_OVERRIDE_SHAPE", "build_cudnn_gemm_bf16_graph_override_shape", "execute_cudnn_gemm_bf16_graph_override_shape", "build_cudnn_fp4_gemm_graph_override_shape", diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index 52b794221a..de74872480 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -1720,21 +1720,29 @@ def _is_cublas_fp4_available_in_cudnn(): ) -# Minimum cuDNN backend version required for is_override_shape_enabled support. -CUDNN_MIN_VERSION_OVERRIDE_SHAPE = 92100 - - def _check_cudnn_override_shape_availability(): """Raise if the installed cuDNN backend does not support is_override_shape_enabled.""" _check_cudnn_availability() backend_version = cudnn.backend_version() - if backend_version < CUDNN_MIN_VERSION_OVERRIDE_SHAPE: + if backend_version < 92100: raise RuntimeError( - f"cuDNN override-shape GEMM requires backend version >= " - f"{CUDNN_MIN_VERSION_OVERRIDE_SHAPE} (9.21.0), " + f"cuDNN override-shape GEMM requires backend version >= 92100 (9.21.0), " f"found {backend_version}. " f"Please upgrade cuDNN: pip install --upgrade nvidia-cudnn-cu12 nvidia-cudnn-frontend" ) + try: + version_str = cudnn.__version__ + major, minor = map(int, version_str.split(".")[:2]) + if (major, minor) < (1, 20): + raise RuntimeError( + f"cuDNN override-shape GEMM requires cudnn-frontend version >= 1.20, found {version_str}. " + f"Please upgrade: pip install --upgrade nvidia-cudnn-frontend" + ) + except (AttributeError, ValueError, IndexError) as e: + raise RuntimeError( + "Unable to determine cudnn-frontend version. " + "Override-shape GEMM requires cudnn-frontend >= 1.20" + ) from e def is_cudnn_override_shape_available() -> bool: @@ -1742,7 +1750,11 @@ def is_cudnn_override_shape_available() -> bool: if not CUDNN_AVAILABLE: return False try: - return cudnn.backend_version() >= CUDNN_MIN_VERSION_OVERRIDE_SHAPE + if cudnn.backend_version() < 92100: + return False + version_str = cudnn.__version__ + major, minor = map(int, version_str.split(".")[:2]) + return (major, minor) >= (1, 20) except Exception: return False @@ -1979,10 +1991,6 @@ def build_cudnn_fp4_gemm_graph_override_shape( batch, n, k, - a_descale_n_dim, - a_descale_k_dim, - b_descale_k_dim, - b_descale_n_dim, ab_type, o_type, block_size, @@ -2005,20 +2013,20 @@ def build_cudnn_fp4_gemm_graph_override_shape( scale_type = cudnn.data_type.FP8_E4M3 if use_nvfp4 else cudnn.data_type.FP8_E8M0 # Build shapes / strides using cache_m - block_scale_dim_m, _, block_scale_dim_k = _calculate_block_scale_dims( - cache_m, n, k, block_size + block_scale_dim_m, block_scale_dim_n, block_scale_dim_k = ( + _calculate_block_scale_dims(cache_m, n, k, block_size) ) - a_shape = [batch, cache_m, k] # FP4 packed: K dimension stores k*2 uint8 values + a_shape = [batch, cache_m, k] a_stride = [cache_m * k, k, 1] b_shape = [batch, k, n] b_stride = [k * n, 1, k] - a_descale_shape = [batch, block_scale_dim_m, a_descale_k_dim] - a_descale_stride = [block_scale_dim_m * a_descale_k_dim, a_descale_k_dim, 1] - b_descale_shape = [batch, b_descale_k_dim, b_descale_n_dim] - b_descale_stride = [b_descale_n_dim * b_descale_k_dim, 1, b_descale_k_dim] + a_descale_shape = [batch, block_scale_dim_m, block_scale_dim_k] + a_descale_stride = [block_scale_dim_m * block_scale_dim_k, block_scale_dim_k, 1] + b_descale_shape = [batch, block_scale_dim_k, block_scale_dim_n] + b_descale_stride = [block_scale_dim_n * block_scale_dim_k, 1, block_scale_dim_k] stream = torch.cuda.current_stream(device) graph = cudnn.pygraph( diff --git a/tests/gemm/test_cudnn_override_shape.py b/tests/gemm/test_cudnn_override_shape.py index 2f313c9c3e..8aeb40bd92 100644 --- a/tests/gemm/test_cudnn_override_shape.py +++ b/tests/gemm/test_cudnn_override_shape.py @@ -22,7 +22,6 @@ build_cudnn_mxfp8_gemm_graph_override_shape, execute_cudnn_mxfp8_gemm_graph_override_shape, is_cudnn_override_shape_available, - CUDNN_MIN_VERSION_OVERRIDE_SHAPE, _calculate_block_scale_dims, ) from flashinfer.utils import get_compute_capability @@ -37,11 +36,8 @@ def _skip_if_override_shape_not_supported(): if not CUDNN_AVAILABLE: pytest.skip("cuDNN not available") if not is_cudnn_override_shape_available(): - import cudnn - pytest.skip( - f"cuDNN override-shape requires backend >= {CUDNN_MIN_VERSION_OVERRIDE_SHAPE} " - f"(9.21.0), found {cudnn.backend_version()}" + "cuDNN override-shape requires higher version of cuDNN backend and frontend" ) @@ -151,8 +147,8 @@ def test_nvfp4_override_shape_dynamic_m(self, cache_m, dynamic_ms, n, k): out_dtype = torch.bfloat16 # Compute block scale dims using cache_m - block_scale_dim_m_cache, block_scale_dim_n, block_scale_dim_k = ( - _calculate_block_scale_dims(cache_m, n, k, block_size) + _, block_scale_dim_n, block_scale_dim_k = _calculate_block_scale_dims( + cache_m, n, k, block_size ) # Build graph once with cache_m @@ -160,10 +156,6 @@ def test_nvfp4_override_shape_dynamic_m(self, cache_m, dynamic_ms, n, k): batch=1, n=n, k=k, - a_descale_n_dim=block_scale_dim_m_cache, - a_descale_k_dim=block_scale_dim_k, - b_descale_k_dim=block_scale_dim_k, - b_descale_n_dim=block_scale_dim_n, ab_type=cudnn.data_type.FP4_E2M1, o_type=_torch_data_type_to_cudnn_data_type(out_dtype), block_size=block_size, From b4c92c62906fe3f304dc1802b0d5f73104fd7bfd Mon Sep 17 00:00:00 2001 From: Yanqin Zhai Date: Mon, 16 Mar 2026 11:40:30 -0700 Subject: [PATCH 08/11] add-correctness-check-for-fp4-fp8 --- tests/gemm/test_cudnn_override_shape.py | 66 ++++++++++++++++++++++++- 1 file changed, 64 insertions(+), 2 deletions(-) diff --git a/tests/gemm/test_cudnn_override_shape.py b/tests/gemm/test_cudnn_override_shape.py index 8aeb40bd92..630295e897 100644 --- a/tests/gemm/test_cudnn_override_shape.py +++ b/tests/gemm/test_cudnn_override_shape.py @@ -171,6 +171,31 @@ def test_nvfp4_override_shape_dynamic_m(self, cache_m, dynamic_ms, n, k): device=device, ) + # FP4 E2M1 lookup table: index is the 4-bit pattern (0–15) + # Encoding: sign(1) | exp(2) | mantissa(1), bias=1 + FP4_E2M1_LUT = torch.tensor( + [ + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, + -0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, + ], + dtype=torch.float32, + device=device, + ) + # B is fixed across all dynamic_ms b_packed = torch.randint( 0, 256, (1, n, k // 2), dtype=torch.uint8, device=device @@ -212,6 +237,25 @@ def test_nvfp4_override_shape_dynamic_m(self, cache_m, dynamic_ms, n, k): ) torch.cuda.synchronize() + # Correctness check: dequantize FP4 E2M1 via LUT and compare with + # FP32 bmm reference. Descales are all 1.0, so no scaling needed. + # A packing: a_packed (1, m, k//2), low nibble = even k, high = odd k + a_fp32 = torch.empty(1, m, k, dtype=torch.float32, device=device) + a_fp32[:, :, 0::2] = FP4_E2M1_LUT[(a_packed & 0x0F).long()] + a_fp32[:, :, 1::2] = FP4_E2M1_LUT[((a_packed >> 4) & 0x0F).long()] + # B packing: b_packed (1, k//2, n), low nibble = even k, high = odd k + b_fp32 = torch.empty(1, k, n, dtype=torch.float32, device=device) + b_fp32[:, 0::2, :] = FP4_E2M1_LUT[(b_packed & 0x0F).long()] + b_fp32[:, 1::2, :] = FP4_E2M1_LUT[((b_packed >> 4) & 0x0F).long()] + ref = torch.bmm(a_fp32, b_fp32).to(out_dtype) + + assert torch.allclose(ref, out, rtol=1e-1, atol=1.0), ( + f"NVFP4 override_shape failed for m={m}, n={n}, k={k}: " + f"max_abs_err={(ref - out).abs().max().item():.4f}, " + f"max_rel_err=" + f"{((ref - out).abs() / (ref.abs() + 1e-8)).max().item():.4f}" + ) + # ============================================================================ # MXFP8 GEMM with override_shape @@ -270,8 +314,9 @@ def test_mxfp8_override_shape_dynamic_m(self, cache_m, dynamic_ms, n, k): device=device, ) + # Use values 0–126 to avoid NaN FP8_E4M3 bit patterns (0x7F, 0xFF). b = torch.randint( - 0, 256, (1, n, k), dtype=torch.uint8, device=device + 0, 127, (1, n, k), dtype=torch.uint8, device=device ).transpose(1, 2) b_descale = torch.ones( 1, @@ -284,7 +329,8 @@ def test_mxfp8_override_shape_dynamic_m(self, cache_m, dynamic_ms, n, k): for m in dynamic_ms: block_scale_dim_m, _, _ = _calculate_block_scale_dims(m, n, k, block_size) - a = torch.randint(0, 256, (1, m, k), dtype=torch.uint8, device=device) + # Use values 0–126 to avoid NaN FP8_E4M3 bit patterns (0x7F, 0xFF). + a = torch.randint(0, 127, (1, m, k), dtype=torch.uint8, device=device) a_descale = torch.ones( 1, block_scale_dim_m, @@ -306,3 +352,19 @@ def test_mxfp8_override_shape_dynamic_m(self, cache_m, dynamic_ms, n, k): tactic=0, ) torch.cuda.synchronize() + + # Correctness check: reinterpret uint8 as FP8_E4M3, compute FP32 + # bmm reference. Descales are all 1.0 (2^0), so no scaling needed. + # A: (1, m, k) contiguous uint8 → float8_e4m3fn → float32 + a_fp32 = a.view(torch.float8_e4m3fn).to(torch.float32) + # B logical shape is (1, k, n) with stride [n*k, 1, k]; make + # contiguous before view so dtype reinterpretation is valid. + b_fp32 = b.contiguous().view(torch.float8_e4m3fn).to(torch.float32) + ref = torch.bmm(a_fp32, b_fp32).to(out_dtype) + + assert torch.allclose(ref, out, rtol=5e-2, atol=5e-2), ( + f"MXFP8 override_shape failed for m={m}, n={n}, k={k}: " + f"max_abs_err={(ref - out).abs().max().item():.4f}, " + f"max_rel_err=" + f"{((ref - out).abs() / (ref.abs() + 1e-8)).max().item():.4f}" + ) From 4c8b0cead88c8ad7a6214d9362343a3634d68932 Mon Sep 17 00:00:00 2001 From: Yanqin Zhai Date: Tue, 17 Mar 2026 09:48:24 -0700 Subject: [PATCH 09/11] address-some-comments --- flashinfer/gemm/gemm_base.py | 149 +----------------------- tests/gemm/test_cudnn_override_shape.py | 33 +++--- 2 files changed, 22 insertions(+), 160 deletions(-) diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index de74872480..5668248c4e 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -1978,7 +1978,8 @@ def execute_cudnn_gemm_fp4_graph( # --------------------------------------------------------------------------- # Sentinel value used as "cache M" when building override-shape graphs. -# Any sufficiently large M will work; 8192 covers typical LLM inference shapes. +# Any M value will work in general. +# 8192 covers typical LLM inference shapes and set as default value. _OVERRIDE_SHAPE_CACHE_M = 8192 # --------------------------------------------------------------------------- @@ -2131,7 +2132,7 @@ def execute_cudnn_fp4_gemm_graph_override_shape( ): """Execute FP4 GEMM cuDNN graph with dynamic-shape overrides.""" - assert a.stride()[2] == 1 and b.stride()[1] == 1 + assert a.stride()[2] == 1 and b.stride()[1] == 1, "a and b must be k-major" variant_pack = { UIDs.A_UID.value: a, @@ -2228,65 +2229,6 @@ def _get_cudnn_fp4_gemm_graph_override_shape( ) -def _cudnn_gemm_fp4_override_shape( - a: torch.Tensor, - b: torch.Tensor, - a_descale: torch.Tensor, - b_descale: torch.Tensor, - alpha: Optional[torch.Tensor] = None, - out_dtype: torch.dtype = torch.bfloat16, - out: Optional[torch.Tensor] = None, - block_size: int = 16, - use_nvfp4: bool = True, - workspace_buffer: torch.Tensor = None, - tactic: int = 0, -): - """FP4 GEMM via cuDNN using override-shape for dynamic M dimension.""" - graph = _get_cudnn_fp4_gemm_graph_override_shape( - a=a, - b=b, - a_descale=a_descale, - b_descale=b_descale, - alpha=alpha, - out_dtype=out_dtype, - out=out, - block_size=block_size, - use_nvfp4=use_nvfp4, - ) - - real_a_shape, real_a_stride = _get_real_fp4_shape_from_packed_uint8(a) - real_b_shape, real_b_stride = _get_real_fp4_shape_from_packed_uint8(b) - batch = real_a_shape[0] - expanded_a_descale_shape, expanded_a_descale_stride = ( - _expand_block_scale_tensor_shape(a_descale, batch) - ) - expanded_b_descale_shape, expanded_b_descale_stride = ( - _expand_block_scale_tensor_shape(b_descale, batch) - ) - - a_3d = a.view(real_a_shape) if a.ndim == 2 else a - b_3d = b.view(real_b_shape) if b.ndim == 2 else b - a_descale_3d = ( - a_descale.view(expanded_a_descale_shape) if a_descale.ndim == 2 else a_descale - ) - b_descale_3d = ( - b_descale.view(expanded_b_descale_shape) if b_descale.ndim == 2 else b_descale - ) - out_3d = out.unsqueeze(0) if out.ndim == 2 else out - - execute_cudnn_fp4_gemm_graph_override_shape( - graph, - a_3d, - b_3d, - a_descale_3d, - b_descale_3d, - alpha, - out_3d, - workspace_buffer, - tactic=tactic, - ) - - def execute_cudnn_gemm_mxfp8_graph( graph, a, @@ -2770,48 +2712,6 @@ def execute_cudnn_gemm_with_per_tensor_q_graph_override_shape( ) -def _cudnn_gemm_fp8_override_shape( - workspace: torch.Tensor, - a: torch.Tensor, - b: torch.Tensor, - a_scale: torch.Tensor, - b_scale: torch.Tensor, - out: Optional[torch.Tensor], - torch_out_dtype: torch.dtype, - tactic: int = 0, -): - """FP8 per-tensor GEMM via cuDNN using override-shape for dynamic M.""" - _check_cudnn_availability() - - # Expand 2-D tensors to 3-D for cuDNN - a_3d_shape, _ = _get_bf16_3d_shape_stride(a) - b_3d_shape, _ = _get_bf16_3d_shape_stride(b) - out_3d_shape, _ = _get_bf16_3d_shape_stride(out) - - batch = a_3d_shape[0] - n = b_3d_shape[2] - k = a_3d_shape[2] - - a_3d = a.view(a_3d_shape) if a.ndim == 2 else a - b_3d = b.view(b_3d_shape) if b.ndim == 2 else b - out_3d = out.view(out_3d_shape) if out.ndim == 2 else out - - graph = build_cudnn_gemm_with_per_tensor_q_graph_override_shape( - batch, - n, - k, - _torch_data_type_to_cudnn_data_type(a.dtype), - _torch_data_type_to_cudnn_data_type(b.dtype), - _torch_data_type_to_cudnn_data_type(torch_out_dtype), - a.device, - ) - - execute_cudnn_gemm_with_per_tensor_q_graph_override_shape( - graph, a_3d, b_3d, a_scale, b_scale, out_3d, workspace, tactic=tactic - ) - return out - - def _torch_data_type_to_cudnn_data_type(dtype: torch.dtype): if dtype == torch.bfloat16: return cudnn.data_type.BFLOAT16 @@ -3067,49 +2967,6 @@ def execute_cudnn_gemm_bf16_graph_override_shape( ) -def _cudnn_gemm_bf16_override_shape( - workspace: torch.Tensor, - a: torch.Tensor, - b: torch.Tensor, - out: torch.Tensor, - tactic: int = 0, -): - """BF16 GEMM via cuDNN using override-shape for dynamic M dimension. - - A single plan compiled with ``_OVERRIDE_SHAPE_CACHE_M`` as M is reused - for all M values without triggering a graph rebuild. - """ - _check_cudnn_availability() - - # Both mm (2-D) and bmm (3-D) are supported via the existing - # _get_bf16_3d_shape_stride helper which pads 2-D inputs to 3-D. - a_3d_shape, _ = _get_bf16_3d_shape_stride(a) - b_3d_shape, _ = _get_bf16_3d_shape_stride(b) - out_3d_shape, _ = _get_bf16_3d_shape_stride(out) - - batch = a_3d_shape[0] - n = b_3d_shape[2] - k = a_3d_shape[2] - - # Ensure 3-D contiguous views for the cuDNN call - a_3d = a.view(a_3d_shape) if a.ndim == 2 else a - b_3d = b.view(b_3d_shape) if b.ndim == 2 else b - out_3d = out.view(out_3d_shape) if out.ndim == 2 else out - - graph = build_cudnn_gemm_bf16_graph_override_shape( - batch, - n, - k, - _torch_data_type_to_cudnn_data_type(out.dtype), - a.device, - ) - - execute_cudnn_gemm_bf16_graph_override_shape( - graph, a_3d, b_3d, out_3d, workspace, tactic=tactic - ) - return out - - def _cudnn_gemm_bf16( workspace: torch.Tensor, a: torch.Tensor, diff --git a/tests/gemm/test_cudnn_override_shape.py b/tests/gemm/test_cudnn_override_shape.py index 630295e897..bba1dbded0 100644 --- a/tests/gemm/test_cudnn_override_shape.py +++ b/tests/gemm/test_cudnn_override_shape.py @@ -41,10 +41,10 @@ def _skip_if_override_shape_not_supported(): ) -def _skip_if_not_sm100(): +def _skip_if_not_sm100_or_sm103(): major, minor = get_compute_capability(torch.device("cuda")) - if major * 10 + minor < 100: - pytest.skip("override-shape GEMM requires SM100+ (Blackwell)") + if major * 10 + minor not in [100, 103]: + pytest.skip("override-shape GEMM requires SM100 or SM103") # ============================================================================ @@ -55,6 +55,10 @@ def _skip_if_not_sm100(): class TestCudnnBf16OverrideShape: """Single compiled plan handles multiple M dimensions for BF16 GEMM.""" + @pytest.mark.skipif( + not CUDNN_AVAILABLE, + reason="cuDNN not available", + ) @pytest.mark.parametrize( "cache_m,dynamic_ms", [ @@ -67,7 +71,7 @@ class TestCudnnBf16OverrideShape: def test_bf16_override_shape_dynamic_m(self, cache_m, dynamic_ms, n, k): _skip_if_no_cudnn() _skip_if_override_shape_not_supported() - _skip_if_not_sm100() + _skip_if_not_sm100_or_sm103() from flashinfer.gemm.gemm_base import _torch_data_type_to_cudnn_data_type @@ -134,7 +138,7 @@ class TestCudnnNVFp4OverrideShape: def test_nvfp4_override_shape_dynamic_m(self, cache_m, dynamic_ms, n, k): _skip_if_no_cudnn() _skip_if_override_shape_not_supported() - _skip_if_not_sm100() + _skip_if_not_sm100_or_sm103() import cudnn @@ -281,7 +285,7 @@ class TestCudnnMXFp8OverrideShape: def test_mxfp8_override_shape_dynamic_m(self, cache_m, dynamic_ms, n, k): _skip_if_no_cudnn() _skip_if_override_shape_not_supported() - _skip_if_not_sm100() + _skip_if_not_sm100_or_sm103() import cudnn from flashinfer.gemm.gemm_base import _torch_data_type_to_cudnn_data_type @@ -291,8 +295,8 @@ def test_mxfp8_override_shape_dynamic_m(self, cache_m, dynamic_ms, n, k): out_dtype = torch.bfloat16 # Compute block scale dims using cache_m - block_scale_dim_m_cache, block_scale_dim_n, block_scale_dim_k = ( - _calculate_block_scale_dims(cache_m, n, k, block_size) + _, block_scale_dim_n, block_scale_dim_k = _calculate_block_scale_dims( + cache_m, n, k, block_size ) # Build graph once with cache_m @@ -314,10 +318,10 @@ def test_mxfp8_override_shape_dynamic_m(self, cache_m, dynamic_ms, n, k): device=device, ) - # Use values 0–126 to avoid NaN FP8_E4M3 bit patterns (0x7F, 0xFF). - b = torch.randint( - 0, 127, (1, n, k), dtype=torch.uint8, device=device - ).transpose(1, 2) + # Use all possible values 0–255, but reset NaN FP8_E4M3 bit patterns (0x7F, 0xFF) to 0. + b = torch.randint(0, 256, (1, n, k), dtype=torch.uint8, device=device) + b[(b == 0x7F) | (b == 0xFF)] = 0 + b = b.transpose(1, 2) b_descale = torch.ones( 1, block_scale_dim_n, @@ -329,8 +333,9 @@ def test_mxfp8_override_shape_dynamic_m(self, cache_m, dynamic_ms, n, k): for m in dynamic_ms: block_scale_dim_m, _, _ = _calculate_block_scale_dims(m, n, k, block_size) - # Use values 0–126 to avoid NaN FP8_E4M3 bit patterns (0x7F, 0xFF). - a = torch.randint(0, 127, (1, m, k), dtype=torch.uint8, device=device) + # Use all possible values 0–255, but reset NaN FP8_E4M3 bit patterns (0x7F, 0xFF) to 0. + a = torch.randint(0, 256, (1, m, k), dtype=torch.uint8, device=device) + a[(a == 0x7F) | (a == 0xFF)] = 0 a_descale = torch.ones( 1, block_scale_dim_m, From ffb9f95245c491a22ae7088bff28bb6bebb42c5b Mon Sep 17 00:00:00 2001 From: Yanqin Zhai Date: Tue, 17 Mar 2026 20:57:02 -0700 Subject: [PATCH 10/11] add-random-nums-for_bs-tests --- tests/gemm/test_cudnn_override_shape.py | 147 ++++++++---------------- 1 file changed, 49 insertions(+), 98 deletions(-) diff --git a/tests/gemm/test_cudnn_override_shape.py b/tests/gemm/test_cudnn_override_shape.py index bba1dbded0..e3e3b94aa1 100644 --- a/tests/gemm/test_cudnn_override_shape.py +++ b/tests/gemm/test_cudnn_override_shape.py @@ -12,6 +12,7 @@ import pytest import torch +import torch.nn.functional as F from flashinfer.gemm.gemm_base import ( CUDNN_AVAILABLE, @@ -25,6 +26,8 @@ _calculate_block_scale_dims, ) from flashinfer.utils import get_compute_capability +from flashinfer.fp4_quantization import nvfp4_quantize +from flashinfer.fp8_quantization import mxfp8_quantize def _skip_if_no_cudnn(): @@ -175,56 +178,27 @@ def test_nvfp4_override_shape_dynamic_m(self, cache_m, dynamic_ms, n, k): device=device, ) - # FP4 E2M1 lookup table: index is the 4-bit pattern (0–15) - # Encoding: sign(1) | exp(2) | mantissa(1), bias=1 - FP4_E2M1_LUT = torch.tensor( - [ - 0.0, - 0.5, - 1.0, - 1.5, - 2.0, - 3.0, - 4.0, - 6.0, - -0.0, - -0.5, - -1.0, - -1.5, - -2.0, - -3.0, - -4.0, - -6.0, - ], - dtype=torch.float32, - device=device, - ) + global_sf = torch.tensor(1.0, dtype=torch.float32, device=device) # B is fixed across all dynamic_ms - b_packed = torch.randint( - 0, 256, (1, n, k // 2), dtype=torch.uint8, device=device - ).transpose(1, 2) - b_descale = torch.ones( - 1, - block_scale_dim_n, - block_scale_dim_k, - dtype=torch.float8_e4m3fn, - device=device, - ).transpose(1, 2) + b_bf16 = torch.empty([1, n, k], device="cuda", dtype=torch.bfloat16).uniform_( + -5.0, 5.0 + ) + b_packed, b_scale = nvfp4_quantize(b_bf16, global_sf, True) + + b_bf16 = b_bf16.transpose(1, 2) + b_packed = b_packed.transpose(1, 2) + b_scale = b_scale.unsqueeze(0).transpose(1, 2) for m in dynamic_ms: block_scale_dim_m, _, _ = _calculate_block_scale_dims(m, n, k, block_size) - a_packed = torch.randint( - 0, 256, (1, m, k // 2), dtype=torch.uint8, device=device - ) - a_descale = torch.ones( - 1, - block_scale_dim_m, - block_scale_dim_k, - dtype=torch.float8_e4m3fn, - device=device, - ) + a_bf16 = torch.empty( + [1, m, k], device="cuda", dtype=torch.bfloat16 + ).uniform_(-5.0, 5.0) + a_packed, a_scale = nvfp4_quantize(a_bf16, global_sf, True) + + a_scale = a_scale.unsqueeze(0) # Execute with cached graph (override_shape) out = torch.empty(1, m, n, dtype=out_dtype, device=device) @@ -232,8 +206,8 @@ def test_nvfp4_override_shape_dynamic_m(self, cache_m, dynamic_ms, n, k): graph, a_packed, b_packed, - a_descale, - b_descale, + a_scale, + b_scale, alpha=None, c_final=out, workspace_buffer=workspace, @@ -241,23 +215,12 @@ def test_nvfp4_override_shape_dynamic_m(self, cache_m, dynamic_ms, n, k): ) torch.cuda.synchronize() - # Correctness check: dequantize FP4 E2M1 via LUT and compare with - # FP32 bmm reference. Descales are all 1.0, so no scaling needed. - # A packing: a_packed (1, m, k//2), low nibble = even k, high = odd k - a_fp32 = torch.empty(1, m, k, dtype=torch.float32, device=device) - a_fp32[:, :, 0::2] = FP4_E2M1_LUT[(a_packed & 0x0F).long()] - a_fp32[:, :, 1::2] = FP4_E2M1_LUT[((a_packed >> 4) & 0x0F).long()] - # B packing: b_packed (1, k//2, n), low nibble = even k, high = odd k - b_fp32 = torch.empty(1, k, n, dtype=torch.float32, device=device) - b_fp32[:, 0::2, :] = FP4_E2M1_LUT[(b_packed & 0x0F).long()] - b_fp32[:, 1::2, :] = FP4_E2M1_LUT[((b_packed >> 4) & 0x0F).long()] - ref = torch.bmm(a_fp32, b_fp32).to(out_dtype) - - assert torch.allclose(ref, out, rtol=1e-1, atol=1.0), ( - f"NVFP4 override_shape failed for m={m}, n={n}, k={k}: " - f"max_abs_err={(ref - out).abs().max().item():.4f}, " - f"max_rel_err=" - f"{((ref - out).abs() / (ref.abs() + 1e-8)).max().item():.4f}" + ref = torch.bmm(a_bf16, b_bf16).to(out_dtype) + + min_cos_sim = 0.9 + cos_sim = F.cosine_similarity(ref.reshape(-1), out.reshape(-1), dim=0) + assert cos_sim > min_cos_sim, ( + f"Cosine similarity {cos_sim:.4f} is too low (expected > {min_cos_sim})" ) @@ -318,31 +281,27 @@ def test_mxfp8_override_shape_dynamic_m(self, cache_m, dynamic_ms, n, k): device=device, ) - # Use all possible values 0–255, but reset NaN FP8_E4M3 bit patterns (0x7F, 0xFF) to 0. - b = torch.randint(0, 256, (1, n, k), dtype=torch.uint8, device=device) - b[(b == 0x7F) | (b == 0xFF)] = 0 + # B is fixed across all dynamic_ms + b_bf16 = torch.empty([1, n, k], device="cuda", dtype=torch.bfloat16).uniform_( + -5.0, 5.0 + ) + b, b_scale = mxfp8_quantize(b_bf16, True) + + b_bf16 = b_bf16.transpose(1, 2) b = b.transpose(1, 2) - b_descale = torch.ones( - 1, - block_scale_dim_n, - block_scale_dim_k, - dtype=torch.float8_e8m0fnu, - device=device, - ).transpose(1, 2) + b_scale = b_scale.reshape((-1, block_scale_dim_n, block_scale_dim_k)).transpose( + 1, 2 + ) for m in dynamic_ms: block_scale_dim_m, _, _ = _calculate_block_scale_dims(m, n, k, block_size) - # Use all possible values 0–255, but reset NaN FP8_E4M3 bit patterns (0x7F, 0xFF) to 0. - a = torch.randint(0, 256, (1, m, k), dtype=torch.uint8, device=device) - a[(a == 0x7F) | (a == 0xFF)] = 0 - a_descale = torch.ones( - 1, - block_scale_dim_m, - block_scale_dim_k, - dtype=torch.float8_e8m0fnu, - device=device, - ) + a_bf16 = torch.empty( + [1, m, k], device="cuda", dtype=torch.bfloat16 + ).uniform_(-5.0, 5.0) + a, a_scale = mxfp8_quantize(a_bf16, True) + + a_scale = a_scale.reshape((-1, block_scale_dim_m, block_scale_dim_k)) # Execute with cached graph (override_shape) out = torch.empty(1, m, n, dtype=out_dtype, device=device) @@ -350,26 +309,18 @@ def test_mxfp8_override_shape_dynamic_m(self, cache_m, dynamic_ms, n, k): graph, a, b, - a_descale, - b_descale, + a_scale, + b_scale, c_final=out, workspace_buffer=workspace, tactic=0, ) torch.cuda.synchronize() - # Correctness check: reinterpret uint8 as FP8_E4M3, compute FP32 - # bmm reference. Descales are all 1.0 (2^0), so no scaling needed. - # A: (1, m, k) contiguous uint8 → float8_e4m3fn → float32 - a_fp32 = a.view(torch.float8_e4m3fn).to(torch.float32) - # B logical shape is (1, k, n) with stride [n*k, 1, k]; make - # contiguous before view so dtype reinterpretation is valid. - b_fp32 = b.contiguous().view(torch.float8_e4m3fn).to(torch.float32) - ref = torch.bmm(a_fp32, b_fp32).to(out_dtype) + ref = torch.bmm(a_bf16, b_bf16).to(out_dtype) - assert torch.allclose(ref, out, rtol=5e-2, atol=5e-2), ( - f"MXFP8 override_shape failed for m={m}, n={n}, k={k}: " - f"max_abs_err={(ref - out).abs().max().item():.4f}, " - f"max_rel_err=" - f"{((ref - out).abs() / (ref.abs() + 1e-8)).max().item():.4f}" + min_cos_sim = 0.9 + cos_sim = F.cosine_similarity(ref.reshape(-1), out.reshape(-1), dim=0) + assert cos_sim > min_cos_sim, ( + f"Cosine similarity {cos_sim:.4f} is too low (expected > {min_cos_sim})" ) From 1fa66febb21644dca3e72b01b1c1b82983f1d091 Mon Sep 17 00:00:00 2001 From: Yanqin Zhai Date: Wed, 18 Mar 2026 12:26:37 -0700 Subject: [PATCH 11/11] fix-pre-commit-conflict --- flashinfer/gemm/gemm_base.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index fcfd443db2..22a358c6da 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -2045,7 +2045,7 @@ def build_cudnn_fp4_gemm_graph_override_shape( io_data_type=cudnn.data_type.FLOAT, intermediate_data_type=cudnn.data_type.FLOAT, compute_data_type=cudnn.data_type.FLOAT, - handle=_get_cudnn_handle(stream), + handle=_get_cudnn_handle(device, stream), is_override_shape_enabled=True, ) @@ -2184,7 +2184,7 @@ def execute_cudnn_fp4_gemm_graph_override_shape( variant_pack, workspace_buffer, tactic, - handle=_get_cudnn_handle(stream), + handle=_get_cudnn_handle(a.device, stream), override_uids=override_uids, override_shapes=override_shapes, override_strides=override_strides, @@ -2331,7 +2331,7 @@ def build_cudnn_mxfp8_gemm_graph_override_shape( io_data_type=cudnn.data_type.FLOAT, intermediate_data_type=cudnn.data_type.FLOAT, compute_data_type=cudnn.data_type.FLOAT, - handle=_get_cudnn_handle(stream), + handle=_get_cudnn_handle(device, stream), is_override_shape_enabled=True, ) @@ -2441,7 +2441,7 @@ def execute_cudnn_mxfp8_gemm_graph_override_shape( variant_pack, workspace_buffer, tactic, - handle=_get_cudnn_handle(stream), + handle=_get_cudnn_handle(a.device, stream), override_uids=override_uids, override_shapes=override_shapes, override_strides=override_strides, @@ -2635,7 +2635,7 @@ def build_cudnn_gemm_with_per_tensor_q_graph_override_shape( io_data_type=cudnn.data_type.FLOAT, intermediate_data_type=cudnn.data_type.FLOAT, compute_data_type=cudnn.data_type.FLOAT, - handle=_get_cudnn_handle(stream), + handle=_get_cudnn_handle(device, stream), is_override_shape_enabled=True, ) @@ -2710,7 +2710,7 @@ def execute_cudnn_gemm_with_per_tensor_q_graph_override_shape( override_strides = [list(a.stride()), list(b.stride()), list(c_final.stride())] stream = torch.cuda.current_stream(a.device) - cudnn_handle = _get_cudnn_handle(stream) + cudnn_handle = _get_cudnn_handle(a.device, stream) if workspace.numel() < graph.get_workspace_size(): workspace = torch.empty( @@ -2910,7 +2910,7 @@ def build_cudnn_gemm_bf16_graph_override_shape( io_data_type=cudnn.data_type.BFLOAT16, intermediate_data_type=cudnn.data_type.FLOAT, compute_data_type=cudnn.data_type.FLOAT, - handle=_get_cudnn_handle(stream), + handle=_get_cudnn_handle(device, stream), is_override_shape_enabled=True, ) @@ -2967,7 +2967,7 @@ def execute_cudnn_gemm_bf16_graph_override_shape( override_strides = [list(a.stride()), list(b.stride()), list(c_final.stride())] stream = torch.cuda.current_stream(a.device) - cudnn_handle = _get_cudnn_handle(stream) + cudnn_handle = _get_cudnn_handle(a.device, stream) if workspace.numel() < graph.get_workspace_size(): workspace = torch.empty(