From 5a3ebc3340394631841376dc1c2c16a97f31dfeb Mon Sep 17 00:00:00 2001 From: Maximilien Breughe Date: Mon, 3 Nov 2025 10:20:09 -0800 Subject: [PATCH 1/3] Updated decorator to support unspecified default --- flashinfer/gemm.py | 2 +- flashinfer/utils.py | 19 +++++++++++++++++-- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/flashinfer/gemm.py b/flashinfer/gemm.py index b561a67862..9f00cc6e25 100644 --- a/flashinfer/gemm.py +++ b/flashinfer/gemm.py @@ -1834,7 +1834,7 @@ def _trtllm_gemm_fp4_requirement( return True -@supported_compute_capability([100, 103, 120, 121]) +@supported_compute_capability([100, 103, 110, 120, 121]) def _cutlass_gemm_fp4_requirement( a: torch.Tensor, b: torch.Tensor, diff --git a/flashinfer/utils.py b/flashinfer/utils.py index 936d08380c..e695665b9a 100644 --- a/flashinfer/utils.py +++ b/flashinfer/utils.py @@ -23,6 +23,7 @@ import torch.version from torch.torch_version import TorchVersion from torch.torch_version import __version__ as torch_version +import inspect from .jit.spdlog import gen_spdlog_module @@ -950,6 +951,19 @@ def backend_requirement( """ def decorator(func): + def get_backend(args, kwargs): + # backend may not be specified, but could have a default value + sig = inspect.signature(func) + backend_parameter = sig.parameters.get("backend") + if ( + backend_parameter + and backend_parameter.default != inspect.Parameter.empty + ): + backend = kwargs.get("backend", backend_parameter.default) + else: + backend = kwargs.get("backend") + return backend + def is_backend_supported(backend, cc=None): # Is this backend present? if backend not in backend_checks: @@ -972,7 +986,7 @@ def is_compute_capability_supported(cc): ) def is_problem_size_supported(*args, **kwargs): - backend = kwargs.get("backend") + backend = get_backend(args, kwargs) if backend not in backend_checks: raise BackendSupportedError( f"Backend '{backend}' is not supported for {func.__name__}" @@ -985,12 +999,13 @@ def is_problem_size_supported(*args, **kwargs): @functools.wraps(func) def wrapper(*args, **kwargs): - backend = kwargs.get("backend") # skip_check is an optional argument that the decorator adds to any API function. # It prevents the performance overhead of checking. skip_check = kwargs.pop("skip_check", False) if not skip_check: + backend = get_backend(args, kwargs) + capability = None # Find the first tensor argument. # Assume all tensors are on the same device/capability. From 56d4c133957e6a2478a63c2ddc772df519b8781f Mon Sep 17 00:00:00 2001 From: Maximilien Breughe Date: Mon, 3 Nov 2025 11:52:09 -0800 Subject: [PATCH 2/3] Applied defaults from the wrapper function --- flashinfer/utils.py | 48 +++++++++++++++++----------------- tests/utils/test_decorators.py | 36 +++++++++++++++++++++++++ 2 files changed, 60 insertions(+), 24 deletions(-) diff --git a/flashinfer/utils.py b/flashinfer/utils.py index e695665b9a..32cb64ea6f 100644 --- a/flashinfer/utils.py +++ b/flashinfer/utils.py @@ -951,18 +951,8 @@ def backend_requirement( """ def decorator(func): - def get_backend(args, kwargs): - # backend may not be specified, but could have a default value - sig = inspect.signature(func) - backend_parameter = sig.parameters.get("backend") - if ( - backend_parameter - and backend_parameter.default != inspect.Parameter.empty - ): - backend = kwargs.get("backend", backend_parameter.default) - else: - backend = kwargs.get("backend") - return backend + # Get the function signature once for reuse + sig = inspect.signature(func) def is_backend_supported(backend, cc=None): # Is this backend present? @@ -985,8 +975,10 @@ def is_compute_capability_supported(cc): for checker in backend_checks.values() ) - def is_problem_size_supported(*args, **kwargs): - backend = get_backend(args, kwargs) + # @note: this function does not automatically apply defaults to the arguments. + def _is_problem_size_supported(*args, **kwargs): + # At this point, kwargs should have defaults applied, so backend should be present + backend = kwargs.get("backend") if backend not in backend_checks: raise BackendSupportedError( f"Backend '{backend}' is not supported for {func.__name__}" @@ -997,27 +989,35 @@ def is_problem_size_supported(*args, **kwargs): else: return req_checker(*args, **kwargs) + # @brief: Wrapper function that calls the orignal, decorated function, after applying a number of checks. + # @note that here we manually apply defaults to the arguments in the wrapper function. @functools.wraps(func) def wrapper(*args, **kwargs): # skip_check is an optional argument that the decorator adds to any API function. # It prevents the performance overhead of checking. skip_check = kwargs.pop("skip_check", False) + # Apply defaults from the function signature + # This ensures that all parameters (including backend) have their default values + # if not explicitly provided by the caller + bound_args = sig.bind(*args, **kwargs) + bound_args.apply_defaults() + # Convert back to args and kwargs for consistency with the rest of the code + kwargs = dict(bound_args.arguments) + args = () # All arguments are now in kwargs after binding + if not skip_check: - backend = get_backend(args, kwargs) + backend = kwargs.get("backend") capability = None # Find the first tensor argument. # Assume all tensors are on the same device/capability. # We could consider check all tensors at a performance cost. tensor_arg = None - for arg in args: - if isinstance(arg, torch.Tensor): - tensor_arg = arg - if tensor_arg is None: - for value in kwargs.values(): - if isinstance(value, torch.Tensor): - tensor_arg = value + for value in kwargs.values(): + if isinstance(value, torch.Tensor): + tensor_arg = value + break if tensor_arg is not None: # Get compute capability from the first tensor @@ -1030,11 +1030,11 @@ def wrapper(*args, **kwargs): raise BackendSupportedError( f"{func.__name__} does not support backend '{backend}'{extra}" ) - if not is_problem_size_supported(*args, **kwargs): + if not _is_problem_size_supported(**kwargs): raise ValueError( f"Problem size is not supported for {func.__name__}" ) - return func(*args, **kwargs) + return func(**kwargs) wrapper.is_backend_supported = is_backend_supported wrapper.is_compute_capability_supported = is_compute_capability_supported diff --git a/tests/utils/test_decorators.py b/tests/utils/test_decorators.py index e0520b1d43..e0528cfd60 100644 --- a/tests/utils/test_decorators.py +++ b/tests/utils/test_decorators.py @@ -210,3 +210,39 @@ def my_documented_function(x, backend="backend"): # Verify that added methods still exist assert hasattr(my_documented_function, "is_backend_supported") assert hasattr(my_documented_function, "is_compute_capability_supported") + + +def test_backend_default_parameter(): + """Test that backend_requirement correctly uses default backend parameter when not specified.""" + if not torch.cuda.is_available(): + pytest.skip("Skipping CUDA tests (no GPU available)") + + # Get actual device capability + x = torch.randn(1, 1, device="cuda") + major, minor = torch.cuda.get_device_capability(x.device) + actual_capability = major * 10 + minor + + @supported_compute_capability([80, 86, 89, 90, actual_capability]) + def _cutlass_check(x, backend): + return x.shape[0] > 0 + + @supported_compute_capability([75, 80, 86, 89, 90, actual_capability]) + def _cudnn_check(x, backend): + return x.shape[0] > 0 + + @backend_requirement({"cutlass": _cutlass_check, "cudnn": _cudnn_check}) + def my_kernel(x, backend="cudnn"): + return x * 2 + + x = torch.randn(10, 10, device="cuda") + + # Test that calling without backend argument uses the default "cudnn" + # This should work without raising an error + result = my_kernel(x) + assert result.shape == x.shape + assert torch.allclose(result, x * 2) + + # Test that explicitly passing a different backend also works + result2 = my_kernel(x, backend="cutlass") + assert result2.shape == x.shape + assert torch.allclose(result2, x * 2) From 7401a6bdb6809b10b2caf124126881f18585dca7 Mon Sep 17 00:00:00 2001 From: Maximilien Breughe Date: Mon, 3 Nov 2025 11:59:06 -0800 Subject: [PATCH 3/3] Avoid binding args when skip_check --- flashinfer/utils.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/flashinfer/utils.py b/flashinfer/utils.py index 32cb64ea6f..eb42e1291e 100644 --- a/flashinfer/utils.py +++ b/flashinfer/utils.py @@ -990,31 +990,30 @@ def _is_problem_size_supported(*args, **kwargs): return req_checker(*args, **kwargs) # @brief: Wrapper function that calls the orignal, decorated function, after applying a number of checks. - # @note that here we manually apply defaults to the arguments in the wrapper function. + # @note that here we manually apply defaults to the arguments in the wrapper function when doing validation. @functools.wraps(func) def wrapper(*args, **kwargs): # skip_check is an optional argument that the decorator adds to any API function. # It prevents the performance overhead of checking. skip_check = kwargs.pop("skip_check", False) - # Apply defaults from the function signature - # This ensures that all parameters (including backend) have their default values - # if not explicitly provided by the caller - bound_args = sig.bind(*args, **kwargs) - bound_args.apply_defaults() - # Convert back to args and kwargs for consistency with the rest of the code - kwargs = dict(bound_args.arguments) - args = () # All arguments are now in kwargs after binding - if not skip_check: - backend = kwargs.get("backend") + # Apply defaults from the function signature for validation + # This ensures that all parameters (including backend) have their default values + # if not explicitly provided by the caller + bound_args = sig.bind(*args, **kwargs) + bound_args.apply_defaults() + # Convert to kwargs for validation functions + kwargs_with_defaults = dict(bound_args.arguments) + + backend = kwargs_with_defaults.get("backend") capability = None # Find the first tensor argument. # Assume all tensors are on the same device/capability. # We could consider check all tensors at a performance cost. tensor_arg = None - for value in kwargs.values(): + for value in kwargs_with_defaults.values(): if isinstance(value, torch.Tensor): tensor_arg = value break @@ -1030,11 +1029,12 @@ def wrapper(*args, **kwargs): raise BackendSupportedError( f"{func.__name__} does not support backend '{backend}'{extra}" ) - if not _is_problem_size_supported(**kwargs): + if not _is_problem_size_supported(**kwargs_with_defaults): raise ValueError( f"Problem size is not supported for {func.__name__}" ) - return func(**kwargs) + + return func(*args, **kwargs) wrapper.is_backend_supported = is_backend_supported wrapper.is_compute_capability_supported = is_compute_capability_supported