-
-
Notifications
You must be signed in to change notification settings - Fork 16.5k
[DSV4] Enable Multi-stream for Pre-Attn GEMM #41061
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
0edd8b8
f1158ec
d14835b
974f6bd
a042cc1
7bf590c
bdcb5f2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -4,8 +4,9 @@ | |||||||
| DeepseekV4 MLA Attention Layer | ||||||||
| """ | ||||||||
|
|
||||||||
| from collections.abc import Callable | ||||||||
| from dataclasses import dataclass | ||||||||
| from typing import TYPE_CHECKING, cast | ||||||||
| from typing import TYPE_CHECKING, Any, cast | ||||||||
|
|
||||||||
| import torch | ||||||||
| import torch.nn as nn | ||||||||
|
|
@@ -16,6 +17,7 @@ | |||||||
| ReplicatedLinear, | ||||||||
| ) | ||||||||
| from vllm.model_executor.layers.sparse_attn_indexer import SparseAttnIndexer | ||||||||
| from vllm.model_executor.layers.utils import cublas_gemm_bf16_bf16_fp32 | ||||||||
| from vllm.utils.deep_gemm import fp8_einsum | ||||||||
| from vllm.utils.torch_utils import direct_register_custom_op | ||||||||
| from vllm.v1.attention.ops.deepseek_v4_ops import ( | ||||||||
|
|
@@ -51,7 +53,10 @@ | |||||||
| from vllm.model_executor.layers.quantization.utils.quant_utils import ( | ||||||||
| GroupShape, | ||||||||
| ) | ||||||||
| from vllm.utils.multi_stream_utils import maybe_execute_in_parallel | ||||||||
| from vllm.utils.multi_stream_utils import ( | ||||||||
| execute_in_parallel, | ||||||||
| maybe_execute_in_parallel, | ||||||||
| ) | ||||||||
| from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata | ||||||||
| from vllm.v1.attention.backends.mla.flashmla_sparse import ( | ||||||||
| DeepseekV4FlashMLASparseBackend, | ||||||||
|
|
@@ -94,7 +99,7 @@ class DeepseekV4MLAModules: | |||||||
| indexer: torch.nn.Module | None | ||||||||
| indexer_rotary_emb: torch.nn.Module | ||||||||
| topk_indices_buffer: torch.Tensor | None | ||||||||
| aux_stream: torch.cuda.Stream | None = None | ||||||||
| aux_stream_list: list[torch.cuda.Stream] | None = None | ||||||||
|
|
||||||||
|
|
||||||||
| # --8<-- [start:multi_head_latent_attention] | ||||||||
|
|
@@ -217,8 +222,11 @@ def __init__( | |||||||
| + 1 # 1B pad | ||||||||
| ) | ||||||||
|
|
||||||||
| self.aux_stream = mla_modules.aux_stream | ||||||||
| self.ln_events = [torch.cuda.Event(), torch.cuda.Event()] | ||||||||
| self.aux_stream_list = mla_modules.aux_stream_list | ||||||||
| # [0]: GEMM start / post-GEMM event0. [1..3]: GEMM done events; | ||||||||
| # [1] doubles as post-GEMM event1. Reuse is safe: GEMM fully joins | ||||||||
| # before post-GEMM starts. | ||||||||
| self.ln_events = [torch.cuda.Event() for _ in range(4)] | ||||||||
|
|
||||||||
| assert cache_config is not None, "DeepseekV4 attention requires cache_config" | ||||||||
| self.swa_cache_layer = DeepseekV4SWACache( | ||||||||
|
|
@@ -277,9 +285,6 @@ def forward( | |||||||
| hidden_states: torch.Tensor, | ||||||||
| llama_4_scaling: torch.Tensor | None = None, | ||||||||
| ) -> torch.Tensor: | ||||||||
| qr_kv, _ = self.fused_wqa_wkv(hidden_states) | ||||||||
| qr, kv = qr_kv.split([self.q_lora_rank, self.head_dim], dim=-1) | ||||||||
|
|
||||||||
| # Pre-allocate attention output with FlashMLA-padded head count. | ||||||||
| # The op writes into `o_padded`; we slice to n_local_heads after. | ||||||||
| num_tokens = hidden_states.shape[0] | ||||||||
|
|
@@ -292,8 +297,6 @@ def forward( | |||||||
| # Attention (inside custom op for torch.compile boundary) | ||||||||
| torch.ops.vllm.deepseek_v4_attention( | ||||||||
| hidden_states, | ||||||||
| qr, | ||||||||
| kv, | ||||||||
| positions, | ||||||||
| o_padded, | ||||||||
| self.layer_name, | ||||||||
|
|
@@ -332,60 +335,132 @@ def forward( | |||||||
|
|
||||||||
| return self.wo_b(z.flatten(1)) | ||||||||
|
|
||||||||
| def attn_gemm_parallel_execute(self, hidden_states) -> tuple[Any, ...]: | ||||||||
| assert self.aux_stream_list is not None | ||||||||
| assert len(self.aux_stream_list) >= 3 | ||||||||
|
|
||||||||
| # fused_wqa_wkv (heaviest) on default; the three lighter input GEMMs | ||||||||
| # on aux streams 0..2 when their owning module exists. ln_events[0] | ||||||||
| # is the fan-out start event; ln_events[1..3] are per-aux done events. | ||||||||
| aux_fns: list[Callable[[], Any] | None] = [None, None, None] | ||||||||
|
|
||||||||
| if self.compressor is not None: | ||||||||
| # Local ref so the closure keeps a non-None type for mypy. | ||||||||
| compressor = self.compressor | ||||||||
|
|
||||||||
| def compressor_kv_score() -> torch.Tensor: | ||||||||
| return cublas_gemm_bf16_bf16_fp32( | ||||||||
| hidden_states, compressor.fused_wkv_wgate.weight | ||||||||
| ) | ||||||||
|
|
||||||||
| aux_fns[0] = compressor_kv_score | ||||||||
|
|
||||||||
| if self.indexer is not None: | ||||||||
| indexer = self.indexer | ||||||||
|
|
||||||||
| def indexer_weights_proj() -> torch.Tensor: | ||||||||
| # ReplicatedLinear returns (output, bias); bias is None. | ||||||||
| weights, _ = indexer.weights_proj(hidden_states) | ||||||||
| return weights | ||||||||
|
|
||||||||
| def indexer_compressor_kv_score() -> torch.Tensor: | ||||||||
| return cublas_gemm_bf16_bf16_fp32( | ||||||||
| hidden_states, indexer.compressor.fused_wkv_wgate.weight | ||||||||
| ) | ||||||||
|
|
||||||||
| aux_fns[1] = indexer_weights_proj | ||||||||
| aux_fns[2] = indexer_compressor_kv_score | ||||||||
|
|
||||||||
| def fused_wqa_wkv() -> torch.Tensor: | ||||||||
| # MergedColumnParallelLinear returns (output, bias); bias is None. | ||||||||
| qr_kv, _ = self.fused_wqa_wkv(hidden_states) | ||||||||
| return qr_kv | ||||||||
|
|
||||||||
| qr_kv, (kv_score, indexer_weights, indexer_kv_score) = execute_in_parallel( | ||||||||
| fused_wqa_wkv, | ||||||||
| aux_fns, | ||||||||
| self.ln_events[0], | ||||||||
| self.ln_events[1:4], | ||||||||
| self.aux_stream_list[:3], | ||||||||
| ) | ||||||||
|
|
||||||||
| return qr_kv, kv_score, indexer_kv_score, indexer_weights | ||||||||
|
|
||||||||
| def attention_impl( | ||||||||
| self, | ||||||||
| hidden_states: torch.Tensor, | ||||||||
| qr: torch.Tensor, | ||||||||
| kv: torch.Tensor, | ||||||||
| positions: torch.Tensor, | ||||||||
| out: torch.Tensor, # [num_tokens, padded_heads, head_dim], written in place | ||||||||
| ) -> None: | ||||||||
| forward_context = get_forward_context() | ||||||||
| attn_metadata = forward_context.attn_metadata | ||||||||
|
|
||||||||
| qr_kv, kv_score, indexer_kv_score, indexer_weights = ( | ||||||||
| self.attn_gemm_parallel_execute(hidden_states) | ||||||||
| ) | ||||||||
|
|
||||||||
| qr, kv = qr_kv.split([self.q_lora_rank, self.head_dim], dim=-1) | ||||||||
| qr, kv = fused_q_kv_rmsnorm( | ||||||||
| qr, | ||||||||
| kv, | ||||||||
| self.q_norm.weight.data, | ||||||||
| self.kv_norm.weight.data, | ||||||||
| self.eps, | ||||||||
| ) | ||||||||
| q = self.wq_b(qr).view(-1, self.n_local_heads, self.head_dim) | ||||||||
|
|
||||||||
| # Overlap kv_insert with whichever of indexer/compressor is present. | ||||||||
| # Indexer implies compressor; when both exist, compressor rides on the | ||||||||
| # aux stream alongside kv_insert so the heavy indexer owns default. | ||||||||
| # wq_b + kv_insert (+ MLA compressor when an indexer is present) ride | ||||||||
| # on the default stream so q stays on its consumer stream (mla_attn | ||||||||
| # downstream reads q on default). Indexer/compressor go on aux for | ||||||||
| # overlap with default's GEMM + cache write. | ||||||||
| if self.indexer is not None: | ||||||||
| assert self.aux_stream_list is not None | ||||||||
| aux_stream = self.aux_stream_list[0] | ||||||||
| indexer = self.indexer | ||||||||
| # Local ref so the closure keeps a non-None type for mypy. | ||||||||
| assert self.compressor is not None | ||||||||
| compressor = self.compressor | ||||||||
|
|
||||||||
| def kv_insert_and_compress() -> None: | ||||||||
| def wq_b_kv_insert_and_compress() -> torch.Tensor: | ||||||||
| q = self.wq_b(qr).view(-1, self.n_local_heads, self.head_dim) | ||||||||
| self._fused_qnorm_rope_kv_insert(q, kv, positions, attn_metadata) | ||||||||
| compressor(hidden_states, positions, self.rotary_emb) | ||||||||
|
|
||||||||
| maybe_execute_in_parallel( | ||||||||
| lambda: indexer(hidden_states, qr, positions, self.indexer_rotary_emb), | ||||||||
| kv_insert_and_compress, | ||||||||
| compressor(kv_score, positions, self.rotary_emb) | ||||||||
| return q | ||||||||
|
|
||||||||
| q, _ = maybe_execute_in_parallel( | ||||||||
| wq_b_kv_insert_and_compress, | ||||||||
| lambda: indexer( | ||||||||
| hidden_states, | ||||||||
| qr, | ||||||||
| indexer_kv_score, | ||||||||
| indexer_weights, | ||||||||
| positions, | ||||||||
| self.indexer_rotary_emb, | ||||||||
| ), | ||||||||
| self.ln_events[0], | ||||||||
| self.ln_events[1], | ||||||||
| self.aux_stream, | ||||||||
| aux_stream, | ||||||||
| ) | ||||||||
| elif self.compressor is not None: | ||||||||
| # Compressor on default, kv_insert on aux. | ||||||||
| # wq_b + kv_insert on default, compressor on aux. | ||||||||
| assert self.aux_stream_list is not None | ||||||||
| aux_stream = self.aux_stream_list[0] | ||||||||
| compressor = self.compressor | ||||||||
| maybe_execute_in_parallel( | ||||||||
| lambda: compressor(hidden_states, positions, self.rotary_emb), | ||||||||
| lambda: self._fused_qnorm_rope_kv_insert( | ||||||||
| q, kv, positions, attn_metadata | ||||||||
| ), | ||||||||
|
|
||||||||
| def wq_b_kv_insert() -> torch.Tensor: | ||||||||
| q = self.wq_b(qr).view(-1, self.n_local_heads, self.head_dim) | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to the previous comment,
Suggested change
|
||||||||
| self._fused_qnorm_rope_kv_insert(q, kv, positions, attn_metadata) | ||||||||
| return q | ||||||||
|
|
||||||||
| q, _ = maybe_execute_in_parallel( | ||||||||
| wq_b_kv_insert, | ||||||||
| lambda: compressor(kv_score, positions, self.rotary_emb), | ||||||||
| self.ln_events[0], | ||||||||
| self.ln_events[1], | ||||||||
| self.aux_stream, | ||||||||
| aux_stream, | ||||||||
| ) | ||||||||
| else: | ||||||||
| # SWA-only layer: no compressor, no overlap. | ||||||||
| q = self.wq_b(qr).view(-1, self.n_local_heads, self.head_dim) | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||
| self._fused_qnorm_rope_kv_insert(q, kv, positions, attn_metadata) | ||||||||
|
|
||||||||
| # Handle dummy run (no metadata). | ||||||||
|
|
@@ -455,21 +530,17 @@ def _fused_qnorm_rope_kv_insert( | |||||||
|
|
||||||||
| def deepseek_v4_attention( | ||||||||
| hidden_states: torch.Tensor, | ||||||||
| qr: torch.Tensor, | ||||||||
| kv: torch.Tensor, | ||||||||
| positions: torch.Tensor, | ||||||||
| out: torch.Tensor, | ||||||||
| layer_name: str, | ||||||||
| ) -> None: | ||||||||
| forward_context: ForwardContext = get_forward_context() | ||||||||
| self = forward_context.no_compile_layers[layer_name] | ||||||||
| self.attention_impl(hidden_states, qr, kv, positions, out) | ||||||||
| self.attention_impl(hidden_states, positions, out) | ||||||||
|
|
||||||||
|
|
||||||||
| def deepseek_v4_attention_fake( | ||||||||
| hidden_states: torch.Tensor, | ||||||||
| qr: torch.Tensor, | ||||||||
| kv: torch.Tensor, | ||||||||
| positions: torch.Tensor, | ||||||||
| out: torch.Tensor, | ||||||||
| layer_name: str, | ||||||||
|
|
@@ -1057,18 +1128,20 @@ def forward( | |||||||
| self, | ||||||||
| hidden_states: torch.Tensor, | ||||||||
| qr: torch.Tensor, | ||||||||
| compressed_kv_score: torch.Tensor, | ||||||||
| indexer_weights: torch.Tensor, | ||||||||
| positions: torch.Tensor, | ||||||||
| rotary_emb: nn.Module, | ||||||||
| ) -> torch.Tensor: | ||||||||
| # ReplicatedLinear returns (output, bias); bias is None. | ||||||||
| q, _ = self.wq_b(qr) | ||||||||
| q = q.view(-1, self.n_head, self.head_dim) | ||||||||
| k = self.compressor(hidden_states, positions, rotary_emb) | ||||||||
| weights, _ = self.weights_proj(hidden_states) | ||||||||
| k = self.compressor(compressed_kv_score, positions, rotary_emb) | ||||||||
| q_quant, weights = fused_indexer_q_rope_quant( | ||||||||
| positions, | ||||||||
| q, | ||||||||
| rotary_emb.cos_sin_cache, | ||||||||
| weights, | ||||||||
| indexer_weights, | ||||||||
| self.softmax_scale, | ||||||||
| self.n_head**-0.5, | ||||||||
| use_fp4=self.use_fp4_kv, | ||||||||
|
|
||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
ColumnParallelLinearlayerself.wq_btypically returns a tuple(output, bias)in vLLM. Calling.view()directly on the result ofself.wq_b(qr)will raise anAttributeErrorif it returns a tuple. Please ensure that you unpack the output tensor before calling.view(), similar to how it's done inDeepseekV4Indexer.forwardat line 1137.