diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index 64f8f8da66..c7708b736b 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -190,7 +190,7 @@ def _cutlass_mm_bf16_requirement( out_dtype: torch.dtype = torch.bfloat16, bias: Optional[torch.Tensor] = None, pdl: bool = False, - backend: Literal["cutlass", "tgv", "auto"] = "tgv", + backend: Literal["cudnn", "cutlass", "tgv", "auto"] = "tgv", ): if bias is not None: raise ValueError( @@ -206,6 +206,31 @@ def _cutlass_mm_bf16_requirement( return True +@supported_compute_capability([100, 103]) +def _cudnn_mm_bf16_requirement( + a: torch.Tensor, + b: torch.Tensor, + out: Optional[torch.Tensor] = None, + out_dtype: torch.dtype = torch.bfloat16, + bias: Optional[torch.Tensor] = None, + pdl: bool = False, + backend: Literal["cudnn", "cutlass", "tgv", "auto"] = "tgv", +): + if bias is not None: + raise ValueError( + "You cannot use the cuDNN backend with a bias. Use the TGV backend instead." + ) + if pdl: + raise ValueError( + "The cuDNN backend does not support PDL. Use the TGV backend instead." + ) + + _validate_bf16_output_dtype(out_dtype) + _check_cudnn_availability() + + return True + + @supported_compute_capability([100, 103]) def _tgv_gemm_requirement( a: torch.Tensor, @@ -214,7 +239,7 @@ def _tgv_gemm_requirement( out_dtype: torch.dtype = torch.bfloat16, bias: Optional[torch.Tensor] = None, pdl: bool = False, - backend: Literal["cutlass", "tgv", "auto"] = "tgv", + backend: Literal["cudnn", "cutlass", "tgv", "auto"] = "tgv", ): if out_dtype != torch.bfloat16: raise ValueError( @@ -230,7 +255,7 @@ def _check_mm_bf16_problem_size( pdl: bool = False, out: Optional[torch.Tensor] = None, out_dtype: torch.dtype = torch.bfloat16, - backend: Literal["cutlass", "tgv", "auto"] = "tgv", + backend: Literal["cudnn", "cutlass", "tgv", "auto"] = "tgv", ): if a.dtype != torch.bfloat16: raise ValueError( @@ -246,6 +271,20 @@ def _check_mm_bf16_problem_size( f"Bias tensor has unsupported dtype {bias.dtype}. Only bfloat16 is supported." ) + if out is not None: + if out.shape != (a.shape[0], b.shape[1]): + raise ValueError( + f"Output shape mismatch. Expected {(a.shape[0], b.shape[1])}, got {out.shape}." + ) + if out.device != a.device: + raise ValueError( + f"Output device mismatch. Expected {a.device}, got {out.device}." + ) + if out.dtype != out_dtype: + raise ValueError( + f"Output dtype mismatch. Expected {out_dtype}, got {out.dtype}." + ) + return True @@ -257,13 +296,16 @@ def _heuristic_func_mm_bf16( pdl: bool = False, out: Optional[torch.Tensor] = None, out_dtype: torch.dtype = torch.bfloat16, - backend: Literal["cutlass", "tgv", "auto"] = "tgv", + backend: Literal["cudnn", "cutlass", "tgv", "auto"] = "tgv", ): heuristic_backends = [] if bias is not None or pdl: + # cuDNN and CUTLASS don't support bias/pdl, only TGV does if "tgv" in suitable_backends: heuristic_backends.append("tgv") else: + if "cudnn" in suitable_backends: + heuristic_backends.append("cudnn") if "cutlass" in suitable_backends: heuristic_backends.append("cutlass") if "tgv" in suitable_backends: @@ -273,6 +315,7 @@ def _heuristic_func_mm_bf16( @backend_requirement( { + "cudnn": _cudnn_mm_bf16_requirement, "cutlass": _cutlass_mm_bf16_requirement, "tgv": _tgv_gemm_requirement, }, @@ -287,7 +330,7 @@ def mm_bf16( pdl: bool = False, out: Optional[torch.Tensor] = None, out_dtype: torch.dtype = torch.bfloat16, - backend: Literal["cutlass", "tgv", "auto"] = "tgv", + backend: Literal["cudnn", "cutlass", "tgv", "auto"] = "tgv", ) -> torch.Tensor: r"""MM BF16 @@ -309,10 +352,13 @@ def mm_bf16( Out tensor, shape (m, n), bf16 or fp16. If provided, can only be used with the CUTLASS backend. Defaults to ``None``. out_dtype: torch.dtype - Output dtype, bf16 or fp16. If provided, can only be used with the CUTLASS backend. Defaults to ``torch.bfloat16``. + Output dtype, bf16 or fp16. Can be used with the CUTLASS or cuDNN backends. Defaults to ``torch.bfloat16``. - backend: Literal["cutlass", "tgv", "auto"] + backend: Literal["cudnn", "cutlass", "tgv", "auto"] The backend to use for the operation. Defaults to ``"tgv"``. + ``"cudnn"`` uses the cuDNN backend (no bias/pdl support). + ``"cutlass"`` uses the CUTLASS backend (no bias/pdl support). + ``"tgv"`` uses the TGV backend (supports bias/pdl, bf16 output only). ``"auto"`` allows selecting the best tactic from all available backends when autotune is enabled. Returns @@ -348,25 +394,16 @@ def mm_bf16( device=a.device, dtype=out_dtype, ) - else: - if out.shape != (a.shape[0], b.shape[1]): - raise ValueError( - f"Output shape mismatch. Expected {(a.shape[0], b.shape[1])}, got {out.shape}." - ) - if out.device != a.device: - raise ValueError( - f"Output device mismatch. Expected {a.device}, got {out.device}." - ) - if out.dtype != out_dtype: - raise ValueError( - f"Output dtype mismatch. Expected {out_dtype}, got {out.dtype}." - ) workspace_buffer = _get_cache_buf( "mm_bf16_workspace", DEFAULT_WORKSPACE_SIZE, a.device ) if backend == "auto": backends = mm_bf16.suitable_auto_backends + elif backend == "cudnn": + backends = _heuristic_func_mm_bf16( + ["cudnn"], a, b, None, False, out, out_dtype, backend + ) elif backend == "cutlass": backends = _heuristic_func_mm_bf16( ["cutlass"], a, b, None, False, out, out_dtype, backend @@ -388,19 +425,32 @@ def _cutlass_bmm_bf16_requirement( B: torch.Tensor, out: Optional[torch.Tensor] = None, out_dtype: torch.dtype = torch.bfloat16, - backend: Literal["cutlass"] = "cutlass", + backend: Literal["cudnn", "cutlass", "auto"] = "cutlass", ): _validate_bf16_output_dtype(out_dtype) return True +@supported_compute_capability([100, 103]) +def _cudnn_bmm_bf16_requirement( + A: torch.Tensor, + B: torch.Tensor, + out: Optional[torch.Tensor] = None, + out_dtype: torch.dtype = torch.bfloat16, + backend: Literal["cudnn", "cutlass", "auto"] = "cutlass", +): + _validate_bf16_output_dtype(out_dtype) + _check_cudnn_availability() + return True + + def _check_bmm_bf16_problem_size( A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Tensor] = None, out_dtype: torch.dtype = torch.bfloat16, - backend: Literal["cutlass"] = "cutlass", + backend: Literal["cudnn", "cutlass", "auto"] = "cutlass", ): if A.dtype != torch.bfloat16: raise ValueError( @@ -411,6 +461,21 @@ def _check_bmm_bf16_problem_size( f"Second tensor has unsupported dtype {B.dtype}. Only bfloat16 is supported." ) + if out is not None: + expected_shape = (A.shape[0], A.shape[1], B.shape[2]) + if out.shape != expected_shape: + raise ValueError( + f"Output shape mismatch. Expected {expected_shape}, got {out.shape}." + ) + if out.device != A.device: + raise ValueError( + f"Output device mismatch. Expected {A.device}, got {out.device}." + ) + if out.dtype != out_dtype: + raise ValueError( + f"Output dtype mismatch. Expected {out_dtype}, got {out.dtype}." + ) + return True @@ -420,9 +485,11 @@ def _heuristic_func_bmm_bf16( B: torch.Tensor, out: Optional[torch.Tensor] = None, out_dtype: torch.dtype = torch.bfloat16, - backend: Literal["cutlass"] = "cutlass", + backend: Literal["cudnn", "cutlass", "auto"] = "cutlass", ): heuristic_backends = [] + if "cudnn" in suitable_backends: + heuristic_backends.append("cudnn") if "cutlass" in suitable_backends: heuristic_backends.append("cutlass") return heuristic_backends @@ -431,6 +498,7 @@ def _heuristic_func_bmm_bf16( @backend_requirement( { "cutlass": _cutlass_bmm_bf16_requirement, + "cudnn": _cudnn_bmm_bf16_requirement, }, common_check=_check_bmm_bf16_problem_size, heuristic_func=_heuristic_func_bmm_bf16, @@ -441,7 +509,7 @@ def bmm_bf16( B: torch.Tensor, out: Optional[torch.Tensor] = None, out_dtype: torch.dtype = torch.bfloat16, - backend: Literal["cutlass"] = "cutlass", + backend: Literal["cudnn", "cutlass", "auto"] = "cutlass", ) -> torch.Tensor: r"""BMM BF16 @@ -459,8 +527,8 @@ def bmm_bf16( out_dtype: torch.dtype Output dtype, bf16 (default) or fp16. - backend: Literal["cutlass"] - Backend to use, defaults to "cutlass". + backend: Literal["cudnn", "cutlass", "auto"] + Backend to use, defaults to "cutlass". ``"auto"`` allows selecting the best tactic from all available backends when autotune is enabled. Returns ------- @@ -487,24 +555,17 @@ def bmm_bf16( device=A.device, dtype=out_dtype, ) - else: - if out.shape != expected_shape: - raise ValueError( - f"Output shape mismatch. Expected {expected_shape}, got {out.shape}." - ) - if out.device != A.device: - raise ValueError( - f"Output device mismatch. Expected {A.device}, got {out.device}." - ) - if out.dtype != out_dtype: - raise ValueError( - f"Output dtype mismatch. Expected {out_dtype}, got {out.dtype}." - ) workspace_buffer = _get_cache_buf( "bmm_bf16_workspace", DEFAULT_WORKSPACE_SIZE, A.device ) - bf16_gemm_sm100(A, B, None, False, out, workspace_buffer, ["cutlass"]) + + if backend == "auto": + backends = bmm_bf16.suitable_auto_backends + else: + backends = [backend] + + bf16_gemm_sm100(A, B, None, False, out, workspace_buffer, backends) return out @@ -824,6 +885,8 @@ def bf16_gemm_sm100( ) -> None: runners = [] use_sm_100f = is_sm100f_supported(a.device) + if "cudnn" in runner_names: + runners.append(_cudnn_gemm_bf16_runner()) if "cutlass" in runner_names: runners.append(get_gemm_sm100_module_cutlass_bf16().cutlass_bf16_gemm_runner()) if "tgv" in runner_names: @@ -2041,6 +2104,138 @@ def forward( return CudnnFp8GemmRunner() +def _get_bf16_3d_shape_stride(tensor: torch.Tensor): + """Expand 2d tensor to 3d tensor for cuDNN""" + shape = list(tensor.shape) + stride = list(tensor.stride()) + + if len(shape) == 2: + shape.insert(0, 1) + stride.insert(0, tensor.numel()) + + return (tuple(shape), tuple(stride)) + + +@functools.cache +def build_cudnn_gemm_bf16_graph(a_shape, a_stride, b_shape, b_stride, o_type, device): + _check_cudnn_availability() + + stream = torch.cuda.current_stream(device) + with cudnn.graph(_get_cudnn_handle(stream)) as (graph, _): + a_cudnn_tensor = graph.tensor( + name="a", dim=a_shape, stride=a_stride, data_type=cudnn.data_type.BFLOAT16 + ) + b_cudnn_tensor = graph.tensor( + name="b", dim=b_shape, stride=b_stride, data_type=cudnn.data_type.BFLOAT16 + ) + c_cudnn_tensor = graph.matmul( + name="matmul", + A=a_cudnn_tensor, + B=b_cudnn_tensor, + compute_data_type=cudnn.data_type.FLOAT, + ) + c_cudnn_tensor.set_name("c").set_output(True).set_data_type(o_type) + + a_cudnn_tensor.set_uid(UIDs.A_UID.value) + b_cudnn_tensor.set_uid(UIDs.B_UID.value) + c_cudnn_tensor.set_uid(UIDs.O_UID.value) + + graph.validate() + graph.build_operation_graph() + graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) + graph.check_support() + graph.build_plans() + + return graph + + +def execute_cudnn_gemm_bf16_graph(graph, a, b, c_final, workspace, tactic: int = -1): + variant_pack = { + UIDs.A_UID.value: a, + UIDs.B_UID.value: b, + UIDs.O_UID.value: c_final, + } + + stream = torch.cuda.current_stream(a.device) + cudnn_handle = _get_cudnn_handle(stream) + + if workspace.numel() < graph.get_workspace_size(): + workspace = torch.empty( + graph.get_workspace_size(), device=a.device, dtype=torch.uint8 + ) + + if tactic == -1: + graph.execute(variant_pack, workspace, handle=cudnn_handle) + else: + graph.execute_plan_at_index( + variant_pack, workspace, tactic, handle=cudnn_handle + ) + + +def _cudnn_gemm_bf16( + workspace: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + out: torch.Tensor, + tactic: int = -1, +): + _check_cudnn_availability() + + # This allows the same graph to work for both mm (2D) and bmm (3D) + a_shape, a_stride = _get_bf16_3d_shape_stride(a) + b_shape, b_stride = _get_bf16_3d_shape_stride(b) + + graph = build_cudnn_gemm_bf16_graph( + a_shape, + a_stride, + b_shape, + b_stride, + _torch_data_type_to_cudnn_data_type(out.dtype), + a.device, + ) + execute_cudnn_gemm_bf16_graph(graph, a, b, out, workspace, tactic=tactic) + return out + + +def _cudnn_gemm_bf16_runner(): + class CudnnBf16GemmRunner(TunableRunner): + def get_valid_tactics( + self, + inputs: List[torch.Tensor], + profile: OptimizationProfile, + ) -> List[int]: + a, b, _, _, out, _ = inputs + a_shape, a_stride = _get_bf16_3d_shape_stride(a) + b_shape, b_stride = _get_bf16_3d_shape_stride(b) + + graph = build_cudnn_gemm_bf16_graph( + a_shape, + a_stride, + b_shape, + b_stride, + _torch_data_type_to_cudnn_data_type(out.dtype), + a.device, + ) + return list(range(graph.get_execution_plan_count())) + + def forward( + self, + inputs: List[torch.Tensor], + tactic: int = -1, + do_preparation: bool = False, + **kwargs, + ) -> torch.Tensor: + a, b, bias, pdl, out, workspace_buffer = inputs + if bias is not None: + raise ValueError("cudnn bf16 gemm does not support bias.") + if pdl: + raise ValueError("cudnn bf16 gemm does not support pdl.") + _cudnn_gemm_bf16(workspace_buffer, a, b, out, tactic=tactic) + return out + + return CudnnBf16GemmRunner() + + def _get_real_fp4_shape_from_packed_uint8(packed_fp4_tensor): # the FP4 data are packed into uint8, we need to expand the shape and stride information to get the real shape and stride to be used in the cuDNN graph. is_column_major = packed_fp4_tensor.stride(-2) == 1 diff --git a/tests/gemm/test_bmm_bf16.py b/tests/gemm/test_bmm_bf16.py index 5cd0ac8337..24541f329f 100644 --- a/tests/gemm/test_bmm_bf16.py +++ b/tests/gemm/test_bmm_bf16.py @@ -11,7 +11,8 @@ @pytest.mark.parametrize("n", [80, 64]) @pytest.mark.parametrize("k", [64, 256]) @pytest.mark.parametrize("res_dtype", [torch.bfloat16, torch.float16]) -def test_bmm_bf16(b, m, n, k, res_dtype): +@pytest.mark.parametrize("backend", ["cutlass", "cudnn"]) +def test_bmm_bf16(b, m, n, k, res_dtype, backend): compute_capability = get_compute_capability(torch.device(device="cuda")) compute_capability_number = compute_capability[0] * 10 + compute_capability[1] if not bmm_bf16.is_compute_capability_supported(compute_capability_number): @@ -19,6 +20,15 @@ def test_bmm_bf16(b, m, n, k, res_dtype): f"bmm_bf16 not supported on current compute capability." f"Detected sm{compute_capability_number}." ) + if not bmm_bf16.is_backend_supported(backend, compute_capability_number): + pytest.skip(f"{backend} backend not supported on current compute capability.") + # cuDNN on SM103 does not support bf16 input -> fp16 output + if ( + backend == "cudnn" + and compute_capability_number == 103 + and res_dtype == torch.float16 + ): + pytest.skip("cuDNN bf16 GEMM with fp16 output not supported on SM103.") torch.manual_seed(7) input = torch.randn([b, m, k], device="cuda", dtype=torch.bfloat16) mat2 = torch.randn([b, n, k], device="cuda", dtype=torch.bfloat16).transpose(-2, -1) @@ -26,7 +36,7 @@ def test_bmm_bf16(b, m, n, k, res_dtype): out = torch.empty([b, m, n], device="cuda", dtype=res_dtype) with autotune(): - bmm_bf16(input, mat2, out=out, out_dtype=res_dtype) + bmm_bf16(input, mat2, out=out, out_dtype=res_dtype, backend=backend) cos_sim = F.cosine_similarity(reference.reshape(-1), out.reshape(-1), dim=0) assert cos_sim > 0.99 diff --git a/tests/gemm/test_mm_bf16.py b/tests/gemm/test_mm_bf16.py index ef6151e26b..384d3a52e6 100644 --- a/tests/gemm/test_mm_bf16.py +++ b/tests/gemm/test_mm_bf16.py @@ -12,7 +12,7 @@ @pytest.mark.parametrize("res_dtype", [torch.bfloat16, torch.float16]) @pytest.mark.parametrize("enable_bias", [True, False]) @pytest.mark.parametrize("pdl", [True, False]) -@pytest.mark.parametrize("backend", ["cutlass", "tgv"]) +@pytest.mark.parametrize("backend", ["cudnn", "cutlass", "tgv"]) def test_mm_bf16( m: int, n: int, @@ -29,7 +29,13 @@ def test_mm_bf16( f"mm_bf16 not supported on current compute capability." f"Detected sm{compute_capability_number}." ) + if not mm_bf16.is_backend_supported(backend, compute_capability_number): + pytest.skip(f"{backend} backend not supported on current compute capability.") + if backend == "cudnn" and (enable_bias or pdl): + pytest.skip( + "mm_bf16 with cuDNN backend does not support bias or pdl arguments." + ) if backend == "cutlass" and (enable_bias or pdl): pytest.skip( "mm_bf16 with CUTLASS backend does not support bias or pdl arguments." @@ -38,6 +44,13 @@ def test_mm_bf16( pytest.skip( "mm_bf16 with TGV backend does not support specifying non-bfloat16 result dtypes." ) + # cuDNN on SM103 does not support bf16 input -> fp16 output + if ( + backend == "cudnn" + and compute_capability_number == 103 + and res_dtype == torch.float16 + ): + pytest.skip("cuDNN bf16 GEMM with fp16 output not supported on SM103.") torch.manual_seed(42) input = torch.randn([m, k], device="cuda", dtype=torch.bfloat16)