From 112fa95e45d023c99fd533849f1c14341b4221a1 Mon Sep 17 00:00:00 2001 From: John Liu BUAA Date: Sat, 17 Jan 2026 15:53:00 +0800 Subject: [PATCH] Replace ColumnParallelLinear with nn.Linear in talker Swapped out the custom ColumnParallelLinear layer for a standard nn.Linear in Qwen2_5OmniTalkerForConditionalGeneration. Updated the forward pass to match the new layer's output signature, simplifying the projection step. Signed-off-by: John Liu BUAA --- .buildkite/pipeline.yml | 2 +- .../models/qwen2_5_omni/qwen2_5_omni_talker.py | 9 ++------- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index f18aef6177..d4cbbab2b8 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -94,7 +94,7 @@ steps: - "/fsx/hf_cache:/fsx/hf_cache" - label: "Diffusion Parallelism Test" - timeout_in_minutes: 20 + timeout_in_minutes: 25 depends_on: image-build commands: - pytest -s -v tests/e2e/offline_inference/test_sequence_parallel.py diff --git a/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_talker.py b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_talker.py index 7b0b443091..927bc55257 100644 --- a/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_talker.py +++ b/vllm_omni/model_executor/models/qwen2_5_omni/qwen2_5_omni_talker.py @@ -9,7 +9,6 @@ # from vllm.attention import AttentionMetadata # unused import from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.models.interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from vllm.model_executor.models.qwen2_5_omni_thinker import ( Qwen2_5OmniThinkerDummyInputsBuilder, @@ -68,13 +67,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): else: self.config = config - self.thinker_to_talker_proj = ColumnParallelLinear( + self.thinker_to_talker_proj = nn.Linear( self.config.embedding_size, self.config.hidden_size, - bias=True, - gather_output=True, - skip_bias_add=False, - quant_config=quant_config, ) self.language_model = init_vllm_registered_model( vllm_config=vllm_config, @@ -145,7 +140,7 @@ def forward( input_ids = None # projection - inputs_embeds, _ = self.thinker_to_talker_proj(inputs_embeds) + inputs_embeds = self.thinker_to_talker_proj(inputs_embeds) hidden_states = self.language_model.model( input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds