diff --git a/.github/workflows/_e2e_test.yaml b/.github/workflows/_e2e_test.yaml index 8a718eb8558..1d152289551 100644 --- a/.github/workflows/_e2e_test.yaml +++ b/.github/workflows/_e2e_test.yaml @@ -203,6 +203,7 @@ jobs: pytest -sv --durations=0 tests/e2e/multicard/test_qwen3_moe.py pytest -sv --durations=0 tests/e2e/multicard/test_offline_weight_load.py + e2e-4-cards: name: multicard-4 needs: [e2e, e2e-2-cards] @@ -267,6 +268,7 @@ jobs: pytest -sv --durations=0 tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_multistream_moe pytest -sv --durations=0 tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Kimi_K2_Thinking_W4A16 pytest -sv --durations=0 tests/e2e/multicard/test_data_parallel_tp2.py + pytest -sv --durations=0 tests/e2e/multicard/long_sequence/test_flashcomm2.py pytest -sv --durations=0 tests/e2e/multicard/long_sequence/test_basic.py - name: Install Ascend toolkit & triton_ascend (for Qwen3-Next-80B-A3B-Instruct) diff --git a/tests/e2e/multicard/long_sequence/test_flashcomm2.py b/tests/e2e/multicard/long_sequence/test_flashcomm2.py new file mode 100644 index 00000000000..a423161851f --- /dev/null +++ b/tests/e2e/multicard/long_sequence/test_flashcomm2.py @@ -0,0 +1,54 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py +# +"""Compare the short outputs of HF and vLLM when using greedy sampling. + +Run `pytest tests/e2e/multicard/test_flashcomm2.py`. +""" + +import os + +import pytest +from vllm import SamplingParams + +from tests.e2e.conftest import VllmRunner +from vllm_ascend.utils import vllm_version_is + +os.environ["VLLM_ASCEND_ENABLE_FLASHCOMM2_OSHARED"] = "1" +os.environ["VLLM_ASCEND_FLASHCOMM2_PARALLEL_SIZE"] = "1" + + +@pytest.mark.skipif(vllm_version_is('0.12.0'), + reason="0.12.0 is not supported for context sequence.") +def test_pcp_dcp_flashcomm2(): + prompts = [ + "The capital of France is", "Hello, my name is Tom, I am", + "The president of United States is", "AI future is" + ] + model = "deepseek-ai/DeepSeek-V2-Lite-Chat" + sampling_params = SamplingParams(max_tokens=32, temperature=0.0) + with VllmRunner(model, + enforce_eager=True, + max_model_len=1024, + tensor_parallel_size=2, + prefill_context_parallel_size=2, + decode_context_parallel_size=2, + max_num_batched_tokens=1024, + enable_expert_parallel=True, + block_size=128) as runner: + runner.model.generate(prompts, sampling_params) diff --git a/vllm_ascend/distributed/parallel_state.py b/vllm_ascend/distributed/parallel_state.py index e886a31113e..73ec2c9c3f3 100644 --- a/vllm_ascend/distributed/parallel_state.py +++ b/vllm_ascend/distributed/parallel_state.py @@ -41,6 +41,7 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ): global_tp_size = parallel_config.tensor_parallel_size global_dp_size = parallel_config.data_parallel_size global_pp_size = parallel_config.pipeline_parallel_size + global_pcp_size = parallel_config.prefill_context_parallel_size # The layout of all ranks: ExternalDP * EP # ExternalDP is the data parallel group that is not part of the model, @@ -154,16 +155,22 @@ def _create_shared_weight_group(group_name: str) -> GroupCoordinator: for pp_idx in range(global_pp_size): group = [] for dp_idx in range(global_dp_size): - base = (dp_idx * global_pp_size + pp_idx) * global_tp_size - for i in range(global_tp_size): - global_rank = base + i - group.append(global_rank) + for pcp_idx in range(global_pcp_size): + base = (dp_idx * global_pp_size * global_pcp_size * + global_tp_size + + pp_idx * global_pcp_size * global_tp_size + + pcp_idx * global_tp_size) + for tp_idx in range(global_tp_size): + global_rank = base + tp_idx + group.append(global_rank) group_ranks.append(group) - return init_model_parallel_group(group_ranks, - get_world_group().local_rank, - backend, - group_name=group_name) + return init_model_parallel_group( + group_ranks, + get_world_group().local_rank, + backend, + group_name=group_name, + ) global _SHARED_WEIGHT # TODO: Check if the model is Deepseek V3.2 with enabled SFA CP and activated shared weights. It will then be normalized within the PCP parameters. -- clrs97