Commit 0aee7af
authored
feat: Add backend='auto' to mm_fp4 and enable autotune for backend='cudnn' (#1979)
<!-- .github/pull_request_template.md -->
## 📌 Description
Current PR:
* Introduces an `auto` backend to `mm_fp4` that can be autotuned. **It
replaces `cudnn` as the default.**
* Implementation matches `bmm_fp8`'s auto backend support.
* Allows `cudnn` backend to be autotuned.
* Added unit test test cases for backend=auto
Behavior of `auto` backend:
* Examines CUDA version & cuDNN version and calls either `cutlass` or
`cudnn` kernel backends. `trtllm` kernel is not considered due to a
non-interchangeable interface with other backends.
* `auto` backend therefore only supports inputs runnable by `cutlass`
and/or `cudnn.
* Non-autotuned behavior:
* Constructs an ordered list of backends (cudnn, cutlass) or (cutlass,
cudnn) where ordering is based on previous microbenchmark study results.
* If CUDA 12 --> cutlass comes to front.
* If CUDA 13 and cuDNN version < 9.15 --> cutlass comes front
* If CUDA 13 and cuDNN version >= 9.15 --> cudnn comes front
* If kernel is not available from a support check, it is removed from
the list.
* Autotune behavior:
* If backend is explicitly provided --> Autotunes within the backend.
Same as previous behavior, but now autotuning is supported for cudnn.
* If `backend='auto'` --> Autotunes within and across backends (cudnn &
cutlass) and chooses the best config of best backend. `trtllm` kernel is
not considered
* A lot of helper functions to `mm_fp4` were refactored to enable
cross-backend autotuning. Refactoring was done to match cross-backend
autotune-enabled `bmm_fp8` as a reference.
### Pytest outputs
`pytest tests/gemm/test_mm_fp4.py`
* SM100 (B200) CUDA 13 & cuDNN 9.15: `900 passed, 2532 skipped in
125.19s (0:02:05)`
* SM100 (B200) CUDA 12 & cuDNN 9.15: `900 passed, 2532 skipped in
125.67s (0:02:05)`
* SM120 (RTX 5090) CUDA 13 & cuDNN 9.15: `720 passed, 2712 skipped in
76.50s (0:01:16)`
### Example microbenchmark outputs:
On SM100 (B200) CUDA 13 & cuDNN 9.15
```
flashinfer/benchmarks# python3 flashinfer_benchmark.py --routine mm_fp4 --m 1024 --n 7168 --k 4608 --out_dtype bfloat16 --backends cudnn cutlass trtllm auto --use_128x4_sf_layout --use_nvfp4 --refcheck
[PERF] cudnn :: median time 0.018 ms; std 0.000 ms; achieved tflops 3797.932 TFLOPs/sec; achieved tb_per_sec 1.884 TB/sec
[PERF] cutlass :: median time 0.020 ms; std 0.000 ms; achieved tflops 3440.640 TFLOPs/sec; achieved tb_per_sec 1.707 TB/sec
[PERF] trtllm :: median time 0.031 ms; std 0.000 ms; achieved tflops 2187.427 TFLOPs/sec; achieved tb_per_sec 1.085 TB/sec
[PERF] auto :: median time 0.018 ms; std 0.000 ms; achieved tflops 3840.714 TFLOPs/sec; achieved tb_per_sec 1.905 TB/sec
/flashinfer/benchmarks# python3 flashinfer_benchmark.py --routine mm_fp4 --m 1024 --n 7168 --k 4608 --out_dtype bfloat16 --backends cudnn cutlass trtllm auto --use_128x4_sf_layout --refcheck
[INFO] cutlass backend does not support this configuration: ValueError: Only cudnn and auto FP4 GEMM supports mxfp4 quantization.
[INFO] trtllm backend does not support this configuration: ValueError: Only cudnn and auto FP4 GEMM supports mxfp4 quantization.
[PERF] cudnn :: median time 0.021 ms; std 0.000 ms; achieved tflops 3238.249 TFLOPs/sec; achieved tb_per_sec 1.606 TB/sec
[PERF] auto :: median time 0.021 ms; std 0.000 ms; achieved tflops 3237.753 TFLOPs/sec; achieved tb_per_sec 1.606 TB/sec median time 0.009 ms; std 0.000 ms; achieved tflops 938.356 TFLOPs/sec; achieved tb_per_sec 2.069 TB/sec
## Autotune
/flashinfer/benchmarks# python3 flashinfer_benchmark.py --routine mm_fp4 --m 1024 --n 7168 --k 4608 --out_dtype bfloat16 --backends cudnn cutlass trtllm auto --use_128x4_sf_layout --use_nvfp4 --refcheck --autotune
2025-11-11 23:43:23,715 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ...
2025-11-11 23:43:25,789 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends
2025-11-11 23:43:25,790 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ...
2025-11-11 23:43:26,251 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends
2025-11-11 23:43:26,251 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ...
2025-11-11 23:43:26,327 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends
2025-11-11 23:43:26,327 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ...
2025-11-11 23:43:26,335 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends
[PERF] cudnn_autotune :: median time 0.016 ms; std 0.000 ms; achieved tflops 4129.171 TFLOPs/sec; achieved tb_per_sec 2.048 TB/sec
[PERF] cutlass_autotun:: median time 0.019 ms; std 0.000 ms; achieved tflops 3513.845 TFLOPs/sec; achieved tb_per_sec 1.743 TB/sec
[PERF] trtllm_autotune:: median time 0.026 ms; std 0.000 ms; achieved tflops 2613.338 TFLOPs/sec; achieved tb_per_sec 1.296 TB/sec
[PERF] auto_autotune :: median time 0.016 ms; std 0.000 ms; achieved tflops 4128.768 TFLOPs/sec; achieved tb_per_sec 2.048 TB/sec
/flashinfer/benchmarks# python3 flashinfer_benchmark.py --routine mm_fp4 --m 1024 --n 7168 --k 4608 --out_dtype bfloat16 --backends cudnn cutlass trtllm auto --use_128x4_sf_layout --refcheck --autotune
[INFO] cutlass backend does not support this configuration: ValueError: Only cudnn and auto FP4 GEMM supports mxfp4 quantization.
[INFO] trtllm backend does not support this configuration: ValueError: Only cudnn and auto FP4 GEMM supports mxfp4 quantization.
2025-11-11 23:43:37,942 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ...
2025-11-11 23:43:43,116 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends
2025-11-11 23:43:43,116 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ...
2025-11-11 23:43:43,124 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends
[PERF] cudnn_autotune :: median time 0.020 ms; std 0.000 ms; achieved tflops 3370.154 TFLOPs/sec; achieved tb_per_sec 1.672 TB/sec
[PERF] auto_autotune :: median time 0.020 ms; std 0.000 ms; achieved tflops 3370.692 TFLOPs/sec; achieved tb_per_sec 1.672 TB/sec
```
On SM100 (B200) CUDA 12 & cuDNN 9.15
```
flashinfer/benchmarks# python3 flashinfer_benchmark.py --routine mm_fp4 --m 1024 --n 7168 --k 4608 --out_dtype bfloat16 --backends cudnn cutlass trtllm auto --use_128x4_sf_layout --use_nvfp4 --refcheck
[PERF] cudnn :: median time 0.023 ms; std 0.001 ms; achieved tflops 2975.898 TFLOPs/sec; achieved tb_per_sec 1.476 TB/sec
[PERF] cutlass :: median time 0.020 ms; std 0.000 ms; achieved tflops 3370.423 TFLOPs/sec; achieved tb_per_sec 1.672 TB/sec
[PERF] trtllm :: median time 0.031 ms; std 0.000 ms; achieved tflops 2187.427 TFLOPs/sec; achieved tb_per_sec 1.085 TB/sec
[PERF] auto :: median time 0.020 ms; std 0.000 ms; achieved tflops 3371.229 TFLOPs/sec; achieved tb_per_sec 1.672 TB/sec
(py312) root@84ef83abb1b5:/flashinfer/benchmarks# python3 flashinfer_benchmark.py --routine mm_fp4 --m 1024 --n 7168 --k 4608 --out_dtype bfloat16 --backends cudnn cutlass trtllm auto --use_128x4_sf_layout --refcheck
[INFO] cutlass backend does not support this configuration: ValueError: Only cudnn and auto FP4 GEMM supports mxfp4 quantization.
[INFO] trtllm backend does not support this configuration: ValueError: Only cudnn and auto FP4 GEMM supports mxfp4 quantization.
[PERF] cudnn :: median time 0.021 ms; std 0.000 ms; achieved tflops 3238.249 TFLOPs/sec; achieved tb_per_sec 1.606 TB/sec
[PERF] auto :: median time 0.021 ms; std 0.000 ms; achieved tflops 3238.249 TFLOPs/sec; achieved tb_per_sec 1.606 TB/sec
## Autotune
/flashinfer/benchmarks# python3 flashinfer_benchmark.py --routine mm_fp4 --m 1024 --n 7168 --k 4608 --out_dtype bfloat16 --backends cudnn cutlass trtllm auto --use_128x4_sf_layout --use_nvfp4 --refcheck --autotune
2025-11-11 23:42:43,378 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ...
2025-11-11 23:42:45,451 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends
2025-11-11 23:42:45,451 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ...
2025-11-11 23:42:45,910 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends
2025-11-11 23:42:45,910 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ...
2025-11-11 23:42:45,986 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends
2025-11-11 23:42:45,986 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ...
2025-11-11 23:42:45,993 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends
[PERF] cudnn_autotune :: median time 0.021 ms; std 0.000 ms; achieved tflops 3190.355 TFLOPs/sec; achieved tb_per_sec 1.583 TB/sec
[PERF] cutlass_autotun:: median time 0.019 ms; std 0.000 ms; achieved tflops 3551.330 TFLOPs/sec; achieved tb_per_sec 1.762 TB/sec
[PERF] trtllm_autotune:: median time 0.026 ms; std 0.000 ms; achieved tflops 2621.440 TFLOPs/sec; achieved tb_per_sec 1.300 TB/sec
[PERF] auto_autotune :: median time 0.019 ms; std 0.000 ms; achieved tflops 3551.628 TFLOPs/sec; achieved tb_per_sec 1.762 TB/sec
flashinfer/benchmarks# python3 flashinfer_benchmark.py --routine mm_fp4 --m 1024 --n 7168 --k 4608 --out_dtype bfloat16 --backends cudnn cutlass trtllm auto --use_128x4_sf_layout --refcheck --autotune
[INFO] cutlass backend does not support this configuration: ValueError: Only cudnn and auto FP4 GEMM supports mxfp4 quantization.
[INFO] trtllm backend does not support this configuration: ValueError: Only cudnn and auto FP4 GEMM supports mxfp4 quantization.
2025-11-11 23:42:55,176 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ...
2025-11-11 23:42:58,600 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends
2025-11-11 23:42:58,601 - INFO - autotuner.py:256 - flashinfer.jit: [Autotuner]: Autotuning process starts ...
2025-11-11 23:42:58,608 - INFO - autotuner.py:262 - flashinfer.jit: [Autotuner]: Autotuning process ends
[PERF] cudnn_autotune :: median time 0.021 ms; std 0.000 ms; achieved tflops 3238.249 TFLOPs/sec; achieved tb_per_sec 1.606 TB/sec
[PERF] auto_autotune :: median time 0.021 ms; std 0.000 ms; achieved tflops 3238.249 TFLOPs/sec; achieved tb_per_sec 1.606 TB/sec
```
On SM120 (RTX 5090) CUDA 13 & cuDNN 9.15
```
/flashinfer/benchmarks$ python3 flashinfer_benchmark.py --routine mm_fp4 --m 1024 --n 7168 --k 4608 --out_dtype bfloat16 --backends cudnn cutlass trtllm auto --use_128x4_sf_layout --use_nvfp4 --refcheck
[INFO] trtllm backend does not support this configuration: BackendSupportedError: mm_fp4 does not support backend 'trtllm' with capability 120
[PERF] cudnn :: median time 0.058 ms; std 0.000 ms; achieved tflops 1167.143 TFLOPs/sec; achieved tb_per_sec 0.579 TB/sec
[PERF] cutlass :: median time 0.060 ms; std 0.000 ms; achieved tflops 1135.056 TFLOPs/sec; achieved tb_per_sec 0.563 TB/sec
[PERF] auto :: median time 0.058 ms; std 0.000 ms; achieved tflops 1158.952 TFLOPs/sec; achieved tb_per_sec 0.575 TB/sec
/flashinfer/benchmarks$ python3 flashinfer_benchmark.py --routine mm_fp4 --m 1024 --n 7168 --k 4608 --out_dtype bfloat16 --backends cudnn cutlass trtllm auto --use_128x4_sf_layout --refcheck
[INFO] cutlass backend does not support this configuration: ValueError: Only cudnn and auto FP4 GEMM supports mxfp4 quantization.
[INFO] trtllm backend does not support this configuration: BackendSupportedError: mm_fp4 does not support backend 'trtllm' with capability 120
[PERF] cudnn :: median time 0.054 ms; std 0.000 ms; achieved tflops 1241.735 TFLOPs/sec; achieved tb_per_sec 0.616 TB/sec
[PERF] auto :: median time 0.054 ms; std 0.000 ms; achieved tflops 1241.735 TFLOPs/sec; achieved tb_per_sec 0.616 TB/sec
```
<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->
## 🔍 Related Issues
#1722
<!-- Link any related issues here -->
## 🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.
### ✅ Pre-commit Checks
- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.
> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).
## 🧪 Tests
- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).
## Reviewer Notes
<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
* **New Features**
* "auto" backend selection for FP4 ops to choose backend at runtime
* cuDNN, CUTLASS and TRTLLM selectable as FP4 GEMM backends
* CUDA/cuDNN version awareness to guide auto-backend heuristics
* **Improvements**
* Runtime capability checks replace static backend lists; unsupported
backends are removed dynamically
* Heuristic-driven auto-backend selection required for automatic mode
* Expanded autotuning/warmup across backends and relaxed FP4 validation
tolerance
* **Tests**
* Tests updated and added to exercise auto-backend scenarios and relaxed
constraints
<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>
<!-- end of auto-generated comment: release notes by coderabbit.ai -->1 parent 2628beb commit 0aee7af
File tree
6 files changed
+543
-381
lines changed- benchmarks/routines
- flashinfer
- gemm
- tests
- gemm
- utils
6 files changed
+543
-381
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
235 | 235 | | |
236 | 236 | | |
237 | 237 | | |
238 | | - | |
239 | | - | |
240 | | - | |
241 | | - | |
242 | | - | |
243 | | - | |
244 | | - | |
245 | | - | |
246 | | - | |
247 | | - | |
248 | | - | |
| 238 | + | |
249 | 239 | | |
250 | 240 | | |
251 | 241 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
131 | 131 | | |
132 | 132 | | |
133 | 133 | | |
134 | | - | |
| 134 | + | |
135 | 135 | | |
136 | 136 | | |
137 | 137 | | |
| |||
790 | 790 | | |
791 | 791 | | |
792 | 792 | | |
793 | | - | |
| 793 | + | |
794 | 794 | | |
795 | 795 | | |
796 | | - | |
797 | | - | |
798 | 796 | | |
799 | 797 | | |
800 | 798 | | |
801 | 799 | | |
802 | 800 | | |
803 | | - | |
804 | | - | |
805 | | - | |
806 | | - | |
807 | | - | |
808 | | - | |
809 | | - | |
810 | | - | |
811 | | - | |
812 | | - | |
813 | | - | |
814 | | - | |
815 | | - | |
816 | | - | |
817 | | - | |
818 | | - | |
819 | | - | |
820 | | - | |
821 | | - | |
822 | | - | |
823 | | - | |
824 | | - | |
825 | | - | |
826 | | - | |
827 | | - | |
828 | | - | |
829 | | - | |
830 | | - | |
831 | | - | |
832 | | - | |
833 | | - | |
834 | | - | |
835 | | - | |
836 | | - | |
837 | | - | |
838 | | - | |
839 | | - | |
840 | | - | |
841 | | - | |
842 | | - | |
843 | | - | |
844 | | - | |
845 | | - | |
846 | | - | |
847 | | - | |
848 | 801 | | |
849 | 802 | | |
850 | 803 | | |
| |||
886 | 839 | | |
887 | 840 | | |
888 | 841 | | |
889 | | - | |
| 842 | + | |
| 843 | + | |
| 844 | + | |
890 | 845 | | |
891 | | - | |
892 | | - | |
893 | | - | |
| 846 | + | |
| 847 | + | |
| 848 | + | |
| 849 | + | |
| 850 | + | |
| 851 | + | |
| 852 | + | |
| 853 | + | |
| 854 | + | |
| 855 | + | |
| 856 | + | |
| 857 | + | |
894 | 858 | | |
895 | 859 | | |
896 | 860 | | |
| |||
904 | 868 | | |
905 | 869 | | |
906 | 870 | | |
| 871 | + | |
| 872 | + | |
| 873 | + | |
| 874 | + | |
| 875 | + | |
| 876 | + | |
| 877 | + | |
| 878 | + | |
| 879 | + | |
| 880 | + | |
| 881 | + | |
| 882 | + | |
| 883 | + | |
| 884 | + | |
| 885 | + | |
| 886 | + | |
| 887 | + | |
| 888 | + | |
| 889 | + | |
| 890 | + | |
| 891 | + | |
| 892 | + | |
| 893 | + | |
| 894 | + | |
| 895 | + | |
| 896 | + | |
| 897 | + | |
| 898 | + | |
907 | 899 | | |
908 | 900 | | |
909 | 901 | | |
| |||
917 | 909 | | |
918 | 910 | | |
919 | 911 | | |
920 | | - | |
921 | | - | |
922 | | - | |
923 | | - | |
924 | | - | |
925 | | - | |
| 912 | + | |
| 913 | + | |
| 914 | + | |
| 915 | + | |
| 916 | + | |
926 | 917 | | |
927 | 918 | | |
928 | 919 | | |
| |||
0 commit comments