Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions tests/v1/engine/test_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,3 +213,29 @@ def find_metric(name) -> list[Metric]:
assert len(num_accepted_tokens_per_pos) == 1
assert isinstance(num_accepted_tokens_per_pos[0], Vector)
assert len(num_accepted_tokens_per_pos[0].values) == 5


@pytest.mark.parametrize("model", ["meta-llama/Llama-3.2-1B-Instruct"])
def test_skip_tokenizer_initialization(model: str,
monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("VLLM_USE_V1", "1")
# This test checks if the flag skip_tokenizer_init skips the initialization
# of tokenizer and detokenizer. The generated output is expected to contain
# token ids.
llm = LLM(
model=model,
skip_tokenizer_init=True,
enforce_eager=True,
)
sampling_params = SamplingParams(prompt_logprobs=True, detokenize=True)

with pytest.raises(ValueError, match="cannot pass text prompts when"):
llm.generate("abc", sampling_params)

outputs = llm.generate({"prompt_token_ids": [1, 2, 3]},
sampling_params=sampling_params)
assert len(outputs) > 0
completions = outputs[0].outputs
assert len(completions) > 0
assert completions[0].text == ""
assert completions[0].token_ids
9 changes: 7 additions & 2 deletions vllm/v1/engine/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ def _validate_sampling_params(
return
if not params.allowed_token_ids:
raise ValueError("allowed_token_ids is not None and empty!")
if self.tokenizer is None:
# When skip_tokenizer_init=True, we can't validate token IDs
# Skip validation and let the model handle invalid tokens
return
Comment on lines +92 to +95
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This change correctly handles the case where allowed_token_ids are provided with skip_tokenizer_init=True. However, the PR is incomplete and will still lead to a crash in other common scenarios.

When skip_tokenizer_init=True, OutputProcessor is initialized with tokenizer=None. This will cause a crash in process_outputs when it tries to create a Detokenizer, which will raise a ValueError: Tokenizer not initialized. This happens even if the user sets detokenize=False in SamplingParams, because the default SamplingParams has detokenize=True. The test case in the PR description will trigger this crash.

Additionally, if bad_words are provided, they are silently ignored as update_from_tokenizer is skipped, but no warning or error is raised to the user.

To properly fix this, we should validate these unsupported parameter combinations when skip_tokenizer_init=True and raise an error, similar to how structured output is handled. A good place for these checks would be in _validate_params or at the beginning of this method (_validate_sampling_params).

For example:

if self.tokenizer is None:
    if params.detokenize:
        raise ValueError(
            "Detokenization is not supported when `skip_tokenizer_init=True`. "
            "Please set `detokenize=False` in `SamplingParams`."
        )
    if params.bad_words:
        raise ValueError(
            "`bad_words` is not supported when `skip_tokenizer_init=True`."
        )

Without this, the bug is not fully fixed and the feature remains partially broken.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will cause a crash in process_outputs when it tries to create a Detokenizer, which will raise a ValueError: Tokenizer not initialized.

I don't think this is correct, may be getting confused with v0.

tokenizer = self.tokenizer.get_lora_tokenizer(lora_request)
vocab_size = len(tokenizer)
if not all(0 <= tid < vocab_size for tid in params.allowed_token_ids):
Expand Down Expand Up @@ -283,8 +287,9 @@ def process_inputs(
len(decoder_inputs["prompt_token_ids"]))
sampling_params.update_from_generation_config(
self.generation_config_fields, eos_token_id)
sampling_params.update_from_tokenizer(
self.tokenizer.get_lora_tokenizer(lora_request))
if self.tokenizer is not None:
sampling_params.update_from_tokenizer(
self.tokenizer.get_lora_tokenizer(lora_request))
else:
pooling_params = params.clone()

Expand Down