Skip to content

Perf: support scale_a/scale_b instead of combined scale in cutlass bmm_fp8#1491

Merged
yzh119 merged 1 commit intoflashinfer-ai:mainfrom
ttyio:dev/vincenth/support-gemm-separate_scale
Aug 17, 2025
Merged

Perf: support scale_a/scale_b instead of combined scale in cutlass bmm_fp8#1491
yzh119 merged 1 commit intoflashinfer-ai:mainfrom
ttyio:dev/vincenth/support-gemm-separate_scale

Conversation

@ttyio
Copy link
Collaborator

@ttyio ttyio commented Aug 14, 2025

Previous cutlass implementation require combine of scale_a/scale_b in the python, so extra aten kernel is used which may cost 1us. Now we support separate alpha_a and alpha_b in epilog, to avoid this extra aten kernel.

📌 Description

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @ttyio, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a performance optimization for FP8 Batch Matrix Multiplication (BMM) within the Cutlass library. Previously, the scale_a and scale_b parameters had to be combined in Python using an extra ATen kernel, incurring a 1-microsecond overhead. This change modifies the Cutlass implementation to directly support separate scale_a and scale_b values within the epilogue, thereby eliminating the need for the external kernel and improving overall performance of FP8 GEMM operations.

Highlights

  • Direct Scale Handling: The Cutlass FP8 GEMM implementation now directly accepts separate scale_a and scale_b tensors, eliminating the need for pre-combining them in Python.
  • Performance Improvement: This change avoids an extra ATen kernel call that previously added approximately 1 microsecond of overhead, leading to performance gains for FP8 BMM operations.
  • Epilogue Fusion Update: The Cutlass epilogue fusion logic has been updated to incorporate the separate scale_a and scale_b values, performing the scaling operation directly within the GPU kernel.
  • API and Internal Changes: The runGemm and fp8_bmm_impl functions in C++ and cutlass_fp8_gemm in Python have been refactored to accommodate the new scale_a and scale_b parameters.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a performance optimization by avoiding the host-side multiplication of scale_a and scale_b for FP8 GEMM operations. Instead, the scales are passed separately to the CUTLASS kernel and handled within the epilogue, which eliminates an unnecessary kernel launch. The changes are consistently applied across both the Python and C++ codebases and appear to be functionally correct. My feedback primarily focuses on improving code readability and maintainability by reformatting some long lines in the C++ template files.

@ttyio ttyio changed the title Perf: no more need combine scale_a/scale_b in cutlass bmm_fp8 Perf: support scale_a/scale_b instead of combined scale in cutlass bmm_fp8 Aug 14, 2025
Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@yzh119
Copy link
Collaborator

yzh119 commented Aug 15, 2025

There are some conflicts on flashinfer/gemm.py, please rebase

@ttyio ttyio force-pushed the dev/vincenth/support-gemm-separate_scale branch 2 times, most recently from 828e3eb to 001005b Compare August 15, 2025 15:22
@ttyio
Copy link
Collaborator Author

ttyio commented Aug 15, 2025

There are some conflicts on flashinfer/gemm.py, please rebase

rebased, thank you!

Previous cutlass implementation require combine of scale_a/scale_b in
the python, so extra aten kernel is used which may cost 1us. Now we
support separate alpha_a and alpha_b in epilog, to avoid this extra aten
kernel.

Signed-off-by: Vincent Huang <vincenth@nvidia.com>
@ttyio ttyio force-pushed the dev/vincenth/support-gemm-separate_scale branch from 001005b to 06747c3 Compare August 15, 2025 15:26
Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@yzh119 yzh119 merged commit 6518ce4 into flashinfer-ai:main Aug 17, 2025
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants