diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 13da9aa38af0..adb6d9b1f6ae 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -539,10 +539,13 @@ def _get_logprobs( prompt_len = sampling_metadata.prompt_lens[i] prompt_tokens = sampling_metadata.seq_data[ seq_ids[0]].prompt_token_ids + # Swapped seqs have output tokens. + output_tokens = sampling_metadata.seq_data[ + seq_ids[0]].output_token_ids batched_logprobs_query_seq_indices.extend( sample_idx + j for j in range(prompt_len - 1)) batched_logprobs_query_token_indices.extend( - token_id for token_id in prompt_tokens[1:]) + token_id for token_id in prompt_tokens[1:] + output_tokens) sample_idx += prompt_len - 1 batched_logprobs_query_seq_indices.extend( [sample_idx + parent_id for parent_id in parent_ids]) @@ -586,8 +589,11 @@ def _get_logprobs( prompt_len = sampling_metadata.prompt_lens[i] prompt_tokens = sampling_metadata.seq_data[ seq_ids[0]].prompt_token_ids + # Swapped seqs have output tokens. + output_tokens = sampling_metadata.seq_data[ + seq_ids[0]].output_token_ids group_prompt_logprobs: PromptLogprobs = [None] - for token_id in prompt_tokens[1:]: + for token_id in prompt_tokens[1:] + output_tokens: prompt_logprobs_dict = { token_id: batched_logprobs_query_result[query_result_idx].item()