[ROCm][Deepseekv4] DeepseekV4 Mi300 support #41451
[ROCm][Deepseekv4] DeepseekV4 Mi300 support #41451ganyi1996ppo wants to merge 17 commits intovllm-project:mainfrom
Conversation
Signed-off-by: ganyi <ygan@amd.com> Made-with: Cursor
Signed-off-by: whx-sjtu <xiaowang990929@gmail.com>
Signed-off-by: whx-sjtu <xiaowang990929@gmail.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: tjtanaavllm <tunjian.tan@amd.com>
Signed-off-by: tjtanaavllm <tunjian.tan@amd.com>
Signed-off-by: whx-sjtu <xiaowang990929@gmail.com>
Signed-off-by: whx-sjtu <xiaowang990929@gmail.com>
Signed-off-by: tjtanaavllm <tunjian.tan@amd.com>
Signed-off-by: whx-sjtu <xiaowang990929@gmail.com>
Signed-off-by: ganyi <ygan@amd.com>
Signed-off-by: ganyi <ygan@amd.com>
Signed-off-by: ganyi <ygan@amd.com>
Signed-off-by: ganyi <ygan@amd.com>
Signed-off-by: ganyi <ygan@amd.com>
|
Documentation preview: https://vllm--41451.org.readthedocs.build/en/41451/ |
There was a problem hiding this comment.
Code Review
This pull request enables support for DeepSeek V4 and associated MLA/MoE kernels on the ROCm platform. It introduces ROCm-compatible HIP kernels, Triton-based fallbacks for sparse attention indexing and FP8 einsum operations, and updates model executors to handle ROCm-specific constraints like FNUZ FP8 formats and the disabling of auxiliary streams. Feedback identifies a critical bug in the DeepSeek-V2 ROCm path where RoPE application results are ignored, and highlights performance bottlenecks in ROCm fallback implementations for MQA logits and MHC Sinkhorn iterations that rely on slow Python loops.
| rotary_emb( | ||
| positions, q[..., : self.rope_dim], k[..., : self.rope_dim].unsqueeze(1) | ||
| ) |
There was a problem hiding this comment.
The return values of rotary_emb are ignored in the ROCm path. In vLLM, RotaryEmbedding is not an in-place operation. Furthermore, k[..., : self.rope_dim].unsqueeze(1) creates a temporary tensor, so even if the operation were in-place, the original k tensor would not be updated. This results in RoPE not being applied to the query and key tensors, which will lead to incorrect model outputs.
q_pe, k_pe = rotary_emb(
positions, q[..., : self.rope_dim], k[..., : self.rope_dim].unsqueeze(1)
)
q[..., : self.rope_dim] = q_pe.view_as(q[..., : self.rope_dim])
k[..., : self.rope_dim] = k_pe.view_as(k[..., : self.rope_dim])| if context_lens.dim() > 1: | ||
| context_lens = context_lens.squeeze(-1) | ||
| kv_cache_flat = kv_cache.view(-1, block_size * (dim + 4)) | ||
| for i in range(batch_size): |
There was a problem hiding this comment.
This fallback implementation uses a Python loop over the batch size and performs F.linear (GEMV) for each sequence. This will be extremely slow for large batches and will significantly degrade performance during the decode phase. While this is a fallback, a vectorized implementation using torch.bmm or torch.matmul should be preferred to avoid the Python loop overhead.
| for _ in range(sinkhorn_repeat - 1): | ||
| comb_mix = comb_mix / (comb_mix.sum(dim=-1, keepdim=True) + hc_sinkhorn_eps) | ||
| comb_mix = comb_mix / (comb_mix.sum(dim=-2, keepdim=True) + hc_sinkhorn_eps) |
|
Besides the author, are there any fellow developers online who have successfully run DeepSeek V4 on AMD MI308X using frameworks like vLLM or SGLang? I’d love to ask for some guidance. BIG THANKS。 |
Purpose
This PR based on PR #41217 and #40871. Will reformat after those 2 PR merged.
machine: mi308
test script:
request:
response:
Will have a more thorough test after previous PR merged.
Test Plan
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.