diff --git a/tests/ut/torchair/test_torchair_worker.py b/tests/ut/torchair/test_torchair_worker.py index 32d5a92e655..0397aee17c7 100644 --- a/tests/ut/torchair/test_torchair_worker.py +++ b/tests/ut/torchair/test_torchair_worker.py @@ -59,6 +59,7 @@ def test_init_device(self, mock_platform, mock_init_dist_env): worker.vllm_config = MagicMock() worker.parallel_config = MagicMock() worker.parallel_config.local_world_size = 0 + worker.parallel_config.data_parallel_size = 1 result = worker._init_device() @@ -93,6 +94,7 @@ def test_init_device_torchair_worker(self, mock_platform, worker.vllm_config = MagicMock() worker.parallel_config = MagicMock() worker.parallel_config.local_world_size = 0 + worker.parallel_config.data_parallel_size = 1 result = worker._init_device() diff --git a/tests/ut/worker/test_worker_v1.py b/tests/ut/worker/test_worker_v1.py index fbc7fdc4299..5a12981a370 100644 --- a/tests/ut/worker/test_worker_v1.py +++ b/tests/ut/worker/test_worker_v1.py @@ -329,6 +329,8 @@ def test_init_device(self, mock_platform, mock_init_dist_env): worker.model_config = MagicMock() worker.parallel_config = MagicMock() worker.parallel_config.local_world_size = 0 + worker.parallel_config.data_parallel_size = 1 + worker.model_config.seed = 42 # Test _init_device diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index e9000eae38e..df7fec602d0 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -208,12 +208,18 @@ def _init_device(self): NPUPlatform.set_device(device) NPUPlatform.empty_cache() - visible_device_count = (torch.npu.device_count() - if torch.npu.is_available() else 0) - assert self.parallel_config.local_world_size <= visible_device_count, ( - f"local_world_size ({self.parallel_config.local_world_size}) must be " - f"less than or equal to the number of visible devices " - f"({visible_device_count}).") + if (self.parallel_config.data_parallel_size > 1 + and self.parallel_config.data_parallel_size_local > 0 + and self.parallel_config.distributed_executor_backend + not in ["ray", "external_launcher"] and + self.vllm_config.parallel_config.data_parallel_backend != "ray" + and self.vllm_config.parallel_config.nnodes_within_dp == 1): + visible_device_count = (torch.npu.device_count() + if torch.npu.is_available() else 0) + assert self.parallel_config.local_world_size <= visible_device_count, ( + f"local_world_size ({self.parallel_config.local_world_size}) must " + f"be less than or equal to the number of visible devices " + f"({visible_device_count}).") self.init_npu_memory = NPUPlatform.mem_get_info()[0] # Initialize the distributed environment.