[fmha-v2] Remove H2D transfer for bmm2 scale, enabling cuda graphs#2866
[fmha-v2] Remove H2D transfer for bmm2 scale, enabling cuda graphs#2866akhilg-nv wants to merge 2 commits intoflashinfer-ai:mainfrom
Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request optimizes the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
📝 WalkthroughWalkthroughRefactored tensor construction in Changes
Estimated code review effort🎯 1 (Trivial) | ⏱️ ~2 minutes Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request addresses a host-to-device (H2D) transfer by replacing torch.tensor([scalar]) with torch.full() to improve performance and CUDA graph compatibility.
However, the same function, _create_scale_bmm2_d_tensor, still contains H2D transfers in the branches for float16 and bfloat16 data types. The item assignment result.view(...)[0] = scale_bmm2 on a device tensor triggers an H2D copy, which can also break CUDA graphs.
To make the function fully free of H2D transfers, I recommend updating the other branches as well.
For float16 (lines 114-116):
# Before
result = torch.zeros(1, dtype=torch.int32, device=device)
result.view(torch.float16)[0] = scale_bmm2
return result# Suggested change
return (
torch.full((1,), scale_bmm2, dtype=torch.float16, device=device)
.view(torch.uint16)
.to(torch.int32)
)For bfloat16 (lines 119-121):
# Before
result = torch.zeros(1, dtype=torch.int32, device=device)
result.view(torch.bfloat16)[0] = scale_bmm2
return result# Suggested change
return (
torch.full((1,), scale_bmm2, dtype=torch.bfloat16, device=device)
.view(torch.uint16)
.to(torch.int32)
)Applying these suggestions would make the function fully robust for CUDA graphing, completing the goal of this pull request.
6b800a0 to
bdf2911
Compare
| result.view(torch.float16)[0] = scale_bmm2 | ||
| return result | ||
| return ( | ||
| torch.full((1,), scale_bmm2, dtype=torch.float16, device=device) |
There was a problem hiding this comment.
I think this still causes a tiny Fill kernel.
If we want to eliminate this kernel as well, the solution would be to accept bmm1_scale and bmm2_scale as a torch.Tensor so that the framework (like SGLang) can provide the scales as device tensors directly (and framework can cache that across decoding steps).
See this as an example: https://github.com/akhilg-nv/flashinfer/blob/bdf29115facde5097b050c5ffdf60f0eae9826f9/flashinfer/prefill.py#L3725-L3726
There was a problem hiding this comment.
Correct, Jimmy has this draft PR up which allows the scale to be Union[float, torch.Tensor]. I think we will still want to keep the logic in this PR for the case where the input is a float, but perhaps it may be better to force it to be a tensor?
📌 Description
Removes a host to device transfer for setting the bmm2 scale, enabling cuda graphs. Potentially, we can also enable the API to accept a tensor for this scale, such that it can be computed elsewhere and passed in as a constant.
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
Chores
Refactor