Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions vllm/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def __init__(
*,
multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None,
kv_transfer_params: Optional[dict[str, Any]] = None,
prompt_hidden_states: Optional[torch.Tensor] = None,
# Forward compatibility, code that uses args added in new release can
# still run with older versions of vLLM without breaking.
**kwargs: Any,
Expand All @@ -139,6 +140,7 @@ def __init__(
self.encoder_prompt_token_ids = encoder_prompt_token_ids
self.num_cached_tokens = num_cached_tokens
self.kv_transfer_params = kv_transfer_params
self.prompt_hidden_states = prompt_hidden_states

def add(self, next_output: "RequestOutput", aggregate: bool) -> None:
"""Merge subsequent RequestOutput into this one"""
Expand Down
2 changes: 2 additions & 0 deletions vllm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ class SamplingParams(
response. When set to -1, return all `vocab_size` log probabilities."""
prompt_logprobs: Optional[int] = None
"""Number of log probabilities to return per prompt token."""
return_prompt_hidden_states: bool = False

# NOTE: This parameter is only exposed at the engine level for now.
# It is not exposed in the OpenAI API server, as the OpenAI API does
# not support returning only a list of token IDs.
Expand Down
3 changes: 3 additions & 0 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,6 +845,7 @@
sampled_token_ids = model_runner_output.sampled_token_ids
logprobs = model_runner_output.logprobs
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
prompt_hidden_states_dict = model_runner_output.prompt_hidden_states_dict

Check failure on line 848 in vllm/v1/core/sched/scheduler.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/v1/core/sched/scheduler.py:848:81: E501 Line too long (81 > 80)
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
pooler_outputs = model_runner_output.pooler_output
num_nans_in_logits = model_runner_output.num_nans_in_logits
Expand Down Expand Up @@ -932,6 +933,7 @@

# Get prompt logprobs for this request.
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
prompt_hidden_states = prompt_hidden_states_dict.get(req_id)
if new_token_ids or pooler_output is not None \
or kv_transfer_params:

Expand All @@ -943,6 +945,7 @@
finish_reason=request.get_finished_reason(),
new_logprobs=new_logprobs,
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
prompt_hidden_states=prompt_hidden_states,
pooling_output=pooler_output,
stop_reason=request.stop_reason,
events=request.take_events(),
Expand Down
1 change: 1 addition & 0 deletions vllm/v1/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ class EngineCoreOutput(
new_logprobs: Optional[LogprobsLists] = None
new_prompt_logprobs_tensors: Optional[LogprobsTensors] = None

prompt_hidden_states: Optional[torch.Tensor] = None
pooling_output: Optional[torch.Tensor] = None

finish_reason: Optional[FinishReason] = None
Expand Down
1 change: 1 addition & 0 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
model_output = self.execute_model_with_error_logging(
self.model_executor.execute_model, # type: ignore
scheduler_output)
print("lxy model_output to enginecoreoutput")
engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, model_output) # type: ignore

Expand Down
56 changes: 56 additions & 0 deletions vllm/v1/engine/hidden_states.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
from dataclasses import dataclass
from typing import Optional

import torch

from vllm.logger import init_logger
from vllm.sequence import PromptLogprobs
from vllm.v1.engine import EngineCoreOutput

logger = init_logger(__name__)

NONES = itertools.repeat(None)


@dataclass
class HiddenStatesProcessor:
prompt_hidden_states: Optional[torch.Tensor]

@classmethod
def from_new_request(cls, ) -> "HiddenStatesProcessor":
return cls(prompt_hidden_states=None)

def _set_prompt_hidden_states(
self,
prompt_hidden_states_tensor: torch.Tensor,
) -> None:
# We only need to set the prompt hidden states once.
assert self.prompt_hidden_states is None

self.prompt_hidden_states = prompt_hidden_states_tensor

def pop_prompt_hidden_states(self) -> Optional[PromptLogprobs]:
"""Pop and return all request prompt hidden states

The hidden states processor aggregates prompt chunk hidden states
over one or more prefill chunks. This method returns
all prompt hidden states at once and then forgets them.
Ensures correct RequestOutputKind.DELTA semantics
wherein all prompt hidden states are returned at once at
the end of prefill.

Returns:
None if prompt hidden states are disabled for this request.
List of all prompt hidden states, otherwise.
"""
plp = self.prompt_hidden_states
if plp:
self.prompt_hidden_states = None
return plp

def update_from_output(self, output: EngineCoreOutput) -> None:
if output.prompt_hidden_states is not None:
self._set_prompt_hidden_states(output.prompt_hidden_states)
1 change: 1 addition & 0 deletions vllm/v1/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ def step(self) -> Union[list[RequestOutput], list[PoolingRequestOutput]]:

# 2) Process EngineCoreOutputs.
iteration_stats = IterationStats() if self.log_stats else None
print("lxy call process_outputs")
processed_outputs = self.output_processor.process_outputs(
outputs.outputs,
engine_core_timestamp=outputs.timestamp,
Expand Down
19 changes: 18 additions & 1 deletion vllm/v1/engine/output_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
from vllm.v1.engine.detokenizer import IncrementalDetokenizer
from vllm.v1.engine.hidden_states import HiddenStatesProcessor
from vllm.v1.engine.logprobs import LogprobsProcessor
from vllm.v1.engine.parallel_sampling import ParentRequest
from vllm.v1.metrics.stats import (IterationStats, LoRARequestStates,
Expand Down Expand Up @@ -93,6 +94,7 @@
arrival_time: float,
queue: Optional[RequestOutputCollector],
log_stats: bool,
hidden_states_processor: Optional[HiddenStatesProcessor],
):
self.request_id = request_id
self.parent_req = parent_req
Expand All @@ -111,6 +113,7 @@

self.stats = RequestStateStats(
arrival_time=arrival_time) if log_stats else None
self.hidden_states_processor = hidden_states_processor

@classmethod
def from_new_request(
Expand All @@ -137,10 +140,12 @@
request=request,
)
max_tokens_param = sampling_params.max_tokens
hidden_states_processor = HiddenStatesProcessor.from_new_request()
else:
logprobs_processor = None
detokenizer = None
max_tokens_param = None
hidden_states_processor = None
assert request.pooling_params is not None
output_kind = request.pooling_params.output_kind

Expand All @@ -159,6 +164,7 @@
arrival_time=request.arrival_time,
queue=queue,
log_stats=log_stats,
hidden_states_processor=hidden_states_processor,
)

def make_request_output(
Expand Down Expand Up @@ -204,7 +210,7 @@
finished: bool,
kv_transfer_params: Optional[dict[str, Any]] = None,
) -> Union[RequestOutput, PoolingRequestOutput]:

# Seeems here to process outputs
first_output = outputs[0]
if isinstance(first_output, PoolingOutput):
assert len(outputs) == 1
Expand All @@ -215,17 +221,23 @@
finished=finished,
)
assert self.logprobs_processor is not None
assert self.hidden_states_processor is not None
if self.output_kind == RequestOutputKind.DELTA:
# Side effect: logprobs processor forgets prompt logprobs
prompt_logprobs = self.logprobs_processor.pop_prompt_logprobs()
prompt_hidden_states = self.hidden_states_processor.pop_prompt_hidden_states(

Check failure on line 228 in vllm/v1/engine/output_processor.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/v1/engine/output_processor.py:228:81: E501 Line too long (89 > 80)
)
else:
prompt_logprobs = self.logprobs_processor.prompt_logprobs
prompt_hidden_states = self.hidden_states_processor.prompt_hidden_states

Check failure on line 232 in vllm/v1/engine/output_processor.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/v1/engine/output_processor.py:232:81: E501 Line too long (84 > 80)

# prompt logprobs is added here
return RequestOutput(
request_id=request_id,
prompt=self.prompt,
prompt_token_ids=self.prompt_token_ids,
prompt_logprobs=prompt_logprobs,
prompt_hidden_states=prompt_hidden_states,
outputs=cast(list[CompletionOutput], outputs),
finished=finished,
kv_transfer_params=kv_transfer_params,
Expand Down Expand Up @@ -399,6 +411,7 @@
kv_transfer_params = engine_core_output.kv_transfer_params
req_state.num_cached_tokens = engine_core_output.num_cached_tokens
req_state.is_prefilling = False
prompt_hidden_states = engine_core_output.prompt_hidden_states

if pooling_output is None:
assert req_state.detokenizer is not None
Expand All @@ -414,8 +427,12 @@
# if required.
req_state.logprobs_processor.update_from_output(
engine_core_output)
assert req_state.hidden_states_processor is not None
req_state.hidden_states_processor.update_from_output(
engine_core_output)

# 4) Create and handle RequestOutput objects.
print("lxy here make_request_output", prompt_hidden_states is None)
if request_output := req_state.make_request_output(
new_token_ids, pooling_output, finish_reason, stop_reason,
kv_transfer_params):
Expand Down
4 changes: 4 additions & 0 deletions vllm/v1/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ class ModelRunnerOutput:
# [prompt_len]
prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]]

# req_id ->
prompt_hidden_states_dict: dict[str, Optional[torch.Tensor]]

# [num_reqs, hidden_size]
pooler_output: list[Optional[torch.Tensor]]

Expand All @@ -128,5 +131,6 @@ class DraftTokenIds:
sampled_token_ids=[],
logprobs=None,
prompt_logprobs_dict={},
prompt_hidden_states_dict={},
pooler_output=[],
num_nans_in_logits=None)
8 changes: 7 additions & 1 deletion vllm/v1/worker/gpu_input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,10 +217,12 @@ def __init__(
# NOTE(rob): num_prompt_logprobs only includes reqs
# that are currently in the prefill phase.
self.num_prompt_logprobs: dict[str, int] = {}

# To accumulate prompt logprobs tensor chunks across prefill steps.
self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}

self.return_prompt_hidden_states_reqs: set[str] = set()
self.in_progress_prompt_hidden_states_cpu: dict[str, torch.Tensor] = {}

# Internal representation of per-step batch state changes, used for
# reordering persistent batch and generating logitsprocs batch state
# updates. Should reset each step.
Expand Down Expand Up @@ -358,6 +360,9 @@ def add_request(
self.num_prompt_logprobs[
req_id] = sampling_params.prompt_logprobs

if sampling_params.return_prompt_hidden_states:
self.return_prompt_hidden_states_reqs.add(req_id)

if sampling_params.allowed_token_ids:
self.has_allowed_token_ids.add(req_id)
if self.allowed_token_ids_mask_cpu_tensor is None:
Expand Down Expand Up @@ -447,6 +452,7 @@ def remove_request(self, req_id: str) -> Optional[int]:
self.num_logprobs.pop(req_id, None)
self.num_prompt_logprobs.pop(req_id, None)
self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
self.in_progress_prompt_hidden_states_cpu.pop(req_id, None)

self.has_allowed_token_ids.discard(req_id)
if self.allowed_token_ids_mask_cpu_tensor is not None:
Expand Down
90 changes: 90 additions & 0 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1448,6 +1448,7 @@
sampled_token_ids=[],
logprobs=None,
prompt_logprobs_dict={},
prompt_hidden_states_dict={},
pooler_output=pooler_output,
kv_connector_output=kv_connector_output,
)
Expand Down Expand Up @@ -1683,6 +1684,10 @@
hidden_states[:num_scheduled_tokens],
scheduler_output.num_scheduled_tokens,
)
prompt_hidden_states_dict = self._get_prompt_hidden_states_dict(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's make this an option by default disabled from vllm serve arguments

hidden_states[:num_scheduled_tokens],
scheduler_output.num_scheduled_tokens,
)

# Get the valid generated tokens.
sampled_token_ids = sampler_output.sampled_token_ids
Expand Down Expand Up @@ -1746,6 +1751,7 @@
sampled_token_ids=valid_sampled_token_ids,
logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict,
prompt_hidden_states_dict=prompt_hidden_states_dict,
pooler_output=[],
kv_connector_output=kv_connector_output,
num_nans_in_logits=num_nans_in_logits,
Expand Down Expand Up @@ -2123,6 +2129,90 @@

return prompt_logprobs_dict

def _get_prompt_hidden_states_dict(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

try extact to separate utils.py?

self,
hidden_states: torch.Tensor,
num_scheduled_tokens: dict[str, int],
) -> dict[str, Optional[torch.Tensor]]:
"""
This function is similar to _get_prompt_logprobs_dict but for prompt hidden states

Check failure on line 2138 in vllm/v1/worker/gpu_model_runner.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/v1/worker/gpu_model_runner.py:2138:81: E501 Line too long (90 > 80)
"""

return_prompt_hidden_states_reqs = self.input_batch.return_prompt_hidden_states_reqs

Check failure on line 2141 in vllm/v1/worker/gpu_model_runner.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/v1/worker/gpu_model_runner.py:2141:81: E501 Line too long (92 > 80)
if not return_prompt_hidden_states_reqs:
return {}

in_progress_dict = self.input_batch.in_progress_prompt_hidden_states_cpu
prompt_hidden_states_dict: dict[str, Optional[torch.Tensor]] = {}

# Since prompt hidden states are a rare feature, prioritize simple,
# maintainable loop over optimal performance.
completed_prefill_reqs = []
for req_id in return_prompt_hidden_states_reqs:
num_tokens = num_scheduled_tokens[req_id]

# Get metadata for this request.
request = self.requests[req_id]
num_prompt_tokens = len(request.prompt_token_ids)

# Set up target hidden_states_tensors object.
hidden_states_tensors = in_progress_dict.get(req_id)
if not hidden_states_tensors:
# Create empty hidden_states_tensors CPU tensors for the entire prompt.

Check failure on line 2161 in vllm/v1/worker/gpu_model_runner.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/v1/worker/gpu_model_runner.py:2161:81: E501 Line too long (87 > 80)
# If chunked, we'll copy in slice by slice.
hidden_states_tensors = torch.empty(
(num_prompt_tokens - 1, self.hidden_size),
dtype=torch.int32,
device="cpu")
in_progress_dict[req_id] = hidden_states_tensors

# Determine number of hidden states to retrieve.
start_idx = request.num_computed_tokens
start_tok = start_idx + 1
num_remaining_tokens = num_prompt_tokens - start_tok
if num_tokens <= num_remaining_tokens:
# This is a chunk, more tokens remain.
# In the == case, there are no more prompt logprobs to produce
# but we want to defer returning them to the next step where we
# have new generated tokens to return.
num_logits = num_tokens
else:
# This is the last chunk of prompt tokens to return.
num_logits = num_remaining_tokens
completed_prefill_reqs.append(req_id)
prompt_hidden_states_dict[req_id] = hidden_states_tensors

if num_logits <= 0:
# This can happen for the final chunk if we prefilled exactly
# (num_prompt_tokens - 1) tokens for this request in the prior
# step. There are no more prompt hidden states to produce.
continue

# Get the hidden states corresponding to this req's prompt tokens.
# If this is a partial request (i.e. chunked prefill),
# then there is prompt hidden states generated for each index.
req_idx = self.input_batch.req_id_to_index[req_id]
offset = self.query_start_loc.np[req_idx].item()
prompt_hidden_states = hidden_states[offset:offset + num_logits]

# Transfer GPU->CPU async.
chunk_slice = slice(start_idx, start_idx + num_logits)
hidden_states_tensors[chunk_slice].copy_(prompt_hidden_states,
non_blocking=True)

# Remove requests that have completed prefill from the batch
# num_prompt_logprobs_dict.
for req_id in completed_prefill_reqs:
return_prompt_hidden_states_reqs.remove(req_id)
del in_progress_dict[req_id]

# Must synchronize the non-blocking GPU->CPU transfers.
if prompt_hidden_states_dict:
self._sync_device()

# the return would be empty for prior steps
return prompt_hidden_states_dict

def _get_nans_in_logits(
self,
logits: Optional[torch.Tensor],
Expand Down
Loading
Loading