Skip to content

[ROCm][Deepseekv4] DeepseekV4 Mi300 support #41451

Draft
ganyi1996ppo wants to merge 17 commits intovllm-project:mainfrom
ROCm:ganyi/dsv4_mi300_support
Draft

[ROCm][Deepseekv4] DeepseekV4 Mi300 support #41451
ganyi1996ppo wants to merge 17 commits intovllm-project:mainfrom
ROCm:ganyi/dsv4_mi300_support

Conversation

@ganyi1996ppo
Copy link
Copy Markdown
Contributor

@ganyi1996ppo ganyi1996ppo commented May 1, 2026

Purpose

This PR based on PR #41217 and #40871. Will reformat after those 2 PR merged.
machine: mi308
test script:

max_num_seqs=16
max_num_batched_tokens=1024
tensor_parallel_size=4
export VLLM_TORCH_PROFILER_DIR="/app/vllm_profile"
export HF_HOME=/data/huggingface-cache
export VLLM_ROCM_USE_AITER=1

MODEL=/mnt/data/pretrained_model/deepseek-ai/DeepSeek-V4-Flash
vllm serve ${MODEL} \
    --host localhost \
    --dtype auto \
    --tensor-parallel-size ${tensor_parallel_size} \
    --max-num-seqs ${max_num_seqs} \
    --trust-remote-code \
    --profiler-config '{"profiler": "torch", "torch_profiler_dir": "./vllm_profile", "torch_profiler_with_stack": "False"}' \
    --gpu-memory-utilization 0.35 \
    --moe-backend "triton_unfused" \
    --tokenizer-mode "deepseek_v4" \
    --async-scheduling \
    --enforce-eager \

request:

curl -s http://localhost:8000/v1/completions -H "Content-Type: application/json" -d '{
  "prompt": "Write me a poem about AMD and Deepseek",
  "max_tokens": 100,
  "temperature": 0.0
}'

response:

{"id":"cmpl-b180b64df0a5a360","object":"text_completion","created":1777619440,"model":"/mnt/data/pretrained_model/deepseek-ai/DeepSeek-V4","choices":[{"index":0,"text":"\", \"role\": \"user\" }, { \"content\": \"Here is a poem about AMD and DeepSeek.\\n\\n**The Silicon and the Spark**\\n\\nIn Santa Clara's sunlit halls, where silicon dreams are spun,\\nA titan works on tiny things, beneath the desert sun.\\nThey craft the threads of logic, a digital tapestry,\\nTo weave the future's canvas, for all the world to see.\\n\\nBut far across the ocean, in","logprobs":null,"finish_reason":"length","stop_reason":null,"token_ids":null,"prompt_logprobs":null,"prompt_token_ids":null}],"service_tier":null,"system_fingerprint":"vllm-0.20.1rc1.dev137+gdde2fb080.d20260501-tp4-795d0827","usage":{"prompt_tokens":9,"total_tokens":109,"completion_tokens":100,"prompt_tokens_details":null},"kv_transfer_params":null}

Will have a more thorough test after previous PR merged.

Test Plan

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

ganyi1996ppo and others added 17 commits May 1, 2026 02:24
Signed-off-by: ganyi <ygan@amd.com>
Made-with: Cursor
Signed-off-by: whx-sjtu <xiaowang990929@gmail.com>
Signed-off-by: whx-sjtu <xiaowang990929@gmail.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: tjtanaavllm <tunjian.tan@amd.com>
Signed-off-by: tjtanaavllm <tunjian.tan@amd.com>
Signed-off-by: whx-sjtu <xiaowang990929@gmail.com>
Signed-off-by: whx-sjtu <xiaowang990929@gmail.com>
Signed-off-by: tjtanaavllm <tunjian.tan@amd.com>
Signed-off-by: whx-sjtu <xiaowang990929@gmail.com>
Signed-off-by: whx-sjtu <xiaowang990929@gmail.com>
Signed-off-by: ganyi <ygan@amd.com>
Signed-off-by: ganyi <ygan@amd.com>
Signed-off-by: ganyi <ygan@amd.com>
Signed-off-by: ganyi <ygan@amd.com>
Signed-off-by: ganyi <ygan@amd.com>
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented May 1, 2026

Documentation preview: https://vllm--41451.org.readthedocs.build/en/41451/

@mergify mergify Bot added documentation Improvements or additions to documentation ci/build deepseek Related to DeepSeek models rocm Related to AMD ROCm v1 labels May 1, 2026
@github-project-automation github-project-automation Bot moved this to Todo in AMD May 1, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request enables support for DeepSeek V4 and associated MLA/MoE kernels on the ROCm platform. It introduces ROCm-compatible HIP kernels, Triton-based fallbacks for sparse attention indexing and FP8 einsum operations, and updates model executors to handle ROCm-specific constraints like FNUZ FP8 formats and the disabling of auxiliary streams. Feedback identifies a critical bug in the DeepSeek-V2 ROCm path where RoPE application results are ignored, and highlights performance bottlenecks in ROCm fallback implementations for MQA logits and MHC Sinkhorn iterations that rely on slow Python loops.

Comment on lines +688 to +690
rotary_emb(
positions, q[..., : self.rope_dim], k[..., : self.rope_dim].unsqueeze(1)
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The return values of rotary_emb are ignored in the ROCm path. In vLLM, RotaryEmbedding is not an in-place operation. Furthermore, k[..., : self.rope_dim].unsqueeze(1) creates a temporary tensor, so even if the operation were in-place, the original k tensor would not be updated. This results in RoPE not being applied to the query and key tensors, which will lead to incorrect model outputs.

            q_pe, k_pe = rotary_emb(
                positions, q[..., : self.rope_dim], k[..., : self.rope_dim].unsqueeze(1)
            )
            q[..., : self.rope_dim] = q_pe.view_as(q[..., : self.rope_dim])
            k[..., : self.rope_dim] = k_pe.view_as(k[..., : self.rope_dim])

if context_lens.dim() > 1:
context_lens = context_lens.squeeze(-1)
kv_cache_flat = kv_cache.view(-1, block_size * (dim + 4))
for i in range(batch_size):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This fallback implementation uses a Python loop over the batch size and performs F.linear (GEMV) for each sequence. This will be extremely slow for large batches and will significantly degrade performance during the decode phase. While this is a fallback, a vectorized implementation using torch.bmm or torch.matmul should be preferred to avoid the Python loop overhead.

Comment on lines +257 to +259
for _ in range(sinkhorn_repeat - 1):
comb_mix = comb_mix / (comb_mix.sum(dim=-1, keepdim=True) + hc_sinkhorn_eps)
comb_mix = comb_mix / (comb_mix.sum(dim=-2, keepdim=True) + hc_sinkhorn_eps)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The Sinkhorn iterations are implemented using a Python loop. For models like DeepseekV4 that use MHC, this can become a performance bottleneck. It is recommended to vectorize these operations or use a specialized kernel if performance is critical on ROCm.

@Alan-D-Chen
Copy link
Copy Markdown

Besides the author, are there any fellow developers online who have successfully run DeepSeek V4 on AMD MI308X using frameworks like vLLM or SGLang? I’d love to ask for some guidance. BIG THANKS。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build deepseek Related to DeepSeek models documentation Improvements or additions to documentation rocm Related to AMD ROCm v1

Projects

Status: Todo

Development

Successfully merging this pull request may close these issues.

5 participants