diff --git a/.github/workflows/_accuracy_test.yaml b/.github/workflows/_e2e_nightly_single_node_models.yaml
similarity index 62%
rename from .github/workflows/_accuracy_test.yaml
rename to .github/workflows/_e2e_nightly_single_node_models.yaml
index b9d155f231..b7d55945b0 100644
--- a/.github/workflows/_accuracy_test.yaml
+++ b/.github/workflows/_e2e_nightly_single_node_models.yaml
@@ -1,4 +1,21 @@
-name: 'accuracy test'
+#
+# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# This file is a part of the vllm-ascend project.
+#
+
+name: 'e2e nightly models test'
on:
workflow_call:
@@ -16,7 +33,7 @@ on:
image:
required: true
type: string
- model_name:
+ model_list:
required: true
type: string
upload:
@@ -24,38 +41,44 @@ on:
type: boolean
default: false
-jobs:
- accuracy_tests:
+# Bash shells do not use ~/.profile or ~/.bashrc so these shells need to be explicitly
+# declared as "shell: bash -el {0}" on steps that need to be properly activated.
+# It's used to activate ascend-toolkit environment variables.
+defaults:
+ run:
+ shell: bash -el {0}
+
+# only cancel in-progress runs of the same workflow
+# and ignore the lint / 1 card / 2 cards / 4 cards test type
+concurrency:
+ group: ${{ github.workflow }}-${{ github.ref }}-${{ inputs.runner }}-${{inputs.model_list}}
+ cancel-in-progress: true
+jobs:
+ e2e-nightly:
+ name: ${{inputs.model_list}} accuracy test
runs-on: ${{ inputs.runner }}
- name: ${{ inputs.model_name }} accuracy
container:
image: swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/cann:8.3.rc1-910b-ubuntu22.04-py3.11
env:
VLLM_USE_MODELSCOPE: True
- # 1. If version specified (work_dispatch), do specified branch accuracy test
- # 2. If no version (labeled PR), do accuracy test by default ref:
- # The branch, tag or SHA to checkout. When checking out the repository that
- # triggered a workflow, this defaults to the reference or SHA for that event.
- # Otherwise, uses the default branch.
GHA_VLLM_ASCEND_VERSION: ${{ inputs.vllm-ascend }}
-
steps:
- - name: Checkout repository
- uses: actions/checkout@v4
-
- - name: Set model name as output
- id: set_output
+ - name: Check npu and CANN info
run: |
- echo "model_name=${{ inputs.model_name }}" >> $GITHUB_OUTPUT
+ npu-smi info
+ cat /usr/local/Ascend/ascend-toolkit/latest/"$(uname -i)"-linux/ascend_toolkit_install.info
- name: Config mirrors
run: |
- sed -Ei 's@(ports|archive).ubuntu.com@cache-service.nginx-pypi-cache.svc.cluster.local:8081@g' /etc/apt/sources.list
- pip config set global.index-url http://cache-service.nginx-pypi-cache.svc.cluster.local/pypi/simple
- pip config set global.trusted-host cache-service.nginx-pypi-cache.svc.cluster.local
+ sed -i 's|ports.ubuntu.com|mirrors.tuna.tsinghua.edu.cn|g' /etc/apt/sources.list
+ pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
apt-get update -y
apt install git -y
+ git config --global url."https://gh-proxy.test.osinfra.cn/https://github.com/".insteadOf https://github.com/
+
+ - name: Checkout vllm-project/vllm-ascend repo
+ uses: actions/checkout@v4
- name: Install system dependencies
run: |
@@ -73,9 +96,16 @@ jobs:
working-directory: ./vllm-empty
run: |
VLLM_TARGET_DEVICE=empty pip install -e .
-
+
+ - name: Install vllm-project/vllm-ascend
+ env:
+ PIP_EXTRA_INDEX_URL: https://mirrors.huaweicloud.com/ascend/repos/pypi
+ run: |
+ pip install -r requirements-dev.txt
+ pip install -v -e .
+
- name: Install Ascend toolkit & triton_ascend (for Qwen3-Next-80B-A3B-Instruct)
- if: ${{ inputs.model_name == 'Qwen3-Next-80B-A3B-Instruct' }}
+ if: ${{ inputs.runner == 'linux-aarch64-a2-4' && contains(inputs.model_list, 'Qwen3-Next-80B-A3B-Instruct') }}
shell: bash -l {0}
run: |
wget -q https://vllm-ascend.obs.cn-north-4.myhuaweicloud.com/vllm-ascend/Ascend-BiSheng-toolkit_aarch64.run -O /tmp/Ascend-BiSheng-toolkit_aarch64.run
@@ -108,14 +138,6 @@ jobs:
path: ./vllm-ascend
ref: ${{ env.GHA_VLLM_ASCEND_VERSION }}
- - name: Install vllm-project/vllm-ascend
- working-directory: ./vllm-ascend
- env:
- PIP_EXTRA_INDEX_URL: https://mirrors.huaweicloud.com/ascend/repos/pypi
- run: |
- pip install -r requirements-dev.txt
- pip install -v -e .
-
- name: Get vLLM commit hash and URL
working-directory: ./vllm-empty
run: |
@@ -149,11 +171,12 @@ jobs:
pip show vllm | grep "Version:" | awk '{print "GHA_VLLM_VERSION="$2}' | sed 's/+.*//'
} >> "$GITHUB_ENV"
- - name: Run accuracy test
+ - name: Run vllm-project/vllm-ascend accuracy test
id: report
env:
VLLM_WORKER_MULTIPROC_METHOD: spawn
VLLM_USE_MODELSCOPE: True
+ VLLM_CI_RUNNER: ${{ inputs.runner }}
VLLM_VERSION: ${{ env.GHA_VLLM_VERSION }}
VLLM_COMMIT: ${{ env.VLLM_COMMIT }}
VLLM_ASCEND_VERSION: ${{ env.GHA_VLLM_ASCEND_VERSION || github.ref }}
@@ -162,24 +185,44 @@ jobs:
TORCH_VERSION: ${{ env.GHA_TORCH_VERSION }}
TORCH_NPU_VERSION: ${{ env.GHA_TORCH_NPU_VERSION }}
run: |
- model_base_name=$(basename ${{ inputs.model_name }})
- markdown_name="${model_base_name}"
- echo "markdown_name=$markdown_name" >> $GITHUB_OUTPUT
mkdir -p ./benchmarks/accuracy
- pytest -sv ./tests/e2e/models/test_lm_eval_correctness.py \
- --config ./tests/e2e/models/configs/${{ inputs.model_name }}.yaml
+ echo "Received model_list: ${{ inputs.model_list }}"
+ models=$(echo '${{ inputs.model_list }}' | jq -r '.[]')
+ any_failure=0
+ for model in $models; do
+ echo "Running test for model: $model"
+ pytest -sv ./tests/e2e/models/test_lm_eval_correctness.py \
+ --config "./tests/e2e/models/configs/${model}.yaml" || {
+ echo "Test failed for model: $model"
+ any_failure=1
+ }
+ done
+
+ if [ $any_failure -ne 0 ]; then
+ exit 1
+ fi
- name: Generate step summary
if: ${{ always() }}
run: |
- cat ./benchmarks/accuracy/${{ steps.report.outputs.markdown_name }}.md >> $GITHUB_STEP_SUMMARY
+ models=$(echo '${{ inputs.model_list }}' | jq -r '.[]')
+ for model in $models; do
+ echo "Processing model: $model"
+ model_base_name=$(basename "$model")
+ cat ./benchmarks/accuracy/${model_base_name}.md >> $GITHUB_STEP_SUMMARY
+ done
+
+ - name: Set artifact timestamp
+ id: ts
+ run: |
+ echo "artifact_ts=$(date -u +%Y%m%dT%H%M%SZ)" >> $GITHUB_OUTPUT
- name: Upload Report
if: ${{ inputs.upload == true }}
uses: actions/upload-artifact@v5
with:
- name: "report-${{ env.GHA_VLLM_ASCEND_VERSION }}-${{ steps.report.outputs.markdown_name }}"
- path: ./benchmarks/accuracy/${{ steps.report.outputs.markdown_name }}.md
+ name: report-${{ env.GHA_VLLM_ASCEND_VERSION }}-${{ steps.ts.outputs.artifact_ts }}
+ path: ./benchmarks/accuracy/
if-no-files-found: warn
retention-days: 90
- overwrite: true
+ overwrite: true
\ No newline at end of file
diff --git a/.github/workflows/_e2e_test.yaml b/.github/workflows/_e2e_test.yaml
index 476948ba3b..6bbd4ba64f 100644
--- a/.github/workflows/_e2e_test.yaml
+++ b/.github/workflows/_e2e_test.yaml
@@ -182,7 +182,7 @@ jobs:
pytest -sv tests/e2e/multicard/test_torchair_graph_mode.py
pytest -sv tests/e2e/multicard/test_data_parallel.py
pytest -sv tests/e2e/multicard/test_expert_parallel.py
- # pytest -sv tests/e2e/multicard/test_external_launcher.py
+ pytest -sv tests/e2e/multicard/test_external_launcher.py
pytest -sv tests/e2e/multicard/test_single_request_aclgraph.py
pytest -sv tests/e2e/multicard/test_fused_moe_allgather_ep.py
pytest -sv tests/e2e/multicard/test_ilama_lora_tp2.py
diff --git a/.github/workflows/accuracy_test.yaml b/.github/workflows/accuracy_test.yaml
deleted file mode 100644
index 7a1b5c398e..0000000000
--- a/.github/workflows/accuracy_test.yaml
+++ /dev/null
@@ -1,85 +0,0 @@
-#
-# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# This file is a part of the vllm-ascend project.
-#
-
-# This test will be triggered:
-# - PR labeled with: 'accuracy-test' & 'ready-for-test'
-name: ascend test / accuracy
-
-on:
- pull_request:
- branches:
- - 'main'
- - '*-dev'
- types: [ labeled, synchronize ]
-
-# Bash shells do not use ~/.profile or ~/.bashrc so these shells need to be explicitly
-# declared as "shell: bash -el {0}" on steps that need to be properly activated.
-# It's used to activate ascend-toolkit environment variables.
-defaults:
- run:
- shell: bash -el {0}
-
-# only cancel in-progress runs of the same workflow
-concurrency:
- group: ${{ github.workflow }}-${{ github.ref }}
- cancel-in-progress: true
-
-jobs:
- run:
- name: ""
- strategy:
- matrix:
- # Only top series models should be listed in here
- include:
- - runner: a2-1
- model_name: Qwen3-8B
- - runner: a2-1
- model_name: Qwen2.5-VL-7B-Instruct
- # To do: This model has a bug that needs to be fixed and readded
- # - runner: a2-1
- # model_name: Qwen2-Audio-7B-Instruct
- - runner: a2-2
- model_name: Qwen3-30B-A3B
- - runner: a2-2
- model_name: Qwen3-VL-30B-A3B-Instruct
- - runner: a2-2
- model_name: DeepSeek-V2-Lite
- - runner: a2-4
- model_name: Qwen3-Next-80B-A3B-Instruct
- - runner: a2-1
- model_name: Qwen3-8B-W8A8
- - runner: a2-1
- model_name: Qwen3-VL-8B-Instruct
- - runner: a2-1
- model_name: Qwen2.5-Omni-7B
- - runner: a2-1
- model_name: Meta-Llama-3.1-8B-Instruct
- - runner: a2-4
- model_name: Qwen3-30B-A3B-W8A8
- fail-fast: false
- # test will be triggered when tag 'accuracy-test' & 'ready-for-test'
- if: >-
- ${{
- contains(github.event.pull_request.labels.*.name, 'accuracy-test') &&
- contains(github.event.pull_request.labels.*.name, 'ready-for-test')
- }}
- uses: ./.github/workflows/_accuracy_test.yaml
- with:
- vllm: v0.11.0
- runner: linux-aarch64-${{ matrix.runner }}
- image: swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/cann:8.3.rc1-910b-ubuntu22.04-py3.11
- model_name: ${{ matrix.model_name }}
diff --git a/.github/workflows/vllm_ascend_test_nightly_a2.yaml b/.github/workflows/vllm_ascend_test_nightly_a2.yaml
index 19fc3b5dc9..72b97d419e 100644
--- a/.github/workflows/vllm_ascend_test_nightly_a2.yaml
+++ b/.github/workflows/vllm_ascend_test_nightly_a2.yaml
@@ -27,6 +27,7 @@ on:
pull_request:
branches:
- 'main'
+ types: [ labeled, synchronize ]
# Bash shells do not use ~/.profile or ~/.bashrc so these shells need to be explicitly
# declared as "shell: bash -el {0}" on steps that need to be properly activated.
@@ -88,3 +89,44 @@ jobs:
config_file_path: ${{ matrix.test_config.config_file_path }}
secrets:
KUBECONFIG_B64: ${{ secrets.KUBECONFIG_A2_B64 }}
+
+ single-node-accuracy-tests:
+ if: >-
+ ${{
+ github.event_name == 'schedule' ||
+ github.event_name == 'workflow_dispatch' ||
+ (
+ contains(github.event.pull_request.labels.*.name, 'accuracy-test') &&
+ contains(github.event.pull_request.labels.*.name, 'ready-for-test')
+ )
+ }}
+ strategy:
+ fail-fast: false
+ matrix:
+ test_config:
+ - os: linux-aarch64-a2-1
+ model_list:
+ - Qwen3-8B
+ - Qwen2.5-VL-7B-Instruct
+ # TODO: This model has a bug that needs to be fixed and readded
+ # - Qwen2-Audio-7B-Instruct
+ - Qwen3-8B-W8A8
+ - Qwen3-VL-8B-Instruct
+ - Qwen2.5-Omni-7B
+ - Meta-Llama-3.1-8B-Instruct
+ - os: linux-aarch64-a2-2
+ model_list:
+ - Qwen3-30B-A3B
+ - Qwen3-VL-30B-A3B-Instruct
+ - DeepSeek-V2-Lite
+ - Qwen3-30B-A3B-W8A8
+ - os: linux-aarch64-a2-4
+ model_list:
+ - Qwen3-Next-80B-A3B-Instruct
+ uses: ./.github/workflows/_e2e_nightly_single_node_models.yaml
+ with:
+ vllm: v0.11.0
+ runner: ${{ matrix.test_config.os }}
+ model_list: ${{ toJson(matrix.test_config.model_list) }}
+ image: swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/cann:8.2.rc1-910b-ubuntu22.04-py3.11
+ upload: false
diff --git a/.github/workflows/vllm_ascend_test_nightly_a3.yaml b/.github/workflows/vllm_ascend_test_nightly_a3.yaml
index d880a8bffb..00e0565956 100644
--- a/.github/workflows/vllm_ascend_test_nightly_a3.yaml
+++ b/.github/workflows/vllm_ascend_test_nightly_a3.yaml
@@ -78,6 +78,12 @@ jobs:
- name: qwq-32b-a3
os: linux-aarch64-a3-4
tests: tests/e2e/nightly/models/test_qwq_32b.py
+ - name: qwen3-30b-w8a8
+ os: linux-aarch64-a3-2
+ tests: tests/e2e/nightly/models/test_qwen3_30b_w8a8.py
+ - name: qwen3-235b-w8a8
+ os: linux-aarch64-a3-16
+ tests: tests/e2e/nightly/models/test_qwen3_235b_w8a8.py
uses: ./.github/workflows/_e2e_nightly_single_node.yaml
with:
vllm: v0.11.0
@@ -119,3 +125,4 @@ jobs:
config_file_path: ${{ matrix.test_config.config_file_path }}
secrets:
KUBECONFIG_B64: ${{ secrets.KUBECONFIG_B64 }}
+
\ No newline at end of file
diff --git a/.github/workflows/vllm_ascend_test_models.yaml b/.github/workflows/vllm_ascend_test_report.yaml
similarity index 88%
rename from .github/workflows/vllm_ascend_test_models.yaml
rename to .github/workflows/vllm_ascend_test_report.yaml
index beba0e4464..0f7a06586e 100644
--- a/.github/workflows/vllm_ascend_test_models.yaml
+++ b/.github/workflows/vllm_ascend_test_report.yaml
@@ -20,18 +20,15 @@
# 2. pull_request change the related files
# 3. workflow_dispatch with models input
-name: ascend test / models
+name: ascend test / accuracy report
on:
- schedule:
- # Runs every 6 hours
- - cron: '0 */6 * * *'
pull_request:
branches:
- 'main'
- '*-dev'
paths:
- - '.github/workflows/vllm_ascend_test_models.yaml'
+ - '.github/workflows/vllm_ascend_test_report.yaml'
- 'tests/e2e/models/test_lm_eval_correctness.py'
workflow_dispatch:
inputs:
@@ -60,27 +57,26 @@ concurrency:
jobs:
run:
strategy:
+ fail-fast: false
matrix:
include:
- - model_name: Qwen3-8B
- runner: a2-1
- - model_name: Qwen2.5-VL-7B-Instruct
- runner: a2-1
- - model_name: Qwen2-Audio-7B-Instruct
- runner: a2-1
- - model_name: Qwen3-30B-A3B
- runner: a2-2
- - model_name: Qwen3-VL-30B-A3B-Instruct
- runner: a2-2
- - model_name: DeepSeek-V2-Lite
- runner: a2-2
- fail-fast: false
- uses: ./.github/workflows/_accuracy_test.yaml
+ - runner: linux-aarch64-a2-1
+ model_list:
+ - Qwen3-8B
+ - Qwen2.5-VL-7B-Instruct
+ # TODO: This model has a bug that needs to be fixed and readded
+ # - Qwen2-Audio-7B-Instruct
+ - runner: linux-aarch64-a2-2
+ model_list:
+ - Qwen3-30B-A3B
+ - Qwen3-VL-30B-A3B-Instruct
+ - DeepSeek-V2-Lite
+ uses: ./.github/workflows/_e2e_nightly_single_node_models.yaml
with:
vllm: v0.11.0
- runner: linux-aarch64-${{ matrix.runner }}
+ runner: ${{ matrix.runner }}
image: swr.cn-southwest-2.myhuaweicloud.com/base_image/ascend-ci/cann:8.3.rc1-910b-ubuntu22.04-py3.11
- model_name: ${{ matrix.model_name }}
+ model_list: ${{ toJson(matrix.model_list) }}
upload: ${{ github.event_name == 'workflow_dispatch' && github.event.inputs.vllm-ascend-version == 'latest' }}
create_pr:
diff --git a/docs/source/assets/disaggregated_prefill_pull.png b/docs/source/assets/disaggregated_prefill_pull.png
new file mode 100644
index 0000000000..4f01dafa5e
Binary files /dev/null and b/docs/source/assets/disaggregated_prefill_pull.png differ
diff --git a/docs/source/assets/disaggregated_prefill_push.png b/docs/source/assets/disaggregated_prefill_push.png
new file mode 100644
index 0000000000..ec5537ed5f
Binary files /dev/null and b/docs/source/assets/disaggregated_prefill_push.png differ
diff --git a/docs/source/assets/eplb.png b/docs/source/assets/eplb.png
new file mode 100644
index 0000000000..996d417b3a
Binary files /dev/null and b/docs/source/assets/eplb.png differ
diff --git a/docs/source/community/versioning_policy.md b/docs/source/community/versioning_policy.md
index 1b092885c1..c5d733dab8 100644
--- a/docs/source/community/versioning_policy.md
+++ b/docs/source/community/versioning_policy.md
@@ -18,7 +18,7 @@ For example:
## Release compatibility matrix
-The table below is the release compatibility matrix for vLLM Ascend Plugin.
+The table below is the release compatibility matrix for vLLM Ascend release.
| vLLM Ascend | vLLM | Python | Stable CANN | PyTorch/torch_npu | MindIE Turbo |
|-------------|--------------|------------------|-------------|--------------------|--------------|
@@ -38,6 +38,11 @@ The table below is the release compatibility matrix for vLLM Ascend Plugin.
| v0.7.3.post1| v0.7.3 | >= 3.9, < 3.12 | 8.1.RC1 | 2.5.1 / 2.5.1 | 2.0rc1 |
| v0.7.3 | v0.7.3 | >= 3.9, < 3.12 | 8.1.RC1 | 2.5.1 / 2.5.1 | 2.0rc1 |
+For main branch of vLLM Ascend, we usually make it compatible with the latest vLLM release and a newer commit hash of vLLM. Please note that this table is usually updated. Please check it regularly.
+| vLLM Ascend | vLLM | Python | Stable CANN | PyTorch/torch_npu |
+|-------------|--------------|------------------|-------------|--------------------|
+| main | v0.11.0/83f478bb19489b41e9d208b47b4bb5a95ac171ac | >= 3.10, < 3.12 | 8.3.RC1 | 2.7.1 / 2.7.1 |
+
## Release cadence
### Release window
diff --git a/docs/source/developer_guide/contribution/multi_node_test.md b/docs/source/developer_guide/contribution/multi_node_test.md
index 1d78c8e353..1fdcc3c590 100644
--- a/docs/source/developer_guide/contribution/multi_node_test.md
+++ b/docs/source/developer_guide/contribution/multi_node_test.md
@@ -51,7 +51,7 @@ From the workflow perspective, we can see how the final test script is executed,
# - no headless(have api server)
decoder_host_index: [1]
- # Add each node's vllm serve cli command just like you runs locally
+ # Add each node's vllm serve cli command just like you run locally
deployment:
-
server_cmd: >
diff --git a/docs/source/developer_guide/feature_guide/KV_Cache_Pool_Guide.md b/docs/source/developer_guide/feature_guide/KV_Cache_Pool_Guide.md
new file mode 100644
index 0000000000..f29595f5c1
--- /dev/null
+++ b/docs/source/developer_guide/feature_guide/KV_Cache_Pool_Guide.md
@@ -0,0 +1,83 @@
+# KV Cache Pool
+
+## Why KV Cache Pool?
+
+Prefix caching is an important feature in LLM inference that can reduce prefill computation time drastically.
+
+However, the performance gain from prefix caching is highly dependent on cache hit rate, while cache hit rate can be limited if one only uses HBM for kv cache storage.
+
+Hence, KV Cache Pool is proposed to utilize various types of storages including HBM,DRAM and SSD, making a pool for KV Cache storage, while making the prefix of requests visible across all nodes, increasing the cache hit rate for all requests.
+
+vLLM Ascend currently supports [MooncakeStore](https://github.com/kvcache-ai/Mooncake): one of the most recognized KV Cache storage engine;
+
+While one can utilize mooncake store in vLLM V1 engine by setting it as a remote backend of LMCache with GPU (see [Tutorial](https://github.com/LMCache/LMCache/blob/dev/examples/kv_cache_reuse/remote_backends/mooncakestore/README.md)), we find it would be better to integrate a connector that directly supports mooncake store and can utilize the data transfer strategy to one that is best fit to Huawei NPU hardware.
+
+Hence, we propose to integrate Mooncake Store with a brand new **MooncakeStoreConnectorV1**, which is indeed largly inspired by **LMCacheConnectorV1** (see the `How is MooncakestoreConnectorV1 Implemented?` section).
+
+## Usage
+
+vLLM Ascend Currently supports Mooncake Store for KV Cache Pool. To enable Mooncake Store, one needs to config `kv-transfer-config` and choose `MooncakeStoreConnector` as KV Connector.
+
+For step-by-step deployment and configuration, please refer to the KV Pool User Guide at `vllm-ascend/docs/source/user_guide/feature_guide/kv_pool_mooncake.md`
+
+## How it works?
+The KV Cache Pool integrates multiple memory tiers (HBM, DRAM, SSD, etc.) through a connector-based architecture.
+
+Each connector implements a unified interface for storing, retrieving, and transferring KV blocks between tiers, depending on access frequency and hardware bandwidth.
+
+When combined with vLLM’s Prefix Caching mechanism, the pool enables efficient caching both locally (in HBM) and globally (via Mooncake), ensuring that frequently used prefixes remain hot while less frequently accessed KV data can spill over to lower-cost memory.
+
+### 1. Combining KV Cache Pool with HBM Prefix Caching
+Prefix Caching with HBM is already supported by the vLLM V1 Engine.
+By introducing KV Connector V1, users can seamlessly combine HBM-based Prefix Caching with Mooncake-backed KV Pool.
+
+ The user can enable both features simply by enabling Prefix Caching, which is enabled by default in vLLM V1 unless the --no_enable_prefix_caching flag is set, and setting up the KV Connector for KV Pool(e.g. the MooncakeStoreConnector)
+
+**Workflow**:
+
+1. The engine first checks for prefix hits in the HBM cache.
+
+2. After getting the number of hit tokens on HBM, it queries the KV Pool via the connector, if there is additional hits in KV Pool, we get the **additional blocks only** from KV Pool, and get the rest of the blocks directly from HBM to minimize the data transfer latency.
+
+3. After the KV Caches in KV Pool is load into HBM, the remaining process is the same as Prefix Caching in HBM.
+
+### 2. Combining KV Cache Pool with Mooncake PD Disaggregation
+
+When used together with Mooncake PD (Prefill-Decode) Disaggregation, the KV Cache Pool can further decouple prefill and decode stages across devices or nodes.
+
+Currently, we only perform put and get operation of KV Pool for **Prefiil Nodes**, and Decode Nodes get their KV Cache from Mooncake P2P KV Connector, i.e. MooncakeConnector.
+
+ The key benefit of doing this is that we can keep the gain in performance by computing less with Prefix Caching from HBM and KV Pool for Prefill Nodes while not sacrificing the data transfer efficiency between Prefill and Decode nodes with P2P KV Connector that transfer KV Caches between NPU devices directly.
+
+To Enable this feature, we need to setup both Mooncake Connector and Mooncake Store connector with a Multi Connector, which is a KV Connector class provided by vLLM that can call multiple KV Connectors in specific order;
+
+For details, please also refer to the Mooncake Connector Store Deployment Guide.
+
+## How is MooncakestoreConnectorV1 Implemented?
+**MooncakestoreConnectorV1** inhereits the KV Connector V1 class in vLLM V1: through implementing the required methods defined in the KV connector V1 base class, one can integrate a thrid-party KV cache transfer/storage backend into the vLLM framework.
+
+MooncakeStoreConnectorV1 is also largly inspried by LMCacheConnectorV1 in term of the `Lookup Engine`/`Lookup Client` design for looking up KV cache keys, and the `ChunkedTokenDatabase` class for processing tokens into prefix-aware hashes as well as other hashing related designs. On top of this, we have also added our own design including `KVTransferThread` that allows async `get` and `put` of KV caches with multi-threading, and NPU-related data transfer optimization such as removing the `LocalBuffer` in LMCache to remove redundant data transfer.
+
+The KV Connector methods that need to be implemented can be categorized into scheduler-side methods that are called in V1 scheduler and worker-side methods that are called in V1 worker, namely:
+### KV Connector Scheduler-Side Methods:
+`get_num_new_matched_tokens`: Get prefix cache hit in number of tokens through looking up into the KV pool.
+`update_states_after_alloc`: Update KVConnector state after temporary buffer alloc.
+`build_connector_meta`: Attach the connector metadata to the request object.
+`request_finished`: Once a request is finished, determine whether request blocks should be freed now or will be sent asynchronously and freed later.
+### Connector Worker-Side Methods:
+`register_kv_caches`: Register KV cache buffers needed for KV cache transfer.
+`start_load_kv`: Perform KV cache load operation that transfers KV cache from storage to device.
+`wait_for_layer_load`: Optional; Wait for layer load in layerwise + async KV load scenario.
+`save_kv_layer`: Optional Do layerwise KV cache put into KV Pool.
+`wait_for_save`: Wait for KV Save to finish if async KV cache save/put.
+`get_finished` Get request that finished KV transfer, `done_sending` if `put` finished, `done_reciving` if `get` finished.
+
+## DFX
+1. When looking up a key in KV Pool, if we cannot find the key, there is no Cache Hit for this specific block; we return no hit for this block and do not look up further blocks for current request.
+2. Similaly, when we are trying to put a block into KV Pool and failed, we do not put further blocks (subject to change).
+
+## Limitation
+
+1. Currently, Mooncake Store for vLLM-Ascend only supports DRAM as the storage for KV Cache pool.
+
+2. For now, if we successfully looked up a key and found it exists, but failed to get it when calling KV Pool's get function, we just output a log indicating the get operation failed and keep going; hence, the accuracy of that specific request may be affected. gWe will handle this situation by falling back the request and re-compute everything assuming there's no prefix cache hit (or even better, revert only one block and keep using the Prefix Caches before that).
diff --git a/docs/source/developer_guide/feature_guide/Multi_Token_Prediction.md b/docs/source/developer_guide/feature_guide/Multi_Token_Prediction.md
new file mode 100644
index 0000000000..27986aabbb
--- /dev/null
+++ b/docs/source/developer_guide/feature_guide/Multi_Token_Prediction.md
@@ -0,0 +1,112 @@
+# Multi Token Prediction (MTP)
+
+## Why We Need MTP
+MTP boosts inference performance by parallelizing the prediction of multiple tokens, shifting from single-token to multi-token generation. This approach significantly increases generation throughput and achieves multiplicative acceleration in inference speed—all without compromising output quality.
+
+## How to Use MTP
+To enable MTP for DeepSeek-V3 models, add the following parameter when starting the service:
+
+`--speculative_config={"method": "deepseek_mtp", "num_speculative_tokens": 1, "disable_padded_drafter_batch": False}`
+
+- `num_speculative_tokens`: The number of speculative tokens which enable model to predict multiple tokens at once, if provided. It will default to the number in the draft model config if present, otherwise, it is required.
+- `disable_padded_drafter_batch`: Disable input padding for speculative decoding. If set to True, speculative input batches can contain sequences of different lengths, which may only be supported by certain attention backends. This currently only affects the MTP method of speculation, default is False.
+
+## How It Works
+
+### Module Architecture
+
+```
+vllm_ascend
+├── sample
+│ ├── rejection_sample.py
+├── spec_decode
+│ ├── mtp_proposer.py
+└───────────
+```
+
+**1. sample**
+
+- *rejection_sample.py*: During decoding, the main model processes the previous round’s output token and the predicted token together (computing 1+k tokens simultaneously). The first token is always correct, while the second token—referred to as the **bonus token**—is uncertain since it is derived from speculative prediction, thus We employ **Greedy Strategy** and **Rejection Sampling Strategy** to determine whether the bonus token should be accepted. The module structure consists of an `AscendRejectionSampler` class with a forward method that implements the specific sampling logic.
+
+```
+rejection_sample.py
+├── AscendRejectionSampler
+│ ├── forward
+```
+
+**2. spec_decode**
+
+This section encompasses the model preprocessing for spec-decode, primarily structured as follows: it includes loading the model, executing a dummy run, and generating token ids. These steps collectively form the model data construction and forward invocation for a single spec-decode operation.
+- *mtp_proposer.py*: Configure vLLM-Ascend to use speculative decoding where proposals are generated by deepseek mtp layer.
+
+```
+mtp_proposer.py
+├── Proposer
+│ ├── load_model
+│ ├── dummy_run
+│ ├── generate_token_ids
+│ ├── _prepare_inputs
+│ ├── _propose
+```
+
+### Algorithm
+
+**1. Reject_Sample**
+- *Greedy Strategy*
+
+Verify whether the token generated by the main model matches the speculative token predicted by MTP in the previous round. If they match exactly, accept the bonus token; otherwise, reject it and any subsequent tokens derived from that speculation.
+
+- *Rejection Sampling Strategy*
+
+This method introduces stochasticity in rejection sampling.
+
+For each draft token, acceptance is determined by verifying whether the inequality `P_target / P_draft ≥ U` holds, where `P_target` represents the probability assigned to the current draft token by the target model, `P_draft` denotes the probability assigned by the draft model, and `U` is a random number sampled uniformly from the interval [0, 1).
+
+The decision logic for each draft token is as follows: if the inequality `P_target / P_draft ≥ U` holds, the draft token is accepted as output; conversely, if `P_target / P_draft < U`, the draft token is rejected.
+
+When a draft token is rejected, a recovery sampling process is triggered where a "recovered token" is resampled from the adjusted probability distribution defined as `Q = max(P_target - P_draft, 0)`. In the current MTP implementation, since `P_draft` is not provided and defaults to 1, the formulas simplify such that token acceptance occurs when `P_target ≥ U,` and the recovery distribution becomes `Q = max(P_target - 1, 0)`.
+
+**2. Performance**
+
+If the bonus token is accepted, the MTP model performs inference for (num_speculative +1) tokens, including original main model output token and bonus token. If rejected, inference is performed for less token, determining on how many tokens accepted.
+
+## DFX
+
+### Method Validation
+
+- Currently, the spec_decode scenario only supports methods such as ngram, eagle, eagle3, and deepseek_mtp. If an incorrect parameter is passed for the method, the code will raise an error to alert the user that an incorrect method was provided.
+
+```
+def get_spec_decode_method(method,
+ vllm_config,
+ device,
+ runner,
+ is_torchair_graph=False):
+ if method == "ngram":
+ return NgramProposer(vllm_config, device, runner)
+ elif method in ["eagle", "eagle3"]:
+ return EagleProposer(vllm_config, device, runner)
+ elif method == 'deepseek_mtp':
+ if is_torchair_graph:
+ return TorchairMtpProposer(vllm_config, device, runner)
+ return MtpProposer(vllm_config, device, runner)
+ else:
+ raise ValueError("Unknown speculative decoding method: "
+ f"{method}")
+```
+
+### Integer Validation
+- The current npu_fused_infer_attention_score operator only supports integers less than 16 per decode round. Therefore, the maximum supported value for MTP is 15. If a value greater than 15 is provided, the code will raise an error and alert the user.
+
+```
+if self.speculative_config:
+ spec_token_num = self.speculative_config.num_speculative_tokens
+ self.decode_threshold += spec_token_num
+ assert self.decode_threshold <= 16, f"decode_threshold exceeded \
+ npu_fused_infer_attention_score TND layout's limit of 16, \
+ got {self.decode_threshold}"
+```
+
+## Limitation
+- Due to the fact that only a single layer of weights is exposed in DeepSeek's MTP, the accuracy and performance are not effectively guaranteed in scenarios where MTP > 1 (especially MTP ≥ 3). Moreover, due to current operator limitations, MTP supports a maximum of 15.
+- In the fullgraph mode with MTP > 1, the capture size of each aclgraph must be an integer multiple of (num_speculative_tokens + 1).
diff --git a/docs/source/developer_guide/feature_guide/disaggregated_prefill.md b/docs/source/developer_guide/feature_guide/disaggregated_prefill.md
new file mode 100644
index 0000000000..46d3dbe9aa
--- /dev/null
+++ b/docs/source/developer_guide/feature_guide/disaggregated_prefill.md
@@ -0,0 +1,103 @@
+# Disaggregated-prefill
+
+## Why disaggregated-prefill?
+
+This feature addresses the need to optimize the **Time Per Output Token (TPOT)** and **Time To First Token (TTFT)** in large-scale inference tasks. The motivation is two-fold:
+
+1. **Adjusting Parallel Strategy and Instance Count for P and D Nodes**
+ Using the disaggregated-prefill strategy, this feature allows the system to flexibly adjust the parallelization strategy (e.g., data parallelism (dp), tensor parallelism (tp), and expert parallelism (ep)) and the instance count for both P (Prefiller) and D (Decoder) nodes. This leads to better system performance tuning, particularly for **TTFT** and **TPOT**.
+
+2. **Optimizing TPOT**
+ Without disaggregated-prefill strategy, prefill tasks are inserted during decoding, which results in inefficiencies and delays. disaggregated-prefill solves this by allowing for better control over the system’s **TPOT**. By managing chunked prefill tasks effectively, the system avoids the challenge of determining the optimal chunk size and provides more reliable control over the time taken for generating output tokens.
+
+---
+
+## Usage
+
+vLLM Ascend currently supports two types of connectors for handling KV cache management:
+- **MooncakeConnector**: D nodes pull KV cache from P nodes.
+- **MooncakeLayerwiseConnector**: P nodes push KV cache to D nodes in a layered manner.
+
+For step-by-step deployment and configuration, refer to the following guide:
+[https://vllm-ascend.readthedocs.io/en/latest/tutorials/multi_node_pd_disaggregation_mooncake.html](https://vllm-ascend.readthedocs.io/en/latest/tutorials/multi_node_pd_disaggregation_mooncake.html)
+
+---
+
+## How It Works
+
+### 1. Design Approach
+
+Under the disaggregated-prefill, a global proxy receives external requests, forwarding prefill to P nodes and decode to D nodes; the KV cache (key–value cache) is exchanged between P and D nodes via peer-to-peer (P2P) communication.
+
+### 2. Implementation Design
+
+Our design diagram is shown below, illustrating the pull and push schemes respectively.
+
+
+
+#### Mooncake Connector:
+
+1. The request is sent to the Proxy’s `_handle_completions` endpoint.
+2. The Proxy calls `select_prefiller` to choose a P node and forwards the request, configuring `kv_transfer_params` with `do_remote_decode=True`, `max_tokens=1`, and `min_tokens=1`.
+3. After the P node’s scheduler finishes prefill, `update_from_output` invokes the schedule connector’s `request_finished` to defer KV cache release, constructs `kv_transfer_params` with `do_remote_prefill=True`, and returns to the Proxy.
+4. The Proxy calls `select_decoder` to choose a D node and forwards the request.
+5. On the D node, the scheduler marks the request as `RequestStatus.WAITING_FOR_REMOTE_KVS`, pre-allocates KV cache, calls `kv_connector_no_forward` to pull the remote KV cache, then notifies the P node to release KV cache and proceeds with decoding to return the result.
+
+#### Mooncake Layerwise Connector:
+
+1. The request is sent to the Proxy’s `_handle_completions` endpoint.
+2. The Proxy calls `select_decoder` to choose a D node and forwards the request, configuring `kv_transfer_params` with `do_remote_prefill=True` and setting the `metaserver` endpoint.
+3. On the D node, the scheduler uses `kv_transfer_params` to mark the request as `RequestStatus.WAITING_FOR_REMOTE_KVS`, pre-allocates KV cache, then calls `kv_connector_no_forward` to send a request to the metaserver and waits for the KV cache transfer to complete.
+4. The Proxy’s `metaserver` endpoint receives the request, calls `select_prefiller` to choose a P node, and forwards it with `kv_transfer_params` set to `do_remote_decode=True`, `max_tokens=1`, and `min_tokens=1`.
+5. During processing, the P node’s scheduler pushes KV cache layer-wise; once all layers pushing is complete, it releases the request and notifies the D node to begin decoding.
+6. The D node performs decoding and returns the result.
+
+### 3. Interface Design
+
+Taking MooncakeConnector as an example, the system is organized into three primary classes:
+- **MooncakeConnector**: Base class that provides core interfaces.
+- **MooncakeConnectorScheduler**: Interface for scheduling the connectors within the engine core, responsible for managing KV cache transfer requirements and completion.
+- **MooncakeConnectorWorker**: Interface for managing KV cache registration and transfer in worker processes.
+
+### 4. Specifications Design
+
+This feature is flexible and supports various configurations, including setups with MLA and GQA models. It is compatible with A2 and A3 hardware configurations and facilitates scenarios involving both equal and unequal TP setups across multiple P and D nodes.
+
+| Feature | Status |
+|-------------------------------|----------------|
+| A2 | 🟢 Functional |
+| A3 | 🟢 Functional |
+| equal TP configuration | 🟢 Functional |
+| unequal TP configuration | 🟢 Functional |
+| MLA | 🟢 Functional |
+| GQA | 🟢 Functional |
+
+- 🟢 Functional: Fully operational, with ongoing optimizations.
+- 🔵 Experimental: Experimental support, interfaces and functions may change.
+- 🚧 WIP: Under active development, will be supported soon.
+- 🟡 Planned: Scheduled for future implementation (some may have open PRs/RFCs).
+- 🔴 NO plan/Deprecated: No plan or deprecated by vLLM.
+
+---
+
+## DFX Analysis
+
+### 1. Config Parameter Validation
+
+Validate KV transfer config by checking whether the kv_connector type is supported and whether kv_connector_module_path exists and is loadable. On transfer failures, emit clear error logs for diagnostics.
+
+### 2. Port Conflict Detection
+
+Before startup, perform a port-usage check on configured ports (e.g., rpc_port, metrics_port, http_port/metaserver) by attempting to bind. If a port is already in use, fail fast and log an error.
+
+### 3. PD Ratio Validation
+
+Under non-symmetric PD scenarios, validate the P-to-D tp ratio against expected and scheduling constraints to ensure correct and reliable operation.
+
+---
+
+## Limitations
+
+- Heterogeneous P and D nodes are not supported—for example, running P nodes on A2 and D nodes on A3.
+
+- In non-symmetric TP configurations, only cases where the P nodes have a higher TP degree than the D nodes and the P TP count is an integer multiple of the D TP count are supported (i.e., P_tp > D_tp and P_tp % D_tp = 0).
diff --git a/docs/source/developer_guide/feature_guide/eplb_swift_balancer.md b/docs/source/developer_guide/feature_guide/eplb_swift_balancer.md
new file mode 100644
index 0000000000..af6e90db17
--- /dev/null
+++ b/docs/source/developer_guide/feature_guide/eplb_swift_balancer.md
@@ -0,0 +1,222 @@
+# Expert Parallelism Load Balancer (EPLB)
+
+## Why We Need EPLB?
+When using Expert Parallelism (EP), different experts are assigned to different NPUs. Given that the load of various experts may vary depending on the current workload, it is crucial to maintain balanced loads across different NPUs. We adopt a redundant experts strategy by duplicating heavily-loaded experts. Then, we heuristically pack these duplicated experts onto NPUs to ensure load balancing across them. Moreover, thanks to the group-limited expert routing used in MoE models, we also attempt to place experts of the same group on the same node to reduce inter-node data traffic, whenever possible.
+
+To facilitate reproduction and deployment, Vllm Ascend supported deployed EP load balancing algorithm in `vllm_ascend/eplb/core/policy`. The algorithm computes a balanced expert replication and placement plan based on the estimated expert loads. Note that the exact method for predicting expert loads is outside the scope of this repository. A common method is to use a moving average of historical statistics.
+
+
+## How to Use EPLB?
+Please refer to the EPLB section of the user guide for detailed information: [How to Use EPLB](../../user_guide/feature_guide/eplb_swift_balancer.md)
+
+## How It Works?
+**EPLB Module Architecture**
+
+```
+vllm_ascend
+├── eplb
+│ ├── adaptor
+│ │ ├── abstract_adaptor.py
+│ │ ├── vllm_adaptor.py
+│ ├── core
+│ │ ├── policy
+│ │ │ ├── policy_abstract.py
+│ │ │ ├── policy_dynamic_ep.py
+│ │ │ ├── policy_dynamic_ep_v2.py
+│ │ │ ├── policy_factory.py
+│ │ │ ├── policy_flashlb.py
+│ │ ├── eplb_device_transfer_loader.py
+│ │ ├── eplb_utils.py
+│ │ ├── eplb_worker.py
+│ ├── eplb_updator.py
+│ ├── utils.py
+└───────────
+```
+
+**1. Adaptor Module**
+*Handles registration and adaptation for different MoE model types*
+- `abstract_adaptor.py`
+ Abstract base class defining unified registration interfaces for EPLB adapters
+- `vllm_adaptor.py`
+ Implementation supporting Qwen3-MoE and DeepSeek models, standardizing parameter handling for policy algorithms
+
+**2. Core Module**
+*Implements core algorithms, updates, and asynchronous processing*
+- **Policy Submodule**
+ *Load balancing algorithms with factory pattern instantiation*
+ - `policy_abstract.py`
+ Abstract class for load balancing strategy interfaces
+ - `policy_dynamic_ep.py`
+ Default implementation of open-source EPLB paper algorithm
+ - `policy_dynamic_ep_v2.py`
+ Enhanced version optimizing expert swaps for low-bandwidth devices (e.g., A2)
+ - `policy_flashlb.py`
+ Threshold-based adjustment reducing operational costs through layer-wise fluctuation detection
+ - `policy_factory.py`
+ Strategy factory for automatic algorithm instantiation
+
+- `eplb_device_transfer_loader.py`
+ Manages expert table/weight transmission and updates
+- `eplb_utils.py`
+ Utilities for expert table initialization and mapping
+- `eplb_worker.py`
+ Asynchronous algorithm orchestration and result processing
+
+**3. System Components**
+- `eplb_updator.py`
+ Central coordinator for load balancing during inference workflows
+- `utils.py`
+ General utilities for EPLB interface registration
+
+*Key Optimizations:*
+1. Maintained original structure while improving technical clarity
+2. Standardized terminology
+3. Enhanced algorithm differentiation through concise descriptors
+4. Improved scoping through hierarchical presentation
+5. Preserved file/class relationships while optimizing readability
+
+### Default Algorithm
+#### Hierarchical Load Balancing
+When the number of server nodes evenly divides the number of expert groups, we use the hierarchical load balancing policy to leverage group-limited expert routing. We first pack the expert groups onto nodes evenly, ensuring balanced loads across different nodes. Then, we replicate the experts within each node. Finally, we pack the replicated experts onto individual NPUs to ensure load balancing across them. The hierarchical load balancing policy can be used in the prefilling stage with a smaller expert-parallel size.
+
+#### Global Load Balancing
+In other cases, we use the global load balancing policy, which replicates experts globally regardless of expert groups, and packs the replicated experts onto individual NPUs. This policy can be adopted in the decoding stage with a larger expert-parallel size.
+
+### Add a New EPLB Policy
+If you want to add a new eplb policy to vllm_ascend, you must follow these steps:
+1. Inherit the `EplbPolicy` abstract class of `policy_abstract.py` and override the `rebalance_experts` interface, ensuring consistent input parameters `current_expert_table`, `expert_workload` and return types `newplacement`.
+For example:
+
+```python
+class RandomLoadBalance(EplbPolicy):
+
+ def __init__(self, config: DynamicConfig):
+ super().__init__(config)
+
+ def rebalance_experts(self, current_expert_table, expert_workload):
+ new_table = copy.deepcopy(current_expert_table)
+ num_layers = len(current_expert_table)
+
+ for i in range(num_layers):
+ # randomly choose two card
+ # indices = random.sample(range(num_card), 2)
+ indices = [3, 1]
+
+ # swap redundant experts
+ expert_id_to_exchange = new_table[i][indices[0]][-1].clone()
+ new_table[i][indices[0]][-1] = new_table[i][indices[1]][-1]
+ new_table[i][indices[1]][-1] = expert_id_to_exchange
+
+ return 1, [-i for i in range(num_layers)], new_table
+```
+
+2. To add a new EPLB algorithm, include the policy type and its corresponding implementation class in the `PolicyFactory` of `policy_factory.py`.
+
+### Add a New MoE Model
+**Implementation Guide for Model Integration**
+
+1. **Adapter File Modification**
+ - Inherit or modify `vllm_ascend/eplb/adaptor/vllm_adaptor.py`
+ - Add processing logic for key parameters:
+ - `num_dense_layers`
+ - `global_expert_num`
+ - `num_roe_layers`
+ - Ensure parameter synchronization in the `model_register` function.
+
+ For example:
+
+ Modify `__init__` of `vllm_adaptor.py` to add a new moe model eplb params:
+
+ ```python
+ if self.model.config.model_type == "qwen3_moe":
+ self.num_dense_layers = 0
+ self.global_expert_num = self.model.config.num_experts
+ ```
+
+ Modify `model_register` of `vllm_adaptor.py` to register eplb params for new moe model:
+
+ ```python
+ if config.model_type == "qwen3_moe":
+ model.num_moe_layers = config.num_hidden_layers
+ ```
+
+2. **MoE Feature Integration**
+ - Extend `vllm_ascend/eplb/utils.py` with MoE-specific methods
+ - Implement required functionality for expert routing or weight management
+
+3. **Registration Logic Update**
+ - Add patch logic within the `model_register` function
+ - Maintain backward compatibility with existing model types
+
+4. **Validation & Testing**
+ - Verify parameter consistency across layers
+ - Test cross-device communication for expert tables
+ - Benchmark against baseline implementations (e.g., Qwen3-MoE)
+
+*Key Implementation Notes:*
+- Preserve existing interface contracts in abstract classes
+- Use decorators for non-intrusive patch integration
+- Leverage `eplb_utils.py` for shared expert mapping operations
+## DFX
+### Parameter Validation
+#### Integer Parameters
+All integer input parameters must explicitly specify their maximum and minimum values and be subject to valid value validation. For example, `num_iterations_eplb_update` must be greater than 0:
+
+```python
+ @staticmethod
+ def check_iterations(iterations):
+ if not isinstance(iterations, int):
+ raise TypeError(f"The {iterations} is not int.")
+ if iterations <= 0:
+ raise ValueError(
+ f"The {iterations} can not less than or equal to 0.")
+ if iterations > sys.maxsize:
+ raise ValueError(
+ f"The {iterations} can not large than {sys.maxsize}")
+```
+
+#### File Path
+The file path for EPLB must be checked for legality, such as whether the file path is valid and whether it has appropriate read and write permissions. For example:
+
+```python
+ @staticmethod
+ def check_expert_map_path(expert_map):
+ if expert_map is None:
+ return
+ if not isinstance(expert_map, str):
+ raise TypeError("The expert_map is not str.")
+ if not expert_map.strip():
+ raise ValueError("The expert_map is not empty.")
+ _, ext = os.path.splitext(expert_map)
+ if ext.lower() != ".json":
+ raise TypeError("The expert_map is not json.")
+ if not os.path.exists(expert_map):
+ raise ValueError("The expert_map is not exist.")
+ try:
+ with open(expert_map, "w", encoding='utf-8') as f:
+ f.read()
+ except Exception as e:
+ raise IOError(
+ f"Fail read expert info from {expert_map}, please check the reading permission of {expert_map} : {e}"
+ )
+
+```
+
+### Function Specifications
+#### Initialization Function
+All EPLB parameters must be initialized by default during initialization, with specified parameter types and default values for proper handling.
+
+#### General Functions
+All method arguments must specify parameter types and default values, and functions must include default return value handling for default arguments. It is recommended to use `try-except` blocks to handle the function body, specifying the type of exception captured and the failure handling (e.g., logging exceptions or returning a failure status).
+
+### Consistency
+#### Expert Map
+The expert map must be globally unique during initialization and update. In a multi-node scenario during initialization, distributed communication should be used to verify the consistency of expert maps across each rank. If they are inconsistent, the user should be notified which ranks have inconsistent maps.
+During the update process, if only a few layers or the expert table of a certain rank has been changed, the updated expert table must be synchronized with the EPLB's context to ensure global consistency.
+
+#### Expert Weight
+When updating expert weights, ensure that the memory allocated for the expert weights has been released, or that the expert (referring to the old version) is no longer in use.
+
+## Limitation
+Before using EPLB, start the script and add `export DYNAMIC_EPLB="true"`.
+Before performing load data collection (or performance data collection), start the script and add `export EXPERT_MAP_RECORD="true"`.
diff --git a/docs/source/developer_guide/feature_guide/index.md b/docs/source/developer_guide/feature_guide/index.md
index 6f7a5d31c6..91f6badb4b 100644
--- a/docs/source/developer_guide/feature_guide/index.md
+++ b/docs/source/developer_guide/feature_guide/index.md
@@ -7,5 +7,9 @@ This section provides an overview of the features implemented in vLLM Ascend. De
:maxdepth: 1
patch
ModelRunner_prepare_inputs
+disaggregated_prefill
+eplb_swift_balancer.md
+Multi_Token_Prediction
ACL_Graph
+KV_Cache_Pool_Guide
:::
diff --git a/docs/source/developer_guide/performance/optimization_and_tuning.md b/docs/source/developer_guide/performance/optimization_and_tuning.md
index fd5947031b..953ec389a2 100644
--- a/docs/source/developer_guide/performance/optimization_and_tuning.md
+++ b/docs/source/developer_guide/performance/optimization_and_tuning.md
@@ -70,7 +70,7 @@ Make sure your vLLM and vllm-ascend are installed after your python configuratio
#### 1.1. Install optimized `python`
-Python supports **LTO** and **PGO** optimization starting from version `3.6` and above, which can be enabled at compile time. And we have offered optimized `python` packages directly to users for the sake of convenience. You can also reproduce the `python` built following this [tutorial](https://www.hiascend.com/document/detail/zh/Pytorch/600/ptmoddevg/trainingmigrguide/performance_tuning_0063.html) according to your specific scenarios.
+Python supports **LTO** and **PGO** optimization starting from version `3.6` and above, which can be enabled at compile time. And we have offered optimized `python` packages directly to users for the sake of convenience. You can also reproduce the `python` build following this [tutorial](https://www.hiascend.com/document/detail/zh/Pytorch/600/ptmoddevg/trainingmigrguide/performance_tuning_0063.html) according to your specific scenarios.
```{code-block} bash
:substitutions:
@@ -116,7 +116,7 @@ export LD_PRELOAD=/usr/lib/"$(uname -i)"-linux-gnu/libjemalloc.so.2 $LD_PRELOAD
#### 2.2. Tcmalloc
-**Tcmalloc (Thread Counting Malloc)** is a universal memory allocator that improves overall performance while ensuring low latency by introducing a multi-level cache structure, reducing mutex competition and optimizing large object processing flow. Find more details [here](https://www.hiascend.com/document/detail/zh/Pytorch/700/ptmoddevg/trainingmigrguide/performance_tuning_0068.html).
+**Tcmalloc (Thread Caching Malloc)** is a universal memory allocator that improves overall performance while ensuring low latency by introducing a multi-level cache structure, reducing mutex competition and optimizing large object processing flow. Find more details [here](https://www.hiascend.com/document/detail/zh/Pytorch/700/ptmoddevg/trainingmigrguide/performance_tuning_0068.html).
```{code-block} bash
:substitutions:
diff --git a/docs/source/faqs.md b/docs/source/faqs.md
index f997eb9fcf..3145466e77 100644
--- a/docs/source/faqs.md
+++ b/docs/source/faqs.md
@@ -21,7 +21,7 @@ Below series are NOT supported yet:
- Atlas 200I A2 (Ascend-cann-kernels-310b) unplanned yet
- Ascend 910, Ascend 910 Pro B (Ascend-cann-kernels-910) unplanned yet
-From a technical view, vllm-ascend support would be possible if the torch-npu is supported. Otherwise, we have to implement it by using custom ops. We are also welcome to join us to improve together.
+From a technical view, vllm-ascend support would be possible if the torch-npu is supported. Otherwise, we have to implement it by using custom ops. We also welcome you to join us to improve together.
### 2. How to get our docker containers?
@@ -38,7 +38,7 @@ docker pull quay.nju.edu.cn/ascend/vllm-ascend:$TAG
```
#### Load Docker Images for offline environment
-If you want to use container image for offline environments (no internet connection), you need to download container image in a environment with internet access:
+If you want to use container image for offline environments (no internet connection), you need to download container image in an environment with internet access:
**Exporting Docker images:**
@@ -74,7 +74,7 @@ There are many channels that you can communicate with our community developers /
- Submit a GitHub [issue](https://github.com/vllm-project/vllm-ascend/issues?page=1).
- Join our [weekly meeting](https://docs.google.com/document/d/1hCSzRTMZhIB8vRq1_qOOjx4c9uYUxvdQvDsMV2JcSrw/edit?tab=t.0#heading=h.911qu8j8h35z) and share your ideas.
-- Join our [WeChat](https://github.com/vllm-project/vllm-ascend/issues/227) group and ask your quenstions.
+- Join our [WeChat](https://github.com/vllm-project/vllm-ascend/issues/227) group and ask your questions.
- Join our ascend channel in [vLLM forums](https://discuss.vllm.ai/c/hardware-support/vllm-ascend-support/6) and publish your topics.
### 5. What features does vllm-ascend V1 supports?
@@ -142,7 +142,7 @@ In scenarios where NPUs have limited high bandwidth memory (HBM) capacity, dynam
- **Configure `PYTORCH_NPU_ALLOC_CONF`**: Set this environment variable to optimize NPU memory management. For example, you can use `export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True` to enable virtual memory feature to mitigate memory fragmentation caused by frequent dynamic memory size adjustments during runtime. See details in: [PYTORCH_NPU_ALLOC_CONF](https://www.hiascend.com/document/detail/zh/Pytorch/700/comref/Envvariables/Envir_012.html).
### 14. Failed to enable NPU graph mode when running DeepSeek.
-You may encounter the following error if running DeepSeek with NPU graph mode is enabled. The allowed number of queries per KV when enabling both MLA and Graph mode is {32, 64, 128}. **Thus this is not supported for DeepSeek-V2-Lite**, as it only has 16 attention heads. The NPU graph mode support on DeepSeek-V2-Lite will be implemented in the future.
+Enabling NPU graph mode for DeepSeek may trigger an error. This is because when both MLA and NPU graph mode are active, the number of queries per KV head must be 32, 64, or 128. However, DeepSeek-V2-Lite has only 16 attention heads, which results in 16 queries per KV—a value outside the supported range. Support for NPU graph mode on DeepSeek-V2-Lite will be added in a future update.
And if you're using DeepSeek-V3 or DeepSeek-R1, please make sure after the tensor parallel split, num_heads/num_kv_heads is {32, 64, 128}.
diff --git a/docs/source/index.md b/docs/source/index.md
index 940a619b4b..8c087447a8 100644
--- a/docs/source/index.md
+++ b/docs/source/index.md
@@ -25,7 +25,7 @@ vLLM Ascend plugin (vllm-ascend) is a community maintained hardware plugin for r
This plugin is the recommended approach for supporting the Ascend backend within the vLLM community. It adheres to the principles outlined in the [[RFC]: Hardware pluggable](https://github.com/vllm-project/vllm/issues/11162), providing a hardware-pluggable interface that decouples the integration of the Ascend NPU with vLLM.
-By using vLLM Ascend plugin, popular open-source models, including Transformer-like, Mixture-of-Expert, Embedding, Multi-modal LLMs can run seamlessly on the Ascend NPU.
+By using vLLM Ascend plugin, popular open-source models, including Transformer-like, Mixture-of-Experts, Embedding, Multi-modal LLMs can run seamlessly on the Ascend NPU.
## Documentation
diff --git a/docs/source/tutorials/DeepSeek-V3.2-Exp.md b/docs/source/tutorials/DeepSeek-V3.2-Exp.md
index c3e7cbf65b..415134f4ed 100644
--- a/docs/source/tutorials/DeepSeek-V3.2-Exp.md
+++ b/docs/source/tutorials/DeepSeek-V3.2-Exp.md
@@ -32,13 +32,13 @@ If you want to deploy multi-node environment, you need to verify multi-node comm
:::::{tab-set}
::::{tab-item} Use deepseek-v3.2 docker image
-Currently, we provide the all-in-one images `quay.io/ascend/vllm-ascend:v0.11.0rc0-deepseek-v3.2-exp`(for Atlas 800 A2) and `quay.io/ascend/vllm-ascend:v0.11.0rc0-a3-deepseek-v3.2-exp`(for Atlas 800 A3).
+In `vllm-ascend:v0.11.0rc0` release, we provide the all-in-one images `quay.io/ascend/vllm-ascend:v0.11.0rc0-deepseek-v3.2-exp`(for Atlas 800 A2) and `quay.io/ascend/vllm-ascend:v0.11.0rc0-a3-deepseek-v3.2-exp`(for Atlas 800 A3).
Refer to [using docker](../installation.md#set-up-using-docker) to set up environment using Docker, remember to replace the image with deepseek-v3.2 docker image.
:::{note}
-The image is based on a specific version and will not continue to release new version.
-Only AArch64 architecture are supported currently due to extra operator's installation limitations.
+- The image is based on a specific version `vllm-ascend:v0.11.0rc0` and will not continue to release new version. Move to another tab `Use vllm-ascend docker image` for latest support of deepseek-v3.2 on vllm-ascend.
+- Only AArch64 architecture are supported currently due to extra operator's installation limitations.
:::
::::
@@ -66,23 +66,7 @@ wget https://vllm-ascend.obs.cn-north-4.myhuaweicloud.com/vllm-ascend/a3/custom_
pip install custom_ops-1.0-cp311-cp311-linux_aarch64.whl
```
-3. Download and install `MLAPO`.
-
-```shell
-wget https://vllm-ascend.obs.cn-north-4.myhuaweicloud.com/vllm-ascend/a3/CANN-custom_ops-mlapo-linux.aarch64.run
-# please set a custom install-path, here take `/`vllm-workspace/CANN` as example.
-chmod +x ./CANN-custom_ops-mlapo-linux.aarch64.run
-./CANN-custom_ops-mlapo-linux.aarch64.run --quiet --install-path=/vllm-workspace/CANN
-wget https://vllm-ascend.obs.cn-north-4.myhuaweicloud.com/vllm-ascend/a3/torch_npu-2.7.1%2Bgitb7c90d0-cp311-cp311-linux_aarch64.whl
-pip install torch_npu-2.7.1+gitb7c90d0-cp311-cp311-linux_aarch64.whl
-wget https://vllm-ascend.obs.cn-north-4.myhuaweicloud.com/vllm-ascend/a3/libopsproto_rt2.0.so
-cp libopsproto_rt2.0.so /usr/local/Ascend/ascend-toolkit/8.2.RC1/opp/built-in/op_proto/lib/linux/aarch64/libopsproto_rt2.0.so
-# Don't forget to replace `/vllm-workspace/CANN/` to the custom path you set before.
-source /vllm-workspace/CANN/vendors/customize/bin/set_env.bash
-export LD_PRELOAD=/vllm-workspace/CANN/vendors/customize/op_proto/lib/linux/aarch64/libcust_opsproto_rt2.0.so:${LD_PRELOAD}
-```
-
-For `A2` image, you should change all `wget` commands as above, and replace `A3` with `A2` release file.
+For `A2` image:
1. Start the docker image on your node, refer to [using docker](../installation.md#set-up-using-docker).
@@ -98,22 +82,6 @@ wget https://vllm-ascend.obs.cn-north-4.myhuaweicloud.com/vllm-ascend/a2/custom_
pip install custom_ops-1.0-cp311-cp311-linux_aarch64.whl
```
-3. Download and install `MLAPO`.
-
-```shell
-wget https://vllm-ascend.obs.cn-north-4.myhuaweicloud.com/vllm-ascend/a2/CANN-custom_ops-mlapo-linux.aarch64.run
-# please set a custom install-path, here take `/`vllm-workspace/CANN` as example.
-chmod +x ./CANN-custom_ops-mlapo-linux.aarch64.run
-./CANN-custom_ops-mlapo-linux.aarch64.run --quiet --install-path=/vllm-workspace/CANN
-wget https://vllm-ascend.obs.cn-north-4.myhuaweicloud.com/vllm-ascend/a2/torch_npu-2.7.1%2Bgitb7c90d0-cp311-cp311-linux_aarch64.whl
-pip install torch_npu-2.7.1+gitb7c90d0-cp311-cp311-linux_aarch64.whl
-wget https://vllm-ascend.obs.cn-north-4.myhuaweicloud.com/vllm-ascend/a2/libopsproto_rt2.0.so
-cp libopsproto_rt2.0.so /usr/local/Ascend/ascend-toolkit/8.2.RC1/opp/built-in/op_proto/lib/linux/aarch64/libopsproto_rt2.0.so
-# Don't forget to replace `/vllm-workspace/CANN/` to the custom path you set before.
-source /vllm-workspace/CANN/vendors/customize/bin/set_env.bash
-export LD_PRELOAD=/vllm-workspace/CANN/vendors/customize/op_proto/lib/linux/aarch64/libcust_opsproto_rt2.0.so:${LD_PRELOAD}
-```
-
::::
::::{tab-item} Build from source
diff --git a/docs/source/tutorials/multi_node_pd_disaggregation_mooncake.md b/docs/source/tutorials/multi_node_pd_disaggregation_mooncake.md
index fefb86f2f7..1db83e071f 100644
--- a/docs/source/tutorials/multi_node_pd_disaggregation_mooncake.md
+++ b/docs/source/tutorials/multi_node_pd_disaggregation_mooncake.md
@@ -57,23 +57,14 @@ for i in {0..15}; do hccn_tool -i $i -ping -g address x.x.x.x;done
Mooncake is the serving platform for Kimi, a leading LLM service provided by Moonshot AI. First, we need to obtain the Mooncake project. Refer to the following command:
```shell
-git clone https://github.com/kvcache-ai/Mooncake.git
+git clone -b v0.3.7.post2 --depth 1 https://github.com/kvcache-ai/Mooncake.git
```
-Update and install Python
-
-```shell
-apt-get update
-apt-get install python3
-```
-
-Modify Mooncake compilation option
+(Optional) Replace go install url if the network is poor
```shell
cd Mooncake
-vi mooncake-common/common.cmake
-# find this row and set USE_ASCEND_DIRECT ON.
-option(USE_ASCEND_DIRECT "option for using ascend npu with adxl engine" ON)
+sed -i 's|https://go.dev/dl/|https://golang.google.cn/dl/|g' dependencies.sh
```
Install mpi
@@ -93,7 +84,7 @@ Compile and install
```shell
mkdir build
cd build
-cmake ..
+cmake .. -USE_ASCEND_DIRECT=ON
make -j
make install
```
diff --git a/docs/source/tutorials/multi_npu.md b/docs/source/tutorials/multi_npu.md
index 80a0929e4b..3dedc9723e 100644
--- a/docs/source/tutorials/multi_npu.md
+++ b/docs/source/tutorials/multi_npu.md
@@ -1,4 +1,4 @@
-# Multi-NPU (QwQ 32B)
+# Multi-NPU (QwQ-32B)
## Run vllm-ascend on Multi-NPU
diff --git a/docs/source/tutorials/multi_npu_moge.md b/docs/source/tutorials/multi_npu_moge.md
index 57ff41e26b..e426c0f337 100644
--- a/docs/source/tutorials/multi_npu_moge.md
+++ b/docs/source/tutorials/multi_npu_moge.md
@@ -1,4 +1,4 @@
-# Multi-NPU (Pangu Pro MoE)
+# Multi-NPU (Pangu-Pro-MoE)
## Run vllm-ascend on Multi-NPU
diff --git a/docs/source/tutorials/multi_npu_quantization.md b/docs/source/tutorials/multi_npu_quantization.md
index 7e664b2b75..23b183dbd1 100644
--- a/docs/source/tutorials/multi_npu_quantization.md
+++ b/docs/source/tutorials/multi_npu_quantization.md
@@ -1,4 +1,4 @@
-# Multi-NPU (QwQ 32B W8A8)
+# Multi-NPU (QwQ-32B-W8A8)
## Run Docker Container
:::{note}
diff --git a/docs/source/tutorials/single_npu.md b/docs/source/tutorials/single_npu.md
index 0759e3ede8..4b10d009a9 100644
--- a/docs/source/tutorials/single_npu.md
+++ b/docs/source/tutorials/single_npu.md
@@ -1,4 +1,4 @@
-# Single NPU (Qwen3 8B)
+# Single NPU (Qwen3-8B)
## Run vllm-ascend on Single NPU
diff --git a/docs/source/tutorials/single_npu_qwen2.5_vl.md b/docs/source/tutorials/single_npu_qwen2.5_vl.md
index 45aeeaa764..2454e0c710 100644
--- a/docs/source/tutorials/single_npu_qwen2.5_vl.md
+++ b/docs/source/tutorials/single_npu_qwen2.5_vl.md
@@ -1,4 +1,4 @@
-# Single NPU (Qwen2.5-VL 7B)
+# Single NPU (Qwen2.5-VL-7B)
## Run vllm-ascend on Single NPU
diff --git a/docs/source/tutorials/single_npu_qwen2_audio.md b/docs/source/tutorials/single_npu_qwen2_audio.md
index 94d86c5a9e..e093e84511 100644
--- a/docs/source/tutorials/single_npu_qwen2_audio.md
+++ b/docs/source/tutorials/single_npu_qwen2_audio.md
@@ -1,4 +1,4 @@
-# Single NPU (Qwen2-Audio 7B)
+# Single NPU (Qwen2-Audio-7B)
## Run vllm-ascend on Single NPU
diff --git a/docs/source/tutorials/single_npu_qwen3_quantization.md b/docs/source/tutorials/single_npu_qwen3_quantization.md
index bd735d79a5..40acff3468 100644
--- a/docs/source/tutorials/single_npu_qwen3_quantization.md
+++ b/docs/source/tutorials/single_npu_qwen3_quantization.md
@@ -1,4 +1,4 @@
-# Single-NPU (Qwen3 8B W4A8)
+# Single-NPU (Qwen3-8B-W4A8)
## Run Docker Container
:::{note}
diff --git a/docs/source/user_guide/configuration/additional_config.md b/docs/source/user_guide/configuration/additional_config.md
index 78e6d33a68..ec1e1a429b 100644
--- a/docs/source/user_guide/configuration/additional_config.md
+++ b/docs/source/user_guide/configuration/additional_config.md
@@ -1,6 +1,6 @@
# Additional Configuration
-Additional configuration is a mechanism provided by vLLM to allow plugins to control inner behavior by their own. vLLM Ascend uses this mechanism to make the project more flexible.
+Additional configuration is a mechanism provided by vLLM to allow plugins to control inner behavior by themselves. VLLM Ascend uses this mechanism to make the project more flexible.
## How to use
@@ -35,7 +35,7 @@ The following table lists additional configuration options available in vLLM Asc
| `enable_shared_expert_dp` | bool | `False` | When the expert is shared in DP, it delivers better performance but consumes more memory. Currently only DeepSeek series models are supported. |
| `lmhead_tensor_parallel_size` | int | `None` | The custom tensor parallel size of lmhead. |
| `oproj_tensor_parallel_size` | int | `None` | The custom tensor parallel size of oproj. |
-| `multistream_overlap_shared_expert` | bool | `False` | Whether to enable multistream shared expert. This option only takes effects on MoE models with shared experts. |
+| `multistream_overlap_shared_expert` | bool | `False` | Whether to enable multistream shared expert. This option only takes effect on MoE models with shared experts. |
| `dynamic_eplb` | bool | `False` | Whether to enable dynamic EPLB. |
| `num_iterations_eplb_update` | int | `400` | Forward iterations when EPLB begins. |
| `gate_eplb` | bool | `False` | Whether to enable EPLB only once. |
@@ -70,14 +70,14 @@ The details of each configuration option are as follows:
| `max_long_partial_prefills` | Union[int, float] | `float('inf')` | The maximum number of prompts longer than long_prefill_token_threshold that will be prefilled concurrently. |
| `long_prefill_token_threshold` | Union[int, float] | `float('inf')` | a request is considered long if the prompt is longer than this number of tokens. |
-ascend_scheduler_config also support the options from [vllm scheduler config](https://docs.vllm.ai/en/stable/api/vllm/config.html#vllm.config.SchedulerConfig). For example, you can add `enable_chunked_prefill: True` to ascend_scheduler_config as well.
+ascend_scheduler_config also supports the options from [vllm scheduler config](https://docs.vllm.ai/en/stable/api/vllm/config.html#vllm.config.SchedulerConfig). For example, you can add `enable_chunked_prefill: True` to ascend_scheduler_config as well.
**weight_prefetch_config**
| Name | Type | Default | Description |
|------------------|------|-------------------------------------------------------------|------------------------------------|
| `enabled` | bool | `False` | Whether to enable weight prefetch. |
-| `prefetch_ratio` | dict | `{"attn": {"qkv": 1.0, "o": 1.0}, "moe": {"gate_up": 0.8}}` | Prefetch ratio of each weights. |
+| `prefetch_ratio` | dict | `{"attn": {"qkv": 1.0, "o": 1.0}, "moe": {"gate_up": 0.8}}` | Prefetch ratio of each weight. |
### Example
diff --git a/docs/source/user_guide/feature_guide/dynamic_batch.md b/docs/source/user_guide/feature_guide/dynamic_batch.md
index c1e76354e9..7c68b2a930 100644
--- a/docs/source/user_guide/feature_guide/dynamic_batch.md
+++ b/docs/source/user_guide/feature_guide/dynamic_batch.md
@@ -11,9 +11,9 @@ We are working on further improvements and this feature will support more XPUs i
### Prerequisites
-1. Dynamic batch now depends on a offline cost model saved in a look-up table to refine the token budget. The lookup-table is saved in '.csv' file, which should be first downloaded from [here](https://vllm-ascend.obs.cn-north-4.myhuaweicloud.com/vllm-ascend/dynamic_batch_scheduler/A2-B3-BLK128.csv), renamed, and saved to the path `vllm_ascend/core/profile_table.csv`
+1. Dynamic batch now depends on an offline cost model saved in a lookup table to refine the token budget. The lookup table is saved in '.csv' file, which should be first downloaded from [here](https://vllm-ascend.obs.cn-north-4.myhuaweicloud.com/vllm-ascend/dynamic_batch_scheduler/A2-B3-BLK128.csv), renamed, and saved to the path `vllm_ascend/core/profile_table.csv`
-2. `Pandas` is needed to load the look-up table, in case `pandas` is not installed.
+2. `Pandas` is needed to load the lookup table, in case `pandas` is not installed.
```bash
pip install pandas
diff --git a/docs/source/user_guide/feature_guide/graph_mode.md b/docs/source/user_guide/feature_guide/graph_mode.md
index 3af9a41809..90aba6a3ee 100644
--- a/docs/source/user_guide/feature_guide/graph_mode.md
+++ b/docs/source/user_guide/feature_guide/graph_mode.md
@@ -8,7 +8,7 @@ This guide provides instructions for using Ascend Graph Mode with vLLM Ascend. P
## Getting Started
-From v0.9.1rc1 with V1 Engine, vLLM Ascend will run models in graph mode by default to keep the same behavior with vLLM. If you hit any issues, please feel free to open an issue on GitHub and fallback to the eager mode temporarily by set `enforce_eager=True` when initializing the model.
+From v0.9.1rc1 with V1 Engine, vLLM Ascend will run models in graph mode by default to keep the same behavior with vLLM. If you hit any issues, please feel free to open an issue on GitHub and fallback to the eager mode temporarily by setting `enforce_eager=True` when initializing the model.
There are two kinds for graph mode supported by vLLM Ascend:
- **ACLGraph**: This is the default graph mode supported by vLLM Ascend. In v0.9.1rc1, only Qwen series models are well tested.
@@ -45,14 +45,14 @@ import os
from vllm import LLM
# TorchAirGraph is only work without chunked-prefill now
-model = LLM(model="deepseek-ai/DeepSeek-R1-0528", additional_config={"torchair_graph_config": {"enabled": True},"ascend_scheduler_config": {"enabled": True,}})
+model = LLM(model="deepseek-ai/DeepSeek-R1-0528", additional_config={"torchair_graph_config": {"enabled": True},"ascend_scheduler_config": {"enabled": True}})
outputs = model.generate("Hello, how are you?")
```
Online example:
```shell
-vllm serve Qwen/Qwen2-7B-Instruct --additional-config='{"torchair_graph_config": {"enabled": true},"ascend_scheduler_config": {"enabled": true,}}'
+vllm serve deepseek-ai/DeepSeek-R1-0528 --additional-config='{"torchair_graph_config": {"enabled": true},"ascend_scheduler_config": {"enabled": true}}'
```
You can find more details about additional configuration [here](../configuration/additional_config.md).
@@ -74,5 +74,5 @@ outputs = model.generate("Hello, how are you?")
Online example:
```shell
-vllm serve Qwen/Qwen2-7B-Instruct --enforce-eager
+vllm serve someother_model_weight --enforce-eager
```
diff --git a/docs/source/user_guide/feature_guide/index.md b/docs/source/user_guide/feature_guide/index.md
index 61c333b0f6..b0c0fd7d46 100644
--- a/docs/source/user_guide/feature_guide/index.md
+++ b/docs/source/user_guide/feature_guide/index.md
@@ -13,4 +13,5 @@ lora
eplb_swift_balancer
netloader
dynamic_batch
+kv_pool_mooncake
:::
diff --git a/examples/disaggregated_prefill_v1/mooncake_connector_store_deployment_guide.md b/docs/source/user_guide/feature_guide/kv_pool_mooncake.md
similarity index 63%
rename from examples/disaggregated_prefill_v1/mooncake_connector_store_deployment_guide.md
rename to docs/source/user_guide/feature_guide/kv_pool_mooncake.md
index 28dd83b7cd..34ab047907 100644
--- a/examples/disaggregated_prefill_v1/mooncake_connector_store_deployment_guide.md
+++ b/docs/source/user_guide/feature_guide/kv_pool_mooncake.md
@@ -5,17 +5,26 @@
* Software:
* Python >= 3.9, < 3.12
* CANN >= 8.3.rc1
- * PyTorch == 2.7.1, torch-npu == 2.7.1
+ * PyTorch >= 2.7.1, torch-npu >= 2.7.1.dev20250724
* vLLM:main branch
* vLLM-Ascend:main branch
- * Mooncake:[AscendTransport/Mooncake at pooling-async-memcpy](https://github.com/AscendTransport/Mooncake/tree/pooling-async-memcpy)(Currently available branch code, continuously updated.)
- Installation and Compilation Guide:https://github.com/AscendTransport/Mooncake/tree/pooling-async-memcpy?tab=readme-ov-file#build-and-use-binaries
+ * Mooncake:main branch
+
+ Installation and Compilation Guide:https://github.com/kvcache-ai/Mooncake?tab=readme-ov-file#build-and-use-binaries
+
+ Make sure to build with `-DUSE_ASCEND_DIRECT` to enable ADXL engine.
+
+ An example command for compiling ADXL:
+
+ `rm -rf build && mkdir -p build && cd build \ && cmake .. -DCMAKE_INSTALL_PREFIX=/opt/transfer-engine/ -DCMAKE_POLICY_VERSION_MINIMUM=3.5 -DUSE_ASCEND_DIRECT=ON -DBUILD_SHARED_LIBS=ON -DBUILD_UNIT_TESTS=OFF \ && make -j \ && make install`
+
+ Also, you need to set environment variables to point to them `export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib64/python3.11/site-packages/mooncake`, or copy the .so files to the `/usr/local/lib64` directory after compilation
### KV Pooling Parameter Description
-**kv_connector_extra_config**:Additional Configurable Parameters for Pooling
- **mooncake_rpc_port**:Port for RPC Communication Between Pooling Scheduler Process and Worker Process: Each Instance Requires a Unique Port Configuration.
- **load_async**:Whether to Enable Asynchronous Loading. The default value is false.
- **register_buffer**:Whether to Register Video Memory with the Backend. Registration is Not Required When Used with MooncakeConnectorV1; It is Required in All Other Cases. The Default Value is false.
+**kv_connector_extra_config**:Additional Configurable Parameters for Pooling.
+**mooncake_rpc_port**:Port for RPC Communication Between Pooling Scheduler Process and Worker Process: Each Instance Requires a Unique Port Configuration.
+**load_async**:Whether to Enable Asynchronous Loading. The default value is false.
+**register_buffer**:Whether to Register Video Memory with the Backend. Registration is Not Required When Used with MooncakeConnectorV1; It is Required in All Other Cases. The Default Value is false.
## run mooncake master
@@ -29,26 +38,32 @@ The environment variable **MOONCAKE_CONFIG_PATH** is configured to the full path
"metadata_server": "P2PHANDSHAKE",
"protocol": "ascend",
"device_name": "",
+ "use_ascend_direct": true,
+ "alloc_in_same_node": true,
"master_server_address": "xx.xx.xx.xx:50088",
- "global_segment_size": 30000000000
+ "global_segment_size": "1GB" (1024MB/1048576KB/1073741824B/1073741824)
}
```
-**local_hostname**: Configured as the IP address of the current master node,
-**metadata_server**: Configured as **P2PHANDSHAKE**,
-**protocol:** Configured for Ascend to use Mooncake's HCCL communication,
-**device_name**: ""
-**master_server_address**: Configured with the IP and port of the master service
-**global_segment_size**: Expands the kvcache size registered by the PD node to the master
+**local_hostname**: Configured as the IP address of the current master node.
+**metadata_server**: Configured as **P2PHANDSHAKE**.
+**protocol:** Configured for Ascend to use Mooncake's HCCL communication.
+**device_name**: ""
+**use_ascend_direct**: Indicator for using ADXL engine.
+**alloc_in_same_node**: Indicator for preferring local buffer allocation strategy.
+**master_server_address**: Configured with the IP and port of the master service.
+**global_segment_size**: Expands the kvcache size registered by the PD node to the master.
### 2. Start mooncake_master
Under the mooncake folder:
```
-mooncake_master --port 50088
+mooncake_master --port 50088 --eviction_high_watermark_ratio 0.95 --eviction_ratio 0.05
```
+`eviction_high_watermark_ratio` determines the watermark where Mooncake Store will perform eviction,and `eviction_ratio` determines the portion of stored objects that would be evicted.
+
## Pooling and Prefill Decode Disaggregate Scenario
### 1.Run `prefill` Node and `decode` Node
@@ -69,11 +84,11 @@ export PYTHONPATH=$PYTHONPATH:/xxxxx/vllm
export MOONCAKE_CONFIG_PATH="/xxxxxx/mooncake.json"
export VLLM_USE_V1=1
export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3
-export ASCEND_TRANSPORT_PRINT=1
export ACL_OP_INIT_MODE=1
-# The upper boundary environment variable for memory swap logging is set to mooncake, where 1 indicates enabled and 0 indicates disabled.
-export ASCEND_AGGREGATE_ENABLE=1
-# The upper-level environment variable is the switch for enabling the mooncake aggregation function, where 1 means on and 0 means off.
+export ASCEND_BUFFER_POOL=4:8
+# ASCEND_BUFFER_POOL is the environment variable for configuring the number and size of buffer on NPU Device for aggregation and KV transfer,the value 4:8 means we allocate 4 buffers of size 8MB.
+export ASCEND_CONNECT_TIMEOUT=10000
+export ASCEND_TRANSFER_TIMEOUT=10000
python3 -m vllm.entrypoints.openai.api_server \
--model /xxxxx/Qwen2.5-7B-Instruct \
@@ -88,34 +103,34 @@ python3 -m vllm.entrypoints.openai.api_server \
--max-num-batched-tokens 4096 \
--kv-transfer-config \
'{
- "kv_connector": "MultiConnector",
- "kv_role": "kv_producer",
- "kv_connector_extra_config": {
- "use_layerwise": false,
- "connectors": [
- {
- "kv_connector": "MooncakeConnectorV1",
- "kv_role": "kv_producer",
- "kv_port": "20001",
- "kv_connector_extra_config": {
- "prefill": {
- "dp_size": 1,
- "tp_size": 1
- },
- "decode": {
- "dp_size": 1,
- "tp_size": 1
- }
- }
- },
- {
- "kv_connector": "MooncakeConnectorStoreV1",
- "kv_role": "kv_producer",
+ "kv_connector": "MultiConnector",
+ "kv_role": "kv_producer",
+ "kv_connector_extra_config": {
+ "use_layerwise": false,
+ "connectors": [
+ {
+ "kv_connector": "MooncakeConnectorV1",
+ "kv_role": "kv_producer",
+ "kv_port": "20001",
+ "kv_connector_extra_config": {
+ "prefill": {
+ "dp_size": 1,
+ "tp_size": 1
+ },
+ "decode": {
+ "dp_size": 1,
+ "tp_size": 1
+ }
+ }
+ },
+ {
+ "kv_connector": "MooncakeConnectorStoreV1",
+ "kv_role": "kv_producer",
"mooncake_rpc_port":"0"
- }
- ]
- }
-}' > p.log 2>&1
+ }
+ ]
+ }
+ }' > p.log 2>&1
```
`decode` Node:
@@ -133,10 +148,9 @@ export MOONCAKE_CONFIG_PATH="/xxxxx/mooncake.json"
export VLLM_USE_V1=1
export ASCEND_RT_VISIBLE_DEVICES=4,5,6,7
export ACL_OP_INIT_MODE=1
-export ASCEND_TRANSPORT_PRINT=1
-# The upper boundary environment variable for memory swap logging is set to mooncake, where 1 indicates enabled and 0 indicates disabled.
-export ASCEND_AGGREGATE_ENABLE=1
-# The upper-level environment variable is the switch for enabling the mooncake aggregation function, where 1 means on and 0 means off.
+export ASCEND_BUFFER_POOL=4:8
+export ASCEND_CONNECT_TIMEOUT=10000
+export ASCEND_TRANSFER_TIMEOUT=10000
python3 -m vllm.entrypoints.openai.api_server \
--model /xxxxx/Qwen2.5-7B-Instruct \
@@ -151,33 +165,34 @@ python3 -m vllm.entrypoints.openai.api_server \
--max-num-batched-tokens 4096 \
--kv-transfer-config \
'{
- "kv_connector": "MultiConnector",
- "kv_role": "kv_consumer",
- "kv_connector_extra_config": {
- "use_layerwise": false,
- "connectors": [
- {
- "kv_connector": "MooncakeConnectorV1",
- "kv_role": "kv_consumer",
- "kv_port": "20002",
- "kv_connector_extra_config": {
- "prefill": {
- "dp_size": 1,
- "tp_size": 1
- },
- "decode": {
- "dp_size": 1,
- "tp_size": 1
- }
- }
- },
- {
- "kv_connector": "MooncakeConnectorStoreV1",
- "kv_role": "kv_consumer",
+ "kv_connector": "MultiConnector",
+ "kv_role": "kv_consumer",
+ "kv_connector_extra_config": {
+ "use_layerwise": false,
+ "connectors": [
+ {
+ "kv_connector": "MooncakeConnectorV1",
+ "kv_role": "kv_consumer",
+ "kv_port": "20002",
+ "kv_connector_extra_config": {
+ "use_ascend_direct": true,
+ "prefill": {
+ "dp_size": 1,
+ "tp_size": 1
+ },
+ "decode": {
+ "dp_size": 1,
+ "tp_size": 1
+ }
+ }
+ },
+ {
+ "kv_connector": "MooncakeConnectorStoreV1",
+ "kv_role": "kv_consumer",
"mooncake_rpc_port":"1"
- }
- ]
- }
+ }
+ ]
+ }
}' > d.log 2>&1
```
@@ -234,10 +249,9 @@ export MOONCAKE_CONFIG_PATH="/xxxxxx/mooncake.json"
export VLLM_USE_V1=1
export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3
export ACL_OP_INIT_MODE=1
-export ASCEND_TRANSPORT_PRINT=1
-# The upper boundary environment variable for memory swap logging is set to mooncake, where 1 indicates enabled and 0 indicates disabled.
-export ASCEND_AGGREGATE_ENABLE=1
-# The upper-level environment variable is the switch for enabling the mooncake aggregation function, where 1 means on and 0 means off.
+export ASCEND_BUFFER_POOL=4:8
+export ASCEND_CONNECT_TIMEOUT=10000
+export ASCEND_TRANSFER_TIMEOUT=10000
python3 -m vllm.entrypoints.openai.api_server \
--model /xxxxx/Qwen2.5-7B-Instruct \
@@ -252,12 +266,12 @@ python3 -m vllm.entrypoints.openai.api_server \
--max-num-batched-tokens 4096 \
--kv-transfer-config \
'{
- "kv_connector": "MooncakeConnectorStoreV1",
- "kv_role": "kv_both",
- "kv_connector_extra_config": {
- "use_layerwise": false,
+ "kv_connector": "MooncakeConnectorStoreV1",
+ "kv_role": "kv_both",
+ "kv_connector_extra_config": {
+ "use_layerwise": false,
"mooncake_rpc_port":"0"
- }
+ }
}' > mix.log 2>&1
```
@@ -275,4 +289,4 @@ Long question:
```
curl -s http://localhost:8100/v1/completions -H "Content-Type: application/json" -d '{ "model": "/xxxxx/Qwen2.5-7B-Instruct", "prompt": "Given the accelerating impacts of climate change—including rising sea levels, increasing frequency of extreme weather events, loss of biodiversity, and adverse effects on agriculture and human health—there is an urgent need for a robust, globally coordinated response. However, international efforts are complicated by a range of factors: economic disparities between high-income and low-income countries, differing levels of industrialization, varying access to clean energy technologies, and divergent political systems that influence climate policy implementation. In this context, how can global agreements like the Paris Accord be redesigned or strengthened to not only encourage but effectively enforce emission reduction targets? Furthermore, what mechanisms can be introduced to promote fair and transparent technology transfer, provide adequate financial support for climate adaptation in vulnerable regions, and hold nations accountable without exacerbating existing geopolitical tensions or disproportionately burdening those with historically lower emissions?", "max_tokens": 256, "temperature":0.0 }'
-```
\ No newline at end of file
+```
diff --git a/docs/source/user_guide/feature_guide/lora.md b/docs/source/user_guide/feature_guide/lora.md
index ad4bc2d320..4678c0241e 100644
--- a/docs/source/user_guide/feature_guide/lora.md
+++ b/docs/source/user_guide/feature_guide/lora.md
@@ -20,4 +20,4 @@ vllm serve meta-llama/Llama-2-7b \
We have implemented LoRA-related AscendC operators, such as bgmv_shrink, bgmv_expand, sgmv_shrink and sgmv_expand. You can find them under the "csrc/kernels" directory of [vllm-ascend repo](https://github.com/vllm-project/vllm-ascend.git).
-When you install vllm and vllm-ascend, those operators mentioned above will be compiled and installed automatically. If you do not want to use AscendC operators when you run vllm-ascend, you should set `COMPILE_CUSTOM_KERNELS=0` and reinstall vllm-ascend. To require more instructions about installation and compilation, you can refer to [installation guide](../../installation.md).
+When you install vllm and vllm-ascend, those operators mentioned above will be compiled and installed automatically. If you do not want to use AscendC operators when you run vllm-ascend, you should set `COMPILE_CUSTOM_KERNELS=0` and reinstall vllm-ascend. For more instructions about installation and compilation, you can refer to [installation guide](../../installation.md).
diff --git a/docs/source/user_guide/feature_guide/quantization.md b/docs/source/user_guide/feature_guide/quantization.md
index e2a48ff35a..8a6e36765d 100644
--- a/docs/source/user_guide/feature_guide/quantization.md
+++ b/docs/source/user_guide/feature_guide/quantization.md
@@ -28,7 +28,7 @@ See https://www.modelscope.cn/models/vllm-ascend/Kimi-K2-Instruct-W8A8.
This conversion process requires a larger CPU memory, ensure that the RAM size is greater than 2 TB.
:::
-### Adapt to changes
+### Adapts and changes
1. Ascend does not support the `flash_attn` library. To run the model, you need to follow the [guide](https://gitee.com/ascend/msit/blob/master/msmodelslim/example/DeepSeek/README.md#deepseek-v3r1) and comment out certain parts of the code in `modeling_deepseek.py` located in the weights folder.
2. The current version of transformers does not support loading weights in FP8 quantization format. you need to follow the [guide](https://gitee.com/ascend/msit/blob/master/msmodelslim/example/DeepSeek/README.md#deepseek-v3r1) and delete the quantization related fields from `config.json` in the weights folder.
diff --git a/docs/source/user_guide/feature_guide/sleep_mode.md b/docs/source/user_guide/feature_guide/sleep_mode.md
index c616f7e815..6fc3652115 100644
--- a/docs/source/user_guide/feature_guide/sleep_mode.md
+++ b/docs/source/user_guide/feature_guide/sleep_mode.md
@@ -80,7 +80,7 @@ The following is a simple example of how to use sleep mode.
vllm serve Qwen/Qwen2.5-0.5B-Instruct --enable-sleep-mode
- # after serveing is up, post these endpoints
+ # after serving is up, post to these endpoints
# sleep level 1
curl -X POST http://127.0.0.1:8000/sleep \
diff --git a/docs/source/user_guide/release_notes.md b/docs/source/user_guide/release_notes.md
index 56d101dabf..307d15357e 100644
--- a/docs/source/user_guide/release_notes.md
+++ b/docs/source/user_guide/release_notes.md
@@ -39,7 +39,7 @@ This is the 1st release candidate of v0.10.2 for vLLM Ascend. Please follow the
- MTP now works with the token > 1. [#2708](https://github.com/vllm-project/vllm-ascend/pull/2708)
- Qwen2.5 VL now works with quantization. [#2778](https://github.com/vllm-project/vllm-ascend/pull/2778)
- Improved the performance with async scheduler enabled. [#2783](https://github.com/vllm-project/vllm-ascend/pull/2783)
-- Fixed the performance regression with non MLA model when use default scheduler. [#2894](https://github.com/vllm-project/vllm-ascend/pull/2894)
+- Fixed the performance regression with non MLA model when using default scheduler. [#2894](https://github.com/vllm-project/vllm-ascend/pull/2894)
### Others
- The performance of W8A8 quantization is improved. [#2275](https://github.com/vllm-project/vllm-ascend/pull/2275)
@@ -106,7 +106,7 @@ This is the 1st release candidate of v0.10.1 for vLLM Ascend. Please follow the
* Environment variable `VLLM_LLMDD_RPC_PORT` is renamed to `VLLM_ASCEND_LLMDD_RPC_PORT` now. [#2450](https://github.com/vllm-project/vllm-ascend/pull/2450)
* Added `VLLM_ASCEND_ENABLE_MLP_OPTIMIZE` in environment variables, whether to enable mlp optimize when tensor parallel is enabled. This feature provides better performance in eager mode. [#2120](https://github.com/vllm-project/vllm-ascend/pull/2120)
* Removed `MOE_ALL2ALL_BUFFER` and `VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ` in environment variables. [#2612](https://github.com/vllm-project/vllm-ascend/pull/2612)
- * Added `enable_prefetch` in `additional_config`, whether to enable weight prefetch. [#2465](https://github.com/vllm-project/vllm-ascend/pull/2465)
+ * Added `enable_prefetch` in `additional_config`, Whether to enable weight prefetch. [#2465](https://github.com/vllm-project/vllm-ascend/pull/2465)
* Added `mode` in `additional_config.torchair_graph_config`, When using reduce-overhead mode for torchair, mode needs to be set. [#2461](https://github.com/vllm-project/vllm-ascend/pull/2461)
* `enable_shared_expert_dp` in `additional_config` is disabled by default now, and it is recommended to be enabled when inferencing with deepseek. [#2457](https://github.com/vllm-project/vllm-ascend/pull/2457)
@@ -461,7 +461,7 @@ This is the 1st release candidate of v0.9.0 for vllm-ascend. Please follow the [
### Highlights
- DeepSeek works with graph mode now. Follow the [official doc](https://vllm-ascend.readthedocs.io/en/latest/user_guide/feature_guide/graph_mode.html) to take a try. [#789](https://github.com/vllm-project/vllm-ascend/pull/789)
-- Qwen series models works with graph mode now. It works by default with V1 Engine. Please note that in this release, only Qwen series models are well tested with graph mode. We'll make it stable and generalize in the next release. If you hit any issues, please feel free to open an issue on GitHub and fallback to eager mode temporarily by set `enforce_eager=True` when initializing the model.
+- Qwen series models work with graph mode now. It works by default with V1 Engine. Please note that in this release, only Qwen series models are well tested with graph mode. We'll make it stable and generalize in the next release. If you hit any issues, please feel free to open an issue on GitHub and fallback to eager mode temporarily by set `enforce_eager=True` when initializing the model.
### Core
@@ -590,13 +590,13 @@ This is the first release candidate of v0.8.4 for vllm-ascend. Please follow the
- vLLM V1 engine experimental support is included in this version. You can visit [official guide](https://docs.vllm.ai/en/latest/getting_started/v1_user_guide.html) to get more detail. By default, vLLM will fallback to V0 if V1 doesn't work, please set `VLLM_USE_V1=1` environment if you want to use V1 forcibly.
- LoRA、Multi-LoRA And Dynamic Serving is supported now. The performance will be improved in the next release. Please follow the [official doc](https://docs.vllm.ai/en/latest/features/lora.html) for more usage information. Thanks for the contribution from China Merchants Bank. [#521](https://github.com/vllm-project/vllm-ascend/pull/521).
-- Sleep Mode feature is supported. Currently it's only work on V0 engine. V1 engine support will come soon. [#513](https://github.com/vllm-project/vllm-ascend/pull/513)
+- Sleep Mode feature is supported. Currently it only works on V0 engine. V1 engine support will come soon. [#513](https://github.com/vllm-project/vllm-ascend/pull/513)
### Core
- The Ascend scheduler is added for V1 engine. This scheduler is more affinity with Ascend hardware. More scheduler policy will be added in the future. [#543](https://github.com/vllm-project/vllm-ascend/pull/543)
- Disaggregated Prefill feature is supported. Currently only 1P1D works. NPND is under design by vllm team. vllm-ascend will support it once it's ready from vLLM. Follow the [official guide](https://docs.vllm.ai/en/latest/features/disagg_prefill.html) to use. [#432](https://github.com/vllm-project/vllm-ascend/pull/432)
-- Spec decode feature works now. Currently it's only work on V0 engine. V1 engine support will come soon. [#500](https://github.com/vllm-project/vllm-ascend/pull/500)
+- Spec decode feature works now. Currently it only works on V0 engine. V1 engine support will come soon. [#500](https://github.com/vllm-project/vllm-ascend/pull/500)
- Structured output feature works now on V1 Engine. Currently it only supports xgrammar backend while using guidance backend may get some errors. [#555](https://github.com/vllm-project/vllm-ascend/pull/555)
### Others
diff --git a/docs/source/user_guide/support_matrix/supported_features.md b/docs/source/user_guide/support_matrix/supported_features.md
index 10816a4092..72d8811e86 100644
--- a/docs/source/user_guide/support_matrix/supported_features.md
+++ b/docs/source/user_guide/support_matrix/supported_features.md
@@ -10,7 +10,7 @@ You can check the [support status of vLLM V1 Engine][v1_user_guide]. Below is th
| Automatic Prefix Caching | 🟢 Functional | Functional, see detailed note: [vllm-ascend#732][apc] |
| LoRA | 🟢 Functional | [vllm-ascend#396][multilora], [vllm-ascend#893][v1 multilora] |
| Speculative decoding | 🟢 Functional | Basic support |
-| Pooling | 🟢 Functional | CI needed to adapt to more models; V1 support rely on vLLM support. |
+| Pooling | 🟢 Functional | CI needed to adapt to more models; V1 support relies on vLLM support. |
| Enc-dec | 🟡 Planned | vLLM should support this feature first. |
| Multi Modality | 🟢 Functional | [Tutorial][multimodal], optimizing and adapting more models |
| LogProbs | 🟢 Functional | CI needed |
diff --git a/docs/source/user_guide/support_matrix/supported_models.md b/docs/source/user_guide/support_matrix/supported_models.md
index 256f0333f7..c5a718b0cc 100644
--- a/docs/source/user_guide/support_matrix/supported_models.md
+++ b/docs/source/user_guide/support_matrix/supported_models.md
@@ -6,78 +6,78 @@ Get the latest info here: https://github.com/vllm-project/vllm-ascend/issues/160
### Generative Models
-| Model | Support | Note |
-|-------------------------------|-----------|----------------------------------------------------------------------|
-| DeepSeek V3/3.1 | ✅ | |
-| DeepSeek V3.2 EXP | ✅ | |
-| DeepSeek R1 | ✅ | |
-| DeepSeek Distill (Qwen/LLama) | ✅ | |
-| Qwen3 | ✅ | |
-| Qwen3-based | ✅ | |
-| Qwen3-Coder | ✅ | |
-| Qwen3-Moe | ✅ | |
-| Qwen3-Next | ✅ | |
-| Qwen2.5 | ✅ | |
-| Qwen2 | ✅ | |
-| Qwen2-based | ✅ | |
-| QwQ-32B | ✅ | |
-| LLama2/3/3.1 | ✅ | |
-| Internlm | ✅ | [#1962](https://github.com/vllm-project/vllm-ascend/issues/1962) |
-| Baichuan | ✅ | |
-| Baichuan2 | ✅ | |
-| Phi-4-mini | ✅ | |
-| MiniCPM | ✅ | |
-| MiniCPM3 | ✅ | |
-| Ernie4.5 | ✅ | |
-| Ernie4.5-Moe | ✅ | |
-| Gemma-2 | ✅ | |
-| Gemma-3 | ✅ | |
-| Phi-3/4 | ✅ | |
-| Mistral/Mistral-Instruct | ✅ | |
-| GLM-4.5 | ✅ | |
-| GLM-4 | ❌ | [#2255](https://github.com/vllm-project/vllm-ascend/issues/2255) |
-| GLM-4-0414 | ❌ | [#2258](https://github.com/vllm-project/vllm-ascend/issues/2258) |
-| ChatGLM | ❌ | [#554](https://github.com/vllm-project/vllm-ascend/issues/554) |
-| DeepSeek V2.5 | 🟡 | Need test |
-| Mllama | 🟡 | Need test |
-| MiniMax-Text | 🟡 | Need test |
+| Model | Support | Note | BF16 | Supported Hardware | W8A8 | Chunked Prefill | Automatic Prefix Cache | LoRA | Speculative Decoding | Async Scheduling | Tensor Parallel | Pipeline Parallel | Expert Parallel | Data Parallel | Prefill-decode Disaggregation | Piecewise AclGraph | Fullgraph AclGraph | max-model-len | MLP Weight Prefetch | Doc |
+|-------------------------------|-----------|----------------------------------------------------------------------|------|--------------------|------|-----------------|------------------------|------|----------------------|------------------|-----------------|-------------------|-----------------|---------------|-------------------------------|--------------------|--------------------|---------------|---------------------|-----|
+| DeepSeek V3/3.1 | ✅ | |||||||||||||||||||
+| DeepSeek V3.2 EXP | ✅ | | ✅ | A2/A3 | ✅ | ✅ | ✅ | ✅ | ✅ | | ✅ | ✅ | ✅ | ✅ | ❌ | | | 163840 | | [DeepSeek-V3.2-Exp tutorial](../../tutorials/DeepSeek-V3.2-Exp.md) |
+| DeepSeek R1 | ✅ | |||||||||||||||||||
+| DeepSeek Distill (Qwen/LLama) | ✅ | |||||||||||||||||||
+| Qwen3 | ✅ | |||||||||||||||||||
+| Qwen3-based | ✅ | |||||||||||||||||||
+| Qwen3-Coder | ✅ | |||||||||||||||||||
+| Qwen3-Moe | ✅ | |||||||||||||||||||
+| Qwen3-Next | ✅ | |||||||||||||||||||
+| Qwen2.5 | ✅ | |||||||||||||||||||
+| Qwen2 | ✅ | |||||||||||||||||||
+| Qwen2-based | ✅ | |||||||||||||||||||
+| QwQ-32B | ✅ | |||||||||||||||||||
+| LLama2/3/3.1 | ✅ | |||||||||||||||||||
+| Internlm | ✅ | [#1962](https://github.com/vllm-project/vllm-ascend/issues/1962) |||||||||||||||||||
+| Baichuan | ✅ | |||||||||||||||||||
+| Baichuan2 | ✅ | |||||||||||||||||||
+| Phi-4-mini | ✅ | |||||||||||||||||||
+| MiniCPM | ✅ | |||||||||||||||||||
+| MiniCPM3 | ✅ | |||||||||||||||||||
+| Ernie4.5 | ✅ | |||||||||||||||||||
+| Ernie4.5-Moe | ✅ | |||||||||||||||||||
+| Gemma-2 | ✅ | |||||||||||||||||||
+| Gemma-3 | ✅ | |||||||||||||||||||
+| Phi-3/4 | ✅ | |||||||||||||||||||
+| Mistral/Mistral-Instruct | ✅ | |||||||||||||||||||
+| GLM-4.5 | ✅ | |||||||||||||||||||
+| GLM-4 | ❌ | [#2255](https://github.com/vllm-project/vllm-ascend/issues/2255) |||||||||||||||||||
+| GLM-4-0414 | ❌ | [#2258](https://github.com/vllm-project/vllm-ascend/issues/2258) |||||||||||||||||||
+| ChatGLM | ❌ | [#554](https://github.com/vllm-project/vllm-ascend/issues/554) |||||||||||||||||||
+| DeepSeek V2.5 | 🟡 | Need test |||||||||||||||||||
+| Mllama | 🟡 | Need test |||||||||||||||||||
+| MiniMax-Text | 🟡 | Need test |||||||||||||||||||
### Pooling Models
-| Model | Support | Note |
-|-------------------------------|-----------|----------------------------------------------------------------------|
-| Qwen3-Embedding | ✅ | |
-| Molmo | ✅ | [1942](https://github.com/vllm-project/vllm-ascend/issues/1942) |
-| XLM-RoBERTa-based | ❌ | [1960](https://github.com/vllm-project/vllm-ascend/issues/1960) |
+| Model | Support | Note | BF16 | Supported Hardware | W8A8 | Chunked Prefill | Automatic Prefix Cache | LoRA | Speculative Decoding | Async Scheduling | Tensor Parallel | Pipeline Parallel | Expert Parallel | Data Parallel | Prefill-decode Disaggregation | Piecewise AclGraph | Fullgraph AclGraph | max-model-len | MLP Weight Prefetch | Doc |
+|-------------------------------|-----------|----------------------------------------------------------------------|------|--------------------|------|-----------------|------------------------|------|----------------------|------------------|-----------------|-------------------|-----------------|---------------|-------------------------------|--------------------|--------------------|---------------|---------------------|-----|
+| Qwen3-Embedding | ✅ | |||||||||||||||||||
+| Molmo | ✅ | [1942](https://github.com/vllm-project/vllm-ascend/issues/1942) |||||||||||||||||||
+| XLM-RoBERTa-based | ❌ | [1960](https://github.com/vllm-project/vllm-ascend/issues/1960) |||||||||||||||||||
## Multimodal Language Models
### Generative Models
-| Model | Support | Note |
-|--------------------------------|---------------|----------------------------------------------------------------------|
-| Qwen2-VL | ✅ | |
-| Qwen2.5-VL | ✅ | |
-| Qwen3-VL | ✅ | |
-| Qwen3-VL-MOE | ✅ | |
-| Qwen2.5-Omni | ✅ | [1760](https://github.com/vllm-project/vllm-ascend/issues/1760) |
-| QVQ | ✅ | |
-| LLaVA 1.5/1.6 | ✅ | [1962](https://github.com/vllm-project/vllm-ascend/issues/1962) |
-| InternVL2 | ✅ | |
-| InternVL2.5 | ✅ | |
-| Qwen2-Audio | ✅ | |
-| Aria | ✅ | |
-| LLaVA-Next | ✅ | |
-| LLaVA-Next-Video | ✅ | |
-| MiniCPM-V | ✅ | |
-| Mistral3 | ✅ | |
-| Phi-3-Vison/Phi-3.5-Vison | ✅ | |
-| Gemma3 | ✅ | |
-| LLama4 | ❌ | [1972](https://github.com/vllm-project/vllm-ascend/issues/1972) |
-| LLama3.2 | ❌ | [1972](https://github.com/vllm-project/vllm-ascend/issues/1972) |
-| Keye-VL-8B-Preview | ❌ | [1963](https://github.com/vllm-project/vllm-ascend/issues/1963) |
-| Florence-2 | ❌ | [2259](https://github.com/vllm-project/vllm-ascend/issues/2259) |
-| GLM-4V | ❌ | [2260](https://github.com/vllm-project/vllm-ascend/issues/2260) |
-| InternVL2.0/2.5/3.0
InternVideo2.5/Mono-InternVL | ❌ | [2064](https://github.com/vllm-project/vllm-ascend/issues/2064) |
-| Whisper | ❌ | [2262](https://github.com/vllm-project/vllm-ascend/issues/2262) |
-| Ultravox | 🟡 | Need test |
+| Model | Support | Note | BF16 | Supported Hardware | W8A8 | Chunked Prefill | Automatic Prefix Cache | LoRA | Speculative Decoding | Async Scheduling | Tensor Parallel | Pipeline Parallel | Expert Parallel | Data Parallel | Prefill-decode Disaggregation | Piecewise AclGraph | Fullgraph AclGraph | max-model-len | MLP Weight Prefetch | Doc |
+|--------------------------------|---------------|----------------------------------------------------------------------|------|--------------------|------|-----------------|------------------------|------|----------------------|------------------|-----------------|-------------------|-----------------|---------------|-------------------------------|--------------------|--------------------|---------------|---------------------|-----|
+| Qwen2-VL | ✅ | |||||||||||||||||||
+| Qwen2.5-VL | ✅ | |||||||||||||||||||
+| Qwen3-VL | ✅ | |||||||||||||||||||
+| Qwen3-VL-MOE | ✅ | |||||||||||||||||||
+| Qwen2.5-Omni | ✅ | [1760](https://github.com/vllm-project/vllm-ascend/issues/1760) |||||||||||||||||||
+| QVQ | ✅ | |||||||||||||||||||
+| LLaVA 1.5/1.6 | ✅ | [1962](https://github.com/vllm-project/vllm-ascend/issues/1962) |||||||||||||||||||
+| InternVL2 | ✅ | |||||||||||||||||||
+| InternVL2.5 | ✅ | |||||||||||||||||||
+| Qwen2-Audio | ✅ | |||||||||||||||||||
+| Aria | ✅ | |||||||||||||||||||
+| LLaVA-Next | ✅ | |||||||||||||||||||
+| LLaVA-Next-Video | ✅ | |||||||||||||||||||
+| MiniCPM-V | ✅ | |||||||||||||||||||
+| Mistral3 | ✅ | |||||||||||||||||||
+| Phi-3-Vison/Phi-3.5-Vison | ✅ | |||||||||||||||||||
+| Gemma3 | ✅ | |||||||||||||||||||
+| LLama4 | ❌ | [1972](https://github.com/vllm-project/vllm-ascend/issues/1972) |||||||||||||||||||
+| LLama3.2 | ❌ | [1972](https://github.com/vllm-project/vllm-ascend/issues/1972) |||||||||||||||||||
+| Keye-VL-8B-Preview | ❌ | [1963](https://github.com/vllm-project/vllm-ascend/issues/1963) |||||||||||||||||||
+| Florence-2 | ❌ | [2259](https://github.com/vllm-project/vllm-ascend/issues/2259) |||||||||||||||||||
+| GLM-4V | ❌ | [2260](https://github.com/vllm-project/vllm-ascend/issues/2260) |||||||||||||||||||
+| InternVL2.0/2.5/3.0
InternVideo2.5/Mono-InternVL | ❌ | [2064](https://github.com/vllm-project/vllm-ascend/issues/2064) |||||||||||||||||||
+| Whisper | ❌ | [2262](https://github.com/vllm-project/vllm-ascend/issues/2262) |||||||||||||||||||
+| Ultravox | 🟡 | Need test |||||||||||||||||||
diff --git a/examples/disaggregated_prefill_v1/load_balance_proxy_layerwise_server_example.py b/examples/disaggregated_prefill_v1/load_balance_proxy_layerwise_server_example.py
index 67c34ee899..8bbc3595ee 100644
--- a/examples/disaggregated_prefill_v1/load_balance_proxy_layerwise_server_example.py
+++ b/examples/disaggregated_prefill_v1/load_balance_proxy_layerwise_server_example.py
@@ -447,7 +447,7 @@ def get_api_request_id(api, req_id):
def get_origin_request_id(api, req_id):
if api == "/completions":
- return req_id.replace("cmpl-", "").replace("-0", "")
+ return req_id.replace("cmpl-", "")[:-2]
elif api == "/chat/completions":
return req_id.replace("chatcmpl-", "")
@@ -561,9 +561,12 @@ async def metaserver(request: Request):
max_retries=global_args.max_retries,
base_delay=global_args.retry_delay)
proxy_state.release_prefiller(prefiller_idx, prefiller_score)
+ proxy_state.release_prefiller_kv(prefiller_idx,prefiller_score)
except Exception as e:
logger.error(f"Post metaserver failed with: {str(e)}")
+ proxy_state.release_prefiller(prefiller_idx, prefiller_score)
+ proxy_state.release_prefiller_kv(prefiller_idx, prefiller_score)
if __name__ == '__main__':
diff --git a/tests/e2e/multicard/test_external_launcher.py b/tests/e2e/multicard/test_external_launcher.py
index 9bf855e30a..d5441691c3 100644
--- a/tests/e2e/multicard/test_external_launcher.py
+++ b/tests/e2e/multicard/test_external_launcher.py
@@ -108,6 +108,7 @@ def test_moe_external_launcher(model):
assert proc.returncode == 0
+@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_NZ": "0"})
def test_external_launcher_and_sleepmode():
script = Path(
__file__
@@ -154,6 +155,7 @@ def test_external_launcher_and_sleepmode():
assert proc.returncode == 0
+@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_NZ": "0"})
def test_external_launcher_and_sleepmode_level2():
script = Path(
__file__
diff --git a/tests/e2e/multicard/test_qwen3_next.py b/tests/e2e/multicard/test_qwen3_next.py
index 9fda522021..cf3382318d 100644
--- a/tests/e2e/multicard/test_qwen3_next.py
+++ b/tests/e2e/multicard/test_qwen3_next.py
@@ -20,10 +20,17 @@
Run `pytest tests/e2e/multicard/test_qwen3_next.py`.
"""
+import os
+from unittest.mock import patch
from tests.e2e.conftest import VllmRunner
+# NZ will cause precision error in Qwen3-Next
+# When it is fixed, this set-up can be removed
+_IS_ENABLE_NZ = "VLLM_ASCEND_ENABLE_NZ"
+
+@patch.dict(os.environ, {_IS_ENABLE_NZ: "0"})
def test_models_distributed_Qwen3_NEXT_TP4():
example_prompts = [
"Hello, my name is",
@@ -36,8 +43,10 @@ def test_models_distributed_Qwen3_NEXT_TP4():
distributed_executor_backend="mp",
enforce_eager=True) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)
+ del vllm_model
+@patch.dict(os.environ, {_IS_ENABLE_NZ: "0"})
def test_models_distributed_Qwen3_NEXT_TP4_FULL_DECODE_ONLY():
example_prompts = [
"Hello, my name is",
@@ -54,3 +63,50 @@ def test_models_distributed_Qwen3_NEXT_TP4_FULL_DECODE_ONLY():
"cudagraph_capture_sizes": [1, 8, 24, 48, 60]
}) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)
+ del vllm_model
+
+
+@patch.dict(os.environ, {_IS_ENABLE_NZ: "0"})
+def test_models_distributed_Qwen3_NEXT_MTP_TP4_SIMILARITY():
+ example_prompts = [
+ "Hello, my name is",
+ "The president of the United States is",
+ "The capital of France is",
+ "The future of AI is",
+ ]
+ max_tokens = 20
+
+ with VllmRunner("Qwen/Qwen3-Next-80B-A3B-Instruct",
+ tensor_parallel_size=4,
+ max_model_len=4096,
+ gpu_memory_utilization=0.8,
+ distributed_executor_backend="mp") as vllm_model:
+ ref_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
+ del vllm_model
+
+ with VllmRunner("Qwen/Qwen3-Next-80B-A3B-Instruct",
+ tensor_parallel_size=4,
+ max_model_len=4096,
+ gpu_memory_utilization=0.8,
+ distributed_executor_backend="mp",
+ speculative_config={
+ "method": "qwen3_next_mtp",
+ "num_speculative_tokens": 1
+ }) as spec_vllm_model:
+ spec_outputs = spec_vllm_model.generate_greedy(example_prompts,
+ max_tokens)
+ del spec_vllm_model
+
+ matches = 0
+ misses = 0
+ for ref_output, spec_output in zip(ref_outputs, spec_outputs):
+ ref_token_ids = ref_output[0]
+ spec_token_ids = spec_output[0]
+ if ref_token_ids == spec_token_ids[:len(ref_token_ids)]:
+ matches += 1
+ else:
+ misses += 1
+ print(f"ref_output: {ref_output[1]}")
+ print(f"spec_output: {spec_output[1]}")
+
+ assert matches > int(0.66 * len(ref_outputs))
diff --git a/tests/e2e/nightly/models/test_deepseek_r1_w8a8_eplb.py b/tests/e2e/nightly/models/test_deepseek_r1_w8a8_eplb.py
index 26bcfa9104..89449ac4c3 100644
--- a/tests/e2e/nightly/models/test_deepseek_r1_w8a8_eplb.py
+++ b/tests/e2e/nightly/models/test_deepseek_r1_w8a8_eplb.py
@@ -14,6 +14,7 @@
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
+import json
from typing import Any
import openai
@@ -27,8 +28,7 @@
"vllm-ascend/DeepSeek-R1-W8A8",
]
-TENSOR_PARALLELS = [8]
-DATA_PARALLELS = [2]
+MODES = ["eplb"]
prompts = [
"San Francisco is a",
@@ -38,55 +38,52 @@
"max_tokens": 10,
}
-aisbench_cases = [{
+aisbench_gsm8k = [{
"case_type": "accuracy",
"dataset_path": "vllm-ascend/gsm8k-lite",
"request_conf": "vllm_api_general_chat",
"dataset_conf": "gsm8k/gsm8k_gen_0_shot_cot_chat_prompt",
"max_out_len": 32768,
"batch_size": 32,
- "baseline": 93,
+ "top_k": 20,
+ "baseline": 95,
"threshold": 5
-}, {
- "case_type": "performance",
- "dataset_path": "vllm-ascend/GSM8K-in3500-bs400",
- "request_conf": "vllm_api_stream_chat",
- "dataset_conf": "gsm8k/gsm8k_gen_0_shot_cot_str_perf",
- "num_prompts": 80,
- "max_out_len": 1500,
- "batch_size": 20,
- "request_rate": 0,
- "baseline": 1,
- "threshold": 0.97
}]
+mode_aisbench = {"eplb": aisbench_gsm8k}
+
@pytest.mark.asyncio
@pytest.mark.parametrize("model", MODELS)
-@pytest.mark.parametrize("tp_size", TENSOR_PARALLELS)
-@pytest.mark.parametrize("dp_size", DATA_PARALLELS)
-async def test_models(model: str, tp_size: int, dp_size: int) -> None:
+@pytest.mark.parametrize("mode", MODES)
+async def test_models(model: str, mode: str) -> None:
port = get_open_port()
env_dict = {
- "TASK_QUEUE_ENABLE": "1",
+ "OMP_NUM_THREADS": "10",
"OMP_PROC_BIND": "false",
- "HCCL_OP_EXPANSION_MODE": "AIV",
- "PAGED_ATTENTION_MASK_LEN": "5500",
- "DYNAMIC_EPLB": "true",
- "HCCL_BUFFSIZE": "1024"
+ "HCCL_BUFFSIZE": "1024",
+ "PYTORCH_NPU_ALLOC_CONF": "expandable_segments:True",
+ "VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"
+ }
+ additional_config: dict[str, Any] = {
+ "ascend_scheduler_config": {
+ "enabled": False
+ },
}
server_args = [
- "--no-enable-prefix-caching", "--enable-expert-parallel",
- "--tensor-parallel-size",
- str(tp_size), "--data-parallel-size",
- str(dp_size), "--port",
- str(port), "--max-model-len", "36864", "--max-num-batched-tokens",
- "36864", "--block-size", "128", "--trust-remote-code",
- "--quantization", "ascend", "--gpu-memory-utilization", "0.9",
- "--additional-config", '{"enable_weight_nz_layout":true, '
- '"torch_air_graph_config":{"enabled": true, "enable_multistream_mla": true, "graph_batch_size": [16], "use_cached_graph": true},'
- '"dynamic_eplb": true, "num_iterations_eplb_update": 1000, "num_wait_worker_iterations": 200'
+ "--quantization", "ascend", "--async-scheduling",
+ "--data-parallel-size", "4", "--tensor-parallel-size", "4",
+ "--enable-expert-parallel", "--port",
+ str(port), "--max-model-len", "40960", "--max-num-batched-tokens",
+ "8192", "--max-num-seqs", "12", "--trust-remote-code",
+ "--gpu-memory-utilization", "0.9"
]
+ if mode == "eplb":
+ env_dict["DYNAMIC_EPLB"] = "true"
+ additional_config["dynamic_eplb"] = True
+ additional_config["num_iterations_eplb_update"] = 2048
+ additional_config["num_wait_worker_iterations"] = 200
+ server_args.extend(["--additional-config", json.dumps(additional_config)])
request_keyword_args: dict[str, Any] = {
**api_keyword_args,
}
@@ -103,5 +100,10 @@ async def test_models(model: str, tp_size: int, dp_size: int) -> None:
)
choices: list[openai.types.CompletionChoice] = batch.choices
assert choices[0].text, "empty response"
+ print(choices)
# aisbench test
- run_aisbench_cases(model, port, aisbench_cases)
+ aisbench_cases = mode_aisbench[mode]
+ run_aisbench_cases(model,
+ port,
+ aisbench_cases,
+ server_args=server_args)
diff --git a/tests/e2e/nightly/models/test_qwen2_5_vl_32b.py b/tests/e2e/nightly/models/test_qwen2_5_vl_32b.py
index 760f8deeff..fe6bbedf2e 100644
--- a/tests/e2e/nightly/models/test_qwen2_5_vl_32b.py
+++ b/tests/e2e/nightly/models/test_qwen2_5_vl_32b.py
@@ -45,7 +45,7 @@
"dataset_conf": "textvqa/textvqa_gen_base64",
"max_out_len": 2048,
"batch_size": 128,
- "baseline": 76,
+ "baseline": 76.22,
"temperature": 0,
"top_k": -1,
"top_p": 1,
diff --git a/tests/e2e/nightly/models/test_qwen2_5_vl_7b.py b/tests/e2e/nightly/models/test_qwen2_5_vl_7b.py
index bc35ff88c7..d3a726bf07 100644
--- a/tests/e2e/nightly/models/test_qwen2_5_vl_7b.py
+++ b/tests/e2e/nightly/models/test_qwen2_5_vl_7b.py
@@ -45,7 +45,7 @@
"dataset_conf": "textvqa/textvqa_gen_base64",
"max_out_len": 2048,
"batch_size": 128,
- "baseline": 81,
+ "baseline": 82.05,
"threshold": 5
}, {
"case_type": "performance",
diff --git a/tests/e2e/nightly/models/test_qwen3_235b_a22b_w8a8_eplb.py b/tests/e2e/nightly/models/test_qwen3_235b_a22b_w8a8_eplb.py
index 52aafa156f..8debeecb2a 100644
--- a/tests/e2e/nightly/models/test_qwen3_235b_a22b_w8a8_eplb.py
+++ b/tests/e2e/nightly/models/test_qwen3_235b_a22b_w8a8_eplb.py
@@ -14,6 +14,7 @@
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
+import json
from typing import Any
import openai
@@ -27,7 +28,7 @@
"vllm-ascend/Qwen3-235B-A22B-W8A8",
]
-TENSOR_PARALLELS = [16]
+MODES = ["eplb"]
prompts = [
"San Francisco is a",
@@ -37,53 +38,53 @@
"max_tokens": 10,
}
-aisbench_cases = [{
+aisbench_gsm8k = [{
"case_type": "accuracy",
"dataset_path": "vllm-ascend/gsm8k-lite",
"request_conf": "vllm_api_general_chat",
"dataset_conf": "gsm8k/gsm8k_gen_0_shot_cot_chat_prompt",
"max_out_len": 32768,
"batch_size": 32,
- "baseline": 93,
- "threshold": 5
-}, {
- "case_type": "performance",
- "dataset_path": "vllm-ascend/GSM8K-in3500-bs400",
- "request_conf": "vllm_api_stream_chat",
- "dataset_conf": "gsm8k/gsm8k_gen_0_shot_cot_str_perf",
- "num_prompts": 80,
- "max_out_len": 1500,
- "batch_size": 20,
- "request_rate": 0,
- "baseline": 1,
- "threshold": 0.97
+ "top_k": 20,
+ "baseline": 95,
+ "threshold": 5,
+ "topk": 20
}]
+mode_aisbench = {"eplb": aisbench_gsm8k}
+
@pytest.mark.asyncio
@pytest.mark.parametrize("model", MODELS)
-@pytest.mark.parametrize("tp_size", TENSOR_PARALLELS)
-async def test_models(model: str, tp_size: int) -> None:
+@pytest.mark.parametrize("mode", MODES)
+async def test_models(model: str, mode: str) -> None:
port = get_open_port()
env_dict = {
- "TASK_QUEUE_ENABLE": "1",
+ "OMP_NUM_THREADS": "10",
"OMP_PROC_BIND": "false",
- "HCCL_OP_EXPANSION_MODE": "AIV",
- "PAGED_ATTENTION_MASK_LEN": "5500",
- "DYNAMIC_EPLB": "true",
- "HCCL_BUFFSIZE": "1024"
+ "HCCL_BUFFSIZE": "1024",
+ "PYTORCH_NPU_ALLOC_CONF": "expandable_segments:True",
+ "VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"
+ }
+ additional_config: dict[str, Any] = {
+ "ascend_scheduler_config": {
+ "enabled": False
+ },
}
server_args = [
- "--no-enable-prefix-caching", "--enable-expert-parallel",
- "--tensor-parallel-size",
- str(tp_size), "--port",
- str(port), "--max-model-len", "36864", "--max-num-batched-tokens",
- "36864", "--block-size", "128", "--trust-remote-code",
- "--quantization", "ascend", "--gpu-memory-utilization", "0.9",
- "--additional-config",
- '{"enable_weight_nz_layout":true, "dynamic_eplb": true, '
- '"num_iterations_eplb_update": 1000, "num_wait_worker_iterations": 200}'
+ "--quantization", "ascend", "--async-scheduling",
+ "--data-parallel-size", "4", "--tensor-parallel-size", "4",
+ "--enable-expert-parallel", "--port",
+ str(port), "--max-model-len", "40960", "--max-num-batched-tokens",
+ "8192", "--max-num-seqs", "12", "--trust-remote-code",
+ "--gpu-memory-utilization", "0.9"
]
+ if mode == "eplb":
+ env_dict["DYNAMIC_EPLB"] = "true"
+ additional_config["dynamic_eplb"] = True
+ additional_config["num_iterations_eplb_update"] = 2048
+ additional_config["num_wait_worker_iterations"] = 200
+ server_args.extend(["--additional-config", json.dumps(additional_config)])
request_keyword_args: dict[str, Any] = {
**api_keyword_args,
}
@@ -100,5 +101,10 @@ async def test_models(model: str, tp_size: int) -> None:
)
choices: list[openai.types.CompletionChoice] = batch.choices
assert choices[0].text, "empty response"
+ print(choices)
# aisbench test
- run_aisbench_cases(model, port, aisbench_cases)
+ aisbench_cases = mode_aisbench[mode]
+ run_aisbench_cases(model,
+ port,
+ aisbench_cases,
+ server_args=server_args)
diff --git a/tests/e2e/nightly/models/test_qwen3_235b_w8a8.py b/tests/e2e/nightly/models/test_qwen3_235b_w8a8.py
new file mode 100644
index 0000000000..8220e4d59a
--- /dev/null
+++ b/tests/e2e/nightly/models/test_qwen3_235b_w8a8.py
@@ -0,0 +1,107 @@
+# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
+# Copyright 2023 The vLLM team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# This file is a part of the vllm-ascend project.
+#
+import json
+from typing import Any
+
+import openai
+import pytest
+from vllm.utils import get_open_port
+
+from tests.e2e.conftest import RemoteOpenAIServer
+from tools.aisbench import run_aisbench_cases
+
+MODELS = [
+ "vllm-ascend/Qwen3-235B-A22B-W8A8",
+]
+
+MODES = ["full_graph", "piecewise"]
+
+prompts = [
+ "San Francisco is a",
+]
+
+api_keyword_args = {
+ "max_tokens": 10,
+}
+
+aisbench_cases = [{
+ "case_type": "accuracy",
+ "dataset_path": "vllm-ascend/gsm8k-lite",
+ "request_conf": "vllm_api_general_chat",
+ "dataset_conf": "gsm8k/gsm8k_gen_0_shot_cot_chat_prompt",
+ "max_out_len": 32768,
+ "batch_size": 32,
+ "top_k": 20,
+ "baseline": 95,
+ "threshold": 5
+}]
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("model", MODELS)
+@pytest.mark.parametrize("mode", MODES)
+async def test_models(model: str, mode: str) -> None:
+ port = get_open_port()
+ env_dict = {
+ "OMP_NUM_THREADS": "10",
+ "OMP_PROC_BIND": "false",
+ "HCCL_BUFFSIZE": "1024",
+ "PYTORCH_NPU_ALLOC_CONF": "expandable_segments:True",
+ "VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"
+ }
+ additional_config: dict[str, Any] = {
+ "ascend_scheduler_config": {
+ "enabled": False
+ },
+ }
+ compilation_config = {"cudagraph_mode": "FULL_DECODE_ONLY"}
+ server_args = [
+ "--quantization", "ascend", "--async-scheduling",
+ "--data-parallel-size", "4", "--tensor-parallel-size", "4",
+ "--enable-expert-parallel", "--port",
+ str(port), "--max-model-len", "40960", "--max-num-batched-tokens",
+ "8192", "--max-num-seqs", "12", "--trust-remote-code",
+ "--gpu-memory-utilization", "0.9"
+ ]
+ if mode == "piecewise":
+ compilation_config["cudagraph_mode"] = "PIECEWISE"
+ server_args.extend(
+ ["--compilation-config",
+ json.dumps(compilation_config)])
+ server_args.extend(["--additional-config", json.dumps(additional_config)])
+ request_keyword_args: dict[str, Any] = {
+ **api_keyword_args,
+ }
+ with RemoteOpenAIServer(model,
+ server_args,
+ server_port=port,
+ env_dict=env_dict,
+ auto_port=False) as server:
+ client = server.get_async_client()
+ batch = await client.completions.create(
+ model=model,
+ prompt=prompts,
+ **request_keyword_args,
+ )
+ choices: list[openai.types.CompletionChoice] = batch.choices
+ assert choices[0].text, "empty response"
+ print(choices)
+ # aisbench test
+ run_aisbench_cases(model,
+ port,
+ aisbench_cases,
+ server_args=server_args)
diff --git a/tests/e2e/nightly/models/test_qwen3_30b_w8a8.py b/tests/e2e/nightly/models/test_qwen3_30b_w8a8.py
new file mode 100644
index 0000000000..307a1575cc
--- /dev/null
+++ b/tests/e2e/nightly/models/test_qwen3_30b_w8a8.py
@@ -0,0 +1,92 @@
+# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
+# Copyright 2023 The vLLM team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# This file is a part of the vllm-ascend project.
+#
+from typing import Any
+
+import openai
+import pytest
+from vllm.utils import get_open_port
+
+from tests.e2e.conftest import RemoteOpenAIServer
+from tools.aisbench import run_aisbench_cases
+
+MODELS = [
+ "vllm-ascend/Qwen3-30B-A3B-W8A8",
+]
+
+TENSOR_PARALLELS = [1]
+
+prompts = [
+ "San Francisco is a",
+]
+
+api_keyword_args = {
+ "max_tokens": 10,
+}
+
+aisbench_cases = [{
+ "case_type": "performance",
+ "dataset_path": "vllm-ascend/GSM8K-in3500-bs400",
+ "request_conf": "vllm_api_stream_chat",
+ "dataset_conf": "gsm8k/gsm8k_gen_0_shot_cot_str_perf",
+ "num_prompts": 180,
+ "max_out_len": 1500,
+ "batch_size": 45,
+ "request_rate": 0,
+ "baseline": 1,
+ "threshold": 0.97
+}]
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("model", MODELS)
+@pytest.mark.parametrize("tp_size", TENSOR_PARALLELS)
+async def test_models(model: str, tp_size: int) -> None:
+ port = get_open_port()
+ env_dict = {
+ "OMP_PROC_BIND": "false",
+ "OMP_NUM_THREADS": "10",
+ "HCCL_BUFFSIZE": "1024",
+ "HCCL_OP_EXPANSION_MODE": "AIV",
+ "PYTORCH_NPU_ALLOC_CONF": "expandable_segments:True"
+ }
+ server_args = [
+ "--quantization", "ascend", "--async-scheduling",
+ "--no-enable-prefix-caching", "--tensor-parallel-size",
+ str(tp_size), "--port",
+ str(port), "--max-model-len", "5600", "--max-num-batched-tokens",
+ "16384", "--max-num-seqs", "100", "--trust-remote-code",
+ "--gpu-memory-utilization", "0.9", "--compilation-config",
+ '{"cudagraph_mode": "FULL_DECODE_ONLY"}'
+ ]
+ request_keyword_args: dict[str, Any] = {
+ **api_keyword_args,
+ }
+ with RemoteOpenAIServer(model,
+ server_args,
+ server_port=port,
+ env_dict=env_dict,
+ auto_port=False) as server:
+ client = server.get_async_client()
+ batch = await client.completions.create(
+ model=model,
+ prompt=prompts,
+ **request_keyword_args,
+ )
+ choices: list[openai.types.CompletionChoice] = batch.choices
+ assert choices[0].text, "empty response"
+ # aisbench test
+ run_aisbench_cases(model, port, aisbench_cases)
diff --git a/tests/e2e/nightly/models/test_qwen3_32b_int8.py b/tests/e2e/nightly/models/test_qwen3_32b_int8.py
index e245f3d74f..bbaf863aa9 100644
--- a/tests/e2e/nightly/models/test_qwen3_32b_int8.py
+++ b/tests/e2e/nightly/models/test_qwen3_32b_int8.py
@@ -58,7 +58,7 @@
"max_out_len": 32768,
"batch_size": 32,
"baseline": 83.33,
- "threshold": 17
+ "threshold": 7
}, {
"case_type": "performance",
"dataset_path": "vllm-ascend/GSM8K-in3500-bs400",
diff --git a/tests/e2e/nightly/multi_node/scripts/build_mooncake.sh b/tests/e2e/nightly/multi_node/scripts/build_mooncake.sh
index 8fb4610bce..7627cf0c95 100644
--- a/tests/e2e/nightly/multi_node/scripts/build_mooncake.sh
+++ b/tests/e2e/nightly/multi_node/scripts/build_mooncake.sh
@@ -9,15 +9,13 @@ YELLOW="\033[0;33m"
RED="\033[0;31m"
NC="\033[0m" # No Color
-branch=${1:-pooling_async_memecpy_v1}
-point=${2:-8fce1ffab3930fec2a8b8d3be282564dfa1bb186}
+branch=${1:-v0.3.7.post2}
-repo_url="https://github.com/AscendTransport/Mooncake"
+repo_url="https://github.com/kvcache-ai/Mooncake"
repo_name="Mooncake"
state_file=".build_state"
echo "[INFO] Branch: $branch"
-echo "[INFO] Commit: $point"
echo "-------------------------------------------"
@@ -29,22 +27,36 @@ if ! is_done "clone"; then
if [ -d "$repo_name" ]; then
echo "[WARN] Directory $repo_name already exists, skipping clone."
else
- git clone -b "$branch" "$repo_url" "$repo_name"
+ git clone --branch "$branch" --depth 1 "$repo_url" "$repo_name"
fi
- cd "$repo_name"
- git fetch --all
- git checkout "$point" || { echo "[ERROR] Checkout failed."; exit 1; }
- cd ..
mark_done "clone"
else
echo "[SKIP] Clone step already done."
fi
+init_ascend_env() {
+ cann_in_sys_path=/usr/local/Ascend/ascend-toolkit; \
+ cann_in_user_path=$HOME/Ascend/ascend-toolkit; \
+ uname_m=$(uname -m) && \
+ if [ -f "${cann_in_sys_path}/set_env.sh" ]; then \
+ source ${cann_in_sys_path}/set_env.sh; \
+ export LD_LIBRARY_PATH=${cann_in_sys_path}/latest/lib64:${cann_in_sys_path}/latest/${uname_m}-linux/devlib:${LD_LIBRARY_PATH} ; \
+ elif [ -f "${cann_in_user_path}/set_env.sh" ]; then \
+ source "$HOME/Ascend/ascend-toolkit/set_env.sh"; \
+ export LD_LIBRARY_PATH=${cann_in_user_path}/latest/lib64:${cann_in_user_path}/latest/${uname_m}-linux/devlib:${LD_LIBRARY_PATH}; \
+ else \
+ echo "No Ascend Toolkit found"; \
+ exit 1; \
+ fi
+}
+
+init_ascend_env
if ! is_done "deps"; then
cd "$repo_name"
- echo "[STEP]Installing dependencies (ignore Go failure)..."
- yes | bash dependencies.sh || echo "⚠️ dependencies.sh failed (Go install likely failed), continuing..."
+ echo "[STEP]Installing dependencies..."
+ sed -i 's|https://go.dev/dl/|https://golang.google.cn/dl/|g' dependencies.sh
+ bash dependencies.sh -y
cd ..
mark_done "deps"
else
@@ -74,7 +86,7 @@ if ! is_done "build"; then
fi
mkdir build && cd build
- cmake .. || { echo "[ERROR] cmake failed."; exit 1; }
+ cmake .. -USE_ASCEND_DIRECT=ON || { echo "[ERROR] cmake failed."; exit 1; }
make -j || { echo "[ERROR] make failed."; exit 1; }
make install || { echo "[ERROR] make install failed."; exit 1; }
mark_done "build"
@@ -83,19 +95,6 @@ else
fi
-if ! is_done "copy_lib"; then
- echo "[STEP] Copy library files..."
- cp mooncake-transfer-engine/src/transport/ascend_transport/hccl_transport/ascend_transport_c/libascend_transport_mem.so \
- /usr/local/Ascend/ascend-toolkit/latest/python/site-packages/
- cp mooncake-transfer-engine/src/libtransfer_engine.so \
- /usr/local/Ascend/ascend-toolkit/latest/python/site-packages/
- cd ..
- mark_done "copy_lib"
-else
- echo "[SKIP] Library copy already done."
-fi
-
-
if ! grep -q "export LD_LIBRARY_PATH=/usr/local/Ascend/ascend-toolkit/latest/python/site-packages:$LD_LIBRARY_PATH" ~/.bashrc; then
echo -e "${YELLOW}Adding LD_LIBRARY_PATH to your PATH in ~/.bashrc${NC}"
echo 'export LD_LIBRARY_PATH=/usr/local/Ascend/ascend-toolkit/latest/python/site-packages:$LD_LIBRARY_PATH' >> ~/.bashrc
diff --git a/tests/e2e/nightly/multi_node/scripts/run.sh b/tests/e2e/nightly/multi_node/scripts/run.sh
index 60a7ce0883..e4c2555534 100644
--- a/tests/e2e/nightly/multi_node/scripts/run.sh
+++ b/tests/e2e/nightly/multi_node/scripts/run.sh
@@ -9,7 +9,6 @@ RED="\033[0;31m"
NC="\033[0m" # No Color
# Configuration
-GOVER=1.23.8
LOG_DIR="/root/.cache/tests/logs"
OVERWRITE_LOGS=true
SRC_DIR="$WORKSPACE/source_code"
@@ -97,34 +96,6 @@ install_vllm() {
pip install -r "$SRC_DIR/vllm-ascend/requirements-dev.txt"
}
-download_go() {
- ARCH=$(uname -m)
- GOVER=1.23.8
- if [ "$ARCH" = "aarch64" ]; then
- ARCH="arm64"
- elif [ "$ARCH" = "x86_64" ]; then
- ARCH="amd64"
- else
- echo "Unsupported architecture: $ARCH"
- exit 1
- fi
- # Download Go
- echo "Downloading Go $GOVER..."
- wget -q --show-progress https://golang.google.cn/dl/go$GOVER.linux-$ARCH.tar.gz
- check_success "Failed to download Go $GOVER"
-
- # Install Go
- echo "Installing Go $GOVER..."
- tar -C /usr/local -xzf go$GOVER.linux-$ARCH.tar.gz
- check_success "Failed to install Go $GOVER"
-
- # Clean up downloaded file
- rm -f go$GOVER.linux-$ARCH.tar.gz
- check_success "Failed to clean up Go installation file"
-
- print_success "Go $GOVER installed successfully"
-}
-
install_ais_bench() {
local AIS_BENCH="$SRC_DIR/vllm-ascend/benchmark"
git clone https://gitee.com/aisbench/benchmark.git $AIS_BENCH
@@ -136,29 +107,6 @@ install_ais_bench() {
cd -
}
-install_go() {
- # Check if Go is already installed
- if command -v go &> /dev/null; then
- GO_VERSION=$(go version | awk '{print $3}')
- if [[ "$GO_VERSION" == "go$GOVER" ]]; then
- echo -e "${YELLOW}Go $GOVER is already installed. Skipping...${NC}"
- else
- echo -e "${YELLOW}Found Go $GO_VERSION. Will install Go $GOVER...${NC}"
- download_go
- fi
- else
- download_go
- fi
-
- # Add Go to PATH if not already there
- if ! grep -q "export PATH=\$PATH:/usr/local/go/bin" ~/.bashrc; then
- echo -e "${YELLOW}Adding Go to your PATH in ~/.bashrc${NC}"
- echo 'export PATH=$PATH:/usr/local/go/bin' >> ~/.bashrc
- echo -e "${YELLOW}Please run 'source ~/.bashrc' or start a new terminal to use Go${NC}"
- fi
- export PATH=$PATH:/usr/local/go/bin
-}
-
kill_npu_processes() {
pgrep python3 | xargs -r kill -9
pgrep VLLM | xargs -r kill -9
@@ -193,11 +141,8 @@ main() {
install_sys_dependencies
install_vllm
install_ais_bench
- # to speed up mooncake build process, install Go here
- install_go
cd "$WORKSPACE/source_code"
- . $SRC_DIR/vllm-ascend/tests/e2e/nightly/multi_node/scripts/build_mooncake.sh \
- "pooling_async_memecpy_v1" "8fce1ffab3930fec2a8b8d3be282564dfa1bb186"
+ . $SRC_DIR/vllm-ascend/tests/e2e/nightly/multi_node/scripts/build_mooncake.sh
cd "$WORKSPACE/source_code/vllm-ascend"
run_tests_with_log
}
diff --git a/tests/e2e/singlecard/test_camem.py b/tests/e2e/singlecard/test_camem.py
index 04643c8082..2fe4a8553a 100644
--- a/tests/e2e/singlecard/test_camem.py
+++ b/tests/e2e/singlecard/test_camem.py
@@ -18,6 +18,8 @@
#
import gc
+import os
+from unittest.mock import patch
import torch
from vllm import SamplingParams
@@ -71,6 +73,7 @@ def test_basic_camem():
@fork_new_process_for_each_test
+@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_NZ": "0"})
def test_end_to_end():
free, total = torch.npu.mem_get_info()
used_bytes_baseline = total - free # in case other process is running
diff --git a/tests/ut/attention/test_attention_v1.py b/tests/ut/attention/test_attention_v1.py
index dfb9a2a07f..9c732f61de 100644
--- a/tests/ut/attention/test_attention_v1.py
+++ b/tests/ut/attention/test_attention_v1.py
@@ -63,10 +63,26 @@ def test_copy_blocks(self):
class TestAscendAttentionMetadataBuilder(TestBase):
- def setUp(self):
+ @patch('vllm.distributed.parallel_state.get_dcp_group')
+ @patch('vllm.distributed.parallel_state._DCP',
+ new_callable=lambda: MagicMock(spec=GroupCoordinator))
+ @patch("vllm.distributed.get_decode_context_model_parallel_world_size",
+ return_value=1)
+ def setUp(self, mock_get_dcp_size, mock_dcp, mock_get_dcp_group):
+ mock_dcp.world_size = 1
+ dcp_group = MagicMock(spec=GroupCoordinator)
+ dcp_group.rank_in_group = 0
+ dcp_group.world_size = 1
+ dcp_group.device_group = MagicMock()
+ mock_get_dcp_group.return_value = dcp_group
+
self.mock_vllm_config = MagicMock()
+ self.mock_vllm_config.speculative_config = None
self.mock_vllm_config.model_config.max_model_len = 640
self.mock_vllm_config.cache_config.block_size = 64
+ self.mock_vllm_config.compilation_config.cudagraph_mode = None
+ self.mock_vllm_config.scheduler_config.max_num_seqs = 10
+ self.mock_vllm_config.scheduler_config.decode_max_num_seqs = 10
self.mock_device = 'cpu:0'
self.builder = AscendAttentionMetadataBuilder(None, None,
self.mock_vllm_config,
diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py
index d8ddc6a6d6..8d15bcaab1 100644
--- a/tests/ut/attention/test_mla_v1.py
+++ b/tests/ut/attention/test_mla_v1.py
@@ -82,7 +82,8 @@ def test_ascend_mla_prefill_metadata_with_chunked_context(self):
seq_tot=seq_tot,
max_seq_lens=max_seq_lens,
workspace=workspace,
- chunk_seq_lens=chunk_seq_lens)
+ chunk_seq_lens=chunk_seq_lens,
+ chunk_seq_lens_npu=chunk_seq_lens)
metadata = AscendMLAPrefillMetadata(
attn_mask=torch.tensor([[1, 0], [1, 1]], dtype=torch.bool),
@@ -103,6 +104,8 @@ def test_ascend_mla_prefill_metadata_with_chunked_context(self):
self.assertEqual(metadata.chunked_context.max_seq_lens, max_seq_lens)
self.assertIs(metadata.chunked_context.workspace, workspace)
self.assertIs(metadata.chunked_context.chunk_seq_lens, chunk_seq_lens)
+ self.assertIs(metadata.chunked_context.chunk_seq_lens_npu,
+ chunk_seq_lens)
class TestAscendMLADecodeMetadata(TestBase):
@@ -478,6 +481,7 @@ def test_compute_prefill_context(self, mock_ring, mock_load):
chunk_ctx = MagicMock()
chunk_ctx.seq_tot = [8]
chunk_ctx.chunk_seq_lens = [torch.tensor([8])]
+ chunk_ctx.chunk_seq_lens_npu = [torch.tensor([8])]
chunk_ctx.starts = [torch.tensor([0])]
prefill_meta = MagicMock()
diff --git a/tests/ut/distributed/mooncake/test_config_data.py b/tests/ut/distributed/mooncake/test_config_data.py
new file mode 100644
index 0000000000..4408b41a82
--- /dev/null
+++ b/tests/ut/distributed/mooncake/test_config_data.py
@@ -0,0 +1,68 @@
+import unittest
+
+from vllm_ascend.distributed.mooncake.config_data import (
+ _convert_to_bytes, _parse_global_segment_size)
+
+
+class TestParseGlobalSegmentSize(unittest.TestCase):
+
+ def test_int_input(self):
+ self.assertEqual(_parse_global_segment_size(1024), 1024)
+ self.assertEqual(_parse_global_segment_size(0), 0)
+
+ def test_gb_unit(self):
+ self.assertEqual(_parse_global_segment_size("2GB"), 2 * 1024**3)
+ self.assertEqual(_parse_global_segment_size("1.5GB"),
+ int(1.5 * 1024**3))
+ self.assertEqual(_parse_global_segment_size(" 2 GB "), 2 * 1024**3)
+
+ def test_gb_unit_edge_cases(self):
+ with self.assertRaises(ValueError):
+ _parse_global_segment_size("GB")
+ with self.assertRaises(ValueError):
+ _parse_global_segment_size("abcGB")
+
+ def test_mb_unit(self):
+ self.assertEqual(_parse_global_segment_size("512MB"), 512 * 1024**2)
+ self.assertEqual(_parse_global_segment_size("0.5MB"),
+ int(0.5 * 1024**2))
+ self.assertEqual(_parse_global_segment_size("1024MB"), 1024 * 1024**2)
+
+ def test_kb_unit(self):
+ self.assertEqual(_parse_global_segment_size("256KB"), 256 * 1024)
+ self.assertEqual(_parse_global_segment_size("1.25KB"),
+ int(1.25 * 1024))
+
+ def test_b_unit(self):
+ self.assertEqual(_parse_global_segment_size("4096B"), 4096)
+ self.assertEqual(_parse_global_segment_size("1024b"), 1024)
+
+ def test_no_unit(self):
+ self.assertEqual(_parse_global_segment_size("2048"), 2048)
+ self.assertEqual(_parse_global_segment_size("0"), 0)
+
+ def test_non_string_non_int_input(self):
+ self.assertEqual(_parse_global_segment_size(2048.0), 2048)
+ self.assertEqual(_parse_global_segment_size(True), 1)
+
+ with self.assertRaises(TypeError):
+ _parse_global_segment_size(None)
+
+ with self.assertRaises(TypeError):
+ _parse_global_segment_size({"size": 1024})
+
+
+class TestConvertToBytes(unittest.TestCase):
+
+ def test_valid_conversion(self):
+ self.assertEqual(_convert_to_bytes("10", 1, "10"), 10)
+ self.assertEqual(_convert_to_bytes("1.5", 1024, "1.5KB"),
+ int(1.5 * 1024))
+ self.assertEqual(_convert_to_bytes("0", 1024**3, "0GB"), 0)
+
+ def test_invalid_numbers(self):
+ with self.assertRaises(ValueError):
+ _convert_to_bytes("abc", 1, "abc")
+
+ with self.assertRaises(ValueError):
+ _convert_to_bytes("1.2.3", 1024, "1.2.3KB")
diff --git a/tests/ut/kv_connector/test_mooncake_connector.py b/tests/ut/kv_connector/test_mooncake_connector.py
index fa78a46f3e..a5bc066f1a 100644
--- a/tests/ut/kv_connector/test_mooncake_connector.py
+++ b/tests/ut/kv_connector/test_mooncake_connector.py
@@ -978,9 +978,6 @@ def __init__(self, *args, **kwargs):
self.data_ptr = MagicMock(return_value=0x1000)
-mock_envs_ascend = MagicMock()
-mock_envs_ascend.MOONCAKE_CONNECTOR_PROTOCOL = "mock_protocol"
-
mock_logger = MagicMock()
@@ -1017,14 +1014,15 @@ def mock_string_to_int64_hash(s):
class TestMooncakeConnectorWorker(unittest.TestCase):
def setUp(self):
- self.envs_ascend_mock = MockEnvsAscend()
self.mock_transfer_engine = MagicMock()
self.mock_transfer_engine.get_rpc_port.return_value = 9090
self.mock_transfer_engine.initialize.return_value = 0
self.mock_transfer_engine.register_memory.return_value = 0
self.patches = [
- patch('os.getenv', return_value="10,11"),
+ patch(
+ 'vllm_ascend.distributed.mooncake_layerwise_connector.envs_ascend.PHYSICAL_DEVICES',
+ '10,11'),
patch('torch.Tensor.size', return_value=(10, 16, 8, 16)),
patch('torch.Tensor.element_size', return_value=4),
patch('torch.Tensor.data_ptr', return_value=0x1000),
@@ -1053,8 +1051,6 @@ def setUp(self):
MagicMock()),
patch('vllm_ascend.distributed.mooncake_connector.threading.Event',
MagicMock()),
- patch.dict('sys.modules',
- {'vllm_ascend.envs': self.envs_ascend_mock}),
]
for p in self.patches:
diff --git a/tests/ut/kv_connector/test_mooncake_layerwise_connector.py b/tests/ut/kv_connector/test_mooncake_layerwise_connector.py
index bc9ba253a4..28504c9b79 100644
--- a/tests/ut/kv_connector/test_mooncake_layerwise_connector.py
+++ b/tests/ut/kv_connector/test_mooncake_layerwise_connector.py
@@ -32,6 +32,14 @@ def setUp(self):
self.engine = MagicMock()
self.engine.register_memory.return_value = 0
self.engine.batch_transfer_sync_write.return_value = 1
+ self._patcher_cs = patch(
+ 'vllm_ascend.distributed.mooncake_layerwise_connector.torch_npu.npu.current_stream'
+ )
+ self.mock_current_stream = self._patcher_cs.start()
+ self.addCleanup(self._patcher_cs.stop)
+ fake_stream = MagicMock(name="FakeStream")
+ fake_stream.synchronize = MagicMock()
+ self.mock_current_stream.return_value = fake_stream
self.first_kv_cache = torch.zeros((2, 2, 2, 8),
dtype=torch.float32,
@@ -792,15 +800,15 @@ def test_request_finished(self, mock_method):
class TestMooncakeLayerwiseConnectorWorker(unittest.TestCase):
def setUp(self):
- self.envs_ascend_mock = type("MockEnvsAscend", (),
- {"PHYSICAL_DEVICES": "10,11"})()
self.mock_transfer_engine = MagicMock()
self.mock_transfer_engine.get_rpc_port.return_value = 9090
self.mock_transfer_engine.initialize.return_value = 0
self.mock_transfer_engine.register_memory.return_value = 0
self.patches = [
- patch('os.getenv', return_value="10,11"),
+ patch(
+ 'vllm_ascend.distributed.mooncake_layerwise_connector.envs_ascend.PHYSICAL_DEVICES',
+ '10,11'),
patch('torch.Tensor.size', return_value=(10, 16, 8, 16)),
patch('torch.Tensor.element_size', return_value=4),
patch('torch.Tensor.data_ptr', return_value=0x1000),
@@ -833,8 +841,6 @@ def setUp(self):
patch(
'vllm_ascend.distributed.mooncake_layerwise_connector.threading.Event',
MagicMock()),
- patch.dict('sys.modules',
- {'vllm_ascend.envs': self.envs_ascend_mock}),
patch(
'vllm_ascend.distributed.mooncake_layerwise_connector.get_ascend_config',
return_value=SimpleNamespace(pd_tp_ratio=1,
diff --git a/tests/ut/test_platform.py b/tests/ut/test_platform.py
index d5d4309878..9e0c0b3295 100644
--- a/tests/ut/test_platform.py
+++ b/tests/ut/test_platform.py
@@ -4,6 +4,7 @@
import pytest
import torch
from vllm.config.compilation import CUDAGraphMode
+from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import PlatformEnum
from tests.ut.base import TestBase
@@ -722,3 +723,32 @@ def test_get_static_graph_wrapper_cls_returns_correct_value(self):
self.platform.get_static_graph_wrapper_cls(),
"vllm_ascend.compilation.acl_graph.ACLGraphWrapper",
)
+
+ def test_aclgraph_enable(self):
+ config = EngineArgs()
+ VllmConfig = config.create_engine_config()
+ self.assertEqual(VllmConfig.compilation_config.cudagraph_mode,
+ CUDAGraphMode.PIECEWISE)
+
+ with self.assertLogs(logger="vllm", level="INFO") as cm:
+ from vllm_ascend import platform
+
+ importlib.reload(platform)
+ self.platform.check_and_update_config(VllmConfig)
+ self.assertTrue(
+ "PIECEWISE compilation enabled on NPU. use_inductor not supported - "
+ "using only ACL Graph mode" in cm.output[1])
+ if vllm_version_is("0.11.0"):
+ self.assertEqual(
+ VllmConfig.compilation_config.level,
+ CompilationLevel.PIECEWISE,
+ )
+ else:
+ self.assertEqual(
+ VllmConfig.compilation_config.mode,
+ CompilationMode.VLLM_COMPILE,
+ )
+ self.assertEqual(
+ VllmConfig.compilation_config.cudagraph_mode,
+ CUDAGraphMode.PIECEWISE,
+ )
diff --git a/tests/ut/torchair/test_torchair_mla.py b/tests/ut/torchair/test_torchair_mla.py
index 3dd1d2f7f6..1f108b3eb0 100644
--- a/tests/ut/torchair/test_torchair_mla.py
+++ b/tests/ut/torchair/test_torchair_mla.py
@@ -86,7 +86,8 @@ def test_ascend_mla_prefill_metadata_with_chunked_context(self):
seq_tot=seq_tot,
max_seq_lens=max_seq_lens,
workspace=workspace,
- chunk_seq_lens=chunk_seq_lens)
+ chunk_seq_lens=chunk_seq_lens,
+ chunk_seq_lens_npu=chunk_seq_lens)
metadata = AscendMLATorchairPrefillMetadata(
attn_mask=torch.tensor([[1, 0], [1, 1]], dtype=torch.bool),
@@ -107,6 +108,8 @@ def test_ascend_mla_prefill_metadata_with_chunked_context(self):
self.assertEqual(metadata.chunked_context.max_seq_lens, max_seq_lens)
self.assertIs(metadata.chunked_context.workspace, workspace)
self.assertIs(metadata.chunked_context.chunk_seq_lens, chunk_seq_lens)
+ self.assertIs(metadata.chunked_context.chunk_seq_lens_npu,
+ chunk_seq_lens)
class TestAscendMLATorchairDecodeMetadata(TestBase):
@@ -661,6 +664,7 @@ def test_compute_prefill_context(self, mock_ring, mock_load):
chunk_ctx = MagicMock()
chunk_ctx.seq_tot = [8]
chunk_ctx.chunk_seq_lens = [torch.tensor([8])]
+ chunk_ctx.chunk_seq_lens_npu = [torch.tensor([8])]
chunk_ctx.starts = [torch.tensor([0])]
prefill_meta = MagicMock()
diff --git a/tests/ut/worker/test_worker_v1.py b/tests/ut/worker/test_worker_v1.py
index 1ead0c5750..2fbad2f810 100644
--- a/tests/ut/worker/test_worker_v1.py
+++ b/tests/ut/worker/test_worker_v1.py
@@ -1,3 +1,4 @@
+import os
import unittest
from unittest.mock import MagicMock, patch
@@ -273,6 +274,7 @@ def test_sleep_mode_disabled_raises_error(self, mock_sleep_mode_enabled):
@patch("vllm_ascend.worker.worker_v1.sleep_mode_enabled")
@patch("vllm_ascend.worker.worker_v1.CaMemAllocator")
+ @patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_NZ": "0"})
def test_wake_up_mode_enabled(self, mock_allocator_class,
mock_sleep_mode_enabled):
"""Test wake_up method when sleep mode is enabled"""
@@ -295,6 +297,7 @@ def test_wake_up_mode_enabled(self, mock_allocator_class,
mock_allocator.wake_up.assert_called_once_with(tags=["test_tag"])
@patch("vllm_ascend.worker.worker_v1.sleep_mode_enabled")
+ @patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_NZ": "0"})
def test_wake_up_mode_disabled_raises_error(self, mock_sleep_mode_enabled):
"""Test wake_up method raises exception when sleep mode is disabled"""
from vllm_ascend.worker.worker_v1 import NPUWorker
diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py
index 80e6541e5f..a0f1edd59f 100644
--- a/vllm_ascend/ascend_forward_context.py
+++ b/vllm_ascend/ascend_forward_context.py
@@ -115,12 +115,10 @@ def set_ascend_forward_context(
# the performance may degrade due to the switching of communication methods.
mmrs_fusion = True
if is_moe_model(vllm_config):
- sp_enabled = enable_sp(vllm_config) and \
- tp_world_size > 1 and num_tokens is not None
+ sp_enabled = enable_sp(vllm_config) and num_tokens is not None
mmrs_fusion = False
else:
sp_enabled = enable_sp(vllm_config) and \
- tp_world_size > 1 and \
num_tokens is not None and num_tokens > 1000
forward_context.mmrs_fusion = mmrs_fusion
diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py
index 258d5e3aac..098e77c543 100644
--- a/vllm_ascend/attention/attention_v1.py
+++ b/vllm_ascend/attention/attention_v1.py
@@ -127,7 +127,7 @@ def copy_blocks(
@staticmethod
def get_supported_block_size() -> list[int]:
- return [64]
+ return [128]
class AscendAttentionState(Enum):
@@ -163,8 +163,8 @@ class AscendMetadataForPrefill:
@dataclass
class AscendMetadataForDecode:
""" Decode Specific Metadata for Ascend"""
- num_computed_tokens_of_pcp_dcp: Optional[list[Optional[list[Optional[
- list[int]]]]]] = None
+ num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None
+ batch_seq_mask: torch.Tensor = None
@dataclass
@@ -232,10 +232,36 @@ def __init__(
):
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
+ self.compilation_config = vllm_config.compilation_config
self.device = device
self.max_num_blocks_per_req = cdiv(
self.model_config.max_model_len,
AscendAttentionBackend.get_supported_block_size()[0])
+ decode_max_num_seqs = getattr(vllm_config.scheduler_config,
+ 'decode_max_num_seqs', 0)
+ max_num_seqs = max(vllm_config.scheduler_config.max_num_seqs,
+ decode_max_num_seqs)
+ self.batch_seq_mask_buf = torch.empty(max_num_seqs,
+ dtype=torch.uint8,
+ device=device)
+ self.pcp_size = get_prefill_context_model_parallel_world_size(
+ ) if prefill_context_parallel_enable() else 1
+ self.pcp_rank = get_prefill_context_model_parallel_rank(
+ ) if self.pcp_size > 1 else 0
+ self.dcp_size = get_decode_context_model_parallel_world_size()
+ self.dcp_rank = get_decode_context_model_parallel_rank(
+ ) if self.dcp_size > 1 else 0
+
+ self.speculative_config = vllm_config.speculative_config
+ self.decode_threshold = 1
+ if self.speculative_config:
+ spec_token_num = self.speculative_config.num_speculative_tokens
+ self.decode_threshold += spec_token_num
+ assert self.decode_threshold <= 16, f"decode_threshold exceeded \
+ npu_fused_infer_attention_score TND layout's limit of 16, \
+ got {self.decode_threshold}"
+
+ AscendAttentionMetadataBuilder.reorder_batch_threshold = self.decode_threshold
def reorder_batch(self, input_batch,
scheduler_output: "SchedulerOutput") -> bool:
@@ -356,11 +382,22 @@ def build(
common_long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
if common_long_seq_metadata is not None:
num_computed_tokens_of_pcp_dcp = common_long_seq_metadata.num_computed_tokens_of_pcp_dcp
- num_computed_tokens_of_pcp_dcp = np.array(
+ assert num_computed_tokens_of_pcp_dcp is not None
+ num_computed_tokens_array = np.array(
num_computed_tokens_of_pcp_dcp)
+ num_computed_tokens_array = num_computed_tokens_array[:
+ num_decodes]
+ batch_seq_mask = (
+ num_computed_tokens_array[:, self.pcp_rank,
+ self.dcp_rank] == 0)
+ # TODO: numpy array mode of the shared memory is used to improve performance
+ self.batch_seq_mask_buf[:batch_seq_mask.shape[0]].copy_(
+ torch.from_numpy(batch_seq_mask), non_blocking=True)
decode_metadata = AscendMetadataForDecode(
- num_computed_tokens_of_pcp_dcp=
- num_computed_tokens_of_pcp_dcp)
+ num_computed_tokens_of_pcp_dcp=num_computed_tokens_array,
+ batch_seq_mask=self.batch_seq_mask_buf[:batch_seq_mask.
+ shape[0]],
+ )
attn_metadata = AscendMetadata(
num_actual_tokens=num_actual_tokens,
@@ -869,7 +906,6 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor,
else:
num_heads = self.num_heads
- q_nope = query.view(query.shape[0], 1, query.shape[1], query.shape[2])
k_nope = self.key_cache.view(self.key_cache.shape[0],
self.key_cache.shape[1], -1)
value = self.value_cache.view(self.key_cache.shape[0],
@@ -880,7 +916,7 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor,
'num_key_value_heads':
self.num_kv_heads,
'input_layout':
- "BSND",
+ 'TND',
'atten_mask':
None,
'scale':
@@ -898,10 +934,12 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor,
'actual_seq_lengths_kv':
attn_metadata.decode_meta.
num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, self.dcp_rank],
+ 'actual_seq_lengths':
+ attn_metadata.actual_seq_lengths_q[:attn_metadata.num_decodes],
}
graph_params = get_graph_params()
forward_context: ForwardContext = get_forward_context()
- num_tokens = q_nope.shape[0]
+ num_tokens = query.shape[0]
if forward_context.capturing:
stream = torch_npu.npu.current_stream()
@@ -913,26 +951,27 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor,
workspace = graph_params.workspaces.get(num_tokens)
if workspace is None:
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
- q_nope, k_nope, value, **common_kwargs)
+ query, k_nope, value, **common_kwargs)
update_graph_params_workspaces(num_tokens,
weak_ref_tensors(workspace))
- attn_out = torch.empty_like(q_nope)
- attn_lse = torch.empty((num_tokens, num_heads, 1, 1),
+ attn_out = torch.empty_like(query)
+ attn_lse = torch.empty((num_tokens, num_heads, 1),
dtype=torch.float,
- device=q_nope.device)
-
- graph_params.attn_params[num_tokens].append(
- (weak_ref_tensors(q_nope), weak_ref_tensors(k_nope),
- weak_ref_tensors(value), self.num_heads, self.num_kv_heads,
- self.scale, attn_metadata.block_tables,
- self.key_cache.shape[1], attn_metadata.decode_meta.
- num_computed_tokens_of_pcp_dcp[:, self.pcp_rank,
- self.dcp_rank],
- weak_ref_tensors(attn_out), weak_ref_tensors(attn_lse),
- self.pcp_rank, self.dcp_rank, self.dcp_size))
+ device=query.device)
+
+ graph_params.attn_params[num_tokens].append((
+ weak_ref_tensors(query), weak_ref_tensors(k_nope),
+ weak_ref_tensors(value), self.num_heads, self.num_kv_heads,
+ self.scale, attn_metadata.block_tables,
+ self.key_cache.shape[1], attn_metadata.decode_meta.
+ num_computed_tokens_of_pcp_dcp[:, self.pcp_rank,
+ self.dcp_rank],
+ attn_metadata.actual_seq_lengths_q[:attn_metadata.num_decodes],
+ weak_ref_tensors(attn_out), weak_ref_tensors(attn_lse),
+ self.dcp_size, self.pcp_rank, self.dcp_rank))
torch.npu.graph_task_group_begin(stream)
torch_npu.npu_fused_infer_attention_score.out(
- q_nope,
+ query,
k_nope,
value,
**common_kwargs,
@@ -942,11 +981,17 @@ def _forward_decode_pcp_dcp(self, query: torch.Tensor,
graph_params.handles[num_tokens].append(handle)
else:
attn_out, attn_lse = torch_npu.npu_fused_infer_attention_score(
- q_nope, k_nope, value, **common_kwargs)
+ query, k_nope, value, **common_kwargs)
+
+ out_mask = attn_metadata.decode_meta.batch_seq_mask[:, None,
+ None].expand_as(
+ attn_out)
+ attn_out = torch.where(out_mask, 0, attn_out)
- attn_out = attn_out.view(attn_out.shape[0], attn_out.shape[2],
- attn_out.shape[3])
- attn_lse = attn_lse.view(attn_lse.shape[0], attn_lse.shape[1], 1)
+ lse_mask = attn_metadata.decode_meta.batch_seq_mask[:, None,
+ None].expand_as(
+ attn_lse)
+ attn_lse = torch.where(lse_mask, -torch.inf, attn_lse)
attn_out_lse_list = []
# Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1]
diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py
index faf032536b..6d5c13971a 100644
--- a/vllm_ascend/attention/mla_v1.py
+++ b/vllm_ascend/attention/mla_v1.py
@@ -1,11 +1,10 @@
from dataclasses import dataclass
-from typing import (TYPE_CHECKING, ClassVar, NamedTuple, Optional, Tuple, Type,
- TypeVar)
+from typing import (TYPE_CHECKING, ClassVar, List, NamedTuple, Optional, Tuple,
+ Type, TypeVar)
import numpy as np
import torch
import torch.distributed as dist
-import torch.nn.functional as F
import torch_npu
from torch import nn
from vllm.attention.backends.abstract import (AttentionBackend,
@@ -111,6 +110,7 @@ class ChunkedContextMetadata:
max_seq_lens: list[int]
workspace: torch.Tensor
chunk_seq_lens: torch.Tensor
+ chunk_seq_lens_npu: torch.Tensor
attn_mask: torch.Tensor
query_lens: torch.Tensor
@@ -140,11 +140,8 @@ class AscendMLADecodeMetadata:
attn_mask: Optional[torch.Tensor] = None
sin: torch.Tensor = None
cos: torch.Tensor = None
- num_computed_tokens_of_pcp_dcp: Optional[list[Optional[list[Optional[
- list[int]]]]]] = None
- seq_mask_pcp: torch.Tensor = None
- seq_mask_dcp: torch.Tensor = None
cp_seq_len: torch.Tensor = None
+ batch_seq_mask: torch.Tensor = None
@dataclass
@@ -264,9 +261,10 @@ def __init__(self,
self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
self.cos_cache = None
self.sin_cache = None
+
self.pcp_size = get_prefill_context_model_parallel_world_size(
) if prefill_context_parallel_enable() else 1
- self.cp_rank = get_prefill_context_model_parallel_rank(
+ self.pcp_rank = get_prefill_context_model_parallel_rank(
) if self.pcp_size > 1 else 0
self.dcp_size = get_decode_context_model_parallel_world_size()
self.dcp_rank = get_decode_context_model_parallel_rank(
@@ -274,6 +272,9 @@ def __init__(self,
decode_max_num_seqs = getattr(scheduler_config, 'decode_max_num_seqs',
0)
max_num_seqs = max(scheduler_config.max_num_seqs, decode_max_num_seqs)
+ self.batch_seq_mask_buf = torch.empty(max_num_seqs,
+ dtype=torch.uint8,
+ device=device)
self.seq_mask_pcp_buf = torch.empty(max_num_seqs,
self.pcp_size,
dtype=torch.uint8,
@@ -449,6 +450,7 @@ def build(
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
chunk_seq_lens=chunk_seq_lens,
+ chunk_seq_lens_npu=chunk_seq_lens.npu(),
workspace=self.chunked_prefill_workspace,
)
prefill_input_positions = input_positions[tokens_start:]
@@ -490,36 +492,19 @@ def build(
num_computed_tokens_of_cp_dcp_array = np.array(
num_computed_tokens_of_pcp_dcp
)[:num_decodes] # [bs, pcp_size, dcp_size]
- seq_mask_pcp = torch.where(
- torch.tensor(
- num_computed_tokens_of_cp_dcp_array.sum(2)) == 0, 0,
- 1).to(torch.uint8)
- self.seq_mask_pcp_buf[:seq_mask_pcp.shape[0], :seq_mask_pcp.
- shape[1]].copy_(seq_mask_pcp,
- non_blocking=True)
- seq_mask_pcp_shape = (seq_mask_pcp.shape[0],
- seq_mask_pcp.shape[1])
-
- seq_mask_dcp = torch.where(
- torch.tensor(
- num_computed_tokens_of_cp_dcp_array[:,
- self.cp_rank, :])
- == 0, 0, 1).to(torch.uint8)
- self.seq_mask_dcp_buf[:seq_mask_dcp.shape[0], :seq_mask_dcp.
- shape[1]].copy_(seq_mask_dcp,
- non_blocking=True)
- seq_mask_dcp_shape = (seq_mask_dcp.shape[0],
- seq_mask_dcp.shape[1])
cp_seq_len = num_computed_tokens_of_cp_dcp_array[:,
- self.cp_rank,
+ self.pcp_rank,
self.dcp_rank]
cp_seq_len = torch.tensor(cp_seq_len, dtype=torch.int32)
+ batch_seq_mask = (cp_seq_len == 0)
+ self.batch_seq_mask_buf[:batch_seq_mask.shape[0]].copy_(
+ batch_seq_mask, non_blocking=True)
+ batch_seq_mask = self.batch_seq_mask_buf[:batch_seq_mask.
+ shape[0]]
cp_seq_len = torch.where(cp_seq_len == 0, 1, cp_seq_len)
else:
- seq_mask_pcp_shape = (0, 0)
- seq_mask_dcp_shape = (0, 0)
- cp_seq_len = None
+ cp_seq_len, batch_seq_mask = None, None
# TODO: After the fullgraph supports MTP, the if branch needs to deleted
assert self.cos_cache is not None
@@ -542,15 +527,8 @@ def build(
actual_seq_lengths_q=actual_seq_lengths_q,
sin=sin,
cos=cos,
- num_computed_tokens_of_pcp_dcp=
- num_computed_tokens_of_pcp_dcp,
- seq_mask_pcp=self.
- seq_mask_pcp_buf[:seq_mask_pcp_shape[0], :
- seq_mask_pcp_shape[1]],
- seq_mask_dcp=self.
- seq_mask_dcp_buf[:seq_mask_dcp_shape[0], :
- seq_mask_dcp_shape[1]],
- cp_seq_len=cp_seq_len)
+ cp_seq_len=cp_seq_len,
+ batch_seq_mask=batch_seq_mask)
else:
cos[:num_decode_tokens,
...] = self.cos_cache[input_positions].unsqueeze(
@@ -569,15 +547,8 @@ def build(
actual_seq_lengths_q=actual_seq_lengths_q,
sin=sin[:num_decode_tokens, ...],
cos=cos[:num_decode_tokens, ...],
- num_computed_tokens_of_pcp_dcp=
- num_computed_tokens_of_pcp_dcp,
- seq_mask_pcp=self.
- seq_mask_pcp_buf[:seq_mask_pcp_shape[0], :
- seq_mask_pcp_shape[1]],
- seq_mask_dcp=self.
- seq_mask_dcp_buf[:seq_mask_dcp_shape[0], :
- seq_mask_dcp_shape[1]],
- cp_seq_len=cp_seq_len)
+ cp_seq_len=cp_seq_len,
+ batch_seq_mask=batch_seq_mask)
return self.metadata_cls( # type: ignore
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
@@ -919,7 +890,8 @@ def _compute_prefill_context(
iters = len(prefill_metadata.chunked_context.seq_tot)
- seq_len1 = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32)
+ current_seq_len = torch.tensor(prefill_metadata.query_lens,
+ dtype=torch.int32)
cache_kv_c = kv_c_and_k_pe_cache[0]
cache_k_pe = kv_c_and_k_pe_cache[1]
num_heads = cache_k_pe.size(2)
@@ -927,8 +899,11 @@ def _compute_prefill_context(
for i in range(iters):
toks = prefill_metadata.chunked_context.seq_tot[i]
- seq_len2 = prefill_metadata.chunked_context.chunk_seq_lens[i]
- seq_len = torch.stack([seq_len1, seq_len2])
+ context_seq_len = prefill_metadata.chunked_context.chunk_seq_lens[
+ i]
+ context_seq_len_npu = prefill_metadata.chunked_context.chunk_seq_lens_npu[
+ i]
+ seq_len = torch.stack([current_seq_len, context_seq_len])
kv_c_normed = torch.empty(toks,
num_heads,
latent_kv_dim,
@@ -944,7 +919,7 @@ def _compute_prefill_context(
cache_kv_c,
cache_k_pe,
prefill_metadata.block_table,
- seq_len2.to(q_nope.device),
+ context_seq_len_npu,
seq_starts=prefill_metadata.chunked_context.starts[i],
key=kv_c_normed,
value=k_pe,
@@ -1664,8 +1639,6 @@ def _forward_decode_pcp_dcp(
q_nope = q_nope.view(num_tokens, num_heads, -1)
q_pe = q_pe.view(num_tokens, num_heads, -1)
# use pcp & dcp split computed token nums from scheduler to compute actual seq_len and seq_mask
- seq_mask_pcp = decode_meta.seq_mask_pcp
- seq_mask_dcp = decode_meta.seq_mask_dcp
seq_len = decode_meta.cp_seq_len
common_kwargs = {
@@ -1735,9 +1708,56 @@ def _forward_decode_pcp_dcp(
output=attn_output,
lse=softmax_lse)
+ # Update out&lse
+ attn_out_lse_list = self._process_attn_out_lse(attn_output,
+ softmax_lse,
+ decode_meta)
+ attn_output = self._npu_attention_update(attn_out_lse_list)
+ return self._v_up_proj(attn_output)
+
+ def _npu_attention_update(
+ self, attn_out_lse_list: List[torch.Tensor]) -> torch.Tensor:
+ attn_out_split_cp = []
+ attn_lse_split_cp = []
+
+ for attn_out_lse in attn_out_lse_list:
+ attn_out_allgather, attn_lse_allgather = self._out_lse_reshape(
+ *torch.split(attn_out_lse, [self.kv_lora_rank, 1], dim=-1))
+ attn_out_split_cp.append(attn_out_allgather)
+ attn_lse_split_cp.append(attn_lse_allgather)
+ attn_out, _ = torch_npu.npu_attention_update(attn_lse_split_cp,
+ attn_out_split_cp, 0)
+ attn_out = attn_out.view(-1, attn_out_lse_list[0].shape[1],
+ self.kv_lora_rank)
+ return attn_out
+
+ def _out_lse_reshape(self, attn_out: torch.Tensor,
+ attn_lse: torch.Tensor) -> torch.Tensor:
+ attn_out = attn_out.contiguous().view(
+ attn_out.shape[0] * attn_out.shape[1], attn_out.shape[2])
+ attn_lse = attn_lse.contiguous().view(
+ attn_lse.shape[0] * attn_lse.shape[1] * attn_lse.shape[2])
+ return attn_out, attn_lse
+
+ def _process_attn_out_lse(
+ self,
+ attn_output: torch.Tensor,
+ softmax_lse: torch.Tensor,
+ decode_meta: AscendMLADecodeMetadata,
+ ) -> List[torch.Tensor]:
+ attn_out_lse_list = []
+ out_mask = decode_meta.batch_seq_mask[:, None,
+ None].expand_as(attn_output)
+ attn_output = torch.where(out_mask, 0, attn_output)
+ lse_mask = decode_meta.batch_seq_mask[:, None,
+ None].expand_as(softmax_lse)
+ softmax_lse = torch.where(lse_mask, -torch.inf, softmax_lse)
+
+ softmax_lse = softmax_lse.to(torch.float32)
+ attn_output = attn_output.to(torch.float32)
+ # Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1]
+ attn_out_lse = torch.cat([attn_output, softmax_lse], dim=-1)
if self.dcp_size > 1:
- # Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1]
- attn_out_lse = torch.cat([attn_output, softmax_lse], dim=-1)
# permute: [bs, num_heads, v_head_dim+1] -> [num_heads, v_head_dim+1, bs]
attn_out_lse = attn_out_lse.permute([1, 2, 0]).contiguous()
attn_out_lse_all2all = torch.empty_like(attn_out_lse)
@@ -1746,24 +1766,12 @@ def _forward_decode_pcp_dcp(
group=self.dcp_group)
# permute: [num_heads, v_head_dim+1, bs] -> [bs, num_heads, v_head_dim+1]
attn_out_lse_all2all = attn_out_lse_all2all.permute([2, 0, 1])
- attn_out_lse_split_on_seq = list(
+ if self.pcp_size > 1:
+ attn_out_lse = attn_out_lse_all2all.contiguous()
+ attn_out_lse_list = list(
torch.chunk(attn_out_lse_all2all, self.dcp_size, dim=1))
- # Update out&lse
- attn_out_g = None
- attn_lse_g = None
- for i, attn_out_lse_l in enumerate(attn_out_lse_split_on_seq):
- attn_out_l, attn_lse_l = torch.split(attn_out_lse_l,
- [self.kv_lora_rank, 1],
- dim=-1)
- attn_out_g, attn_lse_g = self._update_out_and_lse(
- attn_out_g, attn_lse_g, attn_out_l, attn_lse_l,
- seq_mask_dcp[:, i])
- attn_output = attn_out_g
- softmax_lse = attn_lse_g
if self.pcp_size > 1:
- # Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1]
- attn_out_lse = torch.cat([attn_output, softmax_lse], dim=-1)
# AllGather out&lse within PCP group
attn_out_lse_list = [
torch.empty_like(attn_out_lse) for _ in range(self.pcp_size)
@@ -1771,45 +1779,12 @@ def _forward_decode_pcp_dcp(
dist.all_gather(attn_out_lse_list,
attn_out_lse,
group=self.pcp_group)
- # Update out&lse
- attn_out_g = None
- attn_lse_g = None
- for i, attn_out_lse_l in enumerate(attn_out_lse_list):
- attn_out_l, attn_lse_l = torch.split(attn_out_lse_l,
- [self.kv_lora_rank, 1],
- dim=-1)
- attn_out_g, attn_lse_g = self._update_out_and_lse(
- attn_out_g, attn_lse_g, attn_out_l, attn_lse_l,
- seq_mask_pcp[:, i])
- attn_output = attn_out_g
- return self._v_up_proj(attn_output)
-
- # TODO use update op to replace this
- def _update_out_and_lse(
- self,
- out: torch.Tensor,
- lse: torch.Tensor,
- block_out: torch.Tensor,
- block_lse: torch.Tensor,
- mask: torch.Tensor = None,
- ):
- if out is None:
- out = block_out.to(torch.float32)
- lse = block_lse
- else:
- if mask is None:
- mask = torch.ones([block_out.size(0)],
- dtype=torch.uint8,
- device=block_out.device)
- out_mask = mask[:, None, None].expand_as(block_out)
- lse_mask = mask[:, None, None].expand_as(block_lse)
- block_out = block_out.to(torch.float32)
- out_without_update = out.clone()
- lse_without_update = lse.clone()
-
- out = out - F.sigmoid(block_lse - lse) * (out - block_out)
- lse = lse - F.logsigmoid(lse - block_lse)
- # mask
- out = torch.where(out_mask, out, out_without_update)
- lse = torch.where(lse_mask, lse, lse_without_update)
- return out, lse
+ if self.dcp_size > 1 and self.pcp_size > 1:
+ attn_out_lse_list_pcp_dcp = []
+ for s in attn_out_lse_list:
+ attn_out_lse_list_split = list(
+ torch.chunk(s, self.dcp_size, dim=1))
+ attn_out_lse_list_pcp_dcp += attn_out_lse_list_split
+ attn_out_lse_list = attn_out_lse_list_pcp_dcp
+
+ return attn_out_lse_list
diff --git a/vllm_ascend/attention/utils.py b/vllm_ascend/attention/utils.py
index ede83f74a5..ccc845c655 100644
--- a/vllm_ascend/attention/utils.py
+++ b/vllm_ascend/attention/utils.py
@@ -16,8 +16,7 @@ class AscendPrefillContextParallelMetadata:
num_actual_tokens_pcp_padded: Optional[int] = None
- num_computed_tokens_of_pcp_dcp: Optional[list[Optional[list[Optional[
- list[int]]]]]] = None
+ num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None
q_head_idx_tensor: torch.Tensor = None
@@ -47,7 +46,7 @@ class AscendCommonAttentionMetadata:
"""
Per-batch attention metadata, shared across layers and backends.
AttentionMetadataBuilder instances use it to construct per-layer metadata.
-
+
For many of the tensors we keep both GPU and CPU versions.
"""
@@ -104,7 +103,16 @@ class AscendCommonAttentionMetadata:
sin: torch.Tensor = None
prefill_context_parallel_metadata: Optional[
- AscendPrefillContextParallelMetadata] = None
+ AscendPrefillContextParallelMetadata
+ ] = None
+
+ max_seq_len: int = -1
+
+ def batch_size(self) -> int:
+ return self.seq_lens_cpu.shape[0]
+
+ def query_lens(self) -> torch.Tensor:
+ return self.query_start_loc[1:] - self.query_start_loc[:-1]
def split_decodes_and_prefills(
@@ -190,7 +198,8 @@ def trans_rope_weight(weight, rope_dim):
nope_part = weight[..., :-rope_dim, :]
rope_part = weight[..., -rope_dim:, :]
reordered_rope_part = torch.cat(
- (rope_part[..., ::2, :], rope_part[..., 1::2, :]), dim=-2)
+ (rope_part[..., ::2, :], rope_part[..., 1::2, :]), dim=-2
+ )
return torch.cat((nope_part, reordered_rope_part), dim=-2).contiguous()
@@ -203,12 +212,34 @@ def transdata(nd_mat, block_size: tuple = (16, 16)):
nz_mat = torch.permute(
torch.reshape(
nd_mat,
- (r // block_size[0], block_size[0], c // block_size[1],
- block_size[1]),
+ (r // block_size[0], block_size[0], c // block_size[1], block_size[1]),
),
[2, 0, 1, 3],
)
nz_mat = torch.reshape(
- nz_mat,
- (nz_mat.shape[0], nz_mat.shape[1] * nz_mat.shape[2], nz_mat.shape[3]))
+ nz_mat, (nz_mat.shape[0], nz_mat.shape[1] * nz_mat.shape[2], nz_mat.shape[3])
+ )
return nz_mat
+
+
+def extend_flat_seqs(
+ seqs: torch.Tensor, end_locs: torch.Tensor, new_vals: torch.Tensor
+) -> torch.Tensor:
+ """
+ This function appends a single new value into multiple sequences
+ that are stored in a flat format. E.g.
+ [x1, x2, y1] and [x3, y2] become [x1, x2, x3, y1, y2]
+ """
+ new_len = seqs.shape[0] + new_vals.shape[0]
+ new_seqs = torch.zeros(new_len, device=seqs.device, dtype=seqs.dtype)
+ # indices for previous seqs
+ start_locs = end_locs[:-1] + 1
+ seqs_new_idxs = torch.ones_like(seqs)
+ seqs_new_idxs[start_locs] += 1
+ seqs_new_idxs = seqs_new_idxs.cumsum(0) - 1
+ # indices for new values
+ new_val_idxs = end_locs + 1 + torch.arange(new_vals.shape[0], device=seqs.device)
+ # assign seqs and new vals
+ new_seqs[seqs_new_idxs] = seqs
+ new_seqs[new_val_idxs] = new_vals
+ return new_seqs
diff --git a/vllm_ascend/compilation/acl_graph.py b/vllm_ascend/compilation/acl_graph.py
index b929326790..5c65936e4c 100644
--- a/vllm_ascend/compilation/acl_graph.py
+++ b/vllm_ascend/compilation/acl_graph.py
@@ -215,8 +215,16 @@ def update_attn_params(update_stream, forward_context, runtime_shape):
output,
) = param
seq_lens = forward_context.attn_metadata[key].seq_lens
- torch.npu.graph_task_update_begin(update_stream, handle)
- torch_npu._npu_paged_attention(
+
+ # When using FULL_DECODE_ONLY, there are some rare bugs for FULL_DECODE_ONLY
+ # mode with GQA. This is triggered by getting workspace for _npu_paged_attention
+ # in torch_npu. On some rare cases, _npu_paged_attention with smaller seq_lens
+ # might encounter a bigger workspace, while currently we use max_model_len to
+ # calculate max workspace in capturing. So additional get_workspace is added
+ # here to avoid such bugs.
+ # TODO(Angazenn): we will remove this once _npu_paged_attention is fully
+ # replaced by npu_fused_infer_attention_score which does not contain such bugs.
+ workspace = torch_npu._npu_paged_attention_get_workspace(
query=query,
key_cache=key_cache,
value_cache=value_cache,
@@ -225,8 +233,18 @@ def update_attn_params(update_stream, forward_context, runtime_shape):
scale_value=scale,
block_table=block_table,
context_lens=seq_lens,
- out=output,
- workspace=graph_params.workspaces.get(runtime_shape))
+ out=output)
+ torch.npu.graph_task_update_begin(update_stream, handle)
+ torch_npu._npu_paged_attention(query=query,
+ key_cache=key_cache,
+ value_cache=value_cache,
+ num_kv_heads=num_kv_heads,
+ num_heads=num_heads,
+ scale_value=scale,
+ block_table=block_table,
+ context_lens=seq_lens,
+ out=output,
+ workspace=workspace)
torch.npu.graph_task_update_end(update_stream)
event.record(update_stream)
@@ -302,16 +320,28 @@ def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape):
graph_params.events[runtime_shape],
):
(q_nope, k_nope, value, num_heads, num_kv_heads, scale,
- block_table, block_size, actual_seq_lengths_kv, attn_output,
- softmax_lse, pcp_rank, dcp_rank, dcp_size) = param
- actual_seq_lengths_kv = forward_context.attn_metadata[
- key].decode_meta.num_computed_tokens_of_pcp_dcp[:, pcp_rank,
- dcp_rank]
+ block_table, block_size, actual_seq_lengths_kv,
+ actual_seq_lengths_q, attn_output, softmax_lse, dcp_size,
+ pcp_rank, dcp_rank) = param
+ attn_metadata = forward_context.attn_metadata[key]
+ actual_seq_lengths_kv = attn_metadata.decode_meta.num_computed_tokens_of_pcp_dcp[:,
+ pcp_rank,
+ dcp_rank]
pad_length = runtime_shape - len(actual_seq_lengths_kv)
- pad_tensor = np.zeros(pad_length,
- dtype=actual_seq_lengths_kv.dtype)
- actual_seq_lengths_kv = np.concatenate(
- [actual_seq_lengths_kv, pad_tensor])
+ if pad_length > 0:
+ pad_tensor = np.zeros(pad_length,
+ dtype=actual_seq_lengths_kv.dtype)
+ actual_seq_lengths_kv = np.concatenate(
+ [actual_seq_lengths_kv, pad_tensor])
+
+ actual_seq_lengths_q = attn_metadata.actual_seq_lengths_q[:
+ attn_metadata
+ .
+ num_decode_tokens]
+ if (runtime_shape - len(actual_seq_lengths_q)):
+ actual_seq_lengths_q = actual_seq_lengths_q + [
+ actual_seq_lengths_q[-1]
+ ] * (runtime_shape - len(actual_seq_lengths_q))
if dcp_size > 1:
num_heads = num_heads * dcp_size
@@ -323,7 +353,7 @@ def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape):
value,
num_heads=num_heads,
num_key_value_heads=num_kv_heads,
- input_layout="BSND",
+ input_layout="TND",
atten_mask=None,
scale=scale,
antiquant_mode=0,
@@ -332,6 +362,7 @@ def update_attn_dcp_pcp_params(update_stream, forward_context, runtime_shape):
block_table=block_table,
block_size=block_size,
actual_seq_lengths_kv=actual_seq_lengths_kv,
+ actual_seq_lengths=actual_seq_lengths_q,
workspace=graph_params.workspaces.get(runtime_shape),
out=[attn_output, softmax_lse])
torch.npu.graph_task_update_end(update_stream)
diff --git a/vllm_ascend/core/schedule_config.py b/vllm_ascend/core/schedule_config.py
index 32d63cbc40..cbc3f977bd 100644
--- a/vllm_ascend/core/schedule_config.py
+++ b/vllm_ascend/core/schedule_config.py
@@ -27,6 +27,7 @@
class AscendSchedulerConfig(SchedulerConfig):
enable_chunked_prefill: bool = False
max_long_partial_prefills: int = 1
+ max_num_partial_prefills: int = 1
long_prefill_token_threshold: int = MAX_INT
policy: str = "fcfs"
scheduler_cls: Union[str, Type[object]] = (
@@ -47,6 +48,7 @@ def initialize_from_config(
# Override default values into original SchedulerConfig
scheduler_config["enable_chunked_prefill"] = False
scheduler_config["max_long_partial_prefills"] = None
+ scheduler_config["max_num_partial_prefills"] = None
scheduler_config["long_prefill_token_threshold"] = None
scheduler_config["policy"] = "fcfs"
scheduler_config["scheduler_cls"] = (
@@ -78,6 +80,9 @@ def __post_init__(self, *args) -> None:
self.max_long_partial_prefills = 1
self.long_prefill_token_threshold = MAX_INT
+ if self.max_num_partial_prefills is None:
+ self.max_num_partial_prefills = 1
+
if self.long_prefill_token_threshold is None or \
self.long_prefill_token_threshold <= 0:
if self.max_model_len is None:
diff --git a/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py b/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py
index e72f4eba26..d92b724f1a 100644
--- a/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py
+++ b/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py
@@ -31,6 +31,7 @@
from vllm.v1.request import Request, RequestStatus
import vllm_ascend.envs as envs_ascend
+from vllm_ascend.distributed.utils import get_transfer_timeout_value
from vllm_ascend.utils import (AscendSocVersion, get_ascend_soc_version,
prefill_context_parallel_enable,
vllm_version_is)
@@ -438,7 +439,7 @@ def init_llm_datadist(self):
assert self.local_agent_metadata is not None
llm_config = LLMConfig()
llm_config.device_id = self.local_rank
- llm_config.sync_kv_timeout = 20000
+ llm_config.sync_kv_timeout = get_transfer_timeout_value()
llm_config.enable_switch_role = True
llm_config.enable_cache_manager = True
llm_config.enable_remote_cache_accessible = True
diff --git a/vllm_ascend/distributed/mooncake/config_data.py b/vllm_ascend/distributed/mooncake/config_data.py
index 745d91131f..36c820b089 100644
--- a/vllm_ascend/distributed/mooncake/config_data.py
+++ b/vllm_ascend/distributed/mooncake/config_data.py
@@ -2,6 +2,7 @@
import hashlib
import json
import os
+import re
from dataclasses import dataclass
from typing import Iterable, List, Optional, Tuple, Union
@@ -11,6 +12,9 @@
from vllm.utils import cdiv, logger
from vllm.v1.core.sched.output import NewRequestData
+DEFAULT_GLOBAL_SEGMENT_SIZE = 3355443200 # 3.125 GiB
+DEFAULT_LOCAL_BUFFER_SIZE = 1073741824 # 1.0 GiB
+
@dataclass
class MooncakeEngineMetadata:
@@ -419,7 +423,7 @@ class LasyerMultiBlockReqMeta:
class MooncakeStoreConfig:
local_hostname: str
metadata_server: str
- global_segment_size: int
+ global_segment_size: Union[int, str]
local_buffer_size: int
protocol: str
device_name: str
@@ -433,8 +437,11 @@ def from_file(file_path: str) -> "MooncakeStoreConfig":
return MooncakeStoreConfig(
local_hostname=config.get("local_hostname"),
metadata_server=config.get("metadata_server"),
- global_segment_size=config.get("global_segment_size", 3355443200),
- local_buffer_size=config.get("local_buffer_size", 1073741824),
+ global_segment_size=_parse_global_segment_size(
+ config.get("global_segment_size",
+ DEFAULT_GLOBAL_SEGMENT_SIZE)),
+ local_buffer_size=(config.get("local_buffer_size",
+ DEFAULT_LOCAL_BUFFER_SIZE)),
protocol=config.get("protocol", "tcp"),
device_name=config.get("device_name", ""),
master_server_address=config.get("master_server_address"),
@@ -446,4 +453,81 @@ def load_from_env() -> "MooncakeStoreConfig":
if not config_path:
raise ValueError(
"The environment variable 'MOONCAKE_CONFIG_PATH' is not set.")
- return MooncakeStoreConfig.from_file(config_path)
\ No newline at end of file
+ return MooncakeStoreConfig.from_file(config_path)
+
+
+def _parse_global_segment_size(value) -> int:
+ """
+ Parse storage size strings with support for units: GB, MB, KB, B
+
+ Args:
+ value: Input value (int, str, or other convertible types)
+
+ Returns:
+ int: Size in bytes
+
+ Raises:
+ ValueError: For invalid format, missing number, or negative values
+ TypeError: For unsupported input types
+ """
+
+ if isinstance(value, int):
+ return value
+ elif not isinstance(value, str):
+ try:
+ return int(value)
+ except (TypeError, ValueError) as e:
+ raise TypeError(
+ f"Unsupported type for global_segment_size: {type(value)}"
+ ) from e
+
+ cleaned_input = value.strip().lower()
+ if not cleaned_input:
+ raise ValueError("global segment size cannot be empty.")
+
+ UNIT_MULTIPLIERS = {
+ 'gb': 1024**3, # 1 GB = 1024^3 bytes
+ 'mb': 1024**2, # 1 MB = 1024^2 bytes
+ 'kb': 1024, # 1 KB = 1024 bytes
+ 'b': 1 # 1 B = 1 byte
+ }
+ pattern = r'^\s*([\d.]+)\s*(gb|mb|kb|b)?\s*$'
+ match = re.match(pattern, cleaned_input)
+
+ if not match:
+ raise ValueError(f"Invalid format: '{value}'")
+
+ number_str = match.group(1)
+ unit = match.group(2) or 'b'
+
+ multiplier = UNIT_MULTIPLIERS[unit]
+ return _convert_to_bytes(number_str, multiplier, value)
+
+
+def _convert_to_bytes(number_str: str, multiplier: int,
+ original_input: str) -> int:
+ """
+ Convert numeric string to byte count
+
+ Args:
+ number_str: Numeric portion of input
+ multiplier: Unit conversion factor
+ original_input: Original input string (for error messages)
+
+ Returns:
+ int: Byte count
+
+ Raises:
+ ValueError: For invalid numbers or negative results
+ """
+ try:
+ numeric_value = float(number_str)
+ except ValueError:
+ raise ValueError(
+ f"Invalid numeric value '{number_str}' in: '{original_input}'")
+ # Calculate byte count
+ try:
+ byte_count = int(numeric_value * multiplier)
+ except OverflowError:
+ raise ValueError(f"Storage size too large: '{original_input}'")
+ return byte_count
diff --git a/vllm_ascend/distributed/mooncake_connector.py b/vllm_ascend/distributed/mooncake_connector.py
index 5dfb125ef4..7951760d1d 100644
--- a/vllm_ascend/distributed/mooncake_connector.py
+++ b/vllm_ascend/distributed/mooncake_connector.py
@@ -2,6 +2,7 @@
import contextlib
import hashlib
import math
+import os
import queue
import random
import struct
@@ -33,6 +34,7 @@
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config
from vllm_ascend.distributed.mooncake.transfer_engine import get_global_te
+from vllm_ascend.distributed.utils import get_transfer_timeout_value
from vllm_ascend.utils import vllm_version_is
if vllm_version_is("0.11.0"):
@@ -855,6 +857,8 @@ class MooncakeConnectorWorker:
def __init__(self, vllm_config: VllmConfig, engine_id: str):
self._get_prefill_decode_size(vllm_config)
+ os.environ["ASCEND_TRANSFER_TIMEOUT"] = str(
+ get_transfer_timeout_value())
if self._prefill_tp_size < self._decode_tp_size:
raise ValueError(
f"prefill_tp_size: {self._prefill_tp_size} must be greater than"
diff --git a/vllm_ascend/distributed/mooncake_layerwise_connector.py b/vllm_ascend/distributed/mooncake_layerwise_connector.py
index 874adb3edd..1c5c0a9260 100644
--- a/vllm_ascend/distributed/mooncake_layerwise_connector.py
+++ b/vllm_ascend/distributed/mooncake_layerwise_connector.py
@@ -3,6 +3,7 @@
import copy
import hashlib
import math
+import os
import queue
import struct
import threading
@@ -18,6 +19,7 @@
import numpy as np
import numpy.typing as npt
import torch
+import torch_npu
import zmq
from mooncake.engine import TransferEngine # type: ignore
from vllm.config import VllmConfig
@@ -31,6 +33,7 @@
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.utils import (align_memory,
+ get_transfer_timeout_value,
kv_alltoall_and_rearrange)
from vllm_ascend.utils import vllm_version_is
@@ -91,6 +94,8 @@ def __init__(self,
self.total_layers = total_layers
self.use_mla = use_mla
self.block_len = block_len
+ self.model_stream = torch_npu.npu.current_stream()
+ self.current_layer = -1
if self.pd_head_ratio > 1:
# regesit kv buffer for tp inequal
@@ -190,7 +195,9 @@ def _transfer_kv_cache(self, req_id, req_meta, layer_index, key, value):
src_list.append(src)
dst_list.append(dst)
length_list.append(length)
- torch.npu.synchronize()
+ if self.current_layer != layer_index:
+ self.current_layer = layer_index
+ self.model_stream.synchronize()
ret = self.engine.batch_transfer_sync_write(
session_id, src_list, dst_list, length_list)
if ret < 0:
@@ -241,7 +248,7 @@ def _transfer_kv_cache(self, req_id, req_meta, layer_index, key, value):
((self.tp_rank // self.num_head_replica) %
self.pd_head_ratio))
src_layer_addr += length
- torch.npu.synchronize()
+ self.model_stream.synchronize()
ret = self.engine.batch_transfer_sync_write(
session_id, src_list, dst_list, length_list)
if ret < 0:
@@ -602,6 +609,8 @@ class MooncakeLayerwiseConnectorWorker:
def __init__(self, vllm_config: VllmConfig, engine_id: str):
self._get_prefill_decode_size(vllm_config)
+ os.environ["ASCEND_TRANSFER_TIMEOUT"] = str(
+ get_transfer_timeout_value())
if self._prefill_tp_size < self._decode_tp_size:
raise ValueError(
f"prefill_tp_size: {self._prefill_tp_size} must be greater than"
diff --git a/vllm_ascend/distributed/utils.py b/vllm_ascend/distributed/utils.py
index 4b1344a16e..c25c1f15f2 100644
--- a/vllm_ascend/distributed/utils.py
+++ b/vllm_ascend/distributed/utils.py
@@ -1,3 +1,5 @@
+import os
+
import torch
import torch.distributed as dist
@@ -45,3 +47,15 @@ def align_memory(tensor: torch.Tensor, alignment: int) -> torch.Tensor:
aligned_addr = (data_ptr + alignment - 1) // alignment * alignment
offset = (aligned_addr - data_ptr) // tensor.element_size()
return tensor[int(offset):]
+
+
+def get_transfer_timeout_value():
+ ascend_transfer_timeout = os.getenv("ASCEND_TRANSFER_TIMEOUT", "")
+ if len(ascend_transfer_timeout) > 0:
+ return int(ascend_transfer_timeout)
+ hccl_rdma_timeout = int(os.getenv('HCCL_RDMA_TIMEOUT',
+ '20')) # type: ignore
+ hccl_rdma_retry_cnt = int(os.getenv('HCCL_RDMA_RETRY_CNT',
+ '7')) # type: ignore
+ return int((4.096 * (2**hccl_rdma_timeout)) * hccl_rdma_retry_cnt // 1000 +
+ 3000)
diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py
index 8f9e1d9899..a6b4081a79 100644
--- a/vllm_ascend/envs.py
+++ b/vllm_ascend/envs.py
@@ -63,7 +63,7 @@
"ASCEND_HOME_PATH":
lambda: os.getenv("ASCEND_HOME_PATH", None),
# The path for HCCL library, it's used by pyhccl communicator backend. If
- # not set, the default value is libhccl.so。
+ # not set, the default value is libhccl.so.
"HCCL_SO_PATH":
lambda: os.environ.get("HCCL_SO_PATH", None),
# The version of vllm is installed. This value is used for developers who
diff --git a/vllm_ascend/models/__init__.py b/vllm_ascend/models/__init__.py
index 21ea48e3ab..956df2eb31 100644
--- a/vllm_ascend/models/__init__.py
+++ b/vllm_ascend/models/__init__.py
@@ -35,6 +35,10 @@ def register_model():
"PanguProMoEForCausalLM",
"vllm_ascend.torchair.models.torchair_pangu_moe:PanguProMoEForCausalLM"
)
+
ModelRegistry.register_model(
"Qwen3NextForCausalLM",
"vllm_ascend.models.qwen3_next:CustomQwen3NextForCausalLM")
+
+ ModelRegistry.register_model(
+ "Qwen3NextMTP", "vllm_ascend.models.qwen3_next_mtp:CustomQwen3NextMTP")
diff --git a/vllm_ascend/models/qwen3_next.py b/vllm_ascend/models/qwen3_next.py
index f5b4b8a142..b0bfde0eb6 100644
--- a/vllm_ascend/models/qwen3_next.py
+++ b/vllm_ascend/models/qwen3_next.py
@@ -260,6 +260,24 @@ def _forward(
mixed_qkv_spec = None
mixed_qkv_non_spec = mixed_qkv
+ # 2.1: process the mutli-query part
+ if spec_sequence_masks is not None:
+ mixed_qkv_spec = mixed_qkv_spec.view(
+ attn_metadata.num_spec_decodes, -1, mixed_qkv_spec.size(-1))
+ mixed_qkv_spec = rearrange(mixed_qkv_spec, 'b l d -> b d l')
+ mixed_qkv_spec = causal_conv1d_update(
+ mixed_qkv_spec,
+ conv_state,
+ conv_weights,
+ self.conv1d.bias,
+ self.activation,
+ conv_state_indices=spec_state_indices_tensor[:, 0]
+ [:attn_metadata.num_spec_decodes],
+ num_accepted_tokens=num_accepted_tokens,
+ validate_data=False,
+ )
+ mixed_qkv_spec = rearrange(mixed_qkv_spec, 'b d l -> (b l) d')
+
# 2.2: process the remaining part
if attn_metadata.num_prefills > 0:
# - "cache_indices" updates the conv_state cache in positions
diff --git a/vllm_ascend/models/qwen3_next_mtp.py b/vllm_ascend/models/qwen3_next_mtp.py
new file mode 100644
index 0000000000..c17d969cb2
--- /dev/null
+++ b/vllm_ascend/models/qwen3_next_mtp.py
@@ -0,0 +1,109 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Inference-only Qwen3Next MTP model."""
+import torch
+from vllm.compilation.decorators import support_torch_compile
+from vllm.config import VllmConfig
+from vllm.model_executor.layers.linear import ColumnParallelLinear
+from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.vocab_parallel_embedding import (
+ DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
+from vllm.model_executor.models.interfaces import SupportsPP
+from vllm.model_executor.models.qwen3_next_mtp import (
+ Qwen3NextMTP, Qwen3NextMultiTokenPredictor)
+from vllm.model_executor.models.utils import (
+ make_empty_intermediate_tensors_factory, maybe_prefix)
+from vllm.transformers_utils.configs import Qwen3NextConfig
+
+from vllm_ascend.models.qwen3_next import (CustomQwen3NextDecoderLayer,
+ Qwen3NextRMSNorm)
+
+
+@support_torch_compile
+class CustomQwen3NextMultiTokenPredictor(Qwen3NextMultiTokenPredictor):
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super(Qwen3NextMultiTokenPredictor, self).__init__()
+
+ model_config = vllm_config.model_config
+ quant_config = vllm_config.quant_config
+ lora_config = vllm_config.lora_config
+ config: Qwen3NextConfig = model_config.hf_config
+
+ self.config = config
+ lora_vocab = ((lora_config.lora_extra_vocab_size *
+ (lora_config.max_loras or 1)) if lora_config else 0)
+ self.vocab_size = config.vocab_size + lora_vocab
+ self.org_vocab_size = config.vocab_size
+
+ self.mtp_start_layer_idx = config.num_hidden_layers
+ self.num_mtp_layers = getattr(config, "num_nextn_predict_layers", 1)
+
+ self.embed_tokens = VocabParallelEmbedding(
+ self.vocab_size,
+ config.hidden_size,
+ org_num_embeddings=config.vocab_size,
+ )
+
+ self.fc = ColumnParallelLinear(self.config.hidden_size * 2,
+ self.config.hidden_size,
+ gather_output=True,
+ bias=False,
+ return_bias=False,
+ quant_config=quant_config,
+ prefix=f'{prefix}.fc')
+
+ # use old version mtp layer name to avoid a exception in vllm
+ self.layers = torch.nn.ModuleList(
+ CustomQwen3NextDecoderLayer(
+ vllm_config,
+ layer_type="full_attention",
+ prefix=f'{prefix}.layers.{self.mtp_start_layer_idx + idx}',
+ ) for idx in range(self.num_mtp_layers))
+
+ self.make_empty_intermediate_tensors = (
+ make_empty_intermediate_tensors_factory(
+ ["hidden_states", "residual"], config.hidden_size))
+
+ self.norm = Qwen3NextRMSNorm(config.hidden_size,
+ eps=config.rms_norm_eps)
+ self.pre_fc_norm_hidden = Qwen3NextRMSNorm(config.hidden_size,
+ eps=config.rms_norm_eps)
+ self.pre_fc_norm_embedding = Qwen3NextRMSNorm(config.hidden_size,
+ eps=config.rms_norm_eps)
+
+
+@support_torch_compile
+class CustomQwen3NextMTP(Qwen3NextMTP, SupportsPP):
+ packed_modules_mapping = {
+ "qkv_proj": [
+ "q_proj",
+ "k_proj",
+ "v_proj",
+ ],
+ "gate_up_proj": ["up_proj", "down_proj"]
+ }
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ config = vllm_config.model_config.hf_config
+ self.vllm_config = vllm_config
+ cache_config = vllm_config.cache_config
+ assert not cache_config.enable_prefix_caching, \
+ "Qwen3NextMTP currently does not support prefix caching"
+
+ self.quant_config = vllm_config.quant_config
+
+ super(Qwen3NextMTP, self).__init__()
+ self.config = config
+ self.model = CustomQwen3NextMultiTokenPredictor(
+ vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model"))
+ self.unpadded_vocab_size = config.vocab_size
+ self.lm_head = ParallelLMHead(self.unpadded_vocab_size,
+ config.hidden_size,
+ org_num_embeddings=config.vocab_size,
+ padding_size=DEFAULT_VOCAB_PADDING_SIZE,
+ prefix=maybe_prefix(prefix, "lm_head"))
+ self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
+ config.vocab_size)
+ self.make_empty_intermediate_tensors = (
+ self.model.make_empty_intermediate_tensors)
diff --git a/vllm_ascend/ops/casual_conv1d.py b/vllm_ascend/ops/casual_conv1d.py
index 2d008899ad..7ddc9cecca 100644
--- a/vllm_ascend/ops/casual_conv1d.py
+++ b/vllm_ascend/ops/casual_conv1d.py
@@ -55,7 +55,7 @@ def causal_conv1d_ref(
final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
dtype_in) # (batch, dim, width - 1)
if final_states_out is not None:
- final_states_out.copy_(final_states)
+ final_states_out[..., :(width - 1)].copy_(final_states)
else:
final_states_out = final_states
out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
diff --git a/vllm_ascend/ops/fused_moe/prepare_finalize.py b/vllm_ascend/ops/fused_moe/prepare_finalize.py
index f158a4bfca..f54d4579ca 100644
--- a/vllm_ascend/ops/fused_moe/prepare_finalize.py
+++ b/vllm_ascend/ops/fused_moe/prepare_finalize.py
@@ -29,7 +29,10 @@
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
-from vllm_ascend.utils import enable_sp
+from vllm_ascend.utils import enable_sp, prefill_context_parallel_enable
+
+if prefill_context_parallel_enable():
+ from vllm.distributed import get_pcp_group
class QuantType(Enum):
@@ -382,6 +385,17 @@ def _prepare_with_dp_group(
hidden_states, 0)
router_logits = self.moe_config.dp_group.all_gather(
router_logits, 0)
+
+ if prefill_context_parallel_enable() and self.moe_config.pcp_size > 1:
+ hidden_states = get_pcp_group().all_gather(
+ hidden_states,
+ dim=0,
+ )
+ router_logits = get_pcp_group().all_gather(
+ router_logits,
+ dim=0,
+ )
+
return hidden_states, router_logits, None, None
def finalize(self,
@@ -431,6 +445,9 @@ def _finalize_with_dp_group(self, hidden_states: torch.Tensor,
hidden_states = get_dp_group().reduce_scatter(hidden_states, 0)
hidden_states = hidden_states[:self.num_tokens]
+ if prefill_context_parallel_enable() and self.moe_config.pcp_size > 1:
+ hidden_states = get_pcp_group().reduce_scatter(hidden_states,
+ dim=0)
if reduce_results and (self.moe_config.tp_size > 1
or self.moe_config.ep_size > 1):
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
@@ -504,6 +521,16 @@ def prepare(
router_logits = self._naive_multicast(router_logits,
self.cu_tokens_across_dp_cpu)
+ if prefill_context_parallel_enable() and self.moe_config.pcp_size > 1:
+ hidden_states = get_pcp_group().all_gather(
+ hidden_states,
+ dim=0,
+ )
+ router_logits = get_pcp_group().all_gather(
+ router_logits,
+ dim=0,
+ )
+
return hidden_states, router_logits, None, None
def finalize(self,
@@ -528,6 +555,10 @@ def finalize(self,
hidden_states) # Sum across DP
hidden_states = hidden_states[start:end, :]
+ if prefill_context_parallel_enable() and self.moe_config.pcp_size > 1:
+ hidden_states = get_pcp_group().reduce_scatter(hidden_states,
+ dim=0)
+
if reduce_results and (self.moe_config.tp_size > 1
or self.moe_config.ep_size > 1):
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
diff --git a/vllm_ascend/ops/fused_moe/token_dispatcher.py b/vllm_ascend/ops/fused_moe/token_dispatcher.py
index 077163c54a..c6bdfe4d1c 100644
--- a/vllm_ascend/ops/fused_moe/token_dispatcher.py
+++ b/vllm_ascend/ops/fused_moe/token_dispatcher.py
@@ -225,7 +225,7 @@ def token_dispatch(self,
"expand_scales": expand_scales
}
- group_list_type = 1 if dynamic_eplb else 0
+ group_list_type = 0
return {
"group_list_type": group_list_type,
diff --git a/vllm_ascend/patch/platform/patch_config.py b/vllm_ascend/patch/platform/patch_config.py
index d6150383f0..94f9131001 100644
--- a/vllm_ascend/patch/platform/patch_config.py
+++ b/vllm_ascend/patch/platform/patch_config.py
@@ -155,11 +155,6 @@ def __post_init__(self):
)
else:
self.method = "draft_model"
- raise NotImplementedError(
- "Speculative decoding with draft model is not "
- "supported yet. Please consider using other "
- "speculative decoding methods such as ngram, medusa, "
- "eagle, or deepseek_mtp.")
# Replace hf_config for EAGLE draft_model
if self.method in ("eagle", "eagle3"):
diff --git a/vllm_ascend/patch/platform/patch_mamba_config.py b/vllm_ascend/patch/platform/patch_mamba_config.py
index ad083f51c9..1b077b4135 100644
--- a/vllm_ascend/patch/platform/patch_mamba_config.py
+++ b/vllm_ascend/patch/platform/patch_mamba_config.py
@@ -58,7 +58,7 @@ def verify_and_update_config(cls, vllm_config) -> None:
block_size=model_config.max_model_len,
).page_size_bytes
- block_alignment_bytes = 64
+ block_alignment_bytes = 128
# some attention backends (e.g. FA) only support setting
# block size to multiple of 16, so let's suggest a value
diff --git a/vllm_ascend/spec_decode/__init__.py b/vllm_ascend/spec_decode/__init__.py
index 3e17944c5c..b0b75f396c 100644
--- a/vllm_ascend/spec_decode/__init__.py
+++ b/vllm_ascend/spec_decode/__init__.py
@@ -20,21 +20,21 @@
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
from vllm_ascend.spec_decode.ngram_proposer import NgramProposer
from vllm_ascend.torchair.torchair_mtp_proposer import TorchairMtpProposer
+from vllm_ascend.spec_decode.draft_proposer import DraftModelProposer
-def get_spec_decode_method(method,
- vllm_config,
- device,
- runner,
- is_torchair_graph=False):
+def get_spec_decode_method(
+ method, vllm_config, device, runner, is_torchair_graph=False
+):
if method == "ngram":
return NgramProposer(vllm_config, device, runner)
- elif method in ["eagle", "eagle3"]:
+ elif method in ("eagle", "eagle3"):
return EagleProposer(vllm_config, device, runner)
- elif method == 'deepseek_mtp':
+ elif method in ('deepseek_mtp', 'qwen3_next_mtp'):
if is_torchair_graph:
return TorchairMtpProposer(vllm_config, device, runner)
return MtpProposer(vllm_config, device, runner)
+ elif method == "draft_model":
+ return DraftModelProposer(vllm_config, device, runner)
else:
- raise ValueError("Unknown speculative decoding method: "
- f"{method}")
+ raise ValueError(f"Unknown speculative decoding method: {method}")
diff --git a/vllm_ascend/spec_decode/draft_proposer.py b/vllm_ascend/spec_decode/draft_proposer.py
new file mode 100644
index 0000000000..1b87a7305d
--- /dev/null
+++ b/vllm_ascend/spec_decode/draft_proposer.py
@@ -0,0 +1,280 @@
+from dataclasses import replace
+from typing import Any
+
+import torch
+from vllm.attention.layer import Attention
+from vllm.config import VllmConfig, get_layers_from_vllm_config
+from vllm.config.speculative import SpeculativeConfig
+from vllm.logger import init_logger
+from vllm.model_executor.model_loader import get_model
+from vllm.v1.core.sched.output import SchedulerOutput
+from vllm.v1.sample.metadata import SamplingMetadata
+from vllm.v1.spec_decode.eagle import PADDING_SLOT_ID
+from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
+from vllm_ascend.attention.attention_v1 import AscendMetadata
+from vllm_ascend.attention.utils import extend_flat_seqs
+from vllm_ascend.spec_decode.eagle_proposer import SpecDecodeBaseProposer
+
+logger = init_logger(__name__)
+
+
+class DraftModelProposer(SpecDecodeBaseProposer):
+ def __init__(
+ self,
+ vllm_config: VllmConfig,
+ device: torch.device,
+ runner=None,
+ ):
+ super().__init__(
+ vllm_config=vllm_config,
+ device=device,
+ pass_hidden_states_to_model=False,
+ runner=runner,
+ )
+ self.draft_model_config = vllm_config.speculative_config.draft_model_config
+ self._raise_if_mrope()
+ self._raise_if_padded_drafter_batch()
+ self._raise_if_vocab_size_mismatch()
+ self._raise_if_draft_tp_mismatch()
+
+ def generate_token_ids(
+ self,
+ valid_sampled_token_ids: list[list[int]],
+ sampling_metadata: SamplingMetadata = None,
+ scheduler_output: SchedulerOutput = None,
+ spec_decode_metadata: SpecDecodeMetadata = None,
+ positions: torch.Tensor = None,
+ num_scheduled_tokens: int = 0,
+ hidden_states: torch.Tensor = None,
+ attn_metadata=None,
+ aux_hidden_states: torch.Tensor = None,
+ ):
+ attn_metadata = self._get_atten_dict(scheduler_output)
+ attn_metadata = attn_metadata[self.attn_layer_name]
+ next_token_ids: list[int] = []
+ for i, token_ids in enumerate(valid_sampled_token_ids):
+ if token_ids:
+ # Common case.
+ next_token_id = token_ids[-1]
+ else:
+ # Partial prefill (rare case).
+ # Get the next token id from the request state.
+ req_id = self.runner.input_batch.req_ids[i]
+ req_state = self.runner.requests[req_id]
+ seq_len = (
+ req_state.num_computed_tokens
+ + scheduler_output.num_scheduled_tokens[req_id]
+ )
+
+ next_token_id = req_state.get_token_id(seq_len)
+ next_token_ids.append(next_token_id)
+ next_token_ids = torch.tensor(
+ next_token_ids, dtype=torch.int32, device=self.device
+ )
+
+ if spec_decode_metadata is None:
+ # input_ids can be None for multimodal models.
+ target_token_ids = self.runner.input_ids[:num_scheduled_tokens]
+ target_positions = positions[:num_scheduled_tokens]
+ cu_num_tokens = attn_metadata.query_start_loc
+ else:
+ num_draft_tokens = spec_decode_metadata.num_draft_tokens
+ num_rejected_tokens = [
+ n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0
+ for i, n in enumerate(num_draft_tokens)
+ ]
+ num_rejected_tokens = torch.tensor(
+ num_rejected_tokens,
+ dtype=torch.int32,
+ device=self.device,
+ )
+ cu_num_tokens, token_indices = self._prepare_inputs(
+ attn_metadata, num_rejected_tokens
+ )
+ target_token_ids = self.runner.input_ids[token_indices]
+ target_positions = positions[token_indices]
+
+ num_reqs = self.runner.input_batch.num_reqs
+ (target_token_ids, target_positions, target_slot_mapping, cu_num_tokens) = merge_next_token_ids_into_token_ids(
+ input_token_ids=target_token_ids,
+ input_positions=target_positions,
+ cad=attn_metadata,
+ next_token_ids=next_token_ids,
+ block_size=self.block_size,
+ max_model_len=self.vllm_config.model_config.max_model_len,
+ arange=self.arange,
+ cu_num_tokens=cu_num_tokens,
+ num_reqs=num_reqs
+ )
+
+ draft_token_ids = self._propose(
+ target_token_ids=target_token_ids,
+ target_positions=target_positions,
+ target_hidden_states=None,
+ target_slot_mapping=target_slot_mapping.to(torch.int32),
+ next_token_ids=next_token_ids,
+ cu_num_tokens=cu_num_tokens,
+ block_table=attn_metadata.block_tables,
+ sampling_metadata=sampling_metadata,
+ )
+ spec_token_ids = draft_token_ids.tolist()
+
+ return spec_token_ids
+
+ def _raise_if_mrope(self):
+ if self.draft_model_config.uses_mrope:
+ raise NotImplementedError(
+ "Speculative Decoding with draft models does not support M-RoPE yet"
+ )
+
+ def _raise_if_padded_drafter_batch(self):
+ if not self.vllm_config.speculative_config.disable_padded_drafter_batch:
+ raise NotImplementedError(
+ "Speculative Decoding with draft models does not support "
+ "padded drafter batch yet. Please pass --disable-padded-drafter-batch "
+ "in the speculative_config."
+ )
+
+ def _raise_if_vocab_size_mismatch(self):
+ speculative_config = self.vllm_config.speculative_config
+ if (
+ speculative_config.method == "draft_model"
+ and speculative_config.target_model_config is not None
+ and speculative_config.draft_model_config is not None
+ ):
+ target_vocab_size = speculative_config.target_model_config.get_vocab_size()
+ draft_vocab_size = speculative_config.draft_model_config.get_vocab_size()
+ if target_vocab_size != draft_vocab_size:
+ raise ValueError(
+ f"Target and draft model should have the same vocabulary size. "
+ f"Target model vocab_size={target_vocab_size}. "
+ f"Draft model vocab_size={draft_vocab_size}. "
+ f"Using models with different tokenizers can cause out-of-bounds "
+ f"errors during speculative decoding."
+ )
+
+ def _raise_if_draft_tp_mismatch(self):
+ # Note(Tomas Ruiz) If we run the target model with TP > 1 and
+ # the draft model with TP = 1, then the different TP ranks collide.
+ # Specifically when all ranks compile the draft model on rank 0
+ # (because TP=1), then the torch compile cache is overwritten and corrupted.
+ # We need a mechanism like this: https://github.com/vllm-project/vllm/pull/5414
+ # To prevent this error, we assert that both TP sizes must be the same.
+ spec_cfg: SpeculativeConfig = self.vllm_config.speculative_config
+ tgt_tp = spec_cfg.target_parallel_config.tensor_parallel_size
+ draft_tp = spec_cfg.draft_parallel_config.tensor_parallel_size
+ if draft_tp != tgt_tp:
+ raise ValueError(
+ f"Currently, 'draft_tensor_parallel_size' and 'tensor_parallel_size' "
+ f"must be the same. Got {draft_tp} and {tgt_tp}. "
+ "Please pass 'draft_tensor_parallel_size' in the speculative_config."
+ )
+
+ def set_input_ids_first_pass(
+ self,
+ target_token_ids: torch.Tensor,
+ next_token_ids: torch.Tensor,
+ num_tokens: int,
+ last_token_indices: torch.Tensor,
+ ) -> None:
+ self.input_ids[:num_tokens] = target_token_ids
+
+ def load_model(self, target_model: Any) -> None:
+ """Takes target_model to satisfy the type checker."""
+
+ # This must be computed before loading the draft model
+ # because that mutates the forward_context of the vllm_config
+ target_attn_layer_names = set(
+ get_layers_from_vllm_config(self.vllm_config, Attention).keys()
+ )
+
+ from vllm.compilation.backends import set_model_tag
+
+ draft_vllm_config: VllmConfig = create_vllm_config_for_draft_model(
+ target_model_vllm_config=self.vllm_config
+ )
+ logger.info(
+ "Starting to load draft model %s. TP=%d, rank=%d",
+ draft_vllm_config.model_config.model,
+ draft_vllm_config.parallel_config.tensor_parallel_size,
+ draft_vllm_config.parallel_config.rank,
+ )
+ with set_model_tag("draft_model"):
+ self.model = get_model(vllm_config=draft_vllm_config, prefix="draft_model")
+
+ # This must be computed after loading the draft model
+ # because that mutates the forward_context of the vllm_config
+ draft_attn_layer_names = (
+ get_layers_from_vllm_config(self.vllm_config, Attention).keys()
+ - target_attn_layer_names
+ )
+ self.attn_layer_name = next(iter(draft_attn_layer_names))
+
+
+def create_vllm_config_for_draft_model(
+ target_model_vllm_config: VllmConfig,
+) -> VllmConfig:
+ """The vllm_config is configured for the target model, e.g.
+ its quant_config and parallel_config. But the draft model is potentially
+ quantized differently, and has potentially different tensor_parallel_size.
+ This function creates a new vllm_config configured for the draft model.
+ The vllm_config is useful when loading the draft model with get_model().
+ """
+ old = target_model_vllm_config
+ new_parallel_config = replace(
+ old.speculative_config.draft_parallel_config, rank=old.parallel_config.rank
+ )
+
+ new: VllmConfig = replace(
+ old,
+ quant_config=None, # quant_config is recomputed in __init__()
+ model_config=old.speculative_config.draft_model_config,
+ parallel_config=new_parallel_config,
+ )
+ return new
+
+
+def merge_next_token_ids_into_token_ids(
+ input_token_ids: torch.Tensor,
+ input_positions: torch.Tensor,
+ cad: AscendMetadata,
+ next_token_ids: torch.Tensor,
+ block_size: int,
+ max_model_len: int,
+ arange: torch.Tensor,
+ cu_num_tokens,
+ num_reqs
+):
+ """
+ Merges the next token ids with the existing token ids into a flat sequence.
+ Does the same for the positions, computes new slot mapping,
+ and updates the common_attn_metadata. The inputs are not modified in-place.
+ """
+ query_end_locs = cu_num_tokens[1:] - 1
+ new_token_ids = extend_flat_seqs(
+ seqs=input_token_ids, end_locs=query_end_locs, new_vals=next_token_ids
+ )
+
+ # append new positions
+ positions_to_append = input_positions[query_end_locs] + 1
+ new_positions = extend_flat_seqs(
+ seqs=input_positions, end_locs=query_end_locs, new_vals=positions_to_append
+ )
+ # recompute slot mapping
+ batch_size, n_blocks_per_req = cad.block_tables.shape
+ req_indices = torch.arange(num_reqs, device=cad.query_start_loc.device)
+
+ query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1]
+ req_indices = torch.repeat_interleave(
+ req_indices, query_lens.to(cad.query_start_loc.device) + 1
+ )
+ block_table_indices = req_indices * n_blocks_per_req + new_positions // block_size
+ block_nums = cad.block_tables.view(-1)[block_table_indices]
+ block_offsets = new_positions % block_size
+ new_slot_mapping = block_nums * block_size + block_offsets
+ # Mask out the position ids that exceed the max model length.
+ exceeds_max_model_len = new_positions >= max_model_len
+ new_slot_mapping.masked_fill_(exceeds_max_model_len, PADDING_SLOT_ID)
+
+ cu_num_tokens = cu_num_tokens + arange[: len(cu_num_tokens)]
+ return (new_token_ids, new_positions, new_slot_mapping, cu_num_tokens)
diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py
index 74e2917806..e5e4a31dde 100644
--- a/vllm_ascend/spec_decode/eagle_proposer.py
+++ b/vllm_ascend/spec_decode/eagle_proposer.py
@@ -34,16 +34,23 @@
PADDING_SLOT_ID = -1
-class EagleProposer(Proposer):
+class SpecDecodeBaseProposer(Proposer):
def __init__(self,
vllm_config: VllmConfig,
device: torch.device,
+ pass_hidden_states_to_model: bool,
runner=None):
- self.name = SpecDcodeType.EAGLE if vllm_config.speculative_config.method == "eagle" else SpecDcodeType.EAGLE3
+ if vllm_config.speculative_config.method == "eagle":
+ self.name = SpecDcodeType.EAGLE
+ elif vllm_config.speculative_config.method == "draft_model":
+ self.name = SpecDcodeType.DRAFT_MODEL
+ else:
+ self.name = SpecDcodeType.EAGLE3
self.vllm_config = vllm_config
self.device = device
self.runner = runner
+ self.pass_hidden_states_to_model = pass_hidden_states_to_model
self.block_size = vllm_config.cache_config.block_size
# We need to get the hidden size from the draft model config because
@@ -143,11 +150,13 @@ def dummy_run(self,
self.vllm_config,
moe_comm_type=moe_comm_type,
num_tokens=num_tokens):
- self.model(
+ model_kwargs = dict(
input_ids=self.input_ids[:num_tokens],
positions=self.positions[:num_tokens],
- hidden_states=self.hidden_states[:num_tokens],
)
+ if self.pass_hidden_states_to_model:
+ model_kwargs["hidden_states"] = self.hidden_states[:num_tokens]
+ self.model(**model_kwargs)
def generate_token_ids(self,
valid_sampled_token_ids: list[list[int]],
@@ -160,7 +169,7 @@ def generate_token_ids(self,
attn_metadata=None,
aux_hidden_states: torch.Tensor = None):
- attn_metadata = self._get_eagle_atten_dict(scheduler_output)
+ attn_metadata = self._get_atten_dict(scheduler_output)
next_token_ids: list[int] = []
for i, token_ids in enumerate(valid_sampled_token_ids):
if token_ids:
@@ -179,7 +188,7 @@ def generate_token_ids(self,
next_token_ids = torch.tensor(next_token_ids,
dtype=torch.int32,
device=self.device)
- eagle_attn_metadata = attn_metadata[self.attn_layer_name]
+ draft_attn_metadata = attn_metadata[self.attn_layer_name]
if spec_decode_metadata is None:
# input_ids can be None for multimodal models.
target_token_ids = self.runner.input_ids[:num_scheduled_tokens]
@@ -190,8 +199,8 @@ def generate_token_ids(self,
dim=-1)
else:
target_hidden_states = hidden_states[:num_scheduled_tokens]
- target_slot_mapping = eagle_attn_metadata.slot_mapping
- cu_num_tokens = eagle_attn_metadata.query_start_loc
+ target_slot_mapping = draft_attn_metadata.slot_mapping
+ cu_num_tokens = draft_attn_metadata.query_start_loc
else:
num_draft_tokens = spec_decode_metadata.num_draft_tokens
num_rejected_tokens = [
@@ -204,7 +213,7 @@ def generate_token_ids(self,
device=self.device,
)
cu_num_tokens, token_indices =\
- self._prepare_inputs(eagle_attn_metadata, num_rejected_tokens)
+ self._prepare_inputs(draft_attn_metadata, num_rejected_tokens)
target_token_ids = self.runner.input_ids[token_indices]
target_positions = positions[token_indices]
if self.name == SpecDcodeType.EAGLE3:
@@ -212,7 +221,7 @@ def generate_token_ids(self,
[h[token_indices] for h in aux_hidden_states], dim=-1)
else:
target_hidden_states = hidden_states[token_indices]
- target_slot_mapping = eagle_attn_metadata.slot_mapping[
+ target_slot_mapping = draft_attn_metadata.slot_mapping[
token_indices]
draft_token_ids = self._propose(
@@ -222,13 +231,13 @@ def generate_token_ids(self,
target_slot_mapping=target_slot_mapping,
next_token_ids=next_token_ids,
cu_num_tokens=cu_num_tokens,
- block_table=eagle_attn_metadata.block_tables,
+ block_table=draft_attn_metadata.block_tables,
sampling_metadata=sampling_metadata,
)
spec_token_ids = draft_token_ids.tolist()
return spec_token_ids
- def _get_eagle_atten_dict(
+ def _get_atten_dict(
self,
scheduler_output: "SchedulerOutput",
):
@@ -431,12 +440,15 @@ def _propose(
target_hidden_states)
assert target_hidden_states.shape[-1] == self.hidden_size
+
+
# Shift the input ids by one token.
- # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
- self.input_ids[:num_tokens - 1] = target_token_ids[1:]
- # Replace the last token with the next token.
- # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
- self.input_ids[last_token_indices] = next_token_ids
+ # # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
+ # self.input_ids[:num_tokens - 1] = target_token_ids[1:]
+ # # Replace the last token with the next token.
+ # # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
+ # self.input_ids[last_token_indices] = next_token_ids
+ self.set_input_ids_first_pass(target_token_ids, next_token_ids, num_tokens, last_token_indices)
seq_lens = (target_positions[last_token_indices] + 1).int()
query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1]
@@ -483,15 +495,24 @@ def _propose(
self.positions[:num_tokens] = target_positions.to(device)
self.hidden_states[:num_tokens] = target_hidden_states
attn_metadata.block_tables = block_table.to(device)
+ model_kwargs = {
+ "input_ids": self.input_ids[:num_input_tokens],
+ "positions": self.positions[:num_input_tokens]
+ }
+ if self.pass_hidden_states_to_model:
+ model_kwargs["hidden_states"] = self.hidden_states[:num_input_tokens]
+
with set_ascend_forward_context(attn_metadata,
self.vllm_config,
moe_comm_type=moe_comm_type,
num_tokens=num_input_tokens):
- last_hidden_states, hidden_states = self.model(
- input_ids=self.input_ids[:num_input_tokens],
- positions=self.positions[:num_input_tokens],
- hidden_states=self.hidden_states[:num_input_tokens],
- )
+ ret_hidden_states = self.model(**model_kwargs)
+ if not self.model_returns_tuple():
+ last_hidden_states = ret_hidden_states
+ hidden_states = ret_hidden_states
+ else:
+ last_hidden_states, hidden_states = ret_hidden_states
+
sample_hidden_states = last_hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states)
draft_token_ids = logits.argmax(dim=-1)
@@ -586,16 +607,23 @@ def _propose(
attn_metadata.attn_mask = attn_mask
attn_metadata.block_tables = block_table.to(device)
# Run the model.
+ model_kwargs = {
+ "input_ids": self.input_ids[:input_batch_size],
+ "positions": self.positions[:input_batch_size]
+ }
+ if self.pass_hidden_states_to_model:
+ model_kwargs["hidden_states"] = self.hidden_states[:input_batch_size]
with set_ascend_forward_context(attn_metadata,
self.vllm_config,
moe_comm_type=moe_comm_type,
num_tokens=input_batch_size):
- last_hidden_states, hidden_states = self.model(
- input_ids=self.input_ids[:input_batch_size],
- positions=self.positions[:input_batch_size],
- hidden_states=self.hidden_states[:input_batch_size],
- )
+ ret_hidden_states = self.model(**model_kwargs)
+ if not self.model_returns_tuple():
+ last_hidden_states = ret_hidden_states
+ hidden_states = ret_hidden_states
+ else:
+ last_hidden_states, hidden_states = ret_hidden_states
hidden_states = hidden_states[:batch_size]
logits = self.model.compute_logits(last_hidden_states[:batch_size])
@@ -694,3 +722,23 @@ def _prepare_inputs(
torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:])
return cu_num_tokens, token_indices
+
+ def set_input_ids_first_pass(
+ self,
+ target_token_ids: torch.Tensor,
+ next_token_ids: torch.Tensor,
+ num_tokens: int,
+ last_token_indices: torch.Tensor,
+ ) -> None:
+ self.input_ids[: num_tokens - 1] = target_token_ids[1:]
+ self.input_ids[last_token_indices] = next_token_ids
+
+ def model_returns_tuple(self) -> bool:
+ return self.name != SpecDcodeType.DRAFT_MODEL
+
+class EagleProposer(SpecDecodeBaseProposer):
+ def __init__(self,
+ vllm_config: VllmConfig,
+ device: torch.device,
+ runner=None):
+ super().__init__(vllm_config, device, pass_hidden_states_to_model=True, runner=runner)
\ No newline at end of file
diff --git a/vllm_ascend/spec_decode/interface.py b/vllm_ascend/spec_decode/interface.py
index 3f0a36b13c..ddbabd17ff 100644
--- a/vllm_ascend/spec_decode/interface.py
+++ b/vllm_ascend/spec_decode/interface.py
@@ -13,6 +13,7 @@ class SpecDcodeType(enum.Enum):
EAGLE = 1
EAGLE3 = 2
MTP = 4
+ DRAFT_MODEL = 5
class Proposer:
diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py
index 9f6d787471..2d4e239e27 100644
--- a/vllm_ascend/spec_decode/mtp_proposer.py
+++ b/vllm_ascend/spec_decode/mtp_proposer.py
@@ -1,3 +1,4 @@
+import importlib
from typing import Optional
import numpy as np
@@ -12,7 +13,6 @@
from vllm.model_executor.model_loader import get_model_loader
from vllm.model_executor.model_loader.utils import \
process_weights_after_loading
-from vllm.model_executor.models.deepseek_mtp import DeepSeekMTP
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.utils import cdiv
@@ -42,6 +42,26 @@
PADDING_SLOT_ID = -1
+_MTP_MODELS = {
+ "DeepseekV3ForCausalLM":
+ ("vllm.model_executor.models.deepseek_mtp", "DeepSeekMTP"),
+ "Qwen3NextForCausalLM":
+ ("vllm_ascend.models.qwen3_next_mtp", "CustomQwen3NextMTP")
+}
+
+_DEFAULT_FIRST_LAYER = 'model.layers.0.self_attn.attn'
+
+_FIRST_LAYERS = {"Qwen3NextForCausalLM": 'model.layers.3.self_attn.attn'}
+
+
+def _load_model(architecture):
+ if architecture not in _MTP_MODELS:
+ raise ValueError("Invalid architecture for mtp.")
+ module_name, model_name = _MTP_MODELS[architecture]
+ module = importlib.import_module(module_name)
+ model = getattr(module, model_name)
+ return model
+
class MtpProposer(Proposer):
@@ -150,9 +170,7 @@ def load_model(self, model) -> None:
with set_default_torch_dtype(
draft_model_config.dtype), set_current_vllm_config(
self.vllm_config):
- self.model = DeepSeekMTP(
- vllm_config=self.vllm_config).to(target_device)
-
+ self._init_mtp_model()
draft_attn_layer_names = (get_layers_from_vllm_config(
self.vllm_config, AttentionLayerBase).keys() -
target_attn_layer_names)
@@ -228,8 +246,7 @@ def generate_token_ids(self,
attn_metadata=None,
aux_hidden_states: torch.Tensor = None):
common_attn_metadata = self.runner.spec_decode_common_attn_metadata
- if attn_metadata is not None and isinstance(attn_metadata, dict):
- attn_metadata = attn_metadata['model.layers.0.self_attn.attn']
+ attn_metadata = self._get_attn_metadata(attn_metadata)
if self.speculative_config.disable_padded_drafter_batch:
# When padded-batch is disabled, the sampled_token_ids should be
@@ -311,6 +328,20 @@ def generate_token_ids(self,
return draft_token_ids
+ def _init_mtp_model(self):
+ architecture = self.vllm_config.model_config.architecture
+ target_device = self.vllm_config.device_config.device
+ model = _load_model(architecture)
+ self.model = model(vllm_config=self.vllm_config).to(target_device)
+
+ def _get_attn_metadata(self, attn_metadata):
+ if attn_metadata is not None and isinstance(attn_metadata, dict):
+ architecture = self.vllm_config.model_config.architecture
+ layer_name = _FIRST_LAYERS.get(architecture, _DEFAULT_FIRST_LAYER)
+ attn_metadata = attn_metadata[layer_name]
+
+ return attn_metadata
+
def _prepare_inputs(
self,
common_attn_metadata: CommonAttentionMetadata,
diff --git a/vllm_ascend/torchair/torchair_mla.py b/vllm_ascend/torchair/torchair_mla.py
index ce539b7d68..51becad900 100644
--- a/vllm_ascend/torchair/torchair_mla.py
+++ b/vllm_ascend/torchair/torchair_mla.py
@@ -69,6 +69,7 @@ class TorchairChunkedContextMetadata:
max_seq_lens: list[int]
workspace: torch.Tensor
chunk_seq_lens: torch.Tensor
+ chunk_seq_lens_npu: torch.Tensor
attn_mask: torch.Tensor
query_lens: torch.Tensor
@@ -447,6 +448,7 @@ def build(
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
chunk_seq_lens=chunk_seq_lens,
+ chunk_seq_lens_npu=chunk_seq_lens.npu(),
workspace=self.chunked_prefill_workspace,
)
prefill_input_positions = input_positions[tokens_start:]
@@ -760,7 +762,8 @@ def _compute_prefill_context(
q_pe = query[..., self.qk_nope_head_dim:]
q_nope = query[..., :self.qk_nope_head_dim]
- seq_len1 = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32)
+ current_seq_len = torch.tensor(prefill_metadata.query_lens,
+ dtype=torch.int32)
cache_kv_c = kv_c_and_k_pe_cache[0]
cache_k_pe = kv_c_and_k_pe_cache[1]
num_heads = cache_k_pe.size(2)
@@ -768,8 +771,11 @@ def _compute_prefill_context(
for i in range(iters):
toks = prefill_metadata.chunked_context.seq_tot[i]
- seq_len2 = prefill_metadata.chunked_context.chunk_seq_lens[i]
- seq_len = torch.stack([seq_len1, seq_len2])
+ context_seq_len = prefill_metadata.chunked_context.chunk_seq_lens[
+ i]
+ context_seq_len_npu = prefill_metadata.chunked_context.chunk_seq_lens_npu[
+ i]
+ seq_len = torch.stack([current_seq_len, context_seq_len])
kv_c_normed = torch.empty(toks,
num_heads,
latent_kv_dim,
@@ -785,7 +791,7 @@ def _compute_prefill_context(
cache_kv_c,
cache_k_pe,
prefill_metadata.block_table,
- seq_len2.to(query.device),
+ context_seq_len_npu,
seq_starts=prefill_metadata.chunked_context.starts[i],
key=kv_c_normed,
value=k_pe,
diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py
index e1afd24a08..46e80606fd 100644
--- a/vllm_ascend/utils.py
+++ b/vllm_ascend/utils.py
@@ -659,6 +659,17 @@ def enable_sp(vllm_config=None) -> bool:
# We retain the env VLLM_ASCEND_ENABLE_FLASHCOMM here for backward compatibility.
or bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", '0'))))
+ if not _ENABLE_SP:
+ return _ENABLE_SP
+
+ assert vllm_config.parallel_config.tensor_parallel_size > 1, \
+ "Flash Comm v1 (Sequence Parallelism) is only supported when tp_size > 1."
+
+ assert (
+ not is_moe_model(vllm_config)
+ or vllm_config.parallel_config.enable_expert_parallel
+ ), "Flash Comm v1 (Sequence Parallelism) requires enable_expert_parallel=True for MoE models."
+
return _ENABLE_SP
diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py
index 8e4acdd0ec..41f63de598 100644
--- a/vllm_ascend/worker/model_runner_v1.py
+++ b/vllm_ascend/worker/model_runner_v1.py
@@ -38,6 +38,7 @@
import torch._dynamo.cache_size
import torch.distributed as dist
import torch.nn as nn
+import vllm_ascend.envs as envs_ascend
from tqdm import tqdm # type: ignore
from vllm.attention import AttentionType, get_attn_backend
from vllm.attention.backends.abstract import AttentionBackend
@@ -102,8 +103,6 @@
gather_mm_placeholders,
sanity_check_mm_encoder_outputs,
scatter_mm_placeholders)
-
-import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import (MoECommType,
set_ascend_forward_context)
@@ -132,6 +131,7 @@
from vllm_ascend.sample.logits_processor import build_logitsprocs
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
from vllm_ascend.spec_decode import get_spec_decode_method
+from vllm_ascend.spec_decode.draft_proposer import DraftModelProposer
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
from vllm_ascend.spec_decode.interface import SpecDcodeType
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
@@ -163,7 +163,6 @@
from vllm.attention.layer import Attention
from vllm.config import CompilationLevel
from vllm.utils import LazyLoader, is_pin_memory_available
-
from vllm_ascend.models.layers.mla import AscendMultiHeadLatentAttention
else:
from vllm.attention.layer import MLAAttention
@@ -209,8 +208,7 @@ def graph_capture(device: torch.device):
in order to explicitly distinguish the kernels to capture
from other kernels possibly launched on background in the default stream.
"""
- graph_capture_context = GraphCaptureContext(
- torch.npu.Stream(device=device))
+ graph_capture_context = GraphCaptureContext(torch.npu.Stream(device=device))
stream = graph_capture_context.stream
# we use nullcontext now
@@ -228,7 +226,6 @@ def graph_capture(device: torch.device):
# Wrapper for ModelRunnerOutput to support overlapped execution.
class AsyncNPUModelRunnerOutput(AsyncModelRunnerOutput):
-
def __init__(
self,
model_runner_output: ModelRunnerOutput,
@@ -251,7 +248,8 @@ def __init__(
with torch.npu.stream(async_output_copy_stream):
async_output_copy_stream.wait_stream(default_stream)
self._sampled_token_ids_cpu = self._sampled_token_ids.to(
- 'cpu', non_blocking=True)
+ "cpu", non_blocking=True
+ )
self._async_copy_ready_event.record()
def get_output(self) -> ModelRunnerOutput:
@@ -274,7 +272,6 @@ def get_output(self) -> ModelRunnerOutput:
class NPUModelRunner(LoRAModelRunnerMixin):
-
def __init__(self, vllm_config: VllmConfig, device: torch.device):
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
@@ -287,19 +284,22 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
self.scheduler_config = vllm_config.scheduler_config
self.speculative_config = vllm_config.speculative_config
self.block_size = vllm_config.cache_config.block_size
- self.max_num_blocks_per_req = cdiv(self.model_config.max_model_len,
- self.block_size)
+ self.max_num_blocks_per_req = cdiv(
+ self.model_config.max_model_len, self.block_size
+ )
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
- decode_max_num_seqs = getattr(self.scheduler_config,
- 'decode_max_num_seqs', 0)
- self.max_num_reqs = max(self.scheduler_config.max_num_seqs,
- decode_max_num_seqs)
+ decode_max_num_seqs = getattr(self.scheduler_config, "decode_max_num_seqs", 0)
+ self.max_num_reqs = max(self.scheduler_config.max_num_seqs, decode_max_num_seqs)
self.dp_size = vllm_config.parallel_config.data_parallel_size
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
- self.pcp_size = get_prefill_context_model_parallel_world_size(
- ) if prefill_context_parallel_enable() else 1
- self.pcp_rank = get_prefill_context_model_parallel_rank(
- ) if self.pcp_size > 1 else 0
+ self.pcp_size = (
+ get_prefill_context_model_parallel_world_size()
+ if prefill_context_parallel_enable()
+ else 1
+ )
+ self.pcp_rank = (
+ get_prefill_context_model_parallel_rank() if self.pcp_size > 1 else 0
+ )
self.dcp_size = get_dcp_group().world_size
self.dcp_rank = get_dcp_group().rank_in_group
self.device = device
@@ -335,13 +335,15 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
else:
self.chunked_prefill_enabled = True
self.weight_prefetch_method = WeightPrefetchMethod(
- self.ascend_config.weight_prefetch_config)
+ self.ascend_config.weight_prefetch_config
+ )
if self.cache_config.cache_dtype == "auto":
self.kv_cache_dtype = self.dtype
else:
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
- self.cache_config.cache_dtype]
+ self.cache_config.cache_dtype
+ ]
# use_hybrid_blocks: if hybrid blocks is used.
self.use_hybrid_blocks: bool = False
self.need_accepted_tokens: bool = False
@@ -354,25 +356,26 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
self.max_num_tokens,
self.model_config.get_hidden_size(),
dtype=self.dtype,
- numpy=False)
- self.is_token_ids = self._make_buffer(self.max_num_tokens,
- dtype=torch.bool)
+ numpy=False,
+ )
+ self.is_token_ids = self._make_buffer(self.max_num_tokens, dtype=torch.bool)
# Set up Attention
- self.use_sparse = hasattr(self.vllm_config.model_config.hf_config,
- "index_topk")
- self.attn_backend = get_attn_backend(0,
- self.dtype,
- None,
- self.block_size,
- use_mla=self.model_config.use_mla,
- use_sparse=self.use_sparse)
+ self.use_sparse = hasattr(self.vllm_config.model_config.hf_config, "index_topk")
+ self.attn_backend = get_attn_backend(
+ 0,
+ self.dtype,
+ None,
+ self.block_size,
+ use_mla=self.model_config.use_mla,
+ use_sparse=self.use_sparse,
+ )
if self.pcp_size > 1:
self.attn_mask_builder = None
else:
self.attn_mask_builder = AttentionMaskBuilder(
- self.scheduler_config.max_num_batched_tokens, self.dtype,
- self.device)
+ self.scheduler_config.max_num_batched_tokens, self.dtype, self.device
+ )
self._set_up_drafter()
@@ -386,36 +389,40 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
self._may_pad_kv_consumer_num_seq()
# Persistent batch.
- self.input_ids = torch.zeros(self.max_num_tokens,
- dtype=torch.int32,
- device=self.device)
- self.positions = torch.zeros(self.max_num_tokens,
- dtype=torch.int64,
- device=self.device)
- self.query_start_loc = torch.zeros(self.max_num_reqs + 1,
- dtype=torch.int32,
- device=self.device)
- self.seq_lens = torch.zeros(self.max_num_reqs,
- dtype=torch.int32,
- device=self.device)
-
- if self.vllm_config.model_config.use_mla and \
- self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
+ self.input_ids = torch.zeros(
+ self.max_num_tokens, dtype=torch.int32, device=self.device
+ )
+ self.positions = torch.zeros(
+ self.max_num_tokens, dtype=torch.int64, device=self.device
+ )
+ self.query_start_loc = torch.zeros(
+ self.max_num_reqs + 1, dtype=torch.int32, device=self.device
+ )
+ self.seq_lens = torch.zeros(
+ self.max_num_reqs, dtype=torch.int32, device=self.device
+ )
+
+ if (
+ self.vllm_config.model_config.use_mla
+ and self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY
+ ):
rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
- self.cos = torch.ones(self.max_num_reqs *
- self.decode_token_per_req,
- 1,
- 1,
- rope_dim,
- dtype=self.dtype,
- device=self.device)
- self.sin = torch.zeros(self.max_num_reqs *
- self.decode_token_per_req,
- 1,
- 1,
- rope_dim,
- dtype=self.dtype,
- device=self.device)
+ self.cos = torch.ones(
+ self.max_num_reqs * self.decode_token_per_req,
+ 1,
+ 1,
+ rope_dim,
+ dtype=self.dtype,
+ device=self.device,
+ )
+ self.sin = torch.zeros(
+ self.max_num_reqs * self.decode_token_per_req,
+ 1,
+ 1,
+ rope_dim,
+ dtype=self.dtype,
+ device=self.device,
+ )
else:
self.cos = None
self.sin = None
@@ -433,80 +440,80 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
# identical position IDs, making M-RoPE functionally equivalent to
# 1D-RoPE.
# See page 5 of https://arxiv.org/abs/2409.12191
- self.mrope_positions = torch.zeros((3, self.max_num_tokens + 1),
- dtype=torch.int64,
- device=self.device)
+ self.mrope_positions = torch.zeros(
+ (3, self.max_num_tokens + 1), dtype=torch.int64, device=self.device
+ )
self.mrope_positions_cpu = torch.zeros(
(3, self.max_num_tokens + 1),
dtype=torch.int64,
device="cpu",
- pin_memory=True)
+ pin_memory=True,
+ )
self.mrope_positions_np = self.mrope_positions_cpu.numpy()
# OPTIMIZATION: Cache the tensors rather than creating them every step.
- self.arange_np: npt.NDArray[np.int32] = np.arange(max(
- self.max_num_reqs + 1, self.model_config.max_model_len,
- self.max_num_tokens),
- dtype=np.int32)
+ self.arange_np: npt.NDArray[np.int32] = np.arange(
+ max(
+ self.max_num_reqs + 1,
+ self.model_config.max_model_len,
+ self.max_num_tokens,
+ ),
+ dtype=np.int32,
+ )
# NOTE(woosuk): These tensors are "stateless", i.e., they are literally
# a faster version of creating a new tensor every time. Thus, we should
# not make any assumptions about the values in these tensors.
- self.input_ids_cpu = torch.zeros(self.max_num_tokens,
- dtype=torch.int32,
- device="cpu",
- pin_memory=True)
- self.positions_cpu = torch.zeros(self.max_num_tokens,
- dtype=torch.int64,
- device="cpu",
- pin_memory=True)
+ self.input_ids_cpu = torch.zeros(
+ self.max_num_tokens, dtype=torch.int32, device="cpu", pin_memory=True
+ )
+ self.positions_cpu = torch.zeros(
+ self.max_num_tokens, dtype=torch.int64, device="cpu", pin_memory=True
+ )
self.positions_np = self.positions_cpu.numpy()
- self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1,
- dtype=torch.int32,
- device="cpu",
- pin_memory=True)
+ self.query_start_loc_cpu = torch.zeros(
+ self.max_num_reqs + 1, dtype=torch.int32, device="cpu", pin_memory=True
+ )
self.query_start_loc_np = self.query_start_loc_cpu.numpy()
- self.seq_lens_cpu = torch.zeros(self.max_num_reqs,
- dtype=torch.int32,
- device="cpu",
- pin_memory=True)
+ self.seq_lens_cpu = torch.zeros(
+ self.max_num_reqs, dtype=torch.int32, device="cpu", pin_memory=True
+ )
self.seq_lens_np = self.seq_lens_cpu.numpy()
- self.pcp_allgather_restore_idx = torch.zeros(self.max_num_tokens,
- dtype=torch.int32,
- device=self.device)
+ self.pcp_allgather_restore_idx = torch.zeros(
+ self.max_num_tokens, dtype=torch.int32, device=self.device
+ )
self.num_pcp_pads = torch.zeros(self.max_num_reqs, dtype=torch.int32)
- self.pcp_padded_slot_mapping = torch.zeros(self.max_num_tokens,
- dtype=torch.int32,
- device=self.device)
+ self.pcp_padded_slot_mapping = torch.zeros(
+ self.max_num_tokens, dtype=torch.int32, device=self.device
+ )
self.num_actual_tokens_pcp_padded = 0
if self.speculative_config and self.pcp_size > 1:
- self.input_ids_pcp_full = torch.zeros(self.max_num_tokens,
- dtype=torch.int32,
- device="cpu",
- pin_memory=True)
- self.query_start_loc_pcp_full = torch.zeros(self.max_num_reqs + 1,
- dtype=torch.int32,
- device="cpu",
- pin_memory=True)
- self.query_start_loc_pcp_full_np = self.query_start_loc_pcp_full.numpy(
+ self.input_ids_pcp_full = torch.zeros(
+ self.max_num_tokens, dtype=torch.int32, device="cpu", pin_memory=True
+ )
+ self.query_start_loc_pcp_full = torch.zeros(
+ self.max_num_reqs + 1, dtype=torch.int32, device="cpu", pin_memory=True
+ )
+ self.query_start_loc_pcp_full_np = self.query_start_loc_pcp_full.numpy()
+ self.positions_pcp_full = torch.zeros(
+ self.max_num_tokens, dtype=torch.int64, device="cpu", pin_memory=True
)
- self.positions_pcp_full = torch.zeros(self.max_num_tokens,
- dtype=torch.int64,
- device="cpu",
- pin_memory=True)
self.positions_np_pcp_full = self.positions_pcp_full.numpy()
self.use_aclgraph = self._use_aclgraph()
self.aclgraph_batch_sizes = list(
- reversed(self.compilation_config.cudagraph_capture_sizes))
+ reversed(self.compilation_config.cudagraph_capture_sizes)
+ )
- self.uniform_decode_query_len = 1 if not self.speculative_config else \
- 1 + self.speculative_config.num_speculative_tokens
+ self.uniform_decode_query_len = (
+ 1
+ if not self.speculative_config
+ else 1 + self.speculative_config.num_speculative_tokens
+ )
# aclgraph dispatcher for runtime aclgraph dispatching.
self.aclgraph_dispatcher = CudagraphDispatcher(self.vllm_config)
# Cached outputs.
- self._draft_token_ids: Optional[Union[list[list[int]],
- torch.Tensor]] = None
+ self._draft_token_ids: Optional[Union[list[list[int]], torch.Tensor]] = None
# NOTE: we need to use `in_profile_run` to determine whether `enable_force_load_balance` is True
self.in_profile_run = False
@@ -520,31 +527,36 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
)
else:
self.reserved_mc2_mask = None
- self.dynamic_eplb = self.ascend_config.dynamic_eplb or self.ascend_config.expert_map_record_path
+ self.dynamic_eplb = (
+ self.ascend_config.dynamic_eplb or self.ascend_config.expert_map_record_path
+ )
if self.dynamic_eplb:
EPLBParamUtils.check_dynamic_eplb(self.ascend_config.dynamic_eplb)
EPLBParamUtils.check_expert_map_record_path(
- self.ascend_config.expert_map_record_path)
+ self.ascend_config.expert_map_record_path
+ )
self.is_eplb_warmuped = False
self.policy_type = self.ascend_config.eplb_policy_type
self.eplb_loader = D2DExpertWeightLoader()
self.manager = Manager()
- self.shared_dict = self.manager.dict({
- "expert_map": None,
- "moe_load": None,
- "expert_maps": None
- })
- self.eplb_process = EplbProcess(shared_dict=self.shared_dict,
- policy_type=self.policy_type,
- enable_d2d=True)
+ self.shared_dict = self.manager.dict(
+ {"expert_map": None, "moe_load": None, "expert_maps": None}
+ )
+ self.eplb_process = EplbProcess(
+ shared_dict=self.shared_dict,
+ policy_type=self.policy_type,
+ enable_d2d=True,
+ )
self.process = self.eplb_process._launch_process()
ascend_config = get_ascend_config()
- self.eplb_updator = EplbUpdator(ascend_config, self.eplb_loader,
- self.eplb_process, self.process)
+ self.eplb_updator = EplbUpdator(
+ ascend_config, self.eplb_loader, self.eplb_process, self.process
+ )
self.use_async_scheduling = self.scheduler_config.async_scheduling
- self.async_output_copy_stream = torch.npu.Stream() if \
- self.use_async_scheduling else None
+ self.async_output_copy_stream = (
+ torch.npu.Stream() if self.use_async_scheduling else None
+ )
# Input Batch
# NOTE(Chen): Ideally, we should initialize the input batch inside
# `initialize_kv_cache` based on the kv cache config. However, as in
@@ -564,61 +576,75 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
block_sizes=[self.block_size],
is_spec_decode=bool(self.vllm_config.speculative_config),
logitsprocs=build_logitsprocs(
- self.vllm_config, self.device, self.pin_memory,
+ self.vllm_config,
+ self.device,
+ self.pin_memory,
self.is_pooling_model,
- self.vllm_config.model_config.logits_processors),
+ self.vllm_config.model_config.logits_processors,
+ ),
is_pooling_model=self.is_pooling_model,
kernel_block_sizes=[[self.vllm_config.cache_config.block_size]],
- cp_kv_cache_interleave_size=self.parallel_config.
- cp_kv_cache_interleave_size
- if prefill_context_parallel_enable() else 1,
+ cp_kv_cache_interleave_size=self.parallel_config.cp_kv_cache_interleave_size
+ if prefill_context_parallel_enable()
+ else 1,
+ )
+ self.num_accepted_tokens = self._make_buffer(
+ self.max_num_reqs, dtype=torch.int64
)
- self.num_accepted_tokens = self._make_buffer(self.max_num_reqs,
- dtype=torch.int64)
- self.num_draft_tokens = self._make_buffer(self.max_num_reqs,
- dtype=torch.int32)
+ self.num_draft_tokens = self._make_buffer(self.max_num_reqs, dtype=torch.int32)
# Only relevant for multimodal models
self.mm_registry = MULTIMODAL_REGISTRY
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
- self.model_config)
+ self.model_config
+ )
if self.supports_mm_inputs:
- self.is_mm_embed = self._make_buffer(self.max_num_tokens,
- dtype=torch.bool)
+ self.is_mm_embed = self._make_buffer(self.max_num_tokens, dtype=torch.bool)
# TODO: EVS Support (Video tokens pruning) (see vllm#22980)
self.is_multimodal_pruning_enabled = False
def _set_up_drafter(self):
# Set up speculative decoding.
self.spec_attn_mask = None
- self.drafter: Optional[Union[NgramProposer, EagleProposer, MtpProposer,
- TorchairMtpProposer]] = None
+ self.drafter: Optional[
+ Union[
+ NgramProposer,
+ EagleProposer,
+ MtpProposer,
+ TorchairMtpProposer,
+ DraftModelProposer,
+ ]
+ ] = None
self.actual_seq_lengths_q: list[int] = []
self.decode_token_per_req = 1
if self.speculative_config:
spec_token_num = self.speculative_config.num_speculative_tokens
assert spec_token_num > 0
self.decode_token_per_req = 1 + spec_token_num
- self.spec_attn_mask = torch.triu(torch.ones(2048,
- 2048,
- dtype=torch.bool),
- diagonal=1).to(self.device)
+ self.spec_attn_mask = torch.triu(
+ torch.ones(2048, 2048, dtype=torch.bool), diagonal=1
+ ).to(self.device)
if get_pp_group().is_last_rank:
self.drafter = self._get_drafter()
if vllm_version_is("0.11.0"):
self.rejection_sampler = AscendRejectionSampler()
else:
- self.rejection_sampler = AscendRejectionSampler(
- self.sampler)
+ self.rejection_sampler = AscendRejectionSampler(self.sampler)
self.actual_seq_lengths_q = list(
- range(self.decode_token_per_req, self.max_num_tokens + 1,
- self.decode_token_per_req))
- self.discard_request_indices = self._make_buffer(self.max_num_reqs,
- dtype=torch.int64)
+ range(
+ self.decode_token_per_req,
+ self.max_num_tokens + 1,
+ self.decode_token_per_req,
+ )
+ )
+ self.discard_request_indices = self._make_buffer(
+ self.max_num_reqs, dtype=torch.int64
+ )
self.num_discarded_requests = 0
def _get_drafter(self):
- return get_spec_decode_method(self.speculative_config.method,
- self.vllm_config, self.device, self)
+ return get_spec_decode_method(
+ self.speculative_config.method, self.vllm_config, self.device, self
+ )
def _may_pad_kv_consumer_num_seq(self):
# For Full Graph + MTP in a PD (Prefill/Decode) disaggregation scenario,
@@ -634,28 +660,29 @@ def _init_mc2_tokens_capacity(self):
max_num_tokens = self.compilation_config.cudagraph_capture_sizes[0]
else:
# NOTE: To save memory, we cap the max number of tokens to 512.
- max_num_tokens = min(
- self.max_num_reqs * self.uniform_decode_query_len, 512)
+ max_num_tokens = min(self.max_num_reqs * self.uniform_decode_query_len, 512)
tp_size = self.parallel_config.tensor_parallel_size
# Use integer arithmetic for ceiling division.
num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1) // tp_size
self.mc2_tokens_capacity: int = num_tokens_per_tp_rank * tp_size
- def _make_buffer(self,
- *size: Union[int, torch.SymInt],
- dtype: torch.dtype,
- numpy: bool = True) -> CpuGpuBuffer:
+ def _make_buffer(
+ self, *size: Union[int, torch.SymInt], dtype: torch.dtype, numpy: bool = True
+ ) -> CpuGpuBuffer:
# Bfloat16 torch tensors cannot be directly cast to a numpy array, so
# if a bfloat16 buffer is needed without a corresponding numpy array,
# don't bother instantiating the numpy array.
- return CpuGpuBuffer(*size,
- dtype=dtype,
- device=self.device,
- pin_memory=self.pin_memory,
- with_numpy=numpy)
+ return CpuGpuBuffer(
+ *size,
+ dtype=dtype,
+ device=self.device,
+ pin_memory=self.pin_memory,
+ with_numpy=numpy,
+ )
def _update_states_after_model_execute(
- self, output_token_ids: torch.Tensor) -> None:
+ self, output_token_ids: torch.Tensor
+ ) -> None:
"""Update the cached states after model execution.
This is used for MTP/EAGLE for hybrid models, as in linear attention,
@@ -668,22 +695,42 @@ def _update_states_after_model_execute(
return
# Find the number of accepted tokens for each sequence.
- num_accepted_tokens = (torch.cat(
- [
- output_token_ids,
- torch.full((output_token_ids.size(0), 1),
- -1,
- device=output_token_ids.device),
- ],
- dim=1) == -1).int().argmax(-1).cpu().numpy()
+ num_accepted_tokens = (
+ (
+ torch.cat(
+ [
+ output_token_ids,
+ torch.full(
+ (output_token_ids.size(0), 1),
+ -1,
+ device=output_token_ids.device,
+ ),
+ ],
+ dim=1,
+ )
+ == -1
+ )
+ .int()
+ .argmax(-1)
+ .cpu()
+ .numpy()
+ )
for i, num_tokens in enumerate(num_accepted_tokens):
self.input_batch.num_accepted_tokens_cpu[i] = num_tokens
def _use_aclgraph(self) -> bool:
if vllm_version_is("0.11.0"):
- return self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE and self.compilation_config.level == CompilationLevel.PIECEWISE and not self.model_config.enforce_eager
+ return (
+ self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
+ and self.compilation_config.level == CompilationLevel.PIECEWISE
+ and not self.model_config.enforce_eager
+ )
else:
- return self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE and self.compilation_config.mode == CompilationMode.VLLM_COMPILE and not self.model_config.enforce_eager
+ return (
+ self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
+ and self.compilation_config.mode == CompilationMode.VLLM_COMPILE
+ and not self.model_config.enforce_eager
+ )
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
# Remove finished requests from the cached states.
@@ -722,8 +769,10 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
sampling_params = new_req_data.sampling_params
pooling_params = new_req_data.pooling_params
- if sampling_params and \
- sampling_params.sampling_type == SamplingType.RANDOM_SEED:
+ if (
+ sampling_params
+ and sampling_params.sampling_type == SamplingType.RANDOM_SEED
+ ):
generator = torch.Generator(device=self.device)
generator.manual_seed(sampling_params.seed)
else:
@@ -731,7 +780,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
if pooling_params:
assert (task := pooling_params.task) is not None, (
- "You did not set `task` in the API")
+ "You did not set `task` in the API"
+ )
model = cast(VllmModelForPooling, self.get_model())
to_update = model.pooler.get_pooling_updates(task)
to_update.apply(pooling_params)
@@ -778,21 +828,20 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
new_token_ids = req_data.new_token_ids[i]
# Add the sampled token(s) from the previous step (if any).
# This doesn't include "unverified" tokens like spec tokens.
- num_new_tokens = (num_computed_tokens + len(new_token_ids) -
- req_state.num_tokens)
+ num_new_tokens = (
+ num_computed_tokens + len(new_token_ids) - req_state.num_tokens
+ )
if num_new_tokens == 1:
# Avoid slicing list in most common case.
req_state.output_token_ids.append(new_token_ids[-1])
elif num_new_tokens > 0:
- req_state.output_token_ids.extend(
- new_token_ids[-num_new_tokens:])
+ req_state.output_token_ids.extend(new_token_ids[-num_new_tokens:])
# Update the block IDs.
if not resumed_from_preemption:
if new_block_ids is not None:
# Append the new blocks to the existing block IDs.
- for block_ids, new_ids in zip(req_state.block_ids,
- new_block_ids):
+ for block_ids, new_ids in zip(req_state.block_ids, new_block_ids):
block_ids.extend(new_ids)
else:
assert new_block_ids is not None
@@ -809,11 +858,9 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
continue
# Update the persistent batch.
- self.input_batch.num_computed_tokens_cpu[req_index] = (
- num_computed_tokens)
+ self.input_batch.num_computed_tokens_cpu[req_index] = num_computed_tokens
if new_block_ids is not None:
- self.input_batch.block_table.append_row(
- new_block_ids, req_index)
+ self.input_batch.block_table.append_row(new_block_ids, req_index)
# For the last rank, we don't need to update the token_ids_cpu
# because the sampled tokens are already cached.
@@ -822,21 +869,22 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
start_token_index = num_computed_tokens
end_token_index = num_computed_tokens + len(new_token_ids)
self.input_batch.token_ids_cpu[
- req_index,
- start_token_index:end_token_index] = new_token_ids
- self.input_batch.num_tokens_no_spec[
- req_index] = end_token_index
+ req_index, start_token_index:end_token_index
+ ] = new_token_ids
+ self.input_batch.num_tokens_no_spec[req_index] = end_token_index
self.input_batch.num_tokens[req_index] = end_token_index
# Add spec_token_ids to token_ids_cpu.
- spec_token_ids = (
- scheduler_output.scheduled_spec_decode_tokens.get(req_id, ()))
+ spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
+ req_id, ()
+ )
if spec_token_ids:
num_spec_tokens = len(spec_token_ids)
start_index = self.input_batch.num_tokens_no_spec[req_index]
end_token_index = start_index + num_spec_tokens
self.input_batch.token_ids_cpu[
- req_index, start_index:end_token_index] = spec_token_ids
+ req_index, start_index:end_token_index
+ ] = spec_token_ids
# NOTE(woosuk): `num_tokens` here may include spec tokens.
self.input_batch.num_tokens[req_index] += num_spec_tokens
@@ -877,7 +925,7 @@ def _init_mrope_positions(self, req_state: CachedRequestState):
use_audio_in_video = True
if vllm_version_is("0.11.0"):
- req_state.mrope_positions, req_state.mrope_position_delta = \
+ req_state.mrope_positions, req_state.mrope_position_delta = (
MRotaryEmbedding.get_input_positions_tensor(
req_state.prompt_token_ids,
hf_config=self.model_config.hf_config,
@@ -887,9 +935,10 @@ def _init_mrope_positions(self, req_state: CachedRequestState):
audio_feature_lengths=audio_feature_lengths,
use_audio_in_video=use_audio_in_video,
)
+ )
else:
if supports_mrope(self.model):
- req_state.mrope_positions, req_state.mrope_position_delta = \
+ req_state.mrope_positions, req_state.mrope_position_delta = (
self.model.get_mrope_input_positions(
req_state.prompt_token_ids,
hf_config=self.model_config.hf_config,
@@ -899,10 +948,11 @@ def _init_mrope_positions(self, req_state: CachedRequestState):
audio_feature_lengths=audio_feature_lengths,
use_audio_in_video=use_audio_in_video,
)
+ )
def _sync_metadata_across_dp(
- self, num_tokens: int,
- with_prefill: bool) -> tuple[int, Optional[torch.Tensor], bool]:
+ self, num_tokens: int, with_prefill: bool
+ ) -> tuple[int, Optional[torch.Tensor], bool]:
# TODO: In vLLM, the only thing that needs to be synced is num_tokens, but in
# our case, we still need to sync the other two flags as well. So we need to
# include them in the all_reduce operation, and more over, we CANNOT skip it
@@ -913,15 +963,15 @@ def _sync_metadata_across_dp(
return num_tokens, None, with_prefill
# Sync num_tokens, with_prefill across dp ranks
- num_tokens_tensor = torch.tensor([
- num_tokens if i == self.dp_rank else 0 for i in range(self.dp_size)
- ],
- dtype=torch.int32,
- device="npu")
+ num_tokens_tensor = torch.tensor(
+ [num_tokens if i == self.dp_rank else 0 for i in range(self.dp_size)],
+ dtype=torch.int32,
+ device="npu",
+ )
- flags_tensor = torch.tensor([int(with_prefill)],
- dtype=torch.int32,
- device="npu")
+ flags_tensor = torch.tensor(
+ [int(with_prefill)], dtype=torch.int32, device="npu"
+ )
packed_tensor = torch.cat([num_tokens_tensor, flags_tensor])
@@ -935,10 +985,9 @@ def _sync_metadata_across_dp(
global_with_prefill = bool(synced_flags[0])
# Create a tensor for num_tokens_after_padding
- num_tokens_after_padding = torch.tensor([max_tokens_across_dp] *
- self.dp_size,
- device="cpu",
- dtype=torch.int32)
+ num_tokens_after_padding = torch.tensor(
+ [max_tokens_across_dp] * self.dp_size, device="cpu", dtype=torch.int32
+ )
return max_tokens_across_dp, num_tokens_after_padding, global_with_prefill
@@ -973,21 +1022,28 @@ def get_supported_tasks(self) -> "tuple[SupportedTask, ...]":
return tuple(tasks)
- def _make_attention_mask(self, seq_lens, position,
- attn_state) -> torch.Tensor:
+ def _make_attention_mask(self, seq_lens, position, attn_state) -> torch.Tensor:
if self.pcp_size > 1:
return None
if self.attn_mask_builder is None:
raise ValueError("Attn mask builder is None")
# Pooling situation.
- if self.model_config.runner_type == "pooling" and self.model_config.pooler_config.pooling_type == "CLS":
+ if (
+ self.model_config.runner_type == "pooling"
+ and self.model_config.pooler_config.pooling_type == "CLS"
+ ):
return self.attn_mask_builder.get_pooling_mask(self.device)
# Chunk Prefill situation.
- elif attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla and not self.use_sparse:
+ elif (
+ attn_state == AscendAttentionState.ChunkedPrefill
+ and not self.vllm_config.model_config.use_mla
+ and not self.use_sparse
+ ):
if self.dcp_size > 1:
max_seq_len = max(seq_lens.max().item(), 0)
return self.attn_mask_builder.get_attn_mask(
- max_seq_len, self.dtype, self.device)
+ max_seq_len, self.dtype, self.device
+ )
else:
return self.attn_mask_builder.get_splitfuse_attn_mask()
@@ -995,11 +1051,11 @@ def _make_attention_mask(self, seq_lens, position,
elif attn_state == AscendAttentionState.PrefillNoCache:
max_seq_len = max(seq_lens.max().item(), 0)
return self.attn_mask_builder.get_attn_mask(
- max_seq_len, self.dtype, self.device)
+ max_seq_len, self.dtype, self.device
+ )
# Prefill with cache hit.
elif attn_state == AscendAttentionState.PrefillCacheHit:
- return self.attn_mask_builder.get_attn_mask(
- 2048, self.dtype, self.device)
+ return self.attn_mask_builder.get_attn_mask(2048, self.dtype, self.device)
# Decode-only situation.
else:
return None
@@ -1010,18 +1066,15 @@ def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
req = self.requests[req_id]
assert req.mrope_positions is not None
- num_computed_tokens = \
- self.input_batch.num_computed_tokens_cpu[index]
- num_scheduled_tokens = \
- scheduler_output.num_scheduled_tokens[req_id]
+ num_computed_tokens = self.input_batch.num_computed_tokens_cpu[index]
+ num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
- req.prompt_token_ids, req.prompt_embeds)
+ req.prompt_token_ids, req.prompt_embeds
+ )
if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens:
- prompt_part_len = max(0,
- num_prompt_tokens - num_computed_tokens)
- completion_part_len = max(
- 0, num_scheduled_tokens - prompt_part_len)
+ prompt_part_len = max(0, num_prompt_tokens - num_computed_tokens)
+ completion_part_len = max(0, num_scheduled_tokens - prompt_part_len)
else:
prompt_part_len = num_scheduled_tokens
completion_part_len = 0
@@ -1035,8 +1088,9 @@ def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
src_start = num_computed_tokens
src_end = num_computed_tokens + prompt_part_len
- self.mrope_positions_cpu[:, dst_start:dst_end] = \
- req.mrope_positions[:, src_start:src_end]
+ self.mrope_positions_cpu[:, dst_start:dst_end] = req.mrope_positions[
+ :, src_start:src_end
+ ]
mrope_pos_ptr += prompt_part_len
@@ -1061,7 +1115,8 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
# Batch the multi-modal inputs.
mm_kwargs, mm_hashes_pos = self._batch_mm_kwargs_from_scheduler(
- scheduler_output)
+ scheduler_output
+ )
encoder_outputs = []
if vllm_version_is("0.11.0"):
@@ -1086,8 +1141,7 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
# 2. A list or tuple (length: num_items) of tensors, each of shape
# (feature_size, hidden_size) in case the feature size is dynamic
# depending on the input multimodal items.
- curr_group_outputs = self.model.get_multimodal_embeddings(
- **mm_kwargs_group)
+ curr_group_outputs = self.model.get_multimodal_embeddings(**mm_kwargs_group)
sanity_check_mm_encoder_outputs(
curr_group_outputs,
@@ -1140,19 +1194,20 @@ def _gather_mm_embeddings_0110(
self,
scheduler_output: "SchedulerOutput",
) -> list[torch.Tensor]:
-
def _iter_mm_features(req_state: CachedRequestState):
assert req_state.mm_features is not None
for mm_feature in req_state.mm_features:
pos_info = mm_feature.mm_position
- yield mm_feature.identifier, pos_info, getattr(
- pos_info, "is_embed", None)
+ yield (
+ mm_feature.identifier,
+ pos_info,
+ getattr(pos_info, "is_embed", None),
+ )
mm_embeds: list[torch.Tensor] = []
for req_id in self.input_batch.req_ids:
- num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
- req_id]
+ num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
req_state = self.requests[req_id]
num_computed_tokens = req_state.num_computed_tokens
@@ -1173,8 +1228,7 @@ def _iter_mm_features(req_state: CachedRequestState):
assert start_idx < end_idx
encoder_output = self.encoder_cache.get(mm_hash, None)
- assert encoder_output is not None, \
- f"Encoder cache miss for {mm_hash}."
+ assert encoder_output is not None, f"Encoder cache miss for {mm_hash}."
if is_embed is not None:
is_embed = is_embed[start_idx:end_idx]
@@ -1202,11 +1256,9 @@ def _gather_mm_embeddings(
for req_id in self.input_batch.req_ids:
mm_embeds_req: list[torch.Tensor] = []
- num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
- req_id]
+ num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
req_state = self.requests[req_id]
- num_computed_tokens = \
- req_state.num_computed_tokens + shift_computed_tokens
+ num_computed_tokens = req_state.num_computed_tokens + shift_computed_tokens
for mm_feature in req_state.mm_features: # type: ignore
pos_info = mm_feature.mm_position
@@ -1234,15 +1286,15 @@ def _gather_mm_embeddings(
mm_hash = mm_feature.identifier
encoder_output = self.encoder_cache.get(mm_hash, None)
- assert encoder_output is not None,\
- f"Encoder cache miss for {mm_hash}."
+ assert encoder_output is not None, f"Encoder cache miss for {mm_hash}."
if (is_embed := pos_info.is_embed) is not None:
is_embed = is_embed[start_idx:end_idx]
req_start_pos = req_start_idx + start_pos - num_computed_tokens
- is_mm_embed[req_start_pos+start_idx:req_start_pos + end_idx] \
- = True if is_embed is None else is_embed
+ is_mm_embed[req_start_pos + start_idx : req_start_pos + end_idx] = (
+ True if is_embed is None else is_embed
+ )
mm_embeds_item = gather_mm_placeholders(
encoder_output[start_idx:end_idx],
@@ -1277,8 +1329,9 @@ def _get_cumsum_and_arange(
return cu_num_tokens, arange
- def _prepare_input_ids(self, total_num_scheduled_tokens: int,
- cu_num_tokens: np.ndarray) -> None:
+ def _prepare_input_ids(
+ self, total_num_scheduled_tokens: int, cu_num_tokens: np.ndarray
+ ) -> None:
"""Prepare the input IDs for the current batch.
Carefully handles the `prev_sampled_token_ids` which can be cached
@@ -1288,8 +1341,8 @@ def _prepare_input_ids(self, total_num_scheduled_tokens: int,
if self.input_batch.prev_sampled_token_ids is None:
# Normal scheduling case
self.input_ids[:total_num_scheduled_tokens].copy_(
- self.input_ids_cpu[:total_num_scheduled_tokens],
- non_blocking=True)
+ self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True
+ )
if self.is_multimodal_model or self.enable_prompt_embeds:
self.inputs_embeds.copy_to_gpu(total_num_scheduled_tokens)
self.is_token_ids.copy_to_gpu(total_num_scheduled_tokens)
@@ -1311,15 +1364,15 @@ def _prepare_input_ids(self, total_num_scheduled_tokens: int,
# last token in each common request.
flattened_index = cu_num_tokens[cur_index].item() - 1
flattened_indices.append(flattened_index)
- indices_match &= (prev_index == flattened_index)
+ indices_match &= prev_index == flattened_index
max_flattened_index = max(max_flattened_index, flattened_index)
num_commmon_tokens = len(flattened_indices)
if num_commmon_tokens < total_num_scheduled_tokens:
# If not all requests are decodes from the last iteration,
# We need to copy the input_ids_cpu to the NPU first.
self.input_ids[:total_num_scheduled_tokens].copy_(
- self.input_ids_cpu[:total_num_scheduled_tokens],
- non_blocking=True)
+ self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True
+ )
if self.is_multimodal_model or self.enable_prompt_embeds:
self.inputs_embeds.copy_to_gpu(total_num_scheduled_tokens)
self.is_token_ids.copy_to_gpu(total_num_scheduled_tokens)
@@ -1333,26 +1386,26 @@ def _prepare_input_ids(self, total_num_scheduled_tokens: int,
# The indices are both the same permutation of 0..N-1 so
# we can copy directly using a single slice.
self.input_ids[:num_commmon_tokens].copy_(
- self.input_batch.prev_sampled_token_ids[:num_commmon_tokens,
- 0],
- non_blocking=True)
+ self.input_batch.prev_sampled_token_ids[:num_commmon_tokens, 0],
+ non_blocking=True,
+ )
self.is_token_ids.gpu[:num_commmon_tokens] = True
return
# Upload the index tensors asynchronously
# so the scatter can be non-blocking.
- input_ids_index_tensor = torch.tensor(flattened_indices,
- dtype=torch.int64,
- pin_memory=self.pin_memory).to(
- self.device,
- non_blocking=True)
+ input_ids_index_tensor = torch.tensor(
+ flattened_indices, dtype=torch.int64, pin_memory=self.pin_memory
+ ).to(self.device, non_blocking=True)
prev_common_req_indices_tensor = torch.tensor(
- prev_common_req_indices,
- dtype=torch.int64,
- pin_memory=self.pin_memory).to(self.device, non_blocking=True)
- self.input_ids.scatter_(dim=0,
- index=input_ids_index_tensor,
- src=self.input_batch.prev_sampled_token_ids[
- prev_common_req_indices_tensor, 0])
+ prev_common_req_indices, dtype=torch.int64, pin_memory=self.pin_memory
+ ).to(self.device, non_blocking=True)
+ self.input_ids.scatter_(
+ dim=0,
+ index=input_ids_index_tensor,
+ src=self.input_batch.prev_sampled_token_ids[
+ prev_common_req_indices_tensor, 0
+ ],
+ )
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
"""
@@ -1376,15 +1429,27 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
reorder_batch_to_split_decodes_and_prefills(
self.input_batch,
scheduler_output,
- decode_threshold=self.reorder_batch_threshold)
+ decode_threshold=self.reorder_batch_threshold,
+ )
def _prepare_inputs(
self,
scheduler_output: "SchedulerOutput",
intermediate_tensors: Optional[IntermediateTensors] = None,
- ) -> tuple[dict[str, Any], torch.Tensor, np.ndarray, int, torch.Tensor,
- int, torch.Tensor, SpecDecodeMetadata, Optional[torch.Tensor],
- Optional[torch.Tensor], Optional[torch.Tensor], int]:
+ ) -> tuple[
+ dict[str, Any],
+ torch.Tensor,
+ np.ndarray,
+ int,
+ torch.Tensor,
+ int,
+ torch.Tensor,
+ SpecDecodeMetadata,
+ Optional[torch.Tensor],
+ Optional[torch.Tensor],
+ Optional[torch.Tensor],
+ int,
+ ]:
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
assert total_num_scheduled_tokens > 0
num_reqs = self.input_batch.num_reqs
@@ -1399,58 +1464,75 @@ def _prepare_inputs(
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
- req_indices = np.repeat(self.arange_np[:num_reqs],
- num_scheduled_tokens)
+ req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens)
_, arange = self._get_cumsum_and_arange(num_scheduled_tokens)
positions_np = np.add(
self.input_batch.num_computed_tokens_cpu[req_indices],
arange,
)
+<<<<<<< HEAD
+ self.input_batch.block_table.compute_slot_mapping(req_indices, positions_np)
+ self.input_batch.block_table.commit_slot_mapping(total_num_scheduled_tokens)
+ tokens, position_pcp, pcp_unpad_mask = self._update_tokens_for_pcp(tokens)
+ num_scheduled_tokens = np.array(tokens, dtype=np.int32)
+ # update total_num_scheduled_tokens
+ total_num_scheduled_tokens = sum(num_scheduled_tokens[:num_reqs])
+=======
self.input_batch.block_table.compute_slot_mapping(
req_indices, positions_np)
self.input_batch.block_table.commit_slot_mapping(
total_num_scheduled_tokens)
- tokens, position_pcp, pcp_unpad_mask = self._update_tokens_for_pcp(
- tokens)
- num_scheduled_tokens = np.array(tokens, dtype=np.int32)
- # update total_num_scheduled_tokens
- total_num_scheduled_tokens = sum(num_scheduled_tokens[:num_reqs])
+ if self.pcp_size > 1:
+ tokens, position_pcp, pcp_unpad_mask = self._update_tokens_for_pcp(
+ tokens)
+ num_scheduled_tokens = np.array(tokens, dtype=np.int32)
+ total_num_scheduled_tokens = sum(num_scheduled_tokens[:num_reqs])
+ else:
+ position_pcp, pcp_unpad_mask = None, None
+ self.num_pcp_pads = self.num_pcp_pads[:num_reqs]
+>>>>>>> main
total_num_pcp_pads = sum(self.num_pcp_pads)
max_num_scheduled_tokens = max(tokens)
- num_valid_tokens = np.array([
- num_tokens -
- len(scheduler_output.scheduled_spec_decode_tokens.get(i, []))
- for num_tokens, i in zip(tokens, req_ids)
- ],
- dtype=np.int32)
-
- if (self.use_aclgraph and total_num_scheduled_tokens
- <= self.aclgraph_batch_sizes[-1]):
+ num_valid_tokens = np.array(
+ [
+ num_tokens
+ - len(scheduler_output.scheduled_spec_decode_tokens.get(i, []))
+ for num_tokens, i in zip(tokens, req_ids)
+ ],
+ dtype=np.int32,
+ )
+
+ if (
+ self.use_aclgraph
+ and total_num_scheduled_tokens <= self.aclgraph_batch_sizes[-1]
+ ):
# Add padding to the batch size.
num_input_tokens = self.vllm_config.pad_for_cudagraph(
- total_num_scheduled_tokens)
+ total_num_scheduled_tokens
+ )
elif self.use_aclgraph and enable_sp(self.vllm_config):
# When using aclgraph, if total_num_scheduled_tokens exceeds the maximum graph size,
# the model will fall back to running its FX graph in eager mode.
# In this case, when sequence parallelism is enabled, we need to pad tokens to align
# with tp_size because pad_size cannot be captured by the FX graph
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
- num_input_tokens = math.ceil(
- total_num_scheduled_tokens / tp_size) * tp_size
+ num_input_tokens = math.ceil(total_num_scheduled_tokens / tp_size) * tp_size
else:
# Eager mode.
num_input_tokens = total_num_scheduled_tokens
# Get the attention state.
- attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens,
- num_valid_tokens)
+ attn_state = self._build_attn_state(
+ num_reqs, num_scheduled_tokens, num_valid_tokens
+ )
self.attn_state = attn_state # type: ignore
# Determine if it's a splitfuse batch
with_prefill = attn_state not in [
- AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
+ AscendAttentionState.DecodeOnly,
+ AscendAttentionState.SpecDecoding,
]
self.query_lens = torch.from_numpy(num_scheduled_tokens)
@@ -1458,9 +1540,9 @@ def _prepare_inputs(
# Get info across DP ranks.
# NOTE: maybe_padded_num_tokens is only used when using TorchAir with DP,
# Otherwise, it's just max_tokens_across_dp_cpu
- (maybe_padded_num_tokens, num_tokens_across_dp,
- with_prefill) = self._sync_metadata_across_dp(num_input_tokens,
- with_prefill)
+ (maybe_padded_num_tokens, num_tokens_across_dp, with_prefill) = (
+ self._sync_metadata_across_dp(num_input_tokens, with_prefill)
+ )
# TODO: Now that num_input_tokens is basically identical with maybe_padded_num_tokens
# We should consider removing maybe_padded_num_tokens later
@@ -1472,19 +1554,19 @@ def _prepare_inputs(
# Get request indices.
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
- req_indices = np.repeat(self.arange_np[:num_reqs],
- num_scheduled_tokens)
+ req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens)
# cu_num_tokens: [2, 5, 3] -> [2, 7, 10]
# arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
- cu_num_tokens, arange = self._get_cumsum_and_arange(
- num_scheduled_tokens)
+ cu_num_tokens, arange = self._get_cumsum_and_arange(num_scheduled_tokens)
if self.pcp_size > 1:
positions_np = self.positions_np[:total_num_scheduled_tokens]
- np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
- position_pcp[:total_num_scheduled_tokens],
- out=positions_np)
+ np.add(
+ self.input_batch.num_computed_tokens_cpu[req_indices],
+ position_pcp[:total_num_scheduled_tokens],
+ out=positions_np,
+ )
else:
self.positions_np[:total_num_scheduled_tokens] = positions_np
@@ -1496,35 +1578,41 @@ def _prepare_inputs(
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
self.mrope_positions[:, :total_num_scheduled_tokens].copy_(
self.mrope_positions_cpu[:, :total_num_scheduled_tokens],
- non_blocking=True)
+ non_blocking=True,
+ )
# Get token indices.
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
# -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
# where M is the max_model_len.
- token_indices = (positions_np +
- req_indices * self.input_batch.token_ids_cpu.shape[1])
+ token_indices = (
+ positions_np + req_indices * self.input_batch.token_ids_cpu.shape[1]
+ )
token_indices_tensor = torch.from_numpy(token_indices)
# Prepare input_ids.
# NOTE(woosuk): We use torch.index_select instead of np.take here
# because torch.index_select is much faster than np.take for large
# tensors.
- torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
- 0,
- token_indices_tensor,
- out=self.input_ids_cpu[:total_num_scheduled_tokens])
+ torch.index_select(
+ self.input_batch.token_ids_cpu_tensor.flatten(),
+ 0,
+ token_indices_tensor,
+ out=self.input_ids_cpu[:total_num_scheduled_tokens],
+ )
is_token_ids = self.input_batch.is_token_ids.flatten()
torch.index_select(
is_token_ids,
0,
token_indices_tensor,
- out=self.is_token_ids.cpu[:total_num_scheduled_tokens])
+ out=self.is_token_ids.cpu[:total_num_scheduled_tokens],
+ )
# Because we did not pre-allocate a massive prompt_embeds CPU tensor on
# the InputBatch, we need to fill in the prompt embeds into the expected
# spots in the GpuModelRunner's pre-allocated prompt_embeds tensor.
- if self.input_batch.req_prompt_embeds and (self.is_multimodal_model or
- self.enable_prompt_embeds):
+ if self.input_batch.req_prompt_embeds and (
+ self.is_multimodal_model or self.enable_prompt_embeds
+ ):
output_idx = 0
for req_idx in range(num_reqs):
num_sched = num_scheduled_tokens[req_idx]
@@ -1553,26 +1641,25 @@ def _prepare_inputs(
actual_num_sched = actual_end - start_pos
if actual_num_sched > 0:
- self.inputs_embeds.cpu[output_idx:output_idx +
- actual_num_sched].copy_(
- req_embeds[start_pos:actual_end]
- )
+ self.inputs_embeds.cpu[
+ output_idx : output_idx + actual_num_sched
+ ].copy_(req_embeds[start_pos:actual_end])
output_idx += num_sched
self.query_start_loc_np[0] = 0
- self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
- self.query_start_loc[:num_reqs + 1].copy_(
- self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True)
+ self.query_start_loc_np[1 : num_reqs + 1] = cu_num_tokens
+ self.query_start_loc[: num_reqs + 1].copy_(
+ self.query_start_loc_cpu[: num_reqs + 1], non_blocking=True
+ )
self.seq_lens_np[:num_reqs] = (
- self.input_batch.num_computed_tokens_cpu[:num_reqs] +
- num_scheduled_tokens)
- self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs],
- non_blocking=True)
+ self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens
+ )
+ self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs], non_blocking=True)
# Fill unused with -1. Needed for reshape_and_cache
- self.query_start_loc[num_reqs + 1:].fill_(-1)
+ self.query_start_loc[num_reqs + 1 :].fill_(-1)
self.seq_lens[num_reqs:].fill_(0)
self.query_lens = torch.from_numpy(num_scheduled_tokens)
@@ -1581,18 +1668,20 @@ def _prepare_inputs(
self._prepare_input_ids(total_num_scheduled_tokens, cu_num_tokens)
self.positions_cpu[total_num_scheduled_tokens:num_input_tokens].zero_()
self.positions[:num_input_tokens].copy_(
- self.positions_cpu[:num_input_tokens], non_blocking=True)
+ self.positions_cpu[:num_input_tokens], non_blocking=True
+ )
# Make Attention metadata
positions_cpu = self.positions_cpu[:num_input_tokens]
positions = self.positions[:num_input_tokens]
seq_lens_cpu = self.seq_lens_cpu[:num_reqs]
- attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens,
- num_valid_tokens)
- self.attn_mask = self._make_attention_mask(seq_lens=seq_lens_cpu,
- position=positions_cpu,
- attn_state=attn_state)
+ attn_state = self._build_attn_state(
+ num_reqs, num_scheduled_tokens, num_valid_tokens
+ )
+ self.attn_mask = self._make_attention_mask(
+ seq_lens=seq_lens_cpu, position=positions_cpu, attn_state=attn_state
+ )
self.attn_state = attn_state # type: ignore
self.with_prefill = with_prefill
@@ -1602,9 +1691,7 @@ def _prepare_inputs(
# Record the index of requests that should not be sampled,
# so that we could clear the sampled tokens before returning
- num_tokens = [
- self.requests[r].num_tokens for r in self.input_batch.req_ids
- ]
+ num_tokens = [self.requests[r].num_tokens for r in self.input_batch.req_ids]
num_tokens_np = np.array(num_tokens, dtype=np.int32)
num_reqs = self.input_batch.num_reqs
if self.pcp_size == 1:
@@ -1612,14 +1699,15 @@ def _prepare_inputs(
else:
# while pcp > 1, we need the original num_scheduled_tokens before split
# to calculate discard_requests_mask
- original_seq_lens_np = (
- self.input_batch.num_computed_tokens_cpu[:num_reqs] +
- np.array(list(scheduler_output.num_scheduled_tokens.values())))
+ original_seq_lens_np = self.input_batch.num_computed_tokens_cpu[
+ :num_reqs
+ ] + np.array(list(scheduler_output.num_scheduled_tokens.values()))
discard_requests_mask = original_seq_lens_np < num_tokens_np
discard_request_indices = np.nonzero(discard_requests_mask)[0]
self.num_discarded_requests = len(discard_request_indices)
- self.discard_request_indices.np[:self.num_discarded_requests] = (
- discard_request_indices)
+ self.discard_request_indices.np[: self.num_discarded_requests] = (
+ discard_request_indices
+ )
self.discard_request_indices.copy_to_gpu(self.num_discarded_requests)
# _prepare_inputs may reorder the batch, so we must gather
@@ -1636,12 +1724,12 @@ def _prepare_inputs(
mm_embeds = self._gather_mm_embeddings_0110(scheduler_output)
if mm_embeds:
inputs_embeds = self.model.get_input_embeddings(
- input_ids, mm_embeds)
+ input_ids, mm_embeds
+ )
else:
inputs_embeds = self.model.get_input_embeddings(input_ids)
else:
- mm_embeds, is_mm_embed = self._gather_mm_embeddings(
- scheduler_output)
+ mm_embeds, is_mm_embed = self._gather_mm_embeddings(scheduler_output)
inputs_embeds = self.model.get_input_embeddings(
input_ids,
@@ -1650,8 +1738,7 @@ def _prepare_inputs(
)
# TODO(woosuk): Avoid the copy. Optimize.
- self.inputs_embeds.gpu[:total_num_scheduled_tokens].copy_(
- inputs_embeds)
+ self.inputs_embeds.gpu[:total_num_scheduled_tokens].copy_(inputs_embeds)
inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens]
input_ids = None
elif self.enable_prompt_embeds and get_pp_group().is_first_rank:
@@ -1667,14 +1754,15 @@ def _prepare_inputs(
# If a batch only has token ids, then including the embedding layer
# in the acl graph will be more performant (like in the else case
# below).
- token_ids_idx = self.is_token_ids.gpu[:total_num_scheduled_tokens] \
- .nonzero(as_tuple=False) \
+ token_ids_idx = (
+ self.is_token_ids.gpu[:total_num_scheduled_tokens]
+ .nonzero(as_tuple=False)
.squeeze(1)
+ )
# Some tokens ids may need to become embeds
if token_ids_idx.numel() > 0:
token_ids = self.input_ids[token_ids_idx]
- tokens_to_embeds = self.model.get_input_embeddings(
- input_ids=token_ids)
+ tokens_to_embeds = self.model.get_input_embeddings(input_ids=token_ids)
self.inputs_embeds.gpu[token_ids_idx] = tokens_to_embeds
inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens]
@@ -1688,8 +1776,12 @@ def _prepare_inputs(
inputs_embeds = None
positions = self.positions[:num_input_tokens]
input_ids, positions = self._update_input_ids_and_positions(
- input_ids, positions, num_input_tokens, with_prefill,
- maybe_padded_num_tokens)
+ input_ids,
+ positions,
+ num_input_tokens,
+ with_prefill,
+ maybe_padded_num_tokens,
+ )
if get_pp_group().is_first_rank:
intermediate_tensors = None
@@ -1698,14 +1790,13 @@ def _prepare_inputs(
assert self.intermediate_tensors is not None
for k, v in intermediate_tensors.items():
self.intermediate_tensors[k][:num_input_tokens].copy_(
- v[:num_input_tokens], non_blocking=True)
- intermediate_tensors = IntermediateTensors({
- k: v[:num_input_tokens]
- for k, v in self.intermediate_tensors.items()
- })
-
- use_spec_decode = len(
- scheduler_output.scheduled_spec_decode_tokens) > 0
+ v[:num_input_tokens], non_blocking=True
+ )
+ intermediate_tensors = IntermediateTensors(
+ {k: v[:num_input_tokens] for k, v in self.intermediate_tensors.items()}
+ )
+
+ use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0
if not use_spec_decode:
# NOTE(woosuk): Due to chunked prefills, the batch may contain
# partial requests. While we should not sample any token
@@ -1713,10 +1804,24 @@ def _prepare_inputs(
# We will ignore the sampled tokens from the partial requests.
# TODO: Support prompt logprobs.
spec_decode_metadata = None
- logits_indices = torch.from_numpy(
- cu_num_tokens
- ) * self.pcp_size - self.num_pcp_pads[:num_reqs] - 1
+<<<<<<< HEAD
+ logits_indices = (
+ torch.from_numpy(cu_num_tokens) * self.pcp_size
+ - self.num_pcp_pads[:num_reqs]
+ - 1
+ )
logits_indices = logits_indices.to(self.device, non_blocking=True)
+=======
+ if self.pcp_size * self.dcp_size > 1:
+ logits_indices = torch.from_numpy(
+ cu_num_tokens
+ ) * self.pcp_size - self.num_pcp_pads[:num_reqs] - 1
+ logits_indices = logits_indices.to(self.device,
+ non_blocking=True)
+ else:
+ logits_indices = torch.from_numpy(cu_num_tokens - 1).to(
+ self.device, non_blocking=True)
+>>>>>>> main
else:
# pcp not supported now
assert self.pcp_size == 1
@@ -1724,13 +1829,16 @@ def _prepare_inputs(
# Iterate over the dictionary rather than all requests since not all
# requests have draft tokens.
num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
- for req_id, draft_token_ids in (
- scheduler_output.scheduled_spec_decode_tokens.items()):
+ for (
+ req_id,
+ draft_token_ids,
+ ) in scheduler_output.scheduled_spec_decode_tokens.items():
req_idx = self.input_batch.req_id_to_index[req_id]
num_draft_tokens[req_idx] = len(draft_token_ids)
spec_decode_metadata = self._calc_spec_decode_metadata(
- num_draft_tokens, cu_num_tokens)
+ num_draft_tokens, cu_num_tokens
+ )
logits_indices = spec_decode_metadata.logits_indices
self.num_draft_tokens.np[:num_reqs] = num_draft_tokens
self.num_draft_tokens.np[num_reqs:].fill(0)
@@ -1738,34 +1846,40 @@ def _prepare_inputs(
# Used in the below loop.
# query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1]
- num_computed_tokens_cpu = (
- self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
+ num_computed_tokens_cpu = self.input_batch.num_computed_tokens_cpu_tensor[
+ :num_reqs
+ ]
self.spec_decode_common_attn_metadata = None
if use_spec_decode and self.need_accepted_tokens:
self.num_accepted_tokens.np[:num_reqs] = (
- self.input_batch.num_accepted_tokens_cpu[:num_reqs])
+ self.input_batch.num_accepted_tokens_cpu[:num_reqs]
+ )
self.num_accepted_tokens.np[num_reqs:].fill(1)
self.num_accepted_tokens.copy_to_gpu()
is_prefill = len(scheduler_output.scheduled_new_reqs) > 0
if self.speculative_config and self.pcp_size > 1 and is_prefill:
self._generate_pcp_mtp_input(
- num_reqs, scheduler_output.total_num_scheduled_tokens,
- scheduler_output.num_scheduled_tokens)
+ num_reqs,
+ scheduler_output.total_num_scheduled_tokens,
+ scheduler_output.num_scheduled_tokens,
+ )
# prepare pcp meta data
long_seq_metadata = self._generate_pcp_metadata(
- total_num_scheduled_tokens, seq_lens_cpu)
+ total_num_scheduled_tokens, seq_lens_cpu
+ )
# Prepare the attention metadata for each KV cache group and make layers
# in the same group share the same metadata.
for kv_cache_group_id, kv_cache_group_spec in enumerate(
- self.kv_cache_config.kv_cache_groups):
- slot_mapping_size = (total_num_scheduled_tokens
- if self.pcp_size == 1 else
- total_num_scheduled_tokens * self.pcp_size -
- total_num_pcp_pads)
- if isinstance(kv_cache_group_spec.kv_cache_spec,
- EncoderOnlyAttentionSpec):
+ self.kv_cache_config.kv_cache_groups
+ ):
+ slot_mapping_size = (
+ total_num_scheduled_tokens
+ if self.pcp_size == 1
+ else total_num_scheduled_tokens * self.pcp_size - total_num_pcp_pads
+ )
+ if isinstance(kv_cache_group_spec.kv_cache_spec, EncoderOnlyAttentionSpec):
# Encoder-only layers do not have KV cache, so we need to
# create a dummy block table and slot mapping for them.
blk_table_tensor = torch.zeros(
@@ -1774,7 +1888,7 @@ def _prepare_inputs(
device=self.device,
)
slot_mapping = torch.zeros(
- (total_num_scheduled_tokens, ),
+ (total_num_scheduled_tokens,),
dtype=torch.int64,
device=self.device,
)
@@ -1784,29 +1898,27 @@ def _prepare_inputs(
slot_mapping = blk_table.slot_mapping[:slot_mapping_size]
blk_table.slot_mapping[slot_mapping_size:].fill_(0)
if self.pcp_size > 1:
- slot_mapping_for_pcp = blk_table.slot_mapping[:
- long_seq_metadata
- .
- num_actual_tokens_pcp_padded]
+ slot_mapping_for_pcp = blk_table.slot_mapping[
+ : long_seq_metadata.num_actual_tokens_pcp_padded
+ ]
slot_mapping_for_pcp[slot_mapping_size:].fill_(-1)
assert pcp_unpad_mask is not None
- pcp_padded_slot_mapping = self.pcp_padded_slot_mapping[:
- pcp_unpad_mask
- .
- shape[
- 0]]
+ pcp_padded_slot_mapping = self.pcp_padded_slot_mapping[
+ : pcp_unpad_mask.shape[0]
+ ]
pcp_padded_slot_mapping.fill_(-1)
- pcp_padded_slot_mapping[
- pcp_unpad_mask] = slot_mapping_for_pcp[:
- slot_mapping_size]
- slot_mapping_for_pcp[:long_seq_metadata.
- num_actual_tokens_pcp_padded] = pcp_padded_slot_mapping
+ pcp_padded_slot_mapping[pcp_unpad_mask] = slot_mapping_for_pcp[
+ :slot_mapping_size
+ ]
+ slot_mapping_for_pcp[
+ : long_seq_metadata.num_actual_tokens_pcp_padded
+ ] = pcp_padded_slot_mapping
slot_mapping = slot_mapping_for_pcp
# Make AscendCommonAttentionMetadata
common_attn_metadata = AscendCommonAttentionMetadata(
- query_start_loc=self.query_start_loc[:num_reqs + 1],
- query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1],
+ query_start_loc=self.query_start_loc[: num_reqs + 1],
+ query_start_loc_cpu=self.query_start_loc_cpu[: num_reqs + 1],
seq_lens_cpu=self.seq_lens_cpu[:num_reqs],
seq_lens=self.seq_lens_cpu[:num_reqs],
num_reqs=num_reqs,
@@ -1830,54 +1942,81 @@ def _prepare_inputs(
prefill_context_parallel_metadata=long_seq_metadata,
)
- if self.speculative_config and \
- self.spec_decode_common_attn_metadata is None:
+ if (
+ self.speculative_config
+ and self.spec_decode_common_attn_metadata is None
+ ):
self.spec_decode_common_attn_metadata = common_attn_metadata
for attn_group in self.attn_groups[kv_cache_group_id]:
common_prefix_len = 0
extra_attn_metadata_args = {}
builder = attn_group.get_metadata_builder()
- if isinstance(builder, GDNAttentionMetadataBuilder
- ) or self.model_config.runner_type == "pooling":
+ if (
+ isinstance(builder, GDNAttentionMetadataBuilder)
+ or self.model_config.runner_type == "pooling"
+ ):
if use_spec_decode:
extra_attn_metadata_args = dict(
+<<<<<<< HEAD
+ num_accepted_tokens=self.num_accepted_tokens.gpu[:num_reqs],
+ num_draft_tokens=self.num_draft_tokens.gpu[:num_reqs],
+=======
num_accepted_tokens=self.num_accepted_tokens.
gpu[:num_reqs],
- num_draft_tokens=self.num_draft_tokens.
+ num_decode_draft_tokens_cpu=self.num_draft_tokens.
gpu[:num_reqs],
+>>>>>>> main
)
attn_metadata_i = builder.build(
common_prefix_len=common_prefix_len,
common_attn_metadata=common_attn_metadata,
- **extra_attn_metadata_args)
+ **extra_attn_metadata_args,
+ )
else:
attn_metadata_i = builder.build(
common_prefix_len=common_prefix_len,
common_attn_metadata=common_attn_metadata,
model=self.get_model(),
- **extra_attn_metadata_args)
+ **extra_attn_metadata_args,
+ )
for layer_name in attn_group.layer_names:
attn_metadata[layer_name] = attn_metadata_i
if lmhead_tp_enable():
- max_num_reqs_across_dp = maybe_padded_num_tokens if not with_prefill else self.max_num_reqs
+ max_num_reqs_across_dp = (
+ maybe_padded_num_tokens if not with_prefill else self.max_num_reqs
+ )
logits_indices = nn.functional.pad(
- logits_indices,
- (0, max_num_reqs_across_dp - logits_indices.shape[0]))
-
- return (attn_metadata, positions, num_scheduled_tokens,
- num_input_tokens, num_tokens_across_dp,
- maybe_padded_num_tokens, logits_indices, spec_decode_metadata,
- input_ids, inputs_embeds, intermediate_tensors,
- max_num_scheduled_tokens)
-
- def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill,
- maybe_padded_num_tokens,
- input_ids, positions,
- intermediate_tensors,
- inputs_embeds):
+ logits_indices, (0, max_num_reqs_across_dp - logits_indices.shape[0])
+ )
+
+ return (
+ attn_metadata,
+ positions,
+ num_scheduled_tokens,
+ num_input_tokens,
+ num_tokens_across_dp,
+ maybe_padded_num_tokens,
+ logits_indices,
+ spec_decode_metadata,
+ input_ids,
+ inputs_embeds,
+ intermediate_tensors,
+ max_num_scheduled_tokens,
+ )
+
+ def _generate_process_reqs_hidden_states(
+ self,
+ attn_metadata,
+ with_prefill,
+ maybe_padded_num_tokens,
+ input_ids,
+ positions,
+ intermediate_tensors,
+ inputs_embeds,
+ ):
assert self.model is not None
hidden_states = self.model(
input_ids=input_ids,
@@ -1887,29 +2026,37 @@ def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill,
)
forward_context = get_forward_context()
- if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL \
- and not self.use_sparse:
+ if (
+ forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL
+ and not self.use_sparse
+ ):
# TODO: maybe_padded_num_tokens will be removed, use num_input_tokens instead
if self.vllm_config.model_config.use_mla:
if self.pcp_size * self.dcp_size > 1:
# FIXME: Try using `auto_dispatch_capture=True`
- update_mla_attn_dcp_pcp_params(self.update_stream,
- forward_context,
- maybe_padded_num_tokens,
- self.speculative_config)
+ update_mla_attn_dcp_pcp_params(
+ self.update_stream,
+ forward_context,
+ maybe_padded_num_tokens,
+ self.speculative_config,
+ )
else:
# FIXME: Try using `auto_dispatch_capture=True`
- update_mla_attn_params(self.update_stream, forward_context,
- maybe_padded_num_tokens,
- self.speculative_config)
+ update_mla_attn_params(
+ self.update_stream,
+ forward_context,
+ maybe_padded_num_tokens,
+ self.speculative_config,
+ )
else:
if self.pcp_size * self.dcp_size > 1:
- update_attn_dcp_pcp_params(self.update_stream,
- forward_context,
- maybe_padded_num_tokens)
+ update_attn_dcp_pcp_params(
+ self.update_stream, forward_context, maybe_padded_num_tokens
+ )
else:
- update_attn_params(self.update_stream, forward_context,
- maybe_padded_num_tokens)
+ update_attn_params(
+ self.update_stream, forward_context, maybe_padded_num_tokens
+ )
if get_forward_context().sp_enabled:
hidden_states = tensor_model_parallel_all_gather(hidden_states, 0)
@@ -1919,34 +2066,50 @@ def _generate_process_reqs_hidden_states(self, attn_metadata, with_prefill,
if self.pcp_size > 1:
hidden_states = get_pcp_group().all_gather(
- hidden_states[:self.num_actual_tokens_pcp_padded //
- self.pcp_size], 0)
+ hidden_states[: self.num_actual_tokens_pcp_padded // self.pcp_size], 0
+ )
hidden_states = torch.index_select(
- hidden_states, 0,
- self.pcp_allgather_restore_idx[:hidden_states.shape[0]])
+ hidden_states,
+ 0,
+ self.pcp_allgather_restore_idx[: hidden_states.shape[0]],
+ )
return hidden_states
- def _build_attn_state(self, num_reqs, num_scheduled_tokens,
- num_valid_tokens):
+ def _build_attn_state(self, num_reqs, num_scheduled_tokens, num_valid_tokens):
ascend_config = get_ascend_config()
if np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens):
attn_state = AscendAttentionState.PrefillNoCache
# We assume it is the decode stage, where prefill occurs but only one token is not hit in cache.
elif np.all(num_scheduled_tokens == 1):
attn_state = AscendAttentionState.DecodeOnly
- if self.speculative_config and self.speculative_config.method == 'deepseek_mtp':
+ if (
+ self.speculative_config
+ and self.speculative_config.method == "deepseek_mtp"
+ ):
# SpecDecoding now supports seq_len=1 and seq_len=2
# In Prefilling Decoding Disaggregation scenario, SpecDecoding need to supports seq_len=1
attn_state = AscendAttentionState.SpecDecoding
# Speculative decoding.
elif np.all(num_valid_tokens == 1):
- if self.drafter and (self.drafter.name == SpecDcodeType.EAGLE
- or self.drafter.name == SpecDcodeType.EAGLE3):
+<<<<<<< HEAD
+ if self.drafter and (
+ self.drafter.name == SpecDcodeType.EAGLE
+ or self.drafter.name == SpecDcodeType.EAGLE3
+ or self.drafter.name == SpecDcodeType.DRAFT_MODEL
+ ):
attn_state = AscendAttentionState.ChunkedPrefill
else:
+=======
+ if self.speculative_config and self.speculative_config.method == 'deepseek_mtp':
+>>>>>>> main
attn_state = AscendAttentionState.SpecDecoding
+ else:
+ attn_state = AscendAttentionState.ChunkedPrefill
# splitfuse
- elif not ascend_config.ascend_scheduler_config.enabled or self.chunked_prefill_enabled:
+ elif (
+ not ascend_config.ascend_scheduler_config.enabled
+ or self.chunked_prefill_enabled
+ ):
attn_state = AscendAttentionState.ChunkedPrefill
else:
attn_state = AscendAttentionState.PrefillCacheHit
@@ -1955,9 +2118,14 @@ def _build_attn_state(self, num_reqs, num_scheduled_tokens,
def _update_graph_pad_size(self, with_prefill, graph_pad_size):
self.graph_pad_size = -1
- def _update_input_ids_and_positions(self, input_ids, positions,
- num_input_tokens, with_prefill,
- maybe_padded_num_tokens):
+ def _update_input_ids_and_positions(
+ self,
+ input_ids,
+ positions,
+ num_input_tokens,
+ with_prefill,
+ maybe_padded_num_tokens,
+ ):
if self.uses_mrope:
positions = self.mrope_positions[:, :num_input_tokens]
return input_ids, positions
@@ -1984,13 +2152,15 @@ def _calc_spec_decode_metadata(
cu_num_sampled_tokens = np.cumsum(num_sampled_tokens, dtype=np.int32)
total_num_sampled_tokens = cu_num_sampled_tokens[-1]
# Step 2. [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9]
- cumsums_offsets = np.repeat(cu_num_sampled_tokens - num_sampled_tokens,
- num_sampled_tokens)
+ cumsums_offsets = np.repeat(
+ cu_num_sampled_tokens - num_sampled_tokens, num_sampled_tokens
+ )
# Step 3. [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
arange = self.arange_np[:total_num_sampled_tokens] - cumsums_offsets
# Step 4. [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207]
logits_indices = np.repeat(
- cu_num_scheduled_tokens - num_sampled_tokens, num_sampled_tokens)
+ cu_num_scheduled_tokens - num_sampled_tokens, num_sampled_tokens
+ )
# Step 5. [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
logits_indices += arange
@@ -2002,28 +2172,35 @@ def _calc_spec_decode_metadata(
cu_num_draft_tokens = np.cumsum(num_draft_tokens, dtype=np.int32)
total_num_draft_tokens = cu_num_draft_tokens[-1]
# [0, 0, 0, 3, 3, 5]
- cumsums_offsets = np.repeat(cu_num_draft_tokens - num_draft_tokens,
- num_draft_tokens)
+ cumsums_offsets = np.repeat(
+ cu_num_draft_tokens - num_draft_tokens, num_draft_tokens
+ )
# [0, 1, 2, 0, 1, 0]
arange = self.arange_np[:total_num_draft_tokens] - cumsums_offsets
# [0, 0, 0, 5, 5, 9]
target_logits_indices = np.repeat(
- cu_num_sampled_tokens - num_sampled_tokens, num_draft_tokens)
+ cu_num_sampled_tokens - num_sampled_tokens, num_draft_tokens
+ )
# [0, 1, 2, 5, 6, 9]
target_logits_indices += arange
# TODO: Optimize the CPU -> NPU copy.
cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to(
- self.device, non_blocking=True)
+ self.device, non_blocking=True
+ )
if not vllm_version_is("0.11.0"):
cu_num_sampled_tokens = torch.from_numpy(cu_num_sampled_tokens).to(
- self.device, non_blocking=True)
- logits_indices = torch.from_numpy(logits_indices).to(self.device,
- non_blocking=True)
+ self.device, non_blocking=True
+ )
+ logits_indices = torch.from_numpy(logits_indices).to(
+ self.device, non_blocking=True
+ )
target_logits_indices = torch.from_numpy(target_logits_indices).to(
- self.device, non_blocking=True)
+ self.device, non_blocking=True
+ )
bonus_logits_indices = torch.from_numpy(bonus_logits_indices).to(
- self.device, non_blocking=True)
+ self.device, non_blocking=True
+ )
# Compute the draft token ids.
# draft_token_indices: [ 1, 2, 3, 105, 106, 208]
@@ -2068,47 +2245,49 @@ def apply_grammar_bitmask(
# request in the batch, as the logit indices are offset by this amount.
struct_out_req_batch_indices: dict[str, int] = {}
cumulative_offset = 0
- seq = sorted(self.input_batch.req_id_to_index.items(),
- key=lambda x: x[1])
+ seq = sorted(self.input_batch.req_id_to_index.items(), key=lambda x: x[1])
for req_id, batch_index in seq:
logit_index = batch_index + cumulative_offset
cumulative_offset += len(
- scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
+ scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])
+ )
if req_id in scheduler_output.structured_output_request_ids:
struct_out_req_batch_indices[req_id] = logit_index
out_indices = []
# Reorder the bitmask to match the order of the requests in the batch.
- sorted_bitmask = np.zeros_like(grammar_bitmask,
- shape=(logits.shape[0],
- grammar_bitmask.shape[1]))
+ sorted_bitmask = np.zeros_like(
+ grammar_bitmask, shape=(logits.shape[0], grammar_bitmask.shape[1])
+ )
cumulative_index = 0
if vllm_version_is("0.11.0"):
seq = sorted(
scheduler_output.structured_output_request_ids.items(),
- key=lambda x: x[1])
+ key=lambda x: x[1],
+ )
for req_id, _ in seq:
logit_index = struct_out_req_batch_indices[req_id]
num_spec_tokens = len(
- scheduler_output.scheduled_spec_decode_tokens.get(
- req_id, []))
+ scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])
+ )
for i in range(1 + num_spec_tokens):
- sorted_bitmask[logit_index + i] = \
- grammar_bitmask[cumulative_index + i]
+ sorted_bitmask[logit_index + i] = grammar_bitmask[
+ cumulative_index + i
+ ]
out_indices.append(logit_index + i)
cumulative_index += 1 + num_spec_tokens
else:
for req_id in scheduler_output.structured_output_request_ids:
num_spec_tokens = len(
- scheduler_output.scheduled_spec_decode_tokens.get(
- req_id, []))
+ scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])
+ )
if req_id in struct_out_req_batch_indices:
logit_index = struct_out_req_batch_indices[req_id]
for i in range(1 + num_spec_tokens):
- sorted_bitmask[logit_index +
- i] = grammar_bitmask[cumulative_index +
- i]
+ sorted_bitmask[logit_index + i] = grammar_bitmask[
+ cumulative_index + i
+ ]
out_indices.append(logit_index + i)
cumulative_index += 1 + num_spec_tokens
grammar_bitmask = sorted_bitmask
@@ -2147,9 +2326,16 @@ def propose_draft_token_ids(
draft_token_ids = None
else:
draft_token_ids = self.drafter.generate_token_ids(
- valid_sampled_token_ids, sampling_metadata, scheduler_output,
- spec_decode_metadata, positions, num_scheduled_tokens,
- hidden_states, attn_metadata, aux_hidden_states)
+ valid_sampled_token_ids,
+ sampling_metadata,
+ scheduler_output,
+ spec_decode_metadata,
+ positions,
+ num_scheduled_tokens,
+ hidden_states,
+ attn_metadata,
+ aux_hidden_states,
+ )
return draft_token_ids
def _pool(
@@ -2161,16 +2347,16 @@ def _pool(
finished_recving: Optional[set[str]] = None,
kv_connector_output: Optional["KVConnectorOutput"] = None,
) -> ModelRunnerOutput:
- assert self.input_batch.num_reqs ==\
- len(self.input_batch.pooling_params), \
- "Either all or none of the requests in" \
- " a batch must be pooling request"
+ assert self.input_batch.num_reqs == len(self.input_batch.pooling_params), (
+ "Either all or none of the requests in a batch must be pooling request"
+ )
hidden_states = hidden_states[:num_scheduled_tokens]
pooling_metadata = self.input_batch.pooling_metadata
- pooling_metadata.build_pooling_cursor(num_scheduled_tokens_np.tolist(),
- device=hidden_states.device)
- seq_lens_cpu = self.seq_lens_cpu[:self.input_batch.num_reqs]
+ pooling_metadata.build_pooling_cursor(
+ num_scheduled_tokens_np.tolist(), device=hidden_states.device
+ )
+ seq_lens_cpu = self.seq_lens_cpu[: self.input_batch.num_reqs]
model = cast(VllmModelForPooling, self.model)
raw_pooler_output = model.pooler(
@@ -2185,7 +2371,8 @@ def _pool(
pooler_output: list[Optional[torch.Tensor]] = []
for raw_output, seq_len, prompt_len in zip(
- raw_pooler_output, seq_lens_cpu, pooling_metadata.prompt_lens):
+ raw_pooler_output, seq_lens_cpu, pooling_metadata.prompt_lens
+ ):
output = raw_output if seq_len == prompt_len else None
pooler_output.append(output)
@@ -2199,8 +2386,9 @@ def _pool(
kv_connector_output=kv_connector_output,
)
- def _select_moe_comm_method(self, num_tokens: int,
- with_prefill: bool) -> Optional[MoECommType]:
+ def _select_moe_comm_method(
+ self, num_tokens: int, with_prefill: bool
+ ) -> Optional[MoECommType]:
"""1. If expert parallel is not enabled, we use all-gather since MC2 and all-to-all
are designed for expert parallelism.
2. If expert parallel is enabled, we need to consider the soc version and the
@@ -2227,15 +2415,18 @@ def _select_moe_comm_method(self, num_tokens: int,
return None
soc_version = get_ascend_soc_version()
- quant_type = getattr(self.vllm_config.model_config.hf_config,
- 'moe_quantize', None)
+ quant_type = getattr(
+ self.vllm_config.model_config.hf_config, "moe_quantize", None
+ )
model_type = self.vllm_config.model_config.hf_config.model_type
if not self.parallel_config.enable_expert_parallel:
moe_comm_type = MoECommType.ALLGATHER
elif soc_version in {AscendSocVersion.A2}:
- if (num_tokens <= self.mc2_tokens_capacity
- and self.parallel_config.world_size_across_dp >= 16):
+ if (
+ num_tokens <= self.mc2_tokens_capacity
+ and self.parallel_config.world_size_across_dp >= 16
+ ):
moe_comm_type = MoECommType.MC2
else:
# Currently, w4a8_dynamic does not support allgatherep
@@ -2245,9 +2436,11 @@ def _select_moe_comm_method(self, num_tokens: int,
moe_comm_type = MoECommType.ALLGATHER
elif soc_version in {AscendSocVersion.A3}:
- moe_comm_type = (MoECommType.MC2
- if num_tokens <= self.mc2_tokens_capacity else
- MoECommType.ALLTOALL)
+ moe_comm_type = (
+ MoECommType.MC2
+ if num_tokens <= self.mc2_tokens_capacity
+ else MoECommType.ALLTOALL
+ )
else:
raise ValueError(f"Unsupported soc_version: {soc_version}")
@@ -2262,8 +2455,7 @@ def _select_moe_comm_method(self, num_tokens: int,
moe_comm_type = MoECommType.ALLGATHER
if is_global_first_rank():
- logger.debug(f"num_tokens: {num_tokens}, "
- f"moe_comm_type: {moe_comm_type}")
+ logger.debug(f"num_tokens: {num_tokens}, moe_comm_type: {moe_comm_type}")
return moe_comm_type
@torch.inference_mode()
@@ -2286,61 +2478,80 @@ def execute_model(
if self.dynamic_eplb:
self.eplb_updator.forward_before()
- (attn_metadata, positions, num_scheduled_tokens_np,
- num_input_tokens, num_tokens_across_dp, maybe_padded_num_tokens,
- logits_indices, spec_decode_metadata, input_ids, inputs_embeds,
- intermediate_tensors,
- max_query_len) = (self._prepare_inputs(scheduler_output,
- intermediate_tensors))
+ (
+ attn_metadata,
+ positions,
+ num_scheduled_tokens_np,
+ num_input_tokens,
+ num_tokens_across_dp,
+ maybe_padded_num_tokens,
+ logits_indices,
+ spec_decode_metadata,
+ input_ids,
+ inputs_embeds,
+ intermediate_tensors,
+ max_query_len,
+ ) = self._prepare_inputs(scheduler_output, intermediate_tensors)
if self.dynamic_eplb:
self.eplb_updator.take_update_info_from_eplb_process()
- moe_comm_type = self._select_moe_comm_method(num_input_tokens,
- self.with_prefill)
+ moe_comm_type = self._select_moe_comm_method(
+ num_input_tokens, self.with_prefill
+ )
uniform_decode = (max_query_len == self.uniform_decode_query_len) and (
scheduler_output.total_num_scheduled_tokens
- == self.input_batch.num_reqs * max_query_len)
- batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
- uniform_decode=uniform_decode)
- aclgraph_runtime_mode, batch_descriptor = \
- self.aclgraph_dispatcher.dispatch(batch_descriptor)
+ == self.input_batch.num_reqs * max_query_len
+ )
+ batch_descriptor = BatchDescriptor(
+ num_tokens=num_input_tokens, uniform_decode=uniform_decode
+ )
+ aclgraph_runtime_mode, batch_descriptor = self.aclgraph_dispatcher.dispatch(
+ batch_descriptor
+ )
# Run forward pass
with ProfileExecuteDuration().capture_async("forward"):
with set_ascend_forward_context(
- attn_metadata,
- self.vllm_config,
- num_tokens=num_input_tokens,
- num_tokens_across_dp=num_tokens_across_dp,
- with_prefill=self.with_prefill,
- reserved_mc2_mask=self.reserved_mc2_mask,
- moe_comm_type=moe_comm_type,
- aclgraph_runtime_mode=aclgraph_runtime_mode,
- batch_descriptor=batch_descriptor,
- num_actual_tokens=scheduler_output.
- total_num_scheduled_tokens,
- prefetch_stream=self.prefetch_stream,
- model_instance=self.model,
- weight_prefetch_method=self.weight_prefetch_method):
+ attn_metadata,
+ self.vllm_config,
+ num_tokens=num_input_tokens,
+ num_tokens_across_dp=num_tokens_across_dp,
+ with_prefill=self.with_prefill,
+ reserved_mc2_mask=self.reserved_mc2_mask,
+ moe_comm_type=moe_comm_type,
+ aclgraph_runtime_mode=aclgraph_runtime_mode,
+ batch_descriptor=batch_descriptor,
+ num_actual_tokens=scheduler_output.total_num_scheduled_tokens,
+ prefetch_stream=self.prefetch_stream,
+ model_instance=self.model,
+ weight_prefetch_method=self.weight_prefetch_method,
+ ):
self.maybe_setup_kv_connector(scheduler_output)
hidden_states = self._generate_process_reqs_hidden_states(
- attn_metadata, self.with_prefill, maybe_padded_num_tokens,
- input_ids, positions, intermediate_tensors, inputs_embeds)
+ attn_metadata,
+ self.with_prefill,
+ maybe_padded_num_tokens,
+ input_ids,
+ positions,
+ intermediate_tensors,
+ inputs_embeds,
+ )
self.maybe_wait_for_kv_save()
finished_sending, finished_recving = self.get_finished_kv_transfer(
- scheduler_output)
+ scheduler_output
+ )
aux_hidden_states = None
if self.drafter and self.drafter.name == SpecDcodeType.EAGLE3:
hidden_states, aux_hidden_states = hidden_states
kv_connector_output = KVConnectorOutput(
- finished_sending=finished_sending,
- finished_recving=finished_recving)
+ finished_sending=finished_sending, finished_recving=finished_recving
+ )
finished_sending = None
finished_recving = None
with ProfileExecuteDuration().capture_async("post process"):
@@ -2348,9 +2559,10 @@ def execute_model(
# to make sure we are synced across pp ranks
# TODO: Support overlapping mirco-batches
# https://github.com/vllm-project/vllm/issues/18019
- broadcast_pp_output = \
- self.parallel_config.distributed_executor_backend \
- == "external_launcher" and len(get_pp_group().ranks) > 0
+ broadcast_pp_output = (
+ self.parallel_config.distributed_executor_backend == "external_launcher"
+ and len(get_pp_group().ranks) > 0
+ )
if not get_pp_group().is_last_rank:
# For mid-pipeline stages, return the hidden states.
if not broadcast_pp_output:
@@ -2358,57 +2570,62 @@ def execute_model(
return hidden_states
assert isinstance(hidden_states, IntermediateTensors)
get_pp_group().send_tensor_dict(
- hidden_states.tensors, all_gather_group=get_tp_group())
+ hidden_states.tensors, all_gather_group=get_tp_group()
+ )
logits = None
else:
if self.input_batch.pooling_params:
return self._pool(
hidden_states,
scheduler_output.total_num_scheduled_tokens,
- num_scheduled_tokens_np, finished_sending,
- finished_recving, kv_connector_output)
+ num_scheduled_tokens_np,
+ finished_sending,
+ finished_recving,
+ kv_connector_output,
+ )
sample_hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(sample_hidden_states)
if broadcast_pp_output:
- model_output_broadcast_data = {
- "logits": logits.contiguous(),
- } if logits is not None else {}
- model_output_broadcast_data = get_pp_group(
- ).broadcast_tensor_dict(model_output_broadcast_data,
- src=len(get_pp_group().ranks) - 1)
+ model_output_broadcast_data = (
+ {
+ "logits": logits.contiguous(),
+ }
+ if logits is not None
+ else {}
+ )
+ model_output_broadcast_data = get_pp_group().broadcast_tensor_dict(
+ model_output_broadcast_data, src=len(get_pp_group().ranks) - 1
+ )
assert model_output_broadcast_data is not None
logits = model_output_broadcast_data["logits"]
# Apply structured output bitmasks if present
if vllm_version_is("0.11.0"):
if scheduler_output.grammar_bitmask is not None:
- logits = self.apply_grammar_bitmask(
- scheduler_output, logits)
+ logits = self.apply_grammar_bitmask(scheduler_output, logits)
else:
if scheduler_output.structured_output_request_ids:
- logits = self.apply_grammar_bitmask(
- scheduler_output, logits)
+ logits = self.apply_grammar_bitmask(scheduler_output, logits)
with ProfileExecuteDuration().capture_async("Sample"):
# Sample the next token and get logprobs if needed.
sampling_metadata = self.input_batch.sampling_metadata
if spec_decode_metadata is None:
if lmhead_tp_enable() and logits is not None:
- logits = logits[:self.input_batch.num_reqs]
+ logits = logits[: self.input_batch.num_reqs]
sampler_output = self.sampler(
logits=logits,
sampling_metadata=sampling_metadata,
)
else:
if lmhead_tp_enable() and logits is not None:
- logits = logits[:len(spec_decode_metadata.logits_indices)]
+ logits = logits[: len(spec_decode_metadata.logits_indices)]
# When indexing with a tensor (bonus_logits_indices), PyTorch
# creates a new tensor with separate storage from the original
# logits tensor. This means any in-place operations on bonus_logits
# won't affect the original logits tensor.
assert logits is not None
- bonus_logits = logits[
- spec_decode_metadata.bonus_logits_indices]
+ bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
sampler_output = self.sampler(
logits=bonus_logits,
sampling_metadata=sampling_metadata,
@@ -2418,8 +2635,7 @@ def execute_model(
# Just like `bonus_logits`, `target_logits` is a new tensor with
# separate storage from the original `logits` tensor. Therefore,
# it is safe to update `target_logits` in place.
- target_logits = logits[
- spec_decode_metadata.target_logits_indices]
+ target_logits = logits[spec_decode_metadata.target_logits_indices]
output_token_ids = self.rejection_sampler(
spec_decode_metadata,
None, # draft_probs
@@ -2431,8 +2647,9 @@ def execute_model(
if self.need_accepted_tokens:
self._update_states_after_model_execute(output_token_ids)
- discard_sampled_tokens_req_indices = \
- self.discard_request_indices.np[:self.num_discarded_requests]
+ discard_sampled_tokens_req_indices = self.discard_request_indices.np[
+ : self.num_discarded_requests
+ ]
for i in discard_sampled_tokens_req_indices:
generator = self.input_batch.generators.get(int(i))
if generator is not None:
@@ -2441,18 +2658,18 @@ def execute_model(
# Copy some objects so they don't get modified after returning.
# This is important when using async scheduling.
req_ids_output_copy = self.input_batch.req_ids.copy()
- req_id_to_index_output_copy = \
- self.input_batch.req_id_to_index.copy()
+ req_id_to_index_output_copy = self.input_batch.req_id_to_index.copy()
# NOTE: NPU -> CPU Sync happens here.
# Move as many CPU operations as possible before this sync point.
logprobs_tensors = sampler_output.logprobs_tensors
- logprobs_lists = logprobs_tensors.tolists() \
- if logprobs_tensors is not None else None
+ logprobs_lists = (
+ logprobs_tensors.tolists() if logprobs_tensors is not None else None
+ )
# Compute prompt logprobs if needed.
prompt_logprobs_dict = self._get_prompt_logprobs_dict(
- hidden_states[:scheduler_output.total_num_scheduled_tokens],
+ hidden_states[: scheduler_output.total_num_scheduled_tokens],
scheduler_output,
)
@@ -2475,18 +2692,17 @@ def execute_model(
valid_sampled_token_ids[int(i)].clear()
else:
valid_sampled_token_ids = []
- invalid_req_indices = discard_sampled_tokens_req_indices.tolist(
- )
+ invalid_req_indices = discard_sampled_tokens_req_indices.tolist()
invalid_req_indices_set = set(invalid_req_indices)
assert sampled_token_ids.shape[-1] == 1
# Cache the sampled tokens on the NPU and avoid CPU sync.
# These will be copied into input_ids in the next step
# when preparing inputs.
- self.input_batch.prev_sampled_token_ids = \
- sampled_token_ids
- self.input_batch.prev_sampled_token_ids_invalid_indices = \
+ self.input_batch.prev_sampled_token_ids = sampled_token_ids
+ self.input_batch.prev_sampled_token_ids_invalid_indices = (
invalid_req_indices_set
+ )
self.input_batch.prev_req_id_to_index = {
req_id: i
for i, req_id in enumerate(self.input_batch.req_ids)
@@ -2499,8 +2715,9 @@ def execute_model(
# between the first-stage worker and the last-stage worker.
for req_idx in range(num_sampled_tokens):
if self.use_async_scheduling:
- sampled_ids = [-1] * 1 if \
- req_idx not in invalid_req_indices_set else None
+ sampled_ids = (
+ [-1] * 1 if req_idx not in invalid_req_indices_set else None
+ )
else:
sampled_ids = valid_sampled_token_ids[req_idx]
if not sampled_ids:
@@ -2511,12 +2728,11 @@ def execute_model(
assert end_idx <= self.model_config.max_model_len, (
"Sampled token IDs exceed the max model length. "
f"Total number of tokens: {end_idx} > max_model_len: "
- f"{self.model_config.max_model_len}")
+ f"{self.model_config.max_model_len}"
+ )
- self.input_batch.token_ids_cpu[req_idx,
- start_idx:end_idx] = sampled_ids
- self.input_batch.is_token_ids[req_idx,
- start_idx:end_idx] = True
+ self.input_batch.token_ids_cpu[req_idx, start_idx:end_idx] = sampled_ids
+ self.input_batch.is_token_ids[req_idx, start_idx:end_idx] = True
self.input_batch.num_tokens_no_spec[req_idx] = end_idx
self.input_batch.num_tokens[req_idx] = end_idx
req_id = self.input_batch.req_ids[req_idx]
@@ -2539,9 +2755,17 @@ def propose_draft_token_ids(sampled_token_ids):
with ProfileExecuteDuration().capture_async("Draft"):
if self.speculative_config:
+<<<<<<< HEAD
+ use_padded_batch_for_eagle = (
+ self.speculative_config
+ and self.speculative_config.method == "deepseek_mtp"
+ and not self.speculative_config.disable_padded_drafter_batch
+ )
+=======
use_padded_batch_for_eagle = self.speculative_config and \
- self.speculative_config.method == "deepseek_mtp" and \
+ self.speculative_config.method in ("deepseek_mtp", "qwen3_next_mtp") and \
not self.speculative_config.disable_padded_drafter_batch
+>>>>>>> main
if use_padded_batch_for_eagle:
# EAGLE speculative decoding can use the GPU sampled tokens
# as inputs, and does not need to wait for bookkeeping to finish.
@@ -2554,7 +2778,7 @@ def propose_draft_token_ids(sampled_token_ids):
if has_kv_transfer_group():
get_kv_transfer_group().clear_connector_metadata()
- extra_args = ({"kv_connector_output": kv_connector_output})
+ extra_args = {"kv_connector_output": kv_connector_output}
model_runner_output = ModelRunnerOutput(
req_ids=req_ids_output_copy,
@@ -2569,12 +2793,16 @@ def propose_draft_token_ids(sampled_token_ids):
durations = ProfileExecuteDuration().pop_captured_sync()
if durations:
dr_str = [
- f"[{tag}]:{duration:.2f}ms"
- for tag, duration in durations.items()
+ f"[{tag}]:{duration:.2f}ms" for tag, duration in durations.items()
]
- captured_name = "Decode" if self.attn_state == AscendAttentionState.DecodeOnly else "Prefill"
- logger.info("Profile execute duration [%s]:%s", captured_name,
- " ".join(dr_str))
+ captured_name = (
+ "Decode"
+ if self.attn_state == AscendAttentionState.DecodeOnly
+ else "Prefill"
+ )
+ logger.info(
+ "Profile execute duration [%s]:%s", captured_name, " ".join(dr_str)
+ )
if self.dynamic_eplb:
self.eplb_updator.forward_end()
if not self.use_async_scheduling:
@@ -2599,19 +2827,21 @@ def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
return DraftTokenIds(req_ids, draft_token_ids)
def kv_connector_no_forward(
- self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput:
+ self, scheduler_output: "SchedulerOutput"
+ ) -> ModelRunnerOutput:
with set_ascend_forward_context(None, self.vllm_config):
self.maybe_setup_kv_connector(scheduler_output)
- finished_sending, finished_recving = (
- self.get_finished_kv_transfer(scheduler_output))
+ finished_sending, finished_recving = self.get_finished_kv_transfer(
+ scheduler_output
+ )
# For the case of no forward caused by receiving remote kv,
# one round of dummy inference is necessary
# to prevent hang over the collective calls.
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
output.kv_connector_output = KVConnectorOutput(
- finished_sending=finished_sending,
- finished_recving=finished_recving)
+ finished_sending=finished_sending, finished_recving=finished_recving
+ )
return output
@staticmethod
@@ -2621,8 +2851,7 @@ def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"):
kv_connector = get_kv_transfer_group()
assert isinstance(kv_connector, KVConnectorBase_V1)
assert scheduler_output.kv_connector_metadata is not None
- kv_connector.bind_connector_metadata(
- scheduler_output.kv_connector_metadata)
+ kv_connector.bind_connector_metadata(scheduler_output.kv_connector_metadata)
kv_connector.start_load_kv(get_forward_context())
@@ -2637,7 +2866,8 @@ def get_finished_kv_transfer(
) -> tuple[Optional[set[str]], Optional[set[str]]]:
if has_kv_transfer_group():
return get_kv_transfer_group().get_finished(
- scheduler_output.finished_req_ids)
+ scheduler_output.finished_req_ids
+ )
return None, None
def _build_dummy_attn_metadata(
@@ -2653,8 +2883,9 @@ def _build_dummy_attn_metadata(
attn_metadata: Optional[dict[str, Any]] = None
if force_attention or aclgraph_runtime_mode == CUDAGraphMode.FULL:
- assert with_prefill is False, \
+ assert with_prefill is False, (
"Full decode graph only supports uniform batch now."
+ )
attn_metadata = {}
@@ -2662,47 +2893,57 @@ def _build_dummy_attn_metadata(
self.seq_lens_np[:num_reqs] = seq_lens
self.seq_lens_np[num_reqs:] = 0
- cu_num_tokens, arange = self._get_cumsum_and_arange(
- num_scheduled_tokens)
- query_start_loc_tensor = torch.Tensor(cu_num_tokens).to(
- self.device).to(torch.int32)
- self.query_start_loc[1:num_reqs + 1] = query_start_loc_tensor
- self.query_start_loc_cpu[1:num_reqs +
- 1] = torch.Tensor(cu_num_tokens)
+ cu_num_tokens, arange = self._get_cumsum_and_arange(num_scheduled_tokens)
+ query_start_loc_tensor = (
+ torch.Tensor(cu_num_tokens).to(self.device).to(torch.int32)
+ )
+ self.query_start_loc[1 : num_reqs + 1] = query_start_loc_tensor
+ self.query_start_loc_cpu[1 : num_reqs + 1] = torch.Tensor(cu_num_tokens)
- num_computed_tokens_cpu = (
- self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
+ num_computed_tokens_cpu = self.input_batch.num_computed_tokens_cpu_tensor[
+ :num_reqs
+ ]
for kv_cache_group_id, kv_cache_group_spec in enumerate(
- self.kv_cache_config.kv_cache_groups):
+ self.kv_cache_config.kv_cache_groups
+ ):
block_table_tensor = self.input_batch.block_table[
- kv_cache_group_id].get_device_tensor()
+ kv_cache_group_id
+ ].get_device_tensor()
slot_mapping = self.input_batch.block_table[
- kv_cache_group_id].slot_mapping
- self.cp_kv_recover_idx = torch.zeros(self.max_num_tokens,
- dtype=torch.int32,
- device=self.device)
+ kv_cache_group_id
+ ].slot_mapping
+ self.cp_kv_recover_idx = torch.zeros(
+ self.max_num_tokens, dtype=torch.int32, device=self.device
+ )
long_seq_metadata = self._generate_pcp_metadata(
- num_tokens, self.seq_lens_cpu)
+ num_tokens, self.seq_lens_cpu
+ )
if long_seq_metadata is not None:
- pcp_world_size = get_pcp_group(
- ).world_size if prefill_context_parallel_enable() else 1
+ pcp_world_size = (
+ get_pcp_group().world_size
+ if prefill_context_parallel_enable()
+ else 1
+ )
dcp_world_size = get_dcp_group().world_size
- num_computed_tokens_of_pcp_dcp = [[
- [0] * dcp_world_size for _ in range(pcp_world_size)
- ] for _ in range(num_tokens)]
- long_seq_metadata.num_computed_tokens_of_pcp_dcp = num_computed_tokens_of_pcp_dcp
+ num_computed_tokens_of_pcp_dcp = [
+ [[0] * dcp_world_size for _ in range(pcp_world_size)]
+ for _ in range(num_tokens)
+ ]
+ long_seq_metadata.num_computed_tokens_of_pcp_dcp = (
+ num_computed_tokens_of_pcp_dcp
+ )
if self.speculative_config:
query_start_loc = torch.tensor(
[0] + self.actual_seq_lengths_q[:num_reqs],
device=self.device,
- dtype=torch.int32)
+ dtype=torch.int32,
+ )
else:
- query_start_loc = self.query_start_loc[:num_reqs + 1]
+ query_start_loc = self.query_start_loc[: num_reqs + 1]
common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=query_start_loc,
- query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs +
- 1],
+ query_start_loc_cpu=self.query_start_loc_cpu[: num_reqs + 1],
seq_lens_cpu=self.seq_lens_cpu,
seq_lens=self.seq_lens_cpu[:num_reqs],
num_reqs=num_reqs,
@@ -2722,14 +2963,15 @@ def _build_dummy_attn_metadata(
prefill_context_parallel_metadata=long_seq_metadata,
)
attn_state = AscendAttentionState.DecodeOnly
- if self.speculative_config and \
- self.speculative_config.method == "deepseek_mtp":
+ if (
+ self.speculative_config
+ and self.speculative_config.method == "deepseek_mtp"
+ ):
attn_state = AscendAttentionState.SpecDecoding
common_metadata = CommonAttentionMetadata(
- query_start_loc=self.query_start_loc[:num_reqs + 1],
- query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs +
- 1],
+ query_start_loc=self.query_start_loc[: num_reqs + 1],
+ query_start_loc_cpu=self.query_start_loc_cpu[: num_reqs + 1],
seq_lens_cpu=self.seq_lens_cpu[:num_reqs],
seq_lens=self.seq_lens_cpu[:num_reqs],
num_reqs=num_reqs,
@@ -2738,59 +2980,78 @@ def _build_dummy_attn_metadata(
slot_mapping=slot_mapping,
num_computed_tokens_cpu=num_computed_tokens_cpu,
max_query_len=max_query_len,
- max_seq_len=seq_lens)
+ max_seq_len=seq_lens,
+ )
for attn_group in self.attn_groups[kv_cache_group_id]:
builder = attn_group.get_metadata_builder()
if isinstance(builder, AscendAttentionMetadataBuilder):
attn_metadata_full_attention = builder.build_for_graph_capture(
- common_attn_metadata, attn_state, self.get_model())
+ common_attn_metadata, attn_state, self.get_model()
+ )
elif isinstance(builder, GDNAttentionMetadataBuilder):
- attn_metadata_gdn_attention = builder.build_for_cudagraph_capture(
- common_metadata)
+ attn_metadata_gdn_attention = (
+ builder.build_for_cudagraph_capture(common_metadata)
+ )
for layer_name in kv_cache_group_spec.layer_names:
if "linear_attn" in layer_name:
- attn_metadata[
- layer_name] = attn_metadata_gdn_attention
+ attn_metadata[layer_name] = attn_metadata_gdn_attention
else:
- attn_metadata[
- layer_name] = attn_metadata_full_attention
+ attn_metadata[layer_name] = attn_metadata_full_attention
return attn_metadata
- def _generate_dummy_run_hidden_states(self, with_prefill,
- is_torchair_compile, input_ids,
- positions, attn_metadata, num_tokens,
- intermediate_tensors, inputs_embeds):
- hidden_states = self.model(input_ids=input_ids,
- positions=positions,
- intermediate_tensors=intermediate_tensors,
- inputs_embeds=inputs_embeds)
+ def _generate_dummy_run_hidden_states(
+ self,
+ with_prefill,
+ is_torchair_compile,
+ input_ids,
+ positions,
+ attn_metadata,
+ num_tokens,
+ intermediate_tensors,
+ inputs_embeds,
+ ):
+ hidden_states = self.model(
+ input_ids=input_ids,
+ positions=positions,
+ intermediate_tensors=intermediate_tensors,
+ inputs_embeds=inputs_embeds,
+ )
forward_context = get_forward_context()
assert forward_context is not None
- if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and \
- not forward_context.capturing and not self.use_sparse:
+ if (
+ forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL
+ and not forward_context.capturing
+ and not self.use_sparse
+ ):
if self.vllm_config.model_config.use_mla:
# FIXME: Try using `auto_dispatch_capture=True`
if self.pcp_size * self.dcp_size > 1:
# FIXME: Try using `auto_dispatch_capture=True`
- update_mla_attn_dcp_pcp_params(self.update_stream,
- forward_context,
- positions.shape[0],
- self.speculative_config)
+ update_mla_attn_dcp_pcp_params(
+ self.update_stream,
+ forward_context,
+ positions.shape[0],
+ self.speculative_config,
+ )
else:
# FIXME: Try using `auto_dispatch_capture=True`
- update_mla_attn_params(self.update_stream, forward_context,
- positions.shape[0],
- self.speculative_config)
+ update_mla_attn_params(
+ self.update_stream,
+ forward_context,
+ positions.shape[0],
+ self.speculative_config,
+ )
else:
if self.pcp_size * self.dcp_size > 1:
- update_attn_dcp_pcp_params(self.update_stream,
- forward_context,
- positions.shape[0])
+ update_attn_dcp_pcp_params(
+ self.update_stream, forward_context, positions.shape[0]
+ )
else:
- update_attn_params(self.update_stream, forward_context,
- positions.shape[0])
+ update_attn_params(
+ self.update_stream, forward_context, positions.shape[0]
+ )
if self.drafter and self.drafter.name == SpecDcodeType.EAGLE3:
hidden_states, _ = hidden_states
@@ -2810,7 +3071,9 @@ def _dummy_run(
) -> torch.Tensor:
# only support eager mode and piecewise graph now
assert aclgraph_runtime_mode is None or aclgraph_runtime_mode in {
- CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL
+ CUDAGraphMode.NONE,
+ CUDAGraphMode.PIECEWISE,
+ CUDAGraphMode.FULL,
}
# In multi-DP scenarios, there may be situations where all DP groups are executing dummy runs.
@@ -2824,9 +3087,9 @@ def _dummy_run(
with_prefill = True
# Padding for DP
- (num_tokens, num_tokens_across_dp,
- with_prefill) = self._sync_metadata_across_dp(num_tokens,
- with_prefill)
+ (num_tokens, num_tokens_across_dp, with_prefill) = (
+ self._sync_metadata_across_dp(num_tokens, with_prefill)
+ )
moe_comm_type = self._select_moe_comm_method(num_tokens, with_prefill)
@@ -2843,8 +3106,7 @@ def _dummy_run(
# When setting max_query_len = 1, we switch to and capture the optimized
# routine of FA2 for pure decode, i.e., Flashdecode + an optimization
# for GQA/MQA.
- max_query_len = self.uniform_decode_query_len if uniform_decode else \
- num_tokens
+ max_query_len = self.uniform_decode_query_len if uniform_decode else num_tokens
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
# for dummy run with LoRA so that the num_reqs collectively
@@ -2860,22 +3122,21 @@ def _dummy_run(
if with_prefill:
num_reqs = num_tokens
else:
- num_reqs = (num_tokens + self.decode_token_per_req -
- 1) // self.decode_token_per_req
+ num_reqs = (
+ num_tokens + self.decode_token_per_req - 1
+ ) // self.decode_token_per_req
num_reqs = min(num_reqs, max_num_reqs)
min_tokens_per_req = num_tokens // num_reqs
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
assert sum(num_scheduled_tokens_list) == num_tokens
assert len(num_scheduled_tokens_list) == num_reqs
- num_scheduled_tokens = np.array(num_scheduled_tokens_list,
- dtype=np.int32)
+ num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32)
if not self.in_profile_run and self.dynamic_eplb:
self.eplb_updator.forward_before()
- with self.maybe_dummy_run_with_lora(self.lora_config,
- num_scheduled_tokens):
+ with self.maybe_dummy_run_with_lora(self.lora_config, num_scheduled_tokens):
if self.is_multimodal_model:
input_ids = None
inputs_embeds = self.inputs_embeds.gpu[:num_tokens]
@@ -2897,26 +3158,27 @@ def _dummy_run(
if self.intermediate_tensors is None:
self.intermediate_tensors = (
self.model.make_empty_intermediate_tensors(
- batch_size=num_tokens,
- dtype=self.dtype,
- device=self.device))
- intermediate_tensors = IntermediateTensors({
- k: v[:num_tokens]
- for k, v in self.intermediate_tensors.items()
- })
+ batch_size=num_tokens, dtype=self.dtype, device=self.device
+ )
+ )
+ intermediate_tensors = IntermediateTensors(
+ {k: v[:num_tokens] for k, v in self.intermediate_tensors.items()}
+ )
# filter out the valid batch descriptor
- _ag_mode, batch_descriptor = \
- self.aclgraph_dispatcher.dispatch(
- BatchDescriptor(num_tokens=num_tokens,
- uniform_decode=uniform_decode))
+ _ag_mode, batch_descriptor = self.aclgraph_dispatcher.dispatch(
+ BatchDescriptor(num_tokens=num_tokens, uniform_decode=uniform_decode)
+ )
if aclgraph_runtime_mode is not None:
# we allow forcing NONE when the dispatcher disagrees to support
# warm ups for aclgraph capture
- assert aclgraph_runtime_mode == CUDAGraphMode.NONE or \
- aclgraph_runtime_mode == _ag_mode, (
+ assert (
+ aclgraph_runtime_mode == CUDAGraphMode.NONE
+ or aclgraph_runtime_mode == _ag_mode
+ ), (
f"Aclgraph runtime mode mismatch at dummy_run. "
- f"Expected {_ag_mode}, but got {aclgraph_runtime_mode}.")
+ f"Expected {_ag_mode}, but got {aclgraph_runtime_mode}."
+ )
else:
aclgraph_runtime_mode = _ag_mode
@@ -2932,37 +3194,43 @@ def _dummy_run(
num_scheduled_tokens=num_scheduled_tokens,
)
- need_dummy_logits = (not self.in_profile_run
- and lmhead_tp_enable())
+ need_dummy_logits = not self.in_profile_run and lmhead_tp_enable()
if need_dummy_logits:
- max_num_reqs_across_dp = num_tokens if not with_prefill else max_num_reqs
- dummy_indices = torch.zeros(max_num_reqs_across_dp,
- dtype=torch.int32)
+ max_num_reqs_across_dp = (
+ num_tokens if not with_prefill else max_num_reqs
+ )
+ dummy_indices = torch.zeros(max_num_reqs_across_dp, dtype=torch.int32)
def dummy_compute_logits(hidden_states):
- return self.model.compute_logits(
- hidden_states[dummy_indices])
+ return self.model.compute_logits(hidden_states[dummy_indices])
with set_ascend_forward_context(
- attn_metadata,
- self.vllm_config,
- num_tokens=num_tokens,
- num_tokens_across_dp=num_tokens_across_dp,
- with_prefill=with_prefill,
- in_profile_run=self.in_profile_run,
- reserved_mc2_mask=self.reserved_mc2_mask,
- moe_comm_type=moe_comm_type,
- num_actual_tokens=0,
- aclgraph_runtime_mode=aclgraph_runtime_mode,
- batch_descriptor=batch_descriptor,
- prefetch_stream=self.prefetch_stream,
- model_instance=self.model,
- weight_prefetch_method=self.weight_prefetch_method):
+ attn_metadata,
+ self.vllm_config,
+ num_tokens=num_tokens,
+ num_tokens_across_dp=num_tokens_across_dp,
+ with_prefill=with_prefill,
+ in_profile_run=self.in_profile_run,
+ reserved_mc2_mask=self.reserved_mc2_mask,
+ moe_comm_type=moe_comm_type,
+ num_actual_tokens=0,
+ aclgraph_runtime_mode=aclgraph_runtime_mode,
+ batch_descriptor=batch_descriptor,
+ prefetch_stream=self.prefetch_stream,
+ model_instance=self.model,
+ weight_prefetch_method=self.weight_prefetch_method,
+ ):
hidden_states = self._generate_dummy_run_hidden_states(
- with_prefill, is_torchair_compile, input_ids, positions,
- attn_metadata, num_tokens, intermediate_tensors,
- inputs_embeds)
+ with_prefill,
+ is_torchair_compile,
+ input_ids,
+ positions,
+ attn_metadata,
+ num_tokens,
+ intermediate_tensors,
+ inputs_embeds,
+ )
if need_dummy_logits:
dummy_compute_logits(hidden_states)
@@ -2974,10 +3242,10 @@ def dummy_compute_logits(hidden_states):
num_reqs=num_reqs,
num_tokens_across_dp=num_tokens_across_dp,
aclgraph_runtime_mode=aclgraph_runtime_mode,
- batch_descriptor=batch_descriptor)
+ batch_descriptor=batch_descriptor,
+ )
if need_dummy_logits:
- self.drafter.model.compute_logits(
- hidden_states[dummy_indices])
+ self.drafter.model.compute_logits(hidden_states[dummy_indices])
if self.in_profile_run and self.dynamic_eplb:
self.model.clear_all_moe_loads()
if not self.in_profile_run and self.dynamic_eplb:
@@ -2997,16 +3265,21 @@ def profile_run(self) -> None:
# Trigger compilation for general shape.
with self.set_in_profile_run():
hidden_states = self._dummy_run(
- self.max_num_tokens //
- self.pcp_size if self.pcp_size > 1 else self.max_num_tokens,
- with_prefill=True)
+ self.max_num_tokens // self.pcp_size
+ if self.pcp_size > 1
+ else self.max_num_tokens,
+ with_prefill=True,
+ )
# MC2 will consume additional NPU memory.
# Therefore, we need to run the MC2 path once here to complete its initialization,
# allowing vLLM to correctly estimate the maximum memory required.
- if self.max_num_tokens > self.mc2_tokens_capacity and \
- self._select_moe_comm_method(
- self.mc2_tokens_capacity,
- with_prefill=True) == MoECommType.MC2:
+ if (
+ self.max_num_tokens > self.mc2_tokens_capacity
+ and self._select_moe_comm_method(
+ self.mc2_tokens_capacity, with_prefill=True
+ )
+ == MoECommType.MC2
+ ):
self._dummy_run(self.mc2_tokens_capacity, with_prefill=True)
output = None
@@ -3017,12 +3290,11 @@ def profile_run(self) -> None:
# For profile, have maximum num_reqs and that collectively have
# maximum num_tokens.
min_tokens_per_req = self.max_num_tokens // self.max_num_reqs
- num_scheduled_tokens_list = [min_tokens_per_req
- ] * self.max_num_reqs
- num_scheduled_tokens_list[
- -1] += self.max_num_tokens % self.max_num_reqs
- num_scheduled_tokens = np.array(num_scheduled_tokens_list,
- dtype=np.int32)
+ num_scheduled_tokens_list = [min_tokens_per_req] * self.max_num_reqs
+ num_scheduled_tokens_list[-1] += self.max_num_tokens % self.max_num_reqs
+ num_scheduled_tokens = np.array(
+ num_scheduled_tokens_list, dtype=np.int32
+ )
logit_indices = np.cumsum(num_scheduled_tokens) - 1
# TODO: need to rum a dummy sampler for generate task
hidden_states = hidden_states[logit_indices]
@@ -3049,9 +3321,9 @@ def _dummy_pooler_run_task(
req_num_tokens = num_tokens // num_reqs
- dummy_token_ids = torch.zeros((num_reqs, req_num_tokens),
- dtype=torch.int32,
- device=self.device)
+ dummy_token_ids = torch.zeros(
+ (num_reqs, req_num_tokens), dtype=torch.int32, device=self.device
+ )
model = cast(VllmModelForPooling, self.get_model())
dummy_pooling_params = PoolingParams(task=task)
@@ -3068,19 +3340,22 @@ def _dummy_pooler_run_task(
pooling_params=[dummy_pooling_params] * num_reqs,
)
- dummy_metadata.build_pooling_cursor(num_scheduled_tokens_list,
- device=hidden_states.device)
+ dummy_metadata.build_pooling_cursor(
+ num_scheduled_tokens_list, device=hidden_states.device
+ )
try:
- return model.pooler(hidden_states=hidden_states,
- pooling_metadata=dummy_metadata)
+ return model.pooler(
+ hidden_states=hidden_states, pooling_metadata=dummy_metadata
+ )
except RuntimeError as e:
- if 'out of memory' in str(e):
+ if "out of memory" in str(e):
raise RuntimeError(
"CUDA out of memory occurred when warming up pooler "
f"({task=}) with {num_reqs} dummy requests. Please try "
"lowering `max_num_seqs` or `gpu_memory_utilization` when "
- "initializing the engine.") from e
+ "initializing the engine."
+ ) from e
else:
raise e
@@ -3119,36 +3394,45 @@ def load_model(self) -> None:
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear, QKVParallelLinear,
RowParallelLinear)
+
for module in self.model.modules():
- if isinstance(module,
- (MergedColumnParallelLinear,
- QKVParallelLinear, RowParallelLinear)):
+ if isinstance(
+ module,
+ (
+ MergedColumnParallelLinear,
+ QKVParallelLinear,
+ RowParallelLinear,
+ ),
+ ):
module.weight.data = self._convert_torch_format(
- module.weight.data)
+ module.weight.data
+ )
if self.drafter:
logger.info("Loading drafter model...")
self.drafter.load_model(self.model)
if self.drafter.name == SpecDcodeType.EAGLE3:
self.model.set_aux_hidden_state_layers(
- self.model.get_eagle3_aux_hidden_state_layers())
+ self.model.get_eagle3_aux_hidden_state_layers()
+ )
if self.lora_config:
- self.model = self.load_lora_model(self.model, self.vllm_config,
- self.device)
- logger.info("Loading model weights took %.4f GB",
- m.consumed_memory / float(2**30))
+ self.model = self.load_lora_model(
+ self.model, self.vllm_config, self.device
+ )
+ logger.info(
+ "Loading model weights took %.4f GB", m.consumed_memory / float(2**30)
+ )
# wrap the model with full graph wrapper if needed.
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
self.update_stream: torch.npu.Stream = torch.npu.Stream()
set_graph_params(self.compilation_config.cudagraph_capture_sizes)
- self.model = ACLGraphWrapper(self.model,
- self.vllm_config,
- runtime_mode=CUDAGraphMode.FULL)
+ self.model = ACLGraphWrapper(
+ self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
+ )
def _convert_torch_format(self, tensor):
- if ACL_FORMAT == ACL_FORMAT_FRACTAL_NZ \
- and not is_enable_nz():
+ if ACL_FORMAT == ACL_FORMAT_FRACTAL_NZ and not is_enable_nz():
return tensor
tensor = torch_npu.npu_format_cast(tensor, ACL_FORMAT)
return tensor
@@ -3165,12 +3449,14 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
self.may_add_encoder_only_layers_to_kv_cache_config()
# NOTE(cmq): initialize_attn_backend must before using self.attn_groups
self.initialize_attn_backend(kv_cache_config)
- self.use_hybrid_blocks = (len(self.attn_groups) > 1)
+ self.use_hybrid_blocks = len(self.attn_groups) > 1
# NOTE: Currently, we determine whether we need `num_accepted_tokens` through `MambaSpec`.
- self.need_accepted_tokens = any([
- isinstance(attn_group[0].kv_cache_spec, MambaSpec)
- for attn_group in self.attn_groups
- ])
+ self.need_accepted_tokens = any(
+ [
+ isinstance(attn_group[0].kv_cache_spec, MambaSpec)
+ for attn_group in self.attn_groups
+ ]
+ )
self.may_reinitialize_input_batch(kv_cache_config)
kv_caches = self.initialize_kv_cache_tensors(kv_cache_config)
@@ -3178,15 +3464,15 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
if has_kv_transfer_group():
get_kv_transfer_group().register_kv_caches(kv_caches)
- def _align_memory(self, tensor: torch.Tensor,
- alignment: int) -> torch.Tensor:
+ def _align_memory(self, tensor: torch.Tensor, alignment: int) -> torch.Tensor:
data_ptr = tensor.data_ptr()
aligned_addr = (data_ptr + alignment - 1) // alignment * alignment
offset = (aligned_addr - data_ptr) // tensor.element_size()
- return tensor[int(offset):]
+ return tensor[int(offset) :]
def initialize_kv_cache_tensors(
- self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
+ self, kv_cache_config: KVCacheConfig
+ ) -> dict[str, torch.Tensor]:
"""
Initialize the memory buffer for KV cache.
@@ -3199,16 +3485,18 @@ def initialize_kv_cache_tensors(
# Initialize the memory buffer for KV cache
kv_cache_raw_tensors = self._allocate_kv_cache_tensors(kv_cache_config)
# Change the memory buffer to the desired shape
- kv_caches = self._reshape_kv_cache_tensors(kv_cache_config,
- kv_cache_raw_tensors)
+ kv_caches = self._reshape_kv_cache_tensors(
+ kv_cache_config, kv_cache_raw_tensors
+ )
- bind_kv_cache(kv_caches,
- self.compilation_config.static_forward_context,
- self.kv_caches)
+ bind_kv_cache(
+ kv_caches, self.compilation_config.static_forward_context, self.kv_caches
+ )
return kv_caches
def _allocate_kv_cache_tensors(
- self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
+ self, kv_cache_config: KVCacheConfig
+ ) -> dict[str, torch.Tensor]:
"""
Initializes the KV cache buffer with the correct size. The buffer needs
to be reshaped to the desired shape before being used by the models.
@@ -3223,36 +3511,42 @@ def _allocate_kv_cache_tensors(
corresponding memory buffer for KV cache.
dict[str, tuple(torch.Tensor, torch.Tensor)] A map between layer names
to their corresponding memory buffer for K cache and V cache.
- """
+ """
# init kv cache tensors
- kv_cache_raw_tensors: dict[str, Union[torch.Tensor,
- Optional[torch.Tensor]]] = {}
+ kv_cache_raw_tensors: dict[
+ str, Union[torch.Tensor, Optional[torch.Tensor]]
+ ] = {}
# llmdatadist need the addr of cache tensor be aligned with 2M
alignment = 2 * 1024 * 1024
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
# TODO: REFACTOR ME to sharing hybrid cache
for idx in range(len(kv_cache_tensor.shared_by)):
layer_name = kv_cache_tensor.shared_by[idx]
- if "linear_attn" in layer_name and layer_name not in kv_cache_raw_tensors.keys(
+ if (
+ "linear_attn" in layer_name
+ and layer_name not in kv_cache_raw_tensors.keys()
):
# for mamba linear attention
if self.vllm_config.kv_transfer_config is None:
- tensor = torch.zeros(kv_cache_tensor.size,
- dtype=torch.int8,
- device=self.device)
+ tensor = torch.zeros(
+ kv_cache_tensor.size, dtype=torch.int8, device=self.device
+ )
else:
cache_size_aligned = kv_cache_tensor.size + alignment
- tensor = torch.zeros(cache_size_aligned,
- dtype=torch.int8,
- device=self.device)
- tensor = self._align_memory(
- tensor, alignment)[:kv_cache_tensor.size]
+ tensor = torch.zeros(
+ cache_size_aligned, dtype=torch.int8, device=self.device
+ )
+ tensor = self._align_memory(tensor, alignment)[
+ : kv_cache_tensor.size
+ ]
for layer_name_inner in kv_cache_tensor.shared_by:
# shared the kvcache between the self_attn specs in the same group
if "linear_attn" in layer_name_inner:
kv_cache_raw_tensors[layer_name_inner] = tensor
- elif "attn" in layer_name and layer_name not in kv_cache_raw_tensors.keys(
+ elif (
+ "attn" in layer_name
+ and layer_name not in kv_cache_raw_tensors.keys()
):
# NOTE: We need to init k cache tensor (nope cache tensor in mla) and
# v cache tensor (rope cache tensor in mla) separately to support llmdatadist,
@@ -3260,8 +3554,10 @@ def _allocate_kv_cache_tensors(
# For deepseek mla, we need to spilt cache tensor accrodding to the nope head dim
# and rope head dim.
if self.model_config.is_deepseek_mla:
- head_size = self.model_config.hf_text_config.qk_rope_head_dim + \
- self.model_config.hf_text_config.kv_lora_rank
+ head_size = (
+ self.model_config.hf_text_config.qk_rope_head_dim
+ + self.model_config.hf_text_config.kv_lora_rank
+ )
dsa_k_cache_factor = None
dsa_k_cache_size = None
@@ -3274,61 +3570,87 @@ def _allocate_kv_cache_tensors(
# FullAttentionSpec allocate 2 * mla page size bytes,
# and we use half of that for k cache in DSA
dsa_k_cache_factor = 2
- k_tensor_split_factor = 2 * head_size / self.model_config.hf_text_config.kv_lora_rank
- v_tensor_split_factor = 2 * head_size / self.model_config.hf_text_config.qk_rope_head_dim
- dsa_k_cache_size = int(kv_cache_tensor.size //
- dsa_k_cache_factor)
+ k_tensor_split_factor = (
+ 2
+ * head_size
+ / self.model_config.hf_text_config.kv_lora_rank
+ )
+ v_tensor_split_factor = (
+ 2
+ * head_size
+ / self.model_config.hf_text_config.qk_rope_head_dim
+ )
+ dsa_k_cache_size = int(
+ kv_cache_tensor.size // dsa_k_cache_factor
+ )
else:
# for other deepseek models, use MLAAttentionSpec
- k_tensor_split_factor = head_size / self.model_config.hf_text_config.kv_lora_rank
- v_tensor_split_factor = head_size / self.model_config.hf_text_config.qk_rope_head_dim
+ k_tensor_split_factor = (
+ head_size / self.model_config.hf_text_config.kv_lora_rank
+ )
+ v_tensor_split_factor = (
+ head_size
+ / self.model_config.hf_text_config.qk_rope_head_dim
+ )
- k_tensor_size = int(kv_cache_tensor.size //
- k_tensor_split_factor)
- v_tensor_size = int(kv_cache_tensor.size //
- v_tensor_split_factor)
+ k_tensor_size = int(kv_cache_tensor.size // k_tensor_split_factor)
+ v_tensor_size = int(kv_cache_tensor.size // v_tensor_split_factor)
# for other attentions, e.g., self_attn, sliding window attn
if self.vllm_config.kv_transfer_config is None:
- k_tensor = torch.zeros(k_tensor_size,
- dtype=torch.int8,
- device=self.device)
- v_tensor = torch.zeros(v_tensor_size,
- dtype=torch.int8,
- device=self.device)
+ k_tensor = torch.zeros(
+ k_tensor_size, dtype=torch.int8, device=self.device
+ )
+ v_tensor = torch.zeros(
+ v_tensor_size, dtype=torch.int8, device=self.device
+ )
#### k cache: for deepseek sparse attention
if dsa_k_cache_factor is not None:
dsa_k_cache_tensor = torch.zeros(
- dsa_k_cache_size,
- dtype=torch.int8,
- device=self.device)
+ dsa_k_cache_size, dtype=torch.int8, device=self.device
+ )
else:
- k_tensor = torch.zeros(k_tensor_size + alignment,
- dtype=torch.int8,
- device=self.device)
- v_tensor = torch.zeros(v_tensor_size + alignment,
- dtype=torch.int8,
- device=self.device)
- k_tensor = self._align_memory(
- k_tensor, alignment)[:k_tensor_size]
- v_tensor = self._align_memory(
- v_tensor, alignment)[:v_tensor_size]
+ k_tensor = torch.zeros(
+ k_tensor_size + alignment,
+ dtype=torch.int8,
+ device=self.device,
+ )
+ v_tensor = torch.zeros(
+ v_tensor_size + alignment,
+ dtype=torch.int8,
+ device=self.device,
+ )
+ k_tensor = self._align_memory(k_tensor, alignment)[
+ :k_tensor_size
+ ]
+ v_tensor = self._align_memory(v_tensor, alignment)[
+ :v_tensor_size
+ ]
#### k cache: for deepseek sparse attention
- if dsa_k_cache_factor is not None and dsa_k_cache_size is not None:
+ if (
+ dsa_k_cache_factor is not None
+ and dsa_k_cache_size is not None
+ ):
dsa_k_cache_tensor = torch.zeros(
dsa_k_cache_size + alignment,
dtype=torch.int8,
- device=self.device)
+ device=self.device,
+ )
dsa_k_cache_tensor = self._align_memory(
- dsa_k_cache_tensor,
- alignment)[:dsa_k_cache_size]
+ dsa_k_cache_tensor, alignment
+ )[:dsa_k_cache_size]
for layer_name_inner in kv_cache_tensor.shared_by:
# shared the kvcache between the self_attn specs in the same group
- if ("attn" in layer_name_inner
- and "linear_attn" not in layer_name_inner):
- kv_cache_raw_tensors[layer_name_inner] = (k_tensor, v_tensor) if \
- not self.use_sparse else (k_tensor, v_tensor, dsa_k_cache_tensor)
+ if (
+ "attn" in layer_name_inner
+ and "linear_attn" not in layer_name_inner
+ ):
+ kv_cache_raw_tensors[layer_name_inner] = (
+ (k_tensor, v_tensor)
+ if not self.use_sparse
+ else (k_tensor, v_tensor, dsa_k_cache_tensor)
+ )
layer_names = set()
for group in kv_cache_config.kv_cache_groups:
@@ -3336,8 +3658,9 @@ def _allocate_kv_cache_tensors(
if layer_name in self.runner_only_attn_layers:
continue
layer_names.add(layer_name)
- assert layer_names == set(kv_cache_raw_tensors.keys(
- )), "Some layers are not correctly initialized"
+ assert layer_names == set(kv_cache_raw_tensors.keys()), (
+ "Some layers are not correctly initialized"
+ )
return kv_cache_raw_tensors
@@ -3370,16 +3693,24 @@ def _reshape_kv_cache_tensors(
if isinstance(kv_cache_spec, FullAttentionSpec):
raw_dsa_k_tensor = None
if self.use_sparse:
- raw_k_tensor, raw_v_tensor, raw_dsa_k_tensor = kv_cache_raw_tensors[ # type: ignore
- layer_name]
+ raw_k_tensor, raw_v_tensor, raw_dsa_k_tensor = (
+ kv_cache_raw_tensors[ # type: ignore
+ layer_name
+ ]
+ )
assert raw_dsa_k_tensor is not None
- sum_page_size_bytes = raw_k_tensor.numel(
- ) + raw_v_tensor.numel() + raw_dsa_k_tensor.numel()
+ sum_page_size_bytes = (
+ raw_k_tensor.numel()
+ + raw_v_tensor.numel()
+ + raw_dsa_k_tensor.numel()
+ )
else:
raw_k_tensor, raw_v_tensor = kv_cache_raw_tensors[ # type: ignore
- layer_name]
- sum_page_size_bytes = raw_k_tensor.numel(
- ) + raw_v_tensor.numel()
+ layer_name
+ ]
+ sum_page_size_bytes = (
+ raw_k_tensor.numel() + raw_v_tensor.numel()
+ )
assert raw_k_tensor is not None
assert raw_v_tensor is not None
assert sum_page_size_bytes % kv_cache_spec.page_size_bytes == 0
@@ -3394,26 +3725,36 @@ def _reshape_kv_cache_tensors(
# the min of all `num_blocks`. Verify it here.
assert num_blocks >= kv_cache_config.num_blocks
- if self.vllm_config.additional_config.get(
- "kv_cache_dtype", None) == 'int8':
+ if (
+ self.vllm_config.additional_config.get("kv_cache_dtype", None)
+ == "int8"
+ ):
kv_cache_shape = attn_backend.get_bsh_kv_cache_shape(
- num_blocks, kv_cache_spec.block_size,
+ num_blocks,
+ kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads,
- kv_cache_spec.head_size)
- elif hasattr(attn_backend, "get_supported_block_size"
- ) and self.use_hybrid_blocks:
+ kv_cache_spec.head_size,
+ )
+ elif (
+ hasattr(attn_backend, "get_supported_block_size")
+ and self.use_hybrid_blocks
+ ):
block_size = attn_backend.get_supported_block_size()[0]
block_size_chunk = kv_cache_spec.block_size // block_size
kv_cache_shape = attn_backend.get_kv_cache_shape(
- num_blocks * block_size_chunk, block_size,
+ num_blocks * block_size_chunk,
+ block_size,
kv_cache_spec.num_kv_heads,
- kv_cache_spec.head_size)
+ kv_cache_spec.head_size,
+ )
else:
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
- num_blocks, kv_cache_spec.block_size,
+ num_blocks,
+ kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads,
- kv_cache_spec.head_size)
+ kv_cache_spec.head_size,
+ )
dtype = kv_cache_spec.dtype
if not self.model_config.is_deepseek_mla:
k_shape = kv_cache_shape[1:]
@@ -3422,35 +3763,47 @@ def _reshape_kv_cache_tensors(
# k_cache: nope_cache v_cache: rope_cache
mla_num_blocks, mla_block_size, num_kv_heads, _ = kv_cache_shape
k_shape = [
- mla_num_blocks, mla_block_size, num_kv_heads,
- self.model_config.hf_text_config.kv_lora_rank
+ mla_num_blocks,
+ mla_block_size,
+ num_kv_heads,
+ self.model_config.hf_text_config.kv_lora_rank,
]
v_shape = [
- mla_num_blocks, mla_block_size, num_kv_heads,
- self.model_config.hf_text_config.qk_rope_head_dim
+ mla_num_blocks,
+ mla_block_size,
+ num_kv_heads,
+ self.model_config.hf_text_config.qk_rope_head_dim,
]
k_cache = raw_k_tensor.view(dtype).view(k_shape)
k_cache = self._convert_torch_format(k_cache)
v_cache = raw_v_tensor.view(dtype).view(v_shape)
v_cache = self._convert_torch_format(v_cache)
if self.use_sparse and raw_dsa_k_tensor is not None:
- dsa_k_cache_shape = (num_blocks,
- kv_cache_spec.block_size, 1, 128)
+ dsa_k_cache_shape = (
+ num_blocks,
+ kv_cache_spec.block_size,
+ 1,
+ 128,
+ )
dsa_k_cache_size = (
- num_blocks
- ) * kv_cache_spec.block_size * 128 * dtype.itemsize
- dsa_k_cache = raw_dsa_k_tensor[:dsa_k_cache_size].view(
- dtype).view(dsa_k_cache_shape)
+ (num_blocks)
+ * kv_cache_spec.block_size
+ * 128
+ * dtype.itemsize
+ )
+ dsa_k_cache = (
+ raw_dsa_k_tensor[:dsa_k_cache_size]
+ .view(dtype)
+ .view(dsa_k_cache_shape)
+ )
kv_caches[layer_name] = (k_cache, v_cache, dsa_k_cache)
else:
kv_caches[layer_name] = (k_cache, v_cache)
elif isinstance(kv_cache_spec, MambaSpec):
raw_tensor = kv_cache_raw_tensors[layer_name]
assert raw_tensor is not None
- assert raw_tensor.numel(
- ) % kv_cache_spec.page_size_bytes == 0
- num_blocks = raw_tensor.numel(
- ) // kv_cache_spec.page_size_bytes
+ assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
+ num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes
# `num_blocks` is the number of blocks the model runner can use.
# `kv_cache_config.num_blocks` is the number of blocks that
@@ -3463,11 +3816,11 @@ def _reshape_kv_cache_tensors(
state_tensors = []
storage_offset_bytes = 0
- for (shape, dtype) in zip(kv_cache_spec.shapes,
- kv_cache_spec.dtypes):
+ for shape, dtype in zip(kv_cache_spec.shapes, kv_cache_spec.dtypes):
dtype_size = get_dtype_size(dtype)
num_element_per_page = (
- kv_cache_spec.page_size_bytes // dtype_size)
+ kv_cache_spec.page_size_bytes // dtype_size
+ )
target_shape = (num_blocks, *shape)
stride = torch.empty(target_shape).stride()
target_stride = (num_element_per_page, *stride[1:])
@@ -3486,8 +3839,7 @@ def _reshape_kv_cache_tensors(
return kv_caches
- def may_reinitialize_input_batch(self,
- kv_cache_config: KVCacheConfig) -> None:
+ def may_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig) -> None:
"""
Re-initialize the input batch if the block sizes are different from
`[self.cache_config.block_size]`. This usually happens when there
@@ -3499,8 +3851,7 @@ def may_reinitialize_input_batch(self,
block_sizes = [
kv_cache_group.kv_cache_spec.block_size
for kv_cache_group in kv_cache_config.kv_cache_groups
- if not isinstance(kv_cache_group.kv_cache_spec,
- EncoderOnlyAttentionSpec)
+ if not isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec)
]
# Generate kernel_block_sizes that matches each block_size
@@ -3509,10 +3860,9 @@ def may_reinitialize_input_batch(self,
# For other backends (like Mamba), use [0] (no splitting)
kernel_block_sizes = []
for kv_cache_group_id, kv_cache_group in enumerate(
- kv_cache_config.kv_cache_groups):
-
- if isinstance(kv_cache_group.kv_cache_spec,
- EncoderOnlyAttentionSpec):
+ kv_cache_config.kv_cache_groups
+ ):
+ if isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec):
continue
elif isinstance(kv_cache_group.kv_cache_spec, AttentionSpec):
# This is an attention backend that supports virtual
@@ -3528,9 +3878,11 @@ def may_reinitialize_input_batch(self,
supported_sizes = backend.get_supported_block_size()
# If no specific sizes supported, use cache config
# block_size
- kernel_block_size_list = (supported_sizes
- if supported_sizes else
- [self.cache_config.block_size])
+ kernel_block_size_list = (
+ supported_sizes
+ if supported_sizes
+ else [self.cache_config.block_size]
+ )
else:
# Fallback to cache config block_size if no backend found
kernel_block_size_list = [self.cache_config.block_size]
@@ -3543,13 +3895,14 @@ def may_reinitialize_input_batch(self,
# to kernel_block_sizes[0]
kernel_block_sizes.append([0])
- if block_sizes != [
- self.cache_config.block_size
- ] or kernel_block_sizes != [[self.cache_config.block_size]]:
+ if block_sizes != [self.cache_config.block_size] or kernel_block_sizes != [
+ [self.cache_config.block_size]
+ ]:
assert self.cache_config.cpu_offload_gb == 0, (
"Cannot re-initialize the input batch when CPU weight "
"offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501
- "for more details.")
+ "for more details."
+ )
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=self.model_config.max_model_len,
@@ -3563,7 +3916,9 @@ def may_reinitialize_input_batch(self,
is_pooling_model=self.is_pooling_model,
num_speculative_tokens=(
self.vllm_config.speculative_config.num_speculative_tokens
- if self.vllm_config.speculative_config else 0),
+ if self.vllm_config.speculative_config
+ else 0
+ ),
kernel_block_sizes=kernel_block_sizes,
)
@@ -3572,8 +3927,7 @@ def may_add_encoder_only_layers_to_kv_cache_config(self) -> None:
Add encoder-only layers to the KV cache config.
"""
block_size = self.vllm_config.cache_config.block_size
- encoder_only_attn_specs: dict[AttentionSpec,
- list[str]] = defaultdict(list)
+ encoder_only_attn_specs: dict[AttentionSpec, list[str]] = defaultdict(list)
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
for layer_name, attn_module in attn_layers.items():
if attn_module.attn_type == AttentionType.ENCODER_ONLY:
@@ -3581,23 +3935,24 @@ def may_add_encoder_only_layers_to_kv_cache_config(self) -> None:
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
- dtype=self.kv_cache_dtype)
+ dtype=self.kv_cache_dtype,
+ )
encoder_only_attn_specs[attn_spec].append(layer_name)
self.runner_only_attn_layers.add(layer_name)
if len(encoder_only_attn_specs) > 0:
- assert len(
- encoder_only_attn_specs
- ) == 1, "Only support one encoder-only attention spec now"
+ assert len(encoder_only_attn_specs) == 1, (
+ "Only support one encoder-only attention spec now"
+ )
spec, layer_names = encoder_only_attn_specs.popitem()
self.kv_cache_config.kv_cache_groups.append(
- KVCacheGroupSpec(layer_names=layer_names, kv_cache_spec=spec))
+ KVCacheGroupSpec(layer_names=layer_names, kv_cache_spec=spec)
+ )
def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
"""
Initialize the attention backends and attention metadata builders.
"""
- assert len(self.attn_groups) == 0, \
- "Attention backends are already initialized"
+ assert len(self.attn_groups) == 0, "Attention backends are already initialized"
class AttentionGroupKey(NamedTuple):
attn_backend: type[AttentionBackend]
@@ -3607,8 +3962,8 @@ def get_attn_backends_for_group(
kv_cache_group_spec: KVCacheGroupSpec,
) -> dict[AttentionGroupKey, list[str]]:
layers = get_layers_from_vllm_config(
- self.vllm_config, AttentionLayerBase,
- kv_cache_group_spec.layer_names)
+ self.vllm_config, AttentionLayerBase, kv_cache_group_spec.layer_names
+ )
attn_backends = {}
attn_backend_layers = defaultdict(list)
# Dedupe based on full class name; this is a bit safer than
@@ -3621,39 +3976,38 @@ def get_attn_backends_for_group(
full_cls_name = attn_backend.full_cls_name()
layer_kv_cache_spec = kv_cache_group_spec.kv_cache_spec
if isinstance(layer_kv_cache_spec, UniformTypeKVCacheSpecs):
- layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[
- layer_name]
+ layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[layer_name]
key = (full_cls_name, layer_kv_cache_spec)
- attn_backends[key] = AttentionGroupKey(attn_backend,
- layer_kv_cache_spec)
+ attn_backends[key] = AttentionGroupKey(
+ attn_backend, layer_kv_cache_spec
+ )
attn_backend_layers[key].append(layer_name)
- return {
- attn_backends[k]: v
- for k, v in attn_backend_layers.items()
- }
+ return {attn_backends[k]: v for k, v in attn_backend_layers.items()}
def create_attn_groups(
attn_backends_map: dict[AttentionBackend, list[str]],
) -> list[AttentionGroup]:
attn_groups: list[AttentionGroup] = []
- for (attn_backend,
- kv_cache_spec), layer_names in attn_backends_map.items():
+ for (attn_backend, kv_cache_spec), layer_names in attn_backends_map.items():
attn_metadata_builders = []
- attn_metadata_builders.append(attn_backend.get_builder_cls()(
- kv_cache_spec,
- layer_names,
- self.vllm_config,
- self.device,
- ))
- attn_group = AttentionGroup(attn_backend,
- attn_metadata_builders,
- layer_names, kv_cache_spec)
+ attn_metadata_builders.append(
+ attn_backend.get_builder_cls()(
+ kv_cache_spec,
+ layer_names,
+ self.vllm_config,
+ self.device,
+ )
+ )
+ attn_group = AttentionGroup(
+ attn_backend, attn_metadata_builders, layer_names, kv_cache_spec
+ )
attn_groups.append(attn_group)
return attn_groups
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
attn_backends = get_attn_backends_for_group( # type: ignore
- kv_cache_group_spec)
+ kv_cache_group_spec
+ )
self.attn_groups.append(create_attn_groups(attn_backends))
# Calculate reorder batch threshold (if needed)
@@ -3679,16 +4033,17 @@ def calculate_reorder_batch_threshold(self) -> None:
# check that if any backends reorder batches; that the reordering
# is compatible (e.g., decode threshold is the same)
reorder_batch_threshold_i = (
- attn_metadata_builder_i.reorder_batch_threshold)
+ attn_metadata_builder_i.reorder_batch_threshold
+ )
if reorder_batch_threshold_i is not None:
if self.reorder_batch_threshold is not None:
- if reorder_batch_threshold_i != \
- self.reorder_batch_threshold:
+ if reorder_batch_threshold_i != self.reorder_batch_threshold:
raise ValueError(
f"Attention backend reorders decodes with "
f"threshold {reorder_batch_threshold_i} but other "
f"backend uses threshold "
- f"{self.reorder_batch_threshold}")
+ f"{self.reorder_batch_threshold}"
+ )
else:
self.reorder_batch_threshold = reorder_batch_threshold_i
@@ -3707,8 +4062,7 @@ def get_kv_cache_spec_v0110(self) -> dict[str, KVCacheSpec]:
kv_cache_spec: dict[str, KVCacheSpec] = {}
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
for layer_name, attn_module in attn_layers.items():
- if (kv_tgt_layer :=
- attn_module.kv_sharing_target_layer_name) is not None:
+ if (kv_tgt_layer := attn_module.kv_sharing_target_layer_name) is not None:
# The layer doesn't need its own KV cache and will use that of
# the target layer. We skip creating a KVCacheSpec for it, so
# that KV cache management logic will act as this layer does
@@ -3731,7 +4085,8 @@ def get_kv_cache_spec_v0110(self) -> dict[str, KVCacheSpec]:
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype,
- cache_dtype_str=self.cache_config.cache_dtype)
+ cache_dtype_str=self.cache_config.cache_dtype,
+ )
else:
# TODO(cmq): This is a hack way to fix deepseek kvcache when
# using DSA. Fix the spec in vLLM is a finnal way.
@@ -3739,31 +4094,36 @@ def get_kv_cache_spec_v0110(self) -> dict[str, KVCacheSpec]:
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
- dtype=self.kv_cache_dtype)
- elif attn_module.attn_type in (AttentionType.ENCODER,
- AttentionType.ENCODER_ONLY):
+ dtype=self.kv_cache_dtype,
+ )
+ elif attn_module.attn_type in (
+ AttentionType.ENCODER,
+ AttentionType.ENCODER_ONLY,
+ ):
# encoder-only attention does not need KV cache.
continue
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
raise NotImplementedError
else:
- raise ValueError(
- f"Unknown attention type: {attn_module.attn_type}")
+ raise ValueError(f"Unknown attention type: {attn_module.attn_type}")
mamba_layers = get_layers_from_vllm_config(self.vllm_config, MambaBase)
if len(mamba_layers) > 0:
- if (self.vllm_config.speculative_config is not None
- and self.vllm_config.model_config.hf_config.model_type
- not in ["qwen3_next"]):
+ if (
+ self.vllm_config.speculative_config is not None
+ and self.vllm_config.model_config.hf_config.model_type
+ not in ["qwen3_next"]
+ ):
raise NotImplementedError(
- "Mamba with speculative decoding is not supported yet.")
+ "Mamba with speculative decoding is not supported yet."
+ )
if self.vllm_config.cache_config.enable_prefix_caching:
raise NotImplementedError(
- "Prefix caching is not supported for Mamba yet.")
+ "Prefix caching is not supported for Mamba yet."
+ )
max_model_len = self.vllm_config.model_config.max_model_len
- page_size_padded = (
- self.vllm_config.cache_config.mamba_page_size_padded)
+ page_size_padded = self.vllm_config.cache_config.mamba_page_size_padded
# Set block_size to max_model_len, so that mamba model will always
# have only one block in the KV cache.
@@ -3776,7 +4136,9 @@ def get_kv_cache_spec_v0110(self) -> dict[str, KVCacheSpec]:
mamba_type=mamba_module.mamba_type,
num_speculative_blocks=(
self.speculative_config.num_speculative_tokens
- if self.speculative_config else 0),
+ if self.speculative_config
+ else 0
+ ),
)
return kv_cache_spec
@@ -3795,12 +4157,12 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
block_size = self.vllm_config.cache_config.block_size
use_mla = self.vllm_config.model_config.use_mla
kv_cache_spec: dict[str, KVCacheSpec] = {}
- attn_layers = get_layers_from_vllm_config(self.vllm_config,
- AttentionLayerBase)
+ attn_layers = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase)
for layer_name, attn_module in attn_layers.items():
if isinstance(attn_module, Attention):
- if (kv_tgt_layer :=
- attn_module.kv_sharing_target_layer_name) is not None:
+ if (
+ kv_tgt_layer := attn_module.kv_sharing_target_layer_name
+ ) is not None:
# The layer doesn't need its own KV cache and will use that of
# the target layer. We skip creating a KVCacheSpec for it, so
# that KV cache management logic will act as this layer does
@@ -3819,16 +4181,18 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
- dtype=self.kv_cache_dtype)
- elif attn_module.attn_type in (AttentionType.ENCODER,
- AttentionType.ENCODER_ONLY):
+ dtype=self.kv_cache_dtype,
+ )
+ elif attn_module.attn_type in (
+ AttentionType.ENCODER,
+ AttentionType.ENCODER_ONLY,
+ ):
# encoder-only attention does not need KV cache.
continue
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
raise NotImplementedError
else:
- raise ValueError(
- f"Unknown attention type: {attn_module.attn_type}")
+ raise ValueError(f"Unknown attention type: {attn_module.attn_type}")
elif isinstance(attn_module, MLAAttention):
if use_mla and not self.use_sparse:
@@ -3837,7 +4201,8 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
num_kv_heads=1,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype,
- cache_dtype_str=self.cache_config.cache_dtype)
+ cache_dtype_str=self.cache_config.cache_dtype,
+ )
else:
# TODO(cmq): This is a hack way to fix deepseek kvcache when
# using DSA. Fix the spec in vLLM is a finnal way.
@@ -3845,22 +4210,26 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
block_size=block_size,
num_kv_heads=1,
head_size=attn_module.head_size,
- dtype=self.kv_cache_dtype)
+ dtype=self.kv_cache_dtype,
+ )
mamba_layers = get_layers_from_vllm_config(self.vllm_config, MambaBase)
if len(mamba_layers) > 0:
- if (self.vllm_config.speculative_config is not None
- and self.vllm_config.model_config.hf_config.model_type
- not in ["qwen3_next"]):
+ if (
+ self.vllm_config.speculative_config is not None
+ and self.vllm_config.model_config.hf_config.model_type
+ not in ["qwen3_next"]
+ ):
raise NotImplementedError(
- "Mamba with speculative decoding is not supported yet.")
+ "Mamba with speculative decoding is not supported yet."
+ )
if self.vllm_config.cache_config.enable_prefix_caching:
raise NotImplementedError(
- "Prefix caching is not supported for Mamba yet.")
+ "Prefix caching is not supported for Mamba yet."
+ )
max_model_len = self.vllm_config.model_config.max_model_len
- page_size_padded = (
- self.vllm_config.cache_config.mamba_page_size_padded)
+ page_size_padded = self.vllm_config.cache_config.mamba_page_size_padded
# Set block_size to max_model_len, so that mamba model will always
# have only one block in the KV cache.
@@ -3873,7 +4242,9 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
mamba_type=mamba_module.mamba_type,
num_speculative_blocks=(
self.speculative_config.num_speculative_tokens
- if self.speculative_config else 0),
+ if self.speculative_config
+ else 0
+ ),
)
return kv_cache_spec
@@ -3885,7 +4256,7 @@ def initialize_aclgraph_capture(self) -> None:
for attn_group in self._attn_group_iterator():
builder = attn_group.get_metadata_builder()
graph_support = None
- if hasattr(builder, 'aclgraph_support'):
+ if hasattr(builder, "aclgraph_support"):
graph_support = builder.aclgraph_support.value
else:
graph_support = builder.cudagraph_support.value
@@ -3896,68 +4267,93 @@ def initialize_aclgraph_capture(self) -> None:
# This is an imitation of compilation_config.splitting_ops_contain_attention()
splitting_ops_contain_attention = (
self.compilation_config.splitting_ops is not None
- and all(op in self.compilation_config.splitting_ops for op in [
- "vllm.mla_forward",
- ]))
+ and all(
+ op in self.compilation_config.splitting_ops
+ for op in [
+ "vllm.mla_forward",
+ ]
+ )
+ )
# Flexible resolve the aclgraph mode
aclgraph_mode = self.compilation_config.cudagraph_mode
# check graph for mixed batch is supported
- if aclgraph_mode.mixed_mode() == CUDAGraphMode.FULL \
- and min_ag_support != AttentionCGSupport.ALWAYS:
- msg = (f"ACLGraphMode.{aclgraph_mode.name} is not supported "
- f"with {min_ag_builder_name} backend (support: "
- f"{min_ag_support})")
+ if (
+ aclgraph_mode.mixed_mode() == CUDAGraphMode.FULL
+ and min_ag_support != AttentionCGSupport.ALWAYS
+ ):
+ msg = (
+ f"ACLGraphMode.{aclgraph_mode.name} is not supported "
+ f"with {min_ag_builder_name} backend (support: "
+ f"{min_ag_support})"
+ )
if min_ag_support == AttentionCGSupport.NEVER:
# if not supported any full graphs, just raise it.
- msg += "; please try cudagraph_mode=PIECEWISE, and "\
+ msg += (
+ "; please try cudagraph_mode=PIECEWISE, and "
"make sure compilation level is piecewise"
+ )
raise ValueError(msg)
# attempt to resolve the full graph related mode
if splitting_ops_contain_attention:
msg += "; setting cudagraph_mode=FULL_AND_PIECEWISE"
aclgraph_mode = self.compilation_config.cudagraph_mode = (
- CUDAGraphMode.FULL_AND_PIECEWISE)
+ CUDAGraphMode.FULL_AND_PIECEWISE
+ )
else:
msg += "; setting cudagraph_mode=FULL_DECODE_ONLY"
aclgraph_mode = self.compilation_config.cudagraph_mode = (
- CUDAGraphMode.FULL_DECODE_ONLY)
+ CUDAGraphMode.FULL_DECODE_ONLY
+ )
logger.warning(msg)
# double check that we can support full graph if they are requested
# even after automatic downgrades
- if aclgraph_mode.has_full_cudagraphs() \
- and min_ag_support == AttentionCGSupport.NEVER:
- raise ValueError(f"CUDAGraphMode.{aclgraph_mode.name} is not "
- f"supported with {min_ag_builder_name} backend ("
- f"support:{min_ag_support}) "
- "; please try cudagraph_mode=PIECEWISE, "
- "and make sure compilation level is piecewise")
+ if (
+ aclgraph_mode.has_full_cudagraphs()
+ and min_ag_support == AttentionCGSupport.NEVER
+ ):
+ raise ValueError(
+ f"CUDAGraphMode.{aclgraph_mode.name} is not "
+ f"supported with {min_ag_builder_name} backend ("
+ f"support:{min_ag_support}) "
+ "; please try cudagraph_mode=PIECEWISE, "
+ "and make sure compilation level is piecewise"
+ )
self.aclgraph_dispatcher.initialize_cudagraph_keys(
- self.compilation_config.cudagraph_mode,
- self.uniform_decode_query_len)
+ self.compilation_config.cudagraph_mode, self.uniform_decode_query_len
+ )
- def _capture_aclgraphs(self, compilation_cases: list[int],
- aclgraph_runtime_mode: CUDAGraphMode,
- uniform_decode: bool):
- assert aclgraph_runtime_mode != CUDAGraphMode.NONE and \
- aclgraph_runtime_mode in [CUDAGraphMode.FULL,
- CUDAGraphMode.PIECEWISE]
+ def _capture_aclgraphs(
+ self,
+ compilation_cases: list[int],
+ aclgraph_runtime_mode: CUDAGraphMode,
+ uniform_decode: bool,
+ ):
+ assert (
+ aclgraph_runtime_mode != CUDAGraphMode.NONE
+ and aclgraph_runtime_mode in [CUDAGraphMode.FULL, CUDAGraphMode.PIECEWISE]
+ )
# Only rank 0 should print progress bar during capture
if is_global_first_rank():
logger.info(
"Starting to capture ACL graphs for cases: %s, "
- "mode: %s, uniform_decode: %s", compilation_cases,
- aclgraph_runtime_mode.name, uniform_decode)
+ "mode: %s, uniform_decode: %s",
+ compilation_cases,
+ aclgraph_runtime_mode.name,
+ uniform_decode,
+ )
compilation_cases = tqdm(
compilation_cases,
disable=not self.load_config.use_tqdm_on_load,
desc="Capturing ACL graphs ({}, {})".format(
"decode" if uniform_decode else "mixed prefill-decode",
- aclgraph_runtime_mode.name))
+ aclgraph_runtime_mode.name,
+ ),
+ )
# We skip EPLB here since we don't want to record dummy metrics
for num_tokens in compilation_cases:
for _ in range(self.compilation_config.cudagraph_num_of_warmups):
@@ -3966,21 +4362,26 @@ def _capture_aclgraphs(self, compilation_cases: list[int],
# if we want to warm up attention or not. This is
# different from the case where `FULL` implies capture
# attention while `PIECEWISE` implies no attention.
- force_attention = (aclgraph_runtime_mode == CUDAGraphMode.FULL)
- self._dummy_run(num_tokens,
- aclgraph_runtime_mode=CUDAGraphMode.NONE,
- force_attention=force_attention,
- uniform_decode=uniform_decode)
- self._dummy_run(num_tokens,
- aclgraph_runtime_mode=aclgraph_runtime_mode,
- force_attention=force_attention,
- uniform_decode=uniform_decode)
+ force_attention = aclgraph_runtime_mode == CUDAGraphMode.FULL
+ self._dummy_run(
+ num_tokens,
+ aclgraph_runtime_mode=CUDAGraphMode.NONE,
+ force_attention=force_attention,
+ uniform_decode=uniform_decode,
+ )
+ self._dummy_run(
+ num_tokens,
+ aclgraph_runtime_mode=aclgraph_runtime_mode,
+ force_attention=force_attention,
+ uniform_decode=uniform_decode,
+ )
def _capture_model(self):
if not self.use_aclgraph:
logger.warning(
"Skipping ACL graph capture. To turn on ACL graph capture, "
- "ensure `aclraph_mode` was not manually set to `NONE`")
+ "ensure `aclraph_mode` was not manually set to `NONE`"
+ )
return
else:
self.initialize_aclgraph_capture()
@@ -4000,11 +4401,12 @@ def _capture_model(self):
self._capture_aclgraphs(
compilation_cases,
aclgraph_runtime_mode=aclgraph_runtime_mode,
- uniform_decode=False)
+ uniform_decode=False,
+ )
except Exception as e:
error_msg = str(e)
- error_code = '0x7020023'
- pattern = r'retCode=([^,\s\.]+)'
+ error_code = "0x7020023"
+ pattern = r"retCode=([^,\s\.]+)"
match = re.search(pattern, error_msg)
if match:
retCode = match.group(1)
@@ -4018,21 +4420,28 @@ def _capture_model(self):
"1. Manually configure the compilation_config parameter "
"with a reduced set of sizes: '{\"cudagraph_capture_sizes\":[size1, size2, size3, ...]}'.\n"
"2. Utilize ACLgraph's full graph mode as an alternative to the piece-wise approach.\n\n"
- f"{str(e)}")
+ f"{str(e)}"
+ )
raise
- if aclgraph_mode.decode_mode() == CUDAGraphMode.FULL and \
- aclgraph_mode.separate_routine():
- max_num_tokens = self.scheduler_config.max_num_seqs * \
- self.uniform_decode_query_len
+ if (
+ aclgraph_mode.decode_mode() == CUDAGraphMode.FULL
+ and aclgraph_mode.separate_routine()
+ ):
+ max_num_tokens = (
+ self.scheduler_config.max_num_seqs * self.uniform_decode_query_len
+ )
decode_cudagraph_batch_sizes = [
- x for x in self.aclgraph_batch_sizes if x <= max_num_tokens
- and x >= self.uniform_decode_query_len
+ x
+ for x in self.aclgraph_batch_sizes
+ if x <= max_num_tokens and x >= self.uniform_decode_query_len
]
compilation_cases_decode = sorted(decode_cudagraph_batch_sizes)
# TODO: refactor this when vLLM supports mtp>1
- if not all(x % self.uniform_decode_query_len == 0
- for x in decode_cudagraph_batch_sizes):
+ if not all(
+ x % self.uniform_decode_query_len == 0
+ for x in decode_cudagraph_batch_sizes
+ ):
raise ValueError(
"In the MTP fullgraph scenario, each graph size must be an integer multiple of "
f"(num_speculative_tokens + 1): {self.uniform_decode_query_len}. "
@@ -4043,7 +4452,8 @@ def _capture_model(self):
self._capture_aclgraphs(
compilation_cases=compilation_cases_decode,
aclgraph_runtime_mode=CUDAGraphMode.FULL,
- uniform_decode=True)
+ uniform_decode=True,
+ )
# Disable aclgraph capturing globally, so any unexpected aclgraph
# capturing will be detected and raise an error after here.
@@ -4053,7 +4463,6 @@ def _capture_model(self):
set_cudagraph_capturing_enabled(False)
def capture_model(self) -> None:
-
compilation_counter.num_gpu_runner_capture_triggers += 1
start_time = time.perf_counter()
@@ -4066,8 +4475,11 @@ def capture_model(self) -> None:
elapsed_time = end_time - start_time
npu_graph_size = start_free_npu_memory - end_free_npu_memory
# This usually takes 5~20 seconds.
- logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
- elapsed_time, npu_graph_size / (1 << 30))
+ logger.info(
+ "Graph capturing finished in %.0f secs, took %.2f GiB",
+ elapsed_time,
+ npu_graph_size / (1 << 30),
+ )
def _get_prompt_logprobs_dict(
self,
@@ -4085,7 +4497,6 @@ def _get_prompt_logprobs_dict(
# maintainable loop over optimal performance.
completed_prefill_reqs = []
for req_id, num_prompt_logprobs in num_prompt_logprobs_dict.items():
-
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
# Get metadata for this request.
@@ -4095,7 +4506,8 @@ def _get_prompt_logprobs_dict(
continue
num_prompt_tokens = len(request.prompt_token_ids)
prompt_token_ids = torch.tensor(request.prompt_token_ids).to(
- self.device, non_blocking=True)
+ self.device, non_blocking=True
+ )
# Set up target LogprobsTensors object.
logprobs_tensors = in_progress_dict.get(req_id)
@@ -4103,7 +4515,8 @@ def _get_prompt_logprobs_dict(
# Create empty logprobs CPU tensors for the entire prompt.
# If chunked, we'll copy in slice by slice.
logprobs_tensors = LogprobsTensors.empty_cpu(
- num_prompt_tokens - 1, num_prompt_logprobs + 1)
+ num_prompt_tokens - 1, num_prompt_logprobs + 1
+ )
in_progress_dict[req_id] = logprobs_tensors
# Determine number of logits to retrieve.
@@ -4133,27 +4546,29 @@ def _get_prompt_logprobs_dict(
# then there is prompt logprob generated for each index.
req_idx = self.input_batch.req_id_to_index[req_id]
offset = self.query_start_loc_np[req_idx].item()
- prompt_hidden_states = hidden_states[offset:offset + num_logits]
+ prompt_hidden_states = hidden_states[offset : offset + num_logits]
logits = self.model.compute_logits(prompt_hidden_states)
# Get the "target" tokens for each index. For prompt at index i,
# the token at prompt index i+1 is the "sampled" token we want
# to gather the logprob for.
- tgt_token_ids = prompt_token_ids[start_tok:start_tok + num_logits]
+ tgt_token_ids = prompt_token_ids[start_tok : start_tok + num_logits]
# Compute prompt logprobs.
logprobs = self.sampler.compute_logprobs(logits)
token_ids, logprobs, ranks = self.sampler.gather_logprobs(
- logprobs, num_prompt_logprobs, tgt_token_ids)
+ logprobs, num_prompt_logprobs, tgt_token_ids
+ )
# Transfer NPU->CPU async.
chunk_slice = slice(start_idx, start_idx + num_logits)
logprobs_tensors.logprob_token_ids[chunk_slice].copy_(
- token_ids, non_blocking=True)
- logprobs_tensors.logprobs[chunk_slice].copy_(logprobs,
- non_blocking=True)
+ token_ids, non_blocking=True
+ )
+ logprobs_tensors.logprobs[chunk_slice].copy_(logprobs, non_blocking=True)
logprobs_tensors.selected_token_ranks[chunk_slice].copy_(
- ranks, non_blocking=True)
+ ranks, non_blocking=True
+ )
# Remove requests that have completed prefill from the batch
# num_prompt_logprobs_dict.
@@ -4180,46 +4595,50 @@ def _build_drafter_prepare_inputs_torchair_param(self):
def _update_tokens_for_pcp(self, tokens):
num_reqs = self.input_batch.num_reqs
self.num_pcp_pads = self.num_pcp_pads[:num_reqs]
- if not self.pcp_size > 1:
- return tokens, None, None
tokens = np.array(tokens, dtype=np.int32)
num_decode_reqs = sum(
- self.input_batch.num_computed_tokens_cpu[:num_reqs] >=
- self.input_batch.num_prompt_tokens[:num_reqs])
- num_padded_scheduled_tokens = np.ceil(
- tokens /
- (2 * self.pcp_size)).astype(np.int32) * (2 * self.pcp_size)
+ self.input_batch.num_computed_tokens_cpu[:num_reqs]
+ >= self.input_batch.num_prompt_tokens[:num_reqs]
+ )
+ num_padded_scheduled_tokens = np.ceil(tokens / (2 * self.pcp_size)).astype(
+ np.int32
+ ) * (2 * self.pcp_size)
num_padded_scheduled_tokens[:num_decode_reqs] = self.pcp_size
self.num_pcp_pads = num_padded_scheduled_tokens - tokens
- cu_padded_tokens, pcp_padded_arange = \
- self._get_cumsum_and_arange(num_padded_scheduled_tokens)
+ cu_padded_tokens, pcp_padded_arange = self._get_cumsum_and_arange(
+ num_padded_scheduled_tokens
+ )
unpad_mask = torch.from_numpy(
- pcp_padded_arange < np.repeat(tokens, num_padded_scheduled_tokens))
+ pcp_padded_arange < np.repeat(tokens, num_padded_scheduled_tokens)
+ )
pcp_tokens = num_padded_scheduled_tokens // self.pcp_size
pcp_chunk_sizes = (pcp_tokens // 2).clip(min=1)
_, pcp_arange = self._get_cumsum_and_arange(pcp_tokens)
_, pcp_chunk_arange = self._get_cumsum_and_arange(pcp_chunk_sizes)
- pcp_head_chunk_mask = pcp_arange < np.repeat(pcp_chunk_sizes,
- pcp_tokens)
+ pcp_head_chunk_mask = pcp_arange < np.repeat(pcp_chunk_sizes, pcp_tokens)
def get_current_rank_positions(cu_tokens, rank):
positions_start_loc = np.zeros_like(cu_tokens)
positions_start_loc[1:] = cu_tokens[:-1]
positions = np.zeros(len(pcp_head_chunk_mask), dtype=np.int32)
head_start_loc = positions_start_loc + rank * pcp_chunk_sizes
- tail_start_loc = positions_start_loc + \
- (2 * self.pcp_size - rank - 1) * pcp_chunk_sizes
- positions[pcp_head_chunk_mask] = pcp_chunk_arange + \
- np.repeat(head_start_loc, pcp_chunk_sizes)
+ tail_start_loc = (
+ positions_start_loc + (2 * self.pcp_size - rank - 1) * pcp_chunk_sizes
+ )
+ positions[pcp_head_chunk_mask] = pcp_chunk_arange + np.repeat(
+ head_start_loc, pcp_chunk_sizes
+ )
# Decode reqs do not have tail chunks.
- positions[~pcp_head_chunk_mask] = \
- pcp_chunk_arange[num_decode_reqs:] + \
- np.repeat(tail_start_loc, pcp_chunk_sizes)[num_decode_reqs:]
+ positions[~pcp_head_chunk_mask] = (
+ pcp_chunk_arange[num_decode_reqs:]
+ + np.repeat(tail_start_loc, pcp_chunk_sizes)[num_decode_reqs:]
+ )
return positions
positions = get_current_rank_positions(
- np.zeros(num_reqs, dtype=np.int32), self.pcp_rank)
+ np.zeros(num_reqs, dtype=np.int32), self.pcp_rank
+ )
# Decode tokens are duplicate and their positions always be 0.
positions[:num_decode_reqs] = 0
@@ -4228,8 +4647,9 @@ def get_current_rank_positions(cu_tokens, rank):
for rank_i in range(self.pcp_size)
]
all_positions_tensor = torch.from_numpy(np.concatenate(all_positions))
- self.pcp_allgather_restore_idx[:all_positions_tensor.shape[0]].copy_(
- all_positions_tensor.float().argsort().long(), non_blocking=True)
+ self.pcp_allgather_restore_idx[: all_positions_tensor.shape[0]].copy_(
+ all_positions_tensor.float().argsort().long(), non_blocking=True
+ )
pcp_tokens[:num_decode_reqs] = 1
return pcp_tokens, positions, unpad_mask
@@ -4246,11 +4666,17 @@ def _get_pcp_local_seq_lens(
num_requests = seq_lens.size(0)
total_world_size = pcp_world_size * dcp_world_size
seq_lens_tiled = seq_lens.unsqueeze(-1).repeat(1, total_world_size)
- rank_offsets = (torch.arange(total_world_size,
- dtype=torch.int32).unsqueeze(0).repeat(
- num_requests, 1))
- base = (seq_lens_tiled // cp_kv_cache_interleave_size //
- total_world_size * cp_kv_cache_interleave_size)
+ rank_offsets = (
+ torch.arange(total_world_size, dtype=torch.int32)
+ .unsqueeze(0)
+ .repeat(num_requests, 1)
+ )
+ base = (
+ seq_lens_tiled
+ // cp_kv_cache_interleave_size
+ // total_world_size
+ * cp_kv_cache_interleave_size
+ )
remainder = seq_lens_tiled - base * total_world_size
remainder = torch.clip(
remainder - rank_offsets * cp_kv_cache_interleave_size,
@@ -4258,13 +4684,16 @@ def _get_pcp_local_seq_lens(
cp_kv_cache_interleave_size,
)
dcp_local_seq_lens = (base + remainder).reshape(
- [-1, pcp_world_size, dcp_world_size])
+ [-1, pcp_world_size, dcp_world_size]
+ )
return dcp_local_seq_lens
def _generate_pcp_metadata(self, total_num_scheduled_tokens, seq_lens):
num_reqs = self.input_batch.num_reqs
- num_decodes = sum(self.input_batch.num_computed_tokens_cpu[:num_reqs]
- >= self.input_batch.num_prompt_tokens[:num_reqs])
+ num_decodes = sum(
+ self.input_batch.num_computed_tokens_cpu[:num_reqs]
+ >= self.input_batch.num_prompt_tokens[:num_reqs]
+ )
num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_size
self.num_actual_tokens_pcp_padded = num_actual_tokens_pcp_padded
long_seq_metadata = None
@@ -4294,47 +4723,58 @@ def _generate_pcp_metadata(self, total_num_scheduled_tokens, seq_lens):
chunk_len = seq_len // 2
chunk_seqlens.append(chunk_len)
q_head_idx.extend(
- list(range(q_req_offset, q_req_offset + chunk_len)))
+ list(range(q_req_offset, q_req_offset + chunk_len))
+ )
kv_with_q_head_nomask_idx.extend(
list(
- range(kv_req_offset, kv_req_offset +
- chunk_len * q_head_chunk_id)))
+ range(
+ kv_req_offset,
+ kv_req_offset + chunk_len * q_head_chunk_id,
+ )
+ )
+ )
kv_with_q_head_mask_idx.extend(
list(
range(
kv_req_offset + chunk_len * q_head_chunk_id,
- kv_req_offset + chunk_len *
- (q_head_chunk_id + 1))))
- kv_with_q_head_nomask_seqlens.append(chunk_len *
- q_head_chunk_id)
+ kv_req_offset + chunk_len * (q_head_chunk_id + 1),
+ )
+ )
+ )
+ kv_with_q_head_nomask_seqlens.append(chunk_len * q_head_chunk_id)
q_tail_idx.extend(
list(
- range(q_req_offset + chunk_len,
- q_req_offset + chunk_len * 2)))
+ range(
+ q_req_offset + chunk_len, q_req_offset + chunk_len * 2
+ )
+ )
+ )
kv_with_q_tail_nomask_idx.extend(
list(
- range(kv_req_offset, kv_req_offset +
- chunk_len * q_tail_chunk_id)))
+ range(
+ kv_req_offset,
+ kv_req_offset + chunk_len * q_tail_chunk_id,
+ )
+ )
+ )
kv_with_q_tail_mask_idx.extend(
list(
range(
kv_req_offset + chunk_len * q_tail_chunk_id,
- kv_req_offset + chunk_len *
- (q_tail_chunk_id + 1))))
- kv_with_q_tail_nomask_seqlens.append(chunk_len *
- q_tail_chunk_id)
+ kv_req_offset + chunk_len * (q_tail_chunk_id + 1),
+ )
+ )
+ )
+ kv_with_q_tail_nomask_seqlens.append(chunk_len * q_tail_chunk_id)
q_req_offset += seq_len
kv_req_offset += seq_len * self.pcp_size
# Convert lists to tensors and move to device
def _list_to_tensor(lst, device, dtype=torch.int32):
- tensor_npu = torch.zeros(len(lst),
- dtype=dtype,
- device=device)
- tensor_npu.copy_(torch.tensor(lst, dtype=dtype),
- non_blocking=True)
+ tensor_npu = torch.zeros(len(lst), dtype=dtype, device=device)
+ tensor_npu.copy_(torch.tensor(lst, dtype=dtype), non_blocking=True)
return tensor_npu
q_head_idx_tensor = _list_to_tensor(q_head_idx, self.device)
@@ -4343,70 +4783,76 @@ def _list_to_tensor(lst, device, dtype=torch.int32):
self.q_tail_idx_tensor = q_tail_idx_tensor
q_full_idx = torch.cat([q_head_idx_tensor, q_tail_idx_tensor])
- q_full_idx = q_full_idx.to(torch.float32).argsort().to(
- torch.int32)
+ q_full_idx = q_full_idx.to(torch.float32).argsort().to(torch.int32)
self.q_full_idx = q_full_idx
self.kv_idx_names = {
- 'kv_with_q_head_nomask_idx_tensor':
- kv_with_q_head_nomask_idx,
- 'kv_with_q_head_mask_idx_tensor': kv_with_q_head_mask_idx,
- 'kv_with_q_tail_nomask_idx_tensor':
- kv_with_q_tail_nomask_idx,
- 'kv_with_q_tail_mask_idx_tensor': kv_with_q_tail_mask_idx
+ "kv_with_q_head_nomask_idx_tensor": kv_with_q_head_nomask_idx,
+ "kv_with_q_head_mask_idx_tensor": kv_with_q_head_mask_idx,
+ "kv_with_q_tail_nomask_idx_tensor": kv_with_q_tail_nomask_idx,
+ "kv_with_q_tail_mask_idx_tensor": kv_with_q_tail_mask_idx,
}
for key, value in self.kv_idx_names.items():
tensor_npu = _list_to_tensor(value, self.device)
self.kv_idx_names[key] = tensor_npu
attn_mask_seqlens = torch.tensor(
- [chunk_seqlens, chunk_seqlens], dtype=torch.int32)
+ [chunk_seqlens, chunk_seqlens], dtype=torch.int32
+ )
head_attn_nomask_seqlens = torch.tensor(
- [chunk_seqlens, kv_with_q_head_nomask_seqlens],
- dtype=torch.int32)
+ [chunk_seqlens, kv_with_q_head_nomask_seqlens], dtype=torch.int32
+ )
tail_attn_nomask_seqlens = torch.tensor(
- [chunk_seqlens, kv_with_q_tail_nomask_seqlens],
- dtype=torch.int32)
+ [chunk_seqlens, kv_with_q_tail_nomask_seqlens], dtype=torch.int32
+ )
if self.vllm_config.model_config.use_mla:
pcp_prefill_mask = torch.triu(
- torch.ones(512,
- 512,
- device=self.device,
- dtype=self.dtype), 1)
+ torch.ones(512, 512, device=self.device, dtype=self.dtype), 1
+ )
else:
pcp_prefill_mask = torch.triu(
- torch.full((2048, 2048),
- True,
- device=self.device,
- dtype=torch.bool), 1)
+ torch.full(
+ (2048, 2048), True, device=self.device, dtype=torch.bool
+ ),
+ 1,
+ )
self.extra_long_seq_kwargs = {
- 'attn_mask_seqlens': attn_mask_seqlens,
- 'head_attn_nomask_seqlens': head_attn_nomask_seqlens,
- 'tail_attn_nomask_seqlens': tail_attn_nomask_seqlens,
- 'pcp_prefill_mask': pcp_prefill_mask
+ "attn_mask_seqlens": attn_mask_seqlens,
+ "head_attn_nomask_seqlens": head_attn_nomask_seqlens,
+ "tail_attn_nomask_seqlens": tail_attn_nomask_seqlens,
+ "pcp_prefill_mask": pcp_prefill_mask,
}
- long_seq_metadata.pcp_allgather_restore_idx = self.pcp_allgather_restore_idx[:
- num_actual_tokens_pcp_padded]
+ long_seq_metadata.pcp_allgather_restore_idx = (
+ self.pcp_allgather_restore_idx[:num_actual_tokens_pcp_padded]
+ )
long_seq_metadata.q_head_idx_tensor = self.q_head_idx_tensor
long_seq_metadata.q_tail_idx_tensor = self.q_tail_idx_tensor
long_seq_metadata.q_full_idx = self.q_full_idx
long_seq_metadata.kv_with_q_head_nomask_idx_tensor = self.kv_idx_names[
- 'kv_with_q_head_nomask_idx_tensor']
+ "kv_with_q_head_nomask_idx_tensor"
+ ]
long_seq_metadata.kv_with_q_head_mask_idx_tensor = self.kv_idx_names[
- 'kv_with_q_head_mask_idx_tensor']
+ "kv_with_q_head_mask_idx_tensor"
+ ]
long_seq_metadata.kv_with_q_tail_nomask_idx_tensor = self.kv_idx_names[
- 'kv_with_q_tail_nomask_idx_tensor']
+ "kv_with_q_tail_nomask_idx_tensor"
+ ]
long_seq_metadata.kv_with_q_tail_mask_idx_tensor = self.kv_idx_names[
- 'kv_with_q_tail_mask_idx_tensor']
+ "kv_with_q_tail_mask_idx_tensor"
+ ]
long_seq_metadata.attn_mask_seqlens = self.extra_long_seq_kwargs[
- 'attn_mask_seqlens']
+ "attn_mask_seqlens"
+ ]
long_seq_metadata.head_attn_nomask_seqlens = self.extra_long_seq_kwargs[
- 'head_attn_nomask_seqlens']
+ "head_attn_nomask_seqlens"
+ ]
long_seq_metadata.tail_attn_nomask_seqlens = self.extra_long_seq_kwargs[
- 'tail_attn_nomask_seqlens']
+ "tail_attn_nomask_seqlens"
+ ]
long_seq_metadata.pcp_prefill_mask = self.extra_long_seq_kwargs[
- 'pcp_prefill_mask']
+ "pcp_prefill_mask"
+ ]
self.long_seq_metadata = long_seq_metadata
return long_seq_metadata
@@ -4425,28 +4871,36 @@ def _generate_pcp_mtp_input(
num_scheduled_tokens_pcp_full = np.empty(num_reqs, dtype=np.int32)
for i, req_id in enumerate(self.input_batch.req_ids):
num_scheduled_tokens_pcp_full[i] = num_scheduled_tokens[req_id]
- req_indices_pcp_full = np.repeat(self.arange_np[:num_reqs],
- num_scheduled_tokens_pcp_full)
+ req_indices_pcp_full = np.repeat(
+ self.arange_np[:num_reqs], num_scheduled_tokens_pcp_full
+ )
cu_num_tokens_pcp_full = np.cumsum(num_scheduled_tokens_pcp_full)
self.query_start_loc_pcp_full_np[0] = 0
- self.query_start_loc_pcp_full_np[1:num_reqs +
- 1] = cu_num_tokens_pcp_full
- self.query_start_loc_pcp_full_np[num_reqs + 1:].fill(-1)
+ self.query_start_loc_pcp_full_np[1 : num_reqs + 1] = cu_num_tokens_pcp_full
+ self.query_start_loc_pcp_full_np[num_reqs + 1 :].fill(-1)
cumsums_offsets_pcp_full = np.repeat(
cu_num_tokens_pcp_full - num_scheduled_tokens_pcp_full,
- num_scheduled_tokens_pcp_full)
- arange_pcp_full = self.arange_np[:
- total_num_scheduled_tokens_pcp_full] - cumsums_offsets_pcp_full
- positions_np_pcp_full = self.positions_np_pcp_full[:
- total_num_scheduled_tokens_pcp_full]
- np.add(self.input_batch.num_computed_tokens_cpu[req_indices_pcp_full],
- arange_pcp_full,
- out=positions_np_pcp_full)
+ num_scheduled_tokens_pcp_full,
+ )
+ arange_pcp_full = (
+ self.arange_np[:total_num_scheduled_tokens_pcp_full]
+ - cumsums_offsets_pcp_full
+ )
+ positions_np_pcp_full = self.positions_np_pcp_full[
+ :total_num_scheduled_tokens_pcp_full
+ ]
+ np.add(
+ self.input_batch.num_computed_tokens_cpu[req_indices_pcp_full],
+ arange_pcp_full,
+ out=positions_np_pcp_full,
+ )
token_indices_pcp_full = (
- positions_np_pcp_full +
- req_indices_pcp_full * self.input_batch.token_ids_cpu.shape[1])
+ positions_np_pcp_full
+ + req_indices_pcp_full * self.input_batch.token_ids_cpu.shape[1]
+ )
torch.index_select(
self.input_batch.token_ids_cpu_tensor.flatten(),
0,
torch.from_numpy(token_indices_pcp_full),
- out=self.input_ids_pcp_full[:total_num_scheduled_tokens_pcp_full])
+ out=self.input_ids_pcp_full[:total_num_scheduled_tokens_pcp_full],
+ )
diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py
index e8729925fa..a90883cdcb 100644
--- a/vllm_ascend/worker/worker_v1.py
+++ b/vllm_ascend/worker/worker_v1.py
@@ -47,7 +47,7 @@
from vllm_ascend.device_allocator.camem import CaMemAllocator
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
from vllm_ascend.platform import NPUPlatform
-from vllm_ascend.utils import (init_ascend_soc_version,
+from vllm_ascend.utils import (init_ascend_soc_version, is_enable_nz,
prefill_context_parallel_enable,
register_ascend_customop, sleep_mode_enabled,
try_register_lib, vllm_version_is)
@@ -184,6 +184,11 @@ def wake_up(self, tags: Optional[list[str]] = None) -> None:
raise ValueError(
"Sleep mode is not enabled. Please compile vllm-ascend with COMPILE_CUSTOM_KERNELS=1."
)
+
+ if is_enable_nz():
+ raise ValueError(
+ "FRACTAL_NZ mode is enabled. This may cause model parameter precision issues "
+ "in the RL scenarios. Please set VLLM_ASCEND_ENABLE_NZ=0.")
allocator = CaMemAllocator.get_instance()
allocator.wake_up(tags=tags)