Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions tests/v1/engine/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from dataclasses import dataclass
from typing import TypeAlias

import numpy as np
import torch
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast

Expand Down Expand Up @@ -369,9 +370,9 @@ def get_outputs(self) -> list[EngineCoreOutput]:
self.generated_logprobs_raw[req_idx][token_idx]
)
logprobs = LogprobsLists(
[logprobs_token_ids_],
[logprobs_],
[sampled_token_ranks_],
np.array([logprobs_token_ids_]),
np.array([logprobs_]),
np.array([sampled_token_ranks_]),
)
else:
logprobs = None
Expand Down
7 changes: 6 additions & 1 deletion vllm/v1/engine/logprobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,12 @@ def _update_sample_logprobs(self, logprobs_lists: LogprobsLists) -> None:

token_ids_lst, logprobs_lst, ranks_lst, _ = logprobs_lists

for rank, logprobs, token_ids in zip(ranks_lst, logprobs_lst, token_ids_lst):
for rank_np, logprobs_np, token_ids_np in zip(
ranks_lst, logprobs_lst, token_ids_lst
):
rank = rank_np.tolist()
logprobs = logprobs_np.tolist()
token_ids = token_ids_np.tolist()
# Detokenize (non-incrementally).
decoded_tokens = (
NONES
Expand Down
13 changes: 7 additions & 6 deletions vllm/v1/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, NamedTuple

import numpy as np
import torch

if TYPE_CHECKING:
Expand All @@ -15,11 +16,11 @@

class LogprobsLists(NamedTuple):
# [num_reqs x num_generated_tokens, max_num_logprobs + 1]
logprob_token_ids: list[list[int]]
logprob_token_ids: np.ndarray
# [num_reqs x num_generated_tokens, max_num_logprobs + 1]
logprobs: list[list[float]]
logprobs: np.ndarray
# [num_reqs x num_generated_tokens]
sampled_token_ranks: list[int]
sampled_token_ranks: np.ndarray
# [num_reqs]
# Used for slicing the logprobs in cases like speculative
# decoding where the number of generated tokens may be
Expand Down Expand Up @@ -60,9 +61,9 @@ class LogprobsTensors(NamedTuple):

def tolists(self, cu_num_generated_tokens: list[int] | None = None):
return LogprobsLists(
self.logprob_token_ids.tolist(),
self.logprobs.tolist(),
self.selected_token_ranks.tolist(),
self.logprob_token_ids.cpu().numpy(),
self.logprobs.cpu().numpy(),
self.selected_token_ranks.cpu().numpy(),
cu_num_generated_tokens,
)

Expand Down