Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 95 additions & 1 deletion python/tvm/relax/backend/cuda/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def get_object_file_path(src: Path) -> Path:

# Determine compute version
compute_version = "".join(tvm.contrib.nvcc.get_target_compute_version(target).split("."))
if compute_version in ["90"]:
if compute_version in ["90", "100"]:
compute_version += "a"
cuda_cflags += [
"-gencode",
Expand Down Expand Up @@ -488,3 +488,97 @@ def gen_sampling_module(target: Target, num_threads: int = 8):
object_files = _compile_flashinfer_kernels(uri, source_paths, target, num_threads)
modules = _load_flashinfer_modules(object_files)
return modules


def gen_grouped_gemm_module(
dtype_a: str,
dtype_b: str,
dtype_out: str,
scale_granularity_m: int,
scale_granularity_n: int,
scale_granularity_k: int,
scale_major_mode: str,
mma_sm: int,
target: Target,
num_threads: int = 8,
) -> List[tvm.runtime.Module]:
"""Generate a FlashInfer module for FP8 grouped GEMM.

Parameters
----------
dtype_a : str
The data type of matrix A (e.g., "float8_e4m3fn").
dtype_b : str
The data type of matrix B (e.g., "float8_e4m3fn").
dtype_out : str
The data type of the output matrix (e.g., "bfloat16").
scale_granularity_m : int
The scaling granularity in the M dimension.
scale_granularity_n : int
The scaling granularity in the N dimension.
scale_granularity_k : int
The scaling granularity in the K dimension.
scale_major_mode : str
The scale storage mode ("K" or "MN").
mma_sm : int
The MMA scheduling mode (1 or 2).
target : Target
The target device to compile for.
num_threads : int
The number of threads to use for compilation.

Returns
-------
List[tvm.runtime.Module]
A list of compiled static library modules for FlashInfer FP8 grouped GEMM kernels.

Note
_____
when apply grouped gemm on A: (total_m, k), B: (batch_size, n, k), m_indptr: (batch_size, )
requires all m in m_indptr to be multiple of 4
"""
try:
from flashinfer.jit import ( # pylint: disable=import-outside-toplevel
gen_grouped_gemm_fp8_tvm_binding,
get_grouped_gemm_fp8_uri,
)
except ImportError:
raise ImportError(
"FlashInfer is not installed. Please follow instructions "
"in https://docs.flashinfer.ai to install FlashInfer."
)
try:
import torch # pylint: disable=import-outside-toplevel
except ImportError:
raise ImportError("PyTorch is not installed. Please install PyTorch to use FlashInfer.")

torch_dtype_a = getattr(torch, dtype_a)
torch_dtype_b = getattr(torch, dtype_b)
torch_dtype_out = getattr(torch, dtype_out)

uri = get_grouped_gemm_fp8_uri(
dtype_a=torch_dtype_a,
dtype_b=torch_dtype_b,
dtype_out=torch_dtype_out,
scale_granularity_m=scale_granularity_m,
scale_granularity_n=scale_granularity_n,
scale_granularity_k=scale_granularity_k,
scale_major_mode=scale_major_mode,
mma_sm=mma_sm,
)

uri, source_paths = gen_grouped_gemm_fp8_tvm_binding(
uri=uri,
dtype_a=torch_dtype_a,
dtype_b=torch_dtype_b,
dtype_out=torch_dtype_out,
scale_granularity_m=scale_granularity_m,
scale_granularity_n=scale_granularity_n,
scale_granularity_k=scale_granularity_k,
scale_major_mode=scale_major_mode,
mma_sm=mma_sm,
)

object_files = _compile_flashinfer_kernels(uri, source_paths, target, num_threads)
modules = _load_flashinfer_modules(object_files)
return modules
Loading