diff --git a/flashinfer/gemm.py b/flashinfer/gemm.py index 7f799fa7f6..1726a34195 100755 --- a/flashinfer/gemm.py +++ b/flashinfer/gemm.py @@ -86,38 +86,32 @@ def gen_gemm_module() -> JitSpec: def get_gemm_module(): module = gen_gemm_module().build_and_load() - # torch library for bmm_fp8 - - @register_custom_op("flashinfer::bmm_fp8", mutates_args=("workspace_buffer", "D")) - def bmm_fp8( - workspace_buffer: torch.Tensor, - A: torch.Tensor, - B: torch.Tensor, - D: torch.Tensor, - A_scale: torch.Tensor, - B_scale: torch.Tensor, - ) -> None: - cublas_handle = torch.cuda.current_blas_handle() - module.bmm_fp8.default( - A, - B, - D, - A_scale, - B_scale, - workspace_buffer, - cublas_handle, - ) + # auto-tuned cublas fp8 gemm runner + def cublas_fp8_gemm_runner(): + class CublasFp8GemmRunner(TunableRunner): + def get_valid_tactics( + self, + inputs: List[torch.Tensor], + profile: OptimizationProfile, + ) -> List[int]: + # cublas has heuristic for fp8 gemm, so we only need to use the default tactic + return [0] + + def forward( + self, + inputs: List[torch.Tensor], + *, + tactic: int = -1, + do_preparation: bool = False, + ) -> torch.Tensor: + cublas_handle = torch.cuda.current_blas_handle() + a, b, scale_a, scale_b, out, workspace_buffer = inputs + module.bmm_fp8.default( + a, b, out, scale_a, scale_b, workspace_buffer, cublas_handle + ) + return out - @register_fake_op("flashinfer::bmm_fp8") - def _fake_bmm_fp8( - workspace_buffer: torch.Tensor, - A: torch.Tensor, - B: torch.Tensor, - D: torch.Tensor, - A_scale: torch.Tensor, - B_scale: torch.Tensor, - ) -> None: - pass + return CublasFp8GemmRunner() # torch library for cutlass_segment_gemm @@ -166,7 +160,7 @@ def _fake_cutlass_segment_gemm( # Register the module _gemm_module = SimpleNamespace( - bmm_fp8=bmm_fp8, + cublas_fp8_gemm_runner=cublas_fp8_gemm_runner, cutlass_segment_gemm=cutlass_segment_gemm, ) @@ -392,77 +386,89 @@ def get_trtllm_gemm_module(): def get_gemm_sm100_module_cutlass_fp8(): module = gen_gemm_sm100_module_cutlass_fp8().build_and_load() - class CutlassFp8GemmRunner(TunableRunner): - def __init__(self): - self._fp8_gemm_runner = module.fp8_gemm - - def get_valid_tactics( - self, - inputs: List[torch.Tensor], - profile: OptimizationProfile, - ) -> List[int]: - return list(range(module.fp8_gemm_tactic_num())) + def cutlass_fp8_gemm_runner(): + class CutlassFp8GemmRunner(TunableRunner): + def get_valid_tactics( + self, + inputs: List[torch.Tensor], + profile: OptimizationProfile, + ) -> List[int]: + return list(range(module.fp8_gemm_tactic_num())) + + def forward( + self, + inputs: List[torch.Tensor], + *, + tactic: int = -1, + do_preparation: bool = False, + ) -> torch.Tensor: + a, b, scale_a, scale_b, out, workspace_buffer = inputs + module.fp8_gemm.default( + a, + b.transpose(-2, -1), + scale_a * scale_b, + out, + workspace_buffer, + tactic, + ) + return out - def forward( - self, - inputs: List[torch.Tensor], - *, - tactic: int = -1, - do_preparation: bool = False, - ): - a, b, alpha, out, workspace_buffer = inputs - module.fp8_gemm.default(a, b, alpha, out, workspace_buffer, tactic) - return out + return CutlassFp8GemmRunner() - @register_custom_op( - "flashinfer::cutlass_fp8_gemm", - mutates_args=(""), + # Register the module + return SimpleNamespace( + cutlass_fp8_gemm_runner=cutlass_fp8_gemm_runner, ) - def cutlass_fp8_gemm( - a: torch.Tensor, - b: torch.Tensor, - alpha: torch.Tensor, - out: torch.Tensor, - workspace_buffer: torch.Tensor, - ): - tuner = AutoTuner.get() - a_tensor_index = 0 - out_tensor_index = 3 - tuning_config = TuningConfig( - dynamic_tensor_specs=( - DynamicTensorSpec( - a_tensor_index, - -2, - get_last_power_of_2_num_tokens_buckets, - last_positive_power_of_2, - ), +def fp8_gemm_sm100( + a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out: torch.Tensor, + workspace_buffer: torch.Tensor, + runner_names: List[str], +) -> None: + runners = [] + # No e5m2 for cutlass + is_e5m2 = a.dtype == torch.float8_e5m2 or b.dtype == torch.float8_e5m2 + if "cutlass" in runner_names and not is_e5m2: + runners.append(get_gemm_sm100_module_cutlass_fp8().cutlass_fp8_gemm_runner()) + if "cublas" in runner_names: + runners.append(get_gemm_module().cublas_fp8_gemm_runner()) + if CUDNN_AVAILABLE and "cudnn" in runner_names: + runners.append(_cudnn_gemm_fp8_runner()) + + tuner = AutoTuner.get() + a_tensor_index = 0 + out_tensor_index = 4 + tuning_config = TuningConfig( + dynamic_tensor_specs=( + DynamicTensorSpec( + a_tensor_index, + -2, + get_last_power_of_2_num_tokens_buckets, + last_positive_power_of_2, ), - constraint_specs=( - ConstraintSpec( - out_tensor_index, -2, lambda shapes: shapes[a_tensor_index][-2] - ), + ), + constraint_specs=( + ConstraintSpec( + out_tensor_index, -2, lambda shapes: shapes[a_tensor_index][-2] ), - ) - - fp8_runner = CutlassFp8GemmRunner() - - inputs = [a, b, alpha, out, workspace_buffer] - _, tactic = tuner.choose_one( - "cutlass_fp8_gemm", - [fp8_runner], - tuning_config, - inputs, - ) - - fp8_runner(inputs=inputs, tactic=tactic) + ), + ) - # Register the module - return SimpleNamespace( - cutlass_fp8_gemm=cutlass_fp8_gemm, + inputs = [a, b, scale_a, scale_b, out, workspace_buffer] + runner, tactic = tuner.choose_one( + "fp8_gemm", + runners, + tuning_config, + inputs, ) + runner(inputs=inputs, tactic=tactic) + @functools.cache def get_gemm_sm100_module_cutlass_fp4(): @@ -1401,6 +1407,30 @@ def _cudnn_gemm_fp8( return out +def _cudnn_gemm_fp8_runner(): + class CudnnFp8GemmRunner(TunableRunner): + def get_valid_tactics( + self, + inputs: List[torch.Tensor], + profile: OptimizationProfile, + ) -> List[int]: + # cudnn has heuristic for fp8 gemm, so we only need to use the default tactic + return [0] + + def forward( + self, + inputs: List[torch.Tensor], + *, + tactic: int = -1, + do_preparation: bool = False, + ) -> torch.Tensor: + a, b, scale_a, scale_b, out, workspace_buffer = inputs + _cudnn_gemm_fp8(workspace_buffer, a, b, scale_a, scale_b, out, out.dtype) + return out + + return CudnnFp8GemmRunner() + + def _get_real_fp4_shape_from_packed_uint8(packed_fp4_tensor): # the FP4 data are packed into uint8, we need to expand the shape and stride information to get the real shape and stride to be used in the cuDNN graph. is_column_major = packed_fp4_tensor.stride(-2) == 1 @@ -1647,7 +1677,7 @@ def bmm_fp8( B_scale: torch.Tensor, dtype: torch.dtype, out: Optional[torch.Tensor] = None, - backend: Literal["cudnn", "cublas", "cutlass"] = "cublas", + backend: Literal["cudnn", "cublas", "cutlass", "auto"] = "cublas", ) -> torch.Tensor: r"""BMM FP8 @@ -1671,8 +1701,9 @@ def bmm_fp8( out: Optional[torch.Tensor] Out tensor, shape (b, m, n), bf16 or fp16, defaults to ``None``. - backend: Literal["cudnn", "cublas", "cutlass"] + backend: Literal["cudnn", "cublas", "cutlass", "auto"] The backend to use for the operation. Defaults to ``"cublas"``. + ``"auto"`` allows selecting the best tactic from all available backends when autotune is enabled. Returns ------- @@ -1715,17 +1746,21 @@ def bmm_fp8( workspace_buffer = _get_cache_buf( "bmm_fp8_workspace", DEFAULT_WORKSPACE_SIZE, A.device ) + if backend == "cudnn": - return _cudnn_gemm_fp8(workspace_buffer, A, B, A_scale, B_scale, out, dtype) + backends = ["cudnn"] elif backend == "cublas": - get_gemm_module().bmm_fp8(workspace_buffer, A, B, out, A_scale, B_scale) + backends = ["cublas"] elif backend == "cutlass": if A.dtype == torch.float8_e5m2 or B.dtype == torch.float8_e5m2: raise ValueError("e5m2 is not supported for cutlass backend") + backends = ["cutlass"] + elif backend == "auto": + backends = ["cutlass", "cublas", "cudnn"] + else: + raise ValueError(f"Unsupported backend: {backend}") - get_gemm_sm100_module_cutlass_fp8().cutlass_fp8_gemm( - A, B.transpose(-2, -1), A_scale * B_scale, out, workspace_buffer - ) + fp8_gemm_sm100(A, B, A_scale, B_scale, out, workspace_buffer, backends) return out diff --git a/tests/test_bmm_fp8.py b/tests/test_bmm_fp8.py index 4da991e704..35d45150bb 100644 --- a/tests/test_bmm_fp8.py +++ b/tests/test_bmm_fp8.py @@ -21,7 +21,7 @@ def to_float8(x, dtype=torch.float8_e4m3fn): @pytest.mark.parametrize("input_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) @pytest.mark.parametrize("mat2_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) @pytest.mark.parametrize("res_dtype", [torch.bfloat16, torch.float16]) -@pytest.mark.parametrize("backend", ["cudnn", "cublas", "cutlass"]) +@pytest.mark.parametrize("backend", ["cudnn", "cublas", "cutlass", "auto"]) @pytest.mark.parametrize("auto_tuning", [True, False]) def test_bmm_fp8(b, m, n, k, input_dtype, mat2_dtype, res_dtype, backend, auto_tuning): if input_dtype == torch.float8_e5m2 and mat2_dtype == torch.float8_e5m2: