4141 is_sm120a_supported ,
4242 is_sm121a_supported ,
4343 LibraryError ,
44+ backend_requirement ,
45+ supported_compute_capability ,
4446)
4547from .jit .gemm import gen_gemm_sm90_module
4648from .jit .gemm import gen_gemm_module
8183
8284DEFAULT_WORKSPACE_SIZE = 32 * 1024 * 1024
8385
86+ # Error messages
87+ CUDNN_FP4_MXFP4_SM120_CUDNN_VERSION_ERROR = "cudnn FP4 GEMM with mxfp4 quantization is not supported on SM120 with cuDNN backend version < 9.14.0."
88+
8489
8590def _match_sm_version (device : torch .device , sm_version : list [str ]):
8691 major , minor = get_compute_capability (device )
@@ -1182,7 +1187,7 @@ def _validate_fp8_output_dtype(dtype: torch.dtype):
11821187
11831188
11841189@functools .cache
1185- def build_cudnn_gemm_block_scale_dequantize_graph (
1190+ def create_cudnn_execution_plans_fp4_gemm (
11861191 a_shape ,
11871192 a_stride ,
11881193 b_shape ,
@@ -1279,12 +1284,49 @@ def build_cudnn_gemm_block_scale_dequantize_graph(
12791284 # in older cuDNN versions, so we deselect it.
12801285 if (alpha_is_not_none ) and (not _is_cublas_fp4_available_in_cudnn ()):
12811286 graph .deselect_engines (["eng0" ])
1282- graph .check_support ()
1283- graph .build_plans ()
12841287
12851288 return graph
12861289
12871290
1291+ @functools .cache
1292+ def build_plans_cudnn_fp4_gemm_graph (
1293+ a_shape ,
1294+ a_stride ,
1295+ b_shape ,
1296+ b_stride ,
1297+ a_descale_shape ,
1298+ a_descale_stride ,
1299+ b_descale_shape ,
1300+ b_descale_stride ,
1301+ ab_type ,
1302+ o_type ,
1303+ block_size ,
1304+ device ,
1305+ alpha ,
1306+ use_nvfp4 ,
1307+ ):
1308+ graph = create_cudnn_execution_plans_fp4_gemm (
1309+ a_shape ,
1310+ a_stride ,
1311+ b_shape ,
1312+ b_stride ,
1313+ a_descale_shape ,
1314+ a_descale_stride ,
1315+ b_descale_shape ,
1316+ b_descale_stride ,
1317+ ab_type ,
1318+ o_type ,
1319+ block_size ,
1320+ device ,
1321+ alpha ,
1322+ use_nvfp4 ,
1323+ )
1324+
1325+ graph .check_support ()
1326+ graph .build_plans ()
1327+ return graph
1328+
1329+
12881330def execute_cudnn_gemm_fp4_graph (
12891331 graph ,
12901332 a ,
@@ -1647,6 +1689,172 @@ def mm_fp8(
16471689 return out
16481690
16491691
1692+ def _check_mm_fp4_problem_size (
1693+ a : torch .Tensor ,
1694+ b : torch .Tensor ,
1695+ a_descale : torch .Tensor ,
1696+ b_descale : torch .Tensor ,
1697+ alpha : Optional [torch .Tensor ] = None ,
1698+ out_dtype : torch .dtype = torch .bfloat16 ,
1699+ out : Optional [torch .Tensor ] = None ,
1700+ block_size : int = 16 ,
1701+ use_8x4_sf_layout : bool = False ,
1702+ backend : Literal ["cudnn" , "trtllm" , "cutlass" ] = "cudnn" ,
1703+ use_nvfp4 : bool = True ,
1704+ ):
1705+ # Generic checks
1706+ ## pre-check the input tensor, block scale tensor and alpha tensor
1707+ if a .ndim != 2 or b .ndim != 2 :
1708+ raise ValueError (f"mm_fp4 accepts 2d tensors, got { a .shape } and { b .shape } " )
1709+ if a .shape [1 ] != b .shape [0 ]:
1710+ raise ValueError (
1711+ f"K dimension mismatch in mm_fp4. got a.shape[1] = { a .shape [1 ]} , b.shape[0] = { b .shape [0 ]} "
1712+ )
1713+ if a .dtype not in {torch .uint8 , get_native_fp4_dtype ()} or b .dtype not in {
1714+ torch .uint8 ,
1715+ get_native_fp4_dtype (),
1716+ }:
1717+ raise ValueError (
1718+ f"a and b must have float4_e2m1fn_x2 packed into uint8. "
1719+ f"Got { a .dtype } and { b .dtype } ."
1720+ )
1721+ if a_descale .dtype not in {
1722+ torch .float8_e4m3fn ,
1723+ torch .uint8 ,
1724+ } or b_descale .dtype not in {torch .float8_e4m3fn , torch .uint8 }:
1725+ raise ValueError (
1726+ f"a_descale and b_descale must have float8_e4m3fnx2 packed into uint8. "
1727+ f"Got { a_descale .dtype } and { b_descale .dtype } ."
1728+ )
1729+ if alpha is not None and alpha .dtype != torch .float :
1730+ raise ValueError (f"alpha must be a float tensor, got { alpha .dtype } " )
1731+ if alpha is not None and alpha .numel () != 1 :
1732+ raise ValueError (f"alpha must be a scalar, got { alpha .numel ()} " )
1733+
1734+ if out_dtype not in (torch .bfloat16 , torch .float16 ):
1735+ raise ValueError (
1736+ f"Unsupported output dtype: { out_dtype } . "
1737+ f"Only torch.bfloat16 and torch.float16 are supported for FP4 GEMM operations."
1738+ )
1739+
1740+ if backend != "trtllm" and use_8x4_sf_layout :
1741+ raise ValueError ("Only TRTLLM FP4 GEMM supports 8x4 scale factor layout." )
1742+ if backend != "cudnn" and not use_nvfp4 :
1743+ raise ValueError ("Only cudnn FP4 GEMM supports mxfp4 quantization." )
1744+
1745+ if use_nvfp4 and block_size != 16 :
1746+ raise ValueError ("nvfp4 only supports block_size = 16." )
1747+ if not use_nvfp4 and block_size != 32 :
1748+ raise ValueError ("mxfp4 only supports block_size = 32." )
1749+
1750+ return True
1751+
1752+
1753+ @supported_compute_capability ([100 , 103 , 110 , 120 ])
1754+ def _cudnn_gemm_fp4_requirement (
1755+ a : torch .Tensor ,
1756+ b : torch .Tensor ,
1757+ a_descale : torch .Tensor ,
1758+ b_descale : torch .Tensor ,
1759+ alpha : Optional [torch .Tensor ] = None ,
1760+ out_dtype : torch .dtype = torch .bfloat16 ,
1761+ out : Optional [torch .Tensor ] = None ,
1762+ block_size : int = 16 ,
1763+ use_8x4_sf_layout : bool = False ,
1764+ backend : Literal ["cudnn" , "trtllm" , "cutlass" ] = "cudnn" ,
1765+ use_nvfp4 : bool = True ,
1766+ ):
1767+ if (
1768+ not use_nvfp4
1769+ and _match_sm_version (a .device , ["120" ])
1770+ and cudnn .backend_version () < 91400
1771+ ):
1772+ raise LibraryError (CUDNN_FP4_MXFP4_SM120_CUDNN_VERSION_ERROR )
1773+
1774+ _check_cudnn_fp4_availability ()
1775+
1776+ # the fp4 cudnn graph will be shared for both mm and bmm, so
1777+ # here we need to get the 3d shape and stride including the
1778+ # batch dimension for both input and block scale tensors.
1779+ real_a_shape , real_a_stride = _get_real_fp4_shape_from_packed_uint8 (a )
1780+ real_b_shape , real_b_stride = _get_real_fp4_shape_from_packed_uint8 (b )
1781+ batch = real_a_shape [0 ]
1782+ expanded_a_descale_shape , expanded_a_descale_stride = (
1783+ _expand_block_scale_tensor_shape (a_descale , batch )
1784+ )
1785+ expanded_b_descale_shape , expanded_b_descale_stride = (
1786+ _expand_block_scale_tensor_shape (b_descale , batch )
1787+ )
1788+
1789+ # build the fp4 cudnn graph
1790+ graph = create_cudnn_execution_plans_fp4_gemm (
1791+ real_a_shape ,
1792+ real_a_stride ,
1793+ real_b_shape ,
1794+ real_b_stride ,
1795+ expanded_a_descale_shape ,
1796+ expanded_a_descale_stride ,
1797+ expanded_b_descale_shape ,
1798+ expanded_b_descale_stride ,
1799+ cudnn .data_type .FP4_E2M1 ,
1800+ _torch_data_type_to_cudnn_data_type (out_dtype ),
1801+ block_size ,
1802+ a .device ,
1803+ alpha ,
1804+ use_nvfp4 ,
1805+ )
1806+ graph .check_support ()
1807+
1808+ return True
1809+
1810+
1811+ @supported_compute_capability ([100 , 103 , 120 ])
1812+ def _trtllm_gemm_fp4_requirement (
1813+ a : torch .Tensor ,
1814+ b : torch .Tensor ,
1815+ a_descale : torch .Tensor ,
1816+ b_descale : torch .Tensor ,
1817+ alpha : Optional [torch .Tensor ] = None ,
1818+ out_dtype : torch .dtype = torch .bfloat16 ,
1819+ out : Optional [torch .Tensor ] = None ,
1820+ block_size : int = 16 ,
1821+ use_8x4_sf_layout : bool = False ,
1822+ backend : Literal ["cudnn" , "trtllm" , "cutlass" ] = "cudnn" ,
1823+ use_nvfp4 : bool = True ,
1824+ ):
1825+ if out_dtype != torch .bfloat16 :
1826+ raise ValueError (
1827+ f"Unsupported output dtype: { out_dtype } . "
1828+ f"Only torch.bfloat16 is supported for TRTLLM FP4 GEMM operations."
1829+ )
1830+ return True
1831+
1832+
1833+ @supported_compute_capability ([100 , 103 , 120 ])
1834+ def _cutlass_gemm_fp4_requirement (
1835+ a : torch .Tensor ,
1836+ b : torch .Tensor ,
1837+ a_descale : torch .Tensor ,
1838+ b_descale : torch .Tensor ,
1839+ alpha : Optional [torch .Tensor ] = None ,
1840+ out_dtype : torch .dtype = torch .bfloat16 ,
1841+ out : Optional [torch .Tensor ] = None ,
1842+ block_size : int = 16 ,
1843+ use_8x4_sf_layout : bool = False ,
1844+ backend : Literal ["cudnn" , "trtllm" , "cutlass" ] = "cudnn" ,
1845+ use_nvfp4 : bool = True ,
1846+ ):
1847+ return True
1848+
1849+
1850+ @backend_requirement (
1851+ {
1852+ "cudnn" : _cudnn_gemm_fp4_requirement , # Each backend has its own requirement function
1853+ "trtllm" : _trtllm_gemm_fp4_requirement ,
1854+ "cutlass" : _cutlass_gemm_fp4_requirement ,
1855+ },
1856+ common_check = _check_mm_fp4_problem_size , # Shape checks common to all backends
1857+ )
16501858def mm_fp4 (
16511859 a : torch .Tensor ,
16521860 b : torch .Tensor ,
@@ -1721,59 +1929,6 @@ def mm_fp4(
17211929 >>> out.shape
17221930 torch.Size([48, 256])
17231931 """
1724- # pre-check the input tensor, block scale tensor and alpha tensor
1725- if a .ndim != 2 or b .ndim != 2 :
1726- raise ValueError (f"mm_fp4 accepts 2d tensors, got { a .shape } and { b .shape } " )
1727- if a .shape [1 ] != b .shape [0 ]:
1728- raise ValueError (
1729- f"K dimension mismatch in mm_fp4. got a.shape[1] = { a .shape [1 ]} , b.shape[0] = { b .shape [0 ]} "
1730- )
1731- if a .dtype not in {torch .uint8 , get_native_fp4_dtype ()} or b .dtype not in {
1732- torch .uint8 ,
1733- get_native_fp4_dtype (),
1734- }:
1735- raise ValueError (
1736- f"a and b must have float4_e2m1fn_x2 packed into uint8. "
1737- f"Got { a .dtype } and { b .dtype } ."
1738- )
1739- if a_descale .dtype not in {
1740- torch .float8_e4m3fn ,
1741- torch .uint8 ,
1742- } or b_descale .dtype not in {torch .float8_e4m3fn , torch .uint8 }:
1743- raise ValueError (
1744- f"a_descale and b_descale must have float8_e4m3fnx2 packed into uint8. "
1745- f"Got { a_descale .dtype } and { b_descale .dtype } ."
1746- )
1747- if alpha is not None and alpha .dtype != torch .float :
1748- raise ValueError (f"alpha must be a float tensor, got { alpha .dtype } " )
1749- if alpha is not None and alpha .numel () != 1 :
1750- raise ValueError (f"alpha must be a scalar, got { alpha .numel ()} " )
1751-
1752- if out_dtype not in (torch .bfloat16 , torch .float16 ):
1753- raise ValueError (
1754- f"Unsupported output dtype: { out_dtype } . "
1755- f"Only torch.bfloat16 and torch.float16 are supported for FP4 GEMM operations."
1756- )
1757-
1758- if use_nvfp4 and block_size != 16 :
1759- raise ValueError ("nvfp4 only supports block_size = 16." )
1760- if not use_nvfp4 and block_size != 32 :
1761- raise ValueError ("mxfp4 supports block_size = 32." )
1762- if backend != "trtllm" and use_8x4_sf_layout :
1763- raise ValueError ("Only TRTLLM FP4 GEMM supports 8x4 scale factor layout." )
1764- if backend == "trtllm" and _match_sm_version (a .device , ["110" ]):
1765- raise ValueError ("TRTLLM FP4 GEMM is not supported on SM110." )
1766- if backend != "cudnn" and not use_nvfp4 :
1767- raise ValueError ("Only cudnn FP4 GEMM supports mxfp4 quantization." )
1768- if (
1769- backend == "cudnn"
1770- and not use_nvfp4
1771- and _match_sm_version (a .device , ["120" ])
1772- and cudnn .backend_version () < 91400
1773- ):
1774- raise LibraryError (
1775- "cudnn FP4 GEMM with mxfp4 quantization is not supported on SM120 with cuDNN backend version < 9.14.0."
1776- )
17771932
17781933 # allocate the output tensor if not provided
17791934 if out is None :
@@ -1788,8 +1943,6 @@ def mm_fp4(
17881943 )
17891944
17901945 if backend == "cudnn" :
1791- _check_cudnn_fp4_availability ()
1792-
17931946 # the fp4 cudnn graph will be shared for both mm and bmm, so
17941947 # here we need to get the 3d shape and stride including the
17951948 # batch dimension for both input and block scale tensors.
@@ -1804,7 +1957,7 @@ def mm_fp4(
18041957 )
18051958
18061959 # build the fp4 cudnn graph
1807- graph = build_cudnn_gemm_block_scale_dequantize_graph (
1960+ graph = build_plans_cudnn_fp4_gemm_graph (
18081961 real_a_shape ,
18091962 real_a_stride ,
18101963 real_b_shape ,
@@ -1826,12 +1979,6 @@ def mm_fp4(
18261979 graph , a , b , a_descale , b_descale , alpha , out , workspace_buffer
18271980 )
18281981 elif backend == "trtllm" :
1829- if out_dtype != torch .bfloat16 :
1830- raise ValueError (
1831- f"Unsupported output dtype: { out_dtype } . "
1832- f"Only torch.bfloat16 is supported for TRTLLM FP4 GEMM operations."
1833- )
1834-
18351982 get_trtllm_fp4_gemm_module ().trtllm_fp4_gemm (
18361983 a ,
18371984 b .T ,
0 commit comments