Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 10 additions & 67 deletions vllm/model_executor/layers/mamba/gdn_linear_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,6 @@ def fix_query_key_value_ordering(

return query, key, value, z, b, a

@torch.compile(fullgraph=True)
def prepare_gdn_attention_core_inputs(
self,
mixed_qkvz: torch.Tensor,
Expand Down Expand Up @@ -586,67 +585,23 @@ def prepare_gdn_attention_core_inputs(
(query, key, value, z) = torch.split(mixed_qkvz, split_arg_list_qkvz, dim=-1)
(b, a) = torch.split(mixed_ba, split_arg_list_ba, dim=-1)

mixed_qkv_logical = torch.cat(
mixed_qkv_out = torch.cat(
[
query.reshape(num_tokens, -1),
key.reshape(num_tokens, -1),
value.reshape(num_tokens, -1),
],
dim=-1,
)

# The split above produces non-contiguous views into the interleaved
# buffer. Concatenating everything into a single flat tensor forces a
# contiguous copy, then slicing back out gives contiguous q/k/v/z/b/a
# tensors that downstream kernels require. Doing this in one cat+slice
# keeps torch.compile in a single Triton graph instead of emitting
# separate copy kernels per tensor. The original code used
# rearrange(...).contiguous() on each tensor individually.
fused = torch.cat(
[
mixed_qkv_logical.reshape(-1),
z.reshape(-1),
b.reshape(-1),
a.reshape(-1),
],
dim=0,
)

curr = 0
qkv_numel = mixed_qkv_logical.numel()
z_numel = z.numel()
b_numel = b.numel()
a_numel = a.numel()

mixed_qkv_out = fused[curr : curr + qkv_numel].view(num_tokens, -1)
curr += qkv_numel

z_out = fused[curr : curr + z_numel].view(
num_tokens, self.num_v_heads // self.tp_size, self.head_v_dim
)
curr += z_numel

b_out = fused[curr : curr + b_numel].view(
num_tokens, self.num_v_heads // self.tp_size
)
curr += b_numel

a_out = fused[curr : curr + a_numel].view(
num_tokens, self.num_v_heads // self.tp_size
)
num_v_per_tp = self.num_v_heads // self.tp_size
z_out = z.reshape(num_tokens, num_v_per_tp, self.head_v_dim)
b_out = b.reshape(num_tokens, num_v_per_tp)
a_out = a.reshape(num_tokens, num_v_per_tp)
Comment thread
haosdent marked this conversation as resolved.

return mixed_qkv_out, z_out, b_out, a_out

@torch.compile(fullgraph=True)
def rearrange_mixed_qkv(self, mixed_qkv):
"""Split packed qkv into contiguous (1, seq, heads, dim) tensors.

The original code used ``rearrange(x, "l (h d) -> 1 l h d", d=...)``
followed by ``.contiguous()`` on each tensor. This version flattens
all three splits into a single buffer via ``torch.cat`` so that
torch.compile emits one Triton copy kernel instead of three separate
contiguous() calls.
"""
"""Split packed qkv into contiguous (1, seq, heads, dim) tensors."""
if mixed_qkv is None:
return None, None, None

Expand All @@ -656,24 +611,12 @@ def rearrange_mixed_qkv(self, mixed_qkv):
v_dim = self.value_dim // self.tp_size

query, key, value = torch.split(mixed_qkv, [q_dim, k_dim, v_dim], dim=-1)

fused = torch.cat(
[query.reshape(-1), key.reshape(-1), value.reshape(-1)], dim=0
return (
query.contiguous().view(1, seq_len, -1, self.head_k_dim),
key.contiguous().view(1, seq_len, -1, self.head_k_dim),
value.contiguous().view(1, seq_len, -1, self.head_v_dim),
)

q_size = seq_len * q_dim
k_size = seq_len * k_dim

q_contig = fused[0:q_size]
k_contig = fused[q_size : q_size + k_size]
v_contig = fused[q_size + k_size :]

query = q_contig.view(1, seq_len, -1, self.head_k_dim)
key = k_contig.view(1, seq_len, -1, self.head_k_dim)
value = v_contig.view(1, seq_len, -1, self.head_v_dim)

return query, key, value

def forward(
self,
hidden_states: torch.Tensor,
Expand Down
Loading