diff --git a/flashinfer/gemm/__init__.py b/flashinfer/gemm/__init__.py index d44d03fc27..d7da621fae 100644 --- a/flashinfer/gemm/__init__.py +++ b/flashinfer/gemm/__init__.py @@ -23,10 +23,10 @@ is_cudnn_override_shape_available as is_cudnn_override_shape_available, build_cudnn_gemm_bf16_graph_override_shape as build_cudnn_gemm_bf16_graph_override_shape, execute_cudnn_gemm_bf16_graph_override_shape as execute_cudnn_gemm_bf16_graph_override_shape, - build_cudnn_fp4_gemm_graph_override_shape as build_cudnn_fp4_gemm_graph_override_shape, - execute_cudnn_fp4_gemm_graph_override_shape as execute_cudnn_fp4_gemm_graph_override_shape, - build_cudnn_mxfp8_gemm_graph_override_shape as build_cudnn_mxfp8_gemm_graph_override_shape, - execute_cudnn_mxfp8_gemm_graph_override_shape as execute_cudnn_mxfp8_gemm_graph_override_shape, + build_cudnn_gemm_fp4_graph_override_shape as build_cudnn_gemm_fp4_graph_override_shape, + execute_cudnn_gemm_fp4_graph_override_shape as execute_cudnn_gemm_fp4_graph_override_shape, + build_cudnn_gemm_mxfp8_graph_override_shape as build_cudnn_gemm_mxfp8_graph_override_shape, + execute_cudnn_gemm_mxfp8_graph_override_shape as execute_cudnn_gemm_mxfp8_graph_override_shape, build_cudnn_gemm_with_per_tensor_q_graph_override_shape as build_cudnn_gemm_with_per_tensor_q_graph_override_shape, execute_cudnn_gemm_with_per_tensor_q_graph_override_shape as execute_cudnn_gemm_with_per_tensor_q_graph_override_shape, ) @@ -81,10 +81,10 @@ "is_cudnn_override_shape_available", "build_cudnn_gemm_bf16_graph_override_shape", "execute_cudnn_gemm_bf16_graph_override_shape", - "build_cudnn_fp4_gemm_graph_override_shape", - "execute_cudnn_fp4_gemm_graph_override_shape", - "build_cudnn_mxfp8_gemm_graph_override_shape", - "execute_cudnn_mxfp8_gemm_graph_override_shape", + "build_cudnn_gemm_fp4_graph_override_shape", + "execute_cudnn_gemm_fp4_graph_override_shape", + "build_cudnn_gemm_mxfp8_graph_override_shape", + "execute_cudnn_gemm_mxfp8_graph_override_shape", "build_cudnn_gemm_with_per_tensor_q_graph_override_shape", "execute_cudnn_gemm_with_per_tensor_q_graph_override_shape", ] + _cute_dsl_kernels diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index 9e3b7e0b3a..41b07dc51d 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -1950,13 +1950,9 @@ def execute_cudnn_gemm_fp4_graph( # 8192 covers typical LLM inference shapes and set as default value. _OVERRIDE_SHAPE_CACHE_M = 8192 -# --------------------------------------------------------------------------- -# FP4 GEMM with override_shape (dynamic M dimension) -# --------------------------------------------------------------------------- - @functools.cache -def build_cudnn_fp4_gemm_graph_override_shape( +def build_cudnn_gemm_fp4_graph_override_shape( batch, n, k, @@ -1967,6 +1963,7 @@ def build_cudnn_fp4_gemm_graph_override_shape( alpha_is_not_none, use_nvfp4, cache_m: int = _OVERRIDE_SHAPE_CACHE_M, + policy=None, ): """Build a cuDNN FP4 GEMM graph with override-shape support. @@ -1977,7 +1974,10 @@ def build_cudnn_fp4_gemm_graph_override_shape( Caching key contains ``(batch, n, k, ...)`` but **not** M. """ + _check_cudnn_override_shape_availability() + if policy is None: + policy = cudnn.build_plan_policy.HEURISTICS_CHOICE scale_type = cudnn.data_type.FP8_E4M3 if use_nvfp4 else cudnn.data_type.FP8_E8M0 @@ -2082,12 +2082,12 @@ def build_cudnn_fp4_gemm_graph_override_shape( graph.deselect_engines(["eng0"]) graph.check_support() - graph.build_plans() + graph.build_plans(policy) return graph -def execute_cudnn_fp4_gemm_graph_override_shape( +def execute_cudnn_gemm_fp4_graph_override_shape( graph, a, b, @@ -2100,7 +2100,23 @@ def execute_cudnn_fp4_gemm_graph_override_shape( ): """Execute FP4 GEMM cuDNN graph with dynamic-shape overrides.""" - assert a.stride()[2] == 1 and b.stride()[1] == 1, "a and b must be k-major" + real_a_shape, real_a_stride = _get_real_fp4_shape_from_packed_uint8(a) + real_b_shape, real_b_stride = _get_real_fp4_shape_from_packed_uint8(b) + batch = real_a_shape[0] + expanded_a_descale_shape, expanded_a_descale_stride = ( + _expand_block_scale_tensor_shape(a_descale, batch) + ) + expanded_b_descale_shape, expanded_b_descale_stride = ( + _expand_block_scale_tensor_shape(b_descale, batch) + ) + + c_shape, c_stride = _get_bf16_3d_shape_stride(c_final) + + if real_a_stride[2] != 1 or real_b_stride[1] != 1: + raise ValueError( + f"a and b must be k-major (contiguous along the K dimension), " + f"got a stride={tuple(real_a_stride)}, b stride={tuple(real_b_stride)}" + ) variant_pack = { UIDs.A_UID.value: a, @@ -2121,20 +2137,25 @@ def execute_cudnn_fp4_gemm_graph_override_shape( UIDs.O_UID.value, ] override_shapes = [ - [a.shape[0], a.shape[1], a.shape[2] * 2], - [b.shape[0], b.shape[1] * 2, b.shape[2]], - a_descale.shape, - b_descale.shape, - c_final.shape, + real_a_shape, + real_b_shape, + expanded_a_descale_shape, + expanded_b_descale_shape, + c_shape, ] override_strides = [ - [a.stride()[0], a.stride()[1] * 2, a.stride()[2]], - [b.stride()[0], b.stride()[1], b.stride()[2] * 2], - a_descale.stride(), - b_descale.stride(), - c_final.stride(), + real_a_stride, + real_b_stride, + expanded_a_descale_stride, + expanded_b_descale_stride, + c_stride, ] + if workspace_buffer.numel() < graph.get_workspace_size(): + workspace_buffer = torch.empty( + graph.get_workspace_size(), device=a.device, dtype=torch.uint8 + ) + stream = torch.cuda.current_stream(a.device) graph.execute_plan_at_index( @@ -2148,55 +2169,6 @@ def execute_cudnn_fp4_gemm_graph_override_shape( ) -def _get_cudnn_fp4_gemm_graph_override_shape( - a: torch.Tensor, - b: torch.Tensor, - a_descale: torch.Tensor, - b_descale: torch.Tensor, - alpha: Optional[torch.Tensor] = None, - out_dtype: torch.dtype = torch.bfloat16, - out: Optional[torch.Tensor] = None, - block_size: int = 16, - use_nvfp4: bool = True, -): - """Get (or build) a cached FP4 GEMM graph with override-shape support. - - The cache key excludes M so the same compiled plan is reused for all M - values. - """ - real_a_shape, _ = _get_real_fp4_shape_from_packed_uint8(a) - real_b_shape, _ = _get_real_fp4_shape_from_packed_uint8(b) - batch = real_a_shape[0] - n = real_b_shape[2] - k = real_a_shape[2] - - expanded_a_descale_shape, _ = _expand_block_scale_tensor_shape(a_descale, batch) - expanded_b_descale_shape, _ = _expand_block_scale_tensor_shape(b_descale, batch) - - # Scale dimension sizes that are independent of M - a_descale_k_dim = expanded_a_descale_shape[2] - b_descale_k_dim = expanded_b_descale_shape[1] - b_descale_n_dim = expanded_b_descale_shape[2] - # a_descale N-dimension (dim[1]) depends on M, so we pass it separately - a_descale_n_dim = expanded_a_descale_shape[1] - - return build_cudnn_fp4_gemm_graph_override_shape( - batch=batch, - n=n, - k=k, - a_descale_n_dim=a_descale_n_dim, - a_descale_k_dim=a_descale_k_dim, - b_descale_k_dim=b_descale_k_dim, - b_descale_n_dim=b_descale_n_dim, - ab_type=cudnn.data_type.FP4_E2M1, - o_type=_torch_data_type_to_cudnn_data_type(out_dtype), - block_size=block_size, - device=a.device, - alpha_is_not_none=alpha is not None, - use_nvfp4=use_nvfp4, - ) - - def execute_cudnn_gemm_mxfp8_graph( graph, a, @@ -2237,13 +2209,8 @@ def execute_cudnn_gemm_mxfp8_graph( ) -# --------------------------------------------------------------------------- -# MXFP8 GEMM with override_shape (dynamic M dimension) -# --------------------------------------------------------------------------- - - @functools.cache -def build_cudnn_mxfp8_gemm_graph_override_shape( +def build_cudnn_gemm_mxfp8_graph_override_shape( batch, n, k, @@ -2253,6 +2220,7 @@ def build_cudnn_mxfp8_gemm_graph_override_shape( block_size, device, cache_m: int = _OVERRIDE_SHAPE_CACHE_M, + policy=None, ): """Build a cuDNN MXFP8 GEMM graph with override-shape support. @@ -2260,6 +2228,8 @@ def build_cudnn_mxfp8_gemm_graph_override_shape( provided through ``override_shapes`` / ``override_strides``. """ _check_cudnn_override_shape_availability() + if policy is None: + policy = cudnn.build_plan_policy.HEURISTICS_CHOICE if a_type not in [cudnn.data_type.FP8_E4M3, cudnn.data_type.FP8_E5M2]: raise ValueError(f"A type must be FP8_E4M3 or FP8_E5M2, got {a_type}") @@ -2347,12 +2317,12 @@ def build_cudnn_mxfp8_gemm_graph_override_shape( graph.build_operation_graph() graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.B]) graph.check_support() - graph.build_plans() + graph.build_plans(policy) return graph -def execute_cudnn_mxfp8_gemm_graph_override_shape( +def execute_cudnn_gemm_mxfp8_graph_override_shape( graph, a, b, @@ -2393,7 +2363,13 @@ def execute_cudnn_mxfp8_gemm_graph_override_shape( list(c_final.stride()), ] + if workspace_buffer.numel() < graph.get_workspace_size(): + workspace_buffer = torch.empty( + graph.get_workspace_size(), device=a.device, dtype=torch.uint8 + ) + stream = torch.cuda.current_stream(a.device) + graph.execute_plan_at_index( variant_pack, workspace_buffer, @@ -2405,64 +2381,6 @@ def execute_cudnn_mxfp8_gemm_graph_override_shape( ) -def _cudnn_gemm_mxfp8_override_shape( - a: torch.Tensor, - b: torch.Tensor, - a_descale: torch.Tensor, - b_descale: torch.Tensor, - out_dtype: torch.dtype = torch.bfloat16, - out: Optional[torch.Tensor] = None, - workspace_buffer: torch.Tensor = None, - tactic: int = 0, -): - """MXFP8 GEMM via cuDNN using override-shape for dynamic M dimension.""" - block_size = 32 - - a_3d = a if a.ndim == 3 else a.unsqueeze(0) - b_3d = b if b.ndim == 3 else b.unsqueeze(0) - out_3d = out if out.ndim == 3 else out.unsqueeze(0) - - batch = a_3d.shape[0] - k = a_3d.shape[2] - n = b_3d.shape[2] - - block_scale_dim_m, block_scale_dim_n, block_scale_dim_k = ( - _calculate_block_scale_dims(a_3d.shape[1], n, k, block_size) - ) - - if a_descale.ndim == 2: - a_descale_3d = a_descale.view(batch, block_scale_dim_m, block_scale_dim_k) - else: - a_descale_3d = a_descale - - if b_descale.ndim == 2: - b_descale_3d = b_descale.view(batch, block_scale_dim_k, block_scale_dim_n) - else: - b_descale_3d = b_descale - - graph = build_cudnn_mxfp8_gemm_graph_override_shape( - batch=batch, - n=n, - k=k, - a_type=_torch_data_type_to_cudnn_data_type(a.dtype), - b_type=_torch_data_type_to_cudnn_data_type(b.dtype), - o_type=_torch_data_type_to_cudnn_data_type(out_dtype), - block_size=block_size, - device=a.device, - ) - - execute_cudnn_mxfp8_gemm_graph_override_shape( - graph=graph, - a=a_3d, - b=b_3d, - a_descale=a_descale_3d, - b_descale=b_descale_3d, - c_final=out_3d, - workspace_buffer=workspace_buffer, - tactic=tactic, - ) - - @functools.lru_cache(maxsize=1024) def build_cudnn_gemm_with_per_tensor_q_graph( a_shape, a_stride, b_shape, b_stride, a_type, b_type, o_type, device @@ -2845,6 +2763,7 @@ def build_cudnn_gemm_bf16_graph_override_shape( cache_m: int = _OVERRIDE_SHAPE_CACHE_M, is_a_k_major: bool = True, is_b_k_major: bool = True, + policy=None, ): """Build a cuDNN BF16 GEMM graph with override-shape support. @@ -2866,6 +2785,8 @@ def build_cudnn_gemm_bf16_graph_override_shape( If False, B is row-major with K-contiguous layout (stride along K is 1). """ _check_cudnn_override_shape_availability() + if policy is None: + policy = cudnn.build_plan_policy.HEURISTICS_CHOICE a_shape = (batch, cache_m, k) a_stride = (cache_m * k, k, 1) if is_a_k_major else (cache_m * k, 1, cache_m) @@ -2909,7 +2830,7 @@ def build_cudnn_gemm_bf16_graph_override_shape( graph.build_operation_graph() graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) graph.check_support() - graph.build_plans() + graph.build_plans(policy) return graph @@ -2929,9 +2850,13 @@ def execute_cudnn_gemm_bf16_graph_override_shape( UIDs.O_UID.value: c_final, } + a_shape, a_stride = _get_bf16_3d_shape_stride(a) + b_shape, b_stride = _get_bf16_3d_shape_stride(b) + c_shape, c_stride = _get_bf16_3d_shape_stride(c_final) + override_uids = [UIDs.A_UID.value, UIDs.B_UID.value, UIDs.O_UID.value] - override_shapes = [list(a.shape), list(b.shape), list(c_final.shape)] - override_strides = [list(a.stride()), list(b.stride()), list(c_final.stride())] + override_shapes = [list(a_shape), list(b_shape), list(c_shape)] + override_strides = [list(a_stride), list(b_stride), list(c_stride)] stream = torch.cuda.current_stream(a.device) cudnn_handle = _get_cudnn_handle(a.device, stream) @@ -2986,6 +2911,36 @@ def _cudnn_gemm_bf16( def _cudnn_gemm_bf16_runner(): class CudnnBf16GemmRunner(TunableRunner): + @staticmethod + def _get_override_graph(a, b, out): + a_shape, a_stride = _get_bf16_3d_shape_stride(a) + b_shape, b_stride = _get_bf16_3d_shape_stride(b) + + batch = a_shape[0] + actual_m = a_shape[-2] + k = a_shape[-1] + n = b_shape[-1] + o_type = _torch_data_type_to_cudnn_data_type(out.dtype) + + # Ceiling power-of-2 ensures cache_m >= actual_M. + cache_m = last_positive_power_of_2(actual_m) + + is_a_k_major = a_stride[-1] == 1 + is_b_k_major = b_stride[-2] == 1 + + graph = build_cudnn_gemm_bf16_graph_override_shape( + batch=batch, + n=n, + k=k, + o_type=o_type, + device=a.device, + cache_m=cache_m, + is_a_k_major=is_a_k_major, + is_b_k_major=is_b_k_major, + policy=cudnn.build_plan_policy.ALL, + ) + return graph + def get_cache_key_extras(self, inputs: List[torch.Tensor]) -> tuple: # inputs layout: a, b, bias, pdl, out, workspace_buffer # out.dtype distinguishes bfloat16 / float16 / float32 output graphs @@ -2998,18 +2953,22 @@ def get_valid_tactics( profile: OptimizationProfile, ) -> List[int]: a, b, _, _, out, _ = inputs - a_shape, a_stride = _get_bf16_3d_shape_stride(a) - b_shape, b_stride = _get_bf16_3d_shape_stride(b) - graph = build_cudnn_gemm_bf16_graph( - a_shape, - a_stride, - b_shape, - b_stride, - _torch_data_type_to_cudnn_data_type(out.dtype), - a.device, - policy=cudnn.build_plan_policy.ALL, - ) + if is_cudnn_override_shape_available(): + graph = self._get_override_graph(a, b, out) + else: + a_shape, a_stride = _get_bf16_3d_shape_stride(a) + b_shape, b_stride = _get_bf16_3d_shape_stride(b) + + graph = build_cudnn_gemm_bf16_graph( + a_shape, + a_stride, + b_shape, + b_stride, + _torch_data_type_to_cudnn_data_type(out.dtype), + a.device, + policy=cudnn.build_plan_policy.ALL, + ) return list(range(graph.get_execution_plan_count())) @@ -3027,7 +2986,19 @@ def forward( if pdl: raise ValueError("cudnn bf16 gemm does not support pdl.") - _cudnn_gemm_bf16(workspace_buffer, a, b, out, tactic=tactic) + if is_cudnn_override_shape_available(): + graph = self._get_override_graph(a, b, out) + + execute_cudnn_gemm_bf16_graph_override_shape( + graph, + a, + b, + out, + workspace_buffer, + tactic=max(tactic, 0), + ) + else: + _cudnn_gemm_bf16(workspace_buffer, a, b, out, tactic=tactic) return out @@ -4092,6 +4063,34 @@ def _cudnn_gemm_fp4( def _cudnn_gemm_fp4_runner(): class CudnnFp4GemmRunner(TunableRunner): + @staticmethod + def _get_override_graph(a, b, alpha, out_dtype, block_size, use_nvfp4): + real_a_shape, _ = _get_real_fp4_shape_from_packed_uint8(a) + real_b_shape, _ = _get_real_fp4_shape_from_packed_uint8(b) + + batch = real_a_shape[0] + actual_m = real_a_shape[1] + k = real_a_shape[2] + n = real_b_shape[2] + + # Ceiling power-of-2 ensures cache_m >= actual_m. + cache_m = last_positive_power_of_2(actual_m) + + graph = build_cudnn_gemm_fp4_graph_override_shape( + batch=batch, + n=n, + k=k, + ab_type=cudnn.data_type.FP4_E2M1, + o_type=_torch_data_type_to_cudnn_data_type(out_dtype), + block_size=block_size, + device=a.device, + alpha_is_not_none=alpha is not None, + use_nvfp4=use_nvfp4, + cache_m=cache_m, + policy=cudnn.build_plan_policy.ALL, + ) + return graph + def get_cache_key_extras(self, inputs: List[torch.Tensor]) -> tuple: # inputs layout: a, b, a_descale, b_descale, alpha, out_dtype, # out, block_size, use_nvfp4, workspace_buffer @@ -4104,7 +4103,6 @@ def get_valid_tactics( inputs: List[torch.Tensor], profile: OptimizationProfile, ) -> List[int]: - # cudnn has heuristic for fp4 gemm, so we only need to use the default tactic ( a, b, @@ -4118,36 +4116,43 @@ def get_valid_tactics( workspace_buffer, ) = inputs - # the fp4 cudnn graph will be shared for both mm and bmm, so - # here we need to get the 3d shape and stride including the - # batch dimension for both input and block scale tensors. - real_a_shape, real_a_stride = _get_real_fp4_shape_from_packed_uint8(a) - real_b_shape, real_b_stride = _get_real_fp4_shape_from_packed_uint8(b) - batch = real_a_shape[0] - expanded_a_descale_shape, expanded_a_descale_stride = ( - _expand_block_scale_tensor_shape(a_descale, batch) - ) - expanded_b_descale_shape, expanded_b_descale_stride = ( - _expand_block_scale_tensor_shape(b_descale, batch) - ) + # currently cudnn backend does not support alpha for dynamic-shape + # remove this restriction once cudnn suppport it + if is_cudnn_override_shape_available(): + graph = self._get_override_graph( + a, b, alpha, out_dtype, block_size, use_nvfp4 + ) + else: + # the fp4 cudnn graph will be shared for both mm and bmm, so + # here we need to get the 3d shape and stride including the + # batch dimension for both input and block scale tensors. + real_a_shape, real_a_stride = _get_real_fp4_shape_from_packed_uint8(a) + real_b_shape, real_b_stride = _get_real_fp4_shape_from_packed_uint8(b) + batch = real_a_shape[0] + expanded_a_descale_shape, expanded_a_descale_stride = ( + _expand_block_scale_tensor_shape(a_descale, batch) + ) + expanded_b_descale_shape, expanded_b_descale_stride = ( + _expand_block_scale_tensor_shape(b_descale, batch) + ) - graph = build_cudnn_gemm_fp4_graph( - real_a_shape, - real_a_stride, - real_b_shape, - real_b_stride, - expanded_a_descale_shape, - expanded_a_descale_stride, - expanded_b_descale_shape, - expanded_b_descale_stride, - cudnn.data_type.FP4_E2M1, - _torch_data_type_to_cudnn_data_type(out_dtype), - block_size, - a.device, - alpha is not None, - use_nvfp4, - policy=cudnn.build_plan_policy.ALL, - ) + graph = build_cudnn_gemm_fp4_graph( + real_a_shape, + real_a_stride, + real_b_shape, + real_b_stride, + expanded_a_descale_shape, + expanded_a_descale_stride, + expanded_b_descale_shape, + expanded_b_descale_stride, + cudnn.data_type.FP4_E2M1, + _torch_data_type_to_cudnn_data_type(out_dtype), + block_size, + a.device, + alpha is not None, + use_nvfp4, + policy=cudnn.build_plan_policy.ALL, + ) return list(range(graph.get_execution_plan_count())) @@ -4171,19 +4176,38 @@ def forward( workspace_buffer, ) = inputs - _cudnn_gemm_fp4( - a, - b, - a_descale, - b_descale, - alpha, - out_dtype, - out, - block_size, - use_nvfp4, - workspace_buffer, - tactic=tactic, - ) + # currently cudnn backend does not support alpha for dynamic-shape + # remove this restriction once cudnn suppport it + if is_cudnn_override_shape_available(): + graph = self._get_override_graph( + a, b, alpha, out_dtype, block_size, use_nvfp4 + ) + + execute_cudnn_gemm_fp4_graph_override_shape( + graph, + a, + b, + a_descale, + b_descale, + alpha, + out, + workspace_buffer, + tactic=max(tactic, 0), + ) + else: + _cudnn_gemm_fp4( + a, + b, + a_descale, + b_descale, + alpha, + out_dtype, + out, + block_size, + use_nvfp4, + workspace_buffer, + tactic=tactic, + ) return out diff --git a/tests/gemm/test_cudnn_override_shape.py b/tests/gemm/test_cudnn_override_shape.py index e3e3b94aa1..9684a477c8 100644 --- a/tests/gemm/test_cudnn_override_shape.py +++ b/tests/gemm/test_cudnn_override_shape.py @@ -18,15 +18,15 @@ CUDNN_AVAILABLE, build_cudnn_gemm_bf16_graph_override_shape, execute_cudnn_gemm_bf16_graph_override_shape, - build_cudnn_fp4_gemm_graph_override_shape, - execute_cudnn_fp4_gemm_graph_override_shape, - build_cudnn_mxfp8_gemm_graph_override_shape, - execute_cudnn_mxfp8_gemm_graph_override_shape, + build_cudnn_gemm_fp4_graph_override_shape, + execute_cudnn_gemm_fp4_graph_override_shape, + build_cudnn_gemm_mxfp8_graph_override_shape, + execute_cudnn_gemm_mxfp8_graph_override_shape, is_cudnn_override_shape_available, _calculate_block_scale_dims, ) from flashinfer.utils import get_compute_capability -from flashinfer.fp4_quantization import nvfp4_quantize +from flashinfer.fp4_quantization import fp4_quantize from flashinfer.fp8_quantization import mxfp8_quantize @@ -159,7 +159,7 @@ def test_nvfp4_override_shape_dynamic_m(self, cache_m, dynamic_ms, n, k): ) # Build graph once with cache_m - graph = build_cudnn_fp4_gemm_graph_override_shape( + graph = build_cudnn_gemm_fp4_graph_override_shape( batch=1, n=n, k=k, @@ -184,7 +184,7 @@ def test_nvfp4_override_shape_dynamic_m(self, cache_m, dynamic_ms, n, k): b_bf16 = torch.empty([1, n, k], device="cuda", dtype=torch.bfloat16).uniform_( -5.0, 5.0 ) - b_packed, b_scale = nvfp4_quantize(b_bf16, global_sf, True) + b_packed, b_scale = fp4_quantize(b_bf16, global_sf) b_bf16 = b_bf16.transpose(1, 2) b_packed = b_packed.transpose(1, 2) @@ -196,13 +196,13 @@ def test_nvfp4_override_shape_dynamic_m(self, cache_m, dynamic_ms, n, k): a_bf16 = torch.empty( [1, m, k], device="cuda", dtype=torch.bfloat16 ).uniform_(-5.0, 5.0) - a_packed, a_scale = nvfp4_quantize(a_bf16, global_sf, True) + a_packed, a_scale = fp4_quantize(a_bf16, global_sf) a_scale = a_scale.unsqueeze(0) # Execute with cached graph (override_shape) out = torch.empty(1, m, n, dtype=out_dtype, device=device) - execute_cudnn_fp4_gemm_graph_override_shape( + execute_cudnn_gemm_fp4_graph_override_shape( graph, a_packed, b_packed, @@ -263,7 +263,7 @@ def test_mxfp8_override_shape_dynamic_m(self, cache_m, dynamic_ms, n, k): ) # Build graph once with cache_m - graph = build_cudnn_mxfp8_gemm_graph_override_shape( + graph = build_cudnn_gemm_mxfp8_graph_override_shape( batch=1, n=n, k=k, @@ -305,7 +305,7 @@ def test_mxfp8_override_shape_dynamic_m(self, cache_m, dynamic_ms, n, k): # Execute with cached graph (override_shape) out = torch.empty(1, m, n, dtype=out_dtype, device=device) - execute_cudnn_mxfp8_gemm_graph_override_shape( + execute_cudnn_gemm_mxfp8_graph_override_shape( graph, a, b,