Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
b7c8a8a
bump xgrammar dependency version (needed for bugfixes to grammar_matc…
benchislett Mar 12, 2025
5bc9ec3
validate and filter draft tokens according to structured output grammar
benchislett Mar 12, 2025
09acb1a
remove request validation allowing spec+structured output
benchislett Mar 12, 2025
a5b7521
implement compressed-batch structured output bitmask with speculative…
benchislett Mar 12, 2025
24e5e15
fix num spec tokens
benchislett Mar 12, 2025
36ed442
tweaks to benchmark script
benchislett Mar 12, 2025
730d81f
comment about perf improvement potential
benchislett Mar 12, 2025
a15d5a0
Merge branch 'main' into feat-v1-speculative-decoding-with-structured…
benchislett Mar 12, 2025
7726ae4
bugfix
benchislett Mar 12, 2025
a6a997a
bugfix
benchislett Mar 13, 2025
3b93462
fix formatting
benchislett Mar 13, 2025
72f736b
Merge branch 'main' into feat-v1-speculative-decoding-with-structured…
benchislett Mar 14, 2025
2ddbaa6
Merge branch 'main' into feat-v1-speculative-decoding-with-structured…
benchislett Mar 18, 2025
f6aff2a
Merge remote-tracking branch 'upstream/main' into feat-v1-speculative…
benchislett Apr 7, 2025
12715bb
fix guided decoding benchmarks for v1 spec compat
benchislett Apr 7, 2025
818bf9a
update structured output tests for spec compat
benchislett Apr 8, 2025
94f1011
Merge branch 'main' into feat-v1-speculative-decoding-with-structured…
benchislett Apr 11, 2025
cfd235b
tiny fixes for pr comments
benchislett Apr 17, 2025
8860165
Merge branch 'main' into feat-v1-speculative-decoding-with-structured…
benchislett Apr 17, 2025
617a63d
tests for EAGLE
benchislett Apr 17, 2025
8636227
comment for clarity
benchislett Apr 25, 2025
af7bbdc
re-disable test case
benchislett Apr 25, 2025
07292cb
Merge remote-tracking branch 'upstream/main' into feat-v1-speculative…
benchislett Apr 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions benchmarks/backend_request_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ async def async_request_openai_completions(
if request_func_input.model_name else request_func_input.model,
"prompt": request_func_input.prompt,
"temperature": 0.0,
"repetition_penalty": 1.0,
"max_tokens": request_func_input.output_len,
"logprobs": request_func_input.logprobs,
"stream": True,
Expand Down
9 changes: 6 additions & 3 deletions benchmarks/benchmark_serving_structured_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
copy.deepcopy(schema) for _ in range(args.num_prompts)
]
for i in range(len(json_schemas)):
if "properties" not in json_schemas[i]:
json_schemas[i]["properties"] = {}
json_schemas[i]["properties"][
f"__optional_field_{uuid.uuid4()}"] = {
"type":
Expand All @@ -134,7 +136,7 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
json_schemas = [schema] * args.num_prompts

def gen_prompt(index: int):
return f"Generate an example of a user profile given the following schema: {json.dumps(get_schema(index))}" # noqa: E501
return f"Generate an example of a brief user profile given the following schema: {json.dumps(get_schema(index))}" # noqa: E501
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this change for?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some models are endlessly repeating outputs on the benchmark prompt. this tweak was enough to tip the scales


def get_schema(index: int):
return json_schemas[index % len(json_schemas)]
Expand Down Expand Up @@ -231,7 +233,8 @@ def _filter_func(item):
idx -= len_dataset
schema = dataset["schema"][idx]
prompt = tokenizer.apply_chat_template(dataset["prompt"][idx],
tokenize=False)
tokenize=False,
add_generation_prompt=True)
input_len = len(tokenizer(prompt).input_ids)
completion = dataset["completion"][idx]

Expand Down Expand Up @@ -849,7 +852,7 @@ def main(args: argparse.Namespace):
'json', 'json-unique', 'grammar', 'regex',
'choice', 'xgrammar_bench'
])
parser.add_argument("--json_schema_path",
parser.add_argument("--json-schema-path",
type=str,
default=None,
help="Path to json schema.")
Expand Down
35 changes: 28 additions & 7 deletions tests/v1/entrypoints/llm/test_struct_output_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,31 @@
from vllm.platforms import current_platform
from vllm.sampling_params import GuidedDecodingParams, SamplingParams

NGRAM_SPEC_CONFIG = {
"model": "[ngram]",
"num_speculative_tokens": 5,
"prompt_lookup_max": 5,
"prompt_lookup_min": 1,
}

EAGLE_SPEC_CONFIG = {
"method": "eagle",
"model": "yuhuili/EAGLE-LLaMA3.1-Instruct-8B",
"num_speculative_tokens": 5,
}

PARAMS_MODELS_BACKENDS_TOKENIZER_MODE = [
("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "auto"),
("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto"),
("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "mistral"),
("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto"),
("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "auto", None),
("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", None),
("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "mistral", None),
("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", None),
#FIXME: This test is flaky on CI thus disabled
#("Qwen/Qwen2.5-1.5B-Instruct", "guidance", "auto"),
("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto",
NGRAM_SPEC_CONFIG),
("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", NGRAM_SPEC_CONFIG),
("meta-llama/Meta-Llama-3.1-8B-Instruct", "xgrammar", "auto",
EAGLE_SPEC_CONFIG)
]

PARAMS_MODELS_TOKENIZER_MODE = [
Expand All @@ -45,8 +63,9 @@ class CarDescription(BaseModel):


@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("model_name, guided_decoding_backend, tokenizer_mode",
PARAMS_MODELS_BACKENDS_TOKENIZER_MODE)
@pytest.mark.parametrize(
"model_name, guided_decoding_backend, tokenizer_mode, speculative_config",
PARAMS_MODELS_BACKENDS_TOKENIZER_MODE)
def test_structured_output(
monkeypatch: pytest.MonkeyPatch,
sample_json_schema: dict[str, Any],
Expand All @@ -58,6 +77,7 @@ def test_structured_output(
guided_decoding_backend: str,
tokenizer_mode: str,
model_name: str,
speculative_config: dict[str, Any],
):
monkeypatch.setenv("VLLM_USE_V1", "1")

Expand All @@ -71,7 +91,8 @@ def test_structured_output(
max_model_len=1024,
guided_decoding_backend=guided_decoding_backend,
guided_decoding_disable_any_whitespace=True,
tokenizer_mode=tokenizer_mode)
tokenizer_mode=tokenizer_mode,
speculative_config=speculative_config)

#
# Test 1: Generate JSON output based on a provided schema
Expand Down
17 changes: 12 additions & 5 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ def schedule(self) -> SchedulerOutput:
grammar_bitmask = self.structured_output_manager.grammar_bitmask(
self.requests,
structured_output_request_ids,
len(self.running),
scheduled_spec_decode_tokens,
)
# Construct the scheduler output.
new_reqs_data = [
Expand Down Expand Up @@ -682,10 +682,6 @@ def update_from_output(
self.encoder_cache_manager.free_encoder_input(
request, input_id)

# Add newly generated spec token ids to the request.
if spec_token_ids is not None:
request.spec_token_ids = spec_token_ids[req_index]

stopped = False
new_logprobs = None
new_token_ids = generated_token_ids
Expand Down Expand Up @@ -717,6 +713,17 @@ def update_from_output(
request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr]
req_id, new_token_ids)

# Add newly generated spec token ids to the request.
if spec_token_ids is not None:
if request.use_structured_output:
metadata = request.structured_output_request
assert metadata is not None and metadata.grammar is not None
# Needs to happen after new_token_ids are accepted.
request.spec_token_ids = metadata.grammar.validate_tokens(
spec_token_ids[req_index])
else:
request.spec_token_ids = spec_token_ids[req_index]

# Get prompt logprobs for this request.
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
if new_token_ids:
Expand Down
59 changes: 46 additions & 13 deletions vllm/v1/structured_output/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class StructuredOutputManager:
def __init__(self, vllm_config: VllmConfig):
self.backend: Optional[StructuredOutputBackend] = None
self.vllm_config = vllm_config

self._grammar_bitmask: Optional[torch.Tensor] = None

# The default max_workers if not specified is the number of CPUs * 5,
Expand Down Expand Up @@ -80,28 +81,60 @@ def grammar_bitmask(
self,
requests: dict[str, Request],
structured_output_request_ids: dict[str, int],
batch_len: int,
scheduled_spec_decode_tokens: dict[str, list[int]],
) -> Optional[npt.NDArray[np.int32]]:
# Prepare the structured output bitmask for this batch.
if not structured_output_request_ids:
return None

if self._grammar_bitmask is None:
assert self.backend is not None
self._grammar_bitmask = self.backend.allocate_token_bitmask(
self.vllm_config.scheduler_config.max_num_seqs)

# Fill the bitmask using the index of each request equal to its
# position in the batch. Resize the bitmask down to the size of
# the batch.
bitmask_tensor = self._grammar_bitmask
for req_id, batch_index in structured_output_request_ids.items():
max_batch_size = self.vllm_config.scheduler_config.max_num_seqs
if self.vllm_config.speculative_config is not None:
max_num_spec_tokens = self.vllm_config.\
speculative_config.num_speculative_tokens
else:
max_num_spec_tokens = 0

# Allocate a bitmask for each token needing to be checked:
# one for each speculative position, and one more for the
# bonus token / non-speculative token.
self._grammar_bitmask = \
self.backend.allocate_token_bitmask(
max_batch_size * (1 + max_num_spec_tokens))

# Generate a batched bitmask for all structured output requests.
# When speculative decoding is enabled, we need to include multiple
# masks for each request, one for each possible bonus token position.
# These are stored inline in the tensor and unpacked by the gpu runner.
cumulative_index = 0
ordered_seq = sorted(structured_output_request_ids.items(),
key=lambda x: x[1])
# NOTE: This outer loop can likely be parallelized to improve
# performance of bitmask generation for large batches.
for req_id, _ in ordered_seq:
request = requests[req_id].structured_output_request
assert request is not None and request.grammar is not None
if not request.grammar.is_terminated():
request.grammar.fill_bitmask(bitmask_tensor, batch_index)
if batch_len < self._grammar_bitmask.shape[0]:
bitmask_tensor = self._grammar_bitmask[:batch_len]
state_advancements = 0
req_tokens = scheduled_spec_decode_tokens.get(req_id, []) + [None]
for i, token in enumerate(req_tokens):
if not request.grammar.is_terminated():
request.grammar.fill_bitmask(self._grammar_bitmask,
cumulative_index)
if token is not None:
# In order to generate the correct bitmask for each
# position in the speculative sequence, we advance
# the FSM state for each speculative token and rollback
# to restore the previous state when we are finished.
assert request.grammar.accept_tokens(req_id, [token])
state_advancements += 1
cumulative_index += 1
if state_advancements > 0:
request.grammar.rollback(state_advancements)

bitmask_tensor = self._grammar_bitmask
if cumulative_index < self._grammar_bitmask.shape[0]:
bitmask_tensor = self._grammar_bitmask[:cumulative_index]

# After finishing with the xgrammar operations, we convert to
# np.ndarray, because that is much more efficient for serialization
Expand Down
21 changes: 21 additions & 0 deletions vllm/v1/structured_output/backend_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,27 @@ def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:

return r

def validate_tokens(self, tokens: list[int]) -> list[int]:
"""Checks if the list of tokens are accepted by the parser in sequence.
Will not advance the parser.

Returns the prefix list of tokens that are accepted by the parser.
"""
if len(tokens) == 0:
return []
if self.ll_matcher.is_stopped():
return []

num_tokens = self.ll_matcher.validate_tokens(tokens)

self.check_error()

return tokens[:num_tokens]

def rollback(self, num_tokens: int) -> None:
self.ll_matcher.rollback(num_tokens)
self.check_error()

def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None:
# this will automatically return [EOS] mask if the matcher is stopped
# or otherwise in an error state
Expand Down
24 changes: 24 additions & 0 deletions vllm/v1/structured_output/backend_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,30 @@ def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
bool: True if the tokens are accepted, False otherwise.
"""

@abstractmethod
def validate_tokens(self, tokens: list[int]) -> list[int]:
"""
Validates the provided tokens against the grammar.
Will not advance the FSM.

Args:
tokens (list[int]): A list of token IDs to validate.

Returns:
list[int]: A list of accepted token IDs. Will be a prefix
of the input tokens, and empty if none are accepted.
"""

@abstractmethod
def rollback(self, num_tokens: int) -> None:
"""
Rolls back the state of the grammar by a specified number of tokens.
Will also revert counters for the number of processed tokens.

Args:
num_tokens (int): The number of tokens to roll back.
"""

@abstractmethod
def fill_bitmask(self, bitmask: torch.Tensor, batch_index: int) -> None:
"""
Expand Down
32 changes: 30 additions & 2 deletions vllm/v1/structured_output/backend_xgrammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ def __init__(self, vllm_config: VllmConfig):
self.disable_any_whitespace = \
vllm_config.decoding_config.disable_any_whitespace

self.num_speculative_tokens = 0
if self.vllm_config.speculative_config is not None:
self.num_speculative_tokens = \
self.vllm_config.speculative_config.num_speculative_tokens

tokenizer = tokenizer_group.get_lora_tokenizer(None)
self.vocab_size = vllm_config.model_config.get_vocab_size()
if isinstance(tokenizer, MistralTokenizer):
Expand Down Expand Up @@ -118,7 +123,10 @@ def compile_grammar(self, request_type: StructuredOutputOptions,
f"grammar is not of valid supported types. ({request_type!s})")

return XgrammarGrammar(
matcher=xgr.GrammarMatcher(ctx),
matcher=xgr.GrammarMatcher(
ctx,
max_rollback_tokens=self.num_speculative_tokens,
),
vocab_size=self.vocab_size,
ctx=ctx,
)
Expand All @@ -136,7 +144,6 @@ class XgrammarGrammar(StructuredOutputGrammar):
# supporting different backends, in the future.
# For now, just xgrammar.
#
# TODO: support max_rollback_tokens
# https://xgrammar.mlc.ai/docs/api/python/index.html#xgrammar.GrammarMatcher.find_jump_forward_string
# for jump-forward decoding

Expand All @@ -163,6 +170,27 @@ def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
self.num_processed_tokens += 1
return True

def validate_tokens(self, tokens: list[int]) -> list[int]:
"""Checks if the list of tokens are accepted by the FSM in sequence.
Will not advance the FSM.

Returns the prefix list of tokens that are accepted by the FSM.
"""
accepted_tokens = []
for token in tokens:
if self.matcher.accept_token(token):
accepted_tokens.append(token)
else:
break
if len(accepted_tokens) > 0:
# Rollback the FSM to the initial state
self.matcher.rollback(len(accepted_tokens))
return accepted_tokens

def rollback(self, num_tokens: int) -> None:
self.matcher.rollback(num_tokens)
self.num_processed_tokens -= num_tokens

def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None:
self.matcher.fill_next_token_bitmask(bitmask, idx)

Expand Down
Loading