diff --git a/requirements/common.txt b/requirements/common.txt index c5eb6dab955d..0a4b27c03447 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -7,7 +7,7 @@ requests >= 2.26.0 tqdm blake3 py-cpuinfo -transformers >= 4.53.2 +transformers >= 4.55.0 huggingface-hub[hf_xet] >= 0.33.0 # Required for Xet downloads. tokenizers >= 0.21.1 # Required for fast incremental detokenization. protobuf # Required by LlamaTokenizer. diff --git a/requirements/test.in b/requirements/test.in index 9ecaaae92727..9c8c75dd6f70 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -35,7 +35,7 @@ opencv-python-headless >= 4.11.0 # required for video test datamodel_code_generator # required for minicpm3 test lm-eval[api]==0.4.8 # required for model evaluation test mteb[bm25s]>=1.38.11, <2 # required for mteb test -transformers==4.53.2 +transformers==4.55.0 tokenizers==0.21.1 huggingface-hub[hf_xet]>=0.33.0 # Required for Xet downloads. schemathesis>=3.39.15 # Required for openai schema test. diff --git a/requirements/test.txt b/requirements/test.txt index 691420df87c4..08ba964f22a4 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -214,7 +214,7 @@ fiona==1.10.1 # via torchgeo flask==3.1.1 # via mlflow -fonttools==4.54.1 +fonttools==4.55.0 # via matplotlib fqdn==1.5.1 # via jsonschema @@ -286,7 +286,7 @@ httpx==0.27.2 # via # -r requirements/test.in # schemathesis -huggingface-hub==0.33.1 +huggingface-hub==0.34.3 # via # -r requirements/test.in # accelerate @@ -1148,7 +1148,7 @@ tqdm==4.66.6 # transformers tqdm-multiprocess==0.0.11 # via lm-eval -transformers==4.53.2 +transformers==4.55.0 # via # -r requirements/test.in # genai-perf diff --git a/tests/models/multimodal/generation/test_common.py b/tests/models/multimodal/generation/test_common.py index 8cb826c1144d..2a65d7e244d7 100644 --- a/tests/models/multimodal/generation/test_common.py +++ b/tests/models/multimodal/generation/test_common.py @@ -337,6 +337,10 @@ vllm_output_post_proc=model_utils.fuyu_vllm_to_hf_output, num_logprobs=10, image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)], + # FIXME(Isotr0py): This model is broken in Transformers v4.54.1, we + # should enable this again after the fix is released: + # https://github.com/huggingface/transformers/pull/39915 + marks=[pytest.mark.skip("HF model is broken")], ), "gemma3": VLMTestInfo( models=["google/gemma-3-4b-it"], diff --git a/tests/models/registry.py b/tests/models/registry.py index 47057d32e9cd..92a719d7a92d 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -179,8 +179,7 @@ def check_available_online( min_transformers_version="4.54"), "Fairseq2LlamaForCausalLM": _HfExamplesInfo("mgleize/fairseq2-dummy-Llama-3.2-1B"), # noqa: E501 "FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"), - "FalconH1ForCausalLM":_HfExamplesInfo("tiiuae/Falcon-H1-0.5B-Base", - min_transformers_version="4.53"), + "FalconH1ForCausalLM":_HfExamplesInfo("tiiuae/Falcon-H1-0.5B-Base"), "GemmaForCausalLM": _HfExamplesInfo("google/gemma-1.1-2b-it"), "Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"), "Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it"), @@ -223,7 +222,10 @@ def check_available_online( trust_remote_code=True), "JAISLMHeadModel": _HfExamplesInfo("inceptionai/jais-13b-chat"), "JambaForCausalLM": _HfExamplesInfo("ai21labs/AI21-Jamba-1.5-Mini", - extras={"tiny": "ai21labs/Jamba-tiny-dev"}), # noqa: E501 + extras={ + "tiny": "ai21labs/Jamba-tiny-dev", + "random": "ai21labs/Jamba-tiny-random", # noqa: E501 + }), "LlamaForCausalLM": _HfExamplesInfo("meta-llama/Llama-3.2-1B-Instruct", extras={"guard": "meta-llama/Llama-Guard-3-1B", # noqa: E501 "hermes": "NousResearch/Hermes-3-Llama-3.1-8B", # noqa: E501 @@ -239,8 +241,7 @@ def check_available_online( trust_remote_code=True), "MiniCPM3ForCausalLM": _HfExamplesInfo("openbmb/MiniCPM3-4B", trust_remote_code=True), - "MiniMaxForCausalLM": _HfExamplesInfo("MiniMaxAI/MiniMax-Text-01-hf", - min_transformers_version="4.53"), + "MiniMaxForCausalLM": _HfExamplesInfo("MiniMaxAI/MiniMax-Text-01-hf"), "MiniMaxText01ForCausalLM": _HfExamplesInfo("MiniMaxAI/MiniMax-Text-01", trust_remote_code=True, revision="a59aa9cbc53b9fb8742ca4e9e1531b9802b6fdc3"), # noqa: E501 @@ -272,6 +273,8 @@ def check_available_online( "PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct", trust_remote_code=True), "Plamo2ForCausalLM": _HfExamplesInfo("pfnet/plamo-2-1b", + max_transformers_version="4.53", + transformers_version_reason="vLLM impl inherits PreTrainedModel and clashes with get_input_embeddings", # noqa: E501 trust_remote_code=True), "QWenLMHeadModel": _HfExamplesInfo("Qwen/Qwen-7B-Chat", trust_remote_code=True), @@ -299,8 +302,7 @@ def check_available_online( "Zamba2ForCausalLM": _HfExamplesInfo("Zyphra/Zamba2-7B-instruct"), "MiMoForCausalLM": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL", trust_remote_code=True), - "Dots1ForCausalLM": _HfExamplesInfo("rednote-hilab/dots.llm1.inst", - min_transformers_version="4.53"), + "Dots1ForCausalLM": _HfExamplesInfo("rednote-hilab/dots.llm1.inst"), # [Encoder-decoder] "BartModel": _HfExamplesInfo("facebook/bart-base"), "BartForConditionalGeneration": _HfExamplesInfo("facebook/bart-large-cnn"), @@ -326,8 +328,12 @@ def check_available_online( "NomicBertModel": _HfExamplesInfo("nomic-ai/nomic-embed-text-v2-moe", trust_remote_code=True, v0_only=True), # noqa: E501 "Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"), - "Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B"), - "Qwen2ForProcessRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-PRM-7B"), + "Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B", + max_transformers_version="4.53", + transformers_version_reason="HF model uses remote code that is not compatible with latest Transformers"), # noqa: E501 + "Qwen2ForProcessRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-PRM-7B", + max_transformers_version="4.53", + transformers_version_reason="HF model uses remote code that is not compatible with latest Transformers"), # noqa: E501 "RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2", v0_only=True), # noqa: E501 "RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1", v0_only=True), # noqa: E501 "XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small", v0_only=True), # noqa: E501 diff --git a/tests/quantization/test_experts_int8.py b/tests/quantization/test_experts_int8.py index 84a656a3b9da..1e3e69e008bd 100644 --- a/tests/quantization/test_experts_int8.py +++ b/tests/quantization/test_experts_int8.py @@ -9,6 +9,8 @@ from tests.quantization.utils import is_quant_method_supported +from ..models.registry import HF_EXAMPLE_MODELS + MODELS = ["ai21labs/Jamba-tiny-random", "pfnet/plamo-2-1b"] @@ -25,6 +27,8 @@ def test_model_experts_int8_startup( dtype: str, max_tokens: int, ) -> None: + model_info = HF_EXAMPLE_MODELS.find_hf_info(model) + model_info.check_transformers_version(on_fail="skip") with vllm_runner(model, dtype=dtype, quantization="experts_int8") as vllm_model: diff --git a/vllm/model_executor/models/interfaces_base.py b/vllm/model_executor/models/interfaces_base.py index 4d68227b2af8..697fa020deb4 100644 --- a/vllm/model_executor/models/interfaces_base.py +++ b/vllm/model_executor/models/interfaces_base.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol, +from typing import (TYPE_CHECKING, Any, ClassVar, Literal, Optional, Protocol, Union, overload, runtime_checkable) import torch @@ -14,6 +14,10 @@ from vllm.config import VllmConfig from vllm.model_executor.layers.pooler import Pooler from vllm.model_executor.sampling_metadata import SamplingMetadata +else: + VllmConfig = Any + Pooler = Any + SamplingMetadata = Any logger = init_logger(__name__) @@ -34,7 +38,7 @@ class VllmModel(Protocol[T_co]): def __init__( self, - vllm_config: "VllmConfig", + vllm_config: VllmConfig, prefix: str = "", ) -> None: ... @@ -96,7 +100,7 @@ class VllmModelForTextGeneration(VllmModel[T], Protocol[T]): def compute_logits( self, hidden_states: T, - sampling_metadata: "SamplingMetadata", + sampling_metadata: SamplingMetadata, ) -> Optional[T]: """Return `None` if TP rank > 0.""" ... @@ -140,7 +144,7 @@ class VllmModelForPooling(VllmModel[T_co], Protocol[T_co]): MRO of your model class. """ - pooler: "Pooler" + pooler: Pooler """The pooler is only called on TP rank 0.""" diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 40d77312b72c..633f8598e879 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -1395,11 +1395,12 @@ def __init__( **kwargs, ): self.image_processor = Tarsier2ImageProcessor(**vision_config) - super().__init__(image_processor=self.image_processor, - tokenizer=tokenizer, - video_processor=Qwen2VLVideoProcessor(), - chat_template=None, - **kwargs) + super().__init__( + image_processor=self.image_processor, + tokenizer=tokenizer, + video_processor=Qwen2VLVideoProcessor(**vision_config), + chat_template=None, + **kwargs) class Tarsier2ProcessingInfo(Qwen2VLProcessingInfo): diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 5059d1e1d9fe..0c3df267edb1 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -90,7 +90,7 @@ def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module): def replace_linear_class( linear: nn.Linear, style: Literal["colwise", "rowwise"], quant_config: QuantizationConfig -) -> Union[ColumnParallelLinear, RowParallelLinear]: +) -> Union[ColumnParallelLinear, RowParallelLinear, ReplicatedLinear]: """ Replace nn.Linear with one of vLLM's tensor parallel linear classes. @@ -445,7 +445,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # Set correct attn and init on "meta" to delay allocating GPU tensors # TODO: @raushan, use the public `model.set_attn_implementation()` - # method after v4.54.0 is released + # method once its checks are fixed in Transformers. self.text_config._attn_implementation = "vllm" with init_on_device_without_buffers("meta"), config_override: self.model: PreTrainedModel = AutoModel.from_config( @@ -520,7 +520,7 @@ def pipeline_parallel(self): for i in range(len(layers)): if start_layer <= i and i < end_layer: continue - layers[i] = PPMissingLayer(return_tuple=True) + layers[i] = PPMissingLayer() # Layers after module list for name in pp_plan[module_list_idx + 1:]: @@ -533,14 +533,16 @@ def tensor_parallel(self): Apply the model's tensor parallelization plan. Currently only supports linear layers. """ - if not self.model.supports_tp_plan: - if self.tp_size <= 1: - return + tp_plan = getattr(self.model.config, "base_model_tp_plan", None) or {} + if not tp_plan and self.tp_size > 1: raise ValueError( f"{type(self.model)} does not support tensor parallel yet!") - tp_plan = self.model._tp_plan + # Some weight loaders expect linear layers to inherit from vLLM's + # LinearBase class, so we set a default style which causes any + # unspecified linear layers to be replaced with ReplicatedLinear + tp_plan[".*"] = "replicated" def _tensor_parallel(module: nn.Module, prefix: str = ""): for child_name, child_module in module.named_children(): @@ -552,6 +554,7 @@ def _tensor_parallel(module: nn.Module, prefix: str = ""): child_module, style, self.quant_config) setattr(module, child_name, new_module) log_replacement(qual_name, child_module, new_module) + break else: _tensor_parallel(child_module, prefix=qual_name) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 28508e1bac1e..fecd14dde4a8 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -534,16 +534,10 @@ class PPMissingLayer(torch.nn.Identity): def __init__(self, *args, **kwargs): super().__init__() - self.return_tuple = kwargs.get("return_tuple", False) def forward(self, *args, **kwargs): - """ - Return the first arg from args or the first value from kwargs. - - Wraps the input in a tuple if `self.return_tuple` is True. - """ - input = args[0] if args else next(iter(kwargs.values())) - return (input, ) if self.return_tuple else input + """Return the first arg from args or the first value from kwargs.""" + return args[0] if args else next(iter(kwargs.values())) _CPU_OFFLOAD_BYTES = 0 diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 8fe153464d36..bce24ef74cde 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -35,7 +35,8 @@ MllamaConfig, MLPSpeculatorConfig, Nemotron_Nano_VL_Config, NemotronConfig, NVLM_D_Config, - RWConfig, SpeculatorsConfig, + OvisConfig, RWConfig, + SpeculatorsConfig, Step3TextConfig, Step3VLConfig, UltravoxConfig) # yapf: enable @@ -85,6 +86,7 @@ def _get_hf_token() -> Optional[str]: "speculators": SpeculatorsConfig, "nemotron": NemotronConfig, "NVLM_D": NVLM_D_Config, + "ovis": OvisConfig, "ultravox": UltravoxConfig, "step3_vl": Step3VLConfig, "step3_text": Step3TextConfig, diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 64ace167a5a0..82d24bb16ba5 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -24,6 +24,7 @@ from vllm.transformers_utils.configs.nemotron_h import NemotronHConfig from vllm.transformers_utils.configs.nemotron_vl import Nemotron_Nano_VL_Config from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config +from vllm.transformers_utils.configs.ovis import OvisConfig from vllm.transformers_utils.configs.speculators.base import SpeculatorsConfig from vllm.transformers_utils.configs.step3_vl import (Step3TextConfig, Step3VisionEncoderConfig, @@ -45,6 +46,7 @@ "NemotronHConfig", "Nemotron_Nano_VL_Config", "NVLM_D_Config", + "OvisConfig", "SpeculatorsConfig", "UltravoxConfig", "Step3VLConfig", diff --git a/vllm/transformers_utils/configs/ovis.py b/vllm/transformers_utils/configs/ovis.py new file mode 100644 index 000000000000..550f5e15dbcc --- /dev/null +++ b/vllm/transformers_utils/configs/ovis.py @@ -0,0 +1,176 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# yapf: disable +# ruff: noqa: E501 +# adapted from https://huggingface.co/AIDC-AI/Ovis2-1B/blob/main/configuration_aimv2.py +# and https://huggingface.co/AIDC-AI/Ovis2-1B/blob/main/configuration_ovis.py +# Ovis Config with AimV2 config registration removed for Transformers compatibility +from typing import Any, Optional, Union + +from transformers import AutoConfig, PretrainedConfig + + +class AIMv2Config(PretrainedConfig): + """This is the configuration class to store the configuration of an [`AIMv2Model`]. + Instantiating a configuration with the defaults will yield a similar configuration + to that of the [apple/aimv2-large-patch14-224](https://huggingface.co/apple/aimv2-large-patch14-224). + Args: + hidden_size: Dimension of the hidden representations. + intermediate_size: Dimension of the SwiGLU representations. + num_hidden_layers: Number of hidden layers in the Transformer. + num_attention_heads: Number of attention heads for each attention layer + in the Transformer. + num_channels: Number of input channels. + image_size: Image size. + patch_size: Patch size. + rms_norm_eps: Epsilon value used for the RMS normalization layer. + attention_dropout: Dropout ratio for attention probabilities. + projection_dropout: Dropout ratio for the projection layer after the attention. + qkv_bias: Whether to add a bias to the queries, keys and values. + use_bias: Whether to add a bias in the feed-forward and projection layers. + kwargs: Keyword arguments for the [`PretrainedConfig`]. + """ + + model_type: str = "aimv2" + + def __init__( + self, + hidden_size: int = 1024, + intermediate_size: int = 2816, + num_hidden_layers: int = 24, + num_attention_heads: int = 8, + num_channels: int = 3, + image_size: int = 224, + patch_size: int = 14, + rms_norm_eps: float = 1e-5, + attention_dropout: float = 0.0, + projection_dropout: float = 0.0, + qkv_bias: bool = False, + use_bias: bool = False, + **kwargs: Any, + ): + super().__init__(**kwargs) + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.image_size = image_size + self.attention_dropout = attention_dropout + self.rms_norm_eps = rms_norm_eps + + self.projection_dropout = projection_dropout + self.qkv_bias = qkv_bias + self.use_bias = use_bias + + +# ---------------------------------------------------------------------- +# Visual Tokenizer Configuration +# ---------------------------------------------------------------------- +class BaseVisualTokenizerConfig(PretrainedConfig): + + def __init__(self, + vocab_size=16384, + tokenize_function="softmax", + tau=1.0, + depths=None, + drop_cls_token=False, + backbone_config: Optional[Union[PretrainedConfig, + dict]] = None, + hidden_stride: int = 1, + **kwargs): + super().__init__(**kwargs) + self.vocab_size = vocab_size + self.tokenize_function = tokenize_function + self.tau = tau + if isinstance(depths, str): + depths = [int(x) for x in depths.split('|')] + self.depths = depths + self.backbone_kwargs = dict[str, Any]() + self.drop_cls_token = drop_cls_token + if backbone_config is not None: + assert isinstance(backbone_config, (PretrainedConfig, dict)), \ + f"expect `backbone_config` to be instance of PretrainedConfig or dict, but got {type(backbone_config)} type" + if not isinstance(backbone_config, PretrainedConfig): + model_type = backbone_config['model_type'] + if model_type != "aimv2": + backbone_config.pop('model_type') + backbone_config = AutoConfig.for_model(model_type, **backbone_config) + else: + backbone_config = AIMv2Config(**backbone_config) + self.backbone_config = backbone_config + self.hidden_stride = hidden_stride + + +class Aimv2VisualTokenizerConfig(BaseVisualTokenizerConfig): + model_type = "aimv2_visual_tokenizer" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + if self.drop_cls_token: + self.drop_cls_token = False + if self.depths: + assert len(self.depths) == 1 + self.backbone_kwargs['num_hidden_layers'] = self.depths[0] + + +class SiglipVisualTokenizerConfig(BaseVisualTokenizerConfig): + model_type = "siglip_visual_tokenizer" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + if self.drop_cls_token: + self.drop_cls_token = False + if self.depths: + assert len(self.depths) == 1 + self.backbone_kwargs['num_hidden_layers'] = self.depths[0] + + +AutoConfig.register("siglip_visual_tokenizer", SiglipVisualTokenizerConfig) +AutoConfig.register("aimv2_visual_tokenizer", Aimv2VisualTokenizerConfig) + + +# ---------------------------------------------------------------------- +# Ovis Configuration +# ---------------------------------------------------------------------- +class OvisConfig(PretrainedConfig): + model_type = "ovis" + + def __init__(self, + llm_config: Optional[Union[PretrainedConfig, dict]] = None, + visual_tokenizer_config: Optional[Union[PretrainedConfig, + dict]] = None, + multimodal_max_length=8192, + hidden_size=None, + conversation_formatter_class=None, + llm_attn_implementation=None, + disable_tie_weight=False, + **kwargs): + super().__init__(**kwargs) + if llm_config is not None: + assert isinstance(llm_config, (PretrainedConfig, dict)), \ + f"expect `llm_config` to be instance of PretrainedConfig or dict, but got {type(llm_config)} type" + if not isinstance(llm_config, PretrainedConfig): + model_type = llm_config['model_type'] + llm_config.pop('model_type') + llm_config = AutoConfig.for_model(model_type, **llm_config) + + # map llm_config to text_config + self.text_config = llm_config + if visual_tokenizer_config is not None: + assert isinstance(visual_tokenizer_config, (PretrainedConfig, dict)), \ + f"expect `visual_tokenizer_config` to be instance of PretrainedConfig or dict, but got {type(visual_tokenizer_config)} type" + if not isinstance(visual_tokenizer_config, PretrainedConfig): + model_type = visual_tokenizer_config['model_type'] + visual_tokenizer_config.pop('model_type') + visual_tokenizer_config = AutoConfig.for_model( + model_type, **visual_tokenizer_config) + + self.visual_tokenizer_config = visual_tokenizer_config + self.multimodal_max_length = multimodal_max_length + self.hidden_size = hidden_size + self.conversation_formatter_class = conversation_formatter_class + self.llm_attn_implementation = llm_attn_implementation + self.disable_tie_weight = disable_tie_weight