Skip to content

[Kernel][Hardware][AMD] Bf16 mfma opt for ROCm skinny GEMMs#17071

Merged
vllm-bot merged 8 commits intovllm-project:mainfrom
hshmhashemi:bf16_mfma_opt
May 8, 2025
Merged

[Kernel][Hardware][AMD] Bf16 mfma opt for ROCm skinny GEMMs#17071
vllm-bot merged 8 commits intovllm-project:mainfrom
hshmhashemi:bf16_mfma_opt

Conversation

@amd-hhashemi
Copy link
Contributor

@amd-hhashemi amd-hhashemi commented Apr 23, 2025

Bf16 mfma opt for ROCm skinny GEMMs

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com>
Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com>
Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com>
@tjtanaa
Copy link
Collaborator

tjtanaa commented Apr 30, 2025

@amd-hhashemi hi.
Thank you for optimizing bf16. How much perf gain could we expect on mi300?

tjtanaavllm added a commit to ROCm/vllm that referenced this pull request Apr 30, 2025
Signed-off-by: tjtanaavllm <tunjian.tan@amd.com>
@amd-hhashemi
Copy link
Contributor Author

amd-hhashemi commented Apr 30, 2025

@amd-hhashemi hi. Thank you for optimizing bf16. How much perf gain could we expect on mi300?

Hey, this optimization shows 25% speedup on llama3 bf16 batch-1 on MI300. The prior solution does expensive bf16->float conversion followed by FMA ops. This optimization avoids that by using MFMAs instead, which is much more efficient.

@SageMoore
Copy link
Contributor

Hi, @amd-hhashemi. Thanks for the contribution! Could you just run a quick serving benchmark to make sure there are no obvious perf regressions? I'm somewhat fuzzy on the exact cases that skinny gemm is enabled but I assume that it will be used in llama 3.1 8B.

Additionally, can you post the benchmark you ran that is giving you 25% speedup?

Serving commands:
vllm serve meta-llama/Llama-3.1-8B-Instruct --port 4444 --disable-log-requests
followed by:
python benchmarks/benchmark_serving.py --model meta-llama/Llama-3.1-8B-Instruct --dataset-name sharegpt --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json --ignore-eos --port 4444

@tjtanaa
Copy link
Collaborator

tjtanaa commented May 1, 2025

@amd-hhashemi
Copy link
Contributor Author

amd-hhashemi commented May 1, 2025

You mean this PR?
ROCm#520

There is no difference, I wrote that first, then the upstream version was merged to ROCm. So the ROCm one can be dropped and we'll later merge with upstream.

@amd-hhashemi
Copy link
Contributor Author

Oh sorry I didn't realize you were point to Aiter.
I didn't know it had been pulled into Aiter.
Although it seems to be the original version, before fp8 or bf16 support was added.

@amd-hhashemi
Copy link
Contributor Author

amd-hhashemi commented May 1, 2025

Hi SageMoore, I will run serving benchmark.
This is what I ran:
python benchmarks/benchmark_latency.py --model /data
/Meta-Llama-3-8B-Instruct --batch-size 1 --dtype bfloat16

[https://github.com/amd-hhashemi/vllm/blob/main/benchmarks/benchmark_latency.py]

Original reported latency: ~1.07sec
After this optimization: ~0.84sec
(it's actually more like ~22% speedup)
[Note: these numbers were actually on a downsized version of MI300, but since it's a compute bottleneck, it should be same on full MI300. I will verify that too]
The skinny gemms get most heavily used with low batch sizes.

@amd-hhashemi
Copy link
Contributor Author

amd-hhashemi commented May 1, 2025

[corrected, with warmup runs]

I ran the server benchmark before and after the change. There isn't any change on server throughput test (this is expected, skinny GEMMs only show up in low batch count):

Before:
image
After this code change:
image

Signed-off-by: charlifu <charlifu@amd.com>
@tjtanaa
Copy link
Collaborator

tjtanaa commented May 5, 2025

Oh sorry I didn't realize you were point to Aiter. I didn't know it had been pulled into Aiter. Although it seems to be the original version, before fp8 or bf16 support was added.

Will there be plans to integrate this updated kernel into AITER?

@robertgshaw2-redhat robertgshaw2-redhat enabled auto-merge (squash) May 6, 2025 17:04
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label May 6, 2025
charlifu added 2 commits May 7, 2025 14:48
Signed-off-by: charlifu <charlifu@amd.com>
Signed-off-by: charlifu <charlifu@amd.com>
auto-merge was automatically disabled May 7, 2025 14:51

Head branch was pushed to by a user without write access

Signed-off-by: charlifu <charlifu@amd.com>
@vllm-bot vllm-bot merged commit 5a499e7 into vllm-project:main May 8, 2025
76 of 80 checks passed
princepride pushed a commit to princepride/vllm that referenced this pull request May 10, 2025
…ject#17071)

Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com>
Signed-off-by: charlifu <charlifu@amd.com>
Co-authored-by: charlifu <charlifu@amd.com>
Signed-off-by: 汪志鹏 <wangzhipeng628@gmail.com>
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
…ject#17071)

Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com>
Signed-off-by: charlifu <charlifu@amd.com>
Co-authored-by: charlifu <charlifu@amd.com>
Signed-off-by: Mu Huai <tianbowen.tbw@antgroup.com>
mawong-amd pushed a commit to ROCm/vllm that referenced this pull request May 14, 2025
…ject#17071)

Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com>
Signed-off-by: charlifu <charlifu@amd.com>
Co-authored-by: charlifu <charlifu@amd.com>
zzzyq pushed a commit to zzzyq/vllm that referenced this pull request May 24, 2025
…ject#17071)

Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com>
Signed-off-by: charlifu <charlifu@amd.com>
Co-authored-by: charlifu <charlifu@amd.com>
Signed-off-by: Yuqi Zhang <yuqizhang@google.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants