feat: BF16 GEMM using cuDNN backend#2376
Conversation
Signed-off-by: raayandhar <raayan.dhar@gmail.com>
📝 WalkthroughWalkthroughAdds cuDNN-backed bfloat16 MM/BMM execution paths and runners, integrates "cudnn" and "auto" into backend selection and heuristics, validates optional outputs and bias/PDL constraints, normalizes 2D/3D shapes/strides for cuDNN graphs, and exposes cuDNN graph build/execute helpers alongside existing Cutlass/TGV flows. Changes
Sequence DiagramsequenceDiagram
participant User as User Code
participant Selector as Backend Selector
participant Validator as Input Validator
participant CuGraph as cuDNN Graph Builder
participant CuExec as cuDNN Executor
participant Cutlass as CUTLASS/TGV Executor
participant Result as Result Tensor
User->>Selector: call mm_bf16/bmm_bf16(..., backend="auto"/"cudnn"/"cutlass")
Selector->>Validator: validate inputs (shapes, device, out dtype), bias/PDL constraints
Validator-->>Selector: validation OK / reject
alt select cuDNN
Selector->>CuGraph: build/cache graph (normalize 2D→3D shapes/strides)
CuGraph->>CuExec: provide graph/plans and workspace
CuExec->>Result: execute cuDNN BF16 GEMM -> out
else select CUTLASS/TGV
Selector->>Cutlass: call CUTLASS/TGV runner
Cutlass->>Result: execute CUTLASS BF16 GEMM -> out
end
Result->>User: return output tensor
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @raayandhar, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the BF16 batched GEMM functionality by integrating the cuDNN library as a new backend option. This addition aims to provide more optimized or alternative execution paths for matrix multiplication operations, potentially leading to improved performance or broader compatibility. The changes also include robust output tensor validation and a flexible backend selection mechanism, making the GEMM operations more reliable and adaptable. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
|
Test results: |
There was a problem hiding this comment.
Code Review
This pull request adds a cuDNN backend for BF16 batched GEMM, which is a great feature. The changes are well-structured, including the necessary requirement checks, runner implementation, and test coverage.
My review focuses on a few key areas:
- Enabling Autotuning for cuDNN: The new cuDNN runner currently doesn't support autotuning over different tactics. I've provided a set of related suggestions to fully enable this, which will allow the autotuner to find the best-performing cuDNN execution plan.
- Code Simplification: I've identified a redundant logic block for backend selection that can be simplified for better readability and maintainability.
- Code Style in Tests: A minor suggestion to avoid shadowing a Python built-in function in the test file.
Overall, this is a solid contribution. Addressing these points will make the new backend more robust and performant.
|
I tried getting cuDNN to work with Sorry that this PR took me a while to open, I was pretty sick. Once we are happy and land this, I will add support for benchmarking the BF16 GEMMs like I mentioned earlier. |
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/gemm/gemm_base.py (1)
446-459: Guard auto-heuristic against missing cuDNN frontend.
_heuristic_func_bmm_bf16can select"cudnn"even when the cuDNN Python frontend isn't installed, causing runtime failures whenbackend="auto". The parallel function_heuristic_func_bmm_fp8already gates this withCUDNN_AVAILABLE. Apply the same guard:🛠️ Proposed fix
- if "cudnn" in suitable_backends: - heuristic_backends.append("cudnn") + if CUDNN_AVAILABLE and "cudnn" in suitable_backends: + heuristic_backends.append("cudnn")
🤖 Fix all issues with AI agents
In `@tests/gemm/test_bmm_bf16.py`:
- Around line 14-32: In test_bmm_bf16 add an explicit import guard for cuDNN:
when backend == "cudnn" call pytest.importorskip("cudnn") before invoking
bmm_bf16.is_backend_supported so the test is skipped early if the cuDNN frontend
package is not installed; modify the test_bmm_bf16 function to perform this
import-or-skip check using the existing backend variable and then proceed to
call is_backend_supported as before.
Signed-off-by: raayandhar <raayan.dhar@gmail.com>
Signed-off-by: raayandhar <raayan.dhar@gmail.com>
Signed-off-by: raayandhar <raayan.dhar@gmail.com>
Signed-off-by: raayandhar <raayan.dhar@gmail.com>
Signed-off-by: raayandhar <raayan.dhar@gmail.com>
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@tests/gemm/test_mm_bf16.py`:
- Around line 32-38: The test currently only checks compute capability via
mm_bf16.is_backend_supported but doesn't verify library availability; import
CUDNN_AVAILABLE from flashinfer.gemm.gemm_base and add an explicit skip when
backend == "cudnn" and not CUDNN_AVAILABLE to avoid runtime RuntimeError from
_check_cudnn_availability(); keep the existing skip that blocks bias/pdl when
cuDNN is selected.
🧹 Nitpick comments (1)
flashinfer/gemm/gemm_base.py (1)
185-186: Remove duplicate compute-capability decorator.
Two identical@supported_compute_capability([100, 103])decorators stack on_cutlass_mm_bf16_requirement. Keep one for clarity.♻️ Suggested change
-@supported_compute_capability([100, 103]) `@supported_compute_capability`([100, 103]) def _cutlass_mm_bf16_requirement(
Signed-off-by: raayandhar <raayan.dhar@gmail.com>
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@flashinfer/gemm/gemm_base.py`:
- Around line 185-186: Remove the duplicate decorator so only one
`@supported_compute_capability`([100, 103]) remains directly above the target
definition; specifically, delete one of the two identical
`@supported_compute_capability`([100, 103]) lines that are stacked before the
subsequent function/class declaration to leave a single decorator instance.
Signed-off-by: raayandhar <raayan.dhar@gmail.com>
Signed-off-by: raayandhar <raayan.dhar@gmail.com>
|
SM103 results: BMM: MM: Have not tested on B200 (SM100) yet with most recent changes. The prices to rent are super high today. |
| ) | ||
| if not bmm_bf16.is_backend_supported(backend, compute_capability_number): | ||
| pytest.skip(f"{backend} backend not supported on current compute capability.") | ||
| # cuDNN on SM103 does not support bf16 input -> fp16 output |
There was a problem hiding this comment.
I don't know why, but for some reason on SM103 this test was failing (when we had fp16 out dtype). This was with my init changes, which I had previously tested on B200 (SM100) and was passing (scroll to the top of my PR to see the BMM tests passing). So not sure why this is happening, maybe someone knows.
There was a problem hiding this comment.
did you get a not supported error or result mismatch error? then can we file this as an github issue and mention in the code as a comment?
There was a problem hiding this comment.
iirc it was a cuDNN can't find an execution plan error but I can re-test later today and find out, it's possible it has to do with my implementation
There was a problem hiding this comment.
i see try exporting env vars CUDNN_LOGLEVEL_DBG=2 CUDNN_LOGDEST_DBG=somefile.log
this may capture more detailed error reason if it's not able to come up with an execution plan
|
All the failing errors so far seem to be about " |
|
/bot run |
|
[FAILED] Pipeline #42366483: 9/20 passed |
|
Test results on B200: |
|
there doesn't seem to be relevant errors in the bot-run pipeline |
|
cc @dhiraj113 - new gemm op code owner for FI |
|
@sricketts @yzh119 i'm not sure if we discussed this yet, shall we give op owners relatively long-term (as opposed to based on recent check-ins) approval rights to help streamline PR processing, in alignment with staffing adjustment? thx if i look at "Merging is blocked" list of code owners, not the best options and i hesitate to ping those who may already have so much ownership already and may cause distraction |
Agreed -- created #2417 to track and discuss more. |
|
did |
📌 Description
Adds the cuDNN backend to BF16 GEMM.
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
✏️ Tip: You can customize this high-level summary in your review settings.