-
-
Notifications
You must be signed in to change notification settings - Fork 14.9k
Support prompt hidden states #24202
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Support prompt hidden states #24202
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,56 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| import itertools | ||
| from dataclasses import dataclass | ||
| from typing import Optional | ||
|
|
||
| import torch | ||
|
|
||
| from vllm.logger import init_logger | ||
| from vllm.sequence import PromptLogprobs | ||
| from vllm.v1.engine import EngineCoreOutput | ||
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
| NONES = itertools.repeat(None) | ||
|
|
||
|
|
||
| @dataclass | ||
| class HiddenStatesProcessor: | ||
| prompt_hidden_states: Optional[torch.Tensor] | ||
|
|
||
| @classmethod | ||
| def from_new_request(cls, ) -> "HiddenStatesProcessor": | ||
| return cls(prompt_hidden_states=None) | ||
|
|
||
| def _set_prompt_hidden_states( | ||
| self, | ||
| prompt_hidden_states_tensor: torch.Tensor, | ||
| ) -> None: | ||
| # We only need to set the prompt hidden states once. | ||
| assert self.prompt_hidden_states is None | ||
|
|
||
| self.prompt_hidden_states = prompt_hidden_states_tensor | ||
|
|
||
| def pop_prompt_hidden_states(self) -> Optional[PromptLogprobs]: | ||
| """Pop and return all request prompt hidden states | ||
|
|
||
| The hidden states processor aggregates prompt chunk hidden states | ||
| over one or more prefill chunks. This method returns | ||
| all prompt hidden states at once and then forgets them. | ||
| Ensures correct RequestOutputKind.DELTA semantics | ||
| wherein all prompt hidden states are returned at once at | ||
| the end of prefill. | ||
|
|
||
| Returns: | ||
| None if prompt hidden states are disabled for this request. | ||
| List of all prompt hidden states, otherwise. | ||
| """ | ||
| plp = self.prompt_hidden_states | ||
| if plp: | ||
| self.prompt_hidden_states = None | ||
| return plp | ||
|
|
||
| def update_from_output(self, output: EngineCoreOutput) -> None: | ||
| if output.prompt_hidden_states is not None: | ||
| self._set_prompt_hidden_states(output.prompt_hidden_states) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1448,6 +1448,7 @@ | |
| sampled_token_ids=[], | ||
| logprobs=None, | ||
| prompt_logprobs_dict={}, | ||
| prompt_hidden_states_dict={}, | ||
| pooler_output=pooler_output, | ||
| kv_connector_output=kv_connector_output, | ||
| ) | ||
|
|
@@ -1683,6 +1684,10 @@ | |
| hidden_states[:num_scheduled_tokens], | ||
| scheduler_output.num_scheduled_tokens, | ||
| ) | ||
| prompt_hidden_states_dict = self._get_prompt_hidden_states_dict( | ||
| hidden_states[:num_scheduled_tokens], | ||
| scheduler_output.num_scheduled_tokens, | ||
| ) | ||
|
|
||
| # Get the valid generated tokens. | ||
| sampled_token_ids = sampler_output.sampled_token_ids | ||
|
|
@@ -1746,6 +1751,7 @@ | |
| sampled_token_ids=valid_sampled_token_ids, | ||
| logprobs=logprobs_lists, | ||
| prompt_logprobs_dict=prompt_logprobs_dict, | ||
| prompt_hidden_states_dict=prompt_hidden_states_dict, | ||
| pooler_output=[], | ||
| kv_connector_output=kv_connector_output, | ||
| num_nans_in_logits=num_nans_in_logits, | ||
|
|
@@ -2123,6 +2129,90 @@ | |
|
|
||
| return prompt_logprobs_dict | ||
|
|
||
| def _get_prompt_hidden_states_dict( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. try extact to separate utils.py? |
||
| self, | ||
| hidden_states: torch.Tensor, | ||
| num_scheduled_tokens: dict[str, int], | ||
| ) -> dict[str, Optional[torch.Tensor]]: | ||
| """ | ||
| This function is similar to _get_prompt_logprobs_dict but for prompt hidden states | ||
| """ | ||
|
|
||
| return_prompt_hidden_states_reqs = self.input_batch.return_prompt_hidden_states_reqs | ||
| if not return_prompt_hidden_states_reqs: | ||
| return {} | ||
|
|
||
| in_progress_dict = self.input_batch.in_progress_prompt_hidden_states_cpu | ||
| prompt_hidden_states_dict: dict[str, Optional[torch.Tensor]] = {} | ||
|
|
||
| # Since prompt hidden states are a rare feature, prioritize simple, | ||
| # maintainable loop over optimal performance. | ||
| completed_prefill_reqs = [] | ||
| for req_id in return_prompt_hidden_states_reqs: | ||
| num_tokens = num_scheduled_tokens[req_id] | ||
|
|
||
| # Get metadata for this request. | ||
| request = self.requests[req_id] | ||
| num_prompt_tokens = len(request.prompt_token_ids) | ||
|
|
||
| # Set up target hidden_states_tensors object. | ||
| hidden_states_tensors = in_progress_dict.get(req_id) | ||
| if not hidden_states_tensors: | ||
| # Create empty hidden_states_tensors CPU tensors for the entire prompt. | ||
| # If chunked, we'll copy in slice by slice. | ||
| hidden_states_tensors = torch.empty( | ||
| (num_prompt_tokens - 1, self.hidden_size), | ||
| dtype=torch.int32, | ||
| device="cpu") | ||
| in_progress_dict[req_id] = hidden_states_tensors | ||
|
|
||
| # Determine number of hidden states to retrieve. | ||
| start_idx = request.num_computed_tokens | ||
| start_tok = start_idx + 1 | ||
| num_remaining_tokens = num_prompt_tokens - start_tok | ||
| if num_tokens <= num_remaining_tokens: | ||
| # This is a chunk, more tokens remain. | ||
| # In the == case, there are no more prompt logprobs to produce | ||
| # but we want to defer returning them to the next step where we | ||
| # have new generated tokens to return. | ||
| num_logits = num_tokens | ||
| else: | ||
| # This is the last chunk of prompt tokens to return. | ||
| num_logits = num_remaining_tokens | ||
| completed_prefill_reqs.append(req_id) | ||
| prompt_hidden_states_dict[req_id] = hidden_states_tensors | ||
|
|
||
| if num_logits <= 0: | ||
| # This can happen for the final chunk if we prefilled exactly | ||
| # (num_prompt_tokens - 1) tokens for this request in the prior | ||
| # step. There are no more prompt hidden states to produce. | ||
| continue | ||
|
|
||
| # Get the hidden states corresponding to this req's prompt tokens. | ||
| # If this is a partial request (i.e. chunked prefill), | ||
| # then there is prompt hidden states generated for each index. | ||
| req_idx = self.input_batch.req_id_to_index[req_id] | ||
| offset = self.query_start_loc.np[req_idx].item() | ||
| prompt_hidden_states = hidden_states[offset:offset + num_logits] | ||
|
|
||
| # Transfer GPU->CPU async. | ||
| chunk_slice = slice(start_idx, start_idx + num_logits) | ||
| hidden_states_tensors[chunk_slice].copy_(prompt_hidden_states, | ||
| non_blocking=True) | ||
|
|
||
| # Remove requests that have completed prefill from the batch | ||
| # num_prompt_logprobs_dict. | ||
| for req_id in completed_prefill_reqs: | ||
| return_prompt_hidden_states_reqs.remove(req_id) | ||
| del in_progress_dict[req_id] | ||
|
|
||
| # Must synchronize the non-blocking GPU->CPU transfers. | ||
| if prompt_hidden_states_dict: | ||
| self._sync_device() | ||
|
|
||
| # the return would be empty for prior steps | ||
| return prompt_hidden_states_dict | ||
|
|
||
| def _get_nans_in_logits( | ||
| self, | ||
| logits: Optional[torch.Tensor], | ||
|
|
||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's make this an option by default disabled from vllm serve arguments