Skip to content

Commit

Permalink
[Hardware] [Intel GPU] refactor xpu worker/executor (vllm-project#7686)
Browse files Browse the repository at this point in the history
  • Loading branch information
jikunshang authored Aug 20, 2024
1 parent 93b3870 commit 08619ff
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 28 deletions.
38 changes: 15 additions & 23 deletions vllm/executor/xpu_executor.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from typing import List, Optional
from typing import List, Optional, Tuple, Union

import torch

from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, PromptAdapterConfig,
SchedulerConfig, SpeculativeConfig)
ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.gpu_executor import GPUExecutor
from vllm.logger import init_logger
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput
from vllm.utils import make_async
from vllm.worker.worker_base import WorkerWrapperBase

logger = init_logger(__name__)

Expand All @@ -30,6 +30,7 @@ def __init__(
lora_config: Optional[LoRAConfig],
prompt_adapter_config: Optional[PromptAdapterConfig],
speculative_config: Optional[SpeculativeConfig],
observability_config: Optional[ObservabilityConfig],
) -> None:
assert device_config.device_type == "xpu"
assert (not speculative_config
Expand All @@ -46,32 +47,23 @@ def __init__(
self.device_config = device_config
self.prompt_adapter_config = prompt_adapter_config
self.speculative_config = None
self.observability_config = observability_config

# Instantiate the worker and load the model to GPU.
self._init_executor()

def _create_worker(self,
local_rank: int = 0,
rank: int = 0,
distributed_init_method: Optional[str] = None):
if self.speculative_config is None:
worker_module_name = "vllm.worker.xpu_worker"
worker_class_name = "XPUWorker"
else:
def _get_worker_module_and_class(self) -> Tuple[str, str]:
if self.speculative_config is not None:
raise NotImplementedError(
"XPU does not support speculative decoding")

wrapper = WorkerWrapperBase(
worker_module_name=worker_module_name,
worker_class_name=worker_class_name,
)
wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank,
distributed_init_method))
return wrapper.worker
else:
worker_module_name = "vllm.worker.xpu_worker"
worker_class_name = "XPUWorker"
return (worker_module_name, worker_class_name)

def execute_model(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
self, execute_model_req: ExecuteModelRequest
) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]:
output = self.driver_worker.execute_model(execute_model_req)
return output

Expand Down
1 change: 0 additions & 1 deletion vllm/worker/xpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,6 @@ def load_model(self) -> None:
device_config=self.device_config,
load_config=self.load_config,
lora_config=self.lora_config,
multimodal_config=self.multimodal_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
cache_config=self.cache_config,
Expand Down
15 changes: 11 additions & 4 deletions vllm/worker/xpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
import torch.distributed

from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig,
ModelConfig, MultiModalConfig, ObservabilityConfig,
ParallelConfig, PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
Expand Down Expand Up @@ -50,6 +50,7 @@ def __init__(
speculative_config: Optional[SpeculativeConfig] = None,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
is_driver_worker: bool = False,
observability_config: Optional[ObservabilityConfig] = None,
) -> None:
assert device_config.device_type == "xpu"
assert is_xpu()
Expand All @@ -67,8 +68,10 @@ def __init__(
self.lora_config = lora_config
self.prompt_adapter_config = prompt_adapter_config
self.is_driver_worker = is_driver_worker
if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0."
self.observability_config = observability_config
if parallel_config and is_driver_worker:
assert rank % parallel_config.tensor_parallel_size == 0, \
"Driver worker should be rank 0 of tensor parallel group."

self.multimodal_config = multimodal_config

Expand Down Expand Up @@ -183,7 +186,11 @@ def init_worker_distributed_environment(self) -> None:
# dependency (libdrm and drm headers) on your system.
ENV_CCL_ZE_IPC_EXCHANGE = os.getenv("CCL_ZE_IPC_EXCHANGE",
"sockets")
ENV_LOCAL_WORLD_SIZE = os.getenv("LOCAL_WORLD_SIZE",
str(parallel_config.world_size))
os.environ['CCL_ZE_IPC_EXCHANGE'] = ENV_CCL_ZE_IPC_EXCHANGE
os.environ["LOCAL_WORLD_SIZE"] = ENV_LOCAL_WORLD_SIZE
os.environ["LOCAL_RANK"] = str(self.local_rank)
init_distributed_environment(
world_size=parallel_config.world_size,
rank=rank,
Expand Down

0 comments on commit 08619ff

Please sign in to comment.