Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
235 changes: 135 additions & 100 deletions flashinfer/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,38 +86,32 @@ def gen_gemm_module() -> JitSpec:
def get_gemm_module():
module = gen_gemm_module().build_and_load()

# torch library for bmm_fp8

@register_custom_op("flashinfer::bmm_fp8", mutates_args=("workspace_buffer", "D"))
def bmm_fp8(
workspace_buffer: torch.Tensor,
A: torch.Tensor,
B: torch.Tensor,
D: torch.Tensor,
A_scale: torch.Tensor,
B_scale: torch.Tensor,
) -> None:
cublas_handle = torch.cuda.current_blas_handle()
module.bmm_fp8.default(
A,
B,
D,
A_scale,
B_scale,
workspace_buffer,
cublas_handle,
)
# auto-tuned cublas fp8 gemm runner
def cublas_fp8_gemm_runner():
class CublasFp8GemmRunner(TunableRunner):
def get_valid_tactics(
self,
inputs: List[torch.Tensor],
profile: OptimizationProfile,
) -> List[int]:
# cublas has heuristic for fp8 gemm, so we only need to use the default tactic
return [0]

def forward(
self,
inputs: List[torch.Tensor],
*,
tactic: int = -1,
do_preparation: bool = False,
) -> torch.Tensor:
cublas_handle = torch.cuda.current_blas_handle()
a, b, scale_a, scale_b, out, workspace_buffer = inputs
module.bmm_fp8.default(
a, b, out, scale_a, scale_b, workspace_buffer, cublas_handle
)
return out

@register_fake_op("flashinfer::bmm_fp8")
def _fake_bmm_fp8(
workspace_buffer: torch.Tensor,
A: torch.Tensor,
B: torch.Tensor,
D: torch.Tensor,
A_scale: torch.Tensor,
B_scale: torch.Tensor,
) -> None:
pass
return CublasFp8GemmRunner()

# torch library for cutlass_segment_gemm

Expand Down Expand Up @@ -166,7 +160,7 @@ def _fake_cutlass_segment_gemm(

# Register the module
_gemm_module = SimpleNamespace(
bmm_fp8=bmm_fp8,
cublas_fp8_gemm_runner=cublas_fp8_gemm_runner,
cutlass_segment_gemm=cutlass_segment_gemm,
)

Expand Down Expand Up @@ -392,77 +386,89 @@ def get_trtllm_gemm_module():
def get_gemm_sm100_module_cutlass_fp8():
module = gen_gemm_sm100_module_cutlass_fp8().build_and_load()

class CutlassFp8GemmRunner(TunableRunner):
def __init__(self):
self._fp8_gemm_runner = module.fp8_gemm

def get_valid_tactics(
self,
inputs: List[torch.Tensor],
profile: OptimizationProfile,
) -> List[int]:
return list(range(module.fp8_gemm_tactic_num()))
def cutlass_fp8_gemm_runner():
class CutlassFp8GemmRunner(TunableRunner):
def get_valid_tactics(
self,
inputs: List[torch.Tensor],
profile: OptimizationProfile,
) -> List[int]:
return list(range(module.fp8_gemm_tactic_num()))

def forward(
self,
inputs: List[torch.Tensor],
*,
tactic: int = -1,
do_preparation: bool = False,
) -> torch.Tensor:
a, b, scale_a, scale_b, out, workspace_buffer = inputs
module.fp8_gemm.default(
a,
b.transpose(-2, -1),
scale_a * scale_b,
out,
workspace_buffer,
tactic,
)
return out

def forward(
self,
inputs: List[torch.Tensor],
*,
tactic: int = -1,
do_preparation: bool = False,
):
a, b, alpha, out, workspace_buffer = inputs
module.fp8_gemm.default(a, b, alpha, out, workspace_buffer, tactic)
return out
return CutlassFp8GemmRunner()

@register_custom_op(
"flashinfer::cutlass_fp8_gemm",
mutates_args=(""),
# Register the module
return SimpleNamespace(
cutlass_fp8_gemm_runner=cutlass_fp8_gemm_runner,
)
def cutlass_fp8_gemm(
a: torch.Tensor,
b: torch.Tensor,
alpha: torch.Tensor,
out: torch.Tensor,
workspace_buffer: torch.Tensor,
):
tuner = AutoTuner.get()

a_tensor_index = 0
out_tensor_index = 3

tuning_config = TuningConfig(
dynamic_tensor_specs=(
DynamicTensorSpec(
a_tensor_index,
-2,
get_last_power_of_2_num_tokens_buckets,
last_positive_power_of_2,
),
def fp8_gemm_sm100(
a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out: torch.Tensor,
workspace_buffer: torch.Tensor,
runner_names: List[str],
) -> None:
runners = []
# No e5m2 for cutlass
is_e5m2 = a.dtype == torch.float8_e5m2 or b.dtype == torch.float8_e5m2
if "cutlass" in runner_names and not is_e5m2:
runners.append(get_gemm_sm100_module_cutlass_fp8().cutlass_fp8_gemm_runner())
if "cublas" in runner_names:
runners.append(get_gemm_module().cublas_fp8_gemm_runner())
if CUDNN_AVAILABLE and "cudnn" in runner_names:
runners.append(_cudnn_gemm_fp8_runner())

tuner = AutoTuner.get()
a_tensor_index = 0
out_tensor_index = 4
tuning_config = TuningConfig(
dynamic_tensor_specs=(
DynamicTensorSpec(
a_tensor_index,
-2,
get_last_power_of_2_num_tokens_buckets,
last_positive_power_of_2,
),
constraint_specs=(
ConstraintSpec(
out_tensor_index, -2, lambda shapes: shapes[a_tensor_index][-2]
),
),
constraint_specs=(
ConstraintSpec(
out_tensor_index, -2, lambda shapes: shapes[a_tensor_index][-2]
),
)

fp8_runner = CutlassFp8GemmRunner()

inputs = [a, b, alpha, out, workspace_buffer]
_, tactic = tuner.choose_one(
"cutlass_fp8_gemm",
[fp8_runner],
tuning_config,
inputs,
)

fp8_runner(inputs=inputs, tactic=tactic)
),
)

# Register the module
return SimpleNamespace(
cutlass_fp8_gemm=cutlass_fp8_gemm,
inputs = [a, b, scale_a, scale_b, out, workspace_buffer]
runner, tactic = tuner.choose_one(
"fp8_gemm",
runners,
tuning_config,
inputs,
)

runner(inputs=inputs, tactic=tactic)


@functools.cache
def get_gemm_sm100_module_cutlass_fp4():
Expand Down Expand Up @@ -1401,6 +1407,30 @@ def _cudnn_gemm_fp8(
return out


def _cudnn_gemm_fp8_runner():
class CudnnFp8GemmRunner(TunableRunner):
def get_valid_tactics(
self,
inputs: List[torch.Tensor],
profile: OptimizationProfile,
) -> List[int]:
# cudnn has heuristic for fp8 gemm, so we only need to use the default tactic
return [0]

def forward(
self,
inputs: List[torch.Tensor],
*,
tactic: int = -1,
do_preparation: bool = False,
) -> torch.Tensor:
a, b, scale_a, scale_b, out, workspace_buffer = inputs
_cudnn_gemm_fp8(workspace_buffer, a, b, scale_a, scale_b, out, out.dtype)
return out

return CudnnFp8GemmRunner()


def _get_real_fp4_shape_from_packed_uint8(packed_fp4_tensor):
# the FP4 data are packed into uint8, we need to expand the shape and stride information to get the real shape and stride to be used in the cuDNN graph.
is_column_major = packed_fp4_tensor.stride(-2) == 1
Expand Down Expand Up @@ -1647,7 +1677,7 @@ def bmm_fp8(
B_scale: torch.Tensor,
dtype: torch.dtype,
out: Optional[torch.Tensor] = None,
backend: Literal["cudnn", "cublas", "cutlass"] = "cublas",
backend: Literal["cudnn", "cublas", "cutlass", "auto"] = "cublas",
) -> torch.Tensor:
r"""BMM FP8

Expand All @@ -1671,8 +1701,9 @@ def bmm_fp8(
out: Optional[torch.Tensor]
Out tensor, shape (b, m, n), bf16 or fp16, defaults to ``None``.

backend: Literal["cudnn", "cublas", "cutlass"]
backend: Literal["cudnn", "cublas", "cutlass", "auto"]
The backend to use for the operation. Defaults to ``"cublas"``.
``"auto"`` allows selecting the best tactic from all available backends when autotune is enabled.

Returns
-------
Expand Down Expand Up @@ -1715,17 +1746,21 @@ def bmm_fp8(
workspace_buffer = _get_cache_buf(
"bmm_fp8_workspace", DEFAULT_WORKSPACE_SIZE, A.device
)

if backend == "cudnn":
return _cudnn_gemm_fp8(workspace_buffer, A, B, A_scale, B_scale, out, dtype)
backends = ["cudnn"]
elif backend == "cublas":
get_gemm_module().bmm_fp8(workspace_buffer, A, B, out, A_scale, B_scale)
backends = ["cublas"]
elif backend == "cutlass":
if A.dtype == torch.float8_e5m2 or B.dtype == torch.float8_e5m2:
raise ValueError("e5m2 is not supported for cutlass backend")
backends = ["cutlass"]
elif backend == "auto":
backends = ["cutlass", "cublas", "cudnn"]
else:
raise ValueError(f"Unsupported backend: {backend}")

get_gemm_sm100_module_cutlass_fp8().cutlass_fp8_gemm(
A, B.transpose(-2, -1), A_scale * B_scale, out, workspace_buffer
)
fp8_gemm_sm100(A, B, A_scale, B_scale, out, workspace_buffer, backends)
return out


Expand Down
2 changes: 1 addition & 1 deletion tests/test_bmm_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
@pytest.mark.parametrize("input_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
@pytest.mark.parametrize("mat2_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
@pytest.mark.parametrize("res_dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("backend", ["cudnn", "cublas", "cutlass"])
@pytest.mark.parametrize("backend", ["cudnn", "cublas", "cutlass", "auto"])
@pytest.mark.parametrize("auto_tuning", [True, False])
def test_bmm_fp8(b, m, n, k, input_dtype, mat2_dtype, res_dtype, backend, auto_tuning):
if input_dtype == torch.float8_e5m2 and mat2_dtype == torch.float8_e5m2:
Expand Down