-
-
Notifications
You must be signed in to change notification settings - Fork 15.7k
[Core] Add sleep level 0 mode with enqueue/wait pattern #33195
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+167
−26
Merged
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
233b06f
[Core] Add sleep level 0 mode with enqueue/wait pattern
jaewonlee-fb 7b0d059
Retrigger CI
jaewonlee-fb a333619
Fix: Skip reset_prefix_cache for level 0 sleep in async path
jaewonlee-fb b5fcb4c
Retrigger CI
jaewonlee-fb cff329a
Merge branch 'main' into sleep-level-0
houseroad File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -458,6 +458,93 @@ def generate( | |
|
|
||
| return self.engine_class.validate_outputs(outputs, RequestOutput) | ||
|
|
||
| def enqueue( | ||
| self, | ||
| prompts: PromptType | Sequence[PromptType], | ||
| sampling_params: SamplingParams | Sequence[SamplingParams] | None = None, | ||
| lora_request: list[LoRARequest] | LoRARequest | None = None, | ||
| priority: list[int] | None = None, | ||
| use_tqdm: bool | Callable[..., tqdm] = True, | ||
| tokenization_kwargs: dict[str, Any] | None = None, | ||
| ) -> list[str]: | ||
| """Enqueue prompts for generation without waiting for completion. | ||
|
|
||
| This method adds requests to the engine queue but does not start | ||
| processing them. Use wait_for_completion() to process the queued | ||
| requests and get results. | ||
|
|
||
| Args: | ||
| prompts: The prompts to the LLM. See generate() for details. | ||
| sampling_params: The sampling parameters for text generation. | ||
| lora_request: LoRA request to use for generation, if any. | ||
| priority: The priority of the requests, if any. | ||
| use_tqdm: If True, shows a tqdm progress bar while adding requests. | ||
| tokenization_kwargs: Overrides for `tokenizer.encode`. | ||
|
|
||
| Returns: | ||
| A list of request IDs for the enqueued requests. | ||
| """ | ||
| model_config = self.model_config | ||
| runner_type = model_config.runner_type | ||
| if runner_type != "generate": | ||
| raise ValueError("LLM.enqueue() is only supported for generative models.") | ||
|
|
||
| if sampling_params is None: | ||
| sampling_params = self.get_default_sampling_params() | ||
|
|
||
| # Use the same preprocessing as _run_completion | ||
| seq_prompts = prompt_to_seq(prompts) | ||
| seq_params = self._params_to_seq(sampling_params, len(seq_prompts)) | ||
|
|
||
| if any(param.truncate_prompt_tokens is not None for param in seq_params): | ||
| engine_prompts: Sequence[DictPrompt | TokPrompt] = [ | ||
| engine_prompt | ||
| for prompt, param in zip(seq_prompts, seq_params) | ||
| for engine_prompt in self._preprocess_completion( | ||
| [prompt], | ||
| tokenization_kwargs=merge_kwargs( | ||
| tokenization_kwargs, | ||
| dict(truncate_prompt_tokens=param.truncate_prompt_tokens), | ||
| ), | ||
| ) | ||
| ] | ||
| else: | ||
| engine_prompts = self._preprocess_completion( | ||
| seq_prompts, | ||
| tokenization_kwargs=tokenization_kwargs, | ||
| ) | ||
|
|
||
| request_ids = self._validate_and_add_requests( | ||
| prompts=engine_prompts, | ||
| params=seq_params, | ||
| use_tqdm=use_tqdm, | ||
| lora_request=self._get_modality_specific_lora_reqs( | ||
| engine_prompts, lora_request | ||
| ), | ||
| tokenization_kwargs=tokenization_kwargs, | ||
| priority=priority, | ||
| ) | ||
|
|
||
| return request_ids | ||
|
|
||
| def wait_for_completion( | ||
| self, | ||
| use_tqdm: bool | Callable[..., tqdm] = True, | ||
| ) -> list[RequestOutput]: | ||
| """Wait for all enqueued requests to complete and return results. | ||
|
|
||
| This method processes all requests currently in the engine queue | ||
| and returns their outputs. Use after enqueue() to get results. | ||
|
|
||
| Args: | ||
| use_tqdm: If True, shows a tqdm progress bar. | ||
|
|
||
| Returns: | ||
| A list of RequestOutput objects for all completed requests. | ||
| """ | ||
| outputs = self._run_engine(use_tqdm=use_tqdm) | ||
| return self.engine_class.validate_outputs(outputs, RequestOutput) | ||
|
|
||
| def _get_modality_specific_lora_reqs( | ||
| self, | ||
| prompts: Sequence[DictPrompt | TokPrompt], | ||
|
|
@@ -1618,19 +1705,22 @@ def sleep(self, level: int = 1): | |
| during the sleep period, before `wake_up` is called. | ||
|
|
||
| Args: | ||
| level: The sleep level. Level 1 sleep will offload the model | ||
| weights and discard the kv cache. The content of kv cache | ||
| is forgotten. Level 1 sleep is good for sleeping and waking | ||
| up the engine to run the same model again. The model weights | ||
| are backed up in CPU memory. Please make sure there's enough | ||
| CPU memory to store the model weights. Level 2 sleep will | ||
| discard both the model weights and the kv cache. The content | ||
| of both the model weights and kv cache is forgotten. Level 2 | ||
| sleep is good for sleeping and waking up the engine to run a | ||
| different model or update the model, where previous model | ||
| weights are not needed. It reduces CPU memory pressure. | ||
| level: The sleep level. | ||
| - Level 0: Pause scheduling but continue accepting requests. | ||
| Requests are queued but not processed. | ||
| - Level 1: Offload model weights to CPU, discard KV cache. | ||
| The content of kv cache is forgotten. Good for | ||
| sleeping and waking up the engine to run the same | ||
| model again. Please make sure there's enough CPU | ||
| memory to store the model weights. | ||
| - Level 2: Discard all GPU memory (weights + KV cache). | ||
| Good for sleeping and waking up the engine to run | ||
| a different model or update the model, where | ||
| previous model weights are not needed. It reduces | ||
| CPU memory pressure. | ||
| """ | ||
| self.reset_prefix_cache() | ||
| if level > 0: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what's the behavior of level 0? Will this cause any breakage if user use level 0 before? |
||
| self.reset_prefix_cache() | ||
| self.llm_engine.sleep(level=level) | ||
|
|
||
| def wake_up(self, tags: list[str] | None = None): | ||
|
|
@@ -1641,9 +1731,10 @@ def wake_up(self, tags: list[str] | None = None): | |
| Args: | ||
| tags: An optional list of tags to reallocate the engine memory | ||
| for specific memory allocations. Values must be in | ||
| `("weights", "kv_cache")`. If None, all memory is reallocated. | ||
| wake_up should be called with all tags (or None) before the | ||
| engine is used again. | ||
| `("weights", "kv_cache", "scheduling")`. If None, all memory | ||
| is reallocated. wake_up should be called with all tags | ||
| (or None) before the engine is used again. | ||
| Use tags=["scheduling"] to resume from level 0 sleep. | ||
| """ | ||
| self.llm_engine.wake_up(tags) | ||
|
|
||
|
|
@@ -1810,7 +1901,7 @@ def _validate_and_add_requests( | |
| lora_request: Sequence[LoRARequest | None] | LoRARequest | None, | ||
| tokenization_kwargs: dict[str, Any] | None = None, | ||
| priority: list[int] | None = None, | ||
| ) -> None: | ||
| ) -> list[str]: | ||
| num_requests = len(prompts) | ||
| seq_params = self._params_to_seq(params, num_requests) | ||
| seq_lora_requests = self._lora_request_to_seq(lora_request, num_requests) | ||
|
|
@@ -1844,6 +1935,8 @@ def _validate_and_add_requests( | |
| self.llm_engine.abort_request(added_request_ids, internal=True) | ||
| raise e | ||
|
|
||
| return added_request_ids | ||
|
|
||
| def _add_request( | ||
| self, | ||
| prompt: PromptType | DictPrompt | TokPrompt, | ||
|
|
@@ -1895,7 +1988,9 @@ def _add_request( | |
| return engine_request.request_id | ||
|
|
||
| def _run_engine( | ||
| self, *, use_tqdm: bool | Callable[..., tqdm] = True | ||
| self, | ||
| *, | ||
| use_tqdm: bool | Callable[..., tqdm] = True, | ||
| ) -> list[RequestOutput | PoolingRequestOutput]: | ||
| # Initialize tqdm. | ||
| if use_tqdm: | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where do we expect to call this function?