Skip to content

Convert wvSplitKQ to 16x16 MFMA in prep for mi4xx.#34100

Merged
gshtras merged 5 commits intovllm-project:mainfrom
amd-hhashemi:wvSplitKfp16x16
Feb 24, 2026
Merged

Convert wvSplitKQ to 16x16 MFMA in prep for mi4xx.#34100
gshtras merged 5 commits intovllm-project:mainfrom
amd-hhashemi:wvSplitKfp16x16

Conversation

@amd-hhashemi
Copy link
Copy Markdown
Contributor

@amd-hhashemi amd-hhashemi commented Feb 8, 2026

Reduces reg pressure and improves perf of wvSplitKQ too.
It only changes the fp8 skinny gemms.

Run CI validation of test_rocm_skinny_gemms.py after #34013
mi350 tests all PASS, see before/after this PR (run with 5x more seed variations) attached.

test_rocm_skinny_gemms_MI355_after.log
test_rocm_skinny_gemms_MI355_befor.log
test_rocm_skinny_gemms_mi325_after34100.log

MI355 Performance:

vllm bench serve
model=amd/Llama-3.1-70B-Instruct-FP8-KV
max_concurrency=4
num_prompts=400
in=1024
out=1024

<style> </style>
  before after % change
Total token throughput (tok/s) 604.9 823.32 36.11%
Mean TTFT (ms) 122.25 120.06 -1.79%
Mean ITL (ms) 13.11 9.6 -26.77%
Mean E2EL (ms) 13535.86 9944.91 -26.53%

Accuracy

branch-main:
local-completions ({'model': 'deepseek-r1-FP8-Dynamic', 'base_url': 'http://127.0.0.1:8000/v1/completions'}), gen_kwargs: ({}), limit: None, num_fewshot: None, batch_size: 4

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.9500 ± 0.006
strict-match 5 exact_match 0.9492 ± 0.006

branch-wvSplitKfp16x16
local-completions ({'model': 'deepseek-r1-FP8-Dynamic', 'base_url': 'http://127.0.0.1:8000/v1/completions'}), gen_kwargs: ({}), limit: None, num_fewshot: None, batch_size: 4

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.9500 ± 0.006
strict-match 5 exact_match 0.9492 ± 0.006

branch-wvSplitKfp16x16+FULL_AND_PIECEWISE
local-completions ({'model': 'deepseek-r1-FP8-Dynamic', 'base_url': 'http://127.0.0.1:8000/v1/completions'}), gen_kwargs: ({}), limit: None, num_fewshot: None, batch_size: 4

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.9507 ± 0.006
strict-match 5 exact_match 0.9500 ± 0.006

Purpose

Test Plan

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Reduces reg pressure and improves perf too.

Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com>
@mergify mergify bot added the rocm Related to AMD ROCm label Feb 8, 2026
@github-project-automation github-project-automation bot moved this to Todo in AMD Feb 8, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request successfully converts the wvSplitKQ kernel to use 16x16 MFMA instructions, which should improve performance and reduce register pressure as intended. The changes in csrc/rocm/skinny_gemms.cu reflect this conversion, including type changes and updated MFMA intrinsics. The test file tests/kernels/quantization/test_rocm_skinny_gemms.py has increased test coverage by adding more seeds, which is a positive change. However, there are a couple of issues that need to be addressed.

Comment on lines 112 to 115
"""
@pytest.mark.parametrize("xnorm", [False, True])
@pytest.mark.parametrize("n", N_FACTORS_WVSPLITKRC)
@pytest.mark.parametrize("k", K_FACTORS_WVSPLITKRC)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The test_rocm_wvsplitkrc_kernel test function has been commented out. If this kernel (wvSplitKrc) is still a functional part of the codebase, its tests should remain active to prevent regressions. If the wvSplitKrc kernel is no longer needed or is being deprecated, please remove it and its associated tests entirely. Commenting out tests without a clear explanation in the PR description can lead to undetected issues for the affected functionality.

Comment on lines +237 to +238
"""

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The closing """ for commenting out the test_rocm_wvsplitkrc_kernel function is present here. Please refer to the comment on lines 112-115 regarding the status of this test.

Comment on lines 1955 to 1957
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][2], 0x102, 0xf, 0xf,
1); // row_shl2
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][10], 0x102, 0xf, 0xf, 1);
1); // row_shl1
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][3], 0x103, 0xf, 0xf,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The comments for the __builtin_amdgcn_mov_dpp operations are inconsistent. While sum[n][y][1] uses 0x101 (row_shl1), sum[n][y][2] uses 0x102 and sum[n][y][3] uses 0x103. The comment // row_shl1 is repeated for all three, which is misleading. Please update the comments to accurately reflect the shift values (e.g., // row_shl2, // row_shl3).

                                          1);  // row_shl1
        accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][2], 0x102, 0xf, 0xf,
                                          1);  // row_shl2
        accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][3], 0x103, 0xf, 0xf,
                                          1);  // row_shl3

Comment on lines +2102 to +2104
1); // row_shl1
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][9], 0x101, 0xf, 0xf, 1);
accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][2], 0x102, 0xf, 0xf,
1); // row_shl2
accm16 += __builtin_amdgcn_mov_dpp(sum[n][y][10], 0x102, 0xf, 0xf, 1);
1); // row_shl1
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the previous comment, the comments for the __builtin_amdgcn_mov_dpp operations here are inconsistent. The comment // row_shl1 is repeated for different shift values (0x101, 0x102, 0x103). Please update the comments to accurately reflect the shift values (e.g., // row_shl2, // row_shl3).

                                          1);  // row_shl1
        accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][2], 0x102, 0xf, 0xf,
                                          1);  // row_shl2
        accm0 += __builtin_amdgcn_mov_dpp(sum[n][y][3], 0x103, 0xf, 0xf,
                                          1);  // row_shl3

amd-hhashemi and others added 3 commits February 8, 2026 19:14
Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com>
Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com>
@tjtanaa
Copy link
Copy Markdown
Collaborator

tjtanaa commented Feb 14, 2026

@amd-hhashemi Will these changes affect performance ongfx942 and gfx950 because I do not see any code path logic for different gfx?

Can you share some perf results for gfx942 and gfx950 especially end to end as well? For Pull Request related to algorithm, we would like to understand the impact of the changes.

Please also provide some e2e accuracy evaluation (gsm8k will do) to understand if the optimization impacted accuracy or not. As algorithm changes, error accumulation pattern might change.

Do you know if we need to keep multiple version of the wvSplitK? It seems the condition of triggering the kernel keeps on changing based on the optimization approaches. It potentially makes wvSplitK more and more restrictive.

@amd-hhashemi
Copy link
Copy Markdown
Contributor Author

@amd-hhashemi Will these changes affect performance ongfx942 and gfx950 because I do not see any code path logic for different gfx?

Can you share some perf results for gfx942 and gfx950 especially end to end as well? For Pull Request related to algorithm, we would like to understand the impact of the changes.

Please also provide some e2e accuracy evaluation (gsm8k will do) to understand if the optimization impacted accuracy or not. As algorithm changes, error accumulation pattern might change.

Do you know if we need to keep multiple version of the wvSplitK? It seems the condition of triggering the kernel keeps on changing based on the optimization approaches. It potentially makes wvSplitK more and more restrictive.

I just added the mi355 e2e Llama-3.1-70B-Instruct-FP8 perf numbers to the description. This fp8 one should be silid now, has bias and padding. Might just tweak the dispatch config conditions for important models.

Would be good if we could combine the different versions. The problem is the different solutions diverge in ways that cause overheads when trying to merge them.

What is best way to report e2e accuracy? I usually juts run prompt tests and make sure before/after output has not changed.

@tjtanaa
Copy link
Copy Markdown
Collaborator

tjtanaa commented Feb 14, 2026

What is best way to report e2e accuracy? I usually juts run prompt tests and make sure before/after output has not
changed.

Checking prompt is not enough as usually you can only send one prompt at a time. We also need to check if the boundary condition handling of an algorithm, so we have to validate at large batch size.

Evaluating accuracy depends on case by case.
Since wvsplitKQ is only used in skinny gemm situation, this means that gsm8k dataset, (short prompts, short output) is suitable.
You can evaluate the e2e accuracy using lm_eval with gsm8k dataset, with proper batch size (batch size > 1) You can follow this https://github.com/ROCm/vllm/blob/4ab5453eac67e6fa659fccccb3e35fac5366aa24/evaluation/README.md?plain=1#L161

@tjtanaa
Copy link
Copy Markdown
Collaborator

tjtanaa commented Feb 14, 2026

Would be good if we could combine the different versions. The problem is the different solutions diverge in ways that cause overheads when trying to merge them.

If this is the case, rather than modify existing code, can we have them as two different kernels implementations and dispatch them based on different condition? We still want to maintain support for older architecture as well, while optimizing for newer ones.

@amd-hhashemi
Copy link
Copy Markdown
Contributor Author

amd-hhashemi commented Feb 14, 2026

If this is the case, rather than modify existing code, can we have them as two different kernels implementations and dispatch them based on different condition? We still want to maintain support for older architecture as well, while optimizing for newer ones.

@tjtanaa Sorry - to clarify the different kernel solutions are not targeting different products, they target different GEMM scenarios that are currently difficult for hipBlasLt to handle well.
There are basically 3 scenarios we're trying to address with these small tensor GEMMs:

  1. very small N at fp16/bf16 (wvSplitK)
  2. very small N at fp8 (wvSplitKQ)
  3. so small that extreme kspliting is needed to fill machine (wvSplitKrc)

I was just saying that trying to address these 3 in one kernel without introducing overheads is difficult. There's no intention of having different kernel paths for different products.

@amd-hhashemi
Copy link
Copy Markdown
Contributor Author

@tjtanaa Accuracy and perf numbers added to description.

@gshtras gshtras added the ready ONLY add when PR is ready to merge/full CI is needed label Feb 24, 2026
@gshtras gshtras enabled auto-merge (squash) February 24, 2026 23:14
@gshtras gshtras merged commit a0e50a4 into vllm-project:main Feb 24, 2026
17 of 18 checks passed
@github-project-automation github-project-automation bot moved this from Todo to Done in AMD Feb 24, 2026
tom-zju pushed a commit to tom-zju/vllm that referenced this pull request Feb 26, 2026
Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com>
llsj14 pushed a commit to llsj14/vllm that referenced this pull request Mar 1, 2026
Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com>
tunglinwood pushed a commit to tunglinwood/vllm that referenced this pull request Mar 4, 2026
Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com>
askliar pushed a commit to askliar/vllm that referenced this pull request Mar 9, 2026
Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com>
Signed-off-by: Andrii Skliar <askliar@nvidia.com>
Copilot AI pushed a commit to machov/vllm that referenced this pull request Mar 10, 2026
Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com>
EricccYang pushed a commit to EricccYang/vllm that referenced this pull request Apr 1, 2026
Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com>
Signed-off-by: EricccYang <yangyang4991@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

3 participants