Skip to content

Conversation

@nvmbreughe
Copy link
Contributor

@nvmbreughe nvmbreughe commented Sep 29, 2025

📌 Description

This PR adds is_*supported checks for backend and compute capability, through decorators.

  1. This allows us to check support before running
  2. It also wraps the original function so it calls back the support check before running.
  3. The wrapped function also adds an optional parameter "skip_check". A quick measurement show only minimal impact (14.51s without checks, 14.58s with checks for all of test_mm_fp4), so we should further benchmark the usefulness of this feature.

Example:
Screenshot 2025-10-12 at 9 39 06 PM

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

@nvjullin
Copy link
Contributor

nvjullin commented Oct 1, 2025

The checks currently live very far away from the implementation and updating them to be consistent with each other can eventually become a maintenance problem. The conditional checks are also quite tricky to get correct. For example, it's not easy to tell if the mxfp4 checks are correct.

    if not use_nvfp4 and block_size != 32:
        raise ValueError("mxfp4 supports block_size = 32.")

    if backend != "cudnn" and not use_nvfp4:
        raise ValueError("Only cudnn FP4 GEMM supports mxfp4 quantization.")

Shouldn't the checks be reordered to avoid confusing error messages?

  1. User tries trtllm + block_size=16 and gets rejected by block_size=32
  2. User then tries trtllm + block_size=32 and gets rejected by only cudnn is supported

Instead of having one top level supports_backends, perhaps consider a two level design:

  1. Local requirement decorator requirement written for each backend entrypoint
  2. Top level backend_requirement that composes requirements

For example:

def cudnn_gemm_fp4_requirement(
    # ...
):
        if (
            not use_nvfp4
            and _match_sm_version(a.device, ["120"])
            and cudnn.backend_version() < 91400
        ):
            raise LibraryError(
                "cudnn FP4 GEMM with mxfp4 quantization is not supported on SM120 with cuDNN backend version < 9.14.0."
            )

        _check_cudnn_fp4_availability()
        # ...

@requirement(cudnn_gemm_fp4_requirement, capability=["100", "101", "102"])
def execute_cudnn_gemm_fp4_graph(
    # ...


@backend_requirement({
    "cudnn": execute_cudnn_gemm_fp4_graph.requirement,
    "trtllm": #...
})
def mm_fp4(
    # ...

This also means that all requirements are enforced to be local to the backend and won't affect each other.

@nvmbreughe
Copy link
Contributor Author

Instead of having one top level supports_backends, perhaps consider a two level design:

  1. Local requirement decorator requirement written for each backend entrypoint
  2. Top level backend_requirement that composes requirements

Thank you @nvjullin for the excellent suggestion.
I think these are two discussions:

  1. Separate the support checks
  2. Separate the execution routines

While both are valid points, we prioritize separating the checks for now. Not all APIs are as cleanly to separate (2) atm and there is a plan for a more OO Backend class @Anerudhan. That does overlap somewhat with the support checks, as eventually we would be able to do something like cudnn_backend->check_mmfp4_support().

So I think as an intermediary step, and to get tighter checks in, we could do something like this:

@supported_compute_capability(["100", "101", "102"])
def cudnn_gemm_fp4_requirement(
    # ...
):
        if (
            not use_nvfp4
            and _match_sm_version(a.device, ["120"])
            and cudnn.backend_version() < 91400
        ):
            raise LibraryError(
                "cudnn FP4 GEMM with mxfp4 quantization is not supported on SM120 with cuDNN backend version < 9.14.0."
            )

        _check_cudnn_fp4_availability()
        # ...


@backend_requirement({
    "cudnn": execute_cudnn_gemm_fp4_graph.requirement,
    "trtllm": #...
    },
   common_check=common_fp4_checks # To be called by all backend checks
})
def mm_fp4(
    )
    if backend == "cudnn":
              # cudnn path
    elif backend == "trtllm":
              # trtllm path         

@nvjullin
Copy link
Contributor

While both are valid points, we prioritize separating the checks for now. Not all APIs are as cleanly to separate (2) atm and there is a plan for a more OO Backend class @Anerudhan. That does overlap somewhat with the support checks, as eventually we would be able to do something like cudnn_backend->check_mmfp4_support().

I wasn't aware, thanks for the info. LGTM.

@nvmbreughe
Copy link
Contributor Author

While both are valid points, we prioritize separating the checks for now. Not all APIs are as cleanly to separate (2) atm and there is a plan for a more OO Backend class @Anerudhan. That does overlap somewhat with the support checks, as eventually we would be able to do something like cudnn_backend->check_mmfp4_support().

I wasn't aware, thanks for the info. LGTM.

Thank you for the excellent suggestions, @nvjullin

@nvmbreughe nvmbreughe marked this pull request as ready for review October 13, 2025 17:07
@nvmbreughe nvmbreughe requested a review from sricketts October 13, 2025 17:07
@nvmbreughe
Copy link
Contributor Author

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !79 has been created, and the CI pipeline #36524696 is currently running. I'll report back once the pipeline job completes.

@nvmbreughe nvmbreughe enabled auto-merge (squash) October 13, 2025 22:44
Copy link
Collaborator

@aleozlx aleozlx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like a good step forward

@flashinfer-bot
Copy link
Collaborator

[SUCCESS] Pipeline #36524696: 13/17 passed

Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Member

@sricketts sricketts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM. Added one suggestion.

@nvmbreughe nvmbreughe merged commit d728bcd into flashinfer-ai:main Oct 14, 2025
3 checks passed
@sricketts sricketts mentioned this pull request Oct 14, 2025
32 tasks
@bkryu bkryu mentioned this pull request Oct 30, 2025
5 tasks
yzh119 pushed a commit that referenced this pull request Oct 30, 2025
<!-- .github/pull_request_template.md -->

## 📌 Description

In #1809 we previously added a compute-capability-based support check
for `mm_fp4`.

However, we missed enabling SM121 for backend = `cudnn` and  `cutlass`. 
Additionally, we marked `trtllm` as supported on SM120 when it is not.

Current PR fixes it. Example benchmark and pytest command on SM121 after
the fix
```
(py312) root@f414f262f02a:/flashinfer/benchmarks# python3 flashinfer_benchmark.py --routine mm_fp4 --m 8192 --n 7168 --k 512 --out_dtype bfloat16 --backends cudnn cutlass --use_128x4_sf_layout --use_nvfp4 --refcheck --use_cupti
/opt/conda/envs/py312/lib/python3.12/site-packages/torch/cuda/__init__.py:285: UserWarning: 
    Found GPU0 NVIDIA GB10 which is of cuda capability 12.1.
    Minimum and Maximum cuda capability supported by this version of PyTorch is
    (8.0) - (12.0)
    
  warnings.warn(
[PERF] cudnn          :: median time 0.656 ms; std 0.025 ms; achieved tflops 91.701 TFLOPs/sec; achieved tb_per_sec 0.185 TB/sec
[PERF] cutlass        :: median time 0.669 ms; std 0.022 ms; achieved tflops 89.859 TFLOPs/sec; achieved tb_per_sec 0.181 TB/sec

(py312) root@f414f262f02a:/flashinfer# pytest tests/gemm/test_mm_fp4.py 
====================================================================================================================== test session starts ======================================================================================================================
platform linux -- Python 3.12.11, pytest-8.4.2, pluggy-1.6.0
rootdir: /flashinfer
configfile: pytest.ini
collected 3240 items     
...
======================================================================================================================= warnings summary ========================================================================================================================
../opt/conda/envs/py312/lib/python3.12/site-packages/torch/cuda/__init__.py:285
  /opt/conda/envs/py312/lib/python3.12/site-packages/torch/cuda/__init__.py:285: UserWarning: 
      Found GPU0 NVIDIA GB10 which is of cuda capability 12.1.
      Minimum and Maximum cuda capability supported by this version of PyTorch is
      (8.0) - (12.0)
      
    warnings.warn(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
========================================================================================================= 450 passed, 2790 skipped, 1 warning in 8.24s ==========================================================================================================


```

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **New Features**
* Expanded hardware compatibility by adding support for newer NVIDIA GPU
architectures.
* FP4 quantized operations now available across multiple backends on
supported devices.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants