-
Notifications
You must be signed in to change notification settings - Fork 584
Support checks PoC #1809
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support checks PoC #1809
Conversation
|
The checks currently live very far away from the implementation and updating them to be consistent with each other can eventually become a maintenance problem. The conditional checks are also quite tricky to get correct. For example, it's not easy to tell if the mxfp4 checks are correct. if not use_nvfp4 and block_size != 32:
raise ValueError("mxfp4 supports block_size = 32.")
if backend != "cudnn" and not use_nvfp4:
raise ValueError("Only cudnn FP4 GEMM supports mxfp4 quantization.")Shouldn't the checks be reordered to avoid confusing error messages?
Instead of having one top level
For example: def cudnn_gemm_fp4_requirement(
# ...
):
if (
not use_nvfp4
and _match_sm_version(a.device, ["120"])
and cudnn.backend_version() < 91400
):
raise LibraryError(
"cudnn FP4 GEMM with mxfp4 quantization is not supported on SM120 with cuDNN backend version < 9.14.0."
)
_check_cudnn_fp4_availability()
# ...
@requirement(cudnn_gemm_fp4_requirement, capability=["100", "101", "102"])
def execute_cudnn_gemm_fp4_graph(
# ...
@backend_requirement({
"cudnn": execute_cudnn_gemm_fp4_graph.requirement,
"trtllm": #...
})
def mm_fp4(
# ...This also means that all requirements are enforced to be local to the backend and won't affect each other. |
Thank you @nvjullin for the excellent suggestion.
While both are valid points, we prioritize separating the checks for now. Not all APIs are as cleanly to separate (2) atm and there is a plan for a more OO Backend class @Anerudhan. That does overlap somewhat with the support checks, as eventually we would be able to do something like cudnn_backend->check_mmfp4_support(). So I think as an intermediary step, and to get tighter checks in, we could do something like this: @supported_compute_capability(["100", "101", "102"])
def cudnn_gemm_fp4_requirement(
# ...
):
if (
not use_nvfp4
and _match_sm_version(a.device, ["120"])
and cudnn.backend_version() < 91400
):
raise LibraryError(
"cudnn FP4 GEMM with mxfp4 quantization is not supported on SM120 with cuDNN backend version < 9.14.0."
)
_check_cudnn_fp4_availability()
# ...
@backend_requirement({
"cudnn": execute_cudnn_gemm_fp4_graph.requirement,
"trtllm": #...
},
common_check=common_fp4_checks # To be called by all backend checks
})
def mm_fp4(
)
if backend == "cudnn":
# cudnn path
elif backend == "trtllm":
# trtllm path |
I wasn't aware, thanks for the info. LGTM. |
Thank you for the excellent suggestions, @nvjullin |
3c9f687 to
151fc7e
Compare
|
/bot run |
aleozlx
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks like a good step forward
|
[SUCCESS] Pipeline #36524696: 13/17 passed |
yzh119
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
sricketts
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM. Added one suggestion.
<!-- .github/pull_request_template.md --> ## 📌 Description In #1809 we previously added a compute-capability-based support check for `mm_fp4`. However, we missed enabling SM121 for backend = `cudnn` and `cutlass`. Additionally, we marked `trtllm` as supported on SM120 when it is not. Current PR fixes it. Example benchmark and pytest command on SM121 after the fix ``` (py312) root@f414f262f02a:/flashinfer/benchmarks# python3 flashinfer_benchmark.py --routine mm_fp4 --m 8192 --n 7168 --k 512 --out_dtype bfloat16 --backends cudnn cutlass --use_128x4_sf_layout --use_nvfp4 --refcheck --use_cupti /opt/conda/envs/py312/lib/python3.12/site-packages/torch/cuda/__init__.py:285: UserWarning: Found GPU0 NVIDIA GB10 which is of cuda capability 12.1. Minimum and Maximum cuda capability supported by this version of PyTorch is (8.0) - (12.0) warnings.warn( [PERF] cudnn :: median time 0.656 ms; std 0.025 ms; achieved tflops 91.701 TFLOPs/sec; achieved tb_per_sec 0.185 TB/sec [PERF] cutlass :: median time 0.669 ms; std 0.022 ms; achieved tflops 89.859 TFLOPs/sec; achieved tb_per_sec 0.181 TB/sec (py312) root@f414f262f02a:/flashinfer# pytest tests/gemm/test_mm_fp4.py ====================================================================================================================== test session starts ====================================================================================================================== platform linux -- Python 3.12.11, pytest-8.4.2, pluggy-1.6.0 rootdir: /flashinfer configfile: pytest.ini collected 3240 items ... ======================================================================================================================= warnings summary ======================================================================================================================== ../opt/conda/envs/py312/lib/python3.12/site-packages/torch/cuda/__init__.py:285 /opt/conda/envs/py312/lib/python3.12/site-packages/torch/cuda/__init__.py:285: UserWarning: Found GPU0 NVIDIA GB10 which is of cuda capability 12.1. Minimum and Maximum cuda capability supported by this version of PyTorch is (8.0) - (12.0) warnings.warn( -- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html ========================================================================================================= 450 passed, 2790 skipped, 1 warning in 8.24s ========================================================================================================== ``` <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Expanded hardware compatibility by adding support for newer NVIDIA GPU architectures. * FP4 quantized operations now available across multiple backends on supported devices. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
📌 Description
This PR adds is_*supported checks for backend and compute capability, through decorators.
Example:

🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes