Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 42 additions & 15 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,10 @@ class ModelConfig:
the model name will be the same as `model`.
limit_mm_per_prompt: Maximum number of data items per modality
per prompt. Only applicable for multimodal models.
mm_processor_kwargs: Overrides for the multi-modal processor obtained
from `AutoProcessor.from_pretrained`.
disable_mm_preprocessor_cache: If True, disable caching of the
processed multi-modal inputs.
use_async_output_proc: Whether to use async output processor.
Defaults to True.
config_format: The config format which shall be loaded.
Expand All @@ -273,10 +277,6 @@ class ModelConfig:
hf_overrides: If a dictionary, contains arguments to be forwarded to the
HuggingFace config. If a callable, it is called to update the
HuggingFace config.
mm_processor_kwargs: Arguments to be forwarded to the model's processor
for multi-modal data, e.g., image processor.
disable_mm_preprocessor_cache: If true, then disables caching of the
multi-modal preprocessor/mapper. (not recommended)
override_neuron_config: Initialize non default neuron config or
override default neuron config that are specific to Neuron devices,
this argument will be used to configure the neuron config that
Expand Down Expand Up @@ -320,7 +320,6 @@ def compute_hash(self) -> str:
factors.append(self.max_logprobs)
factors.append(self.disable_sliding_window)
factors.append(self.trust_remote_code)
factors.append(self.mm_processor_kwargs)
Copy link
Copy Markdown
Member Author

@DarkLight1337 DarkLight1337 Apr 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why this was included in the first place. mm_processor_kwargs can be overridden per request anyway

factors.append(self.generation_config)
factors.append(self.model_impl)
factors.append(self.override_generation_config)
Expand Down Expand Up @@ -359,12 +358,12 @@ def __init__(
skip_tokenizer_init: bool = False,
served_model_name: Optional[Union[str, list[str]]] = None,
limit_mm_per_prompt: Optional[dict[str, int]] = None,
mm_processor_kwargs: Optional[dict[str, Any]] = None,
disable_mm_preprocessor_cache: bool = False,
use_async_output_proc: bool = True,
config_format: ConfigFormat = ConfigFormat.AUTO,
hf_token: Optional[Union[bool, str]] = None,
hf_overrides: Optional[HfOverrides] = None,
mm_processor_kwargs: Optional[dict[str, Any]] = None,
disable_mm_preprocessor_cache: bool = False,
override_neuron_config: Optional[dict[str, Any]] = None,
override_pooler_config: Optional["PoolerConfig"] = None,
logits_processor_pattern: Optional[str] = None,
Expand Down Expand Up @@ -469,8 +468,6 @@ def __init__(
self.model, hf_token=hf_token, revision=revision)
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
self.use_async_output_proc = use_async_output_proc
self.mm_processor_kwargs = mm_processor_kwargs
self.disable_mm_preprocessor_cache = disable_mm_preprocessor_cache

# Set enforce_eager to False if the value is unset.
if self.enforce_eager is None:
Expand Down Expand Up @@ -515,7 +512,10 @@ def __init__(
self.served_model_name = get_served_model_name(model,
served_model_name)
self.multimodal_config = self._init_multimodal_config(
limit_mm_per_prompt)
limit_mm_per_prompt=limit_mm_per_prompt,
mm_processor_kwargs=mm_processor_kwargs,
disable_mm_preprocessor_cache=disable_mm_preprocessor_cache,
)
if not self.skip_tokenizer_init:
self._verify_tokenizer_mode()

Expand Down Expand Up @@ -581,14 +581,27 @@ def maybe_pull_model_tokenizer_for_s3(self, model: str,
self.tokenizer = s3_tokenizer.dir

def _init_multimodal_config(
self, limit_mm_per_prompt: Optional[dict[str, int]]
self,
limit_mm_per_prompt: Optional[dict[str, int]],
mm_processor_kwargs: Optional[dict[str, Any]],
disable_mm_preprocessor_cache: bool,
) -> Optional["MultiModalConfig"]:
if self.registry.is_multimodal_model(self.architectures):
return MultiModalConfig(limit_per_prompt=limit_mm_per_prompt or {})
return MultiModalConfig(
limit_per_prompt=limit_mm_per_prompt or {},
mm_processor_kwargs=mm_processor_kwargs or {},
disable_mm_preprocessor_cache=disable_mm_preprocessor_cache,
)

if limit_mm_per_prompt:
raise ValueError("`limit_mm_per_prompt` is only supported for "
"multimodal models.")
if mm_processor_kwargs:
raise ValueError("`mm_processor_kwargs` is only supported for "
"multimodal models.")
if disable_mm_preprocessor_cache:
raise ValueError("`disable_mm_preprocessor_cache` is only "
"supported for multimodal models.")

return None

Expand Down Expand Up @@ -2776,7 +2789,23 @@ class MultiModalConfig:
Defaults to 1 (V0) or 999 (V1) for each modality.

For example, to allow up to 16 images and 2 videos per prompt:
``{"images": 16, "videos": 2}``
:code:`{"images": 16, "videos": 2}`
"""

mm_processor_kwargs: Optional[dict[str, object]] = None
"""
Overrides for the multi-modal processor obtained from
:meth:`transformers.AutoProcessor.from_pretrained`.

The available overrides depend on the model that is being run.

For example, for Phi-3-Vision:
:code:`{"num_crops": 4}`.
"""

disable_mm_preprocessor_cache: bool = False
"""
If :code:`True`, disable caching of the processed multi-modal inputs.
"""

def compute_hash(self) -> str:
Expand Down Expand Up @@ -4080,8 +4109,6 @@ def __str__(self):
f"enable_prefix_caching={self.cache_config.enable_prefix_caching}, "
f"chunked_prefill_enabled={self.scheduler_config.chunked_prefill_enabled}, " # noqa
f"use_async_output_proc={self.model_config.use_async_output_proc}, "
f"disable_mm_preprocessor_cache={self.model_config.disable_mm_preprocessor_cache!r}, " # noqa
f"mm_processor_kwargs={self.model_config.mm_processor_kwargs}, "
f"pooler_config={self.model_config.pooler_config!r}, "
f"compilation_config={self.compilation_config!r}")

Expand Down
16 changes: 4 additions & 12 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,20 +672,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
)
multimodal_group.add_argument('--limit-mm-per-prompt',
**multimodal_kwargs["limit_per_prompt"])

parser.add_argument(
multimodal_group.add_argument(
'--mm-processor-kwargs',
default=None,
type=json.loads,
help=('Overrides for the multi-modal processor obtained from '
'``AutoProcessor.from_pretrained``. The available overrides '
'depend on the model that is being run.'
'For example, for Phi-3-Vision: ``{"num_crops": 4}``.'))
parser.add_argument(
**multimodal_kwargs["mm_processor_kwargs"])
multimodal_group.add_argument(
'--disable-mm-preprocessor-cache',
action='store_true',
help='If True, disable caching of the processed multi-modal '
'inputs.')
**multimodal_kwargs["disable_mm_preprocessor_cache"])

# LoRA related configs
lora_kwargs = get_kwargs(LoRAConfig)
Expand Down
6 changes: 4 additions & 2 deletions vllm/inputs/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ def init_processor(
Initialize a HuggingFace-like processor class, merging the
keyword arguments with those in the model's configuration.
"""
base_kwargs = self.model_config.mm_processor_kwargs
mm_config = self.model_config.get_multimodal_config()
base_kwargs = mm_config.mm_processor_kwargs
if base_kwargs is None:
base_kwargs = {}

Expand Down Expand Up @@ -139,7 +140,8 @@ def call_hf_processor(
"""
assert callable(hf_processor)

base_kwargs = self.model_config.mm_processor_kwargs
mm_config = self.model_config.get_multimodal_config()
base_kwargs = mm_config.mm_processor_kwargs
if base_kwargs is None:
base_kwargs = {}

Expand Down
5 changes: 3 additions & 2 deletions vllm/model_executor/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,8 +774,9 @@ def _get_image_processor_kwargs(
size: Optional[dict[str, int]] = None,
**kwargs: object,
):
if self.ctx.model_config.mm_processor_kwargs:
kwargs.update(self.ctx.model_config.mm_processor_kwargs)
mm_config = self.ctx.model_config.get_multimodal_config()
if mm_config.mm_processor_kwargs:
kwargs.update(mm_config.mm_processor_kwargs)

if min_pixels is not None:
kwargs["min_pixels"] = min_pixels
Expand Down
3 changes: 2 additions & 1 deletion vllm/multimodal/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,8 @@ def create_processor(
if tokenizer is None:
tokenizer = cached_tokenizer_from_config(model_config)
if disable_cache is None:
disable_cache = model_config.disable_mm_preprocessor_cache
mm_config = model_config.get_multimodal_config()
disable_cache = mm_config.disable_mm_preprocessor_cache

model_cls = self._get_model_cls(model_config)
factories = self._processor_factories[model_cls]
Expand Down
3 changes: 2 additions & 1 deletion vllm/transformers_utils/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ def __hash__(self) -> int: # type: ignore[override]


def _merge_mm_kwargs(model_config: "ModelConfig", **kwargs):
base_kwargs = model_config.mm_processor_kwargs
mm_config = model_config.get_multimodal_config()
base_kwargs = mm_config.mm_processor_kwargs
if base_kwargs is None:
base_kwargs = {}

Expand Down
5 changes: 4 additions & 1 deletion vllm/v1/engine/mm_input_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@
class MirroredProcessingCache:

def __init__(self, model_config):
self.use_cache = not model_config.disable_mm_preprocessor_cache
mm_config = model_config.multimodal_config
disable_mm_preprocessor_cache = mm_config is not None and \
not mm_config.disable_mm_preprocessor_cache
self.use_cache = not disable_mm_preprocessor_cache
self.mm_cache = ProcessingCache.get_lru_cache(VLLM_MM_INPUT_CACHE_GIB,
MultiModalKwargs)

Expand Down
3 changes: 1 addition & 2 deletions vllm/v1/engine/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@ def __init__(
self.mm_input_cache_client = MirroredProcessingCache(self.model_config)

# Multi-modal hasher (for images)
self.use_hash = (
not self.model_config.disable_mm_preprocessor_cache) or \
self.use_hash = self.mm_input_cache_client.use_cache or \
self.cache_config.enable_prefix_caching

def _validate_logprobs(
Expand Down