Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 71 additions & 2 deletions vllm/model_executor/layers/mamba/gdn_linear_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,12 @@
from vllm.platforms import current_platform
from vllm.transformers_utils.configs.qwen3_next import Qwen3NextConfig
from vllm.triton_utils import tl, triton
from vllm.utils.multi_stream_utils import maybe_execute_in_parallel
from vllm.utils.torch_utils import (
LayerNameType,
_encode_layer_name,
_resolve_layer_name,
aux_stream,
direct_register_custom_op,
)
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
Expand Down Expand Up @@ -275,6 +277,19 @@ def __init__(
self.forward_xpu if current_platform.is_xpu() else self.forward_cuda
)

# Dual-stream dispatch for the two input projections (in_proj_qkvz and
# in_proj_ba). Re-enabled via the LayerName opaque type (PR #38123) so
# torch.compile no longer regresses cold compile times.
self.aux_stream = aux_stream()
self.events = (
[
torch.cuda.Event(enable_timing=False),
torch.cuda.Event(enable_timing=False),
]
if self.aux_stream is not None
else [None, None]
)
Comment on lines +284 to +291
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The initialization of self.events uses current_platform.is_cuda(), which returns False on ROCm (HIP) platforms. However, self.aux_stream is initialized using aux_stream(), which returns a valid stream on ROCm (as it checks is_cuda_alike()). This mismatch will cause a crash in maybe_execute_in_parallel when it attempts to call .record() on a None event.

Additionally, for synchronization events where timing is not required, it is recommended to use enable_timing=False to reduce overhead.

Suggested change
self.events = (
[torch.cuda.Event(), torch.cuda.Event()]
if current_platform.is_cuda()
else [None, None]
)
self.events = (
[torch.cuda.Event(enable_timing=False),
torch.cuda.Event(enable_timing=False)]
if self.aux_stream is not None else [None, None]
)


# QKV
self.conv_dim = self.key_dim * 2 + self.value_dim
self.conv1d = ColumnParallelLinear(
Expand Down Expand Up @@ -541,8 +556,12 @@ def forward_cuda(
b = b.contiguous()
a = a.contiguous()
else:
mixed_qkvz, _ = self.in_proj_qkvz(hidden_states)
ba, _ = self.in_proj_ba(hidden_states)
mixed_qkvz, ba = torch.ops.vllm.gdn_in_proj(
hidden_states,
self.in_proj_qkvz.weight.shape[0],
self.in_proj_ba.weight.shape[0],
_encode_layer_name(self.prefix),
)

if self.gqa_interleaved_layout:
# Qwen3-Next: unpack the interleaved GQA layout
Expand Down Expand Up @@ -594,6 +613,23 @@ def forward_cuda(
core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)")
output[:num_tokens], _ = self.out_proj(core_attn_out)

def _forward_in_proj(
self, hidden_states: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""Run in_proj_qkvz and in_proj_ba in parallel on two CUDA streams.

Falls back to sequential execution when aux_stream is unavailable
(e.g. non-CUDA platforms).
"""
projected_states_qkvz, projected_states_ba = maybe_execute_in_parallel(
lambda: self.in_proj_qkvz(hidden_states)[0],
lambda: self.in_proj_ba(hidden_states)[0],
self.events[0],
self.events[1],
self.aux_stream,
)
return projected_states_qkvz, projected_states_ba

def forward_xpu(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -1063,6 +1099,39 @@ def _forward_core_decode_non_spec(
return


def gdn_in_proj(
hidden_states: torch.Tensor,
qkvz_output_size: int,
ba_output_size: int,
layer_name: LayerNameType,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Custom op wrapping the dual-stream input projection."""
layer_name = _resolve_layer_name(layer_name)
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
return self._forward_in_proj(hidden_states)


def gdn_in_proj_fake(
hidden_states: torch.Tensor,
qkvz_output_size: int,
ba_output_size: int,
layer_name: LayerNameType,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Fake implementation for torch.compile."""
return (
hidden_states.new_empty(hidden_states.shape[0], qkvz_output_size),
hidden_states.new_empty(hidden_states.shape[0], ba_output_size),
)


direct_register_custom_op(
op_name="gdn_in_proj",
op_func=gdn_in_proj,
fake_impl=gdn_in_proj_fake,
)


def gdn_attention_core(
mixed_qkv: torch.Tensor,
b: torch.Tensor,
Expand Down
Loading