Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions vllm/model_executor/models/qwen3_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,12 +180,16 @@ def forward(
# ============================================================
# Part 1: Input Projection
# ============================================================
mixed_qkvz, _ = self.in_proj_qkvz(hidden_states)
mixed_qkvz, ba = torch.ops.vllm.gdn_in_proj(
hidden_states,
self.in_proj_qkvz.weight.shape[0],
self.in_proj_ba.weight.shape[0],
self.prefix,
)
Comment on lines +183 to +188
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This regresses cold compile times by baking in a string into the compiled graph. We should really make a lint rule for this or something

Copy link
Copy Markdown
Contributor Author

@xyang16 xyang16 Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zou3519 Thanks for your comment. Looking into this. btw I was actually following torch.ops.vllm.gdn_attention_core ops in the same forward().

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.ops.vllm.gdn_attention_core is not included in the subgraph so it doesn't cause problems with compile times. I'm trying to figure out what to do with this. In theory we have a fix for this in PyTorch 2.11

Copy link
Copy Markdown
Contributor Author

@xyang16 xyang16 Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can revert using this torch.ops.vllm.gdn_in_proj op and wait for PyTorch 2.11. Please let me know how you think. Thanks!

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xyang16 are you able to refactor this so that the gdn_in_proj op does NOT need to pass a string as an input? Basically we would avoid stashing state into a side table. How difficult do you think that would be?

Copy link
Copy Markdown
Contributor Author

@xyang16 xyang16 Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zou3519 Sure, I will look into this today.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zou3519 I see your PR 38123. So it will fix this issue?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will fix the issue for PyTorch 2.11. But vLLM is going to do one more release (0.19.0, branch cut this Monday) without PyTorch 2.11.

If we can wait for the performance improvement in this PR, the easiest thing for us to do is just revert this PR and then re-merge it after #38123 and we upgrde to 2.11 (probably Tuesday)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zou3519 Thanks for the help. I have created #38152 to revert this PR. cc @benchislett

qkv_size = (self.key_dim * 2 + self.value_dim) // self.tp_size
z_size = self.value_dim // self.tp_size
mixed_qkv, z = mixed_qkvz.split([qkv_size, z_size], dim=-1)
z = z.reshape(z.size(0), -1, self.head_v_dim)
ba, _ = self.in_proj_ba(hidden_states)
b, a = ba.chunk(2, dim=-1)

b = b.contiguous()
Expand Down
64 changes: 61 additions & 3 deletions vllm/model_executor/models/qwen3_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,11 @@
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import Qwen3NextConfig
from vllm.triton_utils import tl, triton
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.utils.multi_stream_utils import maybe_execute_in_parallel
from vllm.utils.torch_utils import (
aux_stream,
direct_register_custom_op,
)
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata

Expand Down Expand Up @@ -419,6 +423,12 @@ def __init__(
self.act = ACT2FN[config.hidden_act]
self.layer_norm_epsilon = config.rms_norm_eps
self.prefix = prefix
self.aux_stream = aux_stream()
self.events = (
[torch.cuda.Event(), torch.cuda.Event()]
if current_platform.is_cuda()
else [None, None]
)

self.config = config
self.model_config = model_config
Expand Down Expand Up @@ -647,8 +657,12 @@ def forward(
# ============================================================
# Part 1: Input Projection
# ============================================================
projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states)
projected_states_ba, _ = self.in_proj_ba(hidden_states)
projected_states_qkvz, projected_states_ba = torch.ops.vllm.gdn_in_proj(
hidden_states,
self.in_proj_qkvz.weight.shape[0],
self.in_proj_ba.weight.shape[0],
self.prefix,
)
query, key, value, z, b, a = self.fix_query_key_value_ordering(
projected_states_qkvz, projected_states_ba
)
Expand Down Expand Up @@ -783,6 +797,18 @@ def _warmup_prefill_kernels(self, mixed_qkv: torch.Tensor) -> None:

torch.accelerator.empty_cache()

def _forward_in_proj(
self, hidden_states: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
projected_states_qkvz, projected_states_ba = maybe_execute_in_parallel(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a small question about the naming here. maybe means it may not run in parallel, but in this case, we always run in parallel, right?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ZJY0516 In maybe_execute_in_parallel, if aux_stream is not None it runs in parallel, otherwise runs sequentially. aux_stream is None in none-cuda platform. Thanks!

lambda: self.in_proj_qkvz(hidden_states)[0],
lambda: self.in_proj_ba(hidden_states)[0],
self.events[0],
self.events[1],
self.aux_stream,
)
return projected_states_qkvz, projected_states_ba

def _forward_core(
self,
mixed_qkv: torch.Tensor,
Expand Down Expand Up @@ -1670,6 +1696,32 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
return self.model.get_expert_mapping()


def gdn_in_proj(
hidden_states: torch.Tensor,
qkvz_output_size: int,
ba_output_size: int,
layer_name: str,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a tracking issue somewhere for porting this over to native Inductor multi-stream support?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have created an issue to track #37372. Thanks!

Custom op for the input projection.
"""
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
return self._forward_in_proj(hidden_states)
Comment on lines +1699 to +1710
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This indirection is pretty gross though. Could we avoid this somehow? I found it very confusing that you were passing in self.in_proj_qkvz.weight.shape[0] to this op instead of the module itself.

Also there is the concern of wrapping these MergedColumnParallelLinear modules that could be quantized - it seems we would lose the potential of torch.compile fusing the input quantization with previous ops or reaching inside of the linear op itself (less valid concern)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mgoin Thanks for review!

Actually this layer is already wrapped here https://github.com/vllm-project/vllm/blob/v0.18.0rc0/vllm/model_executor/models/qwen3_next.py#L1673-L1692. And I agree this should be improved once torch.compile supports multi stream.



def gdn_in_proj_fake(
hidden_states: torch.Tensor,
qkvz_output_size: int,
ba_output_size: int,
layer_name: str,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Fake implementation for torch.compile."""
return hidden_states.new_empty(
hidden_states.shape[0], qkvz_output_size
), hidden_states.new_empty(hidden_states.shape[0], ba_output_size)


def gdn_attention_core(
mixed_qkv: torch.Tensor,
b: torch.Tensor,
Expand Down Expand Up @@ -1703,6 +1755,12 @@ def gdn_attention_core_fake(
return


direct_register_custom_op(
op_name="gdn_in_proj",
op_func=gdn_in_proj,
fake_impl=gdn_in_proj_fake,
)

direct_register_custom_op(
op_name="gdn_attention_core",
op_func=gdn_attention_core,
Expand Down
48 changes: 48 additions & 0 deletions vllm/utils/multi_stream_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from collections.abc import Callable
from typing import Any

import torch


def maybe_execute_in_parallel(
fn0: Callable[[], Any],
fn1: Callable[[], Any],
event0: torch.cuda.Event,
event1: torch.cuda.Event,
aux_stream: torch.cuda.Stream | None = None,
) -> tuple[Any, Any]:
"""Run two functions potentially in parallel on separate CUDA streams.

When aux_stream is provided, fn0 runs on the current (default) stream and
fn1 runs on aux_stream, synchronized via CUDA events. When aux_stream is
None, both functions execute sequentially on the current stream.
Comment on lines +10 to +21
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do like this utility as a pattern to apply generally


This design follows TensorRT-LLM's maybe_execute_in_parallel pattern
(tensorrt_llm/_torch/modules/multi_stream_utils.py).

Args:
fn0: Callable for the default stream.
fn1: Callable for the auxiliary stream.
event0: CUDA event recorded before fn0 so aux_stream can wait.
event1: CUDA event recorded after fn1 so default stream can wait.
aux_stream: The second CUDA stream for fn1.
Multi-stream is disabled when aux_stream is None.

Returns:
Tuple of (fn0_result, fn1_result).
"""
if aux_stream is not None:
event0.record()
result0 = fn0()
with torch.cuda.stream(aux_stream):
event0.wait()
result1 = fn1()
event1.record()
event1.wait()
else:
result0 = fn0()
result1 = fn1()
return (result0, result1)
Loading