Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 2 additions & 7 deletions vllm/model_executor/layers/deepseek_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
)
from vllm.model_executor.layers.utils import cublas_gemm_bf16_bf16_fp32
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.v1.attention.backend import (
Expand Down Expand Up @@ -271,16 +270,12 @@ def __init__(

def forward(
self,
# [num_tokens, hidden_size]
x: torch.Tensor,
# [num_tokens, 2 * self.coff * self.head_dim]
kv_score: torch.Tensor,
# [num_tokens]
positions: torch.Tensor,
rotary_emb,
) -> None:
num_tokens, _ = x.shape
# bf16 weights/activations but fp32 output for numerical stability of
# the downstream compressor math.
kv_score = cublas_gemm_bf16_bf16_fp32(x, self.fused_wkv_wgate.weight)
# Each of shape [num_tokens, coff * self.head_dim]
# input bf16, output are fp32
kv, score = kv_score.split(
Expand Down
149 changes: 111 additions & 38 deletions vllm/model_executor/layers/deepseek_v4_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
DeepseekV4 MLA Attention Layer
"""

from collections.abc import Callable
from dataclasses import dataclass
from typing import TYPE_CHECKING, cast
from typing import TYPE_CHECKING, Any, cast

import torch
import torch.nn as nn
Expand All @@ -16,6 +17,7 @@
ReplicatedLinear,
)
from vllm.model_executor.layers.sparse_attn_indexer import SparseAttnIndexer
from vllm.model_executor.layers.utils import cublas_gemm_bf16_bf16_fp32
from vllm.utils.deep_gemm import fp8_einsum
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.ops.deepseek_v4_ops import (
Expand Down Expand Up @@ -51,7 +53,10 @@
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
)
from vllm.utils.multi_stream_utils import maybe_execute_in_parallel
from vllm.utils.multi_stream_utils import (
execute_in_parallel,
maybe_execute_in_parallel,
)
from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata
from vllm.v1.attention.backends.mla.flashmla_sparse import (
DeepseekV4FlashMLASparseBackend,
Expand Down Expand Up @@ -94,7 +99,7 @@ class DeepseekV4MLAModules:
indexer: torch.nn.Module | None
indexer_rotary_emb: torch.nn.Module
topk_indices_buffer: torch.Tensor | None
aux_stream: torch.cuda.Stream | None = None
aux_stream_list: list[torch.cuda.Stream] | None = None


# --8<-- [start:multi_head_latent_attention]
Expand Down Expand Up @@ -217,8 +222,11 @@ def __init__(
+ 1 # 1B pad
)

self.aux_stream = mla_modules.aux_stream
self.ln_events = [torch.cuda.Event(), torch.cuda.Event()]
self.aux_stream_list = mla_modules.aux_stream_list
# [0]: GEMM start / post-GEMM event0. [1..3]: GEMM done events;
# [1] doubles as post-GEMM event1. Reuse is safe: GEMM fully joins
# before post-GEMM starts.
self.ln_events = [torch.cuda.Event() for _ in range(4)]

assert cache_config is not None, "DeepseekV4 attention requires cache_config"
self.swa_cache_layer = DeepseekV4SWACache(
Expand Down Expand Up @@ -277,9 +285,6 @@ def forward(
hidden_states: torch.Tensor,
llama_4_scaling: torch.Tensor | None = None,
) -> torch.Tensor:
qr_kv, _ = self.fused_wqa_wkv(hidden_states)
qr, kv = qr_kv.split([self.q_lora_rank, self.head_dim], dim=-1)

# Pre-allocate attention output with FlashMLA-padded head count.
# The op writes into `o_padded`; we slice to n_local_heads after.
num_tokens = hidden_states.shape[0]
Expand All @@ -292,8 +297,6 @@ def forward(
# Attention (inside custom op for torch.compile boundary)
torch.ops.vllm.deepseek_v4_attention(
hidden_states,
qr,
kv,
positions,
o_padded,
self.layer_name,
Expand Down Expand Up @@ -332,60 +335,132 @@ def forward(

return self.wo_b(z.flatten(1))

def attn_gemm_parallel_execute(self, hidden_states) -> tuple[Any, ...]:
assert self.aux_stream_list is not None
assert len(self.aux_stream_list) >= 3

# fused_wqa_wkv (heaviest) on default; the three lighter input GEMMs
# on aux streams 0..2 when their owning module exists. ln_events[0]
# is the fan-out start event; ln_events[1..3] are per-aux done events.
aux_fns: list[Callable[[], Any] | None] = [None, None, None]

if self.compressor is not None:
# Local ref so the closure keeps a non-None type for mypy.
compressor = self.compressor

def compressor_kv_score() -> torch.Tensor:
return cublas_gemm_bf16_bf16_fp32(
hidden_states, compressor.fused_wkv_wgate.weight
)

aux_fns[0] = compressor_kv_score

if self.indexer is not None:
indexer = self.indexer

def indexer_weights_proj() -> torch.Tensor:
# ReplicatedLinear returns (output, bias); bias is None.
weights, _ = indexer.weights_proj(hidden_states)
return weights

def indexer_compressor_kv_score() -> torch.Tensor:
return cublas_gemm_bf16_bf16_fp32(
hidden_states, indexer.compressor.fused_wkv_wgate.weight
)

aux_fns[1] = indexer_weights_proj
aux_fns[2] = indexer_compressor_kv_score

def fused_wqa_wkv() -> torch.Tensor:
# MergedColumnParallelLinear returns (output, bias); bias is None.
qr_kv, _ = self.fused_wqa_wkv(hidden_states)
return qr_kv

qr_kv, (kv_score, indexer_weights, indexer_kv_score) = execute_in_parallel(
fused_wqa_wkv,
aux_fns,
self.ln_events[0],
self.ln_events[1:4],
self.aux_stream_list[:3],
)

return qr_kv, kv_score, indexer_kv_score, indexer_weights

def attention_impl(
self,
hidden_states: torch.Tensor,
qr: torch.Tensor,
kv: torch.Tensor,
positions: torch.Tensor,
out: torch.Tensor, # [num_tokens, padded_heads, head_dim], written in place
) -> None:
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata

qr_kv, kv_score, indexer_kv_score, indexer_weights = (
self.attn_gemm_parallel_execute(hidden_states)
)

qr, kv = qr_kv.split([self.q_lora_rank, self.head_dim], dim=-1)
qr, kv = fused_q_kv_rmsnorm(
qr,
kv,
self.q_norm.weight.data,
self.kv_norm.weight.data,
self.eps,
)
q = self.wq_b(qr).view(-1, self.n_local_heads, self.head_dim)

# Overlap kv_insert with whichever of indexer/compressor is present.
# Indexer implies compressor; when both exist, compressor rides on the
# aux stream alongside kv_insert so the heavy indexer owns default.
# wq_b + kv_insert (+ MLA compressor when an indexer is present) ride
# on the default stream so q stays on its consumer stream (mla_attn
# downstream reads q on default). Indexer/compressor go on aux for
# overlap with default's GEMM + cache write.
if self.indexer is not None:
assert self.aux_stream_list is not None
aux_stream = self.aux_stream_list[0]
indexer = self.indexer
# Local ref so the closure keeps a non-None type for mypy.
assert self.compressor is not None
compressor = self.compressor

def kv_insert_and_compress() -> None:
def wq_b_kv_insert_and_compress() -> torch.Tensor:
q = self.wq_b(qr).view(-1, self.n_local_heads, self.head_dim)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The ColumnParallelLinear layer self.wq_b typically returns a tuple (output, bias) in vLLM. Calling .view() directly on the result of self.wq_b(qr) will raise an AttributeError if it returns a tuple. Please ensure that you unpack the output tensor before calling .view(), similar to how it's done in DeepseekV4Indexer.forward at line 1137.

Suggested change
q = self.wq_b(qr).view(-1, self.n_local_heads, self.head_dim)
q, _ = self.wq_b(qr)
q = q.view(-1, self.n_local_heads, self.head_dim)

self._fused_qnorm_rope_kv_insert(q, kv, positions, attn_metadata)
compressor(hidden_states, positions, self.rotary_emb)

maybe_execute_in_parallel(
lambda: indexer(hidden_states, qr, positions, self.indexer_rotary_emb),
kv_insert_and_compress,
compressor(kv_score, positions, self.rotary_emb)
return q

q, _ = maybe_execute_in_parallel(
wq_b_kv_insert_and_compress,
lambda: indexer(
hidden_states,
qr,
indexer_kv_score,
indexer_weights,
positions,
self.indexer_rotary_emb,
),
self.ln_events[0],
self.ln_events[1],
self.aux_stream,
aux_stream,
)
elif self.compressor is not None:
# Compressor on default, kv_insert on aux.
# wq_b + kv_insert on default, compressor on aux.
assert self.aux_stream_list is not None
aux_stream = self.aux_stream_list[0]
compressor = self.compressor
maybe_execute_in_parallel(
lambda: compressor(hidden_states, positions, self.rotary_emb),
lambda: self._fused_qnorm_rope_kv_insert(
q, kv, positions, attn_metadata
),

def wq_b_kv_insert() -> torch.Tensor:
q = self.wq_b(qr).view(-1, self.n_local_heads, self.head_dim)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the previous comment, self.wq_b(qr) likely returns a tuple (output, bias). You should unpack the output tensor before calling .view() to avoid a potential AttributeError.

Suggested change
q = self.wq_b(qr).view(-1, self.n_local_heads, self.head_dim)
q, _ = self.wq_b(qr)
q = q.view(-1, self.n_local_heads, self.head_dim)

self._fused_qnorm_rope_kv_insert(q, kv, positions, attn_metadata)
return q

q, _ = maybe_execute_in_parallel(
wq_b_kv_insert,
lambda: compressor(kv_score, positions, self.rotary_emb),
self.ln_events[0],
self.ln_events[1],
self.aux_stream,
aux_stream,
)
else:
# SWA-only layer: no compressor, no overlap.
q = self.wq_b(qr).view(-1, self.n_local_heads, self.head_dim)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Unpack the output tensor from self.wq_b(qr) before calling .view() to ensure compatibility with the expected return type of ColumnParallelLinear.

Suggested change
q = self.wq_b(qr).view(-1, self.n_local_heads, self.head_dim)
q, _ = self.wq_b(qr)
q = q.view(-1, self.n_local_heads, self.head_dim)

self._fused_qnorm_rope_kv_insert(q, kv, positions, attn_metadata)

# Handle dummy run (no metadata).
Expand Down Expand Up @@ -455,21 +530,17 @@ def _fused_qnorm_rope_kv_insert(

def deepseek_v4_attention(
hidden_states: torch.Tensor,
qr: torch.Tensor,
kv: torch.Tensor,
positions: torch.Tensor,
out: torch.Tensor,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
self.attention_impl(hidden_states, qr, kv, positions, out)
self.attention_impl(hidden_states, positions, out)


def deepseek_v4_attention_fake(
hidden_states: torch.Tensor,
qr: torch.Tensor,
kv: torch.Tensor,
positions: torch.Tensor,
out: torch.Tensor,
layer_name: str,
Expand Down Expand Up @@ -1057,18 +1128,20 @@ def forward(
self,
hidden_states: torch.Tensor,
qr: torch.Tensor,
compressed_kv_score: torch.Tensor,
indexer_weights: torch.Tensor,
positions: torch.Tensor,
rotary_emb: nn.Module,
) -> torch.Tensor:
# ReplicatedLinear returns (output, bias); bias is None.
q, _ = self.wq_b(qr)
q = q.view(-1, self.n_head, self.head_dim)
k = self.compressor(hidden_states, positions, rotary_emb)
weights, _ = self.weights_proj(hidden_states)
k = self.compressor(compressed_kv_score, positions, rotary_emb)
q_quant, weights = fused_indexer_q_rope_quant(
positions,
q,
rotary_emb.cos_sin_cache,
weights,
indexer_weights,
self.softmax_scale,
self.n_head**-0.5,
use_fp4=self.use_fp4_kv,
Expand Down
22 changes: 10 additions & 12 deletions vllm/model_executor/models/deepseek_v4.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.triton_utils import tl, triton
from vllm.utils.multi_stream_utils import AuxStreamType
from vllm.utils.torch_utils import direct_register_custom_op

from .utils import (
Expand Down Expand Up @@ -872,7 +871,7 @@ def __init__(
vllm_config: VllmConfig,
prefix: str,
topk_indices_buffer: torch.Tensor | None = None,
aux_stream: torch.cuda.Stream | None = None,
aux_stream_list: list[torch.cuda.Stream] | None = None,
):
super().__init__()
config = vllm_config.model_config.hf_config
Expand Down Expand Up @@ -1005,7 +1004,7 @@ def __init__(
indexer=self.indexer,
indexer_rotary_emb=self.rotary_emb,
topk_indices_buffer=topk_indices_buffer,
aux_stream=aux_stream,
aux_stream_list=aux_stream_list,
)
self.mla_attn = DeepseekV4MultiHeadLatentAttentionWrapper(
hidden_size=self.hidden_size,
Expand Down Expand Up @@ -1041,7 +1040,7 @@ def __init__(
vllm_config,
prefix,
topk_indices_buffer: torch.Tensor | None = None,
aux_stream_dict: dict[AuxStreamType, torch.cuda.Stream] | None = None,
aux_stream_list: list[torch.cuda.Stream] | None = None,
):
super().__init__()
config = vllm_config.model_config.hf_config
Expand All @@ -1052,9 +1051,7 @@ def __init__(
vllm_config,
prefix=f"{prefix}.attn",
topk_indices_buffer=topk_indices_buffer,
aux_stream=aux_stream_dict.get(AuxStreamType.Attention)
if aux_stream_dict is not None
else None,
aux_stream_list=aux_stream_list,
)
self.ffn = DeepseekV4MoE(vllm_config, prefix=f"{prefix}.ffn")

Expand Down Expand Up @@ -1182,10 +1179,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.hc_dim = self.hc_mult * config.hidden_size
self.rms_norm_eps = config.rms_norm_eps

aux_stream_list = [torch.cuda.Stream() for _ in range(1)]
self.aux_stream_dict = {
AuxStreamType.Attention: aux_stream_list[0],
}
# Three aux streams: one per non-default input GEMM in
# DeepseekV4MultiHeadLatentAttentionWrapper.attn_gemm_parallel_execute
# (compressor kv_score, indexer.weights_proj, indexer.compressor
# kv_score). fused_wqa_wkv stays on the default stream.
aux_stream_list = [torch.cuda.Stream() for _ in range(3)]

self.device = current_platform.device_type
# Reserved topk indices buffer for all Indexer layers to reuse.
Expand All @@ -1209,7 +1207,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
vllm_config,
prefix=prefix,
topk_indices_buffer=self.topk_indices_buffer,
aux_stream_dict=self.aux_stream_dict,
aux_stream_list=aux_stream_list,
),
prefix=f"{prefix}.layers",
)
Expand Down
Loading
Loading