Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions vllm/model_executor/models/qwen3_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,12 +221,8 @@ def forward(
ba, _ = self.in_proj_ba(hidden_states)
z, _ = self.in_proj_z(hidden_states)
else:
mixed_qkvz, ba = torch.ops.vllm.gdn_in_proj(
hidden_states,
sum(self.in_proj_qkvz.output_sizes) // self.tp_size,
sum(self.in_proj_ba.output_sizes) // self.tp_size,
self.prefix,
)
mixed_qkvz, _ = self.in_proj_qkvz(hidden_states)
ba, _ = self.in_proj_ba(hidden_states)
qkv_size = (self.key_dim * 2 + self.value_dim) // self.tp_size
z_size = self.value_dim // self.tp_size
mixed_qkv, z = mixed_qkvz.split([qkv_size, z_size], dim=-1)
Expand Down
64 changes: 3 additions & 61 deletions vllm/model_executor/models/qwen3_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,7 @@
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.qwen3_next import Qwen3NextConfig
from vllm.triton_utils import tl, triton
from vllm.utils.multi_stream_utils import maybe_execute_in_parallel
from vllm.utils.torch_utils import (
aux_stream,
direct_register_custom_op,
)
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata

Expand Down Expand Up @@ -421,12 +417,6 @@ def __init__(
self.act = ACT2FN[config.hidden_act]
self.layer_norm_epsilon = config.rms_norm_eps
self.prefix = prefix
self.aux_stream = aux_stream()
self.events = (
[torch.cuda.Event(), torch.cuda.Event()]
if current_platform.is_cuda_alike()
else [None, None]
)

self.config = config
self.model_config = vllm_config.model_config
Expand Down Expand Up @@ -659,12 +649,8 @@ def forward(
# ============================================================
# Part 1: Input Projection
# ============================================================
projected_states_qkvz, projected_states_ba = torch.ops.vllm.gdn_in_proj(
hidden_states,
sum(self.in_proj_qkvz.output_sizes) // self.tp_size,
sum(self.in_proj_ba.output_sizes) // self.tp_size,
self.prefix,
)
projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states)
projected_states_ba, _ = self.in_proj_ba(hidden_states)
query, key, value, z, b, a = self.fix_query_key_value_ordering(
projected_states_qkvz, projected_states_ba
)
Expand Down Expand Up @@ -804,18 +790,6 @@ def _warmup_prefill_kernels(self, mixed_qkv: torch.Tensor) -> None:

torch.accelerator.empty_cache()

def _forward_in_proj(
self, hidden_states: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
projected_states_qkvz, projected_states_ba = maybe_execute_in_parallel(
lambda: self.in_proj_qkvz(hidden_states)[0],
lambda: self.in_proj_ba(hidden_states)[0],
self.events[0],
self.events[1],
self.aux_stream,
)
return projected_states_qkvz, projected_states_ba

def _forward_core(
self,
mixed_qkv: torch.Tensor,
Expand Down Expand Up @@ -1697,32 +1671,6 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
return self.model.get_expert_mapping()


def gdn_in_proj(
hidden_states: torch.Tensor,
qkvz_output_size: int,
ba_output_size: int,
layer_name: str,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Custom op for the input projection.
"""
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
return self._forward_in_proj(hidden_states)


def gdn_in_proj_fake(
hidden_states: torch.Tensor,
qkvz_output_size: int,
ba_output_size: int,
layer_name: str,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Fake implementation for torch.compile."""
return hidden_states.new_empty(
hidden_states.shape[0], qkvz_output_size
), hidden_states.new_empty(hidden_states.shape[0], ba_output_size)


def gdn_attention_core(
mixed_qkv: torch.Tensor,
b: torch.Tensor,
Expand Down Expand Up @@ -1756,12 +1704,6 @@ def gdn_attention_core_fake(
return


direct_register_custom_op(
op_name="gdn_in_proj",
op_func=gdn_in_proj,
fake_impl=gdn_in_proj_fake,
)

direct_register_custom_op(
op_name="gdn_attention_core",
op_func=gdn_attention_core,
Expand Down
Loading