Skip to content

Commit 1e75bff

Browse files
authored
Updated decorator to support unspecified default (#2026)
<!-- .github/pull_request_template.md --> ## 📌 Description Updated decorator to support unspecified default. This was causing issues when calling mm_fp4 without backend specified. Also added SM 110 as a supported backend on the cutlass backend (mm_fp4) ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * FP4 Cutlass GEMM now supports the SM110 GPU compute capability. * **Bug Fixes** * Kernels called without an explicit backend now consistently use the default backend. * **Tests** * Added a unit test to verify default backend selection and correct results when backend is omitted. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent da01b1b commit 1e75bff

File tree

3 files changed

+62
-11
lines changed

3 files changed

+62
-11
lines changed

flashinfer/gemm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1834,7 +1834,7 @@ def _trtllm_gemm_fp4_requirement(
18341834
return True
18351835

18361836

1837-
@supported_compute_capability([100, 103, 120, 121])
1837+
@supported_compute_capability([100, 103, 110, 120, 121])
18381838
def _cutlass_gemm_fp4_requirement(
18391839
a: torch.Tensor,
18401840
b: torch.Tensor,

flashinfer/utils.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import torch.version
2424
from torch.torch_version import TorchVersion
2525
from torch.torch_version import __version__ as torch_version
26+
import inspect
2627

2728
from .jit.spdlog import gen_spdlog_module
2829

@@ -950,6 +951,9 @@ def backend_requirement(
950951
"""
951952

952953
def decorator(func):
954+
# Get the function signature once for reuse
955+
sig = inspect.signature(func)
956+
953957
def is_backend_supported(backend, cc=None):
954958
# Is this backend present?
955959
if backend not in backend_checks:
@@ -971,7 +975,9 @@ def is_compute_capability_supported(cc):
971975
for checker in backend_checks.values()
972976
)
973977

974-
def is_problem_size_supported(*args, **kwargs):
978+
# @note: this function does not automatically apply defaults to the arguments.
979+
def _is_problem_size_supported(*args, **kwargs):
980+
# At this point, kwargs should have defaults applied, so backend should be present
975981
backend = kwargs.get("backend")
976982
if backend not in backend_checks:
977983
raise BackendSupportedError(
@@ -983,26 +989,34 @@ def is_problem_size_supported(*args, **kwargs):
983989
else:
984990
return req_checker(*args, **kwargs)
985991

992+
# @brief: Wrapper function that calls the orignal, decorated function, after applying a number of checks.
993+
# @note that here we manually apply defaults to the arguments in the wrapper function when doing validation.
986994
@functools.wraps(func)
987995
def wrapper(*args, **kwargs):
988-
backend = kwargs.get("backend")
989996
# skip_check is an optional argument that the decorator adds to any API function.
990997
# It prevents the performance overhead of checking.
991998
skip_check = kwargs.pop("skip_check", False)
992999

9931000
if not skip_check:
1001+
# Apply defaults from the function signature for validation
1002+
# This ensures that all parameters (including backend) have their default values
1003+
# if not explicitly provided by the caller
1004+
bound_args = sig.bind(*args, **kwargs)
1005+
bound_args.apply_defaults()
1006+
# Convert to kwargs for validation functions
1007+
kwargs_with_defaults = dict(bound_args.arguments)
1008+
1009+
backend = kwargs_with_defaults.get("backend")
1010+
9941011
capability = None
9951012
# Find the first tensor argument.
9961013
# Assume all tensors are on the same device/capability.
9971014
# We could consider check all tensors at a performance cost.
9981015
tensor_arg = None
999-
for arg in args:
1000-
if isinstance(arg, torch.Tensor):
1001-
tensor_arg = arg
1002-
if tensor_arg is None:
1003-
for value in kwargs.values():
1004-
if isinstance(value, torch.Tensor):
1005-
tensor_arg = value
1016+
for value in kwargs_with_defaults.values():
1017+
if isinstance(value, torch.Tensor):
1018+
tensor_arg = value
1019+
break
10061020

10071021
if tensor_arg is not None:
10081022
# Get compute capability from the first tensor
@@ -1015,10 +1029,11 @@ def wrapper(*args, **kwargs):
10151029
raise BackendSupportedError(
10161030
f"{func.__name__} does not support backend '{backend}'{extra}"
10171031
)
1018-
if not is_problem_size_supported(*args, **kwargs):
1032+
if not _is_problem_size_supported(**kwargs_with_defaults):
10191033
raise ValueError(
10201034
f"Problem size is not supported for {func.__name__}"
10211035
)
1036+
10221037
return func(*args, **kwargs)
10231038

10241039
wrapper.is_backend_supported = is_backend_supported

tests/utils/test_decorators.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,3 +210,39 @@ def my_documented_function(x, backend="backend"):
210210
# Verify that added methods still exist
211211
assert hasattr(my_documented_function, "is_backend_supported")
212212
assert hasattr(my_documented_function, "is_compute_capability_supported")
213+
214+
215+
def test_backend_default_parameter():
216+
"""Test that backend_requirement correctly uses default backend parameter when not specified."""
217+
if not torch.cuda.is_available():
218+
pytest.skip("Skipping CUDA tests (no GPU available)")
219+
220+
# Get actual device capability
221+
x = torch.randn(1, 1, device="cuda")
222+
major, minor = torch.cuda.get_device_capability(x.device)
223+
actual_capability = major * 10 + minor
224+
225+
@supported_compute_capability([80, 86, 89, 90, actual_capability])
226+
def _cutlass_check(x, backend):
227+
return x.shape[0] > 0
228+
229+
@supported_compute_capability([75, 80, 86, 89, 90, actual_capability])
230+
def _cudnn_check(x, backend):
231+
return x.shape[0] > 0
232+
233+
@backend_requirement({"cutlass": _cutlass_check, "cudnn": _cudnn_check})
234+
def my_kernel(x, backend="cudnn"):
235+
return x * 2
236+
237+
x = torch.randn(10, 10, device="cuda")
238+
239+
# Test that calling without backend argument uses the default "cudnn"
240+
# This should work without raising an error
241+
result = my_kernel(x)
242+
assert result.shape == x.shape
243+
assert torch.allclose(result, x * 2)
244+
245+
# Test that explicitly passing a different backend also works
246+
result2 = my_kernel(x, backend="cutlass")
247+
assert result2.shape == x.shape
248+
assert torch.allclose(result2, x * 2)

0 commit comments

Comments
 (0)