-
-
Notifications
You must be signed in to change notification settings - Fork 10.6k
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Your current environment
Possibly relevant packages from pip list
Package Version Editable project location
---------------------------------------- ------------- -------------------------
accelerate 1.6.0
huggingface-hub 0.30.2
mistral_common 1.5.4
nest_asyncio 1.6.0
numpy 1.26.4
nvidia-cublas-cu12 12.4.5.8
nvidia-cuda-cupti-cu12 12.4.127
nvidia-cuda-nvrtc-cu12 12.4.127
nvidia-cuda-runtime-cu12 12.4.127
nvidia-cudnn-cu12 9.1.0.70
nvidia-cufft-cu12 11.2.1.3
nvidia-cufile-cu12 1.11.1.6
nvidia-curand-cu12 10.3.5.147
nvidia-cusolver-cu12 11.6.1.9
nvidia-cusparse-cu12 12.3.1.170
nvidia-cusparselt-cu12 0.6.2
nvidia-nccl-cu12 2.21.5
nvidia-nvjitlink-cu12 12.4.127
nvidia-nvtx-cu12 12.4.127
openai 1.55.2
pillow 11.2.1
pip 25.0.1
protobuf 3.20.3
pyarrow 19.0.1
pycountry 24.6.1
pycparser 2.22
pydantic 2.11.3
pydantic_core 2.33.1
ray 2.43.0
regex 2024.11.6
ruff 0.2.2
safetensors 0.5.3
stanza 1.10.1
starlette 0.46.2
tokenizers 0.21.1
torch 2.6.0
torchaudio 2.6.0
torchvision 0.21.0
tqdm 4.67.1
transformers 4.51.3
triton 3.2.0
urllib3 2.4.0
uvicorn 0.34.2
vllm 0.8.4
Possibly relevant env vars
"VLLM_LOGGING_LEVEL":"DEBUG",
"VLLM_TRACE_FUNCTION": "1",
"CUBLAS_WORKSPACE_CONFIG": ":4096:8",
"VLLM_USE_V1":"1",
"VLLM_HOST_IP":"localhost",
"VLLM_WORKER_MULTIPROC_METHOD": "spawn",
🐛 Describe the bug
Expected behavior: launch an AsyncLLM model, then use it to generate on X successive batches without issues (here, a batch is a group of inputs provided together in the asyncio run call).
Actual behavior: the AsyncLLM models processes the first batch without issue, then hangs on the second batch.
Notes:
- MWE ran on a node of 2 H100 GPUs.
- As you can see, we're using DP2 TP1 PP1
import asyncio
from vllm.v1.engine.async_llm import AsyncLLM
from vllm import AsyncEngineArgs, RequestOutput, SamplingParams
async def _async_one_item(
model,
params,
index: str,
):
output = await anext(
model.generate(request_id=index, prompt=f"This is a story about {index}", sampling_params=params)
)
return output
async def _async_batch(model, params, name) -> list:
processed_requests = [
_async_one_item(model, params, index=f"{name}_{index}")
for index in range(50)
]
results = await asyncio.gather(*processed_requests)
return results
# This version is also failing
async def _async_batch_v2(model, params, name) -> list:
if model.engine_core.is_sleeping_async():
await model.engine_core.wake_up_async()
processed_requests = [
_async_one_item(model, params, index=f"{name}_{index}")
for index in range(50)
]
results = await asyncio.gather(*processed_requests)
await model.engine_core.sleep_async()
return results
def main():
"model_name=,is_async=True,data_parallel_size=2,tensor_parallel_size=1,dtype=bfloat16,max_model_length=32768,max_num_batched_tokens=32768,gpu_memory_utilization=0.8,generation_parameters={max_new_tokens:32768,temperature:0.6,top_p:0.95}"
model_args = {
"model": "Qwen/Qwen2.5-7B",
"gpu_memory_utilization": 0.8,
"revision": "main",
"dtype": "bfloat16",
"trust_remote_code": True,
"tensor_parallel_size": 1,
"data_parallel_size": 2,
"pipeline_parallel_size": 1,
"swap_space": 4,
"seed": int(0),
"enforce_eager": True,
}
model = AsyncLLM.from_engine_args(AsyncEngineArgs(**model_args))
sampling_params = SamplingParams()
sampling_params.n = 1
sampling_params.max_tokens = 10
sampling_params.logprobs = 1
# This works just fine
responses_1: list[RequestOutput] = asyncio.run(_async_batch(model=model, params = sampling_params, name= "exp1"))
for response in responses_1:
print(response)
# This hangs after adding all the requests
# Last log lines of debug mode:
# INFO 04-29 12:28:56 [async_llm.py:228] Added request exp2_49.
# (EngineCore_1 pid=1452003) (EngineCore_0 pid=1452002)
# DEBUG 04-29 12:28:59 [core.py:411] EngineCore waiting for work.
# DEBUG 04-29 12:28:59 [core.py:411] EngineCore waiting for work.
responses_2: list[RequestOutput] = asyncio.run(_async_batch(model=model, params = sampling_params, name = "exp2"))
# We never reach this
for response in responses_2:
print(response)
if __name__ == "__main__":
main()
Before submitting a new issue...
- Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working