Skip to content

[ROCm][DSv4][Perf] Optimized HIP kernel for sparse mla#43306

Open
heachary wants to merge 15 commits into
vllm-project:mainfrom
heachary:heachary/dsrv4/sparse_mla_PR
Open

[ROCm][DSv4][Perf] Optimized HIP kernel for sparse mla#43306
heachary wants to merge 15 commits into
vllm-project:mainfrom
heachary:heachary/dsrv4/sparse_mla_PR

Conversation

@heachary

@heachary heachary commented May 21, 2026

Copy link
Copy Markdown
Contributor

Summary

This PR replaces the Triton-based rocm_sparse_attn_decode implementation with a hand-written HIP kernel using gfx950 MFMA (mfma_f32_16x16x32_bf16) instructions. On non-gfx950 hardware (e.g. gfx942/MI300), the existing Triton implementation is preserved as a fallback.

Key changes:

  • New HIP kernel (sparse_mla_decode_kernel) for gfx950: performs fused gather→dequant(FP8)→QK→online-softmax→PV in a single launch using 256 threads (4 waves) per workgroup with cooperative LDS usage.
  • Split-K support for long sequences: a partial kernel + reduce kernel pair saturates ~256 CUs via a workload-aware heuristic, avoiding the single-WG bottleneck.
  • 4-wave PV distribution: Wave 0 computes QK scores via MFMA; all 4 waves independently do softmax and each handles 8 of 32 output N-tiles, reducing per-wave register pressure from 128 to 32 VGPRs.
  • JIT compilation via torch.utils.cpp_extension.load_inline, cached to VLLM_SPARSE_MLA_HIP_CACHE_DIR.

Perf

1k1k, conc4/64

Model Served Model Hardware Framework Precision ISL OSL TP EP DP Attention Conc Total Token Throughput (tok/s)
baseline dsrv4 deepseek-ai/DeepSeek-V4-Pro MI355X-DOCKER VLLM FP8 1024 1024 8 1 false 4 154
sparse_mla_optimized dsrv4 deepseek-ai/DeepSeek-V4-Pro MI355X-DOCKER VLLM FP8 1024 1024 8 1 false 4 168
baseline dsrv4 deepseek-ai/DeepSeek-V4-Pro MI355X-DOCKER VLLM FP8 1024 1024 8 1 false 64 1702
sparse_mla_optimized dsrv4 deepseek-ai/DeepSeek-V4-Pro MI355X-DOCKER VLLM FP8 1024 1024 8 1 false 64 1725

8k1k, conc4/64

Model Served Model Hardware Framework Precision ISL OSL TP EP DP Attention Conc Total Token Throughput (tok/s)
baseline dsrv4 deepseek-ai/DeepSeek-V4-Pro MI355X-DOCKER VLLM FP8 8192 1024 8 1 false 4 618
sparse_mla_optimized dsrv4 deepseek-ai/DeepSeek-V4-Pro MI355X-DOCKER VLLM FP8 8192 1024 8 1 false 4 710
baseline dsrv4 deepseek-ai/DeepSeek-V4-Pro MI355X-DOCKER VLLM FP8 8192 1024 8 1 false 64 5331
sparse_mla_optimized dsrv4 deepseek-ai/DeepSeek-V4-Pro MI355X-DOCKER VLLM FP8 8192 1024 8 1 false 64 5636

The new kernel gives a ~6% improvement in total token throughput across different scenarios

Accuracy

Tasks Version Filter n-shot Metric   Value   Stderr
gsm8k 3 flexible-extract 8 exact_match 0.9538 ± 0.0058
    strict-match 8 exact_match 0.9545 ± 0.0057

Signed-off-by: Hemanth Acharya <heachary@amd.com>
@mergify mergify Bot added rocm Related to AMD ROCm v1 labels May 21, 2026
@github-project-automation github-project-automation Bot moved this to Todo in AMD May 21, 2026

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a HIP MFMA kernel implementation for sparse-MLA decode, specifically targeting the gfx950 architecture. The changes include the addition of C++ source for the kernels, JIT compilation logic using load_inline, and a Python wrapper to manage execution and split-K logic. Feedback highlights several critical issues: the lack of fallback logic for non-gfx950 hardware which will cause runtime failures, performance bottlenecks caused by synchronous GPU-to-CPU copies (.item()) and frequent memory allocations for scratch buffers, potential numerical precision loss due to using bf16 for accumulation, and the problematic global modification of environment variables during JIT compilation.

Comment thread vllm/v1/attention/ops/rocm_aiter_mla_sparse.py
Comment on lines +2633 to +2636
max_main_len = int(swa_lens.max().item()) if swa_lens is not None else 0
max_extra_len = 0
if has_extra and topk_lens is not None:
max_extra_len = int(topk_lens.max().item())

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Calling .item() on GPU tensors (swa_lens and topk_lens) triggers synchronous host-device copies and blocks the CPU. In the decode path, which is extremely latency-sensitive, these synchronizations can significantly degrade performance. These values should be retrieved from the CPU-side metadata (e.g., from the scheduler) instead of being synchronized from the GPU here.

Comment on lines +2520 to +2530
scratch_m = torch.empty(
num_queries * num_head_blocks * split_k * BLOCK_H,
device=q.device,
dtype=torch.float32,
)
scratch_l = torch.empty_like(scratch_m)
scratch_acc = torch.empty(
num_queries * num_head_blocks * split_k * BLOCK_H * 512,
device=q.device,
dtype=torch.bfloat16,
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Allocating large scratch buffers (scratch_m, scratch_l, scratch_acc) using torch.empty on every forward pass is inefficient. Memory allocation in PyTorch involves synchronization and overhead that will bottleneck the decode kernel. These buffers should be pre-allocated or managed via a persistent workspace manager to avoid per-step allocation costs.

Comment on lines +2527 to +2529
num_queries * num_head_blocks * split_k * BLOCK_H * 512,
device=q.device,
dtype=torch.bfloat16,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The scratch_acc buffer for Split-K accumulation is allocated as torch.bfloat16. Partial sums in attention kernels can have a large dynamic range, and storing them in bf16 before the final reduction can lead to significant numerical inaccuracies. It is standard practice to use float32 for the accumulation workspace to maintain precision during the reduction phase.

str(pathlib.Path(tempfile.gettempdir()) / "vllm_sparse_mla_hip_cache"),
)
os.makedirs(cache_dir, exist_ok=True)
os.environ["PYTORCH_ROCM_ARCH"] = "gfx950"

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Modifying os.environ["PYTORCH_ROCM_ARCH"] globally within a library function is problematic as it affects the entire process and may interfere with other JIT compilations. Since the target architecture is already explicitly passed via --offload-arch in extra_cuda_cflags, this global environment modification should be removed.

Signed-off-by: Hemanth Acharya <heachary@amd.com>
@heachary heachary marked this pull request as ready for review May 21, 2026 10:35
@heachary heachary requested a review from tjtanaa as a code owner May 21, 2026 10:35
@heachary heachary changed the title [ROCm][DSRv4][Perf] Optimized HIP kernel for sparse mla [ROCm][DSv4][Perf] Optimized HIP kernel for sparse mla May 21, 2026
Comment thread vllm/v1/attention/ops/rocm_aiter_mla_sparse.py Outdated
@tjtanaa

tjtanaa commented May 21, 2026

Copy link
Copy Markdown
Member

Server command

#!/bin/bash

rm -rf ~/.cache/vllm

export VLLM_ROCM_USE_AITER=1

vllm serve deepseek-ai/DeepSeek-V4-Pro \
  --host localhost \
  --port 8001 \
  --tensor-parallel-size 8 \
  --distributed-executor-backend mp \
  --trust-remote-code \
  --gpu-memory-utilization 0.8 \
  --moe-backend triton_unfused \
  --tokenizer-mode deepseek_v4 \
  --reasoning-parser deepseek_v4 \
  --kv-cache-dtype fp8 \
  --compilation-config '{"mode":3,"cudagraph_mode": "FULL_AND_PIECEWISE"}'

lm eval command:

Please use num shot 20, gsm8k with large concurrency of 256 and it must be acc of 0.95. This is ensure the boundaries conditions are implemented correctly.

#!/bin/bash

SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" && pwd)
LM_EVAL_PYTHON=${LM_EVAL_PYTHON:-"${SCRIPT_DIR}/hfvenv/bin/python"}
MODEL=deepseek-ai/DeepSeek-V4-Pro
lm_eval --model local-completions --model_args model=$MODEL,base_url=http://0.0.0.0:8894/v1/completions,num_concurrent=256,max_retries=10,max_gen_toks=2048,max_length=1048576,timeout=60000,trust_remote_code=True --batch_size auto --tasks gsm8k --num_fewshot 20 \
| tee lmeval_deepseek-ai_DeepSeek-V4-Pro.log

Comment thread vllm/v1/attention/ops/rocm_aiter_mla_sparse.py
Comment thread vllm/v1/attention/ops/rocm_aiter_mla_sparse.py Outdated
heachary added 3 commits May 25, 2026 06:00
Signed-off-by: Hemanth Acharya <heachary@amd.com>
Signed-off-by: Hemanth Acharya <heachary@amd.com>
Signed-off-by: Hemanth Acharya <heachary@amd.com>
@heachary heachary requested a review from dllehr-amd as a code owner May 25, 2026 11:35
@heachary

Copy link
Copy Markdown
Contributor Author

Server command

#!/bin/bash

rm -rf ~/.cache/vllm

export VLLM_ROCM_USE_AITER=1

vllm serve deepseek-ai/DeepSeek-V4-Pro \
  --host localhost \
  --port 8001 \
  --tensor-parallel-size 8 \
  --distributed-executor-backend mp \
  --trust-remote-code \
  --gpu-memory-utilization 0.8 \
  --moe-backend triton_unfused \
  --tokenizer-mode deepseek_v4 \
  --reasoning-parser deepseek_v4 \
  --kv-cache-dtype fp8 \
  --compilation-config '{"mode":3,"cudagraph_mode": "FULL_AND_PIECEWISE"}'

lm eval command:

Please use num shot 20, gsm8k with large concurrency of 256 and it must be acc of 0.95. This is ensure the boundaries conditions are implemented correctly.

#!/bin/bash

SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" && pwd)
LM_EVAL_PYTHON=${LM_EVAL_PYTHON:-"${SCRIPT_DIR}/hfvenv/bin/python"}
MODEL=deepseek-ai/DeepSeek-V4-Pro
lm_eval --model local-completions --model_args model=$MODEL,base_url=http://0.0.0.0:8894/v1/completions,num_concurrent=256,max_retries=10,max_gen_toks=2048,max_length=1048576,timeout=60000,trust_remote_code=True --batch_size auto --tasks gsm8k --num_fewshot 20 \
| tee lmeval_deepseek-ai_DeepSeek-V4-Pro.log

@tjtanaa I reran accuracy with these settings :

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 20 exact_match 0.9538 ± 0.0058
strict-match 20 exact_match 0.9545 ± 0.0057

@heachary heachary requested a review from tjtanaa May 25, 2026 11:38
m.impl("decode_split", &sparse_mla_decode_split);
}
"""
_SPARSE_MLA_DECODE_CU = (

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@heachary please use CMakeLists.txt .

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All custom kernels should be compiled through build time. Not JIT build.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@tjtanaa

tjtanaa commented May 25, 2026

Copy link
Copy Markdown
Member

@heachary Does the kernel support MTP use case? We have supported MTP feature #43385, so we will need a logic to cater for the MTP path.

@heachary

heachary commented May 28, 2026

Copy link
Copy Markdown
Contributor Author

@heachary Does the kernel support MTP use case? We have supported MTP feature #43385, so we will need a logic to cater for the MTP path.

@tjtanaa the HIP kernel already supports MTP without changes because it operates on individual query tokens. With MTP, a single decode request produces multiple query tokens (e.g., 3 instead of 1). The upstream metadata pipeline already builds per-token SWA indices, per-token topk indices, and per-token ragged indptr arrays. So the kernel simply sees more queries with correctly populated per-query index ranges — it doesn't distinguish whether those queries came from 100 requests with 1 token each or 50 requests with 2 tokens each.

To confirm, i ran the tests just as you did in the PR:

ISL/OSL conc Output token throughput (tokens / s)
no-mtp 1k1k 4 85.06
with-mtp (num_speculative_tokens=2) 1k1k 4 187

Accuracy (with MTP):

Tasks Version Filter n-shot Metric   Value   Stderr
gsm8k 3 flexible-extract 20 exact_match 0.9515 ± 0.0059
    strict-match 20 exact_match 0.9522 ± 0.0059

Signed-off-by: Hemanth Acharya <heachary@amd.com>
@heachary heachary requested a review from tjtanaa May 28, 2026 16:16
@mergify mergify Bot added the ci/build label May 28, 2026
@mergify

mergify Bot commented May 29, 2026

Copy link
Copy Markdown
Contributor

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @heachary.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label May 29, 2026
Signed-off-by: Hemanth Acharya <heachary@amd.com>
@heachary heachary requested a review from khluu as a code owner June 1, 2026 06:11
@mergify mergify Bot removed the needs-rebase label Jun 1, 2026
@heachary

heachary commented Jun 2, 2026

Copy link
Copy Markdown
Contributor Author

@tjtanaa could you take another look at this PR to see if the changes look okay and add the label which kicks off the unit tests please?



@torch.inference_mode()
def test_hip_decode_main_only_no_sink() -> None:

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@heachary For all of the test, please parametrize inputs and use the exact deepseek params, e.g. block size, num_heads etc.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, the tests now test for the exact dsrv4 config.

Comment thread CMakeLists.txt Outdated
"-Wno-c++11-narrowing")

define_extension_target(
_rocm_sparse_mla_C

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should compile it into _rocm_C as well. Your code will need to be guarded by GPU ARCH macros.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, i removed the extension and added the source into _rocm_C (for gfx950 only)

@tjtanaa

tjtanaa commented Jun 4, 2026

Copy link
Copy Markdown
Member

@heachary Thanks again for the important optimization. I added the label so that it be used to check if the compilation still works if we are compiling for all arch.

@tjtanaa tjtanaa added the ready ONLY add when PR is ready to merge/full CI is needed label Jun 4, 2026
@AndreasKaratzas

Copy link
Copy Markdown
Member

Please rebase @heachary

heachary added 3 commits June 5, 2026 01:36
Signed-off-by: Hemanth Acharya <heachary@amd.com>
Signed-off-by: Hemanth Acharya <heachary@amd.com>
@heachary heachary requested a review from tjtanaa June 5, 2026 08:06

@tjtanaa tjtanaa left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. But the AMD CI needs attention, it seems the docker build is not triggered and tests are not run.

@AndreasKaratzas could you take a look? Should we trigger it manually?

@tjtanaa

tjtanaa commented Jun 5, 2026

Copy link
Copy Markdown
Member

I triggered the AMD CI manually to unblock this PR.

@AndreasKaratzas

Copy link
Copy Markdown
Member

Yes there is currently a PR on infra to mitigate this oversight. Sorry for the confusion.

@tjtanaa tjtanaa added the DSv4 label Jun 6, 2026
@tjtanaa

tjtanaa commented Jun 6, 2026

Copy link
Copy Markdown
Member

@heachary it seems compilation has some issue.

tjtanaa and others added 2 commits June 7, 2026 00:11
@heachary

heachary commented Jun 8, 2026

Copy link
Copy Markdown
Contributor Author

@tjtanaa / @AndreasKaratzas : fixed the compilation issue. tests look to be passing now, i think the PR is ready to be merged now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build DSv4 ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm v1

Projects

Status: Todo

Development

Successfully merging this pull request may close these issues.

3 participants