Skip to content

[New Model] Support DeepseekV4#40760

Closed
zyongye wants to merge 10 commits intovllm-project:mainfrom
zyongye:dsv4
Closed

[New Model] Support DeepseekV4#40760
zyongye wants to merge 10 commits intovllm-project:mainfrom
zyongye:dsv4

Conversation

@zyongye
Copy link
Copy Markdown
Member

@zyongye zyongye commented Apr 24, 2026

Congratulations on Deepseek-ai to release the model. Thanks for all Inferact member's effort for support this.

Note: This model implementation is highly optimized. All the component is coupled. Lot of manually fused kernel. Please consult @WoosukKwon @zyongye @ivanium before making any changes.

Please see https://recipes.vllm.ai/deepseek-ai/DeepSeek-V4-Pro for recipes

Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: Yifan Qiao <yifanqiao@berkeley.edu>
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
Signed-off-by: Nick Hill <nickhill123@gmail.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: yasong.wang <yasong.wang@inferact.ai>
Signed-off-by: Zhewen Li <zhewenli@inferact.ai>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Copy link
Copy Markdown

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
@Windswithyou
Copy link
Copy Markdown

Any cookbook?And How can I run it with hopper?

Please see https://recipes.vllm.ai/deepseek-ai/DeepSeek-V4-Pro for recipes

I am trying to serve deepseek-ai/DeepSeek-V4-Pro using vLLM across 2 nodes, each equipped with 8x H100 80GB GPUs (16 GPUs in total). I am following the instructions from the vLLM recipes page (https://recipes.vllm.ai/deepseek-ai/DeepSeek-V4-Pro).
However, when initializing the vLLM engine, it consistently crashes with an Out of Memory (OOM) error. Even, set model-len 32k, set max-num-seqs, gpu-memory-utilization, etc., nothing worked.
Do you have any best practices for a 2-node H100 setup?
The 4-node works fine.
Just to add, using speculative_config causes issues as well.

@Yang1032 Hi bro, may I ask how did you manage to run v4-pro on 4 nodes? I still meet CUDA OOM on 4 nodes with the default DP num=cards num setup. Thank you very much!

@wxsms H100 80G V4-Pro 2-node TP4,DP4 may be good, or increase TP 4-node DP32 is great Other:the stream FC has some problems

you can follow my fork, just edit v32_parser. And then you can use stream function call

junpuf added a commit to aws/deep-learning-containers that referenced this pull request Apr 24, 2026
Pin vLLM source to zyongye/vllm@bc34b25e (dsv4 branch) from
vllm-project/vllm#40760 which adds
[New Model] Support DeepseekV4.

Changes:
- Add docker/vllm/versions.env with custom VLLM_REPO/VLLM_REF
- Update image configs to point to the custom commit
- Add EXTRA_BUILD_ARGS forwarding in build_image.sh
- Add SETUPTOOLS_SCM_PRETEND_VERSION build-arg in Dockerfile
- Update workflows to source versions.env and include vllm-ref-short in tags
Comment on lines +15 to +16
tool_call_start_token: str = "<|DSML|tool_calls>"
tool_call_end_token: str = "</|DSML|tool_calls>"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These class attributes would not take effect, can do sth like this instead:

  def __init__(self, tokenizer, tools=None):
      super().__init__(tokenizer, tools)
      self.tool_call_start_token = "<|DSML|tool_calls>"
      self.tool_call_end_token = "</|DSML|tool_calls>"
      self.tool_call_complete_regex = re.compile(
          r"<|DSML|tool_calls>(.*?)</|DSML|tool_calls>", re.DOTALL
      )

hash_indices_table: torch.Tensor | None = None,
routed_scaling_factor: float = 1.0,
) -> tuple[torch.Tensor, ...]:
ops.topk_hash_softplus_sqrt(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When using DeepEP, this crashes with "expected scalar type Long but found Int"

The CUDA kernel in topk_softplus_sqrt_kernels.cu dispatches input_tokens and hash_indices_table data_ptr based on topk_indices.scalar_type(). DeepEP sets topk_indices_dtype to int64, but input_tokens and hash_indices_table are int32.

We can detect and handle this case:

Suggested change
ops.topk_hash_softplus_sqrt(
idx_dtype = topk_indices.dtype
if input_tokens is not None and input_tokens.dtype != idx_dtype:
input_tokens = input_tokens.to(idx_dtype)
if hash_indices_table is not None and hash_indices_table.dtype != idx_dtype:
hash_indices_table = hash_indices_table.to(idx_dtype)
ops.topk_hash_softplus_sqrt(

Copy link
Copy Markdown
Member Author

@zyongye zyongye Apr 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's deepep specific constraint? I think all other a2a assume topk_ids to be int32. Can we change the payload on deepep side (btw v2 just come out idk if they have this capability)

import torch

from vllm import _custom_ops as ops
from vllm.model_executor.layers.deepseek_v4_attention import (
Copy link
Copy Markdown
Collaborator

@tjtanaa tjtanaa Apr 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the path seems to have been changed to

from vllm.v1.attention.ops.deepseek_v4_ops import (
    quantize_and_insert_k_cache,
)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh great catch!

@ivanium ivanium mentioned this pull request Apr 25, 2026
4 tasks
@ChuanLi1101 ChuanLi1101 self-assigned this Apr 25, 2026
}
// Compute per-thread scale (using warp reduction when renormalizing).
if (renormalize) {
selected_sum = warpReduceSum(selected_sum);
Copy link
Copy Markdown
Collaborator

@tjtanaa tjtanaa Apr 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cuda_compat.sh has a helper function VLLM_SHFL_XOR_SYNC_WIDTH which can be used to handle both CUDA and ROCm differences

how about defining it this way

#pragma unroll
      for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) {
        selected_sum +=
            VLLM_SHFL_XOR_SYNC_WIDTH(selected_sum, mask, THREADS_PER_ROW);
      }

Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
@mergify mergify Bot mentioned this pull request Apr 25, 2026
4 tasks
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
@zyongye
Copy link
Copy Markdown
Member Author

zyongye commented Apr 27, 2026

#40860

@ivanium
Copy link
Copy Markdown
Contributor

ivanium commented Apr 27, 2026

While decode swa token usage is full, 'get_cpu_copy NotImplementedError' raised.

Scheduler hit an exception: Traceback (most recent call last):
   File "/workspace/sglang/python/sglang/srt/managers/scheduler.py", line 3041, in run_scheduler_process
     scheduler.event_loop_overlap_disagg_decode()
   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 120, in decorate_context
     return func(*args, **kwargs)
            ^^^^^^^^^^^^^^^^^^^^^
   File "/workspace/sglang/python/sglang/srt/disaggregation/decode.py", line 915, in event_loop_overlap_disagg_decode
     batch = self.get_next_disagg_decode_batch_to_run()
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
   File "/workspace/sglang/python/sglang/srt/disaggregation/decode.py", line 977, in get_next_disagg_decode_batch_to_run
     self.running_batch = self.update_running_batch(self.running_batch)
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
   File "/workspace/sglang/python/sglang/srt/managers/scheduler.py", line 2236, in update_running_batch
     retracted_reqs, new_token_ratio, reqs_to_abort = batch.retract_decode(
                                                      ^^^^^^^^^^^^^^^^^^^^^
   File "/workspace/sglang/python/sglang/srt/managers/schedule_batch.py", line 1897, in retract_decode
     self.release_req(idx, len(sorted_indices), server_args)
   File "/workspace/sglang/python/sglang/srt/managers/schedule_batch.py", line 1930, in release_req
     req.offload_kv_cache(
   File "/workspace/sglang/python/sglang/srt/managers/schedule_batch.py", line 1117, in offload_kv_cache
     self.kv_cache_cpu = token_to_kv_pool_allocator.get_cpu_copy(token_indices)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
   File "/workspace/sglang/python/sglang/srt/mem_cache/swa_memory_pool.py", line 642, in get_cpu_copy
     return self._kvcache.get_cpu_copy(indices)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
   File "/workspace/sglang/python/sglang/srt/mem_cache/memory_pool.py", line 671, in get_cpu_copy
     raise NotImplementedError()
 NotImplementedError

Thanks for the report. From the stack trace, this appears to be coming from SGLang’s path rather than vLLM. We don’t expect this issue to occur with vLLM, so I’d recommend trying the same scenario with vLLM and letting us know if it reproduces there.
The support is included in our v0.20.0 release and recipes can be found in https://recipes.vllm.ai/deepseek-ai/DeepSeek-V4-Pro

@nskpro-cmd
Copy link
Copy Markdown

nskpro-cmd commented Apr 30, 2026

Any cookbook?And How can I run it with hopper?

Please see https://recipes.vllm.ai/deepseek-ai/DeepSeek-V4-Pro for recipes

I am trying to serve deepseek-ai/DeepSeek-V4-Pro using vLLM across 2 nodes, each equipped with 8x H100 80GB GPUs (16 GPUs in total). I am following the instructions from the vLLM recipes page (https://recipes.vllm.ai/deepseek-ai/DeepSeek-V4-Pro).
However, when initializing the vLLM engine, it consistently crashes with an Out of Memory (OOM) error. Even, set model-len 32k, set max-num-seqs, gpu-memory-utilization, etc., nothing worked.
Do you have any best practices for a 2-node H100 setup?
The 4-node works fine.
Just to add, using speculative_config causes issues as well.

@Yang1032 Hi bro, may I ask how did you manage to run v4-pro on 4 nodes? I still meet CUDA OOM on 4 nodes with the default DP num=cards num setup. Thank you very much!

@wxsms H100 80G V4-Pro 2-node TP4,DP4 may be good, or increase TP 4-node DP32 is great Other:the stream FC has some problems

Hi bro, i tried with TP-8 DP-2 PP-1 inttially and got same out of memmory errors. agian as you mentioned i shifted to TP4 DP4 PP1.
Still the same isssue. how could we run on 2 H100 nodes, without encountering these errors.

@wxsms , did you resolve the issue? if yes please share the config, Thankyou.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build deepseek Related to DeepSeek models documentation Improvements or additions to documentation gpt-oss Related to GPT-OSS models kv-connector needs-rebase new-model Requests to new models nvidia performance Performance-related issues speculative-decoding tool-calling v1

Projects

Status: Done
Status: Done
Status: Done

Development

Successfully merging this pull request may close these issues.