diff --git a/tests/e2e/multicard/test_ep_etp.py b/tests/e2e/multicard/test_ep_etp.py new file mode 100644 index 00000000000..0232b3871d7 --- /dev/null +++ b/tests/e2e/multicard/test_ep_etp.py @@ -0,0 +1,38 @@ +import os + +import pytest + +from tests.conftest import VllmRunner +from tests.model_utils import check_outputs_equal + + +@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0", + reason="ep is not supported on v0") +@pytest.mark.parametrize("model_name", ["deepseek-ai/DeepSeek-V2-Lite-Chat"]) +def test_e2e_ep_etp_correctness(model_name): + example_prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + max_tokens = 5 + + with VllmRunner(model_name, + tensor_parallel_size=2, + additional_config={ + "expert_tensor_parallel_size": 2, + }) as vllm_model: + etp_output = vllm_model.generate_greedy(example_prompts, max_tokens) + + with VllmRunner(model_name, + tensor_parallel_size=2, + enable_expert_parallel=True) as vllm_model: + ep_output = vllm_model.generate_greedy(example_prompts, max_tokens) + + check_outputs_equal( + outputs_0_lst=ep_output, + outputs_1_lst=etp_output, + name_0="ep_output", + name_1="etp_output", + ) diff --git a/vllm_ascend/distributed/parallel_state.py b/vllm_ascend/distributed/parallel_state.py index 2778a6ef277..a056fa17fae 100644 --- a/vllm_ascend/distributed/parallel_state.py +++ b/vllm_ascend/distributed/parallel_state.py @@ -28,37 +28,34 @@ def model_parallel_initialized(): def init_ascend_model_parallel( expert_parallel_size: int = 1, expert_tensor_parallel_size: int = 1, - world_size: Optional[int] = None, backend: Optional[str] = None, ): if model_parallel_initialized(): return assert torch.distributed.is_initialized() - world_size = world_size or torch.distributed.get_world_size() + world_size = torch.distributed.get_world_size() backend = backend or torch.distributed.get_backend( get_world_group().device_group) - num_expert_parallel_groups = expert_tensor_parallel_size - num_expert_tensor_parallel_groups = expert_parallel_size - global _EP - group_ranks = [] - for i in range(num_expert_parallel_groups): - ranks = list(range(i, world_size, num_expert_parallel_groups)) - group_ranks.append(ranks) + # The layout of all ranks: ExternalDP * EP * ETP + # ExternalDP is the data parallel group that is not part of the model, + # every dp rank can generate independently (in verl integration). + all_ranks = torch.arange(world_size).reshape(-1, expert_parallel_size, + expert_tensor_parallel_size) + global _EP + group_ranks = all_ranks.transpose(1, + 2).view(-1, + expert_parallel_size).unbind(0) + group_ranks = [x.tolist() for x in group_ranks] _EP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, group_name="ep") - group_ranks = [] global _ETP - for i in range(num_expert_tensor_parallel_groups): - ranks = list( - range(i * expert_tensor_parallel_size, - (i + 1) * expert_tensor_parallel_size)) - group_ranks.append(ranks) - + group_ranks = all_ranks.view(-1, expert_tensor_parallel_size).unbind(0) + group_ranks = [x.tolist() for x in group_ranks] _ETP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, diff --git a/vllm_ascend/worker/worker.py b/vllm_ascend/worker/worker.py index bffc6a8de89..4d163abb872 100644 --- a/vllm_ascend/worker/worker.py +++ b/vllm_ascend/worker/worker.py @@ -554,7 +554,6 @@ def _init_worker_distributed_environment( init_ascend_model_parallel( parallel_config.expert_parallel_size, parallel_config.expert_tensor_parallel_size, - parallel_config.world_size_across_dp, ) ensure_kv_transfer_initialized(vllm_config) diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index 6fe84a45808..5535be7e464 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -272,7 +272,6 @@ def _init_worker_distributed_environment(self) -> None: init_ascend_model_parallel( parallel_config.expert_parallel_size, parallel_config.expert_tensor_parallel_size, - parallel_config.world_size_across_dp, ) ensure_kv_transfer_initialized(self.vllm_config)