Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions vllm/model_executor/models/pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,10 @@
from vllm.sequence import IntermediateTensors
from vllm.tokenizers import cached_tokenizer_from_config
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.transformers_utils.processors.pixtral import MistralCommonPixtralProcessor
from vllm.transformers_utils.processors.pixtral import (
MistralCommonImageProcessor,
MistralCommonPixtralProcessor,
)
from vllm.utils.collection_utils import is_list_of
from vllm.utils.tensor_schema import TensorSchema, TensorShape

Expand Down Expand Up @@ -128,18 +131,20 @@ def get_tokenizer(self) -> MistralTokenizer:

return tokenizer

def get_image_processor(self) -> MistralCommonImageProcessor:
return MistralCommonImageProcessor(self.get_tokenizer().instruct.mm_encoder)

def get_hf_processor(self, **kwargs) -> MistralCommonPixtralProcessor:
return self.ctx.init_processor(
MistralCommonPixtralProcessor,
return MistralCommonPixtralProcessor(
tokenizer=self.get_tokenizer(),
**kwargs,
image_processor=self.get_image_processor(),
)

def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"image": None}

def get_image_size_with_most_features(self) -> ImageSize:
image_processor = self.get_hf_processor().image_processor
image_processor = self.get_image_processor()
max_image_size = image_processor.mm_encoder.mm_config.max_image_size

return ImageSize(width=max_image_size, height=max_image_size)
Expand Down
27 changes: 18 additions & 9 deletions vllm/model_executor/models/voxtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,10 @@
from vllm.sequence import IntermediateTensors
from vllm.tokenizers import cached_tokenizer_from_config
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.transformers_utils.processors.voxtral import MistralCommonVoxtralProcessor
from vllm.transformers_utils.processors.voxtral import (
MistralCommonFeatureExtractor,
MistralCommonVoxtralProcessor,
)
from vllm.utils.collection_utils import is_list_of

from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsTranscription
Expand Down Expand Up @@ -84,15 +87,19 @@ def get_tokenizer(self) -> MistralTokenizer:

return tokenizer

def get_feature_extractor(self) -> MistralCommonFeatureExtractor:
return MistralCommonFeatureExtractor(
self.get_tokenizer().instruct.audio_encoder
)

def get_hf_processor(self, **kwargs) -> MistralCommonVoxtralProcessor:
return self.ctx.init_processor(
MistralCommonVoxtralProcessor,
return MistralCommonVoxtralProcessor(
tokenizer=self.get_tokenizer(),
**kwargs,
feature_extractor=self.get_feature_extractor(),
)

def get_data_parser(self):
feature_extractor = self.get_hf_processor().feature_extractor
feature_extractor = self.get_feature_extractor()

return MultiModalDataParser(
target_sr=feature_extractor.sampling_rate,
Expand All @@ -114,7 +121,7 @@ def get_max_audio_tokens(self) -> int:
return self.ctx.model_config.max_model_len

def get_max_audio_array_len(self) -> int:
feature_extractor = self.get_hf_processor().feature_extractor
feature_extractor = self.get_feature_extractor()

return self.get_max_audio_tokens() * int(
feature_extractor.sampling_rate // feature_extractor.frame_rate
Expand Down Expand Up @@ -153,7 +160,7 @@ def get_dummy_processor_inputs(
mm_data: MultiModalDataDict | None = None,
) -> ProcessorInputs:
tokenizer = self.info.get_tokenizer()
feature_extractor = self.info.get_hf_processor().feature_extractor
feature_extractor = self.info.get_feature_extractor()

dummy_text = self.get_dummy_text(mm_counts)
dummy_mm_data = (
Expand Down Expand Up @@ -480,8 +487,10 @@ def get_num_audio_tokens(
This is used for estimating the amount of processing for this audio.
"""
tokenizer = cached_tokenizer_from_config(model_config)
adapter = MistralCommonVoxtralProcessor(tokenizer)
return adapter.feature_extractor.get_num_audio_tokens(
feature_extractor = MistralCommonFeatureExtractor(
tokenizer.instruct.audio_encoder
)
return feature_extractor.get_num_audio_tokens(
int(audio_duration_s * stt_config.sample_rate)
)

Expand Down
10 changes: 6 additions & 4 deletions vllm/transformers_utils/processors/pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,18 @@ def get_number_of_image_patches(
class MistralCommonPixtralProcessor(ProcessorMixin):
attributes = ["image_processor", "tokenizer"]

def __init__(self, tokenizer: MistralTokenizer) -> None:
def __init__(
Comment thread
allgather marked this conversation as resolved.
self,
tokenizer: MistralTokenizer,
image_processor: MistralCommonImageProcessor,
) -> None:
self.tokenizer = tokenizer.transformers_tokenizer

# Back-compatibility for Transformers v4
if not hasattr(self.tokenizer, "init_kwargs"):
self.tokenizer.init_kwargs = {}

self.image_processor = MistralCommonImageProcessor(
tokenizer.instruct.mm_encoder
)
self.image_processor = image_processor

image_special_ids = self.image_processor.mm_encoder.special_ids
self.image_break_id = image_special_ids.img_break
Expand Down
10 changes: 6 additions & 4 deletions vllm/transformers_utils/processors/voxtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,18 @@ def get_num_audio_tokens(self, audio_length: int) -> int:
class MistralCommonVoxtralProcessor(ProcessorMixin):
attributes = ["feature_extractor", "tokenizer"]

def __init__(self, tokenizer: MistralTokenizer) -> None:
def __init__(
self,
tokenizer: MistralTokenizer,
feature_extractor: MistralCommonFeatureExtractor,
) -> None:
self.tokenizer = tokenizer.transformers_tokenizer

# Back-compatibility for Transformers v4
if not hasattr(self.tokenizer, "init_kwargs"):
self.tokenizer.init_kwargs = {}

self.feature_extractor = MistralCommonFeatureExtractor(
tokenizer.instruct.audio_encoder
)
self.feature_extractor = feature_extractor

audio_special_ids = self.feature_extractor.audio_encoder.special_ids
self.audio_token_id = audio_special_ids.audio
Expand Down
Loading