Skip to content

Commit d728bcd

Browse files
authored
Support checks PoC (#1809)
<!-- .github/pull_request_template.md --> ## 📌 Description This PR adds is_backend_supported and is_compute_capability_supported checks through decorators. 1. This allows us to check support before running 2. It also wraps the original function so it calls back the support check before running. 3. The wrapped function also adds an optional parameter "skip_check". A quick measurement show only minimal impact (14.51s without checks, 14.58s with checks for all of test_mm_fp4), so we should further benchmark the usefulness of this feature. Example: <img width="680" height="641" alt="Screenshot 2025-10-12 at 9 39 06 PM" src="https://github.com/user-attachments/assets/79eb2eb3-f7b3-49a8-a7a2-2694e1ea3937" /> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent bea5949 commit d728bcd

File tree

4 files changed

+695
-75
lines changed

4 files changed

+695
-75
lines changed

flashinfer/gemm.py

Lines changed: 212 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@
4141
is_sm120a_supported,
4242
is_sm121a_supported,
4343
LibraryError,
44+
backend_requirement,
45+
supported_compute_capability,
4446
)
4547
from .jit.gemm import gen_gemm_sm90_module
4648
from .jit.gemm import gen_gemm_module
@@ -81,6 +83,9 @@
8183

8284
DEFAULT_WORKSPACE_SIZE = 32 * 1024 * 1024
8385

86+
# Error messages
87+
CUDNN_FP4_MXFP4_SM120_CUDNN_VERSION_ERROR = "cudnn FP4 GEMM with mxfp4 quantization is not supported on SM120 with cuDNN backend version < 9.14.0."
88+
8489

8590
def _match_sm_version(device: torch.device, sm_version: list[str]):
8691
major, minor = get_compute_capability(device)
@@ -1182,7 +1187,7 @@ def _validate_fp8_output_dtype(dtype: torch.dtype):
11821187

11831188

11841189
@functools.cache
1185-
def build_cudnn_gemm_block_scale_dequantize_graph(
1190+
def create_cudnn_execution_plans_fp4_gemm(
11861191
a_shape,
11871192
a_stride,
11881193
b_shape,
@@ -1279,12 +1284,49 @@ def build_cudnn_gemm_block_scale_dequantize_graph(
12791284
# in older cuDNN versions, so we deselect it.
12801285
if (alpha_is_not_none) and (not _is_cublas_fp4_available_in_cudnn()):
12811286
graph.deselect_engines(["eng0"])
1282-
graph.check_support()
1283-
graph.build_plans()
12841287

12851288
return graph
12861289

12871290

1291+
@functools.cache
1292+
def build_plans_cudnn_fp4_gemm_graph(
1293+
a_shape,
1294+
a_stride,
1295+
b_shape,
1296+
b_stride,
1297+
a_descale_shape,
1298+
a_descale_stride,
1299+
b_descale_shape,
1300+
b_descale_stride,
1301+
ab_type,
1302+
o_type,
1303+
block_size,
1304+
device,
1305+
alpha,
1306+
use_nvfp4,
1307+
):
1308+
graph = create_cudnn_execution_plans_fp4_gemm(
1309+
a_shape,
1310+
a_stride,
1311+
b_shape,
1312+
b_stride,
1313+
a_descale_shape,
1314+
a_descale_stride,
1315+
b_descale_shape,
1316+
b_descale_stride,
1317+
ab_type,
1318+
o_type,
1319+
block_size,
1320+
device,
1321+
alpha,
1322+
use_nvfp4,
1323+
)
1324+
1325+
graph.check_support()
1326+
graph.build_plans()
1327+
return graph
1328+
1329+
12881330
def execute_cudnn_gemm_fp4_graph(
12891331
graph,
12901332
a,
@@ -1647,6 +1689,172 @@ def mm_fp8(
16471689
return out
16481690

16491691

1692+
def _check_mm_fp4_problem_size(
1693+
a: torch.Tensor,
1694+
b: torch.Tensor,
1695+
a_descale: torch.Tensor,
1696+
b_descale: torch.Tensor,
1697+
alpha: Optional[torch.Tensor] = None,
1698+
out_dtype: torch.dtype = torch.bfloat16,
1699+
out: Optional[torch.Tensor] = None,
1700+
block_size: int = 16,
1701+
use_8x4_sf_layout: bool = False,
1702+
backend: Literal["cudnn", "trtllm", "cutlass"] = "cudnn",
1703+
use_nvfp4: bool = True,
1704+
):
1705+
# Generic checks
1706+
## pre-check the input tensor, block scale tensor and alpha tensor
1707+
if a.ndim != 2 or b.ndim != 2:
1708+
raise ValueError(f"mm_fp4 accepts 2d tensors, got {a.shape} and {b.shape}")
1709+
if a.shape[1] != b.shape[0]:
1710+
raise ValueError(
1711+
f"K dimension mismatch in mm_fp4. got a.shape[1] = {a.shape[1]}, b.shape[0] = {b.shape[0]}"
1712+
)
1713+
if a.dtype not in {torch.uint8, get_native_fp4_dtype()} or b.dtype not in {
1714+
torch.uint8,
1715+
get_native_fp4_dtype(),
1716+
}:
1717+
raise ValueError(
1718+
f"a and b must have float4_e2m1fn_x2 packed into uint8. "
1719+
f"Got {a.dtype} and {b.dtype}."
1720+
)
1721+
if a_descale.dtype not in {
1722+
torch.float8_e4m3fn,
1723+
torch.uint8,
1724+
} or b_descale.dtype not in {torch.float8_e4m3fn, torch.uint8}:
1725+
raise ValueError(
1726+
f"a_descale and b_descale must have float8_e4m3fnx2 packed into uint8. "
1727+
f"Got {a_descale.dtype} and {b_descale.dtype}."
1728+
)
1729+
if alpha is not None and alpha.dtype != torch.float:
1730+
raise ValueError(f"alpha must be a float tensor, got {alpha.dtype}")
1731+
if alpha is not None and alpha.numel() != 1:
1732+
raise ValueError(f"alpha must be a scalar, got {alpha.numel()}")
1733+
1734+
if out_dtype not in (torch.bfloat16, torch.float16):
1735+
raise ValueError(
1736+
f"Unsupported output dtype: {out_dtype}. "
1737+
f"Only torch.bfloat16 and torch.float16 are supported for FP4 GEMM operations."
1738+
)
1739+
1740+
if backend != "trtllm" and use_8x4_sf_layout:
1741+
raise ValueError("Only TRTLLM FP4 GEMM supports 8x4 scale factor layout.")
1742+
if backend != "cudnn" and not use_nvfp4:
1743+
raise ValueError("Only cudnn FP4 GEMM supports mxfp4 quantization.")
1744+
1745+
if use_nvfp4 and block_size != 16:
1746+
raise ValueError("nvfp4 only supports block_size = 16.")
1747+
if not use_nvfp4 and block_size != 32:
1748+
raise ValueError("mxfp4 only supports block_size = 32.")
1749+
1750+
return True
1751+
1752+
1753+
@supported_compute_capability([100, 103, 110, 120])
1754+
def _cudnn_gemm_fp4_requirement(
1755+
a: torch.Tensor,
1756+
b: torch.Tensor,
1757+
a_descale: torch.Tensor,
1758+
b_descale: torch.Tensor,
1759+
alpha: Optional[torch.Tensor] = None,
1760+
out_dtype: torch.dtype = torch.bfloat16,
1761+
out: Optional[torch.Tensor] = None,
1762+
block_size: int = 16,
1763+
use_8x4_sf_layout: bool = False,
1764+
backend: Literal["cudnn", "trtllm", "cutlass"] = "cudnn",
1765+
use_nvfp4: bool = True,
1766+
):
1767+
if (
1768+
not use_nvfp4
1769+
and _match_sm_version(a.device, ["120"])
1770+
and cudnn.backend_version() < 91400
1771+
):
1772+
raise LibraryError(CUDNN_FP4_MXFP4_SM120_CUDNN_VERSION_ERROR)
1773+
1774+
_check_cudnn_fp4_availability()
1775+
1776+
# the fp4 cudnn graph will be shared for both mm and bmm, so
1777+
# here we need to get the 3d shape and stride including the
1778+
# batch dimension for both input and block scale tensors.
1779+
real_a_shape, real_a_stride = _get_real_fp4_shape_from_packed_uint8(a)
1780+
real_b_shape, real_b_stride = _get_real_fp4_shape_from_packed_uint8(b)
1781+
batch = real_a_shape[0]
1782+
expanded_a_descale_shape, expanded_a_descale_stride = (
1783+
_expand_block_scale_tensor_shape(a_descale, batch)
1784+
)
1785+
expanded_b_descale_shape, expanded_b_descale_stride = (
1786+
_expand_block_scale_tensor_shape(b_descale, batch)
1787+
)
1788+
1789+
# build the fp4 cudnn graph
1790+
graph = create_cudnn_execution_plans_fp4_gemm(
1791+
real_a_shape,
1792+
real_a_stride,
1793+
real_b_shape,
1794+
real_b_stride,
1795+
expanded_a_descale_shape,
1796+
expanded_a_descale_stride,
1797+
expanded_b_descale_shape,
1798+
expanded_b_descale_stride,
1799+
cudnn.data_type.FP4_E2M1,
1800+
_torch_data_type_to_cudnn_data_type(out_dtype),
1801+
block_size,
1802+
a.device,
1803+
alpha,
1804+
use_nvfp4,
1805+
)
1806+
graph.check_support()
1807+
1808+
return True
1809+
1810+
1811+
@supported_compute_capability([100, 103, 120])
1812+
def _trtllm_gemm_fp4_requirement(
1813+
a: torch.Tensor,
1814+
b: torch.Tensor,
1815+
a_descale: torch.Tensor,
1816+
b_descale: torch.Tensor,
1817+
alpha: Optional[torch.Tensor] = None,
1818+
out_dtype: torch.dtype = torch.bfloat16,
1819+
out: Optional[torch.Tensor] = None,
1820+
block_size: int = 16,
1821+
use_8x4_sf_layout: bool = False,
1822+
backend: Literal["cudnn", "trtllm", "cutlass"] = "cudnn",
1823+
use_nvfp4: bool = True,
1824+
):
1825+
if out_dtype != torch.bfloat16:
1826+
raise ValueError(
1827+
f"Unsupported output dtype: {out_dtype}. "
1828+
f"Only torch.bfloat16 is supported for TRTLLM FP4 GEMM operations."
1829+
)
1830+
return True
1831+
1832+
1833+
@supported_compute_capability([100, 103, 120])
1834+
def _cutlass_gemm_fp4_requirement(
1835+
a: torch.Tensor,
1836+
b: torch.Tensor,
1837+
a_descale: torch.Tensor,
1838+
b_descale: torch.Tensor,
1839+
alpha: Optional[torch.Tensor] = None,
1840+
out_dtype: torch.dtype = torch.bfloat16,
1841+
out: Optional[torch.Tensor] = None,
1842+
block_size: int = 16,
1843+
use_8x4_sf_layout: bool = False,
1844+
backend: Literal["cudnn", "trtllm", "cutlass"] = "cudnn",
1845+
use_nvfp4: bool = True,
1846+
):
1847+
return True
1848+
1849+
1850+
@backend_requirement(
1851+
{
1852+
"cudnn": _cudnn_gemm_fp4_requirement, # Each backend has its own requirement function
1853+
"trtllm": _trtllm_gemm_fp4_requirement,
1854+
"cutlass": _cutlass_gemm_fp4_requirement,
1855+
},
1856+
common_check=_check_mm_fp4_problem_size, # Shape checks common to all backends
1857+
)
16501858
def mm_fp4(
16511859
a: torch.Tensor,
16521860
b: torch.Tensor,
@@ -1721,59 +1929,6 @@ def mm_fp4(
17211929
>>> out.shape
17221930
torch.Size([48, 256])
17231931
"""
1724-
# pre-check the input tensor, block scale tensor and alpha tensor
1725-
if a.ndim != 2 or b.ndim != 2:
1726-
raise ValueError(f"mm_fp4 accepts 2d tensors, got {a.shape} and {b.shape}")
1727-
if a.shape[1] != b.shape[0]:
1728-
raise ValueError(
1729-
f"K dimension mismatch in mm_fp4. got a.shape[1] = {a.shape[1]}, b.shape[0] = {b.shape[0]}"
1730-
)
1731-
if a.dtype not in {torch.uint8, get_native_fp4_dtype()} or b.dtype not in {
1732-
torch.uint8,
1733-
get_native_fp4_dtype(),
1734-
}:
1735-
raise ValueError(
1736-
f"a and b must have float4_e2m1fn_x2 packed into uint8. "
1737-
f"Got {a.dtype} and {b.dtype}."
1738-
)
1739-
if a_descale.dtype not in {
1740-
torch.float8_e4m3fn,
1741-
torch.uint8,
1742-
} or b_descale.dtype not in {torch.float8_e4m3fn, torch.uint8}:
1743-
raise ValueError(
1744-
f"a_descale and b_descale must have float8_e4m3fnx2 packed into uint8. "
1745-
f"Got {a_descale.dtype} and {b_descale.dtype}."
1746-
)
1747-
if alpha is not None and alpha.dtype != torch.float:
1748-
raise ValueError(f"alpha must be a float tensor, got {alpha.dtype}")
1749-
if alpha is not None and alpha.numel() != 1:
1750-
raise ValueError(f"alpha must be a scalar, got {alpha.numel()}")
1751-
1752-
if out_dtype not in (torch.bfloat16, torch.float16):
1753-
raise ValueError(
1754-
f"Unsupported output dtype: {out_dtype}. "
1755-
f"Only torch.bfloat16 and torch.float16 are supported for FP4 GEMM operations."
1756-
)
1757-
1758-
if use_nvfp4 and block_size != 16:
1759-
raise ValueError("nvfp4 only supports block_size = 16.")
1760-
if not use_nvfp4 and block_size != 32:
1761-
raise ValueError("mxfp4 supports block_size = 32.")
1762-
if backend != "trtllm" and use_8x4_sf_layout:
1763-
raise ValueError("Only TRTLLM FP4 GEMM supports 8x4 scale factor layout.")
1764-
if backend == "trtllm" and _match_sm_version(a.device, ["110"]):
1765-
raise ValueError("TRTLLM FP4 GEMM is not supported on SM110.")
1766-
if backend != "cudnn" and not use_nvfp4:
1767-
raise ValueError("Only cudnn FP4 GEMM supports mxfp4 quantization.")
1768-
if (
1769-
backend == "cudnn"
1770-
and not use_nvfp4
1771-
and _match_sm_version(a.device, ["120"])
1772-
and cudnn.backend_version() < 91400
1773-
):
1774-
raise LibraryError(
1775-
"cudnn FP4 GEMM with mxfp4 quantization is not supported on SM120 with cuDNN backend version < 9.14.0."
1776-
)
17771932

17781933
# allocate the output tensor if not provided
17791934
if out is None:
@@ -1788,8 +1943,6 @@ def mm_fp4(
17881943
)
17891944

17901945
if backend == "cudnn":
1791-
_check_cudnn_fp4_availability()
1792-
17931946
# the fp4 cudnn graph will be shared for both mm and bmm, so
17941947
# here we need to get the 3d shape and stride including the
17951948
# batch dimension for both input and block scale tensors.
@@ -1804,7 +1957,7 @@ def mm_fp4(
18041957
)
18051958

18061959
# build the fp4 cudnn graph
1807-
graph = build_cudnn_gemm_block_scale_dequantize_graph(
1960+
graph = build_plans_cudnn_fp4_gemm_graph(
18081961
real_a_shape,
18091962
real_a_stride,
18101963
real_b_shape,
@@ -1826,12 +1979,6 @@ def mm_fp4(
18261979
graph, a, b, a_descale, b_descale, alpha, out, workspace_buffer
18271980
)
18281981
elif backend == "trtllm":
1829-
if out_dtype != torch.bfloat16:
1830-
raise ValueError(
1831-
f"Unsupported output dtype: {out_dtype}. "
1832-
f"Only torch.bfloat16 is supported for TRTLLM FP4 GEMM operations."
1833-
)
1834-
18351982
get_trtllm_fp4_gemm_module().trtllm_fp4_gemm(
18361983
a,
18371984
b.T,

0 commit comments

Comments
 (0)