Skip to content
Closed
12 changes: 7 additions & 5 deletions tests/entrypoints/openai/test_guided_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from vllm.entrypoints.openai.protocol import CompletionRequest
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
get_guided_decoding_logits_processor_factory)
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
JSONLogitsProcessor, RegexLogitsProcessor)

Expand Down Expand Up @@ -47,9 +47,10 @@ async def test_guided_logits_processor_black_box(backend: str, sample_regex,
regex_request = CompletionRequest(model='test',
prompt=token_ids,
guided_regex=sample_regex)
regex_lp = await get_guided_decoding_logits_processor(
regex_lpf = await get_guided_decoding_logits_processor_factory(
backend, regex_request, tokenizer)
assert regex_lp is not None
assert regex_lpf is not None
regex_lp = await regex_lpf.get_processor_async()
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
tensor = regex_lp(token_ids, tensor)
Expand All @@ -62,9 +63,10 @@ async def test_guided_logits_processor_black_box(backend: str, sample_regex,
json_request = CompletionRequest(model='test',
prompt=token_ids,
guided_json=sample_json_schema)
json_lp = await get_guided_decoding_logits_processor(
json_lpf = await get_guided_decoding_logits_processor_factory(
backend, json_request, tokenizer)
assert json_lp is not None
assert json_lpf is not None
json_lp = await json_lpf.get_processor_async()
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
tensor = json_lp(token_ids, tensor)
Expand Down
4 changes: 3 additions & 1 deletion tests/test_logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,13 @@ def pick_ith(token_ids, logits):
seq_group_metadata_list = []
seq_lens = []
for i in range(batch_size):
seq_data = {0: SequenceData([1, 2, 3])}
seq_data[0].logits_processors = [pick_ith]
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
seq_data=seq_data,
sampling_params=SamplingParams(temperature=0,
logits_processors=[pick_ith]),
block_tables={0: [1]},
Expand Down
15 changes: 14 additions & 1 deletion vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ async def add_request_async(
prompt_adapter_request=prompt_adapter_request,
)

self._add_processed_request(
seq_group = self._create_sequence_group(
request_id=request_id,
processed_inputs=processed_inputs,
params=params,
Expand All @@ -468,6 +468,19 @@ async def add_request_async(
trace_headers=trace_headers,
)

if isinstance(params, SamplingParams):
for seq in seq_group.get_seqs():
seq.data.logits_processors = \
await params.get_logits_processors_async()

# Add the sequence group to the scheduler with least unfinished seqs.
costs = [
scheduler.get_num_unfinished_seq_groups()
for scheduler in self.scheduler
]
min_cost_scheduler = self.scheduler[costs.index(min(costs))]
min_cost_scheduler.add_seq_group(seq_group)

async def check_health_async(self) -> None:
if self.tokenizer:
self.tokenizer.check_health()
Expand Down
26 changes: 16 additions & 10 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,7 @@ def _get_decoder_start_token_id(self) -> Optional[int]:

return dec_start_token_id

def _add_processed_request(
def _create_sequence_group(
self,
request_id: str,
processed_inputs: Union[LLMInputs, EncoderDecoderLLMInputs],
Expand All @@ -579,7 +579,7 @@ def _add_processed_request(
lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest],
trace_headers: Optional[Mapping[str, str]] = None,
) -> None:
) -> SequenceGroup:
# Create the sequences.
block_size = self.cache_config.block_size
seq_id = next(self.seq_counter)
Expand Down Expand Up @@ -622,13 +622,7 @@ def _add_processed_request(
raise ValueError(
"Either SamplingParams or PoolingParams must be provided.")

# Add the sequence group to the scheduler with least unfinished seqs.
costs = [
scheduler.get_num_unfinished_seq_groups()
for scheduler in self.scheduler
]
min_cost_scheduler = self.scheduler[costs.index(min(costs))]
min_cost_scheduler.add_seq_group(seq_group)
return seq_group

def stop_remote_worker_execution_loop(self) -> None:
self.model_executor.stop_remote_worker_execution_loop()
Expand Down Expand Up @@ -1026,7 +1020,7 @@ def add_request(
prompt_adapter_request=prompt_adapter_request,
)

self._add_processed_request(
seq_group = self._create_sequence_group(
request_id=request_id,
processed_inputs=processed_inputs,
params=params,
Expand All @@ -1036,6 +1030,18 @@ def add_request(
trace_headers=trace_headers,
)

if isinstance(params, SamplingParams):
for seq in seq_group.get_seqs():
seq.data.logits_processors = params.get_logits_processors()

# Add the sequence group to the scheduler with least unfinished seqs.
costs = [
scheduler.get_num_unfinished_seq_groups()
for scheduler in self.scheduler
]
min_cost_scheduler = self.scheduler[costs.index(min(costs))]
min_cost_scheduler.add_seq_group(seq_group)

def _create_sequence_group_with_sampling(
self,
request_id: str,
Expand Down
11 changes: 6 additions & 5 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.entrypoints.openai.logits_processors import get_logits_processors
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import LogitsProcessor, SamplingParams
from vllm.sampling_params import (LogitsProcessor, LogitsProcessorFactory,
SamplingParams)
from vllm.utils import random_uuid

# torch is mocked during docs generation,
Expand Down Expand Up @@ -232,10 +233,10 @@ class ChatCompletionRequest(OpenAIBaseModel):

# doc: end-chat-completion-extra-params

def to_sampling_params(
self, tokenizer: PreTrainedTokenizer,
guided_decode_logits_processor: Optional[LogitsProcessor],
default_max_tokens: int) -> SamplingParams:
def to_sampling_params(self, tokenizer: PreTrainedTokenizer,
guided_decode_logits_processor: Optional[Union[
LogitsProcessor, LogitsProcessorFactory]],
default_max_tokens: int) -> SamplingParams:
max_tokens = self.max_tokens
if max_tokens is None:
max_tokens = default_max_tokens
Expand Down
3 changes: 2 additions & 1 deletion vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ async def create_chat_completion(
request_id = f"chat-{random_uuid()}"
try:
guided_decode_logits_processor = (
await self._guided_decode_logits_processor(request, tokenizer))
await self._guided_decode_logits_processor_factory(
request, tokenizer))

prompt_inputs = self._tokenize_prompt_input(
request,
Expand Down
3 changes: 2 additions & 1 deletion vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ async def create_completion(self, request: CompletionRequest,
lora_request)

guided_decode_logits_processor = (
await self._guided_decode_logits_processor(request, tokenizer))
await self._guided_decode_logits_processor_factory(
request, tokenizer))
prompts = list(
self._tokenize_prompt_input_or_inputs(
request,
Expand Down
10 changes: 5 additions & 5 deletions vllm/entrypoints/openai/serving_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
get_guided_decoding_logits_processor_factory)
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import LogitsProcessor, SamplingParams
from vllm.sampling_params import LogitsProcessorFactory, SamplingParams
from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer_group import AnyTokenizer

Expand Down Expand Up @@ -152,13 +152,13 @@ def create_streaming_error_response(
})
return json_str

async def _guided_decode_logits_processor(
async def _guided_decode_logits_processor_factory(
self, request: Union[ChatCompletionRequest, CompletionRequest],
tokenizer: AnyTokenizer) -> Optional[LogitsProcessor]:
tokenizer: AnyTokenizer) -> Optional[LogitsProcessorFactory]:
decoding_config = await self.async_engine_client.get_decoding_config()
guided_decoding_backend = request.guided_decoding_backend \
or decoding_config.guided_decoding_backend
return await get_guided_decoding_logits_processor(
return await get_guided_decoding_logits_processor_factory(
guided_decoding_backend, request, tokenizer)

async def _check_model(
Expand Down
8 changes: 4 additions & 4 deletions vllm/model_executor/guided_decoding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
from vllm.model_executor.guided_decoding.outlines_decoding import (
get_local_outlines_guided_decoding_logits_processor,
get_outlines_guided_decoding_logits_processor)
from vllm.sampling_params import LogitsProcessor
from vllm.sampling_params import LogitsProcessor, LogitsProcessorFactory


async def get_guided_decoding_logits_processor(
async def get_guided_decoding_logits_processor_factory(
guided_decoding_backend: str, request: Union[CompletionRequest,
ChatCompletionRequest],
tokenizer) -> Optional[LogitsProcessor]:
tokenizer) -> Optional[LogitsProcessorFactory]:
request = _adapt_request_for_tool_use(request)

if guided_decoding_backend == 'outlines':
Expand Down Expand Up @@ -70,4 +70,4 @@ def _adapt_request_for_tool_use(request: Union[CompletionRequest,
tool = tools[tool_name]
request.guided_json = tool.parameters

return request
return request
20 changes: 15 additions & 5 deletions vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,23 @@
from vllm.model_executor.guided_decoding.outlines_decoding import (
get_local_outlines_guided_decoding_logits_processor,
get_outlines_guided_decoding_logits_processor)
from vllm.sampling_params import LogitsProcessor
from vllm.sampling_params import LogitsProcessor, LogitsProcessorFactory


class LMFormatDecodingLogitsProcessorFactory(LogitsProcessorFactory):

def __init__(self, tokenizer_data, character_level_parser):
self.tokenizer_data = tokenizer_data
self.character_level_parser = character_level_parser

def get_processor(self) -> LogitsProcessor:
return build_vllm_logits_processor(self.tokenizer_data,
self.character_level_parser)


async def get_lm_format_enforcer_guided_decoding_logits_processor(
request: Union[CompletionRequest, ChatCompletionRequest],
tokenizer) -> Optional[LogitsProcessor]:
tokenizer) -> Optional[LMFormatDecodingLogitsProcessorFactory]:
"""
Given an OpenAI-compatible request, check for guided decoding parameters
and get the necessary logits processor for the given guide.
Expand Down Expand Up @@ -52,9 +63,8 @@ async def get_lm_format_enforcer_guided_decoding_logits_processor(
else:
return None

logits_processor = build_vllm_logits_processor(tokenizer_data,
character_level_parser)
return logits_processor
return LMFormatDecodingLogitsProcessorFactory(tokenizer_data,
character_level_parser)


def get_local_lm_format_enforcer_guided_decoding_logits_processor(
Expand Down
43 changes: 32 additions & 11 deletions vllm/model_executor/guided_decoding/outlines_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from enum import Enum
from json import dumps as json_dumps
from re import escape as regex_escape
from typing import Tuple, Union
from typing import Optional, Tuple, Union

from pydantic import BaseModel
from transformers import PreTrainedTokenizerBase
Expand All @@ -14,6 +14,7 @@
GuidedDecodingRequest)
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor)
from vllm.sampling_params import LogitsProcessor, LogitsProcessorFactory


class GuidedDecodingMode(Enum):
Expand Down Expand Up @@ -53,11 +54,37 @@ class GuidedDecodingMode(Enum):
global_thread_pool = None # used for generating logits processor fsm


class OutlinesDecodingLogitsProcessorFactory(LogitsProcessorFactory):

def __init__(self, guide: str, tokenizer: PreTrainedTokenizerBase,
mode: GuidedDecodingMode, whitespace_pattern: Union[str,
None]):
self.guide = guide
self.tokenizer = tokenizer
self.mode = mode
self.whitespace_pattern = whitespace_pattern

def get_processor(self) -> LogitsProcessor:
return _get_logits_processor(self.guide, self.tokenizer, self.mode,
self.whitespace_pattern)

async def get_processor_async(self) -> LogitsProcessor:
global global_thread_pool
if global_thread_pool is None:
global_thread_pool = concurrent.futures.ThreadPoolExecutor(
max_workers=2)
loop = asyncio.get_running_loop()

return await loop.run_in_executor(global_thread_pool,
_get_logits_processor, self.guide,
self.tokenizer, self.mode,
self.whitespace_pattern)


async def get_outlines_guided_decoding_logits_processor(
request: Union[CompletionRequest,
ChatCompletionRequest], tokenizer: PreTrainedTokenizerBase
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
None]:
) -> Optional[OutlinesDecodingLogitsProcessorFactory]:
"""
Given an OpenAI-compatible request, check for guided decoding parameters
and get the necessary logits processor for the given guide.
Expand All @@ -69,14 +96,8 @@ async def get_outlines_guided_decoding_logits_processor(
if not guide or not mode:
return None

if global_thread_pool is None:
global_thread_pool = concurrent.futures.ThreadPoolExecutor(
max_workers=2)
loop = asyncio.get_running_loop()

return await loop.run_in_executor(global_thread_pool,
_get_logits_processor, guide, tokenizer,
mode, request.guided_whitespace_pattern)
return OutlinesDecodingLogitsProcessorFactory(
guide, tokenizer, mode, request.guided_whitespace_pattern)


def get_local_outlines_guided_decoding_logits_processor(
Expand Down
12 changes: 6 additions & 6 deletions vllm/model_executor/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,13 @@ def _apply_logits_processors(
logits_processed = 0
for seq_group in sampling_metadata.seq_groups:
seq_ids = seq_group.seq_ids
sampling_params = seq_group.sampling_params
logits_processors = sampling_params.logits_processors
if logits_processors:
found_logits_processors = True

for seq_id, logits_row_idx in zip(seq_ids,
seq_group.sample_indices):
for seq_id, logits_row_idx in zip(seq_ids, seq_group.sample_indices):
logits_processors = seq_group.seq_data[seq_id].logits_processors

if logits_processors:
found_logits_processors = True

logits_row = logits[logits_row_idx]
past_tokens_ids = seq_group.seq_data[seq_id].output_token_ids
prompt_tokens_ids = seq_group.seq_data[seq_id].prompt_token_ids
Expand Down
Loading