diff --git a/tests/model_executor/model_loader/runai_streamer_loader/conftest.py b/tests/model_executor/model_loader/runai_streamer_loader/conftest.py index 9a022f6bbd9d..bad9dea1bf65 100644 --- a/tests/model_executor/model_loader/runai_streamer_loader/conftest.py +++ b/tests/model_executor/model_loader/runai_streamer_loader/conftest.py @@ -29,11 +29,7 @@ def _init_executor(self) -> None: is_driver_worker=is_driver_worker, ) - wrapper_kwargs = { - "vllm_config": self.vllm_config, - } - - self.driver_worker = WorkerWrapperBase(**wrapper_kwargs) + self.driver_worker = WorkerWrapperBase() self.collective_rpc("init_worker", args=([worker_rpc_kwargs],)) self.collective_rpc("init_device") diff --git a/tests/model_executor/model_loader/tensorizer_loader/conftest.py b/tests/model_executor/model_loader/tensorizer_loader/conftest.py index 826ecec71e6c..6c85a1399196 100644 --- a/tests/model_executor/model_loader/tensorizer_loader/conftest.py +++ b/tests/model_executor/model_loader/tensorizer_loader/conftest.py @@ -67,7 +67,7 @@ def assert_from_collective_rpc(engine: LLM, closure: Callable, closure_kwargs: d class DummyExecutor(UniProcExecutor): def _init_executor(self) -> None: """Initialize the worker and load the model.""" - self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config, rpc_rank=0) + self.driver_worker = WorkerWrapperBase(rpc_rank=0) distributed_init_method = get_distributed_init_method(get_ip(), get_open_port()) local_rank = 0 # set local rank as the device index if specified diff --git a/vllm/utils/import_utils.py b/vllm/utils/import_utils.py index ff0f0350fd94..192ac69efa26 100644 --- a/vllm/utils/import_utils.py +++ b/vllm/utils/import_utils.py @@ -23,17 +23,6 @@ logger = init_logger(__name__) -# TODO: This function can be removed if transformer_modules classes are -# serialized by value when communicating between processes -def init_cached_hf_modules() -> None: - """ - Lazy initialization of the Hugging Face modules. - """ - from transformers.dynamic_module_utils import init_hf_modules - - init_hf_modules() - - def import_pynvml(): """ Historical comments: diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 7b5c28eeb317..7b427b4a6cde 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -519,9 +519,7 @@ def __init__( shared_worker_lock: LockType, ): self.rank = rank - wrapper = WorkerWrapperBase( - vllm_config=vllm_config, rpc_rank=local_rank, global_rank=rank - ) + wrapper = WorkerWrapperBase(rpc_rank=local_rank, global_rank=rank) # TODO: move `init_worker` to executor level as a collective rpc call all_kwargs: list[dict] = [ {} for _ in range(vllm_config.parallel_config.world_size) diff --git a/vllm/v1/executor/ray_executor.py b/vllm/v1/executor/ray_executor.py index 2fd64e5c2277..292fa877f5a4 100644 --- a/vllm/v1/executor/ray_executor.py +++ b/vllm/v1/executor/ray_executor.py @@ -208,9 +208,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", **ray_remote_kwar num_gpus=num_gpus, scheduling_strategy=scheduling_strategy, **ray_remote_kwargs, - )(RayWorkerWrapper).remote( # type: ignore[attr-defined] - vllm_config=self.vllm_config, rpc_rank=rank - ) + )(RayWorkerWrapper).remote(rpc_rank=rank) else: worker = ray.remote( num_cpus=0, @@ -218,9 +216,8 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", **ray_remote_kwar resources={current_platform.ray_device_key: num_gpus}, scheduling_strategy=scheduling_strategy, **ray_remote_kwargs, - )(RayWorkerWrapper).remote( # type: ignore[attr-defined] - vllm_config=self.vllm_config, rpc_rank=rank - ) + )(RayWorkerWrapper).remote(rpc_rank=rank) + worker_metadata.append(RayWorkerMetaData(worker=worker, created_rank=rank)) worker_ips = ray.get( diff --git a/vllm/v1/executor/uniproc_executor.py b/vllm/v1/executor/uniproc_executor.py index b8ca92255430..b9c7b550170b 100644 --- a/vllm/v1/executor/uniproc_executor.py +++ b/vllm/v1/executor/uniproc_executor.py @@ -26,7 +26,7 @@ class UniProcExecutor(Executor): def _init_executor(self) -> None: """Initialize the worker and load the model.""" - self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config, rpc_rank=0) + self.driver_worker = WorkerWrapperBase(rpc_rank=0) distributed_init_method, rank, local_rank = self._distributed_args() kwargs = dict( vllm_config=self.vllm_config, diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index fd4ee596c30e..30e7125c0f0e 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -86,12 +86,6 @@ def __init__( precision = envs.VLLM_FLOAT32_MATMUL_PRECISION torch.set_float32_matmul_precision(precision) - if self.model_config.trust_remote_code: - # note: lazy import to avoid importing torch before initializing - from vllm.utils.import_utils import init_cached_hf_modules - - init_cached_hf_modules() - # Buffers saved before sleep self._sleep_saved_buffers: dict[str, torch.Tensor] = {} diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index 3ece4c58214a..d37997eadf63 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -85,12 +85,6 @@ def __init__( else: self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[self.cache_config.cache_dtype] - if self.model_config.trust_remote_code: - # note: lazy import to avoid importing torch before initializing - from vllm.utils.import_utils import init_cached_hf_modules - - init_cached_hf_modules() - # Delay profiler initialization to the start of the profiling. # This is because in vLLM V1, MP runtime is initialized before the # TPU Worker is initialized. The profiler server needs to start after diff --git a/vllm/v1/worker/worker_base.py b/vllm/v1/worker/worker_base.py index d06ae2fdf6fd..99b931dca716 100644 --- a/vllm/v1/worker/worker_base.py +++ b/vllm/v1/worker/worker_base.py @@ -178,7 +178,6 @@ class WorkerWrapperBase: def __init__( self, - vllm_config: VllmConfig, rpc_rank: int = 0, global_rank: int | None = None, ) -> None: @@ -194,21 +193,10 @@ def __init__( """ self.rpc_rank = rpc_rank self.global_rank = self.rpc_rank if global_rank is None else global_rank - self.worker: WorkerBase | None = None - - # do not store this `vllm_config`, `init_worker` will set the final - # one. - # TODO: investigate if we can remove this field in `WorkerWrapperBase`, - # `init_cached_hf_modules` should be unnecessary now. - self.vllm_config: VllmConfig | None = None - - # `model_config` can be None in tests - model_config = vllm_config.model_config - if model_config and model_config.trust_remote_code: - # note: lazy import to avoid importing torch before initializing - from vllm.utils.import_utils import init_cached_hf_modules - init_cached_hf_modules() + # Initialized after init_worker is called + self.worker: WorkerBase + self.vllm_config: VllmConfig def shutdown(self) -> None: if self.worker is not None: @@ -241,27 +229,34 @@ def init_worker(self, all_kwargs: list[dict[str, Any]]) -> None: Arguments are passed to the worker class constructor. """ kwargs = all_kwargs[self.rpc_rank] - self.vllm_config = kwargs.get("vllm_config") - assert self.vllm_config is not None, ( + + vllm_config: VllmConfig | None = kwargs.get("vllm_config") + assert vllm_config is not None, ( "vllm_config is required to initialize the worker" ) - self.vllm_config.enable_trace_function_call_for_thread() + self.vllm_config = vllm_config + + vllm_config.enable_trace_function_call_for_thread() from vllm.plugins import load_general_plugins load_general_plugins() - if isinstance(self.vllm_config.parallel_config.worker_cls, str): - worker_class = resolve_obj_by_qualname( - self.vllm_config.parallel_config.worker_cls + parallel_config = vllm_config.parallel_config + if isinstance(parallel_config.worker_cls, str): + worker_class: type[WorkerBase] = resolve_obj_by_qualname( + parallel_config.worker_cls ) else: raise ValueError( - "passing worker_cls is no longer supported. Please pass keep the class in a separate module and pass the qualified name of the class as a string." # noqa: E501 + "passing worker_cls is no longer supported. " + "Please pass keep the class in a separate module " + "and pass the qualified name of the class as a string." ) - if self.vllm_config.parallel_config.worker_extension_cls: + + if parallel_config.worker_extension_cls: worker_extension_cls = resolve_obj_by_qualname( - self.vllm_config.parallel_config.worker_extension_cls + parallel_config.worker_extension_cls ) extended_calls = [] if worker_extension_cls not in worker_class.__bases__: @@ -294,7 +289,7 @@ def init_worker(self, all_kwargs: list[dict[str, Any]]) -> None: "This argument is needed for mm_processor_cache_type='shm'." ) - mm_config = self.vllm_config.model_config.multimodal_config + mm_config = vllm_config.model_config.multimodal_config if mm_config and mm_config.mm_processor_cache_type == "shm": raise ValueError(msg) else: @@ -303,7 +298,7 @@ def init_worker(self, all_kwargs: list[dict[str, Any]]) -> None: self.mm_receiver_cache = None else: self.mm_receiver_cache = worker_receiver_cache_from_config( - self.vllm_config, + vllm_config, MULTIMODAL_REGISTRY, shared_worker_lock, ) @@ -311,7 +306,6 @@ def init_worker(self, all_kwargs: list[dict[str, Any]]) -> None: with set_current_vllm_config(self.vllm_config): # To make vLLM config available during worker initialization self.worker = worker_class(**kwargs) - assert self.worker is not None def initialize_from_config(self, kv_cache_configs: list[Any]) -> None: kv_cache_config = kv_cache_configs[self.global_rank] @@ -358,20 +352,15 @@ def _apply_mm_cache(self, scheduler_output: SchedulerOutput) -> None: ) def execute_model( - self, - scheduler_output: SchedulerOutput, - *args, - **kwargs, + self, scheduler_output: SchedulerOutput ) -> ModelRunnerOutput | AsyncModelRunnerOutput | None: self._apply_mm_cache(scheduler_output) - assert self.worker is not None - return self.worker.execute_model(scheduler_output, *args, **kwargs) + return self.worker.execute_model(scheduler_output) def reset_mm_cache(self) -> None: mm_receiver_cache = self.mm_receiver_cache if mm_receiver_cache is not None: mm_receiver_cache.clear_cache() - assert self.worker is not None self.worker.reset_mm_cache()