Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ def __del__(self):
cleanup()


@pytest.fixture
@pytest.fixture(scope="session")
def vllm_runner():
return VllmRunner

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
2. One of the provided stop tokens
3. The EOS token

Run `pytest tests/samplers/test_stop_reason.py`.
Run `pytest tests/engine/test_stop_reason.py`.
"""

import pytest
Expand Down
111 changes: 111 additions & 0 deletions tests/engine/test_stop_strings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from typing import Any, List, Optional

import pytest

from vllm import CompletionOutput, LLMEngine, SamplingParams

MODEL = "meta-llama/llama-2-7b-hf"
MAX_TOKENS = 200


@pytest.fixture(scope="session")
def vllm_model(vllm_runner):
return vllm_runner(MODEL)


@pytest.mark.skip_global_cleanup
def test_stop_basic(vllm_model):
_test_stopping(vllm_model.model.llm_engine,
stop=["."],
include_in_output=False,
expected_output="VLLM is a 100% volunteer organization",
expected_reason=".")

_test_stopping(vllm_model.model.llm_engine,
stop=["."],
include_in_output=True,
expected_output="VLLM is a 100% volunteer organization.",
expected_reason=".")


@pytest.mark.skip_global_cleanup
def test_stop_multi_tokens(vllm_model):
_test_stopping(
vllm_model.model.llm_engine,
stop=["group of peo", "short"],
include_in_output=False,
expected_output="VLLM is a 100% volunteer organization. We are a ",
expected_reason="group of peo")

_test_stopping(
vllm_model.model.llm_engine,
stop=["group of peo", "short"],
include_in_output=True,
expected_output=
"VLLM is a 100% volunteer organization. We are a group of peo",
expected_reason="group of peo")


@pytest.mark.skip_global_cleanup
def test_stop_partial_token(vllm_model):
_test_stopping(vllm_model.model.llm_engine,
stop=["gani"],
include_in_output=False,
expected_output="VLLM is a 100% volunteer or",
expected_reason="gani")

_test_stopping(vllm_model.model.llm_engine,
stop=["gani"],
include_in_output=True,
expected_output="VLLM is a 100% volunteer organi",
expected_reason="gani")


@pytest.mark.skip_global_cleanup
def test_stop_token_id(vllm_model):
# token id 13013 => " organization"

_test_stopping(vllm_model.model.llm_engine,
stop_token_ids=[13013],
include_in_output=False,
expected_output="VLLM is a 100% volunteer",
expected_reason=13013)

_test_stopping(vllm_model.model.llm_engine,
stop_token_ids=[13013],
include_in_output=True,
expected_output="VLLM is a 100% volunteer organization",
expected_reason=13013)


def _test_stopping(llm_engine: LLMEngine,
expected_output: str,
expected_reason: Any,
stop: Optional[List[str]] = None,
stop_token_ids: Optional[List[int]] = None,
include_in_output: bool = False) -> None:
llm_engine.add_request(
"id", "A story about vLLM:\n",
SamplingParams(
temperature=0.0,
max_tokens=MAX_TOKENS,
stop=stop,
stop_token_ids=stop_token_ids,
include_stop_str_in_output=include_in_output,
), None)

output: Optional[CompletionOutput] = None
output_text = ""
stop_reason = None
while llm_engine.has_unfinished_requests():
(request_output, ) = llm_engine.step()
(output, ) = request_output.outputs

# Ensure we don't backtrack
assert output.text.startswith(output_text)
output_text = output.text
stop_reason = output.stop_reason

assert output is not None
assert output_text == expected_output
assert stop_reason == expected_reason
98 changes: 65 additions & 33 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,9 +501,11 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup,

for seq, _ in child_seqs:
if seq_group.sampling_params.detokenize:
self.detokenizer.decode_sequence_inplace(
new_char_count = self.detokenizer.decode_sequence_inplace(
seq, seq_group.sampling_params)
self._check_stop(seq, seq_group.sampling_params)
else:
new_char_count = 0
self._check_stop(seq, new_char_count, seq_group.sampling_params)

# Non-beam search case
if not seq_group.sampling_params.use_beam_search:
Expand Down Expand Up @@ -795,56 +797,86 @@ def _get_stats(self,
time_e2e_requests=time_e2e_requests,
)

def _check_stop(self, seq: Sequence,
def _check_stop(self, seq: Sequence, new_char_count: int,
sampling_params: SamplingParams) -> None:
"""Stop the finished sequences."""
# Check if the sequence has reached max_model_len.
if seq.get_len() > self.scheduler_config.max_model_len:
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return
"""Stop the finished sequences.

# Check if the sequence has reached max_tokens.
if seq.get_output_len() == sampling_params.max_tokens:
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return
new_char_count is the number of chars added to the
sequence's output text for the newly generated token
"""

# Check if the minimum number of tokens has been generated yet;
# skip the stop string/token checks if not
if seq.get_output_len() < sampling_params.min_tokens:
return

if sampling_params.detokenize:
for stop_str in sampling_params.stop:
if seq.output_text.endswith(stop_str):
self._finalize_sequence(seq, sampling_params, stop_str)
seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = stop_str
return
# Check if the sequence has generated the EOS token.
if ((not sampling_params.ignore_eos)
and seq.get_last_token_id() == seq.eos_token_id):
seq.status = SequenceStatus.FINISHED_STOPPED
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to set seq.stop_reason to eos_token_id here?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sroy745 it's a good question. It was decided to keep the stop_reason as None in this case so that clients can know that it's due to EOS in simple cases without having to know the token ids. See discussion in #2976.

return

# Check if a stop token was encountered.
# This assumes a single token produced per step.
last_token_id = seq.get_last_token_id()
if last_token_id in sampling_params.stop_token_ids:
stop_str = self.get_tokenizer_for_seq(seq).convert_ids_to_tokens(
last_token_id)
self._finalize_sequence(seq, sampling_params, stop_str)
if new_char_count and (
not sampling_params.include_stop_str_in_output):
# Remove last token
seq.output_text = seq.output_text[:-new_char_count]
seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = last_token_id
return

# Check if the sequence has generated the EOS token.
if ((not sampling_params.ignore_eos)
and seq.get_last_token_id() == seq.eos_token_id):
# Check if any stop strings are matched.
stop_str = self._check_stop_strings(seq, new_char_count,
sampling_params)
if stop_str is not None:
seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = stop_str
return

def _finalize_sequence(self, seq: Sequence,
sampling_params: SamplingParams,
stop_string: str) -> None:
if sampling_params.include_stop_str_in_output:
# Check if the sequence has reached max_model_len.
if seq.get_len() > self.scheduler_config.max_model_len:
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return

if stop_string and seq.output_text.endswith(stop_string):
# Truncate the output text so that the stop string is
# not included in the output.
seq.output_text = seq.output_text[:-len(stop_string)]
# Check if the sequence has reached max_tokens.
if seq.get_output_len() == sampling_params.max_tokens:
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return

@staticmethod
def _check_stop_strings(seq: Sequence, new_char_count: int,
sampling_params: SamplingParams) -> Optional[str]:
"""Check if any stop strings are matched and truncate sequence
output text accordingly.

Returns the stop string if matched or else None.
"""
if not new_char_count:
return None

for stop_str in sampling_params.stop:
stop_string_len = len(stop_str)
# Avoid searching already-searched text.
stop_index = seq.output_text.find(
stop_str, -new_char_count - stop_string_len)
if stop_index == -1:
continue

if sampling_params.include_stop_str_in_output:
# Truncate to end of stop string.
stop_index += stop_string_len
if stop_index >= len(seq.output_text):
# No truncation required.
return stop_str

# Truncate the output text to either the beginning
# or end of the stop string.
seq.output_text = seq.output_text[:stop_index]
return stop_str
return None

def add_lora(self, lora_request: LoRARequest) -> bool:
return self.model_executor.add_lora(lora_request)
Expand Down
4 changes: 3 additions & 1 deletion vllm/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,10 @@ def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
# always has the logprobs of the sampled tokens even if the
# logprobs are not requested.
include_logprobs = seq_group.sampling_params.logprobs is not None
text_buffer_length = seq_group.sampling_params.output_text_buffer_length
outputs = [
CompletionOutput(seqs.index(seq), seq.output_text,
CompletionOutput(seqs.index(seq),
seq.get_output_text_to_return(text_buffer_length),
seq.get_output_token_ids(),
seq.get_cumulative_logprob(),
seq.output_logprobs if include_logprobs else None,
Expand Down
9 changes: 9 additions & 0 deletions vllm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,13 @@ def __init__(
self.logits_processors = logits_processors
self.include_stop_str_in_output = include_stop_str_in_output
self.truncate_prompt_tokens = truncate_prompt_tokens
# Number of characters to hold back for stop string evaluation
# until sequence is finished.
if self.stop and not include_stop_str_in_output:
self.output_text_buffer_length = max(len(s) for s in self.stop) - 1
else:
self.output_text_buffer_length = 0

self._verify_args()
if self.use_beam_search:
self._verify_beam_search()
Expand Down Expand Up @@ -226,6 +233,8 @@ def _verify_args(self) -> None:
and self.truncate_prompt_tokens < 1):
raise ValueError(f"truncate_prompt_tokens must be >= 1, "
f"got {self.truncate_prompt_tokens}")
if any(not stop_str for stop_str in self.stop):
raise ValueError("stop cannot contain an empty string.")
if self.stop and not self.detokenize:
raise ValueError(
"stop strings are only supported when detokenize is True. "
Expand Down
6 changes: 6 additions & 0 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,12 @@ def __init__(
def lora_int_id(self) -> int:
return self.lora_request.lora_int_id if self.lora_request else 0

def get_output_text_to_return(self, buffer_length: int):
# We return the full output text if the sequence is finished.
truncate = buffer_length and not self.is_finished()
return self.output_text[:-buffer_length] if truncate else (
self.output_text)

def hash_of_block(self, logical_idx: int) -> int:
# TODO This can produce incorrect hash when block size > prompt size

Expand Down
7 changes: 6 additions & 1 deletion vllm/transformers_utils/detokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,15 @@ def decode_prompt_logprobs_inplace(
prev_tokens.extend(next_iter_tokens)

def decode_sequence_inplace(self, seq: Sequence,
prms: SamplingParams) -> None:
prms: SamplingParams) -> int:
"""Decodes the new token for a sequence. In-place operation.

Args:
seq: The sequence to decode.
prms: The sampling parameters used to generate the sequence.

Returns:
The number of characters added to the output text.
"""
all_input_ids = seq.get_token_ids()
token_id_generated_this_iteration = all_input_ids[-1]
Expand Down Expand Up @@ -151,6 +154,8 @@ def decode_sequence_inplace(self, seq: Sequence,
seq.read_offset = read_offset
seq.output_text += new_decoded_token_text

return len(new_decoded_token_text)


def _convert_tokens_to_string_with_added_encoders(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
Expand Down