Skip to content

[Bugfix] Fix torch.compile() error when using MultiprocessingGPUExecutor#5229

Merged
simon-mo merged 1 commit intovllm-project:mainfrom
zifeitong:cohere
Jun 4, 2024
Merged

[Bugfix] Fix torch.compile() error when using MultiprocessingGPUExecutor#5229
simon-mo merged 1 commit intovllm-project:mainfrom
zifeitong:cohere

Conversation

@zifeitong
Copy link
Copy Markdown
Contributor

torch.compile() by default will use a process pool for async compiling which doesn't work with MultiprocessingGPUExecutor. Models uses torch.compile() like commandr therefore won't work:

[rank0]: Traceback (most recent call last):
[rank0]:   File "/var/data/home/vllm-upstream/benchmarks/benchmark_latency.py", line 226, in <module>
[rank0]:     main(args)
[rank0]:   File "/var/data/home/vllm-upstream/benchmarks/benchmark_latency.py", line 22, in main
[rank0]:     llm = LLM(model=args.model,
[rank0]:           ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/var/data/home/vllm-upstream/vllm/entrypoints/llm.py", line 143, in __init__
[rank0]:     self.llm_engine = LLMEngine.from_engine_args(
[rank0]:                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/var/data/home/vllm-upstream/vllm/engine/llm_engine.py", line 359, in from_engine_args
[rank0]:     engine = cls(
[rank0]:              ^^^^
[rank0]:   File "/var/data/home/vllm-upstream/vllm/engine/llm_engine.py", line 222, in __init__
[rank0]:     self.model_executor = executor_class(
[rank0]:                           ^^^^^^^^^^^^^^^
[rank0]:   File "/var/data/home/vllm-upstream/vllm/executor/distributed_gpu_executor.py", line 25, in __init__
[rank0]:     super().__init__(*args, **kwargs)
[rank0]:   File "/var/data/home/vllm-upstream/vllm/executor/executor_base.py", line 41, in __init__
[rank0]:     self._init_executor()
[rank0]:   File "/var/data/home/vllm-upstream/vllm/executor/multiproc_gpu_executor.py", line 66, in _init_executor
[rank0]:     self._run_workers("load_model",
[rank0]:   File "/var/data/home/vllm-upstream/vllm/executor/multiproc_gpu_executor.py", line 123, in _run_workers
[rank0]:     ] + [output.get() for output in worker_outputs]
[rank0]:         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/var/data/home/vllm-upstream/vllm/executor/multiproc_gpu_executor.py", line 123, in <listcomp>
[rank0]:     ] + [output.get() for output in worker_outputs]
[rank0]:          ^^^^^^^^^^^^
[rank0]:   File "/var/data/home/vllm-upstream/vllm/executor/multiproc_worker_utils.py", line 58, in get
[rank0]:     raise self.result.exception
[rank0]: AssertionError: daemonic processes are not allowed to have children
ERROR 06-03 20:34:21 multiproc_worker_utils.py:119] Worker VllmWorkerProcess pid 114600 died, exit code: -15

Setting TORCHINDUCTOR_COMPILE_THREADS to 1 can disable the async compiling process pool.

@zifeitong zifeitong force-pushed the cohere branch 2 times, most recently from 7a6d29c to 0c36c6d Compare June 3, 2024 20:47
Copy link
Copy Markdown
Member

@youkaichao youkaichao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@simon-mo simon-mo merged commit a58f24e into vllm-project:main Jun 4, 2024
chengzhi-lu pushed a commit to chengzhi-lu/vllm-infersche that referenced this pull request Jun 6, 2024
robertgshaw2-redhat pushed a commit to neuralmagic/nm-vllm that referenced this pull request Jun 11, 2024
joerunde pushed a commit to joerunde/vllm that referenced this pull request Jun 17, 2024
xjpang pushed a commit to xjpang/vllm that referenced this pull request Jun 27, 2024
xjpang pushed a commit to xjpang/vllm that referenced this pull request Jul 8, 2024
xjpang pushed a commit to xjpang/vllm that referenced this pull request Jul 24, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants