Skip to content

[Feature]Supports DSv3.1 PD separation and C8 quantization#7222

Merged
zzzzwwjj merged 16 commits intovllm-project:mainfrom
pichangping:fa_0313
Mar 16, 2026
Merged

[Feature]Supports DSv3.1 PD separation and C8 quantization#7222
zzzzwwjj merged 16 commits intovllm-project:mainfrom
pichangping:fa_0313

Conversation

@pichangping
Copy link
Copy Markdown
Contributor

@pichangping pichangping commented Mar 13, 2026

Co-authored-by: kunpengW-code 1289706727@qq.com
Co-authored-by: linsheng1 1950916997@qq.com

What this PR does / why we need it?

Currently, chunked prefill is forcibly enabled. DeepSeek V3.1 W8A8C8 supports only the PD separation scenario. C8 refers to quantizing the KV cache to int8, which aims to reduce the GPU memory usage of the KV cache and improve the inference throughput.
Constraints:

  1. Only the PD separation mode can be used and MooncakeLayerwiseConnector can be used to run the model.
  2. Currently, only the activation value supports dynamic quantization, and the KV cache supports static quantization. C8 quantization with MTP is not supported. You can use ModelSlim for quantization. The quantization procedure is as follows:
    pip install transformers==4.48.2
    git clone https://gitcode.com/Ascend/msmodelslim.git
    cd msmodelslim
    bash install.sh
    cd example/DeepSeek/
    python3 quant_deepseek_w8a8.py --model_path <path/weight> --save_path <path/quant_weight> --anti_dataset../common/deepseek_anti_prompt_50_v3_1.json --calib_dataset../common/deepseek_calib_prompt_50_v3_1.json --rot --trust_remote_code True --fa_quant --dynamic --anti_method m6

Does this PR introduce any user-facing change?

no

How was this patch tested?

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request integrates DSv3.1 C8 quantization, a significant enhancement for running large language models on Ascend NPUs. The changes span across the attention mechanism, distributed KV cache management, and model loading infrastructure, ensuring that quantized models can leverage specialized hardware operations for improved efficiency and reduced memory footprint without compromising accuracy.

Highlights

  • DSv3.1 C8 Quantization Support: Introduced comprehensive support for DSv3.1 C8 quantization, enabling efficient execution of quantized models on Ascend NPUs.
  • Quantized Attention Mechanism: Implemented a new fused attention score function (npu_fused_infer_attention_score_v2) specifically designed to handle quantized query and key inputs, improving performance for quantized attention layers.
  • Distributed KV Cache Quantization: Enhanced the MooncakeLayerwiseConnector to manage and transfer quantized KV cache data, including new buffers and logic for handling quantized key and value tensors in a distributed environment.
  • Dynamic Weight Remapping for DeepSeek: Added a dynamic patching mechanism to correctly remap weight scale and offset names for DeepSeek models during quantization, ensuring proper loading of quantized parameters.
  • Quantization Configuration and Metadata: Updated the AscendModelSlimConfig to automatically detect and load FA quantization metadata from quant_model_description.json, and introduced AscendFAQuantAttentionMethod to manage quantization-specific parameters.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • vllm_ascend/attention/mla_v1.py
    • Imported enabling_fa_quant and updated the import path for AscendW8A8LinearMethod.
    • Added dequant_scale_q_nope to DecodeMLAPreprocessResult.
    • Introduced layer_name, fa_quant_layer, and dtype attributes to the MLAPreprocessor class.
    • Modified update_graph_params to include dequant_scale_q_nope and fak_descale_float.
    • Updated update_graph_params to conditionally use npu_fused_infer_attention_score_v2 for quantized attention.
    • Added a conditional call to _process_weights_for_fused_fa_quant in process_weights_after_loading.
    • Implemented _process_weights_for_fused_fa_quant to handle weights and scales for FA quantization.
    • Modified _forward_decode to support fa_quant_layer for KV cache reshaping and pass dequantization scales.
    • Updated _mla_preprocess_only_decode to use npu_mla_prolog_v2 for quantized preprocessing.
    • Extended the forward method to enable MLA Preprocess for FA quantized layers.
  • vllm_ascend/attention/utils.py
    • Added the enabling_fa_quant utility function to determine if FA quantization is active for a given layer.
  • vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py
    • Added k_quant_cache and v_quant_cache fields to the SendTask NamedTuple.
    • Included enable_kv_quant, k_quant_buffer, and v_quant_buffer in KVTransferSender initialization.
    • Modified get_transfer_meta to support quantized KV cache data transfer.
    • Added logic in _transfer_kv_cache to copy quantized KV cache data to dedicated buffers.
    • Ensured resharding_stream synchronization for quantized KV cache transfers.
    • Updated the type hint for kv_layer in save_kv_layer to accept a list of tensors.
    • Adjusted __init__ to enable resharding_stream if KV quantization is enabled.
    • Modified create_kv_buffer to handle tuples of KV caches and allocate buffers for quantized data.
    • Updated register_kv_caches to use the new KV buffer mechanism and pass quantization parameters.
    • Changed start_load_kv to correctly determine the device for buffers and update send_task for quantized layers.
    • Extended save_kv_layer to quantize and transfer KV cache data when FA quantization is active.
    • Added the trans_nd_to_nz method for transforming tensor formats.
  • vllm_ascend/ops/mla.py
    • Passed the layer_name argument to the MLAPreprocessor constructor.
  • vllm_ascend/patch/worker/init.py
    • Imported the new patch_weight_utils module.
  • vllm_ascend/patch/worker/patch_weight_utils.py
    • Added a new file to implement dynamic patching for weight utility functions.
    • Defined ImportPatchDecorator to register and apply module patches.
    • Implemented patch_deepseek to remap specific scale and offset names for DeepSeek models.
    • Implemented patch_weight_utils to apply the DeepSeek remapping during module import.
    • Overrode the built-in __import__ function to apply patches dynamically.
  • vllm_ascend/quantization/methods/init.py
    • Imported AscendFAQuantAttentionMethod.
    • Added AscendFAQuantAttentionMethod to the list of exported quantization methods.
  • vllm_ascend/quantization/methods/kv_c8.py
    • Added a new file defining AscendFAQuantAttentionMethod for FAKQuant attention.
    • Implemented weight_loader for fa_q weights with tensor parallelism support.
    • The AscendFAQuantAttentionMethod creates fa_q, fa_k, fa_v modules and their scale/offset parameters.
    • The process_weights_after_loading method calculates various dequantization scales and quant_kscale.
  • vllm_ascend/quantization/modelslim_config.py
    • Imported necessary modules: glob, json, os, re, and AttentionLayerBase.
    • Defined MODELSLIM_CONFIG_FILENAME for the quantization configuration file.
    • Modified AscendModelSlimConfig.__init__ to accept an optional quant_config and apply extra adaptations.
    • Changed get_config_filenames to return an empty list, deferring config loading.
    • Updated get_quant_method to use is_fa_quant_layer for AttentionLayerBase instances.
    • Added is_fa_quant_layer to check if a layer is configured for FA quantization.
    • Implemented maybe_update_config to load the ModelSlim config from the model directory and provide error handling.
    • Added _apply_extra_quant_adaptations to handle specific key transformations in the quantization description.
    • Implemented _add_kvcache_quant_metadata to extract and store FA quantization layer information.
  • vllm_ascend/worker/model_runner_v1.py
    • Initialized self.kvbytes dictionary in the __init__ method.
    • Modified _allocate_kv_cache_tensors to dynamically adjust head_size and split factors based on self.kvbytes for quantized layers.
    • Updated _reshape_kv_cache_tensors to use the correct dtype for v_cache in quantized scenarios.
    • Added a dtype_to_bytes helper function.
    • Enhanced get_kv_cache_spec to handle fa_quant_layer and populate self.kvbytes with byte sizes for K/V tensors.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Co-authored-by: kunpengW-code <1289706727@qq.com>
Co-authored-by: linsheng1 <1950916997@qq.com>

Signed-off-by: pichangping <1337510399@qq.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for DSv3.1 C8 quantization, specifically for attention layers and KV cache. The changes are comprehensive, integrating the new quantization method across various components including attention implementation, KV transfer mechanisms, weight loading utilities, and quantization configuration. The implementation includes new data structures, conditional logic for new NPU operators, and specific weight processing for quantized tensors. The approach seems well-structured, ensuring that the new quantization is applied dynamically and correctly without impacting existing functionalities. The code adheres to the specified requirements for enabling and managing this new feature.

@github-actions
Copy link
Copy Markdown
Contributor

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

@pichangping pichangping changed the title Supports DSv3.1 C8 quantization [Feature]Supports DSv3.1 PD separation and C8 quantization Mar 13, 2026
@github-actions
Copy link
Copy Markdown
Contributor

This pull request has conflicts, please resolve those before we can evaluate the pull request.

@kunpengW-code
Copy link
Copy Markdown
Contributor

/gemini review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for C8 quantization and prefill/decode separation, which involves extensive changes across attention mechanisms, KV cache transfer, and quantization configurations. The core logic for quantized attention is implemented in vllm_ascend/attention/mla_v1.py, utilizing a new version of the fused attention kernel. A new quantization method, AscendFAQuantAttentionMethod, is also added. While the changes are comprehensive, I've identified a couple of critical issues. One is a potential ValueError due to an unsafe string-to-integer conversion. The other is a logic error where a decode-specific quantization scale is incorrectly passed to the prefill function, while the decode function call is missing this necessary argument. Addressing these issues is crucial for the correctness of the new quantization feature.

Comment thread vllm_ascend/attention/mla_v1.py Outdated
prefill_preprocess_res.value,
kv_cache,
attn_metadata,
decode_preprocess_res.dequant_scale_q_nope,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There seems to be a mis-wiring of the dequant_scale_q_nope argument. This scale is generated during decode preprocessing and should be used in the decode path, not the prefill path.

  1. This argument should be removed from the _forward_prefill call.
  2. It should be passed to the _forward_decode call around line 1655, which is currently missing it. The _forward_decode function signature has been updated to accept it, but the call site has not been updated.

Comment thread vllm_ascend/attention/utils.py Outdated
Comment on lines +346 to +348
id = "".join(re.findall(r"\.(\d+)\.", layer_name))
if int(id) in quant_config.kvcache_quant_layers:
fa_quant_layer = True
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

If re.findall returns an empty list, id will be an empty string, and int(id) will raise a ValueError. This can happen if layer_name does not contain a number between dots. It's safer to check if a valid ID was found before converting to an integer. Also, using id as a variable name shadows the built-in function.

Suggested change
id = "".join(re.findall(r"\.(\d+)\.", layer_name))
if int(id) in quant_config.kvcache_quant_layers:
fa_quant_layer = True
layer_id_str = "".join(re.findall(r"\.(\d+)\.", layer_name))
if layer_id_str.isdigit():
if int(layer_id_str) in quant_config.kvcache_quant_layers:
fa_quant_layer = True

)

# Load cache data into buffers
torch_npu.atb.npu_paged_cache_load(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The atb operator may report an error in A5.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The atb operator needs to be migrated to the vllm-ascend custom ops and adapted to A5.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to do this in the future #7246

Comment thread vllm_ascend/attention/mla_v1.py Outdated
actual_seq_qlen=actual_seq_lengths,
workspace=graph_params.workspaces.get(num_tokens),
out=[attn_output, softmax_lse],
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May can replace v1 with v2.

Comment thread vllm_ascend/worker/model_runner_v1.py Outdated
# ordering expected by graph parameter update logic in attention backends.
mamba_layers: dict[str, MambaBase] = {}

def dtype_to_bytes(dtype: torch.dtype) -> int:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

plz move this function to utils

self.kv_send_layer_thread.send_queue.put(layer_send_task)
self.current_layer += 1

def trans_nd_to_nz(self, cache_tensor: torch.Tensor, layer_group_idx: int):
Copy link
Copy Markdown
Collaborator

@MengqingCao MengqingCao Mar 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

plz add instruction on this function, and maybe this should also in utils?

)

# Load cache data into buffers
torch_npu.atb.npu_paged_cache_load(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The atb operator needs to be migrated to the vllm-ascend custom ops and adapted to A5.

)
buffer_list.append(self.k_buffer)
buffer_list.append(self.v_buffer)
if self.enable_kv_quant:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As the future GQA model will also use C8 quantization, the buffer can be allocated based on the input dtype without the need to write additional branches.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a TODO, refactor it in the next pr

for group_remote_block_id, group_local_block_id in zip(
grouped_remote_block_ids, grouped_local_block_ids
# kv cache quantization scenario
if self.enable_kv_quant and send_task.k_quant_cache is not None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This branch behaves similarly to branches where pd_head_ratio > 1. Can these be simplified into the same code segment?

Comment thread vllm_ascend/worker/model_runner_v1.py Outdated
@@ -2678,80 +2666,73 @@ def _allocate_kv_cache_tensors(self, kv_cache_config: KVCacheConfig) -> dict[str
# For deepseek mla, we need to spilt cache tensor accrodding to the nope head dim
# and rope head dim.
if self.model_config.use_mla:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MengqingCao I think we should refactor the initialize_kvcache_tensors function to reduce unnecessary branching.

Comment thread vllm_ascend/worker/model_runner_v1.py Outdated
self.query_lens: torch.Tensor | None = None
self.cpu_slot_mapping = None
self.sampling_done_event: torch.npu.Event | None = None
self.kvbytes = {}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removable optional parameters

…to fa_0313

# Conflicts:
#	vllm_ascend/quantization/modelslim_config.py
#	vllm_ascend/worker/model_runner_v1.py
Signed-off-by: Wang Kunpeng <1289706727@qq.com>
Signed-off-by: Wang Kunpeng <1289706727@qq.com>
Signed-off-by: Wang Kunpeng <1289706727@qq.com>
Signed-off-by: Wang Kunpeng <1289706727@qq.com>
Signed-off-by: Wang Kunpeng <1289706727@qq.com>
Signed-off-by: Wang Kunpeng <1289706727@qq.com>
@chenxi-hh chenxi-hh added ready read for review ready-for-test start test by label for PR labels Mar 14, 2026
@github-actions
Copy link
Copy Markdown
Contributor

This pull request has conflicts, please resolve those before we can evaluate the pull request.

…to fa_0313

# Conflicts:
#	vllm_ascend/patch/__init__.py
#	vllm_ascend/worker/model_runner_v1.py
Signed-off-by: Wang Kunpeng <1289706727@qq.com>
Signed-off-by: Wang Kunpeng <1289706727@qq.com>
@github-actions
Copy link
Copy Markdown
Contributor

This pull request has conflicts, please resolve those before we can evaluate the pull request.

1 similar comment
@github-actions
Copy link
Copy Markdown
Contributor

This pull request has conflicts, please resolve those before we can evaluate the pull request.

…to fa_0313

# Conflicts:
#	vllm_ascend/worker/model_runner_v1.py
Signed-off-by: Wang Kunpeng <1289706727@qq.com>
setattr(layer, name, torch.nn.Module())
params_dict = {}
dtype = torch.get_default_dtype()
layer.num_kv_heads = 1
Copy link
Copy Markdown
Contributor

@kunpengW-code kunpengW-code Mar 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The layer.num_kv_heads parameter does not need to be assigned

Signed-off-by: Wang Kunpeng <1289706727@qq.com>
Signed-off-by: Wang Kunpeng <1289706727@qq.com>
@zzzzwwjj zzzzwwjj merged commit 3f39ac9 into vllm-project:main Mar 16, 2026
38 checks passed
Nagisa125 pushed a commit to starmountain1997/vllm-ascend that referenced this pull request Mar 17, 2026
…ect#7222)

Co-authored-by: kunpengW-code <1289706727@qq.com>
Co-authored-by: linsheng1 <1950916997@qq.com>

### What this PR does / why we need it?
Currently, chunked prefill is forcibly enabled. DeepSeek V3.1 W8A8C8
supports only the PD separation scenario. C8 refers to quantizing the KV
cache to int8, which aims to reduce the GPU memory usage of the KV cache
and improve the inference throughput.
Constraints: 
1. Only the PD separation mode can be used and
MooncakeLayerwiseConnector can be used to run the model.
2. Currently, only the activation value supports dynamic quantization,
and the KV cache supports static quantization. C8 quantization with MTP
is not supported. You can use ModelSlim for quantization. The
quantization procedure is as follows:
pip install transformers==4.48.2
git clone https://gitcode.com/Ascend/msmodelslim.git
cd msmodelslim
bash install.sh
cd example/DeepSeek/
python3 quant_deepseek_w8a8.py --model_path <path/weight> --save_path
<path/quant_weight>
--anti_dataset../common/deepseek_anti_prompt_50_v3_1.json
--calib_dataset../common/deepseek_calib_prompt_50_v3_1.json --rot
--trust_remote_code True --fa_quant --dynamic --anti_method m6

### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?

- vLLM version: v0.17.0
- vLLM main:
vllm-project/vllm@4034c3d

---------

Signed-off-by: pichangping <1337510399@qq.com>
Signed-off-by: Wang Kunpeng <1289706727@qq.com>
Co-authored-by: Wang Kunpeng <1289706727@qq.com>
ichaoren pushed a commit to ichaoren/vllm-ascend that referenced this pull request Mar 17, 2026
…ect#7222)

Co-authored-by: kunpengW-code <1289706727@qq.com>
Co-authored-by: linsheng1 <1950916997@qq.com>

### What this PR does / why we need it?
Currently, chunked prefill is forcibly enabled. DeepSeek V3.1 W8A8C8
supports only the PD separation scenario. C8 refers to quantizing the KV
cache to int8, which aims to reduce the GPU memory usage of the KV cache
and improve the inference throughput.
Constraints:
1. Only the PD separation mode can be used and
MooncakeLayerwiseConnector can be used to run the model.
2. Currently, only the activation value supports dynamic quantization,
and the KV cache supports static quantization. C8 quantization with MTP
is not supported. You can use ModelSlim for quantization. The
quantization procedure is as follows:
pip install transformers==4.48.2
git clone https://gitcode.com/Ascend/msmodelslim.git
cd msmodelslim
bash install.sh
cd example/DeepSeek/
python3 quant_deepseek_w8a8.py --model_path <path/weight> --save_path
<path/quant_weight>
--anti_dataset../common/deepseek_anti_prompt_50_v3_1.json
--calib_dataset../common/deepseek_calib_prompt_50_v3_1.json --rot
--trust_remote_code True --fa_quant --dynamic --anti_method m6

### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?

- vLLM version: v0.17.0
- vLLM main:
vllm-project/vllm@4034c3d

---------

Signed-off-by: pichangping <1337510399@qq.com>
Signed-off-by: Wang Kunpeng <1289706727@qq.com>
Co-authored-by: Wang Kunpeng <1289706727@qq.com>
Signed-off-by: xutianyi <xutianyi5@huawei.com>
845473182 pushed a commit to 845473182/vllm-ascend that referenced this pull request Mar 18, 2026
…scend into qwen3next_graph

* 'qwen3next_graph' of https://github.com/845473182/vllm-ascend: (62 commits)
  [doc] Refresh the documentation for DeepSeek-V3.2 (vllm-project#7403)
  [bugfix][accuracy] Fix ds indexer accuracy problem caused by k rope (vllm-project#7341)
  [P/D] LayerwiseConnector supports the virtual push functionality on node D. (vllm-project#7361)
  [CI] Add PAT_TOKEN when checkout (vllm-project#7400)
  [main2main] upgrade vllm to 0308 (vllm-project#7213)
  [CI] add scheduled stale issue management (vllm-project#7354)
  [CI] expand issue labeler rules for feature/model triage (vllm-project#7356)
  [Bugfix] Assertion error when decode prefix cache fully hits (vllm-project#7236)
  [doc] Refresh the documentation for GLM-4.7 (vllm-project#7292)
  [BugFix]A2 MOE method&& layerwise MTP bugfix && Mamba gdn_metadata bugfix (vllm-project#7364)
  [doc] Upload doc for qwen3.5-27B and qwen3.5-397B-A17B on Ascend (vllm-project#7313)
  [bugfix]Enable dispatch_ffn_combine feature for qwen3.5 (vllm-project#7066)
  [bugfix] fix unzip file path for fia operator (vllm-project#7367)
  [Perf] Optimize bias handling in AscendRMSNorm (vllm-project#7226)
  [eagle3][pcp] fix bug for eagle3 and cp enable (vllm-project#7309)
  [Bugfix] fix TransposeKvCacheByBlock op error report in plog (vllm-project#7235)
  [Feature]Supports DSv3.1 PD separation and C8 quantization (vllm-project#7222)
  [main][bugfix] Fixed the problem that eagle3 will crash in FULL_DECODE_ONLY (vllm-project#7290)
  [xlite][Bugfix] Support mrope and deepstack features in xlite backend (vllm-project#7295)
  [model_runner_v2]optimize the performance of the _topk_log_softmax_kernel (vllm-project#7221)
  ...
MengqingCao added a commit to yiz-liu/vllm-ascend that referenced this pull request Mar 26, 2026
…notes

- Refine Balance scheduling description (line 9)
- Relocate Flash Comm V1 from Highlights to Features (line 10)
- Add C8 INT8 KV (vllm-project#7474) and DSv3.1 C8 quant (vllm-project#7222) to Highlights (line 12)
- Remove LayerwiseConnector entry from Features (line 18)
- Remove Dependencies section / vLLM upgrade entry (line 27)
- Remove enable_sparse_c8 doc entry (line 31)
- Remove lowered PD log level entry from Others (line 36)
- Remove speculative decoding proposer fix entry (line 40)

Signed-off-by: MengqingCao <cmq0113@163.com>
chenchuw886 pushed a commit to chenchuw886/vllm-ascend that referenced this pull request Apr 1, 2026
…ect#7222)

Co-authored-by: kunpengW-code <1289706727@qq.com>
Co-authored-by: linsheng1 <1950916997@qq.com>

### What this PR does / why we need it?
Currently, chunked prefill is forcibly enabled. DeepSeek V3.1 W8A8C8
supports only the PD separation scenario. C8 refers to quantizing the KV
cache to int8, which aims to reduce the GPU memory usage of the KV cache
and improve the inference throughput.
Constraints: 
1. Only the PD separation mode can be used and
MooncakeLayerwiseConnector can be used to run the model.
2. Currently, only the activation value supports dynamic quantization,
and the KV cache supports static quantization. C8 quantization with MTP
is not supported. You can use ModelSlim for quantization. The
quantization procedure is as follows:
pip install transformers==4.48.2
git clone https://gitcode.com/Ascend/msmodelslim.git
cd msmodelslim
bash install.sh
cd example/DeepSeek/
python3 quant_deepseek_w8a8.py --model_path <path/weight> --save_path
<path/quant_weight>
--anti_dataset../common/deepseek_anti_prompt_50_v3_1.json
--calib_dataset../common/deepseek_calib_prompt_50_v3_1.json --rot
--trust_remote_code True --fa_quant --dynamic --anti_method m6

### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?

- vLLM version: v0.17.0
- vLLM main:
vllm-project/vllm@4034c3d

---------

Signed-off-by: pichangping <1337510399@qq.com>
Signed-off-by: Wang Kunpeng <1289706727@qq.com>
Co-authored-by: Wang Kunpeng <1289706727@qq.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module:ops module:quantization ready read for review ready-for-test start test by label for PR

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants