Skip to content

feat: BF16 GEMM using cuDNN backend#2376

Merged
aleozlx merged 10 commits intoflashinfer-ai:mainfrom
raayandhar:user/rdhar/cudnn_bf16_gemm
Jan 27, 2026
Merged

feat: BF16 GEMM using cuDNN backend#2376
aleozlx merged 10 commits intoflashinfer-ai:mainfrom
raayandhar:user/rdhar/cudnn_bf16_gemm

Conversation

@raayandhar
Copy link
Copy Markdown
Contributor

@raayandhar raayandhar commented Jan 20, 2026

📌 Description

Adds the cuDNN backend to BF16 GEMM.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features
    • Added cuDNN backend and “auto” selection for BF16 MM/BMM, including cuDNN graph execution for 2D/3D and a helper to normalize 3D shapes/strides. Introduced a public runner to leverage cuDNN when available.
  • Bug Fixes
    • Strengthened validation for output tensors and problem-size checks; clearer messages for backend limits (cuDNN excludes bias/PDL; fp16 output restricted on some GPUs). Improved backend availability detection.
  • Tests
    • Parameterized tests across CUTLASS/cuDNN; skip unsupported compute capabilities and incompatible scenarios (cuDNN with bias/PDL, SM103 with fp16 output).

✏️ Tip: You can customize this high-level summary in your review settings.

Signed-off-by: raayandhar <raayan.dhar@gmail.com>
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Jan 20, 2026

📝 Walkthrough

Walkthrough

Adds cuDNN-backed bfloat16 MM/BMM execution paths and runners, integrates "cudnn" and "auto" into backend selection and heuristics, validates optional outputs and bias/PDL constraints, normalizes 2D/3D shapes/strides for cuDNN graphs, and exposes cuDNN graph build/execute helpers alongside existing Cutlass/TGV flows.

Changes

Cohort / File(s) Summary
BF16 GEMM/BMM Core
flashinfer/gemm/gemm_base.py
Added cuDNN BF16 requirement checks (_cudnn_mm_bf16_requirement, _cudnn_bmm_bf16_requirement), expanded backend literals to include "cudnn" and "auto", updated problem-size checks and heuristics, added _get_bf16_3d_shape_stride, cuDNN graph helpers (build_cudnn_gemm_bf16_graph, execute_cudnn_gemm_bf16_graph), _cudnn_gemm_bf16 and _cudnn_gemm_bf16_runner, and wired cuDNN into mm_bf16/bmm_bf16 selection.
Tests
tests/gemm/test_bmm_bf16.py, tests/gemm/test_mm_bf16.py
Parameterized tests to include "cudnn", added runtime skips for unsupported compute capabilities and cuDNN limitations (bias/PDL and SM103 fp16 output), and forwarded backend argument to bmm/mm BF16 test calls.

Sequence Diagram

sequenceDiagram
    participant User as User Code
    participant Selector as Backend Selector
    participant Validator as Input Validator
    participant CuGraph as cuDNN Graph Builder
    participant CuExec as cuDNN Executor
    participant Cutlass as CUTLASS/TGV Executor
    participant Result as Result Tensor

    User->>Selector: call mm_bf16/bmm_bf16(..., backend="auto"/"cudnn"/"cutlass")
    Selector->>Validator: validate inputs (shapes, device, out dtype), bias/PDL constraints
    Validator-->>Selector: validation OK / reject
    alt select cuDNN
        Selector->>CuGraph: build/cache graph (normalize 2D→3D shapes/strides)
        CuGraph->>CuExec: provide graph/plans and workspace
        CuExec->>Result: execute cuDNN BF16 GEMM -> out
    else select CUTLASS/TGV
        Selector->>Cutlass: call CUTLASS/TGV runner
        Cutlass->>Result: execute CUTLASS BF16 GEMM -> out
    end
    Result->>User: return output tensor
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested reviewers

  • bkryu
  • nvmbreughe
  • yongwww
  • jimmyzho

Poem

🐰 I hopped through shapes and strides with glee,
I built small graphs for cuDNN and me.
Cutlass or cudnn, a path to choose,
I route the bfloat16s and mind the rules. ✨

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 13.04% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and concisely summarizes the main change: adding cuDNN backend support for BF16 GEMM operations.
Description check ✅ Passed The description follows the template structure with a brief explanation of changes, completed pre-commit checks, and confirmed test updates, though the Related Issues section is empty and Reviewer Notes are blank.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @raayandhar, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the BF16 batched GEMM functionality by integrating the cuDNN library as a new backend option. This addition aims to provide more optimized or alternative execution paths for matrix multiplication operations, potentially leading to improved performance or broader compatibility. The changes also include robust output tensor validation and a flexible backend selection mechanism, making the GEMM operations more reliable and adaptable.

Highlights

  • cuDNN Backend Integration: Added support for the cuDNN backend for BF16 batched General Matrix Multiply (GEMM) operations, providing an alternative to the existing CUTLASS backend.
  • Output Tensor Validation: Introduced comprehensive validation checks for output tensors in both mm_bf16 and bmm_bf16 functions, ensuring correct shape, device, and data type before computation.
  • Backend Selection Mechanism: Enhanced the bmm_bf16 function to allow explicit selection of 'cudnn', 'cutlass', or 'auto' backends, with 'auto' enabling dynamic selection based on available optimizations.
  • cuDNN Graph Implementation: Implemented a cuDNN graph-based approach for BF16 GEMM, including functions for building, executing, and managing cuDNN tensors and operations.
  • Expanded Test Coverage: Updated the test_bmm_bf16 suite to include tests for both 'cutlass' and 'cudnn' backends, ensuring the new cuDNN implementation functions correctly.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@raayandhar
Copy link
Copy Markdown
Contributor Author

Test results:

(flashinfer) root@pig-of-classic-anger:~/flashinfer# pytest tests/gemm/test_bmm_bf16.py
========================================== test session starts ==========================================
platform linux -- Python 3.10.12, pytest-9.0.2, pluggy-1.6.0
rootdir: /root/flashinfer
configfile: pytest.ini
collected 64 items

tests/gemm/test_bmm_bf16.py ................................................................      [100%]

=========================================== warnings summary ============================================
.venv/lib/python3.10/site-packages/tvm_ffi/_optional_torch_c_dlpack.py:174
  /root/flashinfer/.venv/lib/python3.10/site-packages/tvm_ffi/_optional_torch_c_dlpack.py:174: UserWarning: Failed to JIT torch c dlpack extension, EnvTensorAllocator will not be enabled.
  We recommend installing via `pip install torch-c-dlpack-ext`
    warnings.warn(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
=============================== 64 passed, 1 warning in 75.64s (0:01:15) ================================
(flashinfer) root@pig-of-classic-anger:~/flashinfer#

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds a cuDNN backend for BF16 batched GEMM, which is a great feature. The changes are well-structured, including the necessary requirement checks, runner implementation, and test coverage.
My review focuses on a few key areas:

  1. Enabling Autotuning for cuDNN: The new cuDNN runner currently doesn't support autotuning over different tactics. I've provided a set of related suggestions to fully enable this, which will allow the autotuner to find the best-performing cuDNN execution plan.
  2. Code Simplification: I've identified a redundant logic block for backend selection that can be simplified for better readability and maintainability.
  3. Code Style in Tests: A minor suggestion to avoid shadowing a Python built-in function in the test file.

Overall, this is a solid contribution. Addressing these points will make the new backend more robust and performant.

@raayandhar
Copy link
Copy Markdown
Contributor Author

I tried getting cuDNN to work with mm_bf16 but it seems like it wasn't able to find a valid plan? I'm looking into it but if there are any cuDNN experts that know what I'm doing wrong I'd appreciate it. I also noticed that mm_fp8 did not have a cuDNN backend but bmm_fp8 did. Not sure if it comes from the same problem (if it is a problem at all).

Sorry that this PR took me a while to open, I was pretty sick.

Once we are happy and land this, I will add support for benchmarking the BF16 GEMMs like I mentioned earlier.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/gemm/gemm_base.py (1)

446-459: Guard auto-heuristic against missing cuDNN frontend.

_heuristic_func_bmm_bf16 can select "cudnn" even when the cuDNN Python frontend isn't installed, causing runtime failures when backend="auto". The parallel function _heuristic_func_bmm_fp8 already gates this with CUDNN_AVAILABLE. Apply the same guard:

🛠️ Proposed fix
-    if "cudnn" in suitable_backends:
-        heuristic_backends.append("cudnn")
+    if CUDNN_AVAILABLE and "cudnn" in suitable_backends:
+        heuristic_backends.append("cudnn")
🤖 Fix all issues with AI agents
In `@tests/gemm/test_bmm_bf16.py`:
- Around line 14-32: In test_bmm_bf16 add an explicit import guard for cuDNN:
when backend == "cudnn" call pytest.importorskip("cudnn") before invoking
bmm_bf16.is_backend_supported so the test is skipped early if the cuDNN frontend
package is not installed; modify the test_bmm_bf16 function to perform this
import-or-skip check using the existing backend variable and then proceed to
call is_backend_supported as before.

Signed-off-by: raayandhar <raayan.dhar@gmail.com>
Signed-off-by: raayandhar <raayan.dhar@gmail.com>
Signed-off-by: raayandhar <raayan.dhar@gmail.com>
Signed-off-by: raayandhar <raayan.dhar@gmail.com>
Signed-off-by: raayandhar <raayan.dhar@gmail.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@tests/gemm/test_mm_bf16.py`:
- Around line 32-38: The test currently only checks compute capability via
mm_bf16.is_backend_supported but doesn't verify library availability; import
CUDNN_AVAILABLE from flashinfer.gemm.gemm_base and add an explicit skip when
backend == "cudnn" and not CUDNN_AVAILABLE to avoid runtime RuntimeError from
_check_cudnn_availability(); keep the existing skip that blocks bias/pdl when
cuDNN is selected.
🧹 Nitpick comments (1)
flashinfer/gemm/gemm_base.py (1)

185-186: Remove duplicate compute-capability decorator.
Two identical @supported_compute_capability([100, 103]) decorators stack on _cutlass_mm_bf16_requirement. Keep one for clarity.

♻️ Suggested change
-@supported_compute_capability([100, 103])
 `@supported_compute_capability`([100, 103])
 def _cutlass_mm_bf16_requirement(

Signed-off-by: raayandhar <raayan.dhar@gmail.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@flashinfer/gemm/gemm_base.py`:
- Around line 185-186: Remove the duplicate decorator so only one
`@supported_compute_capability`([100, 103]) remains directly above the target
definition; specifically, delete one of the two identical
`@supported_compute_capability`([100, 103]) lines that are stacked before the
subsequent function/class declaration to leave a single decorator instance.

Signed-off-by: raayandhar <raayan.dhar@gmail.com>
Signed-off-by: raayandhar <raayan.dhar@gmail.com>
@raayandhar
Copy link
Copy Markdown
Contributor Author

raayandhar commented Jan 23, 2026

SM103 results:

BMM:

(flashinfer) root@c532599d0ce4:/sgl-workspace/sglang/flashinfer# pytest tests/gemm/test_bmm_bf16.py
=================================================================== test session starts ===================================================================
platform linux -- Python 3.12.3, pytest-9.0.2, pluggy-1.6.0
rootdir: /sgl-workspace/sglang/flashinfer
configfile: pytest.ini
collected 64 items

tests/gemm/test_bmm_bf16.py ................................................ssssssssssssssss                                                        [100%]

============================================================= 48 passed, 16 skipped in 2.91s ==============================================================
(flashinfer) root@c532599d0ce4:/sgl-workspace/sglang/flashinfer#

MM:

(flashinfer) root@c532599d0ce4:/sgl-workspace/sglang/flashinfer# pytest tests/gemm/test_mm_bf16.py -k "cudnn"
=================================================================== test session starts ===================================================================
platform linux -- Python 3.12.3, pytest-9.0.2, pluggy-1.6.0
rootdir: /sgl-workspace/sglang/flashinfer
configfile: pytest.ini
collected 1080 items / 720 deselected / 360 selected

tests/gemm/test_mm_bf16.py ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 33%]
sssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 74%]
sss.............................................sssssssssssssssssssssssssssssssssssssssssssss                                                       [100%]

===================================================== 45 passed, 315 skipped, 720 deselected in 1.31s =====================================================
(flashinfer) root@c532599d0ce4:/sgl-workspace/sglang/flashinfer#

Have not tested on B200 (SM100) yet with most recent changes. The prices to rent are super high today.

)
if not bmm_bf16.is_backend_supported(backend, compute_capability_number):
pytest.skip(f"{backend} backend not supported on current compute capability.")
# cuDNN on SM103 does not support bf16 input -> fp16 output
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know why, but for some reason on SM103 this test was failing (when we had fp16 out dtype). This was with my init changes, which I had previously tested on B200 (SM100) and was passing (scroll to the top of my PR to see the BMM tests passing). So not sure why this is happening, maybe someone knows.

Copy link
Copy Markdown
Collaborator

@aleozlx aleozlx Jan 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

adding @dhiraj113 for vis

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did you get a not supported error or result mismatch error? then can we file this as an github issue and mention in the code as a comment?

Copy link
Copy Markdown
Contributor Author

@raayandhar raayandhar Jan 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

iirc it was a cuDNN can't find an execution plan error but I can re-test later today and find out, it's possible it has to do with my implementation

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i see try exporting env vars CUDNN_LOGLEVEL_DBG=2 CUDNN_LOGDEST_DBG=somefile.log

this may capture more detailed error reason if it's not able to come up with an execution plan

@raayandhar
Copy link
Copy Markdown
Contributor Author

cc: @aleozlx @bkryu

@raayandhar
Copy link
Copy Markdown
Contributor Author

All the failing errors so far seem to be about "docker: Error response from daemon: error from registry: You have reached your unauthenticated pull rate limit."

@aleozlx aleozlx self-assigned this Jan 23, 2026
@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Jan 23, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !261 has been created, and the CI pipeline #42366483 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #42366483: 9/20 passed

@raayandhar
Copy link
Copy Markdown
Contributor Author

Test results on B200:

(flashinfer) root@38b552430284:~/flashinfer# pytest tests/gemm/test_bmm_bf16.py
========================================== test session starts ==========================================
platform linux -- Python 3.11.11, pytest-9.0.2, pluggy-1.6.0
rootdir: /root/flashinfer
configfile: pytest.ini
collected 64 items

tests/gemm/test_bmm_bf16.py ................................................................      [100%]

===================================== 64 passed in 89.73s (0:01:29) =====================================
(flashinfer) root@38b552430284:~/flashinfer# pytest tests/gemm/test_mm_bf16.py -k "cudnn"
========================================== test session starts ==========================================
platform linux -- Python 3.11.11, pytest-9.0.2, pluggy-1.6.0
rootdir: /root/flashinfer
configfile: pytest.ini
collected 1080 items / 720 deselected / 360 selected

tests/gemm/test_mm_bf16.py ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 19%]
sssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 46%]
sssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 73%]
ssssss..........................................................................................  [100%]

=========================== 90 passed, 270 skipped, 720 deselected in 15.57s ============================
(flashinfer) root@38b552430284:~/flashinfer#

Copy link
Copy Markdown
Collaborator

@aleozlx aleozlx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Jan 26, 2026

there doesn't seem to be relevant errors in the bot-run pipeline
seems good to go
@jimmyzho would you mind help approving it in? thx

@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Jan 26, 2026

cc @dhiraj113 - new gemm op code owner for FI

@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Jan 26, 2026

@sricketts @yzh119 i'm not sure if we discussed this yet, shall we give op owners relatively long-term (as opposed to based on recent check-ins) approval rights to help streamline PR processing, in alignment with staffing adjustment? thx

if i look at "Merging is blocked" list of code owners, not the best options and i hesitate to ping those who may already have so much ownership already and may cause distraction

@sricketts
Copy link
Copy Markdown
Member

i'm not sure if we discussed this yet, shall we give op owners relatively long-term (as opposed to based on recent check-ins) approval rights to help streamline PR processing, in alignment with staffing adjustment? thx

if i look at "Merging is blocked" list of code owners, not the best options and i hesitate to ping those who may already have so much ownership already and may cause distraction

Agreed -- created #2417 to track and discuss more.

Copy link
Copy Markdown
Contributor

@jimmyzho jimmyzho left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm!

@aleozlx
Copy link
Copy Markdown
Collaborator

aleozlx commented Jan 27, 2026

did git merge main to restart checks that were blocking the merge

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants