diff --git a/flashinfer/gemm.py b/flashinfer/gemm.py index ffb2e4825f..2ef535f97c 100644 --- a/flashinfer/gemm.py +++ b/flashinfer/gemm.py @@ -41,6 +41,8 @@ is_sm120a_supported, is_sm121a_supported, LibraryError, + backend_requirement, + supported_compute_capability, ) from .jit.gemm import gen_gemm_sm90_module from .jit.gemm import gen_gemm_module @@ -81,6 +83,9 @@ DEFAULT_WORKSPACE_SIZE = 32 * 1024 * 1024 +# Error messages +CUDNN_FP4_MXFP4_SM120_CUDNN_VERSION_ERROR = "cudnn FP4 GEMM with mxfp4 quantization is not supported on SM120 with cuDNN backend version < 9.14.0." + def _match_sm_version(device: torch.device, sm_version: list[str]): major, minor = get_compute_capability(device) @@ -1182,7 +1187,7 @@ def _validate_fp8_output_dtype(dtype: torch.dtype): @functools.cache -def build_cudnn_gemm_block_scale_dequantize_graph( +def create_cudnn_execution_plans_fp4_gemm( a_shape, a_stride, b_shape, @@ -1279,12 +1284,49 @@ def build_cudnn_gemm_block_scale_dequantize_graph( # in older cuDNN versions, so we deselect it. if (alpha_is_not_none) and (not _is_cublas_fp4_available_in_cudnn()): graph.deselect_engines(["eng0"]) - graph.check_support() - graph.build_plans() return graph +@functools.cache +def build_plans_cudnn_fp4_gemm_graph( + a_shape, + a_stride, + b_shape, + b_stride, + a_descale_shape, + a_descale_stride, + b_descale_shape, + b_descale_stride, + ab_type, + o_type, + block_size, + device, + alpha, + use_nvfp4, +): + graph = create_cudnn_execution_plans_fp4_gemm( + a_shape, + a_stride, + b_shape, + b_stride, + a_descale_shape, + a_descale_stride, + b_descale_shape, + b_descale_stride, + ab_type, + o_type, + block_size, + device, + alpha, + use_nvfp4, + ) + + graph.check_support() + graph.build_plans() + return graph + + def execute_cudnn_gemm_fp4_graph( graph, a, @@ -1647,6 +1689,172 @@ def mm_fp8( return out +def _check_mm_fp4_problem_size( + a: torch.Tensor, + b: torch.Tensor, + a_descale: torch.Tensor, + b_descale: torch.Tensor, + alpha: Optional[torch.Tensor] = None, + out_dtype: torch.dtype = torch.bfloat16, + out: Optional[torch.Tensor] = None, + block_size: int = 16, + use_8x4_sf_layout: bool = False, + backend: Literal["cudnn", "trtllm", "cutlass"] = "cudnn", + use_nvfp4: bool = True, +): + # Generic checks + ## pre-check the input tensor, block scale tensor and alpha tensor + if a.ndim != 2 or b.ndim != 2: + raise ValueError(f"mm_fp4 accepts 2d tensors, got {a.shape} and {b.shape}") + if a.shape[1] != b.shape[0]: + raise ValueError( + f"K dimension mismatch in mm_fp4. got a.shape[1] = {a.shape[1]}, b.shape[0] = {b.shape[0]}" + ) + if a.dtype not in {torch.uint8, get_native_fp4_dtype()} or b.dtype not in { + torch.uint8, + get_native_fp4_dtype(), + }: + raise ValueError( + f"a and b must have float4_e2m1fn_x2 packed into uint8. " + f"Got {a.dtype} and {b.dtype}." + ) + if a_descale.dtype not in { + torch.float8_e4m3fn, + torch.uint8, + } or b_descale.dtype not in {torch.float8_e4m3fn, torch.uint8}: + raise ValueError( + f"a_descale and b_descale must have float8_e4m3fnx2 packed into uint8. " + f"Got {a_descale.dtype} and {b_descale.dtype}." + ) + if alpha is not None and alpha.dtype != torch.float: + raise ValueError(f"alpha must be a float tensor, got {alpha.dtype}") + if alpha is not None and alpha.numel() != 1: + raise ValueError(f"alpha must be a scalar, got {alpha.numel()}") + + if out_dtype not in (torch.bfloat16, torch.float16): + raise ValueError( + f"Unsupported output dtype: {out_dtype}. " + f"Only torch.bfloat16 and torch.float16 are supported for FP4 GEMM operations." + ) + + if backend != "trtllm" and use_8x4_sf_layout: + raise ValueError("Only TRTLLM FP4 GEMM supports 8x4 scale factor layout.") + if backend != "cudnn" and not use_nvfp4: + raise ValueError("Only cudnn FP4 GEMM supports mxfp4 quantization.") + + if use_nvfp4 and block_size != 16: + raise ValueError("nvfp4 only supports block_size = 16.") + if not use_nvfp4 and block_size != 32: + raise ValueError("mxfp4 only supports block_size = 32.") + + return True + + +@supported_compute_capability([100, 103, 110, 120]) +def _cudnn_gemm_fp4_requirement( + a: torch.Tensor, + b: torch.Tensor, + a_descale: torch.Tensor, + b_descale: torch.Tensor, + alpha: Optional[torch.Tensor] = None, + out_dtype: torch.dtype = torch.bfloat16, + out: Optional[torch.Tensor] = None, + block_size: int = 16, + use_8x4_sf_layout: bool = False, + backend: Literal["cudnn", "trtllm", "cutlass"] = "cudnn", + use_nvfp4: bool = True, +): + if ( + not use_nvfp4 + and _match_sm_version(a.device, ["120"]) + and cudnn.backend_version() < 91400 + ): + raise LibraryError(CUDNN_FP4_MXFP4_SM120_CUDNN_VERSION_ERROR) + + _check_cudnn_fp4_availability() + + # the fp4 cudnn graph will be shared for both mm and bmm, so + # here we need to get the 3d shape and stride including the + # batch dimension for both input and block scale tensors. + real_a_shape, real_a_stride = _get_real_fp4_shape_from_packed_uint8(a) + real_b_shape, real_b_stride = _get_real_fp4_shape_from_packed_uint8(b) + batch = real_a_shape[0] + expanded_a_descale_shape, expanded_a_descale_stride = ( + _expand_block_scale_tensor_shape(a_descale, batch) + ) + expanded_b_descale_shape, expanded_b_descale_stride = ( + _expand_block_scale_tensor_shape(b_descale, batch) + ) + + # build the fp4 cudnn graph + graph = create_cudnn_execution_plans_fp4_gemm( + real_a_shape, + real_a_stride, + real_b_shape, + real_b_stride, + expanded_a_descale_shape, + expanded_a_descale_stride, + expanded_b_descale_shape, + expanded_b_descale_stride, + cudnn.data_type.FP4_E2M1, + _torch_data_type_to_cudnn_data_type(out_dtype), + block_size, + a.device, + alpha, + use_nvfp4, + ) + graph.check_support() + + return True + + +@supported_compute_capability([100, 103, 120]) +def _trtllm_gemm_fp4_requirement( + a: torch.Tensor, + b: torch.Tensor, + a_descale: torch.Tensor, + b_descale: torch.Tensor, + alpha: Optional[torch.Tensor] = None, + out_dtype: torch.dtype = torch.bfloat16, + out: Optional[torch.Tensor] = None, + block_size: int = 16, + use_8x4_sf_layout: bool = False, + backend: Literal["cudnn", "trtllm", "cutlass"] = "cudnn", + use_nvfp4: bool = True, +): + if out_dtype != torch.bfloat16: + raise ValueError( + f"Unsupported output dtype: {out_dtype}. " + f"Only torch.bfloat16 is supported for TRTLLM FP4 GEMM operations." + ) + return True + + +@supported_compute_capability([100, 103, 120]) +def _cutlass_gemm_fp4_requirement( + a: torch.Tensor, + b: torch.Tensor, + a_descale: torch.Tensor, + b_descale: torch.Tensor, + alpha: Optional[torch.Tensor] = None, + out_dtype: torch.dtype = torch.bfloat16, + out: Optional[torch.Tensor] = None, + block_size: int = 16, + use_8x4_sf_layout: bool = False, + backend: Literal["cudnn", "trtllm", "cutlass"] = "cudnn", + use_nvfp4: bool = True, +): + return True + + +@backend_requirement( + { + "cudnn": _cudnn_gemm_fp4_requirement, # Each backend has its own requirement function + "trtllm": _trtllm_gemm_fp4_requirement, + "cutlass": _cutlass_gemm_fp4_requirement, + }, + common_check=_check_mm_fp4_problem_size, # Shape checks common to all backends +) def mm_fp4( a: torch.Tensor, b: torch.Tensor, @@ -1721,59 +1929,6 @@ def mm_fp4( >>> out.shape torch.Size([48, 256]) """ - # pre-check the input tensor, block scale tensor and alpha tensor - if a.ndim != 2 or b.ndim != 2: - raise ValueError(f"mm_fp4 accepts 2d tensors, got {a.shape} and {b.shape}") - if a.shape[1] != b.shape[0]: - raise ValueError( - f"K dimension mismatch in mm_fp4. got a.shape[1] = {a.shape[1]}, b.shape[0] = {b.shape[0]}" - ) - if a.dtype not in {torch.uint8, get_native_fp4_dtype()} or b.dtype not in { - torch.uint8, - get_native_fp4_dtype(), - }: - raise ValueError( - f"a and b must have float4_e2m1fn_x2 packed into uint8. " - f"Got {a.dtype} and {b.dtype}." - ) - if a_descale.dtype not in { - torch.float8_e4m3fn, - torch.uint8, - } or b_descale.dtype not in {torch.float8_e4m3fn, torch.uint8}: - raise ValueError( - f"a_descale and b_descale must have float8_e4m3fnx2 packed into uint8. " - f"Got {a_descale.dtype} and {b_descale.dtype}." - ) - if alpha is not None and alpha.dtype != torch.float: - raise ValueError(f"alpha must be a float tensor, got {alpha.dtype}") - if alpha is not None and alpha.numel() != 1: - raise ValueError(f"alpha must be a scalar, got {alpha.numel()}") - - if out_dtype not in (torch.bfloat16, torch.float16): - raise ValueError( - f"Unsupported output dtype: {out_dtype}. " - f"Only torch.bfloat16 and torch.float16 are supported for FP4 GEMM operations." - ) - - if use_nvfp4 and block_size != 16: - raise ValueError("nvfp4 only supports block_size = 16.") - if not use_nvfp4 and block_size != 32: - raise ValueError("mxfp4 supports block_size = 32.") - if backend != "trtllm" and use_8x4_sf_layout: - raise ValueError("Only TRTLLM FP4 GEMM supports 8x4 scale factor layout.") - if backend == "trtllm" and _match_sm_version(a.device, ["110"]): - raise ValueError("TRTLLM FP4 GEMM is not supported on SM110.") - if backend != "cudnn" and not use_nvfp4: - raise ValueError("Only cudnn FP4 GEMM supports mxfp4 quantization.") - if ( - backend == "cudnn" - and not use_nvfp4 - and _match_sm_version(a.device, ["120"]) - and cudnn.backend_version() < 91400 - ): - raise LibraryError( - "cudnn FP4 GEMM with mxfp4 quantization is not supported on SM120 with cuDNN backend version < 9.14.0." - ) # allocate the output tensor if not provided if out is None: @@ -1788,8 +1943,6 @@ def mm_fp4( ) if backend == "cudnn": - _check_cudnn_fp4_availability() - # the fp4 cudnn graph will be shared for both mm and bmm, so # here we need to get the 3d shape and stride including the # batch dimension for both input and block scale tensors. @@ -1804,7 +1957,7 @@ def mm_fp4( ) # build the fp4 cudnn graph - graph = build_cudnn_gemm_block_scale_dequantize_graph( + graph = build_plans_cudnn_fp4_gemm_graph( real_a_shape, real_a_stride, real_b_shape, @@ -1826,12 +1979,6 @@ def mm_fp4( graph, a, b, a_descale, b_descale, alpha, out, workspace_buffer ) elif backend == "trtllm": - if out_dtype != torch.bfloat16: - raise ValueError( - f"Unsupported output dtype: {out_dtype}. " - f"Only torch.bfloat16 is supported for TRTLLM FP4 GEMM operations." - ) - get_trtllm_fp4_gemm_module().trtllm_fp4_gemm( a, b.T, diff --git a/flashinfer/utils.py b/flashinfer/utils.py index e015010c83..936d08380c 100644 --- a/flashinfer/utils.py +++ b/flashinfer/utils.py @@ -60,6 +60,12 @@ class LibraryError(Exception): pass +class BackendSupportedError(Exception): + """Custom exception for backend-related errors.""" + + pass + + def _expand_5d(x: torch.Tensor, kv_layout: str) -> torch.Tensor: if x.ndim not in [4, 5]: raise ValueError("x must be 4D or 5D") @@ -761,3 +767,262 @@ def get_native_fp4_dtype(): return torch.float4_e2m1fn_x2 else: return torch.uint8 + + +def supported_compute_capability(supported_ccs: Iterable[int]) -> Callable: + """ + Decorator to mark functions with their supported CUDA compute capabilities. + + This decorator annotates a function with metadata about which CUDA compute + capabilities (CC) it supports. It adds a `_supported_ccs` attribute containing + the set of supported compute capabilities and an `is_compute_capability_supported` + method to check if a specific compute capability is supported. + + Parameters + ---------- + supported_ccs : list or iterable of int + A list of supported CUDA compute capability versions as integers + (e.g., [75, 80, 86, 89, 90, 100, 103, 110, 120]). + These are computed as major * 10 + minor (e.g., SM 8.0 = 80, SM 9.0 = 90). + + Returns + ------- + decorator : callable + A decorator function that adds compute capability metadata to the decorated function. + + Attributes Added to Decorated Function + --------------------------------------- + _supported_ccs : set of int + A set of integers representing the supported compute capabilities. + is_compute_capability_supported : callable + A method that takes a compute capability (int) and returns True if it's + supported, False otherwise. + + Examples + -------- + >>> @supported_compute_capability([80, 86, 89, 90]) + ... def my_kernel_function(): + ... pass + ... + >>> my_kernel_function._supported_ccs + {80, 86, 89, 90} + >>> my_kernel_function.is_compute_capability_supported(80) + True + >>> my_kernel_function.is_compute_capability_supported(75) + False + + Notes + ----- + This decorator is useful in conjunction with the backend_requirement decorator to mark functions with their supported CUDA compute capabilities. + + Raises + ------ + TypeError + If supported_ccs is not iterable or contains non-integer values. + """ + # Validate that supported_ccs is iterable + try: + ccs_list = list(supported_ccs) + except TypeError: + raise TypeError( + f"supported_ccs must be an iterable, got {type(supported_ccs).__name__}" + ) from None + + # Validate and convert all elements to integers + validated_ccs = [] + for i, cc in enumerate(ccs_list): + if isinstance(cc, bool): + # Reject booleans (which are technically ints in Python) + raise TypeError(f"supported_ccs[{i}] must be an integer, got bool: {cc}") + if not isinstance(cc, int): + raise TypeError( + f"supported_ccs[{i}] must be an integer, got {type(cc).__name__}: {cc}" + ) + validated_ccs.append(cc) + + def decorator(func): + func._supported_ccs = set(validated_ccs) + + def is_cc_supported(cc): + return cc in func._supported_ccs + + func.is_compute_capability_supported = is_cc_supported + return func + + return decorator + + +def backend_requirement( + backend_checks: Dict[str, Callable], common_check: Optional[Callable] = None +) -> Callable: + """ + Decorator to enforce backend and problem size requirements for kernel functions. + + This decorator validates that a function is called with a supported backend and + compute capability, and optionally validates problem size constraints. It performs + runtime checks before executing the function and raises appropriate errors if + requirements are not met. If checking overheads are a concern, you can pass a + `skip_check` keyword argument to the function to bypass the validation. + + Parameters + ---------- + backend_checks : dict + A dictionary mapping backend names (str) to requirement checker functions. + Each checker function should accept the same arguments as the decorated function + and return True if the problem size is supported, False otherwise. + Checkers can be decorated with @supported_compute_capability to specify + which compute capabilities they support. + common_check : callable, optional + An optional function that performs additional validation checks common to all + backends. Should accept the same arguments as the decorated function and return + True if requirements are met, False otherwise. + + Returns + ------- + decorator : callable + A decorator function that wraps the target function with validation logic, and inserts + the "skip_check" keyword argument to the function. + + Attributes Added to Decorated Function + --------------------------------------- + is_backend_supported : callable + Method with signature `is_backend_supported(backend, cc=None)` that returns + True if the specified backend is supported, optionally for a specific compute + capability (cc). + is_compute_capability_supported : callable + Method with signature `is_compute_capability_supported(cc)` that returns True + if any backend supports the given compute capability. + + Keyword Arguments Added to Decorated Function + --------------------------------------------- + skip_check : bool + (Defaults to False) + If True, the function will not be validated. This is useful for performance-critical code paths. + + Raises + ------ + BackendSupportedError + If the function is called with an unsupported backend or compute capability. + ValueError + If the problem size is not supported for the given backend. + + Examples + -------- + >>> @supported_compute_capability([80, 86, 89, 90]) + ... def _cutlass_check(q, k, v, backend): + ... # Validate problem size constraints for CUTLASS backend + ... return q.shape[-1] <= 256 + ... + >>> @supported_compute_capability([75, 80, 86, 89, 90]) + ... def _cudnn_check(q, k, v, backend): + ... # Validate problem size constraints for cuDNN backend + ... return True + ... + >>> @backend_requirement({ + ... "cutlass": _cutlass_check, + ... "cudnn": _cudnn_check + ... }) + ... def my_attention_kernel(q, k, v, backend="cutlass"): + ... # Backend invocation + ... pass + ... + >>> # Check if backend is supported + >>> my_attention_kernel.is_backend_supported("cutlass") + True + >>> # Check if backend supports specific compute capability + >>> my_attention_kernel.is_backend_supported("cutlass", 75) + False + >>> my_attention_kernel.is_backend_supported("cutlass", 80) + True + >>> # Check if any backend supports a compute capability + >>> my_attention_kernel.is_compute_capability_supported(75) + True + + Notes + ----- + - The decorator automatically extracts compute capability from tensor arguments + by finding the first torch.Tensor in args or kwargs. + - A `skip_check=True` keyword argument can be passed to bypass validation for + performance-critical code paths. + - All validation is performed before the wrapped function executes. + - Works in conjunction with the @supported_compute_capability decorator to + provide fine-grained control over backend and architecture support. + """ + + def decorator(func): + def is_backend_supported(backend, cc=None): + # Is this backend present? + if backend not in backend_checks: + return False + req_checker = backend_checks[backend] + # If user just wants to check if the backend is supported (regardless of compute capability), return True + if cc is None: + return True + # Check compute capability support via attribute on requirement function + elif hasattr(req_checker, "is_compute_capability_supported"): + return req_checker.is_compute_capability_supported(cc) + return False + + def is_compute_capability_supported(cc): + # True if any backend requirement supports this cc + return any( + hasattr(checker, "is_compute_capability_supported") + and checker.is_compute_capability_supported(cc) + for checker in backend_checks.values() + ) + + def is_problem_size_supported(*args, **kwargs): + backend = kwargs.get("backend") + if backend not in backend_checks: + raise BackendSupportedError( + f"Backend '{backend}' is not supported for {func.__name__}" + ) + req_checker = backend_checks[backend] + if common_check is not None: + return common_check(*args, **kwargs) and req_checker(*args, **kwargs) + else: + return req_checker(*args, **kwargs) + + @functools.wraps(func) + def wrapper(*args, **kwargs): + backend = kwargs.get("backend") + # skip_check is an optional argument that the decorator adds to any API function. + # It prevents the performance overhead of checking. + skip_check = kwargs.pop("skip_check", False) + + if not skip_check: + capability = None + # Find the first tensor argument. + # Assume all tensors are on the same device/capability. + # We could consider check all tensors at a performance cost. + tensor_arg = None + for arg in args: + if isinstance(arg, torch.Tensor): + tensor_arg = arg + if tensor_arg is None: + for value in kwargs.values(): + if isinstance(value, torch.Tensor): + tensor_arg = value + + if tensor_arg is not None: + # Get compute capability from the first tensor + # Assume all tensors are on the same device/capability + major, minor = get_compute_capability(tensor_arg.device) + capability = major * 10 + minor + + if not is_backend_supported(backend, capability): + extra = f" with capability {capability}" if capability else "" + raise BackendSupportedError( + f"{func.__name__} does not support backend '{backend}'{extra}" + ) + if not is_problem_size_supported(*args, **kwargs): + raise ValueError( + f"Problem size is not supported for {func.__name__}" + ) + return func(*args, **kwargs) + + wrapper.is_backend_supported = is_backend_supported + wrapper.is_compute_capability_supported = is_compute_capability_supported + return wrapper + + return decorator diff --git a/tests/gemm/test_mm_fp4.py b/tests/gemm/test_mm_fp4.py index ef2d3b9659..95975f6031 100644 --- a/tests/gemm/test_mm_fp4.py +++ b/tests/gemm/test_mm_fp4.py @@ -9,6 +9,7 @@ mxfp4_quantize, ) from flashinfer.utils import get_compute_capability, LibraryError +from flashinfer.gemm import CUDNN_FP4_MXFP4_SM120_CUDNN_VERSION_ERROR # TODO: Consdier splitting this function up for the various backends @@ -85,22 +86,17 @@ def test_mm_fp4( use_8x4_sf_layout=not use_128x4_sf_layout, backend=backend, use_nvfp4=use_nvfp4, + skip_check=False, ) cos_sim = F.cosine_similarity(reference.reshape(-1), res.reshape(-1), dim=0) assert cos_sim > 0.97 - except LibraryError: + except LibraryError as e: # TODO: Remove this check once cuDNN backend version is updated to 9.14.0 - if ( - backend == "cudnn" - and not use_nvfp4 - and (compute_capability[0] == 12 and compute_capability[1] == 0) - ): - pytest.xfail( - "cudnn FP4 GEMM with mxfp4 quantization is not supported on SM120 with cuDNN backend version < 9.14.0." - ) + if str(e) == CUDNN_FP4_MXFP4_SM120_CUDNN_VERSION_ERROR: + pytest.xfail(str(e)) else: - pytest.fail("Unexpected LibraryError") + pytest.fail(str(e)) if __name__ == "__main__": diff --git a/tests/utils/test_decorators.py b/tests/utils/test_decorators.py new file mode 100644 index 0000000000..e0520b1d43 --- /dev/null +++ b/tests/utils/test_decorators.py @@ -0,0 +1,212 @@ +import pytest +import torch + +from flashinfer.utils import ( + supported_compute_capability, + backend_requirement, + BackendSupportedError, +) + + +def test_supported_compute_capability(): + """Test the supported_compute_capability decorator.""" + + @supported_compute_capability([80, 86, 89, 90]) + def my_function(x, y): + return x + y + + # Check attributes + assert hasattr(my_function, "_supported_ccs"), "Missing _supported_ccs attribute" + assert my_function._supported_ccs == {80, 86, 89, 90}, "Incorrect _supported_ccs" + + # Check method + assert hasattr(my_function, "is_compute_capability_supported"), "Missing method" + assert my_function.is_compute_capability_supported(80) is True + assert my_function.is_compute_capability_supported(75) is False + + # Check function still works + result = my_function(5, 10) + assert result == 15, "Function doesn't work correctly" + + +def test_input_validation(): + """Test that the decorator validates input correctly.""" + + # Test rejection of non-iterable + with pytest.raises(TypeError, match="must be an iterable"): + + @supported_compute_capability(80) + def func1(): + pass + + # Test rejection of string values + with pytest.raises(TypeError, match="must be an integer"): + + @supported_compute_capability(["80", "86"]) + def func2(): + pass + + # Test rejection of float values + with pytest.raises(TypeError, match="must be an integer"): + + @supported_compute_capability([80.0, 86]) + def func3(): + pass + + # Test rejection of bool values + with pytest.raises(TypeError, match="got bool"): + + @supported_compute_capability([True, False]) + def func4(): + pass + + # Test acceptance of valid integers + @supported_compute_capability([75, 80, 86, 89, 90, 100, 103, 110, 120]) + def func5(): + pass + + assert func5._supported_ccs == {75, 80, 86, 89, 90, 100, 103, 110, 120} + + +def test_backend_requirement_support_checks(): + """Test the backend_requirement decorator support checks.""" + + @supported_compute_capability([80, 86, 89, 90]) + def _cudnn_check_my_kernel(x, backend): + return True + + @supported_compute_capability([75, 80, 86, 89, 90]) + def _cutlass_check_my_kernel(x, backend): + return True + + def _common_check(x, backend): + # Common requirement: must be 2D + return x.dim() == 2 + + @backend_requirement( + {"cudnn": _cudnn_check_my_kernel, "cutlass": _cutlass_check_my_kernel}, + common_check=_common_check, + ) + def my_kernel(x, backend="cudnn"): + return x * 2 + + # Check methods added + assert hasattr(my_kernel, "is_backend_supported"), "Missing is_backend_supported" + assert hasattr(my_kernel, "is_compute_capability_supported"), ( + "Missing is_compute_capability_supported" + ) + + # Check backend support + assert my_kernel.is_backend_supported("cutlass") is True + assert my_kernel.is_backend_supported("cudnn") is True + assert my_kernel.is_backend_supported("trtllm") is False + + # Check compute capability support + assert my_kernel.is_backend_supported("cutlass", 80) is True + assert my_kernel.is_backend_supported("cutlass", 75) is True # cutlass supports 75 + assert ( + my_kernel.is_backend_supported("cudnn", 75) is False + ) # cudnn does NOT support 75 + assert my_kernel.is_backend_supported("cudnn", 80) is True + + # Check cross-backend compute capability + assert my_kernel.is_compute_capability_supported(75) is True # cutlass has it + assert my_kernel.is_compute_capability_supported(80) is True # both have it + assert my_kernel.is_compute_capability_supported(70) is False # neither has it + + +def test_backend_requirement_wrapped_function(): + """Test the backend_requirement decorator's wrapped function.""" + if not torch.cuda.is_available(): + pytest.skip("Skipping CUDA tests (no GPU available)") + + # Get actual device capability + x = torch.randn(1, 1, device="cuda") + major, minor = torch.cuda.get_device_capability(x.device) + actual_capability = major * 10 + minor + + @supported_compute_capability([80, 86, 89, 90, actual_capability]) + def _cutlass_check(x, backend): + return x.shape[0] > 0 + + @supported_compute_capability([75, 80, 86, 89, 90, actual_capability]) + def _cudnn_check(x, backend): + return x.shape[0] > 0 + + @backend_requirement({"cutlass": _cutlass_check, "cudnn": _cudnn_check}) + def my_kernel(x, backend="cutlass"): + return x * 2 + + x = torch.randn(10, 10, device="cuda") + + # Test unsupported backend raises error + # The error message may include capability info, so use a flexible pattern + with pytest.raises( + BackendSupportedError, match="does not support backend 'trtllm'" + ): + my_kernel(x, backend="trtllm") + + # Test supported backend works + result = my_kernel(x, backend="cutlass") + assert result.shape == x.shape + + +def test_common_check(): + """Test common_check parameter.""" + if not torch.cuda.is_available(): + pytest.skip("Skipping CUDA tests (no GPU available)") + + x = torch.randn(1, 1, device="cuda") + major, minor = torch.cuda.get_device_capability(x.device) + actual_capability = major * 10 + minor + + @supported_compute_capability([80, 86, 89, 90, actual_capability]) + def _cudnn_check_my_kernel(x, backend): + return True + + @supported_compute_capability([75, 80, 86, 89, 90, actual_capability]) + def _cutlass_check_my_kernel(x, backend): + return True + + def _common_check(x, backend): + # Common requirement: must be 2D + return x.dim() == 2 + + @backend_requirement( + {"cudnn": _cudnn_check_my_kernel, "cutlass": _cutlass_check_my_kernel}, + common_check=_common_check, + ) + def my_kernel(x, backend="cudnn"): + return x * 2 + + x_2d = torch.randn(10, 10, device="cuda") + x_3d = torch.randn(10, 10, 10, device="cuda") + + # 2D should work with skip_check + result = my_kernel(x_2d, backend="cudnn", skip_check=True) + assert result.shape == x_2d.shape + + # 3D should fail validation + with pytest.raises(ValueError, match="Problem size is not supported"): + my_kernel(x_3d, backend="cudnn") + + +def test_functools_wraps_preserves_metadata(): + """Test that backend_requirement preserves function metadata with functools.wraps.""" + + @supported_compute_capability([80, 86, 89, 90]) + def _check(x, backend): + return True + + @backend_requirement({"backend": _check}) + def my_documented_function(x, backend="backend"): + """This is my function's docstring.""" + return x * 2 + + # Verify that function metadata is preserved + assert my_documented_function.__name__ == "my_documented_function" + assert my_documented_function.__doc__ == "This is my function's docstring." + + # Verify that added methods still exist + assert hasattr(my_documented_function, "is_backend_supported") + assert hasattr(my_documented_function, "is_compute_capability_supported")