Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 13 additions & 15 deletions vllm/model_executor/models/qwen2_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,20 +334,21 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config

self.audio_tower = Qwen2AudioEncoder(config.audio_config)
self.multi_modal_projector = Qwen2AudioMultiModalProjector(
config.audio_config.d_model, config.text_config.hidden_size
)

self.quant_config = quant_config

self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
architectures=["Qwen2ForCausalLM"],
)
with self._mark_tower_model(vllm_config, "audio"):
self.audio_tower = Qwen2AudioEncoder(config.audio_config)
self.multi_modal_projector = Qwen2AudioMultiModalProjector(
config.audio_config.d_model, config.text_config.hidden_size
)

with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
architectures=["Qwen2ForCausalLM"],
)

self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
Expand Down Expand Up @@ -441,9 +442,6 @@ def _process_audio_input(
masked_audio_features, audio_output_lengths.flatten().tolist()
)

def get_language_model(self) -> torch.nn.Module:
return self.language_model

def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
audio_input = self._parse_and_validate_audio_input(**kwargs)
if audio_input is None:
Expand Down
97 changes: 52 additions & 45 deletions vllm/model_executor/models/qwen3_omni_moe_thinker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1612,32 +1612,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
multimodal_config = vllm_config.model_config.multimodal_config
self.config = thinker_config
self.multimodal_config = multimodal_config

self.audio_tower = Qwen3OmniMoeAudioEncoder(
thinker_config.audio_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "audio_tower"),
)

self.visual = Qwen3Omni_VisionTransformer(
vision_config=thinker_config.vision_config,
norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"),
multimodal_config=multimodal_config,
)
self.quant_config = quant_config

self.language_model = Qwen3MoeLLMForCausalLM(
vllm_config=vllm_config.with_hf_config(
thinker_config.text_config, architectures=["Qwen3MoeForCausalLM"]
),
prefix=maybe_prefix(prefix, "language_model"),
)

self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
)
with self._mark_tower_model(vllm_config, "audio"):
self.audio_tower = Qwen3OmniMoeAudioEncoder(
thinker_config.audio_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "audio_tower"),
)

self.use_deepstack = hasattr(
thinker_config.vision_config, "deepstack_visual_indexes"
Expand All @@ -1647,22 +1629,48 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
if self.use_deepstack
else 0
)
# register buffer for deepstack
self.deepstack_input_embeds = (
[
torch.zeros(
vllm_config.scheduler_config.max_num_batched_tokens,
thinker_config.text_config.hidden_size,
)
for _ in range(self.deepstack_num_level)
]
if self.use_deepstack
else None
)
self.visual_dim = thinker_config.vision_config.out_hidden_size
self.multiscale_dim = self.visual_dim * self.deepstack_num_level

def _get_deepstack_input_embeds(self, num_tokens: int) -> IntermediateTensors:
with self._mark_tower_model(vllm_config, {"image", "video"}):
self.visual = Qwen3Omni_VisionTransformer(
vision_config=thinker_config.vision_config,
norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"),
multimodal_config=multimodal_config,
)

# register buffer for deepstack
if self.use_deepstack:
self.deepstack_input_embeds = [
torch.zeros(
vllm_config.scheduler_config.max_num_batched_tokens,
thinker_config.text_config.hidden_size,
)
for _ in range(self.deepstack_num_level)
]

with self._mark_language_model(vllm_config):
self.language_model = Qwen3MoeLLMForCausalLM(
vllm_config=vllm_config.with_hf_config(
thinker_config.text_config,
architectures=["Qwen3MoeForCausalLM"],
),
prefix=maybe_prefix(prefix, "language_model"),
)

self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
)

def _get_deepstack_input_embeds(
self,
num_tokens: int,
) -> IntermediateTensors | None:
if not getattr(self, "deepstack_input_embeds", None):
return None # If vision tower is skipped

# get deepstack_input_embeds from buffer, and clear the buffer
return IntermediateTensors(
{
Expand All @@ -1674,6 +1682,9 @@ def _get_deepstack_input_embeds(self, num_tokens: int) -> IntermediateTensors:
)

def _set_deepstack_input_embeds(self, deepstack_input_embeds: torch.Tensor) -> None:
if not getattr(self, "deepstack_input_embeds", None):
return

# set deepstack_input_embeds to buffer
num_tokens = deepstack_input_embeds.size(1)
if num_tokens > self.deepstack_input_embeds[0].size(0):
Expand All @@ -1692,6 +1703,9 @@ def _set_deepstack_input_embeds(self, deepstack_input_embeds: torch.Tensor) -> N
)

def _clear_deepstack_input_embeds(self, num_tokens: int) -> None:
if not getattr(self, "deepstack_input_embeds", None):
return

# clear deepstack_input_embeds in buffer
if num_tokens > 0:
for idx in range(self.deepstack_num_level):
Expand Down Expand Up @@ -1726,9 +1740,6 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
)
return mm_input_by_modality

def get_language_model(self) -> torch.nn.Module:
return self.language_model

def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None:
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
if not mm_input_by_modality:
Expand Down Expand Up @@ -1844,11 +1855,7 @@ def forward(
if intermediate_tensors is not None:
inputs_embeds = None

if (
self.use_deepstack
and inputs_embeds is not None
and get_pp_group().is_first_rank
):
if inputs_embeds is not None and get_pp_group().is_first_rank:
deepstack_input_embeds = self._get_deepstack_input_embeds(
inputs_embeds.size(0)
)
Expand Down
20 changes: 14 additions & 6 deletions vllm/model_executor/models/qwen3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1321,7 +1321,13 @@ def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = len(self.language_model.model.layers)
return (2, num_layers // 2, num_layers - 3)

def _get_deepstack_input_embeds(self, num_tokens: int) -> IntermediateTensors:
def _get_deepstack_input_embeds(
self,
num_tokens: int,
) -> IntermediateTensors | None:
if not getattr(self, "deepstack_input_embeds", None):
return None # If vision tower is skipped

# get deepstack_input_embeds from buffer, and clear the buffer
return IntermediateTensors(
{
Expand All @@ -1333,6 +1339,9 @@ def _get_deepstack_input_embeds(self, num_tokens: int) -> IntermediateTensors:
)

def _set_deepstack_input_embeds(self, deepstack_input_embeds: torch.Tensor) -> None:
if not getattr(self, "deepstack_input_embeds", None):
return

# set deepstack_input_embeds to buffer
num_tokens = deepstack_input_embeds.size(1)
if num_tokens > self.deepstack_input_embeds[0].size(0):
Expand All @@ -1351,6 +1360,9 @@ def _set_deepstack_input_embeds(self, deepstack_input_embeds: torch.Tensor) -> N
)

def _clear_deepstack_input_embeds(self, num_tokens: int) -> None:
if not getattr(self, "deepstack_input_embeds", None):
return

# clear deepstack_input_embeds in buffer
if num_tokens > 0:
for idx in range(self.deepstack_num_level):
Expand Down Expand Up @@ -2037,11 +2049,7 @@ def forward(
if intermediate_tensors is not None:
inputs_embeds = None

if (
self.use_deepstack
and inputs_embeds is not None
and get_pp_group().is_first_rank
):
if inputs_embeds is not None and get_pp_group().is_first_rank:
deepstack_input_embeds = self._get_deepstack_input_embeds(
inputs_embeds.size(0)
)
Expand Down
1 change: 0 additions & 1 deletion vllm/model_executor/models/radio.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,6 @@ def forward(
x: torch.Tensor,
imgs_sizes: torch.Tensor | None = None,
) -> torch.FloatTensor:
assert self.patch_generator is not None
hidden_states = self.patch_generator(x, imgs_sizes=imgs_sizes)
attn_mask = None
if imgs_sizes is not None and len(imgs_sizes) > 1:
Expand Down
29 changes: 14 additions & 15 deletions vllm/model_executor/models/siglip.py
Original file line number Diff line number Diff line change
Expand Up @@ -1033,20 +1033,22 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

self.text_embed_dim = text_config.hidden_size
self.vision_embed_dim = vision_config.hidden_size
self.text_projection_size = text_config.projection_size

self.text_model = SiglipTextTransformer(
text_config,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "text_model"),
)
self.vision_model = SiglipVisionTransformer(
vision_config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "vision_model"),
)
with self._mark_language_model(vllm_config):
self.text_model = SiglipTextTransformer(
text_config,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "text_model"),
)

self.text_projection_size = text_config.projection_size
with self._mark_tower_model(vllm_config, "image"):
self.vision_model = SiglipVisionTransformer(
vision_config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "vision_model"),
)

pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
Expand Down Expand Up @@ -1155,9 +1157,6 @@ def _process_image_inputs(self, inputs: SiglipImagePixelInputs) -> torch.Tensor:

return self.get_image_features(pixel_values)

def get_language_model(self) -> torch.nn.Module:
return self.text_model

def _embed_text_input_ids(
self,
input_ids: torch.Tensor,
Expand Down
39 changes: 18 additions & 21 deletions vllm/model_executor/models/skyworkr1v.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,24 +674,26 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
self.downsample_ratio = config.downsample_ratio
self.ps_version = config.ps_version

self.llm_arch_name = config.text_config.architectures[0]
self.is_mono = self.llm_arch_name == "SkyworkLM2VEForCausalLM"
self.vision_model = self._init_vision_model(
config,
quant_config=quant_config,
is_mono=self.is_mono,
prefix=maybe_prefix(prefix, "vision_model"),
)
llm_arch_name = config.text_config.architectures[0]
self.is_mono = llm_arch_name == "SkyworkLM2VEForCausalLM"

self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
with self._mark_tower_model(vllm_config, "image"):
self.vision_model = self._init_vision_model(
config,
quant_config=quant_config,
is_mono=self.is_mono,
prefix=maybe_prefix(prefix, "vision_model"),
)
self.mlp1 = self._init_mlp1(
config, quant_config, prefix=maybe_prefix(prefix, "mlp1")
)

self.mlp1 = self._init_mlp1(
config, quant_config, prefix=maybe_prefix(prefix, "mlp1")
)
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)

self.img_context_token_id = None
self.visual_token_mask = None
Expand Down Expand Up @@ -838,8 +840,6 @@ def _process_image_input(
if image_input["type"] == "image_embeds":
return image_input["data"]

assert self.vision_model is not None

image_embeds = self.extract_feature(image_input["pixel_values_flat"])

num_patches = image_input["num_patches"]
Expand Down Expand Up @@ -867,9 +867,6 @@ def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None:
else:
self.visual_token_mask = None

def get_language_model(self) -> torch.nn.Module:
return self.language_model

def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
Expand Down
Loading