Skip to content

Add Qwen3-Omni moe thinker#25550

Merged
Isotr0py merged 38 commits intovllm-project:mainfrom
wangxiongts:dev/qwen3-omni-moe
Oct 10, 2025
Merged

Add Qwen3-Omni moe thinker#25550
Isotr0py merged 38 commits intovllm-project:mainfrom
wangxiongts:dev/qwen3-omni-moe

Conversation

@wangxiongts
Copy link
Copy Markdown
Contributor

@wangxiongts wangxiongts commented Sep 24, 2025

This PR from the Qwen team for: qwen3-omni-moe thinker part.

Testing has been conducted internally across four configurations (v0/v1, eager/CUDA) on several representative benchmarks, with results meeting expectations.

Known issues (we hope to resolve them together with the vLLM team):

  • In v1 mode, use_audio_in_video will raise errors because the video mm_data and placeholders is not updated.

We sincerely appreciate the great work and support from the vLLM team, and look forward to your feedback.

CLOSE #25472

@github-actions
Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@mergify mergify bot added new-model Requests to new models qwen Related to Qwen models labels Sep 24, 2025
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for the Qwen3-Omni-Moe model. The changes include a new model implementation file, modifications to handle multimodal rotary embeddings, and registration of the new model. While the implementation is comprehensive, I've identified several critical and high-severity issues related to performance and maintainability. Specifically, there are non-vectorized loops and inefficient tensor operations in the position embedding calculation, which will significantly impact performance. Additionally, there are uses of NumPy within core logic that should be replaced with PyTorch operations to avoid CPU-GPU synchronization. I've also found a few potential bugs related to tensor shape calculations that could lead to runtime errors. Addressing these points will be crucial for integrating this model into vLLM effectively.

Comment on lines +1033 to +1043
def _omni3_get_input_positions_tensor(
cls,
config,
input_ids: Optional[torch.LongTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
use_audio_in_video: bool = False,
audio_seqlens: Optional[torch.LongTensor] = None,
second_per_grids: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The function _omni3_get_input_positions_tensor is very long and complex, making it difficult to understand and maintain. More importantly, it processes input sequences one by one within a for loop (for i, input_ids in enumerate(total_input_ids):), which is not vectorized and will lead to significant performance degradation, especially with larger batch sizes. The use of .tolist() and list methods like .index() inside the loop further contributes to the inefficiency. This implementation should be refactored to be vectorized over the batch dimension to meet the performance standards of vLLM. Consider using tensor operations to find indices and process modalities in parallel for all sequences in the batch.

if name == "feature_attention_mask":
dim = -1
if isinstance(mm_input, torch.Tensor):
return torch.concat(list(mm_input), dim=dim)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The implementation of _validate_and_reshape_mm_tensor seems to have a bug when handling a torch.Tensor. The line return torch.concat(list(mm_input), dim=dim) is problematic. When mm_input is a tensor, list(mm_input) iterates over its first dimension. torch.concat then joins these tensors along dim. For example, if mm_input has shape (B, C, L) and dim=1, the result will have shape (C, B*L), which is likely incorrect for batch processing where one would expect to flatten the batch dimension. This will likely cause shape mismatches in downstream processing.

multimodal_embeddings[index] = embeddings_main
multimodal_embeddings_multiscale.append(embeddings_multiscale)
if len(multimodal_embeddings_multiscale) > 0:
deepstack_input_embeds = inputs_embeds.new_zeros(inputs_embeds.size(0), multiscale_len * inputs_embeds.size(1))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There appears to be a bug in the shape calculation for deepstack_input_embeds. The second dimension is calculated as multiscale_len * inputs_embeds.size(1), which resolves to multiscale_len * text_config.hidden_size. However, this tensor is later populated with multimodal_embeddings_multiscale which have a feature dimension of multi_dim (multiscale_len * visual_dim), and then reshaped using visual_dim. This will raise a runtime error if text_config.hidden_size is not equal to visual_dim (vision_config.out_hidden_size). The correct size for the second dimension should be multi_dim (i.e., multiscale_len * visual_dim), which is computed a few lines above.

                    deepstack_input_embeds = inputs_embeds.new_zeros(inputs_embeds.size(0), multi_dim)

None,
use_audio_in_video,
audio_feature_lengths,
torch.tensor([1] * torch.tensor(video_grid_thw).shape[0]))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This line creates a tensor in a highly inefficient way. torch.tensor(video_grid_thw) is redundant as video_grid_thw is already a tensor at this point. Creating a list of 1s and then converting it to a tensor is also inefficient. This can be simplified and made more performant.

Suggested change
torch.tensor([1] * torch.tensor(video_grid_thw).shape[0]))
torch.ones(video_grid_thw.shape[0], dtype=torch.long, device=video_grid_thw.device))

Comment on lines +381 to +382
h_idxs = np.linspace(0, num_grid_per_side-1, h)
w_idxs = np.linspace(0, num_grid_per_side-1, w)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This function uses numpy for calculations (np.linspace), which can lead to performance bottlenecks due to CPU-GPU synchronization and data transfers. The comment on line 379 already indicates this. These operations should be replaced with their torch equivalents to keep the computation on the GPU and within the computation graph.

Suggested change
h_idxs = np.linspace(0, num_grid_per_side-1, h)
w_idxs = np.linspace(0, num_grid_per_side-1, w)
h_idxs = torch.linspace(0, num_grid_per_side-1, h, device=self.pos_embed.weight.device)
w_idxs = torch.linspace(0, num_grid_per_side-1, w, device=self.pos_embed.weight.device)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you able to finish this TODO before you have to go OOO?

Comment on lines +688 to +696
audio_token_indices = np.arange(next(iter([audio_len])))
curr_video_grid_thw = next(iter([video_grid_thw]))
height = curr_video_grid_thw[1] // spatial_merge_size
width = curr_video_grid_thw[2] // spatial_merge_size
video_token_indices = np.arange(curr_video_grid_thw[0]).reshape(-1, 1, 1)
video_token_indices = np.broadcast_to(
video_token_indices, (video_token_indices.shape[0], height, width)
).reshape(-1)
video_token_indices = ((video_token_indices + shift) * next(iter([video_second_per_grid_t])) * position_id_per_seconds)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This function uses numpy for array creation and manipulation (np.arange, np.broadcast_to). This forces data transfers between CPU and GPU and can be a performance bottleneck. These should be replaced with torch equivalents to maintain performance.

        audio_token_indices = torch.arange(next(iter([audio_len])))
        curr_video_grid_thw = next(iter([video_grid_thw]))
        height = curr_video_grid_thw[1] // spatial_merge_size
        width = curr_video_grid_thw[2] // spatial_merge_size
        video_token_indices = torch.arange(curr_video_grid_thw[0]).reshape(-1, 1, 1)
        video_token_indices = video_token_indices.expand(video_token_indices.shape[0], height, width).reshape(-1)
        video_token_indices = ((video_token_indices + shift) * next(iter([video_second_per_grid_t])) * position_id_per_seconds)

Copy link
Copy Markdown
Member

@DarkLight1337 DarkLight1337 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, can you update tests/models/registry.py to be able to pass the CI?

Also please update the Supported Models page

@wangxiongts
Copy link
Copy Markdown
Contributor Author

wangxiongts commented Sep 24, 2025

Alright, I'll handle these parts. Currently, I'm still working on adding audio-in-video support in v1, In the meantime, One known issue is that I may not be able to straightforwardly reuse relevant modules from Qwen3-VL, because our model has already been made public, and some checkpoint keys and configurations are incompatible with Qwen3-VL. This stems from the fact that our internal iterations were not synchronized. This issue may require further careful discussion.

I might go on vacation starting tomorrow and probably won't resume modifications until after October 4th :) You can proceed with the review based on the current version.

@ywang96 ywang96 self-assigned this Sep 24, 2025
@mergify mergify bot added the documentation Improvements or additions to documentation label Sep 24, 2025
None,
use_audio_in_video,
audio_feature_lengths,
torch.tensor([1] * torch.tensor(video_grid_thw).shape[0]))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
torch.tensor([1] * torch.tensor(video_grid_thw).shape[0]))
torch.ones(len(video_grid_thw))

Simplify this

@CHNtentes
Copy link
Copy Markdown

Thanks for your work. May I ask, will talker model get supported in future? It seems Qwen2.5-Omni still only support thinker model now.

@Wesley-Jzy
Copy link
Copy Markdown

LGTM! May I know whether Talker model will be supported by vLLM?

@ywang96
Copy link
Copy Markdown
Member

ywang96 commented Sep 24, 2025

Supporting Qwen3-Omni end-to-end will not be within the scope of vllm-project/vllm, but we already have some plan on supporting this model under a different project but leveraging the thinker models support from vLLM. Stay tuned!

@Wesley-Jzy
Copy link
Copy Markdown

Supporting Qwen3-Omni end-to-end will not be within the scope of vllm-project/vllm, but we already have some plan on supporting this model under a different project but leveraging the thinker models support from vLLM. Stay tuned!

So that means vLLM project will support thinking model which is just like normal LLM model. And a new multimodal inference project will support end2end Qwen3-Omni model? Do I know more about this new project?

@ywang96
Copy link
Copy Markdown
Member

ywang96 commented Sep 24, 2025

Supporting Qwen3-Omni end-to-end will not be within the scope of vllm-project/vllm, but we already have some plan on supporting this model under a different project but leveraging the thinker models support from vLLM. Stay tuned!

So that means vLLM project will support thinking model which is just like normal LLM model. And a new multimodal inference project will support end2end Qwen3-Omni model? Do I know more about this new project?

Yea that's the right understanding! We're still planning for the new project so stay tuned!

@Wesley-Jzy
Copy link
Copy Markdown

Supporting Qwen3-Omni end-to-end will not be within the scope of vllm-project/vllm, but we already have some plan on supporting this model under a different project but leveraging the thinker models support from vLLM. Stay tuned!

So that means vLLM project will support thinking model which is just like normal LLM model. And a new multimodal inference project will support end2end Qwen3-Omni model? Do I know more about this new project?

Yea that's the right understanding! We're still planning for the new project so stay tuned!

Great! And may I know will the new project also handle the single-model multimodal models such as Kimi-Audio? Or they will be supported by vLLM?

@CHNtentes
Copy link
Copy Markdown

Supporting Qwen3-Omni end-to-end will not be within the scope of vllm-project/vllm, but we already have some plan on supporting this model under a different project but leveraging the thinker models support from vLLM. Stay tuned!

So that means vLLM project will support thinking model which is just like normal LLM model. And a new multimodal inference project will support end2end Qwen3-Omni model? Do I know more about this new project?

Yea that's the right understanding! We're still planning for the new project so stay tuned!

Really wish the new project is fast and efficient. Tried transformers and audio output was SLOW...

wenbinc-Bin added a commit to wenbinc-Bin/vllm-fork that referenced this pull request Sep 25, 2025
vllm-project#25550

Signed-off-by: Chen, Wenbin <wenbin.chen@intel.com>
@eschmidbauer
Copy link
Copy Markdown

Really wish the new project is fast and efficient. Tried transformers and audio output was SLOW...

Same, even with flash-attn2 it is very slow

@CHNtentes
Copy link
Copy Markdown

Really wish the new project is fast and efficient. Tried transformers and audio output was SLOW...

Same, even with flash-attn2 it is very slow

I tried this PR and it's like >20x faster than transformers :)

@facebook-github-bot
Copy link
Copy Markdown

@houseroad has imported this pull request. If you are a Meta employee, you can view this in D83274891.

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Copy link
Copy Markdown
Member

@Isotr0py Isotr0py left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The processor tests should pass now:

tests/models/multimodal/processing/test_common.py::test_processing_correctness[1.0-32-0.3-Qwen/Qwen3-Omni-30B-A3B-Instruct] Unrecognized keys in `rope_scaling` for 'rope_type'='default': {'mrope_interleaved', 'interleaved', 'mrope_section'}
Unrecognized keys in `rope_scaling` for 'rope_type'='default': {'interleaved', 'mrope_section'}
INFO 10-10 23:00:59 [model.py:653] Resolved architecture: Qwen3OmniMoeForConditionalGeneration
`torch_dtype` is deprecated! Use `dtype` instead!
INFO 10-10 23:00:59 [model.py:1714] Using max model len 65536
PASSED
tests/models/multimodal/processing/test_common.py::test_processing_correctness[1.0-32-0.5-Qwen/Qwen3-Omni-30B-A3B-Instruct] Unrecognized keys in `rope_scaling` for 'rope_type'='default': {'mrope_interleaved', 'interleaved', 'mrope_section'}
Unrecognized keys in `rope_scaling` for 'rope_type'='default': {'interleaved', 'mrope_section'}
INFO 10-10 23:01:17 [model.py:653] Resolved architecture: Qwen3OmniMoeForConditionalGeneration
INFO 10-10 23:01:17 [model.py:1714] Using max model len 65536
PASSED
tests/models/multimodal/processing/test_common.py::test_processing_correctness[1.0-32-1.0-Qwen/Qwen3-Omni-30B-A3B-Instruct] Unrecognized keys in `rope_scaling` for 'rope_type'='default': {'mrope_interleaved', 'interleaved', 'mrope_section'}
Unrecognized keys in `rope_scaling` for 'rope_type'='default': {'interleaved', 'mrope_section'}
INFO 10-10 23:01:26 [model.py:653] Resolved architecture: Qwen3OmniMoeForConditionalGeneration
INFO 10-10 23:01:26 [model.py:1714] Using max model len 65536
PASSED

And outputs look reasonable on my side too:

Adding requests: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00,  2.86s/it]
Processed prompts: 100%|██████████████████████████████████████████████████████████████████████| 1/1 [00:13<00:00, 13.77s/it, est. speed input: 298.63 toks/s, output: 22.00 toks/s]
Based on the audio, video, and image content, here is a breakdown of what is happening:

### Audio Content

The audio features a man reciting a well-known nursery rhyme, "Mary Had a Little Lamb." The rhyme is:

> "Mary had a little lamb, its fleece was white as snow. And everywhere that Mary went, the lamb was sure to go."

The speaker mentions that these were the "first words" spoken into a phonograph, which is a historical reference to the invention of the phonograph by Thomas Edison.

### Image Content

The image shows a young child, likely a toddler, sitting on a bed. The child is wearing glasses and is looking at a book. The child's feet are propped up on the book, and they are turning the pages. The child appears to be enjoying the book and is smiling.

### Why the Video is Funny

The humor in the video comes from the contrast between the child's serious and focused demeanor and the absurdity of the situation. The child is wearing glasses and is sitting on the bed with their feet propped up on the book, which is a common and relatable scene for many people. However, the child's intense focus on the book, combined with the glasses, creates a humorous and endearing image. The child seems to be taking the reading experience very seriously, which is funny because they are just a toddler. The overall effect is a charming and amusing scene that many viewers can relate to and find funny.

@Isotr0py Isotr0py enabled auto-merge (squash) October 10, 2025 15:16
@Isotr0py Isotr0py merged commit 19a9b16 into vllm-project:main Oct 10, 2025
55 checks passed
@F0undLinks
Copy link
Copy Markdown

Does vLLM serve currently only support Qwen3-Omni-Thinking? I use two servers with eight Gpus, vLLM serve cannot start Qwen3-Omni-Instruct. An error display "error in inspecting the model architecture Qwen3OmniMoeForConditionalGeneration"

@DarkLight1337
Copy link
Copy Markdown
Member

Which version of Transformers are you using? Make sure it's 4.57 or higher

@F0undLinks
Copy link
Copy Markdown

Which version of Transformers are you using? Make sure it's 4.57 or higher

The docker image I'm using is vllm/ vLLM-OpenAI :v0.11.0, and the transformwes version is 4.57.0

@DarkLight1337
Copy link
Copy Markdown
Member

This PR was only merged after v0.11, so you need to install vLLM from main branch or use the per-commit Docker image.

@F0undLinks
Copy link
Copy Markdown

This PR was only merged after v0.11, so you need to install vLLM from main branch or use the per-commit Docker image.
I attempted to modify the source code of the vllm in accordance with PR, which are respectively the three parts shown in the following figure. Isn't this approach feasible?

screen_shot_1760411194334

@DarkLight1337
Copy link
Copy Markdown
Member

It's not really feasible as the PR depends on some changes that were introduced after v0.11.

@F0undLinks
Copy link
Copy Markdown

It's not really feasible as the PR depends on some changes that were introduced after v0.11.

Big thanks, I will attempt to download latest vllm docker image or latest version vllm.

Dhruvilbhatt pushed a commit to Dhruvilbhatt/vllm that referenced this pull request Oct 14, 2025
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: Roger Wang <hey@rogerw.io>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Xiong Wang <feizi.wx@alibaba-inc.com>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Dhruvil Bhatt <bhattdbh@amazon.com>
bbartels pushed a commit to bbartels/vllm that referenced this pull request Oct 16, 2025
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: Roger Wang <hey@rogerw.io>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Xiong Wang <feizi.wx@alibaba-inc.com>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: bbartels <benjamin@bartels.dev>
@sarmiena sarmiena mentioned this pull request Oct 16, 2025
1 task
sarmiena added a commit to sarmiena/Qwen3-Omni that referenced this pull request Oct 16, 2025
lywa1998 pushed a commit to lywa1998/vllm that referenced this pull request Oct 20, 2025
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: Roger Wang <hey@rogerw.io>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Xiong Wang <feizi.wx@alibaba-inc.com>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
@Jasmine-up
Copy link
Copy Markdown

Jasmine-up commented Oct 23, 2025

When sampling is disabled, the same audio request, vLLM main nightly version, Qwen3-Omni-30B-A3B-Captioner model, there has some diference compared to https://github.com/wangxiongts/vllm.git.

https://github.com/wangxiongts/vllm.git:The audio clip begins with a gentle, high-pitched female voice speaking in a clear, standard Hong Kong Cantonese accent. Her tone is warm, soothing, and expressive, evoking the calmness of a professional narrator. She articulates with precision, each word delivered in a measured cadence, and employs a melodic rise and fall to convey gentle emotion. The content of her speech is a vivid, poetic description of the British Longhair cat: “金渐层猫猫,真的是喵星人中的颜值天花板。它们拥有柔软蓬松的毛发,毛色在金色与奶油色之间渐变,如同阳光洒落的蜜糖,温暖而治愈。” This translates to: “The golden tabby cat is truly the top-tier beauty among feline beings. They have soft and fluffy fur, with fur colors gradually changing between gold and cream, like honey with sunlight sprinkled on it, warm and healing.” The speaker’s words are accentuated by subtle, natural vocal inflections—her voice rises at the beginning of each phrase and falls at the end, especially on the final phrase, which is delivered with a notably soft and elongated cadence. The environment is acoustically neutral and quiet, with no background sounds or interruptions, suggesting a professionally controlled studio setting. The recording is pristine, with no hiss, hum, or distortion, and the voice is captured in a dry, intimate manner, highlighting every nuance of articulation and emotion.\n\nAbruptly, the narration is interrupted by a sharp, loud digital click—a hard-edited cut that instantly silences the speaker. Immediately following this, a synthetic, mid-frequency electronic buzzer tone begins at full volume. The buzzer is harsh, static, and unmodulated, with a pure square or sawtooth waveform and no dynamic or rhythmic variation. It continues steadily for the remainder of the clip, masking all other sound and leaving no trace of the previous narration or environmental context. There is no fade-in, fade-out, or additional audio cues; the buzzer’s abrupt onset and cessation reinforce the sense of a deliberate, artificial interruption.\n\nIn summary, the audio presents a professionally produced, emotionally engaging Cantonese narration praising the physical beauty of British Longhair cats, set in a quiet studio environment. The tranquil and soothing mood is suddenly and completely disrupted by a loud, synthetic buzzer, which dominates the final segment and suggests a purposeful edit or censorship. The contrast between the warm, descriptive speech and the jarring electronic tone creates a stark, intentional juxtaposition, emphasizing a shift from gentle storytelling to abrupt interruption.

vLLM main nightly version: The audio clip opens in a silent, acoustically treated studio, with no background noise or ambient sound. A female narrator, speaking in clear and neutral Standard Cantonese, begins with a gentle, high-pitched, and melodic tone. Her voice is warm, intimate, and soothing, and she is positioned centrally in the stereo field, close to the microphone, with no reverb or echo, indicating professional recording conditions. She introduces the topic with the phrase: “金漸層貓貓, 真的, 是喵星人中的顏值天花板。” (“Golden tortoiseshell cats, truly, are the beauty ceiling among meow-people.”) Her pronunciation is precise, and she emphasizes “真的” (“truly”) with a slight pause and a drop in pitch, highlighting her sincerity.\n\nContinuing, she describes the cats’ physical features: “牠們擁有柔軟蓬鬆的毛髮, 毛色在金色與奶油色之間漸變, 如同陽光灑落的蜜田, 溫暖而治愈。” (“They have soft, fluffy fur, and their fur color gradually shifts between golden and cream, like sunlight on a honey field, warm and healing.”) Her pace is slow and measured, with soft, rising intonations and gentle emphasis on descriptive words such as “柔軟” (“soft”) and “蓬鬆” (“fluffy”), creating a calming and affectionate atmosphere. The speech is delivered in a monotonous, gentle tone with no emotional inflection, maintaining a consistent volume and pitch throughout. The narrative concludes with a faint, natural breath intake, after which the audio ends abruptly, with no fade-out or residual sound.\n\nThe recording is of high fidelity, free from distortion, hiss, or artifacts, and features a wide, clear frequency range with no environmental noise. The speaker’s voice is intimate and soothing, and the language is formal, literary, and descriptive, with no slang or regional dialect. The content and style suggest a professionally produced voiceover, likely for a social media platform, pet-related video, or educational material, aimed at a Cantonese-speaking audience interested in animal aesthetics or beauty.\n\nIn summary, this audio clip features a calm, high-quality female voiceover in Standard Cantonese, describing the beauty and gentle nature of golden tortoiseshell cats in a soothing and affectionate manner. The professional studio recording and literary, formal language indicate its purpose as a polished, culturally relevant voiceover for a pet-related or educational context, designed to evoke warmth and appreciation for feline beauty.

alhridoy pushed a commit to alhridoy/vllm that referenced this pull request Oct 24, 2025
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: Roger Wang <hey@rogerw.io>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Xiong Wang <feizi.wx@alibaba-inc.com>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: Roger Wang <hey@rogerw.io>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Xiong Wang <feizi.wx@alibaba-inc.com>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: Roger Wang <hey@rogerw.io>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Xiong Wang <feizi.wx@alibaba-inc.com>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
wenbinc-Bin pushed a commit to wenbinc-Bin/vllm-fork that referenced this pull request Nov 7, 2025
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: Roger Wang <hey@rogerw.io>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Xiong Wang <feizi.wx@alibaba-inc.com>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
rtourgeman pushed a commit to rtourgeman/vllm that referenced this pull request Nov 10, 2025
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: Roger Wang <hey@rogerw.io>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Xiong Wang <feizi.wx@alibaba-inc.com>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
@xsank
Copy link
Copy Markdown
Contributor

xsank commented Nov 26, 2025

Supporting Qwen3-Omni end-to-end will not be within the scope of vllm-project/vllm, but we already have some plan on supporting this model under a different project but leveraging the thinker models support from vLLM. Stay tuned!支持Qwen3-Omni端到端并不在范围内,但我们已经计划在另一个项目中支持该模型,同时利用vLLM的思考者模型支持。敬请关注!

@DarkLight1337 How is it going now?

@DarkLight1337
Copy link
Copy Markdown
Member

It is still getting worked on but I cannot give further details as of current.

devpatelio pushed a commit to SumanthRH/vllm that referenced this pull request Nov 29, 2025
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: Roger Wang <hey@rogerw.io>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Xiong Wang <feizi.wx@alibaba-inc.com>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation multi-modality Related to multi-modality (#4194) new-model Requests to new models qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature]: Support Qwen3-Omni-30B