From 072b9f055f8e6e58ec94e342762026a32de8c284 Mon Sep 17 00:00:00 2001 From: yuanheng Date: Wed, 25 Feb 2026 15:13:55 +0800 Subject: [PATCH 1/3] fix: apply utils fa flash_attn_varlen_func Signed-off-by: yuanheng --- .../qwen3_tts/tokenizer_25hz/vq/whisper_encoder.py | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/whisper_encoder.py b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/whisper_encoder.py index c1c3c987a0f..5e2589b5b14 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/whisper_encoder.py +++ b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/whisper_encoder.py @@ -23,19 +23,7 @@ import torch.nn.functional as F from torch import Tensor, nn -try: - from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func -except ImportError: - try: - from flash_attn.flash_attn_interface import flash_attn_unpadded_func as flash_attn_varlen_func - except ImportError: - print( - "\n********\nWarning: flash-attn is not installed. " - "Will only run the manual PyTorch version. " - "Please install flash-attn for faster inference.\n********\n " - ) - flash_attn_varlen_func = None - +from vllm_omni.diffusion.attention.backends.utils.fa import flash_attn_varlen_func N_FFT = 400 HOP_LENGTH = 160 From 2abe9953c4a53d147a14f2e0453df1d14d83cf5a Mon Sep 17 00:00:00 2001 From: yuanheng Date: Wed, 25 Feb 2026 15:24:14 +0800 Subject: [PATCH 2/3] fix branching Signed-off-by: yuanheng --- .../tokenizer_25hz/vq/whisper_encoder.py | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/whisper_encoder.py b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/whisper_encoder.py index 5e2589b5b14..e9193115834 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/whisper_encoder.py +++ b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/whisper_encoder.py @@ -23,7 +23,7 @@ import torch.nn.functional as F from torch import Tensor, nn -from vllm_omni.diffusion.attention.backends.utils.fa import flash_attn_varlen_func +from vllm_omni.diffusion.attention.backends.utils.fa import HAS_FLASH_ATTN, flash_attn_varlen_func N_FFT = 400 HOP_LENGTH = 160 @@ -140,7 +140,7 @@ def forward(self, x: Tensor) -> Tensor: class MultiHeadAttention(nn.Module): - def __init__(self, n_state: int, n_head: int): + def __init__(self, n_state: int, n_head: int, use_flash_attention: bool = True): super().__init__() self.n_head = n_head self.query = Linear(n_state, n_state) @@ -148,7 +148,7 @@ def __init__(self, n_state: int, n_head: int): self.value = Linear(n_state, n_state) self.out = Linear(n_state, n_state) - self.use_flash_attention = True + self.use_flash_attention = use_flash_attention and HAS_FLASH_ATTN def forward( self, @@ -159,15 +159,8 @@ def forward( k = self.key(x) v = self.value(x) - if self.use_flash_attention: - if flash_attn_varlen_func is None: - x = self.qkv_attention_manual(q, k, v, cu_seqlens=cu_seqlens) - else: - if q.dtype not in [torch.float16, torch.bfloat16]: - x = self.qkv_attention_manual(q, k, v, cu_seqlens=cu_seqlens) - self.use_flash_attention = False - else: - x = self.qkv_flash_attention(q, k, v, cu_seqlens=cu_seqlens) + if self.use_flash_attention and q.dtype in [torch.float16, torch.bfloat16]: + x = self.qkv_flash_attention(q, k, v, cu_seqlens=cu_seqlens) else: x = self.qkv_attention_manual(q, k, v, cu_seqlens=cu_seqlens) From eb363aee94f2d7b5d9c89390d88dcb9657d9fda9 Mon Sep 17 00:00:00 2001 From: yuanheng Date: Wed, 25 Feb 2026 17:12:59 +0800 Subject: [PATCH 3/3] rm qwen3 tts modeling legacy arg Signed-off-by: yuanheng --- .../configuration_qwen3_tts_tokenizer_v1.py | 4 ---- .../modeling_qwen3_tts_tokenizer_v1.py | 2 -- .../qwen3_tts/tokenizer_25hz/vq/speech_vq.py | 4 ---- .../tokenizer_25hz/vq/whisper_encoder.py | 24 ++----------------- 4 files changed, 2 insertions(+), 32 deletions(-) diff --git a/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/configuration_qwen3_tts_tokenizer_v1.py b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/configuration_qwen3_tts_tokenizer_v1.py index 74272f936ce..74775591dc4 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/configuration_qwen3_tts_tokenizer_v1.py +++ b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/configuration_qwen3_tts_tokenizer_v1.py @@ -214,8 +214,6 @@ class Qwen3TTSTokenizerV1EncoderConfig(PretrainedConfig): output_dim (`int`, *optional*, defaults to 3584): Output feature dimension produced by the encoder head (before/after projection, implementation-dependent). - grad_checkpointing (`bool`, *optional*, defaults to `False`): - Whether to enable gradient checkpointing to reduce memory usage during training. enable_mp (`bool`, *optional*, defaults to `False`): Whether to enable model parallel features (implementation-dependent). audio_sequence_parallel (`bool`, *optional*, defaults to `False`): @@ -246,7 +244,6 @@ def __init__( n_layer=32, n_window=100, output_dim=3584, - grad_checkpointing=False, enable_mp=False, audio_sequence_parallel=False, audio_vq_type="GRVQ", @@ -265,7 +262,6 @@ def __init__( self.n_layer = n_layer self.n_window = n_window self.output_dim = output_dim - self.grad_checkpointing = grad_checkpointing self.enable_mp = enable_mp self.audio_sequence_parallel = audio_sequence_parallel self.audio_vq_type = audio_vq_type diff --git a/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/modeling_qwen3_tts_tokenizer_v1.py b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/modeling_qwen3_tts_tokenizer_v1.py index bceafc98e39..cf4622fac57 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/modeling_qwen3_tts_tokenizer_v1.py +++ b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/modeling_qwen3_tts_tokenizer_v1.py @@ -1297,8 +1297,6 @@ def __init__(self, config: Qwen3TTSTokenizerV1EncoderConfig): n_layer=config.n_layer, n_window=config.n_window, output_dim=config.output_dim, - grad_checkpointing=config.grad_checkpointing, - enable_mp=config.enable_mp, audio_sequence_parallel=config.audio_sequence_parallel, audio_vq_type=config.audio_vq_type, audio_vq_layers=config.audio_vq_layers, diff --git a/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/speech_vq.py b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/speech_vq.py index 805feb81fea..de2c69702c5 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/speech_vq.py +++ b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/speech_vq.py @@ -196,8 +196,6 @@ def __init__( n_layer: int, n_window: int = 1500, output_dim: int = 512, - grad_checkpointing: bool = False, - enable_mp: bool = False, audio_sequence_parallel: bool = False, audio_vq_layers: int = -1, audio_vq_type: str = "NULL", @@ -219,8 +217,6 @@ def __init__( n_layer, n_window, output_dim, - grad_checkpointing, - enable_mp, audio_sequence_parallel, ) diff --git a/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/whisper_encoder.py b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/whisper_encoder.py index e9193115834..e3bd6e1c3a3 100644 --- a/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/whisper_encoder.py +++ b/vllm_omni/model_executor/models/qwen3_tts/tokenizer_25hz/vq/whisper_encoder.py @@ -230,7 +230,7 @@ def qkv_attention_manual(self, q: Tensor, k: Tensor, v: Tensor, cu_seqlens: Tens class ResidualAttentionBlock(nn.Module): - def __init__(self, n_state: int, n_head: int, enable_mp: bool = False, sequence_parallel: bool = False): + def __init__(self, n_state: int, n_head: int, sequence_parallel: bool = False): super().__init__() n_mlp = n_state * 4 self.attn_ln = nn.LayerNorm(n_state) @@ -255,8 +255,6 @@ def __init__( n_layer: int, n_window: int = 1500, output_dim: int = 512, - grad_checkpointing: bool = False, - enable_mp: bool = False, audio_sequence_parallel: bool = False, ): super().__init__() @@ -267,10 +265,7 @@ def __init__( self.n_mels = n_mels self.blocks = nn.ModuleList( - [ - ResidualAttentionBlock(n_state, n_head, enable_mp=enable_mp, sequence_parallel=audio_sequence_parallel) - for _ in range(n_layer) - ] + [ResidualAttentionBlock(n_state, n_head, sequence_parallel=audio_sequence_parallel) for _ in range(n_layer)] ) self.ln_post = nn.LayerNorm(n_state) self.avg_pooler = nn.AvgPool1d(2, stride=2) @@ -280,8 +275,6 @@ def __init__( self.audio_bos_eos_token = nn.Embedding(2, output_dim) self.output_dim = output_dim - self.grad_checkpointing = grad_checkpointing - self.enable_mp = enable_mp self.n_head = n_head self.n_state = n_state self.n_window = n_window @@ -290,13 +283,6 @@ def __init__( self.tp_world_size = 1 - self.set_audio_sync() - - def set_audio_sync(self): - for name, param in self.named_parameters(): - if not name.startswith("blocks"): - setattr(param, "audio_sync", True) - def forward( self, x_list: list[Tensor], audio_mellens: list[int], audio_aftercnnlens: list[int], audio_seqlens: list[int] ): @@ -358,9 +344,3 @@ def forward( output[end_ids] = self.audio_bos_eos_token.weight[1].to(x.dtype) output[audio_tokens_mask] = x return output - - def lock(self, layers: int): - self.conv1.requires_grad_(False) - self.conv2.requires_grad_(False) - for i in range(min(layers, len(self.blocks))): - self.blocks[i].requires_grad_(False)