Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ exclude = [
"vllm/engine/**/*.py" = ["UP006", "UP035"]
"vllm/executor/**/*.py" = ["UP006", "UP035"]
"vllm/lora/**/*.py" = ["UP006", "UP035"]
"vllm/model_executor/**/*.py" = ["UP006", "UP035"]
"vllm/model_executor/layers/**/*.py" = ["UP006", "UP035"]
"vllm/model_executor/model_loader/**/*.py" = ["UP006", "UP035"]
"vllm/model_executor/models/**/*.py" = ["UP006", "UP035"]
"vllm/platforms/**/*.py" = ["UP006", "UP035"]
"vllm/plugins/**/*.py" = ["UP006", "UP035"]
"vllm/profiler/**/*.py" = ["UP006", "UP035"]
Expand Down
4 changes: 1 addition & 3 deletions vllm/model_executor/custom_op.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
# SPDX-License-Identifier: Apache-2.0

from typing import Dict, Type

import torch.nn as nn

from vllm.config import get_current_vllm_config
Expand Down Expand Up @@ -138,7 +136,7 @@ def default_on() -> bool:
# Examples:
# - MyOp.enabled()
# - op_registry["my_op"].enabled()
op_registry: Dict[str, Type['CustomOp']] = {}
op_registry: dict[str, type['CustomOp']] = {}

# Decorator to register custom ops.
@classmethod
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
import os
from typing import Any, List
from typing import Any

import llguidance
import llguidance.hf
Expand Down Expand Up @@ -62,7 +62,7 @@ def _initialize(self):

def __call__(
self,
input_ids: List[int],
input_ids: list[int],
scores: torch.Tensor,
) -> torch.Tensor:
# we initialize the guidance model here
Expand Down
10 changes: 5 additions & 5 deletions vllm/model_executor/guided_decoding/guided_fields.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
# SPDX-License-Identifier: Apache-2.0

from dataclasses import dataclass
from typing import Dict, List, Optional, TypedDict, Union
from typing import Optional, TypedDict, Union

from pydantic import BaseModel


# These classes are deprecated, see SamplingParams
class LLMGuidedOptions(TypedDict, total=False):
guided_json: Union[Dict, BaseModel, str]
guided_json: Union[dict, BaseModel, str]
guided_regex: str
guided_choice: List[str]
guided_choice: list[str]
guided_grammar: str
guided_decoding_backend: str
guided_whitespace_pattern: str
Expand All @@ -20,9 +20,9 @@ class LLMGuidedOptions(TypedDict, total=False):
@dataclass
class GuidedDecodingRequest:
"""One of the fields will be used to retrieve the logit processor."""
guided_json: Optional[Union[Dict, BaseModel, str]] = None
guided_json: Optional[Union[dict, BaseModel, str]] = None
guided_regex: Optional[str] = None
guided_choice: Optional[List[str]] = None
guided_choice: Optional[list[str]] = None
guided_grammar: Optional[str] = None
guided_decoding_backend: Optional[str] = None
guided_whitespace_pattern: Optional[str] = None
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/guided_decoding/outlines_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from enum import Enum
from json import dumps as json_dumps
from re import escape as regex_escape
from typing import Optional, Tuple, Union
from typing import Optional, Union

from transformers import PreTrainedTokenizerBase

Expand Down Expand Up @@ -111,7 +111,7 @@ def get_local_outlines_guided_decoding_logits_processor(

def _get_guide_and_mode(
guided_params: GuidedDecodingParams
) -> Union[Tuple[str, GuidedDecodingMode], Tuple[None, None]]:
) -> Union[tuple[str, GuidedDecodingMode], tuple[None, None]]:
if guided_params.json:
if isinstance(guided_params.json, dict):
# turn dict into hashable string
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import json
from collections import defaultdict
from functools import lru_cache
from typing import Callable, DefaultDict, Dict, List, Optional, Union
from typing import Callable, Optional, Union

import numpy as np
import torch
Expand Down Expand Up @@ -53,10 +53,10 @@ def __init__(self, guide: Guide, reasoner: Optional[ReasoningParser]):
self._guide: Guide = guide
self._reasoner: Optional[ReasoningParser] = reasoner
# CFGState is used for the FSM state for CFGGuide
self._fsm_state: DefaultDict[int, Union[int,
self._fsm_state: defaultdict[int, Union[int,
CFGState]] = defaultdict(int)

def __call__(self, input_ids: List[int],
def __call__(self, input_ids: list[int],
scores: torch.Tensor) -> torch.Tensor:
"""Use the FSM to bias the logits before sampling the next token."""

Expand Down Expand Up @@ -160,7 +160,7 @@ def __init__(

class JSONLogitsProcessor(RegexLogitsProcessor):

def __init__(self, schema: Union[str, Dict, BaseModel],
def __init__(self, schema: Union[str, dict, BaseModel],
tokenizer: PreTrainedTokenizerBase,
whitespace_pattern: Union[str, None],
reasoner: Optional[ReasoningParser]):
Expand All @@ -181,7 +181,7 @@ def __init__(self, schema: Union[str, Dict, BaseModel],
"""
if isinstance(schema, type(BaseModel)):
schema_str = json.dumps(schema.model_json_schema())
elif isinstance(schema, Dict):
elif isinstance(schema, dict):
schema_str = json.dumps(schema)
elif isinstance(schema, str):
schema_str = schema
Expand Down Expand Up @@ -252,11 +252,11 @@ def convert_token_to_string(token: str) -> str:
return string

def change_decoder(
decoder: Callable[[List[int]],
str]) -> Callable[[List[int]], List[str]]:
decoder: Callable[[list[int]],
str]) -> Callable[[list[int]], list[str]]:
"""Sync vLLM's decoder with the outlines by returning list."""

def new_decoder(inp_tokens: List[int]) -> List[str]:
def new_decoder(inp_tokens: list[int]) -> list[str]:
if (isinstance(inp_tokens, list) and len(inp_tokens) == 1
and isinstance(inp_tokens[0], list)):
inp_tokens = inp_tokens[0]
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/guided_decoding/xgrammar_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import json
import re
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, List
from typing import TYPE_CHECKING, Any

import torch

Expand Down Expand Up @@ -273,7 +273,7 @@ def escape_ebnf_string(s: str) -> str:
return re.sub(r'(["\\])', r'\\\1', s)

@staticmethod
def choice_as_grammar(choice: List[str] | None) -> str:
def choice_as_grammar(choice: list[str] | None) -> str:
if choice is None:
raise ValueError("Choice is not set")
escaped_choices = (GrammarConfig.escape_ebnf_string(c) for c in choice)
Expand Down
8 changes: 4 additions & 4 deletions vllm/model_executor/pooling_metadata.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0

from dataclasses import dataclass
from typing import Any, Dict, List, Tuple
from typing import Any

import torch

Expand All @@ -23,9 +23,9 @@ class PoolingMetadata:

def __init__(
self,
seq_groups: List[Tuple[List[int], PoolingParams]],
seq_data: Dict[int, Any], # Specific data related to sequences
prompt_lens: List[int],
seq_groups: list[tuple[list[int], PoolingParams]],
seq_data: dict[int, Any], # Specific data related to sequences
prompt_lens: list[int],
) -> None:
self.seq_groups = seq_groups
self.seq_data = seq_data
Expand Down
88 changes: 44 additions & 44 deletions vllm/model_executor/sampling_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from array import array
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
from typing import Optional

import torch

Expand All @@ -25,10 +25,10 @@ class SequenceGroupToSample:
# |-- query_len ---|

# Sequence ids for the sequence group in a previous step.
seq_ids: List[int]
seq_ids: list[int]
sampling_params: SamplingParams
# seq_id -> sequence data.
seq_data: Dict[int, SequenceData]
seq_data: dict[int, SequenceData]
# The length of the sequence (all tokens seen in the past + new token to
# compute attention) of the sequence group. None if it is in a decode
# stage.
Expand All @@ -44,9 +44,9 @@ class SequenceGroupToSample:
is_prompt: bool
# Query token indices from logits. to compute prompt logprob. Empty if
# prompt logprob is not required.
prompt_logprob_indices: List[int]
prompt_logprob_indices: list[int]
# Sample token indices from logits. Empty if sampling is not required.
sample_indices: List[int]
sample_indices: list[int]

@property
def do_sample(self):
Expand Down Expand Up @@ -78,7 +78,7 @@ class SamplingMetadataCache:
"""Used to cache SamplingMetadata objects between scheduler iterations"""

def __init__(self):
self._seq_group_to_sample_cache: Dict[int, PyObjectCache] = {}
self._seq_group_to_sample_cache: dict[int, PyObjectCache] = {}

def get_cached_seq_group_to_sample(self, num_seqs):
if num_seqs not in self._seq_group_to_sample_cache:
Expand Down Expand Up @@ -130,9 +130,9 @@ def sample(logits):

def __init__(
self,
seq_groups: List[SequenceGroupToSample],
seq_groups: list[SequenceGroupToSample],
selected_token_indices: torch.Tensor,
categorized_sample_indices: Dict[SamplingType, torch.Tensor],
categorized_sample_indices: dict[SamplingType, torch.Tensor],
num_prompts: int,
skip_sampler_cpu_output: bool = False,
reuse_sampling_tensors: bool = False,
Expand All @@ -146,12 +146,12 @@ def __init__(

@staticmethod
def prepare(
seq_group_metadata_list: List[SequenceGroupMetadata],
seq_lens: List[int],
query_lens: List[int],
seq_group_metadata_list: list[SequenceGroupMetadata],
seq_lens: list[int],
query_lens: list[int],
device: str,
pin_memory: bool,
generators: Optional[Dict[str, torch.Generator]] = None,
generators: Optional[dict[str, torch.Generator]] = None,
cache: Optional[SamplingMetadataCache] = None,
) -> "SamplingMetadata":
(
Expand Down Expand Up @@ -195,16 +195,16 @@ def __repr__(self) -> str:


def _prepare_seq_groups(
seq_group_metadata_list: List[SequenceGroupMetadata],
seq_lens: List[int],
query_lens: List[int],
seq_group_metadata_list: list[SequenceGroupMetadata],
seq_lens: list[int],
query_lens: list[int],
device: str,
generators: Optional[Dict[str, torch.Generator]] = None,
generators: Optional[dict[str, torch.Generator]] = None,
cache: Optional[SamplingMetadataCache] = None,
) -> Tuple[
List[SequenceGroupToSample],
List[int],
Dict[SamplingType, List[int]],
) -> tuple[
list[SequenceGroupToSample],
list[int],
dict[SamplingType, list[int]],
int,
]:
"""Prepare sequence groups and indices for sampling.
Expand All @@ -227,17 +227,17 @@ def _prepare_seq_groups(
num_prompts: Total number of prompts from `seq_group_metadata_list`.
"""
# Batched sequence groups for the current model forward stsep.
seq_groups: List[SequenceGroupToSample] = []
seq_groups: list[SequenceGroupToSample] = []
# A list of token indices to sample/compute logprob. It is used to
# prune the outcome logits from the model for the performance.
selected_token_indices: List[int] = []
selected_token_indices: list[int] = []
# Used for selected_token_indices.
model_output_idx = 0

# Sampling type -> (
# indices to sample/prompt logprob within pruned output logits,
# indices to sample within pruned logits)
categorized_sample_indices: Dict[SamplingType, List[int]] = {
categorized_sample_indices: dict[SamplingType, list[int]] = {
t: []
for t in SamplingType
}
Expand Down Expand Up @@ -265,9 +265,9 @@ def _prepare_seq_groups(
# If the current seq group is in decode stage, it is None.
seq_len: Optional[int] = None
query_len: Optional[int] = None
prompt_logprob_indices: List[int] = (sample_obj.prompt_logprob_indices
prompt_logprob_indices: list[int] = (sample_obj.prompt_logprob_indices
if cache is not None else [])
sample_indices: List[int] = (sample_obj.sample_indices
sample_indices: list[int] = (sample_obj.sample_indices
if cache is not None else [])
do_sample = seq_group_metadata.do_sample

Expand Down Expand Up @@ -389,16 +389,16 @@ def from_sampling_metadata(
vocab_size: int,
device: torch.device,
dtype: torch.dtype,
) -> Tuple["SamplingTensors", bool, bool, bool]:
prompt_tokens: List[array] = []
output_tokens: List[array] = []
top_ks: List[int] = []
temperatures: List[float] = []
top_ps: List[float] = []
min_ps: List[float] = []
presence_penalties: List[float] = []
frequency_penalties: List[float] = []
repetition_penalties: List[float] = []
) -> tuple["SamplingTensors", bool, bool, bool]:
prompt_tokens: list[array] = []
output_tokens: list[array] = []
top_ks: list[int] = []
temperatures: list[float] = []
top_ps: list[float] = []
min_ps: list[float] = []
presence_penalties: list[float] = []
frequency_penalties: list[float] = []
repetition_penalties: list[float] = []
do_penalties = False
do_top_p_top_k = False
do_min_p = False
Expand Down Expand Up @@ -496,15 +496,15 @@ def from_sampling_metadata(
@classmethod
def from_lists(
cls,
temperatures: List[float],
top_ps: List[float],
top_ks: List[int],
min_ps: List[float],
presence_penalties: List[float],
frequency_penalties: List[float],
repetition_penalties: List[float],
prompt_tokens: List[array],
output_tokens: List[array],
temperatures: list[float],
top_ps: list[float],
top_ks: list[int],
min_ps: list[float],
presence_penalties: list[float],
frequency_penalties: list[float],
repetition_penalties: list[float],
prompt_tokens: list[array],
output_tokens: list[array],
vocab_size: int,
device: torch.device,
dtype: torch.dtype,
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
"""Utils for model executor."""
from typing import Any, Dict, Optional
from typing import Any, Optional

import torch

Expand All @@ -12,7 +12,7 @@ def set_random_seed(seed: int) -> None:

def set_weight_attrs(
weight: torch.Tensor,
weight_attrs: Optional[Dict[str, Any]],
weight_attrs: Optional[dict[str, Any]],
):
"""Set attributes on a weight tensor.

Expand Down