Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 112 additions & 17 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,93 @@ def generate(

return self.engine_class.validate_outputs(outputs, RequestOutput)

def enqueue(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where do we expect to call this function?

self,
prompts: PromptType | Sequence[PromptType],
sampling_params: SamplingParams | Sequence[SamplingParams] | None = None,
lora_request: list[LoRARequest] | LoRARequest | None = None,
priority: list[int] | None = None,
use_tqdm: bool | Callable[..., tqdm] = True,
tokenization_kwargs: dict[str, Any] | None = None,
) -> list[str]:
"""Enqueue prompts for generation without waiting for completion.

This method adds requests to the engine queue but does not start
processing them. Use wait_for_completion() to process the queued
requests and get results.

Args:
prompts: The prompts to the LLM. See generate() for details.
sampling_params: The sampling parameters for text generation.
lora_request: LoRA request to use for generation, if any.
priority: The priority of the requests, if any.
use_tqdm: If True, shows a tqdm progress bar while adding requests.
tokenization_kwargs: Overrides for `tokenizer.encode`.

Returns:
A list of request IDs for the enqueued requests.
"""
model_config = self.model_config
runner_type = model_config.runner_type
if runner_type != "generate":
raise ValueError("LLM.enqueue() is only supported for generative models.")

if sampling_params is None:
sampling_params = self.get_default_sampling_params()

# Use the same preprocessing as _run_completion
seq_prompts = prompt_to_seq(prompts)
seq_params = self._params_to_seq(sampling_params, len(seq_prompts))

if any(param.truncate_prompt_tokens is not None for param in seq_params):
engine_prompts: Sequence[DictPrompt | TokPrompt] = [
engine_prompt
for prompt, param in zip(seq_prompts, seq_params)
for engine_prompt in self._preprocess_completion(
[prompt],
tokenization_kwargs=merge_kwargs(
tokenization_kwargs,
dict(truncate_prompt_tokens=param.truncate_prompt_tokens),
),
)
]
else:
engine_prompts = self._preprocess_completion(
seq_prompts,
tokenization_kwargs=tokenization_kwargs,
)

request_ids = self._validate_and_add_requests(
prompts=engine_prompts,
params=seq_params,
use_tqdm=use_tqdm,
lora_request=self._get_modality_specific_lora_reqs(
engine_prompts, lora_request
),
tokenization_kwargs=tokenization_kwargs,
priority=priority,
)

return request_ids

def wait_for_completion(
self,
use_tqdm: bool | Callable[..., tqdm] = True,
) -> list[RequestOutput]:
"""Wait for all enqueued requests to complete and return results.

This method processes all requests currently in the engine queue
and returns their outputs. Use after enqueue() to get results.

Args:
use_tqdm: If True, shows a tqdm progress bar.

Returns:
A list of RequestOutput objects for all completed requests.
"""
outputs = self._run_engine(use_tqdm=use_tqdm)
return self.engine_class.validate_outputs(outputs, RequestOutput)

def _get_modality_specific_lora_reqs(
self,
prompts: Sequence[DictPrompt | TokPrompt],
Expand Down Expand Up @@ -1618,19 +1705,22 @@ def sleep(self, level: int = 1):
during the sleep period, before `wake_up` is called.

Args:
level: The sleep level. Level 1 sleep will offload the model
weights and discard the kv cache. The content of kv cache
is forgotten. Level 1 sleep is good for sleeping and waking
up the engine to run the same model again. The model weights
are backed up in CPU memory. Please make sure there's enough
CPU memory to store the model weights. Level 2 sleep will
discard both the model weights and the kv cache. The content
of both the model weights and kv cache is forgotten. Level 2
sleep is good for sleeping and waking up the engine to run a
different model or update the model, where previous model
weights are not needed. It reduces CPU memory pressure.
level: The sleep level.
- Level 0: Pause scheduling but continue accepting requests.
Requests are queued but not processed.
- Level 1: Offload model weights to CPU, discard KV cache.
The content of kv cache is forgotten. Good for
sleeping and waking up the engine to run the same
model again. Please make sure there's enough CPU
memory to store the model weights.
- Level 2: Discard all GPU memory (weights + KV cache).
Good for sleeping and waking up the engine to run
a different model or update the model, where
previous model weights are not needed. It reduces
CPU memory pressure.
"""
self.reset_prefix_cache()
if level > 0:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the behavior of level 0?

Will this cause any breakage if user use level 0 before?

self.reset_prefix_cache()
self.llm_engine.sleep(level=level)

def wake_up(self, tags: list[str] | None = None):
Expand All @@ -1641,9 +1731,10 @@ def wake_up(self, tags: list[str] | None = None):
Args:
tags: An optional list of tags to reallocate the engine memory
for specific memory allocations. Values must be in
`("weights", "kv_cache")`. If None, all memory is reallocated.
wake_up should be called with all tags (or None) before the
engine is used again.
`("weights", "kv_cache", "scheduling")`. If None, all memory
is reallocated. wake_up should be called with all tags
(or None) before the engine is used again.
Use tags=["scheduling"] to resume from level 0 sleep.
"""
self.llm_engine.wake_up(tags)

Expand Down Expand Up @@ -1810,7 +1901,7 @@ def _validate_and_add_requests(
lora_request: Sequence[LoRARequest | None] | LoRARequest | None,
tokenization_kwargs: dict[str, Any] | None = None,
priority: list[int] | None = None,
) -> None:
) -> list[str]:
num_requests = len(prompts)
seq_params = self._params_to_seq(params, num_requests)
seq_lora_requests = self._lora_request_to_seq(lora_request, num_requests)
Expand Down Expand Up @@ -1844,6 +1935,8 @@ def _validate_and_add_requests(
self.llm_engine.abort_request(added_request_ids, internal=True)
raise e

return added_request_ids

def _add_request(
self,
prompt: PromptType | DictPrompt | TokPrompt,
Expand Down Expand Up @@ -1895,7 +1988,9 @@ def _add_request(
return engine_request.request_id

def _run_engine(
self, *, use_tqdm: bool | Callable[..., tqdm] = True
self,
*,
use_tqdm: bool | Callable[..., tqdm] = True,
) -> list[RequestOutput | PoolingRequestOutput]:
# Initialize tqdm.
if use_tqdm:
Expand Down
3 changes: 2 additions & 1 deletion vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -938,7 +938,8 @@ async def reset_encoder_cache(self) -> None:
await self.engine_core.reset_encoder_cache_async()

async def sleep(self, level: int = 1) -> None:
await self.reset_prefix_cache()
if level > 0:
await self.reset_prefix_cache()
await self.engine_core.sleep_async(level)

if self.logger_manager is not None:
Expand Down
52 changes: 46 additions & 6 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,13 +614,43 @@ def reset_encoder_cache(self) -> None:
self.model_executor.reset_encoder_cache()

def sleep(self, level: int = 1):
self.model_executor.sleep(level)
"""Put the engine to sleep at the specified level.

Args:
level: Sleep level.
- Level 0: Pause scheduling only. Requests are still accepted
but not processed. No GPU memory changes.
- Level 1: Offload model weights to CPU, discard KV cache.
- Level 2: Discard all GPU memory.
"""
if level == 0:
# Level 0: Just pause scheduling, don't touch GPU
self.pause_scheduler()
else:
# Level 1+: Delegate to executor for GPU memory management
self.model_executor.sleep(level)

def wake_up(self, tags: list[str] | None = None):
self.model_executor.wake_up(tags)
"""Wake up the engine from sleep.

Args:
tags: Tags to wake up. Use ["scheduling"] for level 0 wake up.
"""
if tags is not None and "scheduling" in tags:
# Level 0 wake up: Resume scheduling
self.resume_scheduler()
# Remove "scheduling" from tags if there are other tags to process
remaining_tags = [t for t in tags if t != "scheduling"]
if remaining_tags:
self.model_executor.wake_up(remaining_tags)
else:
# Full wake up
self.resume_scheduler()
self.model_executor.wake_up(tags)

def is_sleeping(self) -> bool:
return self.model_executor.is_sleeping
"""Check if engine is sleeping at any level."""
return self._scheduler_paused or self.model_executor.is_sleeping

def execute_dummy_batch(self):
self.model_executor.execute_dummy_batch()
Expand Down Expand Up @@ -1023,15 +1053,21 @@ def run_busy_loop(self):
# 1) Poll the input queue until there is work to do.
self._process_input_queue()
# 2) Step the engine core and return the outputs.
self._process_engine_step()
# Skip if scheduling is paused (level 0 sleep)
if not self._scheduler_paused:
self._process_engine_step()
else:
# When scheduling is paused, still need to check for wake up
# by processing any utility requests that might resume scheduling
pass

def _process_input_queue(self):
"""Exits when an engine step needs to be performed."""

waited = False
while (
not self.engines_running
and not self.scheduler.has_requests()
and (not self.scheduler.has_requests() or self._scheduler_paused)
and not self.batch_queue
and not self._scheduler_paused
):
Expand Down Expand Up @@ -1414,11 +1450,15 @@ def run_busy_loop(self):
# 1) Poll the input queue until there is work to do.
self._process_input_queue()

# Skip processing if scheduling is paused (level 0 sleep)
if self._scheduler_paused:
continue

# 2) Step the engine core.
executed = self._process_engine_step()
self._maybe_publish_request_counts()

local_unfinished_reqs = self.scheduler.has_unfinished_requests()

Comment thread
jaewonlee-fb marked this conversation as resolved.
if not executed:
if not local_unfinished_reqs and not self.engines_running:
# All engines are idle.
Expand Down
3 changes: 2 additions & 1 deletion vllm/v1/engine/core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def collective_rpc(
raise NotImplementedError

def dp_engines_running(self) -> bool:
"""Returns True id data parallel engines are collectively in a
"""Returns True if data parallel engines are collectively in a
Comment thread
jaewonlee-fb marked this conversation as resolved.
running state."""
raise NotImplementedError

Expand Down Expand Up @@ -724,6 +724,7 @@ def get_output(self) -> EngineCoreOutputs:
# it is forwarded to the outputs_queue so we can raise it
# from this (run_output_handler) task to shut down the server.
outputs = self.outputs_queue.get()

if isinstance(outputs, Exception):
raise self._format_exception(outputs) from None
if outputs.wave_complete is not None:
Expand Down
6 changes: 5 additions & 1 deletion vllm/v1/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,11 @@ def step(self) -> list[RequestOutput | PoolingRequestOutput]:

# 4) Record stats
with record_function_or_nullcontext("llm_engine step: record_stats"):
if self.logger_manager is not None and outputs.scheduler_stats is not None:
if (
self.logger_manager is not None
and outputs.scheduler_stats is not None
and len(outputs.outputs) > 0
):
self.logger_manager.record(
scheduler_stats=outputs.scheduler_stats,
iteration_stats=iteration_stats,
Expand Down