diff --git a/vllm/model_executor/models/qwen3_5.py b/vllm/model_executor/models/qwen3_5.py index daca52821e0f..1ae8aa75ab62 100644 --- a/vllm/model_executor/models/qwen3_5.py +++ b/vllm/model_executor/models/qwen3_5.py @@ -221,12 +221,8 @@ def forward( ba, _ = self.in_proj_ba(hidden_states) z, _ = self.in_proj_z(hidden_states) else: - mixed_qkvz, ba = torch.ops.vllm.gdn_in_proj( - hidden_states, - sum(self.in_proj_qkvz.output_sizes) // self.tp_size, - sum(self.in_proj_ba.output_sizes) // self.tp_size, - self.prefix, - ) + mixed_qkvz, _ = self.in_proj_qkvz(hidden_states) + ba, _ = self.in_proj_ba(hidden_states) qkv_size = (self.key_dim * 2 + self.value_dim) // self.tp_size z_size = self.value_dim // self.tp_size mixed_qkv, z = mixed_qkvz.split([qkv_size, z_size], dim=-1) diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index 5dfcd677b9a1..251c93364a9f 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -81,11 +81,7 @@ from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.qwen3_next import Qwen3NextConfig from vllm.triton_utils import tl, triton -from vllm.utils.multi_stream_utils import maybe_execute_in_parallel -from vllm.utils.torch_utils import ( - aux_stream, - direct_register_custom_op, -) +from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.backend import AttentionMetadata from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata @@ -421,12 +417,6 @@ def __init__( self.act = ACT2FN[config.hidden_act] self.layer_norm_epsilon = config.rms_norm_eps self.prefix = prefix - self.aux_stream = aux_stream() - self.events = ( - [torch.cuda.Event(), torch.cuda.Event()] - if current_platform.is_cuda_alike() - else [None, None] - ) self.config = config self.model_config = vllm_config.model_config @@ -659,12 +649,8 @@ def forward( # ============================================================ # Part 1: Input Projection # ============================================================ - projected_states_qkvz, projected_states_ba = torch.ops.vllm.gdn_in_proj( - hidden_states, - sum(self.in_proj_qkvz.output_sizes) // self.tp_size, - sum(self.in_proj_ba.output_sizes) // self.tp_size, - self.prefix, - ) + projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states) + projected_states_ba, _ = self.in_proj_ba(hidden_states) query, key, value, z, b, a = self.fix_query_key_value_ordering( projected_states_qkvz, projected_states_ba ) @@ -804,18 +790,6 @@ def _warmup_prefill_kernels(self, mixed_qkv: torch.Tensor) -> None: torch.accelerator.empty_cache() - def _forward_in_proj( - self, hidden_states: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor]: - projected_states_qkvz, projected_states_ba = maybe_execute_in_parallel( - lambda: self.in_proj_qkvz(hidden_states)[0], - lambda: self.in_proj_ba(hidden_states)[0], - self.events[0], - self.events[1], - self.aux_stream, - ) - return projected_states_qkvz, projected_states_ba - def _forward_core( self, mixed_qkv: torch.Tensor, @@ -1697,32 +1671,6 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: return self.model.get_expert_mapping() -def gdn_in_proj( - hidden_states: torch.Tensor, - qkvz_output_size: int, - ba_output_size: int, - layer_name: str, -) -> tuple[torch.Tensor, torch.Tensor]: - """ - Custom op for the input projection. - """ - forward_context: ForwardContext = get_forward_context() - self = forward_context.no_compile_layers[layer_name] - return self._forward_in_proj(hidden_states) - - -def gdn_in_proj_fake( - hidden_states: torch.Tensor, - qkvz_output_size: int, - ba_output_size: int, - layer_name: str, -) -> tuple[torch.Tensor, torch.Tensor]: - """Fake implementation for torch.compile.""" - return hidden_states.new_empty( - hidden_states.shape[0], qkvz_output_size - ), hidden_states.new_empty(hidden_states.shape[0], ba_output_size) - - def gdn_attention_core( mixed_qkv: torch.Tensor, b: torch.Tensor, @@ -1756,12 +1704,6 @@ def gdn_attention_core_fake( return -direct_register_custom_op( - op_name="gdn_in_proj", - op_func=gdn_in_proj, - fake_impl=gdn_in_proj_fake, -) - direct_register_custom_op( op_name="gdn_attention_core", op_func=gdn_attention_core,