Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 92 additions & 39 deletions vllm/v1/structured_output/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

import multiprocessing
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import Future, ThreadPoolExecutor
from typing import TYPE_CHECKING, Optional

from vllm.config import VllmConfig
Expand Down Expand Up @@ -40,6 +40,17 @@ def __init__(self, vllm_config: VllmConfig):
self._grammar_bitmask: Optional[torch.Tensor] = None
self._full_mask = torch.tensor(-1, dtype=torch.int32)

max_batch_size = self.vllm_config.scheduler_config.max_num_seqs
self.fill_bitmask_parallel_threshold = 128
if self.fill_bitmask_parallel_threshold < max_batch_size:
self.fill_bitmask_parallel_batch_size = 16
# Use:
# - at least 1 CPU
# - at most half the number of CPUs or 8, whichever is less
max_workers = max(1, min(multiprocessing.cpu_count() // 2, 8))
self.executor_for_fillmask = ThreadPoolExecutor(
max_workers=max_workers)

if not self.vllm_config.model_config.skip_tokenizer_init:
# The default max_workers if not specified is the number of
# CPUs * 5, which is way too high since these tasks are CPU-bound,
Expand Down Expand Up @@ -120,6 +131,26 @@ def _async_create_grammar(
assert self.backend is not None
return self.backend.compile_grammar(request_type, grammar_spec)

def _fill_bitmasks(
self,
batch: list[tuple[StructuredOutputGrammar, int, bool]],
) -> None:
assert self._grammar_bitmask is not None
for grammar, index, apply_bitmask in batch:
if apply_bitmask and not grammar.is_terminated():
grammar.fill_bitmask(self._grammar_bitmask, index)
else:
# Note that for thinking support, we will need to
# reset the relevant part of the bitmask for consequent
# requests here.
self._grammar_bitmask[index].fill_(self._full_mask)

def _async_submit_fill_bitmask(
self,
batch: list[tuple[StructuredOutputGrammar, int, bool]],
) -> Future:
return self.executor_for_fillmask.submit(self._fill_bitmasks, batch)

def grammar_bitmask(
self,
requests: dict[str, Request],
Expand All @@ -146,7 +177,6 @@ def grammar_bitmask(
self.backend.allocate_token_bitmask(
max_batch_size * (1 + max_num_spec_tokens))

bitmask_tensor = self._grammar_bitmask
# Generate a batched bitmask for all structured output requests.
# When speculative decoding is enabled, we need to include multiple
# masks for each request, one for each possible bonus token position.
Expand All @@ -155,47 +185,61 @@ def grammar_bitmask(
ordered_seq = sorted(structured_output_request_ids.items(),
key=lambda x: x[1])

# Note that for thinking support, we will need to
# reset the relevant part of the bitmask for consequent
# request here.
bitmask_tensor[:(len(ordered_seq) * (1 + max_num_spec_tokens))].fill_(
self._full_mask)

# NOTE: This outer loop can likely be parallelized to improve
# performance of bitmask generation for large batches.
for req_id, _ in ordered_seq:
request = requests[req_id]
structured_output_request = request.structured_output_request

if TYPE_CHECKING:
assert structured_output_request is not None
assert structured_output_request.grammar is not None
apply_bitmask: bool = True
if self.reasoner is not None:
if structured_output_request.reasoning_ended is None:
structured_output_request.reasoning_ended = \
self.reasoner.is_reasoning_end(request.prompt_token_ids)
apply_bitmask = structured_output_request.reasoning_ended

state_advancements = 0
req_tokens = scheduled_spec_decode_tokens.get(req_id, []) + [None]
for i, token in enumerate(req_tokens):
if apply_bitmask and not \
structured_output_request.grammar.is_terminated():
structured_output_request.grammar.fill_bitmask(
bitmask_tensor, cumulative_index)
if token is not None:
# In order to generate the correct bitmask for each
# position in the speculative sequence, we advance
# the FSM state for each speculative token and rollback
# to restore the previous state when we are finished.
# Optimized parallel filling of bitmasks for
# non-spec, large-batch-size cases
if len(ordered_seq) > self.fill_bitmask_parallel_threshold and \
max_num_spec_tokens == 0:
promises = []
batch = []
for req_id, _ in ordered_seq:
request = requests[req_id]
structured_output_request = request.structured_output_request
if TYPE_CHECKING:
assert structured_output_request is not None
assert structured_output_request.grammar is not None

apply_bitmask = self.should_fill_bitmask(request)
batch.append((structured_output_request.grammar,
cumulative_index, apply_bitmask))
if len(batch) == self.fill_bitmask_parallel_batch_size:
promises.append(self._async_submit_fill_bitmask(batch))
batch = []

cumulative_index += 1
if batch:
promises.append(self._async_submit_fill_bitmask(batch))

# Wait for all bitmask filling tasks to complete.
for promise in promises:
promise.result()
else:
# Fallback to serial filling of bitmasks for small-batch-size cases
for req_id, _ in ordered_seq:
request = requests[req_id]
structured_output_request = request.structured_output_request

if TYPE_CHECKING:
assert structured_output_request is not None
assert structured_output_request.grammar is not None
apply_bitmask = self.should_fill_bitmask(request)

state_advancements = 0
req_tokens = scheduled_spec_decode_tokens.get(req_id, [])
for i, token in enumerate(req_tokens + [None]):
self._fill_bitmasks([(structured_output_request.grammar,
cumulative_index, apply_bitmask)])

if apply_bitmask and token is not None and \
not structured_output_request.grammar.is_terminated():
assert structured_output_request.grammar.accept_tokens(
req_id, [token])
state_advancements += 1
cumulative_index += 1
if state_advancements > 0:
structured_output_request.grammar.rollback(state_advancements)
cumulative_index += 1
if state_advancements > 0:
structured_output_request.grammar.rollback(
state_advancements)

bitmask_tensor = self._grammar_bitmask
if cumulative_index < bitmask_tensor.shape[0]:
bitmask_tensor = bitmask_tensor[:cumulative_index]

Expand All @@ -204,6 +248,15 @@ def grammar_bitmask(
# and deserialization when sending this to the GPU workers.
return bitmask_tensor.numpy()

def should_fill_bitmask(self, request: Request) -> bool:
if self.reasoner is not None:
assert request.structured_output_request is not None
if request.structured_output_request.reasoning_ended is None:
request.structured_output_request.reasoning_ended = \
self.reasoner.is_reasoning_end(request.prompt_token_ids)
return request.structured_output_request.reasoning_ended
return True

def should_advance(self, request: Request) -> bool:
if not request.use_structured_output:
return False
Expand Down
7 changes: 6 additions & 1 deletion vllm/v1/structured_output/backend_xgrammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,20 +148,24 @@ class XgrammarGrammar(StructuredOutputGrammar):
repr=False,
hash=False,
init=False)
_is_terminated: bool = field(default=False, repr=False, hash=False)

def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
"""Accepts a list of tokens and advances the FSM.

Returns True if the FSM was advanced successfully.
Returns False if the FSM failed to advance.
"""
if self._is_terminated:
return False
for token in tokens:
if not self.matcher.accept_token(token):
logger.error(
"Failed to advance FSM for request %s "
"for tokens %s. Please file an issue.", request_id, token)
return False
self.num_processed_tokens += 1
self._is_terminated = self.matcher.is_terminated()
return True

def validate_tokens(self, tokens: list[int]) -> list[int]:
Expand All @@ -184,12 +188,13 @@ def validate_tokens(self, tokens: list[int]) -> list[int]:
def rollback(self, num_tokens: int) -> None:
self.matcher.rollback(num_tokens)
self.num_processed_tokens -= num_tokens
self._is_terminated = self.matcher.is_terminated()

def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None:
self.matcher.fill_next_token_bitmask(bitmask, idx)

def is_terminated(self) -> bool:
return self.matcher.is_terminated()
return self._is_terminated

def reset(self):
self.num_processed_tokens = 0
Expand Down
9 changes: 7 additions & 2 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1278,17 +1278,22 @@ def apply_grammar_bitmask(
cumulative_index += 1 + num_spec_tokens
grammar_bitmask = sorted_bitmask

# If the grammar bitmask and the logits have the same shape
# we don't need to pass indices to the kernel,
# since the bitmask is already aligned with the logits.
skip_out_indices = grammar_bitmask.shape[0] == logits.shape[0]

# Serialization of np.ndarray is much more efficient than a tensor,
# so we receive it in that format.
grammar_bitmask = torch.from_numpy(grammar_bitmask)
grammar_bitmask = torch.from_numpy(grammar_bitmask).contiguous()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just for sanity, could someone confirm that the above numpy-indexing logic with torch.from_numpy will always create a contiguous array and not a row-wise view?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

afaik torch.from_numpy creates a view, so as long as the original numpy array is contiguous, it should be ok?


# Force use of the torch.compile implementation from xgrammar to work
# around issues with the Triton kernel in concurrent structured output
# scenarios. See PR #19565 and issues #19493, #18376 for details.
xgr_torch_compile.apply_token_bitmask_inplace_torch_compile(
logits,
grammar_bitmask.to(self.device, non_blocking=True),
indices=out_indices,
indices=out_indices if not skip_out_indices else None,
)

def sync_and_slice_intermediate_tensors(
Expand Down