Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flashinfer/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1834,7 +1834,7 @@ def _trtllm_gemm_fp4_requirement(
return True


@supported_compute_capability([100, 103, 120, 121])
@supported_compute_capability([100, 103, 110, 120, 121])
def _cutlass_gemm_fp4_requirement(
a: torch.Tensor,
b: torch.Tensor,
Expand Down
19 changes: 17 additions & 2 deletions flashinfer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import torch.version
from torch.torch_version import TorchVersion
from torch.torch_version import __version__ as torch_version
import inspect

from .jit.spdlog import gen_spdlog_module

Expand Down Expand Up @@ -950,6 +951,19 @@ def backend_requirement(
"""

def decorator(func):
def get_backend(args, kwargs):
# backend may not be specified, but could have a default value
sig = inspect.signature(func)
backend_parameter = sig.parameters.get("backend")
if (
backend_parameter
and backend_parameter.default != inspect.Parameter.empty
):
backend = kwargs.get("backend", backend_parameter.default)
else:
backend = kwargs.get("backend")
return backend
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The new get_backend function is a good improvement to handle default values for the backend parameter. However, it only considers keyword arguments (kwargs) and ignores positional arguments (args), which are passed to it but not used. This can lead to incorrect behavior when backend is provided as a positional argument.

A more robust approach is to use inspect.signature().bind() to correctly resolve the backend argument from both positional and keyword arguments, while also handling default values gracefully. This ensures the correct backend is always identified.

Suggested change
def get_backend(args, kwargs):
# backend may not be specified, but could have a default value
sig = inspect.signature(func)
backend_parameter = sig.parameters.get("backend")
if (
backend_parameter
and backend_parameter.default != inspect.Parameter.empty
):
backend = kwargs.get("backend", backend_parameter.default)
else:
backend = kwargs.get("backend")
return backend
def get_backend(args, kwargs):
# backend may not be specified, but could have a default value
sig = inspect.signature(func)
try:
bound_args = sig.bind(*args, **kwargs)
bound_args.apply_defaults()
return bound_args.arguments.get("backend")
except TypeError:
# Fallback for safety, though it's unlikely to be needed if the call is valid.
return kwargs.get("backend")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the function signature does not indicate a default backend, get_backend will return None. We should handle this edge case in the rest of the decorators, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be fixed with apply_defaults


def is_backend_supported(backend, cc=None):
# Is this backend present?
if backend not in backend_checks:
Expand All @@ -972,7 +986,7 @@ def is_compute_capability_supported(cc):
)

def is_problem_size_supported(*args, **kwargs):
backend = kwargs.get("backend")
backend = get_backend(args, kwargs)
if backend not in backend_checks:
raise BackendSupportedError(
f"Backend '{backend}' is not supported for {func.__name__}"
Expand All @@ -985,12 +999,13 @@ def is_problem_size_supported(*args, **kwargs):

@functools.wraps(func)
def wrapper(*args, **kwargs):
backend = kwargs.get("backend")
# skip_check is an optional argument that the decorator adds to any API function.
# It prevents the performance overhead of checking.
skip_check = kwargs.pop("skip_check", False)

if not skip_check:
backend = get_backend(args, kwargs)

capability = None
# Find the first tensor argument.
# Assume all tensors are on the same device/capability.
Expand Down