Skip to content

[DeepSeek] Seperate deepseek v3.2 modeling form deepseek v2#3531

Merged
wangxiyuan merged 1 commit intovllm-project:mainfrom
MengqingCao:slipt_ds
Oct 20, 2025
Merged

[DeepSeek] Seperate deepseek v3.2 modeling form deepseek v2#3531
wangxiyuan merged 1 commit intovllm-project:mainfrom
MengqingCao:slipt_ds

Conversation

@MengqingCao
Copy link
Copy Markdown
Collaborator

@MengqingCao MengqingCao commented Oct 18, 2025

What this PR does / why we need it?

Seperate deepseek v3.2 modeling form deepseek v2

How was this patch tested?

Signed-off-by: MengqingCao <cmq0113@163.com>
@github-actions
Copy link
Copy Markdown
Contributor

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the DeepSeek model implementations by separating the v3.2 logic into its own file, deepseek_v3_2.py, from the v2 implementation. This is a good architectural improvement for clarity and maintenance. My review focuses on the new deepseek_v3_2.py file. I've identified a critical issue with monkey-patching a base vLLM class, which can cause unpredictable behavior and should be refactored. Additionally, there's an instance of code duplication that should be addressed to improve maintainability.

pass


DeepseekV2DecoderLayer.__init__ = CustomDeepseekV2DecoderLayer.__init__
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Monkey-patching DeepseekV2DecoderLayer.__init__ is a dangerous practice. It modifies a base class from vllm globally at import time, which can lead to unexpected side effects and bugs that are difficult to trace. This is especially problematic as vllm_ascend/models/deepseek_v2.py performs a similar patch, making behavior dependent on import order.

Please refactor to avoid monkey-patching. You can achieve this by:

  1. Removing this line.
  2. Using CustomDeepseekV2DecoderLayer directly in AscendDeepseekV2Model's make_layers call (lines 102-106).

Here is how you can modify AscendDeepseekV2Model:

# vllm_ascend/models/deepseek_v3_2.py:102
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
            lambda prefix: CustomDeepseekV2DecoderLayer(vllm_config, prefix,
                                                  topk_indices_buffer),
            prefix=f"{prefix}.layers")

Comment on lines +117 to +213
class CustomDeepseekV2RowParallelLinear(RowParallelLinear):

def __init__(
self,
input_size: int,
output_size: int,
bias: bool = True,
input_is_parallel: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = True,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
*,
return_bias: bool = True,
disable_tp: bool = False,
):
# Divide the weight matrix along the first dimension.
self.tp_rank = (get_tensor_model_parallel_rank()
if not disable_tp else 0)
self.tp_size = (get_tensor_model_parallel_world_size()
if not disable_tp else 1)
self.input_size_per_partition = divide(input_size, self.tp_size)
self.output_size_per_partition = output_size
self.output_partition_sizes = [output_size]

AscendLinearBase.__init__(self,
input_size,
output_size,
skip_bias_add,
params_dtype,
quant_config,
prefix,
return_bias=return_bias,
disable_tp=disable_tp)

self.input_is_parallel = input_is_parallel
self.reduce_results = reduce_results

assert self.quant_method is not None
self.quant_method.create_weights(
layer=self,
input_size_per_partition=self.input_size_per_partition,
output_partition_sizes=self.output_partition_sizes,
input_size=self.input_size,
output_size=self.output_size,
params_dtype=self.params_dtype,
weight_loader=(
self.weight_loader_v2 if self.quant_method.__class__.__name__
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
if not reduce_results and (bias and not skip_bias_add):
raise ValueError("When not reduce the results, adding bias to the "
"results can lead to incorrect results")

if bias:
self.bias = nn.Parameter(
torch.empty(self.output_size, dtype=params_dtype))
set_weight_attrs(self.bias, {
"output_dim": 0,
"weight_loader": self.weight_loader,
})
else:
self.register_parameter("bias", None)
self.update_param_tp_status()

def forward(
self,
input_,
is_prefill=True,
is_force_scatter=False
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[nn.Parameter]]]:
if self.input_is_parallel:
input_parallel = input_
else:
tp_rank = get_tensor_model_parallel_rank()
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[tp_rank].contiguous()

# Matrix multiply.
assert self.quant_method is not None
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
output_parallel = self.quant_method.apply(self,
input_parallel,
bias=bias_)
if self.reduce_results and self.tp_size > 1:
output = tensor_model_parallel_all_reduce(output_parallel)
else:
output = output_parallel

output_bias = self.bias if self.skip_bias_add else None

if not self.return_bias:
return output
return output, output_bias
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The CustomDeepseekV2RowParallelLinear class appears to be a reimplementation of AscendRowParallelLinear from vllm_ascend.ops.linear. This code duplication increases maintenance overhead and can lead to inconsistencies if one implementation is updated but the other is not.

To improve maintainability, please remove this duplicated class and use the existing AscendRowParallelLinear from vllm_ascend.ops.linear in CustomDeepseekV2SFAAttention (line 309).

@MengqingCao MengqingCao added ready read for review ready-for-test start test by label for PR labels Oct 18, 2025
@MengqingCao
Copy link
Copy Markdown
Collaborator Author

@wangxiyuan @whx-sjtu plz take a look, thanks!

ModelRegistry.register_model(
"DeepseekV32ForCausalLM",
"vllm_ascend.models.deepseek_v2:CustomDeepseekV3ForCausalLM")
"vllm_ascend.models.deepseek_v3_2:CustomDeepseekV3ForCausalLM")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deepseek_v2 is not supported now?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Collaborator

@whx-sjtu whx-sjtu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. This is the pre-PR needed by PR #3189.

@wangxiyuan wangxiyuan merged commit daa4dd0 into vllm-project:main Oct 20, 2025
38 checks passed
wxsIcey pushed a commit to wxsIcey/vllm-ascend that referenced this pull request Oct 20, 2025
…ject#3531)

Seperate deepseek v3.2 modeling form deepseek v2

- CI passed with existing test.
- test deepseek v3.2 locally

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

Signed-off-by: MengqingCao <cmq0113@163.com>
ZYang6263 pushed a commit to rjg-lyh/vllm-ascend that referenced this pull request Oct 23, 2025
…ject#3531)

### What this PR does / why we need it?
Seperate deepseek v3.2 modeling form deepseek v2

### How was this patch tested?
- CI passed with existing test.
- test deepseek v3.2 locally

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

Signed-off-by: MengqingCao <cmq0113@163.com>
luolun pushed a commit to luolun/vllm-ascend that referenced this pull request Nov 19, 2025
…ject#3531)

### What this PR does / why we need it?
Seperate deepseek v3.2 modeling form deepseek v2

### How was this patch tested?
- CI passed with existing test.
- test deepseek v3.2 locally

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

Signed-off-by: MengqingCao <cmq0113@163.com>
Signed-off-by: luolun <luolun1995@cmbchina.com>
hwhaokun pushed a commit to hwhaokun/vllm-ascend that referenced this pull request Nov 19, 2025
…ject#3531)

### What this PR does / why we need it?
Seperate deepseek v3.2 modeling form deepseek v2

### How was this patch tested?
- CI passed with existing test.
- test deepseek v3.2 locally

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

Signed-off-by: MengqingCao <cmq0113@163.com>
Signed-off-by: hwhaokun <haokun0405@163.com>
NSDie pushed a commit to NSDie/vllm-ascend that referenced this pull request Nov 24, 2025
…ject#3531)

### What this PR does / why we need it?
Seperate deepseek v3.2 modeling form deepseek v2

### How was this patch tested?
- CI passed with existing test.
- test deepseek v3.2 locally

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

Signed-off-by: MengqingCao <cmq0113@163.com>
Signed-off-by: nsdie <yeyifan@huawei.com>
Clorist33 pushed a commit to Clorist33/vllm-ascend that referenced this pull request Dec 9, 2025
…ject#3531)

### What this PR does / why we need it?
Seperate deepseek v3.2 modeling form deepseek v2

### How was this patch tested?
- CI passed with existing test.
- test deepseek v3.2 locally

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

Signed-off-by: MengqingCao <cmq0113@163.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready read for review ready-for-test start test by label for PR

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants