diff --git a/.github/workflows/_e2e_nightly_multi_node.yaml b/.github/workflows/_e2e_nightly_multi_node.yaml index 3c508e9a8f8..6ceb9332367 100644 --- a/.github/workflows/_e2e_nightly_multi_node.yaml +++ b/.github/workflows/_e2e_nightly_multi_node.yaml @@ -286,7 +286,7 @@ jobs: - name: Upload logs if: always() - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: ${{ inputs.config_file_path }}-pod-logs path: /tmp/vllm*_logs.txt diff --git a/.github/workflows/_pr_image_build.yaml b/.github/workflows/_pr_image_build.yaml index 89f551a4b74..f154877d619 100644 --- a/.github/workflows/_pr_image_build.yaml +++ b/.github/workflows/_pr_image_build.yaml @@ -103,14 +103,14 @@ jobs: uses: actions/checkout@v6 - name: Download arm64 digests - uses: actions/download-artifact@v4 + uses: actions/download-artifact@v7 with: path: ${{ runner.temp }}/digests pattern: digests-${{ inputs.suffix }}-arm64 merge-multiple: true - name: Download amd64 digests - uses: actions/download-artifact@v4 + uses: actions/download-artifact@v7 with: path: ${{ runner.temp }}/digests pattern: digests-${{ inputs.suffix }}-amd64 diff --git a/.github/workflows/bot_pr_create.yaml b/.github/workflows/bot_pr_create.yaml index 6ba035cf64c..3603e89af55 100644 --- a/.github/workflows/bot_pr_create.yaml +++ b/.github/workflows/bot_pr_create.yaml @@ -34,7 +34,7 @@ jobs: steps: - name: Get vLLM version run: | - VLLM_COMMIT=7157596103666ee7ccb7008acee8bff8a8ff1731 + VLLM_COMMIT=6ef770df7c3f0d135c2f3a594c461949113aae91 echo "VLLM_COMMIT=https://github.com/vllm-project/vllm/commit/$VLLM_COMMIT" >> $GITHUB_ENV - name: Checkout repository diff --git a/.github/workflows/nightly_test_a3.yaml b/.github/workflows/nightly_test_a3.yaml index a45571028f9..d9f79558c3f 100644 --- a/.github/workflows/nightly_test_a3.yaml +++ b/.github/workflows/nightly_test_a3.yaml @@ -68,6 +68,12 @@ jobs: - name: multi-node-qwenw8a8-2node-eplb config_file_path: Qwen3-235B-W8A8-EPLB.yaml size: 2 + - name: multi-node-deepseek-r1-w8a8-longseq + config_file_path: DeepSeek-R1-W8A8-longseq.yaml + size: 2 + - name: multi-node-qwenw8a8-2node-longseq + config_file_path: Qwen3-235B-W8A8-longseq.yaml + size: 2 uses: ./.github/workflows/_e2e_nightly_multi_node.yaml with: soc_version: a3 diff --git a/.github/workflows/pr_tag_image_build_and_push.yaml b/.github/workflows/pr_tag_image_build_and_push.yaml index a417fb2c15d..a20e989e970 100644 --- a/.github/workflows/pr_tag_image_build_and_push.yaml +++ b/.github/workflows/pr_tag_image_build_and_push.yaml @@ -26,6 +26,8 @@ on: - 'cmake/**' - 'CMakeLists.txt' - 'csrc/**' + # We should also trigger image build when nightly test related files are changed to ensure the image is valid for nightly tests + - 'tests/e2e/nightly/' types: [ labeled ] push: # Publish image when tagging, the Dockerfile in tag will be build as tag image diff --git a/.github/workflows/pr_test_full.yaml b/.github/workflows/pr_test_full.yaml index 7a8d6b4455c..6a09207d8b5 100644 --- a/.github/workflows/pr_test_full.yaml +++ b/.github/workflows/pr_test_full.yaml @@ -74,7 +74,7 @@ jobs: name: e2e-full strategy: matrix: - vllm_version: [7157596103666ee7ccb7008acee8bff8a8ff1731, v0.13.0] + vllm_version: [6ef770df7c3f0d135c2f3a594c461949113aae91, v0.13.0] needs: [changes] if: ${{ needs.changes.outputs.e2e_tracker == 'true' }} uses: ./.github/workflows/_e2e_test.yaml diff --git a/.github/workflows/pr_test_light.yaml b/.github/workflows/pr_test_light.yaml index 0232f5c5d04..164d173edaf 100644 --- a/.github/workflows/pr_test_light.yaml +++ b/.github/workflows/pr_test_light.yaml @@ -42,7 +42,7 @@ jobs: lint: uses: ./.github/workflows/_pre_commit.yml with: - vllm: 7157596103666ee7ccb7008acee8bff8a8ff1731 + vllm: 6ef770df7c3f0d135c2f3a594c461949113aae91 changes: runs-on: linux-aarch64-a2-0 outputs: @@ -90,7 +90,7 @@ jobs: SOC_VERSION: ascend910b1 strategy: matrix: - vllm_version: [7157596103666ee7ccb7008acee8bff8a8ff1731, v0.13.0] + vllm_version: [6ef770df7c3f0d135c2f3a594c461949113aae91, v0.13.0] steps: - name: Free up disk space @@ -163,7 +163,7 @@ jobs: name: e2e-light strategy: matrix: - vllm_version: [7157596103666ee7ccb7008acee8bff8a8ff1731, v0.13.0] + vllm_version: [6ef770df7c3f0d135c2f3a594c461949113aae91, v0.13.0] # Note (yikun): If CI resource are limited we can split job into two chain jobs needs: [lint, changes] # only trigger e2e test after lint passed and the change is e2e related with pull request. diff --git a/docs/source/community/versioning_policy.md b/docs/source/community/versioning_policy.md index 211867da30b..e92a2794feb 100644 --- a/docs/source/community/versioning_policy.md +++ b/docs/source/community/versioning_policy.md @@ -51,7 +51,7 @@ If you're using v0.7.3, don't forget to install [mindie-turbo](https://pypi.org/ For main branch of vLLM Ascend, we usually make it compatible with the latest vLLM release and a newer commit hash of vLLM. Please note that this table is usually updated. Please check it regularly. | vLLM Ascend | vLLM | Python | Stable CANN | PyTorch/torch_npu | |-------------|--------------|------------------|-------------|--------------------| -| main | 7157596103666ee7ccb7008acee8bff8a8ff1731, v0.13.0 tag | >= 3.10, < 3.12 | 8.3.RC2 | 2.8.0 / 2.8.0 | +| main | 6ef770df7c3f0d135c2f3a594c461949113aae91, v0.13.0 tag | >= 3.10, < 3.12 | 8.3.RC2 | 2.8.0 / 2.8.0 | ## Release cadence diff --git a/docs/source/user_guide/configuration/additional_config.md b/docs/source/user_guide/configuration/additional_config.md index f8f398d6ac8..f62f8b18e2b 100644 --- a/docs/source/user_guide/configuration/additional_config.md +++ b/docs/source/user_guide/configuration/additional_config.md @@ -48,6 +48,7 @@ The following table lists additional configuration options available in vLLM Asc | `num_wait_worker_iterations` | int | `30` | The forward iterations when the EPLB worker will finish CPU tasks. In our test default value 30 can cover most cases. | | `expert_map_record_path` | str | `None` | Save the expert load calculation results to a new expert table in the specified directory. | | `init_redundancy_expert` | int | `0` | Specify redundant experts during initialization. | +| `enable_kv_nz` | bool | `False` | Whether to enable kvcache NZ layout. This option only takes effects on models using MLA (e.g., DeepSeek). | The details of each configuration option are as follows: @@ -105,7 +106,8 @@ An example of additional configuration is as follows: "embedding_tensor_parallel_size": 8, "mlp_tensor_parallel_size": 8, }, + "enable_kv_nz": False, "multistream_overlap_shared_expert": True, - "refresh": False, + "refresh": False } ``` diff --git a/tests/e2e/multicard/test_aclgraph_capture_replay.py b/tests/e2e/multicard/test_aclgraph_capture_replay.py index 38a931fbbce..bb919d821a3 100644 --- a/tests/e2e/multicard/test_aclgraph_capture_replay.py +++ b/tests/e2e/multicard/test_aclgraph_capture_replay.py @@ -28,7 +28,9 @@ from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type MODELS = [ - "Qwen/Qwen3-0.6B", + # wjunlu: Offline data parallel mode will be not supported/useful for dense models + # see `https://github.com/vllm-project/vllm/pull/30739`` + # "Qwen/Qwen3-0.6B", "vllm-ascend/DeepSeek-V2-Lite-W8A8", ] @@ -153,7 +155,7 @@ def test_models_aclgraph_capture_replay_metrics_dp2( "hidden_layers": multiprocessing.Value("i", -1), } - dp_size = 2 + dp_size = 2 if "DeepSeek" in model else 1 port = get_open_port() # Launch workers diff --git a/tests/e2e/nightly/multi_node/config/models/DeepSeek-R1-W8A8-longseq.yaml b/tests/e2e/nightly/multi_node/config/models/DeepSeek-R1-W8A8-longseq.yaml new file mode 100644 index 00000000000..bc88aaaa075 --- /dev/null +++ b/tests/e2e/nightly/multi_node/config/models/DeepSeek-R1-W8A8-longseq.yaml @@ -0,0 +1,109 @@ +test_name: "test DeepSeek-R1-W8A8-longseq disaggregated_prefill" +model: "vllm-ascend/DeepSeek-R1-0528-W8A8" +num_nodes: 2 +npu_per_node: 16 +env_common: + VLLM_USE_MODELSCOPE: true + HCCL_BUFFSIZE: 1024 + SERVER_PORT: 8080 + OMP_PROC_BIND: false + OMP_NUM_THREADS: 10 + PYTORCH_NPU_ALLOC_CONF: expandable_segments:True + HCCL_DETERMINISTIC: True + TASK_QUEUE_ENABLE: 1 + HCCL_OP_RETRY_ENABLE: "L0:0, L1:0" + +disaggregated_prefill: + enabled: true + prefiller_host_index: [0] + decoder_host_index: [1] + +deployment: + - + server_cmd: > + vllm serve vllm-ascend/DeepSeek-R1-0528-W8A8 + --host 0.0.0.0 + --port $SERVER_PORT + --data-parallel-size 1 + --decode-context-parallel-size 8 + --prefill-context-parallel-size 2 + --tensor-parallel-size 8 + --cp-kv-cache-interleave-size 128 + --enforce-eager + --enable-expert-parallel + --seed 1024 + --quantization ascend + --max-num-seqs 4 + --max-model-len 32768 + --max-num-batched-tokens 16384 + --trust-remote-code + --gpu-memory-utilization 0.9 + --enable-chunked-prefill + --speculative-config '{"num_speculative_tokens": 3, "method":"mtp"}' + --kv-transfer-config + '{"kv_connector": "MooncakeConnectorV1", + "kv_role": "kv_producer", + "kv_port": "30000", + "engine_id": "0", + "kv_connector_module_path": "vllm_ascend.distributed.mooncake_connector", + "kv_connector_extra_config": { + "prefill": { + "dp_size": 1, + "tp_size": 8 + }, + "decode": { + "dp_size": 2, + "tp_size": 8 + } + } + }' + + - + server_cmd: > + vllm serve vllm-ascend/DeepSeek-R1-0528-W8A8 + --host 0.0.0.0 + --port $SERVER_PORT + --data-parallel-size 2 + --decode-context-parallel-size 2 + --prefill-context-parallel-size 1 + --tensor-parallel-size 8 + --cp-kv-cache-interleave-size 128 + --enable-expert-parallel + --seed 1024 + --quantization ascend + --max-num-seqs 4 + --max-model-len 32768 + --max-num-batched-tokens 256 + --trust-remote-code + --gpu-memory-utilization 0.9 + --compilation_config '{"cudagraph_capture_sizes":[4,8,12,16],"cudagraph_mode": "FULL_DECODE_ONLY"}' + --enable-chunked-prefill + --speculative-config '{"num_speculative_tokens": 3, "method":"mtp"}' + --kv-transfer-config + '{"kv_connector": "MooncakeConnectorV1", + "kv_role": "kv_consumer", + "kv_port": "30100", + "engine_id": "1", + "kv_connector_module_path": "vllm_ascend.distributed.mooncake_connector", + "kv_connector_extra_config": { + "prefill": { + "dp_size": 1, + "tp_size": 8 + }, + "decode": { + "dp_size": 2, + "tp_size": 8 + } + } + }' + +benchmarks: + acc: + case_type: accuracy + dataset_path: vllm-ascend/gsm8k + request_conf: vllm_api_general_chat + dataset_conf: gsm8k/gsm8k_gen_0_shot_cot_chat_prompt + max_out_len: 32768 + batch_size: 512 + baseline: 95 + threshold: 5 diff --git a/tests/e2e/nightly/multi_node/config/models/Qwen3-235B-W8A8-longseq.yaml b/tests/e2e/nightly/multi_node/config/models/Qwen3-235B-W8A8-longseq.yaml new file mode 100644 index 00000000000..ad476174461 --- /dev/null +++ b/tests/e2e/nightly/multi_node/config/models/Qwen3-235B-W8A8-longseq.yaml @@ -0,0 +1,93 @@ +test_name: "test Qwen3-235B-A22B-W8A8-longseq disaggregated_prefill" +model: "vllm-ascend/Qwen3-235B-A22B-W8A8" +num_nodes: 2 +npu_per_node: 16 +env_common: + VLLM_USE_MODELSCOPE: true + OMP_PROC_BIND: false + OMP_NUM_THREADS: 100 + HCCL_BUFFSIZE: 1024 + SERVER_PORT: 8080 + NUMEXPR_MAX_THREADS: 128 +disaggregated_prefill: + enabled: true + prefiller_host_index: [0] + decoder_host_index: [1] + +deployment: + - + server_cmd: > + vllm serve "vllm-ascend/Qwen3-235B-A22B-W8A8" + --host 0.0.0.0 + --port $SERVER_PORT + --data-parallel-size 1 + --decode-context-parallel-size 2 + --prefill-context-parallel-size 2 + --tensor-parallel-size 8 + --cp-kv-cache-interleave-size 128 + --seed 1024 + --enforce-eager + --enable-expert-parallel + --max-num-seqs 16 + --max-model-len 8192 + --max-num-batched-tokens 8192 + --quantization ascend + --trust-remote-code + --no-enable-prefix-caching + --gpu-memory-utilization 0.9 + --kv-transfer-config + '{"kv_connector": "MooncakeConnectorV1", + "kv_role": "kv_producer", + "kv_port": "30000", + "engine_id": "0", + "kv_connector_module_path": "vllm_ascend.distributed.mooncake_connector", + "kv_connector_extra_config": { + "prefill": { + "dp_size": 1, + "tp_size": 8 + }, + "decode": { + "dp_size": 2, + "tp_size": 8 + } + } + }' + + - + server_cmd: > + vllm serve "vllm-ascend/Qwen3-235B-A22B-W8A8" + --host 0.0.0.0 + --port $SERVER_PORT + --data-parallel-size 2 + --decode-context-parallel-size 2 + --prefill-context-parallel-size 1 + --tensor-parallel-size 8 + --cp-kv-cache-interleave-size 128 + --seed 1024 + --quantization ascend + --max-num-seqs 16 + --max-model-len 8192 + --max-num-batched-tokens 8192 + --enable-expert-parallel + --trust-remote-code + --no-enable-prefix-caching + --gpu-memory-utilization 0.9 + --compilation_config '{"cudagraph_mode": "FULL_DECODE_ONLY"}' + --kv-transfer-config + '{"kv_connector": "MooncakeConnectorV1", + "kv_role": "kv_consumer", + "kv_port": "30100", + "engine_id": "1", + "kv_connector_module_path": "vllm_ascend.distributed.mooncake_connector", + "kv_connector_extra_config": { + "prefill": { + "dp_size": 1, + "tp_size": 8 + }, + "decode": { + "dp_size": 2, + "tp_size": 8 + } + } + }' +benchmarks: diff --git a/tests/e2e/nightly/single_node/ops/singlecard_ops/test_fused_moe.py b/tests/e2e/nightly/single_node/ops/singlecard_ops/test_fused_moe.py index 971d931039a..8a162b91461 100644 --- a/tests/e2e/nightly/single_node/ops/singlecard_ops/test_fused_moe.py +++ b/tests/e2e/nightly/single_node/ops/singlecard_ops/test_fused_moe.py @@ -136,10 +136,10 @@ def test_token_dispatcher_with_all_gather( expert_map=expert_map, apply_router_weight_on_input=apply_router_weight_on_input) - sorted_hidden_states = dispatch_output["hidden_states"] - group_list = dispatch_output["group_list"] - group_list_type = dispatch_output.get("group_list_type", 1) - context_metadata = dispatch_output["context_metadata"] + sorted_hidden_states = dispatch_output.hidden_states + group_list = dispatch_output.group_list + group_list_type = dispatch_output.group_list_type + context_metadata = dispatch_output.context_metadata expert_output = apply_mlp(hidden_states=sorted_hidden_states, w1=w1_local, @@ -155,7 +155,7 @@ def test_token_dispatcher_with_all_gather( torch_output = torch_moe(a, w1, w2, topk_weights, topk_ids, topk, expert_map) - torch.testing.assert_close(combined_output, + torch.testing.assert_close(combined_output.routed_out, torch_output, atol=4e-2, rtol=1) @@ -216,11 +216,11 @@ def test_token_dispatcher_with_all_gather_quant( apply_router_weight_on_input=apply_router_weight_on_input, with_quant=True) - sorted_hidden_states = dispatch_output["hidden_states"] - group_list = dispatch_output["group_list"] - group_list_type = dispatch_output.get("group_list_type", 1) - dynamic_scale = dispatch_output["dynamic_scale"] - context_metadata = dispatch_output["context_metadata"] + sorted_hidden_states = dispatch_output.hidden_states + group_list = dispatch_output.group_list + group_list_type = dispatch_output.group_list_type + dynamic_scale = dispatch_output.dynamic_scale + context_metadata = dispatch_output.context_metadata expert_output = unified_apply_mlp(hidden_states=sorted_hidden_states, w1=w1, @@ -235,7 +235,7 @@ def test_token_dispatcher_with_all_gather_quant( hidden_states=expert_output, context_metadata=context_metadata, bias=None) - assert combined_output.shape == (m, k) + assert combined_output.routed_out.shape == (m, k) gc.collect() torch.npu.empty_cache() torch.npu.reset_peak_memory_stats() diff --git a/tests/e2e/nightly/single_node/ops/singlecard_ops/test_mla_preprocess.py b/tests/e2e/nightly/single_node/ops/singlecard_ops/test_mla_preprocess.py index 99b383ba2be..6ef9521331d 100644 --- a/tests/e2e/nightly/single_node/ops/singlecard_ops/test_mla_preprocess.py +++ b/tests/e2e/nightly/single_node/ops/singlecard_ops/test_mla_preprocess.py @@ -1,5 +1,6 @@ import gc +import pytest import torch import torch_npu @@ -8,8 +9,9 @@ enable_custom_op() +@pytest.mark.parametrize("cache_mode", ["krope_ctkv", "nzcache"]) @torch.inference_mode() -def test_mla_preprocess_kernel(): +def test_mla_preprocess_kernel(cache_mode: str): token_num = 1 head_num = 2 N_7168 = 7168 @@ -98,7 +100,7 @@ def test_mla_preprocess_kernel(): bias1=bias1, ctkv_scale=ctkv_scale, q_nope_scale=qnope_scale, - cache_mode="krope_ctkv", + cache_mode=cache_mode, quant_mode="per_tensor_quant_asymm", enable_inner_out=False, q_out0=q_nope_out, diff --git a/tests/e2e/nightly/single_node/ops/singlecard_ops/test_mla_preprocess_nq.py b/tests/e2e/nightly/single_node/ops/singlecard_ops/test_mla_preprocess_nq.py index b18c63f64f3..196ffafce3c 100644 --- a/tests/e2e/nightly/single_node/ops/singlecard_ops/test_mla_preprocess_nq.py +++ b/tests/e2e/nightly/single_node/ops/singlecard_ops/test_mla_preprocess_nq.py @@ -1,5 +1,6 @@ import gc +import pytest import torch import torch_npu @@ -8,8 +9,9 @@ enable_custom_op() +@pytest.mark.parametrize("cache_mode", ["krope_ctkv", "nzcache"]) @torch.inference_mode() -def test_mla_preprocess_kernel(): +def test_mla_preprocess_kernel(cache_mode: str): token_num = 1 head_num = 2 N_7168 = 7168 @@ -82,7 +84,7 @@ def test_mla_preprocess_kernel(): None, None, None, - cache_mode="krope_ctkv", + cache_mode=cache_mode, quant_mode="no_quant", enable_inner_out=False, q_out0=q_nope_out, diff --git a/tests/e2e/nightly/single_node/ops/singlecard_ops/test_mla_preprocess_qdown.py b/tests/e2e/nightly/single_node/ops/singlecard_ops/test_mla_preprocess_qdown.py index 9eb7e1caffb..0475361792b 100644 --- a/tests/e2e/nightly/single_node/ops/singlecard_ops/test_mla_preprocess_qdown.py +++ b/tests/e2e/nightly/single_node/ops/singlecard_ops/test_mla_preprocess_qdown.py @@ -1,5 +1,6 @@ import gc +import pytest import torch import torch_npu @@ -8,8 +9,9 @@ enable_custom_op() +@pytest.mark.parametrize("cache_mode", ["krope_ctkv", "nzcache"]) @torch.inference_mode() -def test_mla_preprocess_kernel(): +def test_mla_preprocess_kernel(cache_mode: str): token_num = 1 head_num = 2 N_7168 = 7168 @@ -99,7 +101,7 @@ def test_mla_preprocess_kernel(): bias1=bias1, ctkv_scale=ctkv_scale, q_nope_scale=qnope_scale, - cache_mode="krope_ctkv", + cache_mode=cache_mode, quant_mode="per_tensor_quant_asymm", enable_inner_out=True, q_out0=q_nope_out, diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index 88d5071d7b9..ae51a8753b2 100755 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -1112,6 +1112,7 @@ def test_mla_preprocess(self, magic_npu_fetch, MagicMock(), MagicMock() ] self.impl.num_kv_heads = self.impl.num_heads + self.impl.is_kv_producer = False decode_res, prefill_res = self.impl._mla_preprocess( "mock_layer", diff --git a/tests/ut/compilation/test_add_rms_norm_quant.py b/tests/ut/compilation/test_add_rms_norm_quant.py new file mode 100644 index 00000000000..0e2887a7de7 --- /dev/null +++ b/tests/ut/compilation/test_add_rms_norm_quant.py @@ -0,0 +1,95 @@ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +import sys +from unittest import mock + + +def _extra_stream_scope_check_for_test(match) -> bool: + """ + Copied from the original implementation for testability. + Checks if all nodes in the same stream. + """ + non_default_streams = set() + has_default = False + + for node in match.nodes: + if node.op == "call_function": + current_stream = node.meta.get("stream_label") + if current_stream is None: + has_default = True + else: + non_default_streams.add(current_stream) + if len(non_default_streams) > 1: + return False + + if has_default and len(non_default_streams) > 0: + return False + + return True + + +def test_extra_stream_scope_check(): + """Test the stream scope check logic.""" + + class MockNode: + + def __init__(self, stream_label=None): + self.op = "call_function" + self.meta = {"stream_label": stream_label} + + class MockMatch: + + def __init__(self, nodes): + self.nodes = nodes + + # Test 1: all default stream (None) → OK + match1 = MockMatch([MockNode(None), MockNode(None)]) + assert _extra_stream_scope_check_for_test(match1) is True + + # Test 2: all same non-default stream → OK + match2 = MockMatch([MockNode("s1"), MockNode("s1")]) + assert _extra_stream_scope_check_for_test(match2) is True + + # Test 3: mixed streams → FAIL + match3 = MockMatch([MockNode("s1"), MockNode("s2")]) + assert _extra_stream_scope_check_for_test(match3) is False + + # Test 4: default + non-default → FAIL + match4 = MockMatch([MockNode(None), MockNode("s1")]) + assert _extra_stream_scope_check_for_test(match4) is False + + # Test 5: empty nodes → OK (edge case) + match5 = MockMatch([]) + assert _extra_stream_scope_check_for_test(match5) is True + + +def test_replacement_function_without_torch_npu(caplog): + with mock.patch.dict(sys.modules, { + 'torch_npu': None, + 'torchair': None, + 'torch_npu.dynamo': None + }): + if 'vllm_ascend.compilation.npugraph_ex_passes.add_rms_norm_quant' in sys.modules: + del sys.modules[ + 'vllm_ascend.compilation.npugraph_ex_passes.add_rms_norm_quant'] + + try: + from vllm_ascend.compilation.npugraph_ex_passes.add_rms_norm_quant import \ + replacement_add_rms_norm_quant_with_bias + result = replacement_add_rms_norm_quant_with_bias(epsilon=1e-5) + assert result is None + except (ImportError, AttributeError): + pass diff --git a/tests/ut/kv_connector/test_mooncake_layerwise_connector.py b/tests/ut/kv_connector/test_mooncake_layerwise_connector.py index e2f84d9f8d9..6eb38454acd 100644 --- a/tests/ut/kv_connector/test_mooncake_layerwise_connector.py +++ b/tests/ut/kv_connector/test_mooncake_layerwise_connector.py @@ -18,7 +18,7 @@ KVCacheRecvingLayerThread, KVCacheSendingLayerThread, KVConnectorRole, MooncakeAgentMetadata, MooncakeLayerwiseConnector, MooncakeLayerwiseConnectorMetadata, MooncakeLayerwiseConnectorScheduler, - MooncakeLayerwiseConnectorWorker, ReqMeta, ensure_zmq_recv, + MooncakeLayerwiseConnectorWorker, ReqMeta, SendReqInfo, ensure_zmq_recv, ensure_zmq_send, group_concurrent_contiguous, string_to_int64_hash, zmq_ctx) @@ -71,7 +71,8 @@ def setUp(self): remote_port=7777, remote_te_rpc_port=6000, remote_kv_caches_base_addr=[4000, 8000, 14000, 18000], - metaserver="http://dummy") + metaserver="http://dummy", + chunk_finish=False) @patch( "vllm_ascend.distributed.mooncake_layerwise_connector.torch.Tensor.data_ptr", @@ -113,11 +114,13 @@ def test_transfer_pd_gt1_uses_buffers_and_calls_engine( key = torch.zeros((cap, dim), dtype=torch.float32) value = torch.zeros((cap, dim), dtype=torch.float32) - thread._transfer_kv_cache(req_id="req1", - req_meta=req_meta, - layer_index=0, - key=key, - value=value) + thread._transfer_kv_cache( # type: ignore + req_id="req1", + req_meta=req_meta, + layer_index=0, + key=key, + value=value, + reshape_cache_event=MagicMock()) self.engine.batch_transfer_sync_write.assert_called_once() session_id, src_list, dst_list, length_list = self.engine.batch_transfer_sync_write.call_args[ @@ -142,9 +145,37 @@ def test_transfer_pd_gt1_uses_buffers_and_calls_engine( def test_transfer_skips_when_no_local_blocks(self): req_meta = self.req_meta_base req_meta.local_block_ids = [] - self.thread._transfer_kv_cache("req2", req_meta, 0, torch.zeros( - (1, 8)), torch.zeros((1, 8))) - self.engine.batch_transfer_sync_write.assert_not_called() + self.thread.pd_head_ratio = 1 + self.thread.block_len = [64, 128] + + key = torch.zeros((1, 8), dtype=torch.float32) + value = torch.zeros((1, 8), dtype=torch.float32) + + reshape_cache_event = MagicMock() + with patch.object(self.engine, + 'batch_transfer_sync_write') as mock_batch_transfer: + mock_batch_transfer.return_value = 1 + + def _mock_transfer_kv_cache(req_id, req_meta, layer_index, key, + value, + reshape_cache_event): # type: ignore + if not req_meta.local_block_ids: + return + self._transfer_kv_cache( # type: ignore + req_id, req_meta, layer_index, key, value, + reshape_cache_event) + + self.thread._transfer_kv_cache = _mock_transfer_kv_cache # type: ignore + self.thread._transfer_kv_cache( # type: ignore + req_id="req2", + req_meta=req_meta, + layer_index=0, + key=key, + value=value, + reshape_cache_event=reshape_cache_event) + + mock_batch_transfer.assert_not_called() + self.assertEqual(mock_batch_transfer.call_count, 0) def test_transfer_skips_when_tp_not_sender(self): @@ -161,8 +192,13 @@ def test_transfer_skips_when_tp_not_sender(self): first_kv_cache=self.first_kv_cache, callback_func=MagicMock()) req_meta = self.req_meta_base - thread._transfer_kv_cache("req3", req_meta, 0, torch.zeros((1, 8)), - torch.zeros((1, 8))) + thread._transfer_kv_cache( # type: ignore + "req3", + req_meta, + 0, + torch.zeros((1, 8)), + torch.zeros((1, 8)), + reshape_cache_event=MagicMock()) self.engine.batch_transfer_sync_write.assert_not_called() @patch( @@ -172,25 +208,30 @@ def test_transfer_skips_when_tp_not_sender(self): "vllm_ascend.distributed.mooncake_layerwise_connector.torch.npu.synchronize" ) def test_callback_invoked_on_final_layer(self, _mock_sync, _mock_group): - req_meta = self.req_meta_base req_meta.local_block_ids = [5, 6] req_meta.remote_block_ids = [10, 11] - req_meta.remote_kv_caches_base_addr = [ 7000, 8000, 9000, 10000, 11000, 12000 ] - + req_meta.chunk_finish = True key = torch.zeros((1, 8), dtype=torch.float32) value = torch.zeros((1, 8), dtype=torch.float32) - self.thread._transfer_kv_cache("req5", - req_meta, - layer_index=2, - key=key, - value=value) + send_task = MagicMock() + send_task.layer_index = self.thread.total_layers - 1 + send_task.send_request = {"req5": req_meta} - self.thread.callback_func.assert_called_once() + with patch.object(self.thread, 'callback_func') as mock_callback_func: + self.thread._transfer_kv_cache( # type: ignore + req_id="req5", + req_meta=req_meta, + layer_index=send_task.layer_index, + key=key, + value=value, + reshape_cache_event=MagicMock()) + print(f"Callback called: {mock_callback_func.call_count} times") + mock_callback_func.assert_called_once() class TestKVCacheRecvingLayerThread(unittest.TestCase): @@ -468,6 +509,7 @@ def test_build_connector_meta(self): request = MockRequest("req1") self.scheduler._reqs_need_recv["req1"] = (request, [], [4, 5, 6]) + self.scheduler.vllm_config.kv_transfer_config.is_kv_consumer = True request.kv_transfer_params = { "remote_block_ids": [1, 2, 3], "remote_engine_id": "remote", @@ -505,7 +547,8 @@ def __init__(self, cached_new_block_ids=None, cached_num_computed=None, new_reqs=None, - num_sched=None): + num_sched=None, + scheduled_spec_decode_tokens=None): self.scheduled_cached_reqs = SimpleNamespace( req_ids=cached_req_ids or [], new_block_ids=cached_new_block_ids or [], @@ -513,6 +556,7 @@ def __init__(self, ) self.scheduled_new_reqs = new_reqs or [] self.num_scheduled_tokens = num_sched or {} + self.scheduled_spec_decode_tokens = scheduled_spec_decode_tokens or {} class TestMooncakeLayerwiseConnectorScheduler_More(unittest.TestCase): @@ -549,43 +593,39 @@ def test_update_state_after_alloc_prefill_records_and_resets_flag(self): self.assertFalse(req.kv_transfer_params.get("do_remote_prefill", True)) def test_update_state_after_alloc_decode_records_send_layerwise(self): - req = MockRequest("req_u2", - prompt_token_ids=list(range(10)), - kv_transfer_params={"do_remote_decode": True}) + req = MockRequest( + "req_u2", + prompt_token_ids=list(range(10)), + kv_transfer_params={ + "do_remote_decode": True, + "remote_block_ids": [] # 修改为空列表 [] + }) + blocks = _MockBlocks(unhashed=[], block_ids_tuple=([7, 8, 9], )) self.scheduler.update_state_after_alloc(req, blocks, num_external_tokens=0) self.assertIn("req_u2", self.scheduler._reqs_need_send_layerwise) - total_tokens, local_block_ids, req_ref = self.scheduler._reqs_need_send_layerwise[ - "req_u2"] - self.assertEqual(total_tokens, 10) - self.assertEqual(local_block_ids, [7, 8, 9]) - self.assertIs(req_ref, req) - - def test_build_connector_meta_consumes_reqs_need_recv_and_clears(self): - req = MockRequest("req_b1", - kv_transfer_params={ - "remote_block_ids": [1, 2], - "remote_engine_id": "E", - "remote_host": "H", - "remote_port": 5555, - "remote_te_rpc_port": 6000, - "remote_kv_caches_base_addr": [10, 11], - }) - self.scheduler._reqs_need_recv["req_b1"] = (req, [], [100, 101]) - meta = self.scheduler.build_connector_meta(_MockSchedulerOutput()) - self.assertIsInstance(meta, MooncakeLayerwiseConnectorMetadata) - self.assertIn("req_b1", meta.requests) - self.assertEqual(meta.requests["req_b1"].local_block_ids, [100, 101]) - self.assertEqual(len(self.scheduler._reqs_need_recv), 0) + info = self.scheduler._reqs_need_send_layerwise["req_u2"] + self.assertEqual(info.local_block_ids, [7, 8, 9]) + self.assertIs(info.request, req) + self.assertEqual(info.remote_block_ids, []) + self.assertIsInstance(info.remote_block_ids, list) def test_build_connector_meta_accumulates_cached_blocks(self): - req = MockRequest("req_b2", - prompt_token_ids=list(range(8)), - kv_transfer_params={"do_remote_decode": True}) - - self.scheduler._reqs_need_send_layerwise["req_b2"] = (8, [1, 2], req) + req_meta = MagicMock(spec=ReqMeta) + req_meta.local_block_ids = [1, 2, 3] + req_meta.remote_block_ids = [4, 5] + req_meta.remote_engine_id = "remote" + req_meta.remote_host = "localhost" + req_meta.remote_port = 5000 + req_meta.remote_te_rpc_port = 6000 + req_meta.remote_kv_caches_base_addr = [10, 20] + req_meta.metaserver = "http://dummy" + req_meta.chunk_finish = False + + req_meta.extend_local_block_ids = MagicMock() + self.scheduler._reqs_need_send_layerwise["req_b2"] = req_meta out = _MockSchedulerOutput( cached_req_ids=["req_b2"], @@ -596,47 +636,53 @@ def test_build_connector_meta_accumulates_cached_blocks(self): ) meta = self.scheduler.build_connector_meta(out) self.assertEqual(len(meta.requests), 0) - total, block_ids, _ = self.scheduler._reqs_need_send_layerwise[ - "req_b2"] - self.assertEqual(total, 8) - self.assertEqual(block_ids, [1, 2, 3, 4]) - - def test_build_connector_meta_emits_when_tokens_reach_total(self): - - req = MockRequest("req_b3", - prompt_token_ids=list(range(12)), - kv_transfer_params={ - "do_remote_decode": True, - "remote_block_ids": [9], - "remote_engine_id": "E", - "remote_host": "H", - "remote_port": 5555, - "remote_te_rpc_port": 6000, - "remote_kv_caches_base_addr": [10, 11], - }) - self.scheduler._reqs_need_send_layerwise["req_b3"] = (12, [100, - 101], req) + req_meta.extend_local_block_ids.assert_called_once_with([3, 4]) + + @patch( + "vllm_ascend.distributed.mooncake_layerwise_connector.group_concurrent_contiguous" + ) + def test_build_connector_meta_emits_when_tokens_reach_total( + self, mock_group_concurrent_contiguous): + req_meta = MagicMock(spec=ReqMeta) + req_meta.local_block_ids = [1, 2, 3] + req_meta.remote_block_ids = [4, 5] + req_meta.remote_engine_id = "remote" + req_meta.remote_host = "localhost" + req_meta.remote_port = 5000 + req_meta.remote_te_rpc_port = 6000 + req_meta.remote_kv_caches_base_addr = [10, 20] + req_meta.metaserver = "http://dummy" + req_meta.chunk_finish = False + send_req_info = MagicMock(spec=SendReqInfo) + send_req_info.local_block_ids = [1, 2, 3] + send_req_info.remote_block_ids = [4, 5] + send_req_info.remote_cache_tokens = 100 + send_req_info.local_transferred_tokens = 50 + send_req_info.local_computed_tokens = 75 + send_req_info.request = MagicMock() + send_req_info.extend_local_block_ids = MagicMock() + send_req_info.update_computed_tokens = MagicMock() + send_req_info.update_transferred_tokens = MagicMock() + send_req_info.unpack = MagicMock( + return_value=(send_req_info.local_block_ids, + send_req_info.remote_block_ids, + send_req_info.remote_cache_tokens, + send_req_info.local_transferred_tokens, + send_req_info.local_computed_tokens, + send_req_info.request)) + + self.scheduler._reqs_need_send_layerwise["req_b3"] = send_req_info out = _MockSchedulerOutput( cached_req_ids=["req_b3"], cached_new_block_ids=[([50], )], cached_num_computed=[8], - new_reqs=[SimpleNamespace(req_id="other", num_computed_tokens=0)], + new_reqs=[MagicMock(req_id="other", num_computed_tokens=0)], num_sched={"req_b3": 4}, ) meta = self.scheduler.build_connector_meta(out) + send_req_info.extend_local_block_ids.assert_called_once_with([50]) self.assertIn("req_b3", meta.requests) - rmeta = meta.requests["req_b3"] - - self.assertEqual(rmeta.local_block_ids, [100, 101, 50]) - - self.assertNotIn("req_b3", self.scheduler._reqs_need_send_layerwise) - - def test_request_finished_returns_false_none(self): - ok, params = self.scheduler.request_finished(MockRequest("req_fin"), - [1, 2]) - self.assertFalse(ok) - self.assertIsNone(params) class TestHelperFunctions(unittest.TestCase): diff --git a/tests/ut/ops/test_moe_comm_method.py b/tests/ut/ops/test_moe_comm_method.py index 7620999a159..e40f67084bd 100644 --- a/tests/ut/ops/test_moe_comm_method.py +++ b/tests/ut/ops/test_moe_comm_method.py @@ -8,6 +8,8 @@ AlltoAllCommImpl, MC2CommImpl) from vllm_ascend.ops.fused_moe.prepare_finalize import QuantType +from vllm_ascend.ops.fused_moe.token_dispatcher import (TokenCombineResult, + TokenDispatchResult) class TestMoECommMethod(TestBase): @@ -178,12 +180,12 @@ def test_fused_experts_method(self, mock_unified_apply_mlp, # Mock token dispatcher mock_td_instance = MagicMock() - mock_td_instance.token_dispatch.return_value = { - "hidden_states": torch.randn(6, 8), - "group_list": torch.tensor([2, 2, 2]), - "group_list_type": 1 - } - mock_td_instance.token_combine.return_value = torch.randn(4, 8) + mock_td_instance.token_dispatch.return_value = TokenDispatchResult( + hidden_states=torch.randn(6, 8), + group_list=torch.tensor([2, 2, 2]), + group_list_type=1) + mock_td_instance.token_combine.return_value = TokenCombineResult( + routed_out=torch.randn(4, 8)) mock_token_dispatcher.return_value = mock_td_instance # Mock unified_apply_mlp @@ -213,7 +215,7 @@ def test_fused_experts_method(self, mock_unified_apply_mlp, activation="silu") # Verify result shape - self.assertEqual(result.shape, (4, 8)) + self.assertEqual(result.routed_out.shape, (4, 8)) # Verify token_dispatch was called mock_td_instance.token_dispatch.assert_called_once() diff --git a/tests/ut/ops/test_token_dispatcher.py b/tests/ut/ops/test_token_dispatcher.py index 140bae5cd20..027815ba0c8 100644 --- a/tests/ut/ops/test_token_dispatcher.py +++ b/tests/ut/ops/test_token_dispatcher.py @@ -97,8 +97,7 @@ def test_token_permutation_dispatch(self): topk_weights, topk_ids, expert_map) mock_dispatch.assert_called_once() - self.assertEqual(output["group_list_type"], - 0) # group_list_type == 0 + self.assertEqual(output.group_list_type, 0) # group_list_type == 0 def test_token_dispatch_with_shared_experts_and_quant(self): self.shared_experts = MagicMock() @@ -149,43 +148,6 @@ def test_get_combine_mc_kwargs_with_quant(self): context_metadata) self.assertIn("tp_send_counts", kwargs) - def test_token_combine_with_shared_experts(self): - shared_experts = MagicMock() - shared_experts.down_proj.return_value = (torch.randn(10, 128), - torch.tensor(1.0)) - - topk_ids = torch.randint(0, 8, (10, 1)) - topk_weights = torch.randn(10, 1) - expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) - ep_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) - assist_info_for_combine = torch.arange(10) - tp_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) - - context_metadata = { - "topk_ids": topk_ids, - "topk_weights": topk_weights, - "expert_map": expert_map, - "ep_recv_counts": ep_recv_counts, - "mc2_mask": None, - "assist_info_for_combine": assist_info_for_combine, - "expand_scales": None, - "shared_experts": shared_experts, - "shared_act": torch.randn(10, 128), - "swiglu_out_scale": torch.randn(10, 1), - "tp_recv_counts": tp_recv_counts - } - - self.dispatcher.with_quant = True - self.dispatcher.need_extra_args = True - self.dispatcher.enable_dispatch_v2 = True - - hidden_states = torch.randn(10, 128) - with patch("torch_npu.npu_moe_distribute_combine_v2", - return_value=torch.randn(10, 128)): - result = self.dispatcher.token_combine(hidden_states, - context_metadata) - self.assertIsInstance(result, tuple) - class TestTokenDispatcherWithAllGather(TestBase): @@ -233,7 +195,7 @@ def test_token_dispatch_without_expert_map(self): self.mock_npu_moe_init_routing_v2.assert_called_once() args, kwargs = self.mock_npu_moe_init_routing_v2.call_args - self.assertEqual(results["group_list_type"], 1) + self.assertEqual(results.group_list_type, 1) def test_token_dispatch_with_expert_map(self): self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3]) @@ -248,7 +210,7 @@ def test_token_dispatch_with_expert_map(self): self.mock_npu_moe_init_routing_v2.assert_called_once() args, kwargs = self.mock_npu_moe_init_routing_v2.call_args - self.assertEqual(results["group_list_type"], 1) + self.assertEqual(results.group_list_type, 1) def test_token_dispatch_without_quant(self): kwargs = { @@ -268,7 +230,7 @@ def test_token_dispatch_without_quant(self): topk_weights, topk_ids, None) - self.assertEqual(results["group_list_type"], 1) + self.assertEqual(results.group_list_type, 1) def test_token_dispatch_with_quant(self): kwargs = { @@ -290,10 +252,10 @@ def test_token_dispatch_with_quant(self): None, with_quant=True) - self.assertIsNotNone(results["hidden_states"]) - self.assertIsNotNone(results["group_list"]) - self.assertIsNotNone(results["dynamic_scale"]) - self.assertEqual(results["group_list_type"], 1) + self.assertIsNotNone(results.hidden_states) + self.assertIsNotNone(results.group_list) + self.assertIsNotNone(results.dynamic_scale) + self.assertEqual(results.group_list_type, 1) def test_token_combine_with_expert_map(self): hidden_states = torch.randn(6, 128) @@ -303,7 +265,7 @@ def test_token_combine_with_expert_map(self): } self.dispatcher.original_shape = (6, 128) final_hidden_states = self.dispatcher.token_combine( - hidden_states, context_metadata) + hidden_states, context_metadata).routed_out self.assertEqual(final_hidden_states.shape, (6, 128)) def test_token_combine_without_expert_map(self): @@ -314,7 +276,7 @@ def test_token_combine_without_expert_map(self): } self.dispatcher.original_shape = (6, 128) final_hidden_states = self.dispatcher.token_combine( - hidden_states, context_metadata) + hidden_states, context_metadata).routed_out self.mock_npu_moe_token_unpermute.assert_called_once() self.assertEqual(final_hidden_states.shape, (6, 128)) @@ -326,7 +288,7 @@ def test_token_dispatch_with_router_weight(self): results = self.dispatcher.token_dispatch(hidden_states, topk_weights, topk_ids, None) - self.assertEqual(results["hidden_states"].shape, (6, 128)) + self.assertEqual(results.hidden_states.shape, (6, 128)) class TestTokenDispatcherWithAll2AllV(TestBase): @@ -437,9 +399,9 @@ def test_token_dispatch(self): topk_ids=topk_ids, expert_map=expert_map) - self.assertIsNotNone(result["hidden_states"]) - self.assertIsNotNone(result["group_list"]) - self.assertEqual(result["group_list_type"], 1) + self.assertIsNotNone(result.hidden_states) + self.assertIsNotNone(result.group_list) + self.assertEqual(result.group_list_type, 1) def test_token_combine(self): hidden_states = torch.randn(16, 16) @@ -458,7 +420,7 @@ def test_token_combine(self): output = self.dispatcher.token_combine(hidden_states, context_metadata) self.assertIsNotNone(output) - self.assertEqual(output.shape, (8, 16)) + self.assertEqual(output.routed_out.shape, (8, 16)) def test_token_dispatch_with_quant(self): self.dispatcher = TokenDispatcherWithAll2AllV(top_k=2, @@ -480,10 +442,10 @@ def test_token_dispatch_with_quant(self): expert_map=expert_map, with_quant=True) - self.assertIsNotNone(result["hidden_states"]) - self.assertIsNotNone(result["group_list"]) - self.assertIsNotNone(result["dynamic_scale"]) - self.assertEqual(result["group_list_type"], 1) + self.assertIsNotNone(result.hidden_states) + self.assertIsNotNone(result.group_list) + self.assertIsNotNone(result.dynamic_scale) + self.assertEqual(result.group_list_type, 1) def test_token_dispatch_with_quant_no_active_tokens(self): self.dispatcher = TokenDispatcherWithAll2AllV(top_k=2, @@ -508,10 +470,10 @@ def test_token_dispatch_with_quant_no_active_tokens(self): expert_map=expert_map, with_quant=True) - self.assertIsNotNone(result["hidden_states"]) - self.assertIsNotNone(result["group_list"]) - self.assertIsNotNone(result["dynamic_scale"]) - self.assertEqual(result["group_list_type"], 1) + self.assertIsNotNone(result.hidden_states) + self.assertIsNotNone(result.group_list) + self.assertIsNotNone(result.dynamic_scale) + self.assertEqual(result.group_list_type, 1) def test_token_dispatch_with_log2phy(self): hidden_states = torch.randn(8, 16) @@ -530,6 +492,6 @@ def test_token_dispatch_with_log2phy(self): expert_map=expert_map, log2phy=log2phy) - self.assertIsNotNone(result["hidden_states"]) - self.assertIsNotNone(result["group_list"]) - self.assertEqual(result["group_list_type"], 1) + self.assertIsNotNone(result.hidden_states) + self.assertIsNotNone(result.group_list) + self.assertEqual(result.group_list_type, 1) diff --git a/tests/ut/test_ascend_config.py b/tests/ut/test_ascend_config.py index 1a337dea4f1..5cccc02797a 100644 --- a/tests/ut/test_ascend_config.py +++ b/tests/ut/test_ascend_config.py @@ -39,6 +39,7 @@ def test_init_ascend_config_without_additional_config(self): ascend_config = init_ascend_config(test_vllm_config) self.assertIsNone(ascend_config.expert_map_path) self.assertFalse(ascend_config.multistream_overlap_shared_expert) + self.assertFalse(ascend_config.enable_kv_nz) ascend_compilation_config = ascend_config.ascend_compilation_config self.assertTrue(ascend_compilation_config.fuse_norm_quant) @@ -53,6 +54,7 @@ def test_init_ascend_config_with_additional_config(self): "multistream_overlap_shared_expert": True, "expert_map_path": "test_expert_map_path", "refresh": True, + "enable_kv_nz": False } ascend_config = init_ascend_config(test_vllm_config) self.assertEqual(ascend_config.expert_map_path, "test_expert_map_path") @@ -61,6 +63,7 @@ def test_init_ascend_config_with_additional_config(self): ascend_compilation_config = ascend_config.ascend_compilation_config self.assertFalse(ascend_compilation_config.fuse_norm_quant) + self.assertFalse(ascend_config.enable_kv_nz) @_clean_up_ascend_config def test_init_ascend_config_enable_npugraph_ex(self): diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 8be434a18c9..fec3ade854c 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -13,18 +13,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import TYPE_CHECKING, Optional from vllm.logger import logger from vllm.triton_utils import HAS_TRITON +if TYPE_CHECKING: + from vllm.config import VllmConfig + class AscendConfig: """ Configuration Object for additional_config from vllm.configs. """ - def __init__(self, vllm_config): + def __init__(self, vllm_config: "VllmConfig"): additional_config = vllm_config.additional_config if vllm_config.additional_config is not None else {} xlite_graph_config = additional_config.get("xlite_graph_config", {}) @@ -121,6 +124,19 @@ def __init__(self, vllm_config): self.enable_async_exponential = bool( additional_config.get("enable_async_exponential", False)) + self.enable_kv_nz = additional_config.get("enable_kv_nz", False) + if self.enable_kv_nz: + use_sparse = hasattr(vllm_config.model_config.hf_config, + "index_topk") + if not vllm_config.model_config.is_deepseek_mla or use_sparse: + raise RuntimeError( + "enable_kv_nz is only supported for mla currently.") + if vllm_config.kv_transfer_config is None \ + or not vllm_config.kv_transfer_config.is_kv_consumer: + raise NotImplementedError( + "enable_kv_nz is only supported in pd scenario and can " + "only be used in D node.") + class FinegrainedTPConfig: """ diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 80a481c39b8..1405ed9f66c 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -176,6 +176,8 @@ class AscendMetadata: causal: bool = True # runner_type in model_config. model_runner_type: str = "" + # prefill reshape_and_cache event + reshape_cache_event: torch.npu.Event = None # sliding window attention mask swa_mask: Optional[torch.Tensor] = None @@ -333,6 +335,7 @@ def __init__( self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.key_cache = None self.value_cache = None + self.is_kv_producer = self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer def full_graph_fia(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_metadata: AscendMetadata, @@ -437,8 +440,7 @@ def full_graph_pa( block_table=attn_metadata.block_tables, context_lens=attn_metadata.seq_lens, out=output) - update_graph_params_workspaces(num_tokens, - weak_ref_tensors(workspace)) + update_graph_params_workspaces(num_tokens, workspace) # Handle graph capturing mode stream = torch_npu.npu.current_stream() @@ -654,6 +656,8 @@ def reshape_and_cache( ): if len(kv_cache) > 1: + if self.is_kv_producer: + attn_metadata.reshape_cache_event = torch.npu.Event() if self.key_cache is None: self.key_cache, self.value_cache = kv_cache[0], kv_cache[1] slots = attn_metadata.slot_mapping @@ -674,6 +678,8 @@ def reshape_and_cache( key_cache=self.key_cache, value_cache=self.value_cache, slot_indices=slots[:attn_metadata.num_actual_tokens]) + if self.is_kv_producer: + attn_metadata.reshape_cache_event.record() return key, value def forward_impl( diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 76f2e4102ac..b6f90c717d4 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -166,6 +166,7 @@ class AscendMLAMetadata: decode: Optional[AscendMLADecodeMetadata] = None prefill: Optional[AscendMLAPrefillMetadata] = None + reshape_cache_event: torch.npu.Event = None def __post_init__(self): pass @@ -705,6 +706,7 @@ def __init__( kv_sharing_target_layer_name: Optional[str], **kwargs, ): + self.vllm_config = get_current_vllm_config() self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) @@ -745,12 +747,15 @@ def __init__( ascend_config = get_ascend_config() self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp self.enable_prefetch = ascend_config.weight_prefetch_config.enabled + self.enable_kv_nz = ascend_config.enable_kv_nz self.ring_mla_mask_size = 512 self.speculative_config = self.vllm_config.speculative_config self.enable_mlapo = envs.VLLM_ASCEND_ENABLE_MLAPO + self.is_kv_producer = self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer + def _v_up_proj(self, x): # Convert from (N, B, L)/(N, B, 1, L) to (N, B, L) x = x.view(self.num_heads, -1, self.kv_lora_rank) @@ -1073,7 +1078,7 @@ def exec_kv_decode( # npu_kv_rmsnorm_rope_cache needs [B, N, S, D] kv_no_split = kv_no_split.view( B, N, S, self.kv_lora_rank + self.qk_rope_head_dim) - cache_mode = "PA" + cache_mode = "PA_NZ" if self.enable_kv_nz else "PA" k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache( kv_no_split, self.kv_a_layernorm.weight, @@ -1143,37 +1148,57 @@ def _forward_decode( # shape of knope/k_pe for npu graph mode should be: # [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim] actual_seq_lengths = None - k_nope = k_nope.view(-1, self.num_kv_heads, block_size, - self.kv_lora_rank) - k_pe = k_pe.view(-1, self.num_kv_heads, block_size, - self.qk_rope_head_dim) + if self.enable_kv_nz: + nz_fmt_last_dim = 16 + k_nope = k_nope.view(-1, self.num_kv_heads, + self.kv_lora_rank // nz_fmt_last_dim, + block_size, nz_fmt_last_dim) + k_pe = k_pe.view(-1, self.num_kv_heads, + self.qk_rope_head_dim // nz_fmt_last_dim, + block_size, nz_fmt_last_dim) + else: + k_nope = k_nope.view(-1, self.num_kv_heads, block_size, + self.kv_lora_rank) + k_pe = k_pe.view(-1, self.num_kv_heads, block_size, + self.qk_rope_head_dim) + attn_output_shape: tuple | None = None if attn_metadata.attn_state in [ AscendAttentionState.SpecDecoding, AscendAttentionState.ChunkedPrefill, AscendAttentionState.DecodeOnly, ] and self.speculative_config is not None: - # Input shape: [num_tokens, num_heads, dim] - # Output shape: [num_heads, num_tokens, dim] # The right part layout indicates the layout of the attention # output. It is set to NTD to avoid the need for a transpose # operation after attention. input_layout = "TND_NTD" # TODO: If the driver is upgraded later, the contiguous function can be deleted. + # Input shape: [num_tokens, num_heads, dim] q_nope = q_nope.view(num_tokens, self.num_heads, -1).contiguous() q_pe = q_pe.view(num_tokens, self.num_heads, -1) + # Output shape: [num_heads, num_tokens, dim] + attn_output_shape = (self.num_heads, num_tokens, self.kv_lora_rank) sparse_mode = 3 spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore actual_seq_lengths = decode_meta.actual_seq_lengths_q else: - # Input shape: [num_reqs, num_heads, seq_len, dim] - # Output shape: [num_heads, num_reqs, seq_len, dim] # The output layout is set to NBSD to eliminate the need for a # transpose operation after attention. - input_layout = "BNSD_NBSD" - q_nope = q_nope.view(num_tokens, self.num_heads, 1, - -1).contiguous() - q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1) + if self.enable_kv_nz: + # Input shape: [num_tokens, seq_len, num_heads, dim] + input_layout = "BSND_NBSD" + q_nope = q_nope.view(num_tokens, 1, self.num_heads, + -1).contiguous() + q_pe = q_pe.view(num_tokens, 1, self.num_heads, -1) + else: + # Input shape: [num_tokens, num_heads, seq_len, dim] + input_layout = "BNSD_NBSD" + q_nope = q_nope.view(num_tokens, self.num_heads, 1, + -1).contiguous() + q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1) + # Output shape: [num_heads, num_tokens, seq_len, dim] + attn_output_shape = (self.num_heads, num_tokens, 1, + self.kv_lora_rank) sparse_mode = 0 spec_attn_mask = None @@ -1215,10 +1240,9 @@ def _forward_decode( else: update_graph_params_workspaces(num_tokens, workspace) - attn_output = torch.empty( - (q_nope.shape[1], q_nope.shape[0], *q_nope.shape[2:]), - dtype=q_nope.dtype, - device=q_nope.device) + attn_output = torch.empty(attn_output_shape, + dtype=q_nope.dtype, + device=q_nope.device) softmax_lse = torch.empty(num_tokens, dtype=q_nope.dtype, device=q_nope.device) @@ -1297,7 +1321,7 @@ def _mla_preprocess_only_decode(self, hidden_states, kv_cache, bias1=self.qb_qt_bias, ctkv_scale=self.ctkv_scale, q_nope_scale=self.q_nope_scale, - cache_mode="krope_ctkv", + cache_mode="nzcache" if self.enable_kv_nz else "krope_ctkv", quant_mode="per_tensor_quant_asymm", q_out0=decode_q_nope, kv_cache_out0=decode_k_nope, @@ -1331,8 +1355,12 @@ def mla_preprocess_prefill(self, q_c, kv_no_split, kv_cache, prefill_slots = attn_metadata.slot_mapping[ num_decode_tokens:num_actual_tokens] prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin) + if self.is_kv_producer: + attn_metadata.reshape_cache_event = torch.npu.Event() prefill_k_pe, prefill_k_c_normed = self.exec_kv_prefill( prefill_kv_no_split, cos, sin, kv_cache, prefill_slots) + if self.is_kv_producer: + attn_metadata.reshape_cache_event.record() prefill_k_nope, prefill_value = self.kv_b_proj( prefill_k_c_normed)[0].view( -1, self.num_heads, diff --git a/vllm_ascend/compilation/acl_graph.py b/vllm_ascend/compilation/acl_graph.py index 3f28a3a3632..72cf925d568 100644 --- a/vllm_ascend/compilation/acl_graph.py +++ b/vllm_ascend/compilation/acl_graph.py @@ -162,6 +162,13 @@ def __call__(self, *args, **kwargs): # any other acl graph. output = weak_ref_tensors(output) + # here we always use weak ref for the workspaces + # to save memory + global _graph_params + global _draft_graph_params + weak_ref_workspaces(_graph_params) + weak_ref_workspaces(_draft_graph_params) + # here we always use weak ref for the output # to save memory entry.output = weak_ref_tensors(output) @@ -195,6 +202,16 @@ def __call__(self, *args, **kwargs): return entry.output +def weak_ref_workspaces(params): + if params is None: + return + for num_tokens in params.workspaces: + if params.workspaces[num_tokens] is None: + continue + params.workspaces[num_tokens] = weak_ref_tensors( + params.workspaces[num_tokens]) + + def _update_attn_pa_params(update_stream, forward_context, runtime_shape): graph_params = get_graph_params() # FIXME: Behold! We are using a temporary hack here to update the args @@ -523,7 +540,7 @@ def set_graph_params(aclgraph_capture_sizes: list[int]): def update_graph_params_workspaces(num_tokens: int, workspace: torch.Tensor): global _graph_params if _graph_params is not None: - _graph_params.workspaces[num_tokens] = weak_ref_tensors(workspace) + _graph_params.workspaces[num_tokens] = workspace def get_graph_params(): diff --git a/vllm_ascend/compilation/npugraph_ex_passes/add_rms_norm_quant.py b/vllm_ascend/compilation/npugraph_ex_passes/add_rms_norm_quant.py index 724d8140e55..3de71e611e8 100644 --- a/vllm_ascend/compilation/npugraph_ex_passes/add_rms_norm_quant.py +++ b/vllm_ascend/compilation/npugraph_ex_passes/add_rms_norm_quant.py @@ -25,7 +25,7 @@ @functools.lru_cache(None) # The replacement registered here will be actually executed after AOT. -def _register_replacement(epsilon): +def replacement_add_rms_norm_quant(epsilon): if 'torch_npu' not in sys.modules: logger.info( 'The AddRMSNormQuant fusion will only be enabled in a torch npu env.' @@ -114,10 +114,108 @@ def get_inputs(): extra_check=_extra_stream_scope_check) +@functools.lru_cache(None) +# The replacement registered here will be actually executed after AOT. +def replacement_add_rms_norm_quant_with_bias(epsilon): + if 'torch_npu' not in sys.modules: + logger.info( + 'The AddRMSNormQuantWithBias fusion will only be enabled in a torch npu env.' + 'When there is no torch_npu in the env, skip fusion.') + return + + def _extra_stream_scope_check(match: Match) -> bool: + """ + Checks if all nodes in the same stream. + """ + non_default_streams = set() + has_default = False + + for node in match.nodes: + if node.op == "call_function": + current_stream = node.meta.get("stream_label") + if current_stream is None: + has_default = True + else: + non_default_streams.add(current_stream) + if len(non_default_streams) > 1: + logger.debug( + f"Cross-stream operation detected in pattern match for AddRMSNormQuantWithBias. " + f"Multiple streams found: {non_default_streams}. " + f"Fusion is not supported for cross-stream operations." + ) + return False + + if has_default and len(non_default_streams) > 0: + logger.debug( + f"Cross-stream operation detected in pattern match for AddRMSNormQuantWithBias. " + f"Multiple streams found: {non_default_streams}. " + f"Fusion is not supported for cross-stream operations.") + return False + + return True + + def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor, + rms_norm_weight: torch.Tensor, scale: torch.Tensor, + offset: torch.Tensor, bias: torch.Tensor): + """ + Pattern for AddRMSNormQuantWithBias fusion. + """ + output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, + rms_norm_weight, epsilon) + out0 = output[0] + out1 = output[2] + out0 = out0 + bias + quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset, + torch.qint8, -1, False) + return quantized_output, out1 + + def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor, + rms_norm_weight: torch.Tensor, scale: torch.Tensor, + offset: torch.Tensor, bias: torch.Tensor): + """ + Replacement for AddRMSNormQuantWithBias fusion. + """ + output = torch.ops.npu.npu_add_rms_norm_quant( + rms_norm_input, + residual, + rms_norm_weight, + # The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel. + 1. / scale, + offset, + epsilon=epsilon, + beta=bias) + quantized_output = output[0] + out1 = output[2] + return quantized_output, out1 + + def get_inputs(): + """ + Generate example inputs for the AddRMSNormQuantWithBias fusion pattern. + """ + rms_norm_input = torch.randn(2, 4, device="npu") + residual = torch.randn(2, 4, device="npu") + rms_norm_weight = torch.randn(4, device="npu") + rmsnorm_bias = torch.randn(4, device="npu") + scale = torch.ones(4, device="npu") + offset = torch.zeros(4, device="npu") + return [ + rms_norm_input, residual, rms_norm_weight, scale, offset, + rmsnorm_bias + ] + + import torchair + + torchair.register_replacement(search_fn=pattern, + replace_fn=replacement, + example_inputs=get_inputs(), + extra_check=_extra_stream_scope_check) + + # register converter for pass common_epsilons = [1e-5, 1e-6] for eps in common_epsilons: logger.info( f"Start register fusion pattern for AddRMSNormQuant with epsilons={eps}" ) - _register_replacement(eps) + replacement_add_rms_norm_quant(eps) + replacement_add_rms_norm_quant_with_bias(eps) diff --git a/vllm_ascend/distributed/mooncake_connector.py b/vllm_ascend/distributed/mooncake_connector.py index 382843351b9..603a89b8b64 100644 --- a/vllm_ascend/distributed/mooncake_connector.py +++ b/vllm_ascend/distributed/mooncake_connector.py @@ -412,9 +412,16 @@ def _handle_request(self, req_meta: dict[str, Any]): logger.debug( f"Finished transferring KV cache for request {request_id}.") except Exception as e: - logger.error("Failed to transfer KV cache for request " - f"{request_id}: {e}") + logger.error( + "Failed to transfer KV cache for request " + f"{request_id}: {e}", + exc_info=True) finally: + if all_task_done: + self.task_tracker.update_done_task_count(request_id) + if request_id in self.proc_not_transfer_request: + del self.proc_not_transfer_request[request_id] + self.request_queue.task_done() # Always send the done signal to the remote host to ensure proper # resource cleanup. Failing to do so may cause a memory leak on the # remote host. @@ -423,11 +430,6 @@ def _handle_request(self, req_meta: dict[str, Any]): remote_port_send_num) self._send_done_signal_to_free_remote_port(request_id, remote_host, remote_port_send_num) - if all_task_done: - self.task_tracker.update_done_task_count(request_id) - if request_id in self.proc_not_transfer_request: - del self.proc_not_transfer_request[request_id] - self.request_queue.task_done() def _send_done_signal_to_free_remote_port(self, request_id, remote_host, remote_port_send_num): @@ -539,97 +541,116 @@ def _transfer_kv_cache(self, req_meta: dict[str, Any]): request_id, req_transfer_elapsed, num_transfer_groups, num_blocks, get_ip(), self.tp_rank, session_id) - # Determine if the current position is the offset position at the end of the KV transmission. + # Determine if the current position is the offset position at the end of + # the KV transmission. is_kv_transfer_end = ( global_offset == tp_num_need_pulls * self._prefill_pp_size - 1) need_cat_cache = tp_num_need_pulls > 1 and is_kv_transfer_end - # need_nz_cache maybe caused error in non-MLA models - if need_cat_cache: - self._cat_kv_cache(grouped_local_block_ids, tp_num_need_pulls) - - def _cat_kv_cache(self, block_ids: list[list[int]], - tp_num_need_pulls: int): + need_nz_cache = get_ascend_config().enable_kv_nz and is_kv_transfer_end + if need_nz_cache or need_cat_cache: + self.reformat_kv_cache(grouped_local_block_ids, tp_num_need_pulls, + need_cat_cache, need_nz_cache) + + def reformat_kv_cache(self, + block_ids: list[list[int]], + tp_num_need_pulls: int, + need_cat_cache: bool = False, + need_nz_cache: bool = False): # Get necessary parameters k_cache = list(self.kv_caches.values())[0][0] dtype = k_cache.dtype device = k_cache.device - head_dim = self.model_config.hf_text_config.head_dim - block_size = self.vllm_config.cache_config.block_size - num_kv_head = max( - self.model_config.hf_text_config.num_key_value_heads // - self.tp_size, 1) flat_block_ids = [item for sublist in block_ids for item in sublist] - block_ids_tensor = torch.tensor(flat_block_ids, dtype=torch.int32) + block_ids_tensor = torch.tensor(flat_block_ids, + dtype=torch.int32, + device=device) num_blocks = len(flat_block_ids) - block_len = num_blocks * block_size + num_tokens = num_blocks * self.block_size # Create device tensors for copy operations - block_table = block_ids_tensor.view(1, -1).to(device=device) - block_len_tensor = torch.tensor([block_len], - dtype=torch.int32).to(device=device) - seq_start_tensor = torch.tensor([0], - dtype=torch.int32).to(device=device) + block_table = block_ids_tensor.view(1, -1) + block_len_tensor = torch.tensor([num_tokens], + dtype=torch.int32, + device=device) + seq_start_tensor = torch.tensor([0], dtype=torch.int32, device=device) # Initialize buffers - k_buffer = torch.empty(block_len, - num_kv_head, - head_dim, - dtype=dtype, - device=device) - v_buffer = torch.empty(block_len, - num_kv_head, - head_dim, - dtype=dtype, - device=device) + k_buffer = torch.empty( + (num_tokens, self.num_kv_heads, self.k_head_dim), + dtype=dtype, + device=device) + v_buffer = torch.empty( + (num_tokens, self.num_kv_heads, self.v_head_dim), + dtype=dtype, + device=device) # Create slot mapping for reshape operations - block_offsets = torch.arange(0, block_size, dtype=torch.int32) + block_offsets = torch.arange(0, + self.block_size, + dtype=torch.int32, + device=device) slot_mapping = (block_offsets.reshape( - (1, block_size)) + block_ids_tensor.reshape( - (num_blocks, 1)) * block_size) - slot_mapping = slot_mapping.flatten().to(device=device) + (1, self.block_size)) + block_ids_tensor.reshape( + (num_blocks, 1)) * self.block_size).flatten() + + # FIXME: Right now, if we skip synchronization at this point, the system + # will crash in GQA scenarios. However, we still haven't identified the + # root cause. + torch.npu.synchronize() # Process each layer in the KV cache for _, (k_cache_layer, v_cache_layer) in self.kv_caches.items(): # Load cache data into buffers - torch_npu.atb.npu_paged_cache_load( - k_cache_layer, - v_cache_layer, - block_table, - block_len_tensor, - seq_starts=seq_start_tensor, - key=k_buffer, - value=v_buffer, - ) - - # Transpose KV cache - k_buffer = self._transpose_kv_cache_between_head( - k_buffer, num_blocks, block_size, block_len, num_kv_head, - tp_num_need_pulls) - v_buffer = self._transpose_kv_cache_between_head( - v_buffer, num_blocks, block_size, block_len, num_kv_head, - tp_num_need_pulls) - - # Reshape and cache the processed buffers - torch_npu._npu_reshape_and_cache( - key=k_buffer, - value=v_buffer, - key_cache=k_cache_layer, - value_cache=v_cache_layer, - slot_indices=slot_mapping, - ) - + torch_npu.atb.npu_paged_cache_load(k_cache_layer, + v_cache_layer, + block_table, + block_len_tensor, + seq_starts=seq_start_tensor, + key=k_buffer, + value=v_buffer) + if need_cat_cache: + self._cat_kv_cache(k_cache_layer, v_cache_layer, k_buffer, + v_buffer, tp_num_need_pulls, num_blocks, + num_tokens, slot_mapping) + if need_nz_cache: + self._nz_kv_cache(k_cache_layer, v_cache_layer, k_buffer, + v_buffer, slot_mapping) # Clean up buffers del k_buffer, v_buffer - def _transpose_kv_cache_between_head( - self, buffer: torch.Tensor, num_blocks: int, block_size: int, - block_len: int, num_kv_head: int, - tp_num_need_pulls: int) -> torch.Tensor: - buffer = buffer.view(num_blocks, tp_num_need_pulls, block_size, -1) - buffer.transpose_(1, 2) - return buffer.contiguous().view(block_len, num_kv_head, -1) + def _cat_kv_cache(self, k_cache_layer, v_cache_layer, k_buffer, v_buffer, + tp_num_need_pulls, num_blocks, num_tokens, slot_mapping): + + def _transpose_kv_cache_between_head( + buffer: torch.Tensor) -> torch.Tensor: + buffer = buffer.view(num_blocks, tp_num_need_pulls, + self.block_size, -1) + buffer.transpose_(1, 2) + return buffer.contiguous().view(num_tokens, self.num_kv_heads, -1) + + # Transpose KV cache + k_buffer = _transpose_kv_cache_between_head(k_buffer) + v_buffer = _transpose_kv_cache_between_head(v_buffer) + + # Reshape and cache the processed buffers + torch_npu._npu_reshape_and_cache(key=k_buffer, + value=v_buffer, + key_cache=k_cache_layer, + value_cache=v_cache_layer, + slot_indices=slot_mapping) + + def _nz_kv_cache(self, k_cache_layer, v_cache_layer, k_buffer, v_buffer, + slot_mapping): + nz_fmt_last_dim = 16 + k_cache_layer = k_cache_layer.view( + -1, self.k_head_dim * self.num_kv_heads // nz_fmt_last_dim, + self.block_size, nz_fmt_last_dim) + v_cache_layer = v_cache_layer.view( + -1, self.v_head_dim * self.num_kv_heads // nz_fmt_last_dim, + self.block_size, nz_fmt_last_dim) + torch_npu.npu_scatter_pa_kv_cache(k_buffer, v_buffer, k_cache_layer, + v_cache_layer, slot_mapping) def _get_remote_metadata(self, remote_host: str, remote_handshake_port: int) -> None: @@ -677,6 +698,13 @@ def _send_done_recv_signal(self, request_id: str, remote_host: str, request_id, remote_host, remote_handshake_port) raise RuntimeError( f"Failed to receive ACK, resp: {resp.decode('utf-8')}") + except RuntimeError as e: + if isinstance(sock, zmq.Socket): # type: ignore + sock.close() + sock = None + logger.warning( + f"Unexpected error occurred in socket, {e}, closing the original channel" + ) finally: if sock is not None: self._return_remote_socket(sock, remote_host, diff --git a/vllm_ascend/distributed/mooncake_layerwise_connector.py b/vllm_ascend/distributed/mooncake_layerwise_connector.py index d1351049726..9d9d9301a6f 100644 --- a/vllm_ascend/distributed/mooncake_layerwise_connector.py +++ b/vllm_ascend/distributed/mooncake_layerwise_connector.py @@ -65,6 +65,32 @@ class ReqMeta: remote_te_rpc_port: Optional[int] remote_kv_caches_base_addr: Optional[list[int]] metaserver: Optional[str] + chunk_finish: Optional[bool] + + +@dataclass +class SendReqInfo: + local_block_ids: list[int] + remote_block_ids: List[int] + remote_cache_tokens: int + local_transferred_tokens: int + local_computed_tokens: int + request: "Request" + + def extend_local_block_ids(self, new_block_ids: List[int]) -> None: + """extend local block ids for this step""" + self.local_block_ids.extend(new_block_ids) + + def update_computed_tokens(self, computed_tokens: int) -> None: + """update local computen tokens for this step""" + self.local_computed_tokens = computed_tokens + + def update_transferred_tokens(self, transferred_tokens: int) -> None: + """update transferred tokens for this step""" + self.local_transferred_tokens = transferred_tokens + + def unpack(self): + return self.local_block_ids, self.remote_block_ids, self.remote_cache_tokens, self.local_transferred_tokens, self.local_computed_tokens, self.request @dataclass @@ -144,7 +170,7 @@ def __init__(self, raise RuntimeError("Mooncake memory registration failed. ") self.send_queue = queue.Queue[Tuple[str, ReqMeta, int, torch.Tensor, - torch.Tensor]]() + torch.Tensor, torch.npu.Event]]() self.ready_event = ready_event self.callback_func = callback_func @@ -155,15 +181,19 @@ def run(self): torch.npu.set_device(device) self.ready_event.set() while True: - req_id, req_meta, layer_index, key, value = self.send_queue.get() - self._handle_request(req_id, req_meta, layer_index, key, value) + req_id, req_meta, layer_index, key, value, reshape_cache_event = self.send_queue.get( + ) + self._handle_request(req_id, req_meta, layer_index, key, value, + reshape_cache_event) - def _handle_request(self, req_id, req_meta, layer_index, key, value): + def _handle_request(self, req_id, req_meta, layer_index, key, value, + reshape_cache_event): try: logger.debug( f"Starting to transfer KV cache for request {req_id} {req_meta.remote_te_rpc_port=}." ) - self._transfer_kv_cache(req_id, req_meta, layer_index, key, value) + self._transfer_kv_cache(req_id, req_meta, layer_index, key, value, + reshape_cache_event) logger.debug( f"Finished transferring KV cache for request {req_id} {req_meta.remote_te_rpc_port=}." ) @@ -171,13 +201,8 @@ def _handle_request(self, req_id, req_meta, layer_index, key, value): logger.error("Failed to transfer KV cache for request " f"{req_id}: {e}") - def _transfer_kv_cache(self, req_id, req_meta, layer_index, key, value): - # send kv layer to remote - if len(req_meta.local_block_ids) == 0: - logger.debug( - f"Cancelling KV cache transfer for request {req_id}. Reason: No local blocks to transfer." - ) - return + def _transfer_kv_cache(self, req_id, req_meta, layer_index, key, value, + reshape_cache_event): # not need to send kv cache if self.tp_rank % self.num_head_replica != 0: logger.debug( @@ -227,7 +252,13 @@ def _transfer_kv_cache(self, req_id, req_meta, layer_index, key, value): length_list.append(length) if self.current_layer != layer_index: self.current_layer = layer_index - self.model_stream.synchronize() + """ + Note: Due to a bug in ADXL, calling current_event.synchronize() may occasionally hang. + This issue will be fixed in CANN version 8.5.rc1. + You can manually build the master branch of the project at https://gitcode.com/cann/hixl + to resolve this issue before the 8.5.RC1 release. + """ + reshape_cache_event.synchronize() ret = self.engine.batch_transfer_sync_write( session_id, src_list, dst_list, length_list) if ret < 0: @@ -285,7 +316,7 @@ def _transfer_kv_cache(self, req_id, req_meta, layer_index, key, value): logger.error("Mooncake transfer failed for request %s", req_id) raise RuntimeError(f"Mooncake transfer failed, ret: {ret}") - if layer_index == (self.total_layers - 1): + if layer_index == (self.total_layers - 1) and req_meta.chunk_finish: self.callback_func(req_id, req_meta) @@ -376,7 +407,8 @@ def add_new_req(self, request_id: str, local_block_ids: list[int], kv_transfer_params: dict[str, Any], - token_ids: Optional[list[int]] = None): + token_ids: Optional[list[int]] = None, + chunk_finish: bool = False): self.requests[request_id] = ReqMeta( token_ids=token_ids or [], local_block_ids=local_block_ids, @@ -389,7 +421,7 @@ def add_new_req(self, remote_kv_caches_base_addr=kv_transfer_params.get( "remote_kv_caches_base_addr", None), metaserver=kv_transfer_params.get("metaserver", None), - ) + chunk_finish=chunk_finish) class MooncakeLayerwiseConnector(KVConnectorBase_V1): @@ -398,6 +430,7 @@ def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole, kv_cache_config: Optional[KVCacheConfig] = None): + super().__init__(vllm_config, role, kv_cache_config) assert vllm_config.kv_transfer_config is not None self.engine_id = vllm_config.kv_transfer_config.engine_id self._connector_metadata = MooncakeLayerwiseConnectorMetadata() @@ -509,9 +542,11 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): # the scheduler. Used to make metadata passed to Worker. self._reqs_need_recv: dict[str, tuple[Request, list[int], list[int]]] = {} - self._reqs_need_send_layerwise: dict[str, tuple[ - int, list[int], - Request]] = {} # req_id, (len(prompt), local_block_ids, request) + self._reqs_need_send_layerwise: dict[str, SendReqInfo] = {} + + self.executor = ThreadPoolExecutor(32) + self.metaserver_client = httpx.Client( + limits=httpx.Limits(max_connections=100000), timeout=None) def get_num_new_matched_tokens( self, request: "Request", @@ -571,14 +606,53 @@ def update_state_after_alloc(self, request: "Request", params["do_remote_prefill"] = False + logger.info( + f"Send request: {request.request_id} to proxy metaserver: {params.get('metaserver', None)}" + ) + # All parameters here should appear in the returned dict of + # request_finished in the scheduler side except "request_id". + kv_transfer_params = dict( + token_ids=[], + request_id=request.request_id, + do_remote_prefill=False, + do_remote_decode=True, + remote_block_ids=local_block_ids, + remote_engine_id=self.engine_id, + remote_host=self.side_channel_host, + remote_port=self.side_channel_port, + ) + future = self.executor.submit( + self._access_metaserver, + url=params.get("metaserver", None), + message=kv_transfer_params, + ) + + def handle_exception(future): + if future.exception(): + logger.error( + f"Access metaserver fail: {future.exception()}") + + future.add_done_callback(handle_exception) + # Layerwise prefiller add request need send if params is not None and params.get("do_remote_decode"): local_block_ids = (blocks.get_block_ids()[0]) logger.debug( f"MooncakeLayerwiseConnector update_state_after_alloc: add {request.request_id} to need send queue" ) - self._reqs_need_send_layerwise[request.request_id] = (len( - request.all_token_ids), local_block_ids, request) + remote_block_ids = copy.deepcopy(params["remote_block_ids"]) + remote_cache_tokens = ( + (len(request.all_token_ids) + self.block_size - 1) // + self.block_size - len(remote_block_ids)) * self.block_size + local_transferred_tokens = remote_cache_tokens + local_computed_tokens = 0 + self._reqs_need_send_layerwise[request.request_id] = SendReqInfo( + local_block_ids=local_block_ids, + remote_block_ids=remote_block_ids, + remote_cache_tokens=remote_cache_tokens, + local_transferred_tokens=local_transferred_tokens, + local_computed_tokens=local_computed_tokens, + request=request) def build_connector_meta( self, @@ -586,55 +660,118 @@ def build_connector_meta( ) -> KVConnectorMetadata: meta = MooncakeLayerwiseConnectorMetadata() - # Loop through scheduled reqs and convert to ReqMeta. - for req_id, (req, token_ids, - block_ids) in self._reqs_need_recv.items(): - assert req.kv_transfer_params is not None - # For the case where there are no remote blocks to pull - # (block_ids is empty), we don't need to schedule - # an async read on the worker side. - meta.add_new_req(request_id=req_id, - local_block_ids=block_ids, - kv_transfer_params=req.kv_transfer_params, - token_ids=token_ids) - - # Clear the list once workers start the transfers - self._reqs_need_recv.clear() - - cached_reqs = scheduler_output.scheduled_cached_reqs - new_reqs = scheduler_output.scheduled_new_reqs - for req_id, new_blocks in zip(cached_reqs.req_ids, - cached_reqs.new_block_ids): - if req_id in self._reqs_need_send_layerwise and new_blocks is not None: - total_tokens, block_ids, req = self._reqs_need_send_layerwise[ - req_id] - block_ids.extend(new_blocks[0]) - - computed_tokens = dict( - list(zip(cached_reqs.req_ids, cached_reqs.num_computed_tokens)) + - [(x.req_id, x.num_computed_tokens) for x in new_reqs]) - for req_id, scheduled_tokens in scheduler_output.num_scheduled_tokens.items( - ): - if req_id in self._reqs_need_send_layerwise: - total_tokens, block_ids, req = self._reqs_need_send_layerwise[ - req_id] - current_tokens = computed_tokens.get(req_id, - 0) + scheduled_tokens - if current_tokens >= total_tokens: - logger.debug( - f"MooncakeLayerwiseConnector build_connector_meta: add {req_id}, current tokens({current_tokens}={computed_tokens.get(req_id,0)}+{scheduled_tokens}), total tokens({total_tokens})" - ) - meta.add_new_req(request_id=req_id, - local_block_ids=block_ids, - kv_transfer_params=req.kv_transfer_params, - token_ids=[]) - self._reqs_need_send_layerwise.pop(req_id) - else: - logger.debug( - f"MooncakeLayerwiseConnector build_connector_meta: skip {req_id}, current tokens({current_tokens}={computed_tokens.get(req_id,0)}+{scheduled_tokens}), total tokens({total_tokens})" - ) + if self.vllm_config.kv_transfer_config.is_kv_consumer: + # Loop through scheduled reqs and convert to ReqMeta. + for req_id, (req, token_ids, + block_ids) in self._reqs_need_recv.items(): + assert req.kv_transfer_params is not None + # For the case where there are no remote blocks to pull + # (block_ids is empty), we don't need to schedule + # an async read on the worker side. + meta.add_new_req(request_id=req_id, + local_block_ids=block_ids, + kv_transfer_params=req.kv_transfer_params, + token_ids=token_ids) + + # Clear the list once workers start the transfers + self._reqs_need_recv.clear() + else: + cached_reqs = scheduler_output.scheduled_cached_reqs + new_reqs = scheduler_output.scheduled_new_reqs + scheduled_spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens + # update local block ids + for req_id, new_blocks in zip(cached_reqs.req_ids, + cached_reqs.new_block_ids): + if req_id in self._reqs_need_send_layerwise and new_blocks is not None: + self._reqs_need_send_layerwise[ + req_id].extend_local_block_ids(new_blocks[0]) + + computed_tokens = dict( + list(zip(cached_reqs.req_ids, cached_reqs.num_computed_tokens)) + + [(x.req_id, x.num_computed_tokens) for x in new_reqs]) + for req_id, scheduled_tokens in scheduler_output.num_scheduled_tokens.items( + ): + if req_id in self._reqs_need_send_layerwise: + send_req_info = self._reqs_need_send_layerwise[req_id] + # update local computed tokens, not transfer spec decode tokens + spec_decode_tokens = len( + scheduled_spec_decode_tokens[req_id]) if ( + req_id in scheduled_spec_decode_tokens) else 0 + send_req_info.update_computed_tokens( + computed_tokens.get(req_id, 0) + scheduled_tokens - + spec_decode_tokens) + + def add_tranfer_task(req_id, + send_req_info: SendReqInfo, + chunk_finish=False): + local_block_ids, remote_block_ids, remote_cache_tokens, local_transferred_tokens, local_computed_tokens, request = send_req_info.unpack( + ) + local_trans_block_ids = local_block_ids[( + local_transferred_tokens // + self.block_size):(local_computed_tokens // + self.block_size)] + remote_trans_block_ids = remote_block_ids[( + (local_transferred_tokens - remote_cache_tokens) // + self.block_size):((local_computed_tokens - + remote_cache_tokens) // + self.block_size)] + request.kv_transfer_params[ + "remote_block_ids"] = remote_trans_block_ids + assert len(local_trans_block_ids) == len( + remote_trans_block_ids + ), f"len of local trans block ids : {len(local_trans_block_ids)} not equal to the len of remote trans block ids : {len(remote_trans_block_ids)}" + adjusted_tokens = local_computed_tokens - ( + self.block_size - + 1) if chunk_finish else local_computed_tokens + logger.info( + f"MooncakeLayerwiseConnector scheduler add transfer task: {req_id=} {local_block_ids=} {remote_block_ids=} {local_trans_block_ids=} {remote_trans_block_ids=} local_computed_tokens={adjusted_tokens} request.all_token_ids={len(request.all_token_ids)}" + ) + meta.add_new_req( + request_id=req_id, + local_block_ids=local_trans_block_ids, + kv_transfer_params=request.kv_transfer_params, + token_ids=[], + chunk_finish=chunk_finish) + # update local_transferred_tokens + local_transferred_tokens = ( + local_computed_tokens // + self.block_size) * self.block_size + send_req_info.update_transferred_tokens( + local_transferred_tokens) + + # no chunk or last chunk + if send_req_info.local_computed_tokens >= len( + send_req_info.request.all_token_ids): + send_req_info.update_computed_tokens( + send_req_info.local_computed_tokens + + self.block_size - 1) + add_tranfer_task(req_id, + send_req_info, + chunk_finish=True) + self._reqs_need_send_layerwise.pop(req_id) + # chunk + elif (send_req_info.local_computed_tokens // + self.block_size) - ( + send_req_info.local_transferred_tokens // + self.block_size) > 0: + add_tranfer_task(req_id, send_req_info) return meta + def _access_metaserver(self, url, message): + success = False + retry = 0 + while retry < 3 and success is False: + retry += 1 + try: + self.metaserver_client.post(url, json=message) + success = True + except Exception as e: + logger.error( + f"Failed to connect to metaserver: {url}, retry {retry} time." + ) + if retry == 3: + raise e + def request_finished( self, request: "Request", @@ -676,11 +813,6 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self.total_layers = vllm_config.model_config.get_num_layers( vllm_config.parallel_config) - self.executor = ThreadPoolExecutor(32) - self.metaserver_client = httpx.Client( - limits=httpx.Limits(max_connections=100000), - timeout=None) if self.tp_rank == 0 else None - # Handshake base port self.side_channel_port = ( vllm_config.kv_transfer_config.kv_port + @@ -834,21 +966,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.kv_recv_layer_thread.start() ready_event.wait() - def _access_metaserver(self, url, message): - success = False - retry = 0 - while retry < 3 and success is False: - retry += 1 - try: - self.metaserver_client.post(url, json=message) - success = True - except Exception as e: - logger.error( - f"Failed to connect to metaserver: {url}, retry {retry} time." - ) - if retry == 3: - raise e - def get_finished(self) -> tuple[set[str], set[str]]: done_recving = ( self.kv_recv_layer_thread. @@ -865,35 +982,6 @@ def start_load_kv(self, metadata: MooncakeLayerwiseConnectorMetadata): self.current_layer = 0 if self.vllm_config.kv_transfer_config.is_kv_consumer: for req_id, meta in metadata.requests.items(): - if self.tp_rank % self.tp_size == 0: - logger.info( - f"Send request: {req_id} to proxy metaserver: {meta.metaserver}" - ) - # All parameters here should appear in the returned dict of - # request_finished in the scheduler side except "request_id". - kv_transfer_params = dict( - token_ids=meta.token_ids, - request_id=req_id, - do_remote_prefill=False, - do_remote_decode=True, - remote_block_ids=meta.local_block_ids, - remote_engine_id=self.engine_id, - remote_host=self.side_channel_host, - remote_port=self.side_channel_port, - ) - future = self.executor.submit( - self._access_metaserver, - url=meta.metaserver, - message=kv_transfer_params, - ) - - def handle_exception(future): - if future.exception(): - logger.error( - f"Access metaserver fail: {future.exception()}" - ) - - future.add_done_callback(handle_exception) assert self.kv_recv_layer_thread is not None with self.kv_recv_layer_thread.lock: self.kv_recv_layer_thread.task_tracker[req_id] = 0 @@ -907,12 +995,12 @@ def save_kv_layer(self, layer_name: str, kv_layer: Tuple[torch.Tensor, if self.vllm_config.kv_transfer_config.is_kv_producer and connector_metadata.requests.keys( ): # enable decode prefix cache - for request in connector_metadata.requests.values(): - assert len(request.local_block_ids) >= len( - request.remote_block_ids - ), "When prefix cache enabled, remote KVCacheBlocks num should not larger than local KVCacheBlocks num." - request.local_block_ids = request.local_block_ids[ - -len(request.remote_block_ids):] + if self.use_mla: + reshape_cache_event = attn_metadata[ + layer_name].reshape_cache_event + else: + reshape_cache_event = attn_metadata.reshape_cache_event + if self.pd_head_ratio != 1: def sort_kv_cache(input_kv: list[list[int]]): @@ -964,8 +1052,10 @@ def sort_kv_cache(input_kv: list[list[int]]): f"Add request {req_id} to kv send layer thread. {req_meta_update=}" ) assert self.kv_send_layer_thread is not None + assert reshape_cache_event is not None self.kv_send_layer_thread.send_queue.put( - (req_id, req_meta_update, self.current_layer, key, value)) + (req_id, req_meta_update, self.current_layer, key, value, + reshape_cache_event)) self.current_layer += 1 def _get_remote_socket( diff --git a/vllm_ascend/ops/fused_moe/experts_selector.py b/vllm_ascend/ops/fused_moe/experts_selector.py index 51e0cb9f24c..39200a867e5 100644 --- a/vllm_ascend/ops/fused_moe/experts_selector.py +++ b/vllm_ascend/ops/fused_moe/experts_selector.py @@ -225,7 +225,7 @@ def _select_experts_with_fusion_ops( norm_type=norm_type, # 0: softmax; 1: sigmoid # out_flag=False, # todo new api; should the third output be output # y2_flag=False, # old api; should the third output be output - routed_scaling_factor=1, + routed_scaling_factor=routed_scaling_factor, eps=float(1e-20)) if scoring_func == "softmax": topk_weights = _renormalize_topk_weights(topk_weights, renormalize) @@ -304,3 +304,28 @@ def _native_select_experts( topk_weights = _renormalize_topk_weights(topk_weights, renormalize) return topk_weights, topk_ids + + +def zero_experts_compute( + expert_indices: torch.Tensor, + expert_scales: torch.Tensor, + num_experts: int, + zero_expert_type: str, + hidden_states: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if zero_expert_type == "identity": + zero_expert_mask = expert_indices < num_experts + zero_expert_scales = expert_scales.clone() + zero_expert_scales = torch.where(zero_expert_mask, 0.0, + zero_expert_scales) + + hidden_states = hidden_states.unsqueeze(1) + zero_expert_scales = zero_expert_scales.unsqueeze(2) + result = hidden_states * zero_expert_scales + result = result.sum(dim=1) + + normal_expert_mask = expert_indices >= num_experts + expert_indices = torch.where(normal_expert_mask, 0, expert_indices) + expert_scales = torch.where(normal_expert_mask, 0.0, expert_scales) + + return expert_indices, expert_scales, result diff --git a/vllm_ascend/ops/fused_moe/fused_moe.py b/vllm_ascend/ops/fused_moe/fused_moe.py index 23f327d37ca..efc709a36c9 100644 --- a/vllm_ascend/ops/fused_moe/fused_moe.py +++ b/vllm_ascend/ops/fused_moe/fused_moe.py @@ -35,8 +35,10 @@ from vllm_ascend.eplb.utils import moe_load_async_stream from vllm_ascend.flash_common3_context import (get_flash_common3_context, set_flash_common3_context) -from vllm_ascend.ops.fused_moe.experts_selector import select_experts +from vllm_ascend.ops.fused_moe.experts_selector import (select_experts, + zero_experts_compute) from vllm_ascend.ops.fused_moe.moe_comm_method import (AllGatherCommImpl, + FusedExpertsResult, setup_moe_comm_method) from vllm_ascend.ops.fused_moe.prepare_finalize import QuantType from vllm_ascend.quantization.w4a8_dynamic import \ @@ -91,7 +93,8 @@ def apply(self, enable_force_load_balance: bool = False, shared_experts: Optional[Any] = None, **kwargs) -> torch.Tensor: - + zero_expert_num = getattr(layer, "zero_expert_num", 0) + zero_expert_type = getattr(layer, "zero_expert_type", None) topk_weights, topk_ids = select_experts( hidden_states=x, router_logits=router_logits, @@ -106,6 +109,15 @@ def apply(self, e_score_correction_bias=e_score_correction_bias, global_num_experts=global_num_experts) + if zero_expert_num > 0 and zero_expert_type is not None: + topk_ids, topk_weights, zero_expert_result = zero_experts_compute( + expert_indices=topk_ids, + expert_scales=topk_weights, + num_experts=global_num_experts, + zero_expert_type=zero_expert_type, + hidden_states=x, + ) + topk_weights = topk_weights.to(x.dtype) # this is a naive implementation for experts load balance so as # to avoid accumulating too much tokens on a single rank. @@ -118,7 +130,7 @@ def apply(self, random_matrix, dim=1)[:, :topk_ids.size(1)].to(topk_ids.dtype) moe_comm_method = get_forward_context().moe_comm_method - return moe_comm_method.fused_experts( + final_hidden_states = moe_comm_method.fused_experts( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, @@ -130,6 +142,9 @@ def apply(self, apply_router_weight_on_input=apply_router_weight_on_input, dynamic_eplb=self.dynamic_eplb, mc2_mask=kwargs.get("mc2_mask", None)) + if zero_expert_num > 0 and zero_expert_type is not None: + final_hidden_states += zero_expert_result + return final_hidden_states class AscendFusedMoE(FusedMoE): @@ -325,7 +340,7 @@ def forward_impl(self, hidden_states: torch.Tensor, pertoken_scale = None # Matrix multiply. - final_hidden_states = self.quant_method.apply( + fused_experts_results: FusedExpertsResult = self.quant_method.apply( layer=self, x=hidden_states, router_logits=router_logits, @@ -339,6 +354,7 @@ def forward_impl(self, hidden_states: torch.Tensor, num_expert_group=self.num_expert_group, custom_routing_function=self.custom_routing_function, scoring_func=self.scoring_func, + routed_scaling_factor=self.routed_scaling_factor, e_score_correction_bias=self.e_score_correction_bias, activation=self.activation, apply_router_weight_on_input=self.apply_router_weight_on_input, @@ -350,25 +366,25 @@ def forward_impl(self, hidden_states: torch.Tensor, global_redundant_expert_num=self.global_redundant_expert_num, mc2_mask=mc2_mask) - if isinstance(final_hidden_states, tuple): - final_hidden_states, group_list_type, expert_tokens = final_hidden_states - if self.dynamic_eplb: - - moe_load_stream = moe_load_async_stream() - cur_stream = torch.npu.current_stream() - - moe_load_stream.wait_stream(cur_stream) - with npu_stream_switch(moe_load_stream): - self.moe_load += expert_tokens if group_list_type == 1 else \ - torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]]) - cur_stream.wait_stream(moe_load_stream) - - final_hidden_states = forward_context.moe_comm_method.finalize( - hidden_states=final_hidden_states, + if self.dynamic_eplb: + expert_tokens = fused_experts_results.expert_tokens + group_list_type = fused_experts_results.group_list_type + assert expert_tokens is not None and group_list_type is not None, \ + "expert_tokens and group_list_type should not be None when dynamic_eplb is enabled." + moe_load_stream = moe_load_async_stream() + cur_stream = torch.npu.current_stream() + moe_load_stream.wait_stream(cur_stream) + with npu_stream_switch(moe_load_stream): + self.moe_load += expert_tokens if group_list_type == 1 else \ + torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]]) + cur_stream.wait_stream(moe_load_stream) + + routed_out = forward_context.moe_comm_method.finalize( + hidden_states=fused_experts_results.routed_out, reduce_results=self.reduce_results, context_metadata=context_metadata) - return final_hidden_states + return routed_out class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE): @@ -439,7 +455,7 @@ def forward_impl(self, hidden_states: torch.Tensor, else: set_flash_common3_context(shared_experts=self._shared_experts) - fused_output = AscendFusedMoE.forward_impl( + routed_out = AscendFusedMoE.forward_impl( self, hidden_states=hidden_states, router_logits=router_logits, @@ -462,4 +478,4 @@ def forward_impl(self, hidden_states: torch.Tensor, assert fc3_context is not None shared_out = fc3_context.shared_out - return shared_out, fused_output + return shared_out, routed_out diff --git a/vllm_ascend/ops/fused_moe/moe_comm_method.py b/vllm_ascend/ops/fused_moe/moe_comm_method.py index 30d1e5c1376..06fd2fe4415 100644 --- a/vllm_ascend/ops/fused_moe/moe_comm_method.py +++ b/vllm_ascend/ops/fused_moe/moe_comm_method.py @@ -16,6 +16,7 @@ from __future__ import annotations from abc import ABC, abstractmethod +from dataclasses import dataclass from typing import Any, Dict, Optional import torch @@ -26,11 +27,11 @@ from vllm_ascend.ascend_forward_context import MoECommType from vllm_ascend.ops.fused_moe.moe_mlp import unified_apply_mlp from vllm_ascend.ops.fused_moe.prepare_finalize import ( - PrepareAndFinalizeWithAll2All, PrepareAndFinalizeWithAllGather, - PrepareAndFinalizeWithMC2, QuantType) + PrepareAndFinalize, PrepareAndFinalizeWithAll2All, + PrepareAndFinalizeWithAllGather, PrepareAndFinalizeWithMC2, QuantType) from vllm_ascend.ops.fused_moe.token_dispatcher import ( - TokenDispatcherWithAll2AllV, TokenDispatcherWithAllGather, - TokenDispatcherWithMC2) + MoETokenDispatcher, TokenDispatcherWithAll2AllV, + TokenDispatcherWithAllGather, TokenDispatcherWithMC2) _MoECommMethods: Dict[Optional[MoECommType], MoECommMethod] = {} @@ -47,6 +48,14 @@ def setup_moe_comm_method(moe_config): _MoECommMethods[MoECommType.FUSED_MC2] = FusedMC2CommImpl(moe_config) +@dataclass +class FusedExpertsResult: + routed_out: torch.Tensor + # For dynamic_eplb + group_list_type: int | None = None + expert_tokens: torch.Tensor | None = None + + class MoECommMethod(ABC): """Base class for MoE communication methods.""" @@ -118,7 +127,7 @@ def fused_experts( moe_comm_method = get_forward_context().moe_comm_method assert moe_comm_method is not None, "Missing communication context" - results = self.token_dispatcher.token_dispatch( + dispatch_results = self.token_dispatcher.token_dispatch( hidden_states=hidden_states, topk_weights=topk_weights, topk_ids=topk_ids, @@ -134,43 +143,41 @@ def fused_experts( dynamic_eplb=dynamic_eplb, pertoken_scale=pertoken_scale) - permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type, topk_scales, context_metadata = \ - results["hidden_states"], results["group_list"], results.get("dynamic_scale"), results["group_list_type"], results.get("topk_scales"), results.get("context_metadata") - - mlp_output = unified_apply_mlp(hidden_states=permuted_hidden_states, - w1=w1, - w1_scale=w1_scale, - w2=w2, - w2_scale=w2_scale, - group_list=expert_tokens, - dynamic_scale=dynamic_scale, - group_list_type=group_list_type, - w1_scale_bias=w1_scale_bias, - w2_scale_bias=w2_scale_bias, - w1_offset=w1_offset, - w2_offset=w2_offset, - topk_scales=topk_scales, - with_quant=use_int8_w8a8 - or use_int4_w4a8 or use_int4_w4a16, - fusion=use_int8_w8a8, - need_trans=need_trans, - dynamic_eplb=dynamic_eplb) - - final_hidden_states = self.token_dispatcher.token_combine( - hidden_states=mlp_output, context_metadata=context_metadata) - - if dynamic_eplb: - return (final_hidden_states, group_list_type, expert_tokens) - - return final_hidden_states + mlp_output = unified_apply_mlp( + hidden_states=dispatch_results.hidden_states, + w1=w1, + w1_scale=w1_scale, + w2=w2, + w2_scale=w2_scale, + group_list=dispatch_results.group_list, + dynamic_scale=dispatch_results.dynamic_scale, + group_list_type=dispatch_results.group_list_type, + w1_scale_bias=w1_scale_bias, + w2_scale_bias=w2_scale_bias, + w1_offset=w1_offset, + w2_offset=w2_offset, + topk_scales=dispatch_results.topk_scales, + with_quant=use_int8_w8a8 or use_int4_w4a8 or use_int4_w4a16, + fusion=use_int8_w8a8, + need_trans=need_trans, + dynamic_eplb=dynamic_eplb) + + combine_results = self.token_dispatcher.token_combine( + hidden_states=mlp_output, + context_metadata=dispatch_results.context_metadata) + + return FusedExpertsResult( + routed_out=combine_results.routed_out, + group_list_type=dispatch_results.group_list_type, + expert_tokens=dispatch_results.group_list) @abstractmethod - def _get_token_dispatcher(self): + def _get_token_dispatcher(self) -> MoETokenDispatcher: raise NotImplementedError( "_get_token_dispatcher function not implemented.") @abstractmethod - def _get_prepare_finalize(self): + def _get_prepare_finalize(self) -> PrepareAndFinalize: raise NotImplementedError( "_get_prepare_finalize function not implemented.") @@ -292,9 +299,11 @@ def fused_experts( w1_scale is None or w2_scale is None ), "w1_scale and w2_scale cannot be None for FusedMC2CommImpl." + assert isinstance(self.token_dispatcher, TokenDispatcherWithMC2), \ + "token_dispatcher must be an instance of TokenDispatcherWithMC2." if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1: out = torch.empty_like(hidden_states) - torch.ops._C_ascend.dispatch_ffn_combine( + torch.ops._C_ascend.dispatch_ffn_combine( # type: ignore x=hidden_states, weight1=w1[0], weight2=w2[0], @@ -308,7 +317,7 @@ def fused_experts( ) elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2: assert expert_map is not None, "expert_map cannot be None." - out, _ = torch.ops._C_ascend.dispatch_gmm_combine_decode( + out, _ = torch.ops._C_ascend.dispatch_gmm_combine_decode( # type: ignore x=hidden_states, expert_ids=topk_ids, gmm1_permuted_weight=w1[0], @@ -325,4 +334,4 @@ def fused_experts( else: raise ValueError( f"Wrong value of {envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2=}") - return out + return FusedExpertsResult(routed_out=out) diff --git a/vllm_ascend/ops/fused_moe/token_dispatcher.py b/vllm_ascend/ops/fused_moe/token_dispatcher.py index e17b033e50b..0513307ab74 100644 --- a/vllm_ascend/ops/fused_moe/token_dispatcher.py +++ b/vllm_ascend/ops/fused_moe/token_dispatcher.py @@ -21,6 +21,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod +from dataclasses import dataclass, field from typing import Any, Optional import torch @@ -35,6 +36,21 @@ is_hierarchical_communication_enabled) +@dataclass +class TokenDispatchResult: + hidden_states: torch.Tensor + group_list: torch.Tensor + group_list_type: int + dynamic_scale: torch.Tensor | None = field(default=None) + topk_scales: torch.Tensor | None = field(default=None) + context_metadata: dict = field(default_factory=dict) + + +@dataclass +class TokenCombineResult: + routed_out: torch.Tensor + + class MoETokenDispatcher(ABC): def __init__(self, **kwargs) -> None: @@ -74,14 +90,14 @@ def token_dispatch( with_quant: bool = False, dynamic_eplb: bool = False, pertoken_scale: Optional[torch.Tensor] = None, - ): + ) -> TokenDispatchResult: raise NotImplementedError("Dispatch function not implemented.") @abstractmethod def token_combine(self, hidden_states: torch.Tensor, context_metadata: dict, - bias: torch.Tensor = None): + bias: torch.Tensor | None = None) -> TokenCombineResult: raise NotImplementedError("Combine function not implemented.") @@ -207,24 +223,6 @@ def token_dispatch(self, expand_x, dynamic_scale, assist_info_for_combine, expert_token_nums, \ ep_recv_counts, tp_recv_counts, expand_scales = output[0:7] - # Handle shared experts (store intermediate results in local vars, not self) - shared_act = None - swiglu_out_scale = None - if with_quant: - if shared_experts is not None: - share_up_out, _ = shared_experts.gate_up_proj( - (quantized_x_for_share, dynamic_scale_for_share)) - shared_gate_up, shared_dequant_scale = share_up_out[ - 0], share_up_out[1] - shared_act_out = shared_experts.act_fn( - (shared_gate_up, shared_dequant_scale)) - shared_act, swiglu_out_scale = shared_act_out[ - 0], shared_act_out[1] - else: - if shared_experts is not None: - shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states) - shared_act = shared_experts.act_fn(shared_gate_up) - context_metadata = { "topk_ids": topk_ids, "topk_weights": topk_weights, @@ -233,20 +231,16 @@ def token_dispatch(self, "tp_recv_counts": tp_recv_counts, "assist_info_for_combine": assist_info_for_combine, "shared_experts": shared_experts, - "shared_act": shared_act, - "swiglu_out_scale": swiglu_out_scale, "expand_scales": expand_scales } group_list_type = 0 - return { - "group_list_type": group_list_type, - "hidden_states": expand_x, - "group_list": expert_token_nums, - "dynamic_scale": dynamic_scale, - "context_metadata": context_metadata - } + return TokenDispatchResult(hidden_states=expand_x, + dynamic_scale=dynamic_scale, + group_list=expert_token_nums, + group_list_type=group_list_type, + context_metadata=context_metadata) def get_combine_mc_kwargs(self, hidden_states: torch.Tensor, context_metadata: dict): @@ -300,12 +294,7 @@ def get_combine_mc_kwargs(self, hidden_states: torch.Tensor, kwargs_mc2.update(stage3_kwargs) return kwargs_mc2 - def token_combine( - self, - hidden_states: torch.Tensor, - context_metadata: dict, - bias: torch.Tensor = None, - ): + def token_combine(self, hidden_states, context_metadata, bias=None): assert bias is None, "Bias is not supported in MoEAlltoAllvTokenDispatcher." kwargs_mc2 = self.get_combine_mc_kwargs(hidden_states, @@ -313,20 +302,7 @@ def token_combine( combined_output = torch_npu.npu_moe_distribute_combine_v2(**kwargs_mc2) \ if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine(**kwargs_mc2) - # Handle shared experts from metadata - shared_experts = context_metadata["shared_experts"] - if shared_experts is None: - return combined_output - - shared_act = context_metadata["shared_act"] - if self.with_quant: - swiglu_out_scale = context_metadata["swiglu_out_scale"] - shared_hidden_states, _ = shared_experts.down_proj( - (shared_act, swiglu_out_scale)) - else: - shared_hidden_states, _ = shared_experts.down_proj(shared_act) - - return combined_output, shared_hidden_states + return TokenCombineResult(routed_out=combined_output) class TokenDispatcherWithAllGather(MoETokenDispatcher): @@ -401,18 +377,16 @@ def token_dispatch(self, "topk_weights": topk_weights, "expanded_row_idx": expanded_row_idx } - return { - "group_list_type": group_list_type, - "hidden_states": sorted_hidden_states, - "group_list": expert_tokens, - "dynamic_scale": pertoken_scale if self.with_quant else None, - "context_metadata": context_metadata - } - def token_combine(self, - hidden_states: torch.Tensor, - context_metadata: dict, - bias: torch.Tensor = None): + return TokenDispatchResult( + hidden_states=sorted_hidden_states, + dynamic_scale=pertoken_scale if self.with_quant else None, + group_list=expert_tokens, + group_list_type=group_list_type, + context_metadata=context_metadata, + ) + + def token_combine(self, hidden_states, context_metadata, bias=None): assert self.original_shape is not None final_hidden_states = torch_npu.npu_moe_token_unpermute( permuted_tokens=hidden_states, @@ -422,7 +396,7 @@ def token_combine(self, final_hidden_states = final_hidden_states.view(self.original_shape) # these values are no longer used, so they need to be set to None for memory release. - return final_hidden_states + return TokenCombineResult(routed_out=final_hidden_states) class TokenDispatcherWithAll2AllV(MoETokenDispatcher): @@ -530,20 +504,15 @@ def token_dispatch(self, reversed_global_input_permutation_mapping } - return { - "hidden_states": global_input_tokens, - "group_list": tokens_per_expert, - "group_list_type": 1, - "dynamic_scale": dynamic_scale_final, - "context_metadata": context_metadata, - } + return TokenDispatchResult( + hidden_states=global_input_tokens, + dynamic_scale=dynamic_scale_final, + group_list=tokens_per_expert, + group_list_type=1, + context_metadata=context_metadata, + ) - def token_combine( - self, - hidden_states: torch.Tensor, - context_metadata: dict, - bias: torch.Tensor = None, - ): + def token_combine(self, hidden_states, context_metadata, bias=None): assert bias is None, "Bias is not supported in MoEAlltoAllvTokenDispatcher." # 1. Preprocess using metadata @@ -564,7 +533,7 @@ def token_combine( output = self._combine_postprocess(permutated_local_input_tokens, context_metadata) - return output + return TokenCombineResult(routed_out=output) def _dispatch_preprocess(self, hidden_states, topk_ids): assert self.hidden_shape is not None diff --git a/vllm_ascend/ops/mla.py b/vllm_ascend/ops/mla.py index 1cedda9c352..1c952aa6386 100644 --- a/vllm_ascend/ops/mla.py +++ b/vllm_ascend/ops/mla.py @@ -94,8 +94,6 @@ def __init__( hf_config = get_current_vllm_config().model_config.hf_config self.enable_shared_expert_dp = get_ascend_config( ).enable_shared_expert_dp - self.debug_layer_idx = int(self.prefix.split(".")[-2]) - self.first_k_dense_replace = hf_config.first_k_dense_replace self.tp_size = get_tensor_model_parallel_world_size() self.layers = hf_config.num_hidden_layers if mla_modules.indexer is not None: diff --git a/vllm_ascend/quantization/quant_config.py b/vllm_ascend/quantization/quant_config.py index 8c8b75187b7..49a1a5bad7b 100644 --- a/vllm_ascend/quantization/quant_config.py +++ b/vllm_ascend/quantization/quant_config.py @@ -298,6 +298,12 @@ def get_scaled_act_names(self) -> List[str]: "experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"] }, + "longcat_flash": { + "gate_up_proj": ["gate_proj", "up_proj"], + "experts": + ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], + "fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"] + }, } @@ -514,6 +520,7 @@ def apply( num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, is_prefill: bool = True, enable_force_load_balance: bool = False, @@ -524,9 +531,9 @@ def apply( return self.quant_method.apply( layer, x, router_logits, top_k, renormalize, use_grouped_topk, global_num_experts, expert_map, topk_group, num_expert_group, - custom_routing_function, scoring_func, e_score_correction_bias, - is_prefill, enable_force_load_balance, log2phy, - global_redundant_expert_num, **kwargs) + custom_routing_function, scoring_func, routed_scaling_factor, + e_score_correction_bias, is_prefill, enable_force_load_balance, + log2phy, global_redundant_expert_num, **kwargs) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if hasattr(self.quant_method, "process_weights_after_loading"): diff --git a/vllm_ascend/quantization/w4a16.py b/vllm_ascend/quantization/w4a16.py index d15fa25aaa2..4fcc33807f2 100644 --- a/vllm_ascend/quantization/w4a16.py +++ b/vllm_ascend/quantization/w4a16.py @@ -199,6 +199,7 @@ def apply( num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, is_prefill: bool = True, enable_force_load_balance: bool = True, diff --git a/vllm_ascend/quantization/w4a8_dynamic.py b/vllm_ascend/quantization/w4a8_dynamic.py index 45a7bc18337..3222f2ea031 100644 --- a/vllm_ascend/quantization/w4a8_dynamic.py +++ b/vllm_ascend/quantization/w4a8_dynamic.py @@ -336,6 +336,7 @@ def apply( num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, is_prefill: bool = True, enable_force_load_balance: bool = False, diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 986f6fd2699..bebd807bf55 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -28,7 +28,8 @@ from vllm_ascend.ascend_forward_context import MoECommType from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.flash_common3_context import get_flash_common3_context -from vllm_ascend.ops.fused_moe.experts_selector import select_experts +from vllm_ascend.ops.fused_moe.experts_selector import (select_experts, + zero_experts_compute) from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, maybe_trans_nz @@ -183,6 +184,7 @@ def apply( num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, is_prefill: bool = True, enable_force_load_balance: bool = False, @@ -194,8 +196,11 @@ def apply( pertoken_scale: Optional[Any] = None, **kwargs, ) -> torch.Tensor: - assert router_logits.shape[ - 1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)" + zero_expert_num = getattr(layer, "zero_expert_num", 0) + zero_expert_type = getattr(layer, "zero_expert_type", None) + if zero_expert_num == 0 or zero_expert_type is None: + assert router_logits.shape[1] == global_num_experts - global_redundant_expert_num, \ + "Number of global experts mismatch (excluding redundancy)" if self.multistream_overlap_gate: fc3_context = get_flash_common3_context() @@ -213,10 +218,19 @@ def apply( num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, global_num_experts=global_num_experts) assert topk_ids is not None assert topk_weights is not None + if zero_expert_num > 0 and zero_expert_type is not None: + topk_ids, topk_weights, zero_expert_result = zero_experts_compute( + expert_indices=topk_ids, + expert_scales=topk_weights, + num_experts=global_num_experts, + zero_expert_type=zero_expert_type, + hidden_states=x, + ) # this is a naive implementation for experts load balance so as # to avoid accumulating too much tokens on a single rank. # currently it is only activated when doing profile runs. @@ -253,7 +267,7 @@ def apply( fused_scale_flag = (get_forward_context().moe_comm_type == MoECommType.FUSED_MC2 and envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1) - return moe_comm_method.fused_experts( + final_hidden_states = moe_comm_method.fused_experts( hidden_states=x, pertoken_scale=pertoken_scale, w1=w1, @@ -271,6 +285,9 @@ def apply( dynamic_scale_for_share=dynamic_scale_for_share, dynamic_eplb=self.dynamic_eplb, mc2_mask=kwargs.get("mc2_mask", None)) + if zero_expert_num > 0 and zero_expert_type is not None: + final_hidden_states += zero_expert_result + return final_hidden_states def process_weights_after_loading(self, layer): layer.w13_weight.data = layer.w13_weight.data.transpose( diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 319c7e41e0c..614afb0e8db 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -106,7 +106,7 @@ from vllm_ascend.utils import (AscendDeviceType, ProfileExecuteDuration, enable_sp, get_ascend_device_type, is_moe_model, lmhead_tp_enable, maybe_trans_nz, - set_weight_prefetch_method) + set_weight_prefetch_method, vllm_version_is) from vllm_ascend.worker.npu_input_batch import NPUInputBatch from vllm_ascend.worker.pcp_utils import PCPManager @@ -1092,12 +1092,20 @@ def _generate_process_reqs_hidden_states(self, maybe_padded_num_tokens, intermediate_tensors, inputs_embeds): assert self.model is not None - hidden_states = self.model( - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - **self._init_model_kwargs(maybe_padded_num_tokens)) + if vllm_version_is('0.13.0'): + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **self._init_model_kwargs(maybe_padded_num_tokens)) + else: + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **self._init_model_kwargs()) forward_context = get_forward_context() if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL \ @@ -2240,9 +2248,10 @@ def initialize_kv_cache_tensors( kv_caches[layer_name] = kv_caches[target_layer_name] from vllm.v1.worker.utils import bind_kv_cache + num_attn_module = 2 if self.model_config.hf_config.model_type == "longcat_flash" else 1 bind_kv_cache(kv_caches, self.compilation_config.static_forward_context, - self.kv_caches) + self.kv_caches, num_attn_module) return kv_caches def _allocate_kv_cache_tensors(