diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index b700f0631392..e715a1d767fa 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -597,17 +597,6 @@ def __init__( wrapper.init_worker(all_kwargs) self.worker = wrapper - scheduler_config = vllm_config.scheduler_config - self.use_async_scheduling = scheduler_config.async_scheduling - if self.use_async_scheduling: - self.async_output_queue: queue.Queue = queue.Queue() - self.async_output_copy_thread = Thread( - target=self.async_output_busy_loop, - daemon=True, - name="WorkerAsyncOutputCopy", - ) - self.async_output_copy_thread.start() - self.setup_proc_title_and_log_prefix( enable_ep=vllm_config.parallel_config.enable_expert_parallel ) @@ -622,6 +611,17 @@ def __init__( ) self.worker.load_model() + scheduler_config = vllm_config.scheduler_config + self.use_async_scheduling = scheduler_config.async_scheduling + if self.use_async_scheduling: + self.async_output_queue: queue.Queue = queue.Queue() + self.async_output_copy_thread = Thread( + target=self.async_output_busy_loop, + daemon=True, + name="WorkerAsyncOutputCopy", + ) + self.async_output_copy_thread.start() + # Set block size based on the attention backends current_platform.update_block_size_for_backend(vllm_config) @@ -911,6 +911,18 @@ def handle_output(self, output: Any): def async_output_busy_loop(self): """Entrypoint for the thread which handles outputs asynchronously.""" + + # set device to the worker device for the thread. + # a thread will not inherit the context of the main thread. + # when calling any cuda runtime functions, it will implicitly + # create a new cuda context on device 0, consuming extra memory. + # here we set the device to the worker device for the thread, + # enforcing the context to be the same as the main thread. + from vllm.platforms import current_platform + + if hasattr(self.worker, "device"): + current_platform.set_device(self.worker.device) + while True: output = self.async_output_queue.get() self.enqueue_output(output)