Skip to content

[fmha-v2] Remove H2D transfer for bmm2 scale, enabling cuda graphs#2866

Open
akhilg-nv wants to merge 2 commits intoflashinfer-ai:mainfrom
akhilg-nv:fmha_cuda_graph
Open

[fmha-v2] Remove H2D transfer for bmm2 scale, enabling cuda graphs#2866
akhilg-nv wants to merge 2 commits intoflashinfer-ai:mainfrom
akhilg-nv:fmha_cuda_graph

Conversation

@akhilg-nv
Copy link
Copy Markdown
Contributor

@akhilg-nv akhilg-nv commented Mar 23, 2026

📌 Description

Removes a host to device transfer for setting the bmm2 scale, enabling cuda graphs. Potentially, we can also enable the API to accept a tensor for this scale, such that it can be computed elsewhere and passed in as a constant.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Chores

    • Internal performance optimizations for tensor creation; no user-visible changes.
  • Refactor

    • Simplified and standardized tensor construction logic to improve reliability and consistency across numeric types, reducing internal complexity without impacting functionality.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request optimizes the flashinfer library by removing a host-to-device data transfer bottleneck for the bmm2_scale parameter. This change is essential for enabling the use of CUDA graphs, which allows for more efficient execution of computational graphs on the GPU, leading to improved overall performance.

Highlights

  • Performance Optimization: Eliminated a host-to-device transfer for the bmm2_scale parameter, which is crucial for enabling CUDA graphs and improving performance.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 23, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 33e3423d-04ef-4eed-95b1-53760c407ca1

📥 Commits

Reviewing files that changed from the base of the PR and between 6b800a0 and bdf2911.

📒 Files selected for processing (1)
  • flashinfer/prefill.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • flashinfer/prefill.py

📝 Walkthrough

Walkthrough

Refactored tensor construction in flashinfer/prefill.py _create_scale_bmm2_d_tensor: replaced scalar/tensor creation using torch.tensor and manual view/write with torch.full((1,), ...) for FP16/BF16 and FP32 paths, retaining device, dtype, and final .view(...).to(...) reinterpretation.

Changes

Cohort / File(s) Summary
Tensor creation changes
flashinfer/prefill.py
Replaced scalar/tensor creation with torch.full((1,), scale_bmm2, ...) for both FP16/BF16 and non-FP16/BF16 branches; removed manual zero+assignment pattern and preserved subsequent .view(...).to(...) reinterpretation.

Estimated code review effort

🎯 1 (Trivial) | ⏱️ ~2 minutes

Suggested reviewers

  • aleozlx
  • cyx-6
  • yzh119

Poem

A rabbit hops through code so bright, 🐰
Swapping tensors with a gentle pull,
One element kept, placed just right,
torch.full sings soft and full,
Small change, same outcome — tidy and cool. ✨

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically describes the main change: removing an H2D transfer for bmm2 scale to enable CUDA graphs, which aligns with the code modifications.
Description check ✅ Passed The description includes a clear explanation of the changes (H2D transfer removal for bmm2 scale and CUDA graphs enablement), completed pre-commit checks, and passing tests confirmation, but the 'Tests have been added or updated' checkbox is unchecked.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a host-to-device (H2D) transfer by replacing torch.tensor([scalar]) with torch.full() to improve performance and CUDA graph compatibility.

However, the same function, _create_scale_bmm2_d_tensor, still contains H2D transfers in the branches for float16 and bfloat16 data types. The item assignment result.view(...)[0] = scale_bmm2 on a device tensor triggers an H2D copy, which can also break CUDA graphs.

To make the function fully free of H2D transfers, I recommend updating the other branches as well.

For float16 (lines 114-116):

# Before
result = torch.zeros(1, dtype=torch.int32, device=device)
result.view(torch.float16)[0] = scale_bmm2
return result
# Suggested change
return (
    torch.full((1,), scale_bmm2, dtype=torch.float16, device=device)
    .view(torch.uint16)
    .to(torch.int32)
)

For bfloat16 (lines 119-121):

# Before
result = torch.zeros(1, dtype=torch.int32, device=device)
result.view(torch.bfloat16)[0] = scale_bmm2
return result
# Suggested change
return (
    torch.full((1,), scale_bmm2, dtype=torch.bfloat16, device=device)
    .view(torch.uint16)
    .to(torch.int32)
)

Applying these suggestions would make the function fully robust for CUDA graphing, completing the goal of this pull request.

Comment thread flashinfer/prefill.py
result.view(torch.float16)[0] = scale_bmm2
return result
return (
torch.full((1,), scale_bmm2, dtype=torch.float16, device=device)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this still causes a tiny Fill kernel.

If we want to eliminate this kernel as well, the solution would be to accept bmm1_scale and bmm2_scale as a torch.Tensor so that the framework (like SGLang) can provide the scales as device tensors directly (and framework can cache that across decoding steps).

https://github.com/akhilg-nv/flashinfer/blob/bdf29115facde5097b050c5ffdf60f0eae9826f9/flashinfer/prefill.py#L4088-L4089

See this as an example: https://github.com/akhilg-nv/flashinfer/blob/bdf29115facde5097b050c5ffdf60f0eae9826f9/flashinfer/prefill.py#L3725-L3726

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, Jimmy has this draft PR up which allows the scale to be Union[float, torch.Tensor]. I think we will still want to keep the logic in this PR for the case where the input is a float, but perhaps it may be better to force it to be a tensor?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants