[ROCm][DSv4][Perf] Optimized HIP kernel for sparse mla#43306
Conversation
Signed-off-by: Hemanth Acharya <heachary@amd.com>
There was a problem hiding this comment.
Code Review
This pull request introduces a HIP MFMA kernel implementation for sparse-MLA decode, specifically targeting the gfx950 architecture. The changes include the addition of C++ source for the kernels, JIT compilation logic using load_inline, and a Python wrapper to manage execution and split-K logic. Feedback highlights several critical issues: the lack of fallback logic for non-gfx950 hardware which will cause runtime failures, performance bottlenecks caused by synchronous GPU-to-CPU copies (.item()) and frequent memory allocations for scratch buffers, potential numerical precision loss due to using bf16 for accumulation, and the problematic global modification of environment variables during JIT compilation.
| max_main_len = int(swa_lens.max().item()) if swa_lens is not None else 0 | ||
| max_extra_len = 0 | ||
| if has_extra and topk_lens is not None: | ||
| max_extra_len = int(topk_lens.max().item()) |
There was a problem hiding this comment.
Calling .item() on GPU tensors (swa_lens and topk_lens) triggers synchronous host-device copies and blocks the CPU. In the decode path, which is extremely latency-sensitive, these synchronizations can significantly degrade performance. These values should be retrieved from the CPU-side metadata (e.g., from the scheduler) instead of being synchronized from the GPU here.
| scratch_m = torch.empty( | ||
| num_queries * num_head_blocks * split_k * BLOCK_H, | ||
| device=q.device, | ||
| dtype=torch.float32, | ||
| ) | ||
| scratch_l = torch.empty_like(scratch_m) | ||
| scratch_acc = torch.empty( | ||
| num_queries * num_head_blocks * split_k * BLOCK_H * 512, | ||
| device=q.device, | ||
| dtype=torch.bfloat16, | ||
| ) |
There was a problem hiding this comment.
Allocating large scratch buffers (scratch_m, scratch_l, scratch_acc) using torch.empty on every forward pass is inefficient. Memory allocation in PyTorch involves synchronization and overhead that will bottleneck the decode kernel. These buffers should be pre-allocated or managed via a persistent workspace manager to avoid per-step allocation costs.
| num_queries * num_head_blocks * split_k * BLOCK_H * 512, | ||
| device=q.device, | ||
| dtype=torch.bfloat16, |
There was a problem hiding this comment.
The scratch_acc buffer for Split-K accumulation is allocated as torch.bfloat16. Partial sums in attention kernels can have a large dynamic range, and storing them in bf16 before the final reduction can lead to significant numerical inaccuracies. It is standard practice to use float32 for the accumulation workspace to maintain precision during the reduction phase.
| str(pathlib.Path(tempfile.gettempdir()) / "vllm_sparse_mla_hip_cache"), | ||
| ) | ||
| os.makedirs(cache_dir, exist_ok=True) | ||
| os.environ["PYTORCH_ROCM_ARCH"] = "gfx950" |
There was a problem hiding this comment.
Modifying os.environ["PYTORCH_ROCM_ARCH"] globally within a library function is problematic as it affects the entire process and may interfere with other JIT compilations. Since the target architecture is already explicitly passed via --offload-arch in extra_cuda_cflags, this global environment modification should be removed.
Signed-off-by: Hemanth Acharya <heachary@amd.com>
|
Server command lm eval command: Please use num shot 20, gsm8k with large concurrency of 256 and it must be acc of 0.95. This is ensure the boundaries conditions are implemented correctly. |
Signed-off-by: Hemanth Acharya <heachary@amd.com>
Signed-off-by: Hemanth Acharya <heachary@amd.com>
Signed-off-by: Hemanth Acharya <heachary@amd.com>
@tjtanaa I reran accuracy with these settings :
|
| m.impl("decode_split", &sparse_mla_decode_split); | ||
| } | ||
| """ | ||
| _SPARSE_MLA_DECODE_CU = ( |
There was a problem hiding this comment.
All custom kernels should be compiled through build time. Not JIT build.
@tjtanaa the HIP kernel already supports MTP without changes because it operates on individual query tokens. With MTP, a single decode request produces multiple query tokens (e.g., 3 instead of 1). The upstream metadata pipeline already builds per-token SWA indices, per-token topk indices, and per-token ragged indptr arrays. So the kernel simply sees more queries with correctly populated per-query index ranges — it doesn't distinguish whether those queries came from 100 requests with 1 token each or 50 requests with 2 tokens each. To confirm, i ran the tests just as you did in the PR:
Accuracy (with MTP):
|
Signed-off-by: Hemanth Acharya <heachary@amd.com>
|
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: Hemanth Acharya <heachary@amd.com>
|
@tjtanaa could you take another look at this PR to see if the changes look okay and add the label which kicks off the unit tests please? |
|
|
||
|
|
||
| @torch.inference_mode() | ||
| def test_hip_decode_main_only_no_sink() -> None: |
There was a problem hiding this comment.
@heachary For all of the test, please parametrize inputs and use the exact deepseek params, e.g. block size, num_heads etc.
There was a problem hiding this comment.
Done, the tests now test for the exact dsrv4 config.
| "-Wno-c++11-narrowing") | ||
|
|
||
| define_extension_target( | ||
| _rocm_sparse_mla_C |
There was a problem hiding this comment.
We should compile it into _rocm_C as well. Your code will need to be guarded by GPU ARCH macros.
There was a problem hiding this comment.
Okay, i removed the extension and added the source into _rocm_C (for gfx950 only)
|
@heachary Thanks again for the important optimization. I added the label so that it be used to check if the compilation still works if we are compiling for all arch. |
|
Please rebase @heachary |
Signed-off-by: Hemanth Acharya <heachary@amd.com>
Signed-off-by: Hemanth Acharya <heachary@amd.com>
There was a problem hiding this comment.
LGTM. But the AMD CI needs attention, it seems the docker build is not triggered and tests are not run.
@AndreasKaratzas could you take a look? Should we trigger it manually?
|
I triggered the AMD CI manually to unblock this PR. |
|
Yes there is currently a PR on infra to mitigate this oversight. Sorry for the confusion. |
|
@heachary it seems compilation has some issue. |
Signed-off-by: Hemanth Acharya <heachary@amd.com>
|
@tjtanaa / @AndreasKaratzas : fixed the compilation issue. tests look to be passing now, i think the PR is ready to be merged now. |
Summary
This PR replaces the Triton-based rocm_sparse_attn_decode implementation with a hand-written HIP kernel using gfx950 MFMA (mfma_f32_16x16x32_bf16) instructions. On non-gfx950 hardware (e.g. gfx942/MI300), the existing Triton implementation is preserved as a fallback.
Key changes:
Perf
1k1k, conc4/64
8k1k, conc4/64
The new kernel gives a ~6% improvement in total token throughput across different scenarios
Accuracy