diff --git a/flashinfer/gemm/__init__.py b/flashinfer/gemm/__init__.py index d0fd71476f..3d4555c28e 100644 --- a/flashinfer/gemm/__init__.py +++ b/flashinfer/gemm/__init__.py @@ -18,6 +18,17 @@ from .gemm_base import gemm_fp8_nt_groupwise as gemm_fp8_nt_groupwise from .gemm_base import group_gemm_fp8_nt_groupwise as group_gemm_fp8_nt_groupwise from .gemm_base import fp8_blockscale_gemm_sm90 as fp8_blockscale_gemm_sm90 +from .gemm_base import ( + is_cudnn_override_shape_available as is_cudnn_override_shape_available, + build_cudnn_gemm_bf16_graph_override_shape as build_cudnn_gemm_bf16_graph_override_shape, + execute_cudnn_gemm_bf16_graph_override_shape as execute_cudnn_gemm_bf16_graph_override_shape, + build_cudnn_fp4_gemm_graph_override_shape as build_cudnn_fp4_gemm_graph_override_shape, + execute_cudnn_fp4_gemm_graph_override_shape as execute_cudnn_fp4_gemm_graph_override_shape, + build_cudnn_mxfp8_gemm_graph_override_shape as build_cudnn_mxfp8_gemm_graph_override_shape, + execute_cudnn_mxfp8_gemm_graph_override_shape as execute_cudnn_mxfp8_gemm_graph_override_shape, + build_cudnn_gemm_with_per_tensor_q_graph_override_shape as build_cudnn_gemm_with_per_tensor_q_graph_override_shape, + execute_cudnn_gemm_with_per_tensor_q_graph_override_shape as execute_cudnn_gemm_with_per_tensor_q_graph_override_shape, +) from .routergemm import ( mm_M1_16_K7168_N128 as mm_M1_16_K7168_N128, @@ -65,4 +76,13 @@ "mm_M1_16_K7168_N128", "mm_M1_16_K7168_N256", "tinygemm_bf16", + "is_cudnn_override_shape_available", + "build_cudnn_gemm_bf16_graph_override_shape", + "execute_cudnn_gemm_bf16_graph_override_shape", + "build_cudnn_fp4_gemm_graph_override_shape", + "execute_cudnn_fp4_gemm_graph_override_shape", + "build_cudnn_mxfp8_gemm_graph_override_shape", + "execute_cudnn_mxfp8_gemm_graph_override_shape", + "build_cudnn_gemm_with_per_tensor_q_graph_override_shape", + "execute_cudnn_gemm_with_per_tensor_q_graph_override_shape", ] + _cute_dsl_kernels diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index 30ef414daf..22a358c6da 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -1722,6 +1722,45 @@ def _is_cublas_fp4_available_in_cudnn(): ) +def _check_cudnn_override_shape_availability(): + """Raise if the installed cuDNN backend does not support is_override_shape_enabled.""" + _check_cudnn_availability() + backend_version = cudnn.backend_version() + if backend_version < 92100: + raise RuntimeError( + f"cuDNN override-shape GEMM requires backend version >= 92100 (9.21.0), " + f"found {backend_version}. " + f"Please upgrade cuDNN: pip install --upgrade nvidia-cudnn-cu12 nvidia-cudnn-frontend" + ) + try: + version_str = cudnn.__version__ + major, minor = map(int, version_str.split(".")[:2]) + if (major, minor) < (1, 20): + raise RuntimeError( + f"cuDNN override-shape GEMM requires cudnn-frontend version >= 1.20, found {version_str}. " + f"Please upgrade: pip install --upgrade nvidia-cudnn-frontend" + ) + except (AttributeError, ValueError, IndexError) as e: + raise RuntimeError( + "Unable to determine cudnn-frontend version. " + "Override-shape GEMM requires cudnn-frontend >= 1.20" + ) from e + + +def is_cudnn_override_shape_available() -> bool: + """Return True if the installed cuDNN backend supports is_override_shape_enabled.""" + if not CUDNN_AVAILABLE: + return False + try: + if cudnn.backend_version() < 92100: + return False + version_str = cudnn.__version__ + major, minor = map(int, version_str.split(".")[:2]) + return (major, minor) >= (1, 20) + except Exception: + return False + + # One cudnn handle per each GPU _cudnn_handles: dict[int, int] = {} @@ -1945,6 +1984,262 @@ def execute_cudnn_gemm_fp4_graph( ) +# --------------------------------------------------------------------------- +# override_shape shared constant +# --------------------------------------------------------------------------- + +# Sentinel value used as "cache M" when building override-shape graphs. +# Any M value will work in general. +# 8192 covers typical LLM inference shapes and set as default value. +_OVERRIDE_SHAPE_CACHE_M = 8192 + +# --------------------------------------------------------------------------- +# FP4 GEMM with override_shape (dynamic M dimension) +# --------------------------------------------------------------------------- + + +@functools.cache +def build_cudnn_fp4_gemm_graph_override_shape( + batch, + n, + k, + ab_type, + o_type, + block_size, + device, + alpha_is_not_none, + use_nvfp4, + cache_m: int = _OVERRIDE_SHAPE_CACHE_M, +): + """Build a cuDNN FP4 GEMM graph with override-shape support. + + The graph is compiled once using ``cache_m`` as the M dimension. Block + scale dimensions are derived from ``cache_m`` at compile time; at execution + time the real M (and corresponding scale dims) are passed via + ``override_shapes`` / ``override_strides``. + + Caching key contains ``(batch, n, k, ...)`` but **not** M. + """ + _check_cudnn_override_shape_availability() + + scale_type = cudnn.data_type.FP8_E4M3 if use_nvfp4 else cudnn.data_type.FP8_E8M0 + + # Build shapes / strides using cache_m + block_scale_dim_m, block_scale_dim_n, block_scale_dim_k = ( + _calculate_block_scale_dims(cache_m, n, k, block_size) + ) + + a_shape = [batch, cache_m, k] + a_stride = [cache_m * k, k, 1] + + b_shape = [batch, k, n] + b_stride = [k * n, 1, k] + + a_descale_shape = [batch, block_scale_dim_m, block_scale_dim_k] + a_descale_stride = [block_scale_dim_m * block_scale_dim_k, block_scale_dim_k, 1] + b_descale_shape = [batch, block_scale_dim_k, block_scale_dim_n] + b_descale_stride = [block_scale_dim_n * block_scale_dim_k, 1, block_scale_dim_k] + + stream = torch.cuda.current_stream(device) + graph = cudnn.pygraph( + io_data_type=cudnn.data_type.FLOAT, + intermediate_data_type=cudnn.data_type.FLOAT, + compute_data_type=cudnn.data_type.FLOAT, + handle=_get_cudnn_handle(device, stream), + is_override_shape_enabled=True, + ) + + a_cudnn_tensor = graph.tensor( + name="a", dim=a_shape, stride=a_stride, data_type=ab_type + ) + b_cudnn_tensor = graph.tensor( + name="b", dim=b_shape, stride=b_stride, data_type=ab_type + ) + block_descale_a_cudnn_tensor = graph.tensor( + name="block_descale_a", + dim=a_descale_shape, + stride=a_descale_stride, + data_type=scale_type, + reordering_type=cudnn.tensor_reordering.F8_128x4, + ) + block_descale_b_cudnn_tensor = graph.tensor( + name="block_descale_b", + dim=b_descale_shape, + stride=b_descale_stride, + data_type=scale_type, + reordering_type=cudnn.tensor_reordering.F8_128x4, + ) + + dequant_a_tensor = graph.block_scale_dequantize( + a_cudnn_tensor, + block_descale_a_cudnn_tensor, + block_size=[1, block_size], + name="dequant_a", + ) + dequant_a_tensor.set_data_type(cudnn.data_type.FLOAT) + dequant_b_tensor = graph.block_scale_dequantize( + b_cudnn_tensor, + block_descale_b_cudnn_tensor, + block_size=[block_size, 1], + name="dequant_b", + ) + dequant_b_tensor.set_data_type(cudnn.data_type.FLOAT) + c_tensor = graph.matmul( + dequant_a_tensor, + dequant_b_tensor, + compute_data_type=cudnn.data_type.FLOAT, + name="gemm", + ) + c_tensor.set_data_type(cudnn.data_type.FLOAT) + + c_final_cudnn_tensor = c_tensor + + if alpha_is_not_none: + global_scale_cudnn_tensor = graph.tensor( + name="global_scale", + dim=[1, 1, 1], + stride=[1, 1, 1], + data_type=cudnn.data_type.FLOAT, + ) + c_final_cudnn_tensor = graph.mul( + name="scale_mul", + a=c_tensor, + b=global_scale_cudnn_tensor, + compute_data_type=cudnn.data_type.FLOAT, + ) + global_scale_cudnn_tensor.set_uid(UIDs.ALPHA_UID.value) + + c_final_cudnn_tensor.set_name("c_final").set_output(True).set_data_type(o_type) + + a_cudnn_tensor.set_uid(UIDs.A_UID.value) + b_cudnn_tensor.set_uid(UIDs.B_UID.value) + block_descale_a_cudnn_tensor.set_uid(UIDs.BLOCK_DESCALE_A_UID.value) + block_descale_b_cudnn_tensor.set_uid(UIDs.BLOCK_DESCALE_B_UID.value) + c_final_cudnn_tensor.set_uid(UIDs.O_UID.value) + + graph.validate() + graph.build_operation_graph() + graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.B]) + + if alpha_is_not_none and not _is_cublas_fp4_available_in_cudnn(): + graph.deselect_engines(["eng0"]) + + graph.check_support() + graph.build_plans() + + return graph + + +def execute_cudnn_fp4_gemm_graph_override_shape( + graph, + a, + b, + a_descale, + b_descale, + alpha, + c_final, + workspace_buffer, + tactic: int = 0, +): + """Execute FP4 GEMM cuDNN graph with dynamic-shape overrides.""" + + assert a.stride()[2] == 1 and b.stride()[1] == 1, "a and b must be k-major" + + variant_pack = { + UIDs.A_UID.value: a, + UIDs.B_UID.value: b, + UIDs.BLOCK_DESCALE_A_UID.value: a_descale, + UIDs.BLOCK_DESCALE_B_UID.value: b_descale, + UIDs.O_UID.value: c_final, + } + + if alpha is not None: + variant_pack[UIDs.ALPHA_UID.value] = alpha.view(torch.float) + + override_uids = [ + UIDs.A_UID.value, + UIDs.B_UID.value, + UIDs.BLOCK_DESCALE_A_UID.value, + UIDs.BLOCK_DESCALE_B_UID.value, + UIDs.O_UID.value, + ] + override_shapes = [ + [a.shape[0], a.shape[1], a.shape[2] * 2], + [b.shape[0], b.shape[1] * 2, b.shape[2]], + a_descale.shape, + b_descale.shape, + c_final.shape, + ] + override_strides = [ + [a.stride()[0], a.stride()[1] * 2, a.stride()[2]], + [b.stride()[0], b.stride()[1], b.stride()[2] * 2], + a_descale.stride(), + b_descale.stride(), + c_final.stride(), + ] + + stream = torch.cuda.current_stream(a.device) + + graph.execute_plan_at_index( + variant_pack, + workspace_buffer, + tactic, + handle=_get_cudnn_handle(a.device, stream), + override_uids=override_uids, + override_shapes=override_shapes, + override_strides=override_strides, + ) + + +def _get_cudnn_fp4_gemm_graph_override_shape( + a: torch.Tensor, + b: torch.Tensor, + a_descale: torch.Tensor, + b_descale: torch.Tensor, + alpha: Optional[torch.Tensor] = None, + out_dtype: torch.dtype = torch.bfloat16, + out: Optional[torch.Tensor] = None, + block_size: int = 16, + use_nvfp4: bool = True, +): + """Get (or build) a cached FP4 GEMM graph with override-shape support. + + The cache key excludes M so the same compiled plan is reused for all M + values. + """ + real_a_shape, _ = _get_real_fp4_shape_from_packed_uint8(a) + real_b_shape, _ = _get_real_fp4_shape_from_packed_uint8(b) + batch = real_a_shape[0] + n = real_b_shape[2] + k = real_a_shape[2] + + expanded_a_descale_shape, _ = _expand_block_scale_tensor_shape(a_descale, batch) + expanded_b_descale_shape, _ = _expand_block_scale_tensor_shape(b_descale, batch) + + # Scale dimension sizes that are independent of M + a_descale_k_dim = expanded_a_descale_shape[2] + b_descale_k_dim = expanded_b_descale_shape[1] + b_descale_n_dim = expanded_b_descale_shape[2] + # a_descale N-dimension (dim[1]) depends on M, so we pass it separately + a_descale_n_dim = expanded_a_descale_shape[1] + + return build_cudnn_fp4_gemm_graph_override_shape( + batch=batch, + n=n, + k=k, + a_descale_n_dim=a_descale_n_dim, + a_descale_k_dim=a_descale_k_dim, + b_descale_k_dim=b_descale_k_dim, + b_descale_n_dim=b_descale_n_dim, + ab_type=cudnn.data_type.FP4_E2M1, + o_type=_torch_data_type_to_cudnn_data_type(out_dtype), + block_size=block_size, + device=a.device, + alpha_is_not_none=alpha is not None, + use_nvfp4=use_nvfp4, + ) + + def execute_cudnn_gemm_mxfp8_graph( graph, a, @@ -1985,6 +2280,232 @@ def execute_cudnn_gemm_mxfp8_graph( ) +# --------------------------------------------------------------------------- +# MXFP8 GEMM with override_shape (dynamic M dimension) +# --------------------------------------------------------------------------- + + +@functools.cache +def build_cudnn_mxfp8_gemm_graph_override_shape( + batch, + n, + k, + a_type, + b_type, + o_type, + block_size, + device, + cache_m: int = _OVERRIDE_SHAPE_CACHE_M, +): + """Build a cuDNN MXFP8 GEMM graph with override-shape support. + + Compiled once using ``cache_m`` as M; at execution time the actual M is + provided through ``override_shapes`` / ``override_strides``. + """ + _check_cudnn_override_shape_availability() + + if a_type not in [cudnn.data_type.FP8_E4M3, cudnn.data_type.FP8_E5M2]: + raise ValueError(f"A type must be FP8_E4M3 or FP8_E5M2, got {a_type}") + if b_type not in [cudnn.data_type.FP8_E4M3, cudnn.data_type.FP8_E5M2]: + raise ValueError(f"B type must be FP8_E4M3 or FP8_E5M2, got {b_type}") + if o_type not in [cudnn.data_type.BFLOAT16, cudnn.data_type.HALF]: + raise ValueError(f"Output type must be BF16 or FP16, got {o_type}") + + block_scale_dim_m, block_scale_dim_n, block_scale_dim_k = ( + _calculate_block_scale_dims(cache_m, n, k, block_size) + ) + + scale_type = cudnn.data_type.FP8_E8M0 + + a_shape = [batch, cache_m, k] + a_stride = [cache_m * k, k, 1] + b_shape = [batch, k, n] + b_stride = [k * n, 1, k] + a_descale_shape = [batch, block_scale_dim_m, block_scale_dim_k] + a_descale_stride = [block_scale_dim_m * block_scale_dim_k, block_scale_dim_k, 1] + b_descale_shape = [batch, block_scale_dim_k, block_scale_dim_n] + b_descale_stride = [block_scale_dim_n * block_scale_dim_k, 1, block_scale_dim_k] + + stream = torch.cuda.current_stream(device) + graph = cudnn.pygraph( + io_data_type=cudnn.data_type.FLOAT, + intermediate_data_type=cudnn.data_type.FLOAT, + compute_data_type=cudnn.data_type.FLOAT, + handle=_get_cudnn_handle(device, stream), + is_override_shape_enabled=True, + ) + + a_cudnn_tensor = graph.tensor( + name="a", dim=a_shape, stride=a_stride, data_type=a_type + ) + b_cudnn_tensor = graph.tensor( + name="b", dim=b_shape, stride=b_stride, data_type=b_type + ) + block_descale_a_cudnn_tensor = graph.tensor( + name="block_descale_a", + dim=a_descale_shape, + stride=a_descale_stride, + data_type=scale_type, + reordering_type=cudnn.tensor_reordering.F8_128x4, + ) + block_descale_b_cudnn_tensor = graph.tensor( + name="block_descale_b", + dim=b_descale_shape, + stride=b_descale_stride, + data_type=scale_type, + reordering_type=cudnn.tensor_reordering.F8_128x4, + ) + + dequant_a_tensor = graph.block_scale_dequantize( + a_cudnn_tensor, + block_descale_a_cudnn_tensor, + block_size=[1, block_size], + name="dequant_a", + ) + dequant_a_tensor.set_data_type(cudnn.data_type.FLOAT) + dequant_b_tensor = graph.block_scale_dequantize( + b_cudnn_tensor, + block_descale_b_cudnn_tensor, + block_size=[block_size, 1], + name="dequant_b", + ) + dequant_b_tensor.set_data_type(cudnn.data_type.FLOAT) + + c_tensor = graph.matmul( + dequant_a_tensor, + dequant_b_tensor, + compute_data_type=cudnn.data_type.FLOAT, + name="gemm", + ) + c_tensor.set_data_type(cudnn.data_type.FLOAT) + c_tensor.set_output(True).set_data_type(o_type) + + a_cudnn_tensor.set_uid(UIDs.A_UID.value) + b_cudnn_tensor.set_uid(UIDs.B_UID.value) + block_descale_a_cudnn_tensor.set_uid(UIDs.BLOCK_DESCALE_A_UID.value) + block_descale_b_cudnn_tensor.set_uid(UIDs.BLOCK_DESCALE_B_UID.value) + c_tensor.set_uid(UIDs.O_UID.value) + + graph.validate() + graph.build_operation_graph() + graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.B]) + graph.check_support() + graph.build_plans() + + return graph + + +def execute_cudnn_mxfp8_gemm_graph_override_shape( + graph, + a, + b, + a_descale, + b_descale, + c_final, + workspace_buffer, + tactic: int = 0, +): + """Execute MXFP8 GEMM cuDNN graph with dynamic-shape overrides.""" + variant_pack = { + UIDs.A_UID.value: a, + UIDs.B_UID.value: b, + UIDs.BLOCK_DESCALE_A_UID.value: a_descale, + UIDs.BLOCK_DESCALE_B_UID.value: b_descale, + UIDs.O_UID.value: c_final, + } + + override_uids = [ + UIDs.A_UID.value, + UIDs.B_UID.value, + UIDs.BLOCK_DESCALE_A_UID.value, + UIDs.BLOCK_DESCALE_B_UID.value, + UIDs.O_UID.value, + ] + override_shapes = [ + list(a.shape), + list(b.shape), + list(a_descale.shape), + list(b_descale.shape), + list(c_final.shape), + ] + override_strides = [ + list(a.stride()), + list(b.stride()), + list(a_descale.stride()), + list(b_descale.stride()), + list(c_final.stride()), + ] + + stream = torch.cuda.current_stream(a.device) + graph.execute_plan_at_index( + variant_pack, + workspace_buffer, + tactic, + handle=_get_cudnn_handle(a.device, stream), + override_uids=override_uids, + override_shapes=override_shapes, + override_strides=override_strides, + ) + + +def _cudnn_gemm_mxfp8_override_shape( + a: torch.Tensor, + b: torch.Tensor, + a_descale: torch.Tensor, + b_descale: torch.Tensor, + out_dtype: torch.dtype = torch.bfloat16, + out: Optional[torch.Tensor] = None, + workspace_buffer: torch.Tensor = None, + tactic: int = 0, +): + """MXFP8 GEMM via cuDNN using override-shape for dynamic M dimension.""" + block_size = 32 + + a_3d = a if a.ndim == 3 else a.unsqueeze(0) + b_3d = b if b.ndim == 3 else b.unsqueeze(0) + out_3d = out if out.ndim == 3 else out.unsqueeze(0) + + batch = a_3d.shape[0] + k = a_3d.shape[2] + n = b_3d.shape[2] + + block_scale_dim_m, block_scale_dim_n, block_scale_dim_k = ( + _calculate_block_scale_dims(a_3d.shape[1], n, k, block_size) + ) + + if a_descale.ndim == 2: + a_descale_3d = a_descale.view(batch, block_scale_dim_m, block_scale_dim_k) + else: + a_descale_3d = a_descale + + if b_descale.ndim == 2: + b_descale_3d = b_descale.view(batch, block_scale_dim_k, block_scale_dim_n) + else: + b_descale_3d = b_descale + + graph = build_cudnn_mxfp8_gemm_graph_override_shape( + batch=batch, + n=n, + k=k, + a_type=_torch_data_type_to_cudnn_data_type(a.dtype), + b_type=_torch_data_type_to_cudnn_data_type(b.dtype), + o_type=_torch_data_type_to_cudnn_data_type(out_dtype), + block_size=block_size, + device=a.device, + ) + + execute_cudnn_mxfp8_gemm_graph_override_shape( + graph=graph, + a=a_3d, + b=b_3d, + a_descale=a_descale_3d, + b_descale=b_descale_3d, + c_final=out_3d, + workspace_buffer=workspace_buffer, + tactic=tactic, + ) + + @functools.lru_cache(maxsize=1024) def build_cudnn_gemm_with_per_tensor_q_graph( a_shape, a_stride, b_shape, b_stride, a_type, b_type, o_type, device @@ -2088,6 +2609,125 @@ def execute_cudnn_gemm_with_per_tensor_q_graph( graph.execute(variant_pack, workspace, handle=cudnn_handle) +# --------------------------------------------------------------------------- +# FP8 per-tensor GEMM with override_shape (dynamic M dimension) +# --------------------------------------------------------------------------- + + +@functools.cache +def build_cudnn_gemm_with_per_tensor_q_graph_override_shape( + batch, n, k, a_type, b_type, o_type, device, cache_m: int = _OVERRIDE_SHAPE_CACHE_M +): + """Build an FP8 per-tensor-quantized GEMM cuDNN graph with override-shape. + + Compiled once with ``cache_m`` as M; at execution time the actual M is + supplied through ``override_shapes`` / ``override_strides``. + """ + _check_cudnn_override_shape_availability() + + a_shape = [batch, cache_m, k] + a_stride = [cache_m * k, k, 1] + b_shape = [batch, k, n] + b_stride = [k * n, 1, k] + + stream = torch.cuda.current_stream(device) + graph = cudnn.pygraph( + io_data_type=cudnn.data_type.FLOAT, + intermediate_data_type=cudnn.data_type.FLOAT, + compute_data_type=cudnn.data_type.FLOAT, + handle=_get_cudnn_handle(device, stream), + is_override_shape_enabled=True, + ) + + a_cudnn_tensor = graph.tensor( + name="a", dim=a_shape, stride=a_stride, data_type=a_type + ) + b_cudnn_tensor = graph.tensor( + name="b", dim=b_shape, stride=b_stride, data_type=b_type + ) + a_scale_cudnn_tensor = graph.tensor( + name="a_scale", + dim=[1, 1, 1], + stride=[1, 1, 1], + data_type=cudnn.data_type.FLOAT, + ) + b_scale_cudnn_tensor = graph.tensor( + name="b_scale", + dim=[1, 1, 1], + stride=[1, 1, 1], + data_type=cudnn.data_type.FLOAT, + ) + c_cudnn_tensor = graph.matmul( + name="matmul", + A=a_cudnn_tensor, + B=b_cudnn_tensor, + compute_data_type=cudnn.data_type.FLOAT, + ) + c_cudnn_tensor.set_name("c").set_data_type(cudnn.data_type.FLOAT) + c_after_scale_a = graph.mul( + name="scale_mul_a", + a=c_cudnn_tensor, + b=a_scale_cudnn_tensor, + compute_data_type=cudnn.data_type.FLOAT, + ) + c_after_scale_b = graph.mul( + name="scale_mul_b", + a=c_after_scale_a, + b=b_scale_cudnn_tensor, + compute_data_type=cudnn.data_type.FLOAT, + ) + c_after_scale_b.set_name("c_final").set_output(True).set_data_type(o_type) + + a_cudnn_tensor.set_uid(UIDs.A_UID.value) + b_cudnn_tensor.set_uid(UIDs.B_UID.value) + a_scale_cudnn_tensor.set_uid(UIDs.A_SCALE_UID.value) + b_scale_cudnn_tensor.set_uid(UIDs.B_SCALE_UID.value) + c_after_scale_b.set_uid(UIDs.O_UID.value) + + graph.validate() + graph.build_operation_graph() + graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) + graph.check_support() + graph.build_plans() + + return graph + + +def execute_cudnn_gemm_with_per_tensor_q_graph_override_shape( + graph, a, b, a_scale, b_scale, c_final, workspace, tactic: int = 0 +): + """Execute FP8 per-tensor GEMM graph with dynamic-shape overrides.""" + variant_pack = { + UIDs.A_UID.value: a, + UIDs.B_UID.value: b, + UIDs.A_SCALE_UID.value: a_scale, + UIDs.B_SCALE_UID.value: b_scale, + UIDs.O_UID.value: c_final, + } + + override_uids = [UIDs.A_UID.value, UIDs.B_UID.value, UIDs.O_UID.value] + override_shapes = [list(a.shape), list(b.shape), list(c_final.shape)] + override_strides = [list(a.stride()), list(b.stride()), list(c_final.stride())] + + stream = torch.cuda.current_stream(a.device) + cudnn_handle = _get_cudnn_handle(a.device, stream) + + if workspace.numel() < graph.get_workspace_size(): + workspace = torch.empty( + graph.get_workspace_size(), device=a.device, dtype=torch.uint8 + ) + + graph.execute_plan_at_index( + variant_pack, + workspace, + tactic, + handle=cudnn_handle, + override_uids=override_uids, + override_shapes=override_shapes, + override_strides=override_strides, + ) + + def _torch_data_type_to_cudnn_data_type(dtype: torch.dtype): if dtype == torch.bfloat16: return cudnn.data_type.BFLOAT16 @@ -2223,6 +2863,128 @@ def execute_cudnn_gemm_bf16_graph(graph, a, b, c_final, workspace, tactic: int = ) +# --------------------------------------------------------------------------- +# BF16 GEMM with override_shape (dynamic M dimension) +# --------------------------------------------------------------------------- + + +@functools.cache +def build_cudnn_gemm_bf16_graph_override_shape( + batch, + n, + k, + o_type, + device, + cache_m: int = _OVERRIDE_SHAPE_CACHE_M, + is_a_k_major: bool = True, + is_b_k_major: bool = True, +): + """Build a cuDNN BF16 GEMM graph with override-shape support. + + The graph is compiled once with ``cache_m`` as the M dimension. At + execution time the caller supplies the *actual* M via + ``execute_cudnn_gemm_bf16_graph_override_shape``, which calls + ``execute_plan_at_index`` with ``override_shapes`` / ``override_strides`` + so no rebuild is needed for different M values. + + Caching key is ``(batch, n, k, o_type, device, cache_m)`` — M is **not** + part of the key. + + Args: + is_a_k_major: If True, A has shape (batch, M, K) with row-major strides + (K is the contiguous dimension). If False, A has shape (batch, M, K) + with column-major strides (M is the contiguous dimension). + is_b_k_major: If True, B has shape (batch, K, N) where K is the leading + dimension (stride along N is 1, i.e. N-contiguous within each K row). + If False, B is row-major with K-contiguous layout (stride along K is 1). + """ + _check_cudnn_override_shape_availability() + + a_shape = (batch, cache_m, k) + a_stride = (cache_m * k, k, 1) if is_a_k_major else (cache_m * k, 1, cache_m) + b_shape = (batch, k, n) + b_stride = (k * n, 1, k) if is_b_k_major else (k * n, n, 1) + + stream = torch.cuda.current_stream(device) + graph = cudnn.pygraph( + io_data_type=cudnn.data_type.BFLOAT16, + intermediate_data_type=cudnn.data_type.FLOAT, + compute_data_type=cudnn.data_type.FLOAT, + handle=_get_cudnn_handle(device, stream), + is_override_shape_enabled=True, + ) + + a_cudnn_tensor = graph.tensor( + name="a", + dim=list(a_shape), + stride=list(a_stride), + data_type=cudnn.data_type.BFLOAT16, + ) + b_cudnn_tensor = graph.tensor( + name="b", + dim=list(b_shape), + stride=list(b_stride), + data_type=cudnn.data_type.BFLOAT16, + ) + c_cudnn_tensor = graph.matmul( + name="matmul", + A=a_cudnn_tensor, + B=b_cudnn_tensor, + compute_data_type=cudnn.data_type.FLOAT, + ) + c_cudnn_tensor.set_name("c").set_output(True).set_data_type(o_type) + + a_cudnn_tensor.set_uid(UIDs.A_UID.value) + b_cudnn_tensor.set_uid(UIDs.B_UID.value) + c_cudnn_tensor.set_uid(UIDs.O_UID.value) + + graph.validate() + graph.build_operation_graph() + graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) + graph.check_support() + graph.build_plans() + + return graph + + +def execute_cudnn_gemm_bf16_graph_override_shape( + graph, a, b, c_final, workspace, tactic: int = 0 +): + """Execute a BF16 GEMM cuDNN graph built with override-shape enabled. + + Passes the actual shapes/strides of *a*, *b*, and *c_final* as + ``override_shapes`` / ``override_strides`` so a single compiled plan + handles any M dimension without rebuilding. + """ + variant_pack = { + UIDs.A_UID.value: a, + UIDs.B_UID.value: b, + UIDs.O_UID.value: c_final, + } + + override_uids = [UIDs.A_UID.value, UIDs.B_UID.value, UIDs.O_UID.value] + override_shapes = [list(a.shape), list(b.shape), list(c_final.shape)] + override_strides = [list(a.stride()), list(b.stride()), list(c_final.stride())] + + stream = torch.cuda.current_stream(a.device) + cudnn_handle = _get_cudnn_handle(a.device, stream) + + if workspace.numel() < graph.get_workspace_size(): + workspace = torch.empty( + graph.get_workspace_size(), device=a.device, dtype=torch.uint8 + ) + + graph.execute_plan_at_index( + variant_pack, + workspace, + tactic, + handle=cudnn_handle, + override_uids=override_uids, + override_shapes=override_shapes, + override_strides=override_strides, + ) + + def _cudnn_gemm_bf16( workspace: torch.Tensor, a: torch.Tensor, diff --git a/tests/gemm/test_cudnn_override_shape.py b/tests/gemm/test_cudnn_override_shape.py new file mode 100644 index 0000000000..e3e3b94aa1 --- /dev/null +++ b/tests/gemm/test_cudnn_override_shape.py @@ -0,0 +1,326 @@ +""" +Tests for cuDNN GEMM operations using is_override_shape_enabled API. + +A single cuDNN graph is compiled once with a "cache shape" (large M). +At execution time, the actual M is passed via override_shapes / override_strides, +so no graph rebuild is triggered for varying M values. + +Requires: + - CUDA compute capability == SM100 / SM103 + - cuDNN frontend >= 1.20 / backend_version >= 92100 +""" + +import pytest +import torch +import torch.nn.functional as F + +from flashinfer.gemm.gemm_base import ( + CUDNN_AVAILABLE, + build_cudnn_gemm_bf16_graph_override_shape, + execute_cudnn_gemm_bf16_graph_override_shape, + build_cudnn_fp4_gemm_graph_override_shape, + execute_cudnn_fp4_gemm_graph_override_shape, + build_cudnn_mxfp8_gemm_graph_override_shape, + execute_cudnn_mxfp8_gemm_graph_override_shape, + is_cudnn_override_shape_available, + _calculate_block_scale_dims, +) +from flashinfer.utils import get_compute_capability +from flashinfer.fp4_quantization import nvfp4_quantize +from flashinfer.fp8_quantization import mxfp8_quantize + + +def _skip_if_no_cudnn(): + if not CUDNN_AVAILABLE: + pytest.skip("cuDNN not available") + + +def _skip_if_override_shape_not_supported(): + if not CUDNN_AVAILABLE: + pytest.skip("cuDNN not available") + if not is_cudnn_override_shape_available(): + pytest.skip( + "cuDNN override-shape requires higher version of cuDNN backend and frontend" + ) + + +def _skip_if_not_sm100_or_sm103(): + major, minor = get_compute_capability(torch.device("cuda")) + if major * 10 + minor not in [100, 103]: + pytest.skip("override-shape GEMM requires SM100 or SM103") + + +# ============================================================================ +# BF16 GEMM with override_shape +# ============================================================================ + + +class TestCudnnBf16OverrideShape: + """Single compiled plan handles multiple M dimensions for BF16 GEMM.""" + + @pytest.mark.skipif( + not CUDNN_AVAILABLE, + reason="cuDNN not available", + ) + @pytest.mark.parametrize( + "cache_m,dynamic_ms", + [ + (2048, [1, 4, 16, 32, 64, 128, 512, 1024, 2048]), + (4096, [1, 8, 64, 256, 1024, 4096]), + ], + ) + @pytest.mark.parametrize("n", [1024, 2048]) + @pytest.mark.parametrize("k", [1024, 2048]) + def test_bf16_override_shape_dynamic_m(self, cache_m, dynamic_ms, n, k): + _skip_if_no_cudnn() + _skip_if_override_shape_not_supported() + _skip_if_not_sm100_or_sm103() + + from flashinfer.gemm.gemm_base import _torch_data_type_to_cudnn_data_type + + device = torch.device("cuda") + in_dtype = torch.bfloat16 + out_dtype = torch.bfloat16 + + # Build graph once with cache_m + graph = build_cudnn_gemm_bf16_graph_override_shape( + batch=1, + n=n, + k=k, + o_type=_torch_data_type_to_cudnn_data_type(out_dtype), + device=device, + cache_m=cache_m, + is_a_k_major=True, + is_b_k_major=True, + ) + + workspace = torch.empty( + graph.get_workspace_size(), dtype=torch.uint8, device=device + ) + + b = torch.randn(1, n, k, dtype=in_dtype, device=device).transpose(1, 2) + + for m in dynamic_ms: + a = torch.randn(1, m, k, dtype=in_dtype, device=device) + out = torch.empty(1, m, n, dtype=out_dtype, device=device) + ref = torch.bmm(a.float(), b.float()).to(out_dtype) + + execute_cudnn_gemm_bf16_graph_override_shape( + graph, a, b, out, workspace, tactic=0 + ) + torch.cuda.synchronize() + + assert torch.allclose(ref, out, rtol=5e-2, atol=5e-2), ( + f"BF16 override_shape failed for m={m}, n={n}, k={k}: " + f"max_abs_err={(ref - out).abs().max().item():.4f}, " + f"max_rel_err={((ref - out).abs() / (ref.abs() + 1e-8)).max().item():.4f}" + ) + + +# ============================================================================ +# NVFP4 GEMM with override_shape +# ============================================================================ + + +class TestCudnnNVFp4OverrideShape: + """Single compiled plan handles multiple M dimensions for NVFP4 GEMM.""" + + @pytest.mark.skipif( + not CUDNN_AVAILABLE, + reason="cuDNN not available", + ) + @pytest.mark.parametrize( + "cache_m,dynamic_ms", + [ + (2048, [1, 4, 16, 32, 64, 128, 512, 1024, 2048]), + (4096, [1, 8, 64, 256, 1024, 4096]), + ], + ) + @pytest.mark.parametrize("n", [1024, 2048]) + @pytest.mark.parametrize("k", [1024, 2048]) + def test_nvfp4_override_shape_dynamic_m(self, cache_m, dynamic_ms, n, k): + _skip_if_no_cudnn() + _skip_if_override_shape_not_supported() + _skip_if_not_sm100_or_sm103() + + import cudnn + + if cudnn.backend_version() < 91002: + pytest.skip("FP4 requires cuDNN backend >= 91002") + from flashinfer.gemm.gemm_base import _torch_data_type_to_cudnn_data_type + + device = torch.device("cuda") + block_size = 16 + out_dtype = torch.bfloat16 + + # Compute block scale dims using cache_m + _, block_scale_dim_n, block_scale_dim_k = _calculate_block_scale_dims( + cache_m, n, k, block_size + ) + + # Build graph once with cache_m + graph = build_cudnn_fp4_gemm_graph_override_shape( + batch=1, + n=n, + k=k, + ab_type=cudnn.data_type.FP4_E2M1, + o_type=_torch_data_type_to_cudnn_data_type(out_dtype), + block_size=block_size, + device=device, + alpha_is_not_none=False, + use_nvfp4=True, + cache_m=cache_m, + ) + + workspace = torch.empty( + graph.get_workspace_size(), + dtype=torch.uint8, + device=device, + ) + + global_sf = torch.tensor(1.0, dtype=torch.float32, device=device) + + # B is fixed across all dynamic_ms + b_bf16 = torch.empty([1, n, k], device="cuda", dtype=torch.bfloat16).uniform_( + -5.0, 5.0 + ) + b_packed, b_scale = nvfp4_quantize(b_bf16, global_sf, True) + + b_bf16 = b_bf16.transpose(1, 2) + b_packed = b_packed.transpose(1, 2) + b_scale = b_scale.unsqueeze(0).transpose(1, 2) + + for m in dynamic_ms: + block_scale_dim_m, _, _ = _calculate_block_scale_dims(m, n, k, block_size) + + a_bf16 = torch.empty( + [1, m, k], device="cuda", dtype=torch.bfloat16 + ).uniform_(-5.0, 5.0) + a_packed, a_scale = nvfp4_quantize(a_bf16, global_sf, True) + + a_scale = a_scale.unsqueeze(0) + + # Execute with cached graph (override_shape) + out = torch.empty(1, m, n, dtype=out_dtype, device=device) + execute_cudnn_fp4_gemm_graph_override_shape( + graph, + a_packed, + b_packed, + a_scale, + b_scale, + alpha=None, + c_final=out, + workspace_buffer=workspace, + tactic=0, + ) + torch.cuda.synchronize() + + ref = torch.bmm(a_bf16, b_bf16).to(out_dtype) + + min_cos_sim = 0.9 + cos_sim = F.cosine_similarity(ref.reshape(-1), out.reshape(-1), dim=0) + assert cos_sim > min_cos_sim, ( + f"Cosine similarity {cos_sim:.4f} is too low (expected > {min_cos_sim})" + ) + + +# ============================================================================ +# MXFP8 GEMM with override_shape +# ============================================================================ + + +class TestCudnnMXFp8OverrideShape: + """Single compiled plan handles multiple M dimensions for MXFP8 GEMM.""" + + @pytest.mark.skipif( + not CUDNN_AVAILABLE, + reason="cuDNN not available", + ) + @pytest.mark.parametrize( + "cache_m,dynamic_ms", + [ + (2048, [1, 4, 16, 32, 64, 128, 512, 1024, 2048]), + (4096, [1, 8, 64, 256, 1024, 4096]), + ], + ) + @pytest.mark.parametrize("n", [1024, 2048]) + @pytest.mark.parametrize("k", [1024, 2048]) + def test_mxfp8_override_shape_dynamic_m(self, cache_m, dynamic_ms, n, k): + _skip_if_no_cudnn() + _skip_if_override_shape_not_supported() + _skip_if_not_sm100_or_sm103() + + import cudnn + from flashinfer.gemm.gemm_base import _torch_data_type_to_cudnn_data_type + + device = torch.device("cuda") + block_size = 32 + out_dtype = torch.bfloat16 + + # Compute block scale dims using cache_m + _, block_scale_dim_n, block_scale_dim_k = _calculate_block_scale_dims( + cache_m, n, k, block_size + ) + + # Build graph once with cache_m + graph = build_cudnn_mxfp8_gemm_graph_override_shape( + batch=1, + n=n, + k=k, + a_type=cudnn.data_type.FP8_E4M3, + b_type=cudnn.data_type.FP8_E4M3, + o_type=_torch_data_type_to_cudnn_data_type(out_dtype), + block_size=block_size, + device=device, + cache_m=cache_m, + ) + + workspace = torch.empty( + graph.get_workspace_size(), + dtype=torch.uint8, + device=device, + ) + + # B is fixed across all dynamic_ms + b_bf16 = torch.empty([1, n, k], device="cuda", dtype=torch.bfloat16).uniform_( + -5.0, 5.0 + ) + b, b_scale = mxfp8_quantize(b_bf16, True) + + b_bf16 = b_bf16.transpose(1, 2) + b = b.transpose(1, 2) + b_scale = b_scale.reshape((-1, block_scale_dim_n, block_scale_dim_k)).transpose( + 1, 2 + ) + + for m in dynamic_ms: + block_scale_dim_m, _, _ = _calculate_block_scale_dims(m, n, k, block_size) + + a_bf16 = torch.empty( + [1, m, k], device="cuda", dtype=torch.bfloat16 + ).uniform_(-5.0, 5.0) + a, a_scale = mxfp8_quantize(a_bf16, True) + + a_scale = a_scale.reshape((-1, block_scale_dim_m, block_scale_dim_k)) + + # Execute with cached graph (override_shape) + out = torch.empty(1, m, n, dtype=out_dtype, device=device) + execute_cudnn_mxfp8_gemm_graph_override_shape( + graph, + a, + b, + a_scale, + b_scale, + c_final=out, + workspace_buffer=workspace, + tactic=0, + ) + torch.cuda.synchronize() + + ref = torch.bmm(a_bf16, b_bf16).to(out_dtype) + + min_cos_sim = 0.9 + cos_sim = F.cosine_similarity(ref.reshape(-1), out.reshape(-1), dim=0) + assert cos_sim > min_cos_sim, ( + f"Cosine similarity {cos_sim:.4f} is too low (expected > {min_cos_sim})" + )