Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,7 @@ def _init_executor(self) -> None:
is_driver_worker=is_driver_worker,
)

wrapper_kwargs = {
"vllm_config": self.vllm_config,
}

self.driver_worker = WorkerWrapperBase(**wrapper_kwargs)
self.driver_worker = WorkerWrapperBase()

self.collective_rpc("init_worker", args=([worker_rpc_kwargs],))
self.collective_rpc("init_device")
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def assert_from_collective_rpc(engine: LLM, closure: Callable, closure_kwargs: d
class DummyExecutor(UniProcExecutor):
def _init_executor(self) -> None:
"""Initialize the worker and load the model."""
self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config, rpc_rank=0)
self.driver_worker = WorkerWrapperBase(rpc_rank=0)
distributed_init_method = get_distributed_init_method(get_ip(), get_open_port())
local_rank = 0
# set local rank as the device index if specified
Expand Down
11 changes: 0 additions & 11 deletions vllm/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,6 @@
logger = init_logger(__name__)


# TODO: This function can be removed if transformer_modules classes are
# serialized by value when communicating between processes
def init_cached_hf_modules() -> None:
"""
Lazy initialization of the Hugging Face modules.
"""
from transformers.dynamic_module_utils import init_hf_modules

init_hf_modules()


def import_pynvml():
"""
Historical comments:
Expand Down
4 changes: 1 addition & 3 deletions vllm/v1/executor/multiproc_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,9 +519,7 @@ def __init__(
shared_worker_lock: LockType,
):
self.rank = rank
wrapper = WorkerWrapperBase(
vllm_config=vllm_config, rpc_rank=local_rank, global_rank=rank
)
wrapper = WorkerWrapperBase(rpc_rank=local_rank, global_rank=rank)
# TODO: move `init_worker` to executor level as a collective rpc call
all_kwargs: list[dict] = [
{} for _ in range(vllm_config.parallel_config.world_size)
Expand Down
9 changes: 3 additions & 6 deletions vllm/v1/executor/ray_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,19 +208,16 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", **ray_remote_kwar
num_gpus=num_gpus,
scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs,
)(RayWorkerWrapper).remote( # type: ignore[attr-defined]
vllm_config=self.vllm_config, rpc_rank=rank
)
)(RayWorkerWrapper).remote(rpc_rank=rank)
else:
worker = ray.remote(
num_cpus=0,
num_gpus=0,
resources={current_platform.ray_device_key: num_gpus},
scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs,
)(RayWorkerWrapper).remote( # type: ignore[attr-defined]
vllm_config=self.vllm_config, rpc_rank=rank
)
)(RayWorkerWrapper).remote(rpc_rank=rank)

worker_metadata.append(RayWorkerMetaData(worker=worker, created_rank=rank))

worker_ips = ray.get(
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/executor/uniproc_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
class UniProcExecutor(Executor):
def _init_executor(self) -> None:
"""Initialize the worker and load the model."""
self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config, rpc_rank=0)
self.driver_worker = WorkerWrapperBase(rpc_rank=0)
distributed_init_method, rank, local_rank = self._distributed_args()
kwargs = dict(
vllm_config=self.vllm_config,
Expand Down
6 changes: 0 additions & 6 deletions vllm/v1/worker/gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,6 @@ def __init__(
precision = envs.VLLM_FLOAT32_MATMUL_PRECISION
torch.set_float32_matmul_precision(precision)

if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils.import_utils import init_cached_hf_modules

init_cached_hf_modules()

# Buffers saved before sleep
self._sleep_saved_buffers: dict[str, torch.Tensor] = {}

Expand Down
6 changes: 0 additions & 6 deletions vllm/v1/worker/tpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,6 @@ def __init__(
else:
self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[self.cache_config.cache_dtype]

if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils.import_utils import init_cached_hf_modules

init_cached_hf_modules()

# Delay profiler initialization to the start of the profiling.
# This is because in vLLM V1, MP runtime is initialized before the
# TPU Worker is initialized. The profiler server needs to start after
Expand Down
57 changes: 23 additions & 34 deletions vllm/v1/worker/worker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,6 @@ class WorkerWrapperBase:

def __init__(
self,
vllm_config: VllmConfig,
rpc_rank: int = 0,
global_rank: int | None = None,
) -> None:
Expand All @@ -194,21 +193,10 @@ def __init__(
"""
self.rpc_rank = rpc_rank
self.global_rank = self.rpc_rank if global_rank is None else global_rank
self.worker: WorkerBase | None = None

# do not store this `vllm_config`, `init_worker` will set the final
# one.
# TODO: investigate if we can remove this field in `WorkerWrapperBase`,
# `init_cached_hf_modules` should be unnecessary now.
self.vllm_config: VllmConfig | None = None

# `model_config` can be None in tests
model_config = vllm_config.model_config
if model_config and model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils.import_utils import init_cached_hf_modules

init_cached_hf_modules()
# Initialized after init_worker is called
self.worker: WorkerBase
self.vllm_config: VllmConfig

def shutdown(self) -> None:
if self.worker is not None:
Expand Down Expand Up @@ -241,27 +229,34 @@ def init_worker(self, all_kwargs: list[dict[str, Any]]) -> None:
Arguments are passed to the worker class constructor.
"""
kwargs = all_kwargs[self.rpc_rank]
self.vllm_config = kwargs.get("vllm_config")
assert self.vllm_config is not None, (

vllm_config: VllmConfig | None = kwargs.get("vllm_config")
assert vllm_config is not None, (
"vllm_config is required to initialize the worker"
)
self.vllm_config.enable_trace_function_call_for_thread()
self.vllm_config = vllm_config

vllm_config.enable_trace_function_call_for_thread()

from vllm.plugins import load_general_plugins

load_general_plugins()

if isinstance(self.vllm_config.parallel_config.worker_cls, str):
worker_class = resolve_obj_by_qualname(
self.vllm_config.parallel_config.worker_cls
parallel_config = vllm_config.parallel_config
if isinstance(parallel_config.worker_cls, str):
worker_class: type[WorkerBase] = resolve_obj_by_qualname(
parallel_config.worker_cls
)
else:
raise ValueError(
"passing worker_cls is no longer supported. Please pass keep the class in a separate module and pass the qualified name of the class as a string." # noqa: E501
"passing worker_cls is no longer supported. "
"Please pass keep the class in a separate module "
"and pass the qualified name of the class as a string."
)
if self.vllm_config.parallel_config.worker_extension_cls:

if parallel_config.worker_extension_cls:
worker_extension_cls = resolve_obj_by_qualname(
self.vllm_config.parallel_config.worker_extension_cls
parallel_config.worker_extension_cls
)
extended_calls = []
if worker_extension_cls not in worker_class.__bases__:
Expand Down Expand Up @@ -294,7 +289,7 @@ def init_worker(self, all_kwargs: list[dict[str, Any]]) -> None:
"This argument is needed for mm_processor_cache_type='shm'."
)

mm_config = self.vllm_config.model_config.multimodal_config
mm_config = vllm_config.model_config.multimodal_config
if mm_config and mm_config.mm_processor_cache_type == "shm":
raise ValueError(msg)
else:
Expand All @@ -303,15 +298,14 @@ def init_worker(self, all_kwargs: list[dict[str, Any]]) -> None:
self.mm_receiver_cache = None
else:
self.mm_receiver_cache = worker_receiver_cache_from_config(
self.vllm_config,
vllm_config,
MULTIMODAL_REGISTRY,
shared_worker_lock,
)

with set_current_vllm_config(self.vllm_config):
# To make vLLM config available during worker initialization
self.worker = worker_class(**kwargs)
assert self.worker is not None

def initialize_from_config(self, kv_cache_configs: list[Any]) -> None:
kv_cache_config = kv_cache_configs[self.global_rank]
Expand Down Expand Up @@ -358,20 +352,15 @@ def _apply_mm_cache(self, scheduler_output: SchedulerOutput) -> None:
)

def execute_model(
self,
scheduler_output: SchedulerOutput,
*args,
**kwargs,
self, scheduler_output: SchedulerOutput
) -> ModelRunnerOutput | AsyncModelRunnerOutput | None:
self._apply_mm_cache(scheduler_output)

assert self.worker is not None
return self.worker.execute_model(scheduler_output, *args, **kwargs)
return self.worker.execute_model(scheduler_output)

def reset_mm_cache(self) -> None:
mm_receiver_cache = self.mm_receiver_cache
if mm_receiver_cache is not None:
mm_receiver_cache.clear_cache()

assert self.worker is not None
self.worker.reset_mm_cache()