Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions tests/config/test_model_arch_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,3 +154,61 @@ def test_draft_model_arch_config(
_assert_model_config_methods(
model_config, expected, check_head_size=check_head_size
)


@pytest.mark.parametrize(
("model_type", "architectures", "expected_architecture"),
[
(
"qwen3_5_text",
["Qwen3_5ForConditionalGeneration"],
"Qwen3_5ForCausalLM",
),
(
"qwen3_5_moe_text",
["Qwen3_5MoeForConditionalGeneration"],
"Qwen3_5MoeForCausalLM",
),
],
)
def test_qwen3_5_text_model_arch_config(
tmp_path: Path,
model_type: str,
architectures: list[str],
expected_architecture: str,
) -> None:
model_dir = tmp_path / model_type
model_dir.mkdir()

config = {
"architectures": architectures,
"head_dim": 4,
"hidden_size": 16,
"intermediate_size": 32,
"model_type": model_type,
"num_attention_heads": 4,
"num_hidden_layers": 2,
"num_key_value_heads": 4,
"torch_dtype": "float16",
"vocab_size": 128,
}
if model_type == "qwen3_5_moe_text":
config.update(
{
"moe_intermediate_size": 8,
"num_experts": 4,
"num_experts_per_tok": 2,
"shared_expert_intermediate_size": 8,
}
)

with open(model_dir / "config.json", "w") as f:
json.dump(config, f)

model_config = ModelConfig(str(model_dir), tokenizer=str(model_dir))

assert model_config.hf_config.model_type == model_type
assert model_config.hf_config.architectures == [expected_architecture]
assert model_config.architectures == [expected_architecture]
assert model_config.architecture == expected_architecture
assert not model_config.is_multimodal_model
32 changes: 24 additions & 8 deletions tests/models/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,19 @@
from ..utils import create_new_process_for_each_test
from .registry import HF_EXAMPLE_MODELS

_INTERNAL_QWEN3_5_TEXT_ARCHS = {"Qwen3_5ForCausalLM", "Qwen3_5MoeForCausalLM"}


@pytest.mark.parametrize("model_arch", ModelRegistry.get_supported_archs())
def test_registry_imports(model_arch):
# Skip if transformers version is incompatible
model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch)
model_info.check_transformers_version(
on_fail="skip",
check_max_version=False,
check_version_reason="vllm",
)
if model_arch not in _INTERNAL_QWEN3_5_TEXT_ARCHS:
# Skip if transformers version is incompatible
model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch)
model_info.check_transformers_version(
on_fail="skip",
check_max_version=False,
check_version_reason="vllm",
)
# Ensure all model classes can be imported successfully
model_cls = ModelRegistry._try_load_model_cls(model_arch)
assert model_cls is not None
Expand All @@ -54,6 +57,17 @@ def test_registry_imports(model_arch):
assert supports_multimodal(model_cls)


@pytest.mark.parametrize(
"model_arch",
["Qwen3_5ForCausalLM", "Qwen3_5MoeForCausalLM"],
)
def test_qwen3_5_text_models_are_not_multimodal(model_arch):
model_cls = ModelRegistry._try_load_model_cls(model_arch)
assert model_cls is not None
assert is_text_generation_model(model_cls)
assert not supports_multimodal(model_cls)


@create_new_process_for_each_test()
@pytest.mark.parametrize(
"model_arch,is_mm,init_cuda,score_type",
Expand Down Expand Up @@ -118,7 +132,9 @@ def test_registry_is_pp(model_arch, is_pp, init_cuda):

def test_hf_registry_coverage():
untested_archs = (
ModelRegistry.get_supported_archs() - HF_EXAMPLE_MODELS.get_supported_archs()
ModelRegistry.get_supported_archs()
- _INTERNAL_QWEN3_5_TEXT_ARCHS
- HF_EXAMPLE_MODELS.get_supported_archs()
)

assert not untested_archs, (
Expand Down
38 changes: 37 additions & 1 deletion tests/models/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
import torch

from vllm.model_executor.models.utils import AutoWeightsLoader
from vllm.model_executor.models.utils import AutoWeightsLoader, WeightsMapper

pytestmark = pytest.mark.cpu_test

Expand Down Expand Up @@ -155,3 +155,39 @@ def weight_generator():
)
assert torch.all(new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1


class DummyQwen3_5TextOnlyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.model = torch.nn.Module()
self.model.proj = torch.nn.Linear(2, 2, bias=False)
self.lm_head = torch.nn.Linear(2, 2, bias=False)


def test_qwen3_5_text_loader_remaps_vl_weights_and_ignores_visual_keys():
"""Mirror the Qwen3.5 text-only weight-loading path on CPU."""
mod = DummyQwen3_5TextOnlyModule()
model_weight = torch.full_like(mod.model.proj.weight, 1.25)
lm_head_weight = torch.full_like(mod.lm_head.weight, 2.5)

loader = AutoWeightsLoader(
mod,
skip_prefixes=["mtp."],
ignore_unexpected_prefixes=["model.visual."],
)
loaded = loader.load_weights(
[
("model.language_model.proj.weight", model_weight),
("model.visual.proj.weight", torch.tensor([7.0])),
("lm_head.weight", lm_head_weight),
("mtp.extra.weight", torch.tensor([9.0])),
],
mapper=WeightsMapper(
orig_to_new_prefix={"model.language_model.": "model."},
),
)

assert loaded == {"model.proj.weight", "lm_head.weight"}
assert torch.allclose(mod.model.proj.weight, model_weight)
assert torch.allclose(mod.lm_head.weight, lm_head_weight)
2 changes: 2 additions & 0 deletions vllm/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,8 @@ def __post_init__(
self.model, hf_token=self.hf_token, revision=self.revision
)
self.model_arch_config = self.get_model_arch_config()
if self.model_arch_config.architectures is not None:
self.hf_config.architectures = self.model_arch_config.architectures.copy()

architectures = self.architectures
registry = self.registry
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,8 @@ def verify_and_update_model_config(model_config: "ModelConfig") -> None:
"Qwen2ForRewardModel": Qwen2ForRewardModelConfig,
"Qwen3ForSequenceClassification": Qwen3ForSequenceClassificationConfig,
"Qwen3VLForSequenceClassification": Qwen3VLForSequenceClassificationConfig,
"Qwen3_5ForCausalLM": Qwen3_5ForConditionalGenerationConfig,
"Qwen3_5MoeForCausalLM": Qwen3_5ForConditionalGenerationConfig,
"Qwen3_5ForConditionalGeneration": Qwen3_5ForConditionalGenerationConfig,
"Qwen3_5MoeForConditionalGeneration": Qwen3_5ForConditionalGenerationConfig,
"VoyageQwen3BidirectionalEmbedModel": VoyageQwen3BidirectionalEmbedModelConfig,
Expand Down
71 changes: 70 additions & 1 deletion vllm/model_executor/models/qwen3_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
MultiModalEmbeddings,
SupportsEagle3,
SupportsLoRA,
SupportsMRoPE,
SupportsPP,
_require_is_multimodal,
)
Expand All @@ -100,6 +101,7 @@
from .utils import (
AutoWeightsLoader,
PPMissingLayer,
WeightsMapper,
_merge_multimodal_embeddings,
extract_layer_index,
is_pp_missing_parameter,
Expand Down Expand Up @@ -597,9 +599,67 @@ class Qwen3_5ForCausalLMBase(
nn.Module,
HasInnerState,
SupportsEagle3,
IsHybrid,
SupportsMRoPE,
SupportsLoRA,
SupportsPP,
):
# Qwen3.5 interleaves full-attention layers (every 4th) with
# GatedDeltaNet (Mamba-style) layers, making it a hybrid model.
is_hybrid = True
supports_mrope: typing.ClassVar[typing.Literal[True]] = True

def get_mrope_input_positions(
self,
input_tokens: list[int],
mm_features: list,
) -> tuple[torch.Tensor, int]:
# Text-only model: all three M-RoPE axes use identical 1-D positions.
# The config inherits mrope_section from the VL parent, but for
# text-only inference the three sections are identical.
n = len(input_tokens)
positions = torch.arange(n).unsqueeze(0).expand(3, -1)
return positions, 0

@classmethod
def get_mamba_state_dtype_from_config(
cls,
vllm_config: "VllmConfig",
) -> tuple[torch.dtype, torch.dtype]:
return MambaStateDtypeCalculator.gated_delta_net_state_dtype(
vllm_config.model_config.dtype,
vllm_config.cache_config.mamba_cache_dtype,
vllm_config.cache_config.mamba_ssm_cache_dtype,
)

@classmethod
def get_mamba_state_shape_from_config(
cls, vllm_config: "VllmConfig"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The return type hint for this function is incorrect. MambaStateShapeCalculator.gated_delta_net_state_shape returns a tuple[tuple[int, int], tuple[int, int, int]], but the annotation here is tuple[tuple[int, int], tuple[int, int]]. This should be corrected to match the actual return type and the IsHybrid protocol definition.

Suggested change
cls, vllm_config: "VllmConfig"
) -> tuple[tuple[int, int], tuple[int, int, int]]:

) -> tuple[tuple[int, int], tuple[int, int, int]]:
parallel_config = vllm_config.parallel_config
hf_config = vllm_config.model_config.hf_text_config
tp_size = parallel_config.tensor_parallel_size
num_spec = (
vllm_config.speculative_config.num_speculative_tokens
if vllm_config.speculative_config
else 0
)
return MambaStateShapeCalculator.gated_delta_net_state_shape(
tp_size,
hf_config.linear_num_key_heads,
hf_config.linear_num_value_heads,
hf_config.linear_key_head_dim,
hf_config.linear_value_head_dim,
hf_config.linear_conv_kernel_dim,
num_spec,
)

@classmethod
def get_mamba_state_copy_func(
cls,
) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
return MambaStateCopyFuncCalculator.gated_delta_net_state_copy_func()

packed_modules_mapping = {
"qkv_proj": [
"q_proj",
Expand Down Expand Up @@ -690,11 +750,20 @@ def compute_logits(
return self.logits_processor(self.lm_head, hidden_states)

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
# Checkpoints quantized from the VL model (Qwen3_5ForConditionalGeneration)
# store weights under "model.language_model." and include vision-encoder
# keys under "model.visual.". Remap the prefix so that the text-only
# model (which expects "model.") can load them, and silently ignore
# any vision-encoder weights that have no counterpart here.
mapper = WeightsMapper(
orig_to_new_prefix={"model.language_model.": "model."},
)
loader = AutoWeightsLoader(
self,
skip_prefixes=["mtp."],
ignore_unexpected_prefixes=["model.visual."],
)
return loader.load_weights(weights)
return loader.load_weights(weights, mapper=mapper)


class Qwen3_5ForCausalLM(Qwen3_5ForCausalLMBase):
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,8 @@
"Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
"Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
"Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
"Qwen3_5ForCausalLM": ("qwen3_5", "Qwen3_5ForCausalLM"),
"Qwen3_5MoeForCausalLM": ("qwen3_5", "Qwen3_5MoeForCausalLM"),
"RWForCausalLM": ("falcon", "FalconForCausalLM"),
"SarvamMoEForCausalLM": ("sarvam", "SarvamMoEForCausalLM"),
"SarvamMLAForCausalLM": ("sarvam", "SarvamMLAForCausalLM"),
Expand Down
2 changes: 2 additions & 0 deletions vllm/transformers_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,9 @@ def __getitem__(self, key):
qwen3_asr="Qwen3ASRConfig",
qwen3_next="Qwen3NextConfig",
qwen3_5="Qwen3_5Config",
qwen3_5_text="Qwen3_5TextConfig",
qwen3_5_moe="Qwen3_5MoeConfig",
qwen3_5_moe_text="Qwen3_5MoeTextConfig",
lfm2_moe="Lfm2MoeConfig",
tarsier2="Tarsier2Config",
)
Expand Down
28 changes: 28 additions & 0 deletions vllm/transformers_utils/model_arch_config_convertor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,16 @@

logger = init_logger(__name__)

_QWEN3_5_TEXT_ARCHITECTURES = {
"Qwen3_5ForConditionalGeneration": "Qwen3_5ForCausalLM",
"Qwen3_5MoeForConditionalGeneration": "Qwen3_5MoeForCausalLM",
}

_QWEN3_5_TEXT_DEFAULT_ARCHITECTURES = {
"qwen3_5_text": "Qwen3_5ForCausalLM",
"qwen3_5_moe_text": "Qwen3_5MoeForCausalLM",
}


class ModelArchConfigConvertorBase:
def __init__(self, hf_config: PretrainedConfig, hf_text_config: PretrainedConfig):
Expand Down Expand Up @@ -346,6 +356,22 @@ def get_total_num_kv_heads(self) -> int:
return 0


class Qwen3_5TextModelArchConfigConvertor(ModelArchConfigConvertorBase):
def get_architectures(self) -> list[str]:
architectures = super().get_architectures()

if not architectures:
if (
default_arch := _QWEN3_5_TEXT_DEFAULT_ARCHITECTURES.get(
self.hf_config.model_type
)
) is not None:
return [default_arch]
return architectures

return [_QWEN3_5_TEXT_ARCHITECTURES.get(arch, arch) for arch in architectures]


class Zamba2ModelArchConfigConvertor(ModelArchConfigConvertorBase):
def get_head_size(self) -> int:
return getattr(self.hf_text_config, "attention_head_dim", 0)
Expand Down Expand Up @@ -452,6 +478,8 @@ def get_num_hidden_layers(self) -> int:
"falcon_mamba": MambaModelArchConfigConvertor,
"timm_wrapper": TerratorchModelArchConfigConvertor,
"medusa": MedusaModelArchConfigConvertor,
"qwen3_5_text": Qwen3_5TextModelArchConfigConvertor,
"qwen3_5_moe_text": Qwen3_5TextModelArchConfigConvertor,
"zamba2": Zamba2ModelArchConfigConvertor,
"mpt": MPTModelArchConfigConvertor,
"dbrx": DbrxModelArchConfigConvertor,
Expand Down
Loading