From 0edd8b82525f9f3dd6564ad4801cfe7019e726bd Mon Sep 17 00:00:00 2001 From: Yongye Zhu Date: Mon, 27 Apr 2026 17:21:04 +0000 Subject: [PATCH 1/7] init multi stream support Signed-off-by: Yongye Zhu --- .../layers/deepseek_compressor.py | 9 +- .../layers/deepseek_v4_attention.py | 121 ++++++++++++++---- vllm/model_executor/models/deepseek_v4.py | 18 +-- vllm/utils/multi_stream_utils.py | 64 +++++++++ 4 files changed, 166 insertions(+), 46 deletions(-) diff --git a/vllm/model_executor/layers/deepseek_compressor.py b/vllm/model_executor/layers/deepseek_compressor.py index af2783f604da..cae80c35316a 100644 --- a/vllm/model_executor/layers/deepseek_compressor.py +++ b/vllm/model_executor/layers/deepseek_compressor.py @@ -14,7 +14,6 @@ from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, ) -from vllm.model_executor.layers.utils import cublas_gemm_bf16_bf16_fp32 from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.v1.attention.backend import ( @@ -271,16 +270,12 @@ def __init__( def forward( self, - # [num_tokens, hidden_size] - x: torch.Tensor, + # [num_tokens, 2 * self.coff * self.head_dim] + kv_score: torch.Tensor, # [num_tokens] positions: torch.Tensor, rotary_emb, ) -> None: - num_tokens, _ = x.shape - # bf16 weights/activations but fp32 output for numerical stability of - # the downstream compressor math. - kv_score = cublas_gemm_bf16_bf16_fp32(x, self.fused_wkv_wgate.weight) # Each of shape [num_tokens, coff * self.head_dim] # input bf16, output are fp32 kv, score = kv_score.split( diff --git a/vllm/model_executor/layers/deepseek_v4_attention.py b/vllm/model_executor/layers/deepseek_v4_attention.py index 43242eddb5b2..e8c748c56ab4 100644 --- a/vllm/model_executor/layers/deepseek_v4_attention.py +++ b/vllm/model_executor/layers/deepseek_v4_attention.py @@ -4,8 +4,9 @@ DeepseekV4 MLA Attention Layer """ +from collections.abc import Callable from dataclasses import dataclass -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING, Any, cast import torch import torch.nn as nn @@ -16,6 +17,7 @@ ReplicatedLinear, ) from vllm.model_executor.layers.sparse_attn_indexer import SparseAttnIndexer +from vllm.model_executor.layers.utils import cublas_gemm_bf16_bf16_fp32 from vllm.utils.deep_gemm import fp8_einsum from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.ops.deepseek_v4_ops import ( @@ -51,7 +53,10 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, ) -from vllm.utils.multi_stream_utils import maybe_execute_in_parallel +from vllm.utils.multi_stream_utils import ( + execute_in_parallel, + maybe_execute_in_parallel, +) from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata from vllm.v1.attention.backends.mla.flashmla_sparse import ( DeepseekV4FlashMLASparseBackend, @@ -94,7 +99,7 @@ class DeepseekV4MLAModules: indexer: torch.nn.Module | None indexer_rotary_emb: torch.nn.Module topk_indices_buffer: torch.Tensor | None - aux_stream: torch.cuda.Stream | None = None + aux_stream_list: list[torch.cuda.Stream] | None = None # --8<-- [start:multi_head_latent_attention] @@ -217,8 +222,14 @@ def __init__( + 1 # 1B pad ) - self.aux_stream = mla_modules.aux_stream - self.ln_events = [torch.cuda.Event(), torch.cuda.Event()] + self.aux_stream_list = mla_modules.aux_stream_list + # ln_events[0]: doubles as the GEMM-phase start event (default stream + # → fan-out to all aux streams) and as event0 for the subsequent + # maybe_execute_in_parallel call. ln_events[1..3] are the three + # GEMM-phase done events; ln_events[1] is also reused as event1 for + # maybe_execute_in_parallel since the GEMM phase has fully joined + # before that call. + self.ln_events = [torch.cuda.Event() for _ in range(4)] assert cache_config is not None, "DeepseekV4 attention requires cache_config" self.swa_cache_layer = DeepseekV4SWACache( @@ -277,9 +288,6 @@ def forward( hidden_states: torch.Tensor, llama_4_scaling: torch.Tensor | None = None, ) -> torch.Tensor: - qr_kv, _ = self.fused_wqa_wkv(hidden_states) - qr, kv = qr_kv.split([self.q_lora_rank, self.head_dim], dim=-1) - # Pre-allocate attention output with FlashMLA-padded head count. # The op writes into `o_padded`; we slice to n_local_heads after. num_tokens = hidden_states.shape[0] @@ -292,8 +300,6 @@ def forward( # Attention (inside custom op for torch.compile boundary) torch.ops.vllm.deepseek_v4_attention( hidden_states, - qr, - kv, positions, o_padded, self.layer_name, @@ -332,17 +338,71 @@ def forward( return self.wo_b(z.flatten(1)) + def attn_gemm_parallel_execute(self, hidden_states) -> tuple[Any, ...]: + assert self.aux_stream_list is not None + assert len(self.aux_stream_list) >= 3 + + # fused_wqa_wkv (heaviest) on default; the three lighter input GEMMs + # on aux streams 0..2 when their owning module exists. ln_events[0] + # is the fan-out start event; ln_events[1..3] are per-aux done events. + aux_fns: list[Callable[[], Any] | None] = [None, None, None] + + if self.compressor is not None: + # Local ref so the closure keeps a non-None type for mypy. + compressor = self.compressor + + def compressor_kv_score() -> torch.Tensor: + return cublas_gemm_bf16_bf16_fp32( + hidden_states, compressor.fused_wkv_wgate.weight + ) + + aux_fns[0] = compressor_kv_score + + if self.indexer is not None: + indexer = self.indexer + + def indexer_weights_proj() -> torch.Tensor: + # ReplicatedLinear returns (output, bias); bias is None. + weights, _ = indexer.weights_proj(hidden_states) + return weights + + def indexer_compressor_kv_score() -> torch.Tensor: + return cublas_gemm_bf16_bf16_fp32( + hidden_states, indexer.compressor.fused_wkv_wgate.weight + ) + + aux_fns[1] = indexer_weights_proj + aux_fns[2] = indexer_compressor_kv_score + + def fused_wqa_wkv() -> torch.Tensor: + # MergedColumnParallelLinear returns (output, bias); bias is None. + qr_kv, _ = self.fused_wqa_wkv(hidden_states) + return qr_kv + + qr_kv, (kv_score, indexer_weights, indexer_kv_score) = execute_in_parallel( + fused_wqa_wkv, + aux_fns, + self.ln_events[0], + self.ln_events[1:4], + self.aux_stream_list[:3], + ) + + return qr_kv, kv_score, indexer_kv_score, indexer_weights + def attention_impl( self, hidden_states: torch.Tensor, - qr: torch.Tensor, - kv: torch.Tensor, positions: torch.Tensor, out: torch.Tensor, # [num_tokens, padded_heads, head_dim], written in place ) -> None: forward_context = get_forward_context() attn_metadata = forward_context.attn_metadata + qr_kv, kv_score, indexer_kv_score, indexer_weights = ( + self.attn_gemm_parallel_execute(hidden_states) + ) + + qr, kv = qr_kv.split([self.q_lora_rank, self.head_dim], dim=-1) qr, kv = fused_q_kv_rmsnorm( qr, kv, @@ -356,6 +416,8 @@ def attention_impl( # Indexer implies compressor; when both exist, compressor rides on the # aux stream alongside kv_insert so the heavy indexer owns default. if self.indexer is not None: + assert self.aux_stream_list is not None + aux_stream = self.aux_stream_list[0] indexer = self.indexer # Local ref so the closure keeps a non-None type for mypy. assert self.compressor is not None @@ -363,26 +425,35 @@ def attention_impl( def kv_insert_and_compress() -> None: self._fused_qnorm_rope_kv_insert(q, kv, positions, attn_metadata) - compressor(hidden_states, positions, self.rotary_emb) + compressor(kv_score, positions, self.rotary_emb) maybe_execute_in_parallel( - lambda: indexer(hidden_states, qr, positions, self.indexer_rotary_emb), + lambda: indexer( + hidden_states, + q, + indexer_kv_score, + indexer_weights, + positions, + self.indexer_rotary_emb, + ), kv_insert_and_compress, self.ln_events[0], self.ln_events[1], - self.aux_stream, + aux_stream, ) elif self.compressor is not None: # Compressor on default, kv_insert on aux. + assert self.aux_stream_list is not None + aux_stream = self.aux_stream_list[0] compressor = self.compressor maybe_execute_in_parallel( - lambda: compressor(hidden_states, positions, self.rotary_emb), + lambda: compressor(kv_score, positions, self.rotary_emb), lambda: self._fused_qnorm_rope_kv_insert( q, kv, positions, attn_metadata ), self.ln_events[0], self.ln_events[1], - self.aux_stream, + aux_stream, ) else: # SWA-only layer: no compressor, no overlap. @@ -455,21 +526,17 @@ def _fused_qnorm_rope_kv_insert( def deepseek_v4_attention( hidden_states: torch.Tensor, - qr: torch.Tensor, - kv: torch.Tensor, positions: torch.Tensor, out: torch.Tensor, layer_name: str, ) -> None: forward_context: ForwardContext = get_forward_context() self = forward_context.no_compile_layers[layer_name] - self.attention_impl(hidden_states, qr, kv, positions, out) + self.attention_impl(hidden_states, positions, out) def deepseek_v4_attention_fake( hidden_states: torch.Tensor, - qr: torch.Tensor, - kv: torch.Tensor, positions: torch.Tensor, out: torch.Tensor, layer_name: str, @@ -1056,19 +1123,19 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - qr: torch.Tensor, + q: torch.Tensor, + compressed_kv_score: torch.Tensor, + indexer_weights: torch.Tensor, positions: torch.Tensor, rotary_emb: nn.Module, ) -> torch.Tensor: - q, _ = self.wq_b(qr) q = q.view(-1, self.n_head, self.head_dim) - k = self.compressor(hidden_states, positions, rotary_emb) - weights, _ = self.weights_proj(hidden_states) + k = self.compressor(compressed_kv_score, positions, rotary_emb) q_quant, weights = fused_indexer_q_rope_quant( positions, q, rotary_emb.cos_sin_cache, - weights, + indexer_weights, self.softmax_scale, self.n_head**-0.5, use_fp4=self.use_fp4_kv, diff --git a/vllm/model_executor/models/deepseek_v4.py b/vllm/model_executor/models/deepseek_v4.py index 7733252804b7..9f6b426dd28b 100644 --- a/vllm/model_executor/models/deepseek_v4.py +++ b/vllm/model_executor/models/deepseek_v4.py @@ -54,7 +54,6 @@ from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.triton_utils import tl, triton -from vllm.utils.multi_stream_utils import AuxStreamType from vllm.utils.torch_utils import direct_register_custom_op from .utils import ( @@ -872,7 +871,7 @@ def __init__( vllm_config: VllmConfig, prefix: str, topk_indices_buffer: torch.Tensor | None = None, - aux_stream: torch.cuda.Stream | None = None, + aux_stream_list: list[torch.cuda.Stream] | None = None, ): super().__init__() config = vllm_config.model_config.hf_config @@ -1005,7 +1004,7 @@ def __init__( indexer=self.indexer, indexer_rotary_emb=self.rotary_emb, topk_indices_buffer=topk_indices_buffer, - aux_stream=aux_stream, + aux_stream_list=aux_stream_list, ) self.mla_attn = DeepseekV4MultiHeadLatentAttentionWrapper( hidden_size=self.hidden_size, @@ -1041,7 +1040,7 @@ def __init__( vllm_config, prefix, topk_indices_buffer: torch.Tensor | None = None, - aux_stream_dict: dict[AuxStreamType, torch.cuda.Stream] | None = None, + aux_stream_list: list[torch.cuda.Stream] | None = None, ): super().__init__() config = vllm_config.model_config.hf_config @@ -1052,9 +1051,7 @@ def __init__( vllm_config, prefix=f"{prefix}.attn", topk_indices_buffer=topk_indices_buffer, - aux_stream=aux_stream_dict.get(AuxStreamType.Attention) - if aux_stream_dict is not None - else None, + aux_stream_list=aux_stream_list, ) self.ffn = DeepseekV4MoE(vllm_config, prefix=f"{prefix}.ffn") @@ -1182,10 +1179,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.hc_dim = self.hc_mult * config.hidden_size self.rms_norm_eps = config.rms_norm_eps - aux_stream_list = [torch.cuda.Stream() for _ in range(1)] - self.aux_stream_dict = { - AuxStreamType.Attention: aux_stream_list[0], - } + aux_stream_list = [torch.cuda.Stream() for _ in range(3)] self.device = current_platform.device_type # Reserved topk indices buffer for all Indexer layers to reuse. @@ -1209,7 +1203,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): vllm_config, prefix=prefix, topk_indices_buffer=self.topk_indices_buffer, - aux_stream_dict=self.aux_stream_dict, + aux_stream_list=aux_stream_list, ), prefix=f"{prefix}.layers", ) diff --git a/vllm/utils/multi_stream_utils.py b/vllm/utils/multi_stream_utils.py index cc6bc6462449..c00f08f93329 100644 --- a/vllm/utils/multi_stream_utils.py +++ b/vllm/utils/multi_stream_utils.py @@ -56,3 +56,67 @@ def maybe_execute_in_parallel( result0 = fn0() result1 = fn1() return (result0, result1) + + +def execute_in_parallel( + default_fn: Callable[[], Any], + aux_fns: list[Callable[[], Any] | None], + start_event: torch.cuda.Event, + done_events: list[torch.cuda.Event], + aux_streams: list[torch.cuda.Stream] | None = None, +) -> tuple[Any, list[Any]]: + """Run default_fn on the current stream and aux_fns concurrently on + aux_streams. + + Generalizes maybe_execute_in_parallel to N aux callables. Slots where + aux_fns[i] is None are skipped (no stream switch, no event record); their + corresponding entry in the returned aux_results list is None. + + start_event fans out from the current stream to every launched aux stream; + done_events[i] is recorded after aux_fns[i] so the current stream joins + before returning. When aux_streams is None, all aux_fns run sequentially + on the current stream. + + Args: + default_fn: Callable for the default (current) stream. + aux_fns: Per-aux callables; entries may be None to skip. + start_event: CUDA event recorded on the current stream before + default_fn so each launched aux stream can wait on it. + done_events: One CUDA event per aux slot, recorded after the + corresponding aux_fn. Length must match aux_fns. + aux_streams: Per-aux CUDA streams. Length must match aux_fns. + Multi-stream is disabled when None. + + Returns: + Tuple of (default_result, aux_results) where aux_results[i] is the + result of aux_fns[i] (or None when skipped). + """ + aux_results: list[Any] + if aux_streams is None: + default_result = default_fn() + aux_results = [fn() if fn is not None else None for fn in aux_fns] + return default_result, aux_results + + assert len(aux_fns) == len(aux_streams) == len(done_events), ( + "aux_fns, aux_streams, and done_events must be the same length" + ) + + aux_results = [None] * len(aux_fns) + pending: list[torch.cuda.Event] = [] + + start_event.record() + for i, fn in enumerate(aux_fns): + if fn is None: + continue + with torch.cuda.stream(aux_streams[i]): + start_event.wait() + aux_results[i] = fn() + done_events[i].record() + pending.append(done_events[i]) + + default_result = default_fn() + + for ev in pending: + ev.wait() + + return default_result, aux_results From f1158ec57febee4c291868de8937324d655e5d6a Mon Sep 17 00:00:00 2001 From: Yongye Zhu Date: Mon, 27 Apr 2026 17:33:16 +0000 Subject: [PATCH 2/7] bug fix and more lint Signed-off-by: Yongye Zhu --- .../layers/deepseek_v4_attention.py | 36 +++++++++++-------- vllm/model_executor/models/deepseek_v4.py | 4 +++ 2 files changed, 26 insertions(+), 14 deletions(-) diff --git a/vllm/model_executor/layers/deepseek_v4_attention.py b/vllm/model_executor/layers/deepseek_v4_attention.py index e8c748c56ab4..b78207f27a59 100644 --- a/vllm/model_executor/layers/deepseek_v4_attention.py +++ b/vllm/model_executor/layers/deepseek_v4_attention.py @@ -410,11 +410,10 @@ def attention_impl( self.kv_norm.weight.data, self.eps, ) - q = self.wq_b(qr).view(-1, self.n_local_heads, self.head_dim) - # Overlap kv_insert with whichever of indexer/compressor is present. - # Indexer implies compressor; when both exist, compressor rides on the - # aux stream alongside kv_insert so the heavy indexer owns default. + # Overlap wq_b + kv_insert (+ compressor when indexer is present) on + # the aux stream with indexer/compressor on default. q flows out via + # event1 sync. if self.indexer is not None: assert self.aux_stream_list is not None aux_stream = self.aux_stream_list[0] @@ -423,40 +422,47 @@ def attention_impl( assert self.compressor is not None compressor = self.compressor - def kv_insert_and_compress() -> None: + def wq_b_kv_insert_and_compress() -> torch.Tensor: + q = self.wq_b(qr).view(-1, self.n_local_heads, self.head_dim) self._fused_qnorm_rope_kv_insert(q, kv, positions, attn_metadata) compressor(kv_score, positions, self.rotary_emb) + return q - maybe_execute_in_parallel( + _, q = maybe_execute_in_parallel( lambda: indexer( hidden_states, - q, + qr, indexer_kv_score, indexer_weights, positions, self.indexer_rotary_emb, ), - kv_insert_and_compress, + wq_b_kv_insert_and_compress, self.ln_events[0], self.ln_events[1], aux_stream, ) elif self.compressor is not None: - # Compressor on default, kv_insert on aux. + # Compressor on default, wq_b + kv_insert on aux. assert self.aux_stream_list is not None aux_stream = self.aux_stream_list[0] compressor = self.compressor - maybe_execute_in_parallel( + + def wq_b_kv_insert() -> torch.Tensor: + q = self.wq_b(qr).view(-1, self.n_local_heads, self.head_dim) + self._fused_qnorm_rope_kv_insert(q, kv, positions, attn_metadata) + return q + + _, q = maybe_execute_in_parallel( lambda: compressor(kv_score, positions, self.rotary_emb), - lambda: self._fused_qnorm_rope_kv_insert( - q, kv, positions, attn_metadata - ), + wq_b_kv_insert, self.ln_events[0], self.ln_events[1], aux_stream, ) else: # SWA-only layer: no compressor, no overlap. + q = self.wq_b(qr).view(-1, self.n_local_heads, self.head_dim) self._fused_qnorm_rope_kv_insert(q, kv, positions, attn_metadata) # Handle dummy run (no metadata). @@ -1123,12 +1129,14 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - q: torch.Tensor, + qr: torch.Tensor, compressed_kv_score: torch.Tensor, indexer_weights: torch.Tensor, positions: torch.Tensor, rotary_emb: nn.Module, ) -> torch.Tensor: + # ReplicatedLinear returns (output, bias); bias is None. + q, _ = self.wq_b(qr) q = q.view(-1, self.n_head, self.head_dim) k = self.compressor(compressed_kv_score, positions, rotary_emb) q_quant, weights = fused_indexer_q_rope_quant( diff --git a/vllm/model_executor/models/deepseek_v4.py b/vllm/model_executor/models/deepseek_v4.py index 9f6b426dd28b..6b10532dc827 100644 --- a/vllm/model_executor/models/deepseek_v4.py +++ b/vllm/model_executor/models/deepseek_v4.py @@ -1179,6 +1179,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.hc_dim = self.hc_mult * config.hidden_size self.rms_norm_eps = config.rms_norm_eps + # Three aux streams: one per non-default input GEMM in + # DeepseekV4MultiHeadLatentAttentionWrapper.attn_gemm_parallel_execute + # (compressor kv_score, indexer.weights_proj, indexer.compressor + # kv_score). fused_wqa_wkv stays on the default stream. aux_stream_list = [torch.cuda.Stream() for _ in range(3)] self.device = current_platform.device_type From d14835b7d7f73224d8e86d5376890d87adf37968 Mon Sep 17 00:00:00 2001 From: Yongye Zhu Date: Mon, 27 Apr 2026 19:20:04 +0000 Subject: [PATCH 3/7] use separate events Signed-off-by: Yongye Zhu --- .../layers/deepseek_v4_attention.py | 23 ++++++++++--------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/vllm/model_executor/layers/deepseek_v4_attention.py b/vllm/model_executor/layers/deepseek_v4_attention.py index b78207f27a59..31459d49ea81 100644 --- a/vllm/model_executor/layers/deepseek_v4_attention.py +++ b/vllm/model_executor/layers/deepseek_v4_attention.py @@ -223,13 +223,14 @@ def __init__( ) self.aux_stream_list = mla_modules.aux_stream_list - # ln_events[0]: doubles as the GEMM-phase start event (default stream - # → fan-out to all aux streams) and as event0 for the subsequent - # maybe_execute_in_parallel call. ln_events[1..3] are the three - # GEMM-phase done events; ln_events[1] is also reused as event1 for - # maybe_execute_in_parallel since the GEMM phase has fully joined - # before that call. - self.ln_events = [torch.cuda.Event() for _ in range(4)] + # gemm_events[0]: GEMM-phase start (default → fan-out to aux 0..2). + # gemm_events[1..3]: per-aux done events (aux i → default). + # ln_events[0..1]: post-GEMM maybe_execute_in_parallel event0/event1. + # Disjoint sets — re-recording the same event twice in a single + # forward (and hence a single CUDA graph capture) can cause + # graph-edge ambiguity and hangs on first launch. + self.gemm_events = [torch.cuda.Event() for _ in range(4)] + self.ln_events = [torch.cuda.Event(), torch.cuda.Event()] assert cache_config is not None, "DeepseekV4 attention requires cache_config" self.swa_cache_layer = DeepseekV4SWACache( @@ -343,8 +344,8 @@ def attn_gemm_parallel_execute(self, hidden_states) -> tuple[Any, ...]: assert len(self.aux_stream_list) >= 3 # fused_wqa_wkv (heaviest) on default; the three lighter input GEMMs - # on aux streams 0..2 when their owning module exists. ln_events[0] - # is the fan-out start event; ln_events[1..3] are per-aux done events. + # on aux streams 0..2 when their owning module exists. gemm_events[0] + # is the fan-out start event; gemm_events[1..3] are per-aux done events. aux_fns: list[Callable[[], Any] | None] = [None, None, None] if self.compressor is not None: @@ -382,8 +383,8 @@ def fused_wqa_wkv() -> torch.Tensor: qr_kv, (kv_score, indexer_weights, indexer_kv_score) = execute_in_parallel( fused_wqa_wkv, aux_fns, - self.ln_events[0], - self.ln_events[1:4], + self.gemm_events[0], + self.gemm_events[1:4], self.aux_stream_list[:3], ) From 974f6bd0aa26c9e4dfa4ccd045128bdd970d8fca Mon Sep 17 00:00:00 2001 From: Yongye Zhu Date: Mon, 27 Apr 2026 19:37:36 +0000 Subject: [PATCH 4/7] fix concurrency bug Signed-off-by: Yongye Zhu --- .../layers/deepseek_v4_attention.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/layers/deepseek_v4_attention.py b/vllm/model_executor/layers/deepseek_v4_attention.py index 31459d49ea81..9104d13e22db 100644 --- a/vllm/model_executor/layers/deepseek_v4_attention.py +++ b/vllm/model_executor/layers/deepseek_v4_attention.py @@ -412,9 +412,10 @@ def attention_impl( self.eps, ) - # Overlap wq_b + kv_insert (+ compressor when indexer is present) on - # the aux stream with indexer/compressor on default. q flows out via - # event1 sync. + # wq_b + kv_insert (+ MLA compressor when an indexer is present) ride + # on the default stream so q stays on its consumer stream (mla_attn + # downstream reads q on default). Indexer/compressor go on aux for + # overlap with default's GEMM + cache write. if self.indexer is not None: assert self.aux_stream_list is not None aux_stream = self.aux_stream_list[0] @@ -429,7 +430,8 @@ def wq_b_kv_insert_and_compress() -> torch.Tensor: compressor(kv_score, positions, self.rotary_emb) return q - _, q = maybe_execute_in_parallel( + q, _ = maybe_execute_in_parallel( + wq_b_kv_insert_and_compress, lambda: indexer( hidden_states, qr, @@ -438,13 +440,12 @@ def wq_b_kv_insert_and_compress() -> torch.Tensor: positions, self.indexer_rotary_emb, ), - wq_b_kv_insert_and_compress, self.ln_events[0], self.ln_events[1], aux_stream, ) elif self.compressor is not None: - # Compressor on default, wq_b + kv_insert on aux. + # wq_b + kv_insert on default, compressor on aux. assert self.aux_stream_list is not None aux_stream = self.aux_stream_list[0] compressor = self.compressor @@ -454,9 +455,9 @@ def wq_b_kv_insert() -> torch.Tensor: self._fused_qnorm_rope_kv_insert(q, kv, positions, attn_metadata) return q - _, q = maybe_execute_in_parallel( - lambda: compressor(kv_score, positions, self.rotary_emb), + q, _ = maybe_execute_in_parallel( wq_b_kv_insert, + lambda: compressor(kv_score, positions, self.rotary_emb), self.ln_events[0], self.ln_events[1], aux_stream, From a042cc16aa09a5065f5c1890182decf2101ce0e9 Mon Sep 17 00:00:00 2001 From: Yongye Zhu Date: Mon, 27 Apr 2026 19:45:04 +0000 Subject: [PATCH 5/7] switch using stream.wait Signed-off-by: Yongye Zhu --- .../layers/deepseek_v4_attention.py | 83 +++++++++---------- 1 file changed, 38 insertions(+), 45 deletions(-) diff --git a/vllm/model_executor/layers/deepseek_v4_attention.py b/vllm/model_executor/layers/deepseek_v4_attention.py index 9104d13e22db..2c5ada05f33e 100644 --- a/vllm/model_executor/layers/deepseek_v4_attention.py +++ b/vllm/model_executor/layers/deepseek_v4_attention.py @@ -4,7 +4,6 @@ DeepseekV4 MLA Attention Layer """ -from collections.abc import Callable from dataclasses import dataclass from typing import TYPE_CHECKING, Any, cast @@ -53,10 +52,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, ) -from vllm.utils.multi_stream_utils import ( - execute_in_parallel, - maybe_execute_in_parallel, -) +from vllm.utils.multi_stream_utils import maybe_execute_in_parallel from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata from vllm.v1.attention.backends.mla.flashmla_sparse import ( DeepseekV4FlashMLASparseBackend, @@ -223,13 +219,9 @@ def __init__( ) self.aux_stream_list = mla_modules.aux_stream_list - # gemm_events[0]: GEMM-phase start (default → fan-out to aux 0..2). - # gemm_events[1..3]: per-aux done events (aux i → default). # ln_events[0..1]: post-GEMM maybe_execute_in_parallel event0/event1. - # Disjoint sets — re-recording the same event twice in a single - # forward (and hence a single CUDA graph capture) can cause - # graph-edge ambiguity and hangs on first launch. - self.gemm_events = [torch.cuda.Event() for _ in range(4)] + # The GEMM phase synchronizes via stream.wait_stream() (no user- + # managed events) so it doesn't share event slots with this pair. self.ln_events = [torch.cuda.Event(), torch.cuda.Event()] assert cache_config is not None, "DeepseekV4 attention requires cache_config" @@ -343,50 +335,51 @@ def attn_gemm_parallel_execute(self, hidden_states) -> tuple[Any, ...]: assert self.aux_stream_list is not None assert len(self.aux_stream_list) >= 3 - # fused_wqa_wkv (heaviest) on default; the three lighter input GEMMs - # on aux streams 0..2 when their owning module exists. gemm_events[0] - # is the fan-out start event; gemm_events[1..3] are per-aux done events. - aux_fns: list[Callable[[], Any] | None] = [None, None, None] + # Run fused_wqa_wkv (heaviest) on the current (default) stream and the + # three lighter input GEMMs on aux streams 0..2 when their owning + # module exists. Sync via stream.wait_stream() rather than user-managed + # events: aux streams wait_stream(current) so they observe inputs + # produced upstream on default, then current.wait_stream(aux) joins + # the aux outputs back before returning. + current = torch.cuda.current_stream() + aux0, aux1, aux2 = ( + self.aux_stream_list[0], + self.aux_stream_list[1], + self.aux_stream_list[2], + ) - if self.compressor is not None: - # Local ref so the closure keeps a non-None type for mypy. - compressor = self.compressor + kv_score: torch.Tensor | None = None + indexer_weights: torch.Tensor | None = None + indexer_kv_score: torch.Tensor | None = None + used_aux: list[torch.cuda.Stream] = [] - def compressor_kv_score() -> torch.Tensor: - return cublas_gemm_bf16_bf16_fp32( - hidden_states, compressor.fused_wkv_wgate.weight + if self.compressor is not None: + aux0.wait_stream(current) + with torch.cuda.stream(aux0): + kv_score = cublas_gemm_bf16_bf16_fp32( + hidden_states, self.compressor.fused_wkv_wgate.weight ) - - aux_fns[0] = compressor_kv_score + used_aux.append(aux0) if self.indexer is not None: - indexer = self.indexer - - def indexer_weights_proj() -> torch.Tensor: + aux1.wait_stream(current) + with torch.cuda.stream(aux1): # ReplicatedLinear returns (output, bias); bias is None. - weights, _ = indexer.weights_proj(hidden_states) - return weights + indexer_weights, _ = self.indexer.weights_proj(hidden_states) + used_aux.append(aux1) - def indexer_compressor_kv_score() -> torch.Tensor: - return cublas_gemm_bf16_bf16_fp32( - hidden_states, indexer.compressor.fused_wkv_wgate.weight + aux2.wait_stream(current) + with torch.cuda.stream(aux2): + indexer_kv_score = cublas_gemm_bf16_bf16_fp32( + hidden_states, self.indexer.compressor.fused_wkv_wgate.weight ) + used_aux.append(aux2) - aux_fns[1] = indexer_weights_proj - aux_fns[2] = indexer_compressor_kv_score + # MergedColumnParallelLinear returns (output, bias); bias is None. + qr_kv, _ = self.fused_wqa_wkv(hidden_states) - def fused_wqa_wkv() -> torch.Tensor: - # MergedColumnParallelLinear returns (output, bias); bias is None. - qr_kv, _ = self.fused_wqa_wkv(hidden_states) - return qr_kv - - qr_kv, (kv_score, indexer_weights, indexer_kv_score) = execute_in_parallel( - fused_wqa_wkv, - aux_fns, - self.gemm_events[0], - self.gemm_events[1:4], - self.aux_stream_list[:3], - ) + for aux in used_aux: + current.wait_stream(aux) return qr_kv, kv_score, indexer_kv_score, indexer_weights From 7bf590cc775f676af4118dfc0fa71007c9e9bd5d Mon Sep 17 00:00:00 2001 From: Yongye Zhu Date: Mon, 27 Apr 2026 19:49:18 +0000 Subject: [PATCH 6/7] wait on all streams Signed-off-by: Yongye Zhu --- .../layers/deepseek_v4_attention.py | 25 +++++++++---------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/vllm/model_executor/layers/deepseek_v4_attention.py b/vllm/model_executor/layers/deepseek_v4_attention.py index 2c5ada05f33e..311ca6df1819 100644 --- a/vllm/model_executor/layers/deepseek_v4_attention.py +++ b/vllm/model_executor/layers/deepseek_v4_attention.py @@ -338,9 +338,9 @@ def attn_gemm_parallel_execute(self, hidden_states) -> tuple[Any, ...]: # Run fused_wqa_wkv (heaviest) on the current (default) stream and the # three lighter input GEMMs on aux streams 0..2 when their owning # module exists. Sync via stream.wait_stream() rather than user-managed - # events: aux streams wait_stream(current) so they observe inputs - # produced upstream on default, then current.wait_stream(aux) joins - # the aux outputs back before returning. + # events: a single fan-out barrier (all aux wait_stream(current)) at + # the start and a single fan-in barrier (current.wait_stream(each aux)) + # at the end. current = torch.cuda.current_stream() aux0, aux1, aux2 = ( self.aux_stream_list[0], @@ -348,38 +348,37 @@ def attn_gemm_parallel_execute(self, hidden_states) -> tuple[Any, ...]: self.aux_stream_list[2], ) + # Fan-out: every aux stream observes default's prior work. + aux0.wait_stream(current) + aux1.wait_stream(current) + aux2.wait_stream(current) + kv_score: torch.Tensor | None = None indexer_weights: torch.Tensor | None = None indexer_kv_score: torch.Tensor | None = None - used_aux: list[torch.cuda.Stream] = [] if self.compressor is not None: - aux0.wait_stream(current) with torch.cuda.stream(aux0): kv_score = cublas_gemm_bf16_bf16_fp32( hidden_states, self.compressor.fused_wkv_wgate.weight ) - used_aux.append(aux0) if self.indexer is not None: - aux1.wait_stream(current) with torch.cuda.stream(aux1): # ReplicatedLinear returns (output, bias); bias is None. indexer_weights, _ = self.indexer.weights_proj(hidden_states) - used_aux.append(aux1) - - aux2.wait_stream(current) with torch.cuda.stream(aux2): indexer_kv_score = cublas_gemm_bf16_bf16_fp32( hidden_states, self.indexer.compressor.fused_wkv_wgate.weight ) - used_aux.append(aux2) # MergedColumnParallelLinear returns (output, bias); bias is None. qr_kv, _ = self.fused_wqa_wkv(hidden_states) - for aux in used_aux: - current.wait_stream(aux) + # Fan-in: default observes every aux stream's outputs. + current.wait_stream(aux0) + current.wait_stream(aux1) + current.wait_stream(aux2) return qr_kv, kv_score, indexer_kv_score, indexer_weights From bdcb5f21e70e384d685bd9322cd619ca0818c724 Mon Sep 17 00:00:00 2001 From: Yongye Zhu Date: Mon, 27 Apr 2026 20:09:20 +0000 Subject: [PATCH 7/7] fall back to elegant approach Signed-off-by: Yongye Zhu --- .../layers/deepseek_v4_attention.py | 84 ++++++++++--------- 1 file changed, 44 insertions(+), 40 deletions(-) diff --git a/vllm/model_executor/layers/deepseek_v4_attention.py b/vllm/model_executor/layers/deepseek_v4_attention.py index 311ca6df1819..a968a06bb650 100644 --- a/vllm/model_executor/layers/deepseek_v4_attention.py +++ b/vllm/model_executor/layers/deepseek_v4_attention.py @@ -4,6 +4,7 @@ DeepseekV4 MLA Attention Layer """ +from collections.abc import Callable from dataclasses import dataclass from typing import TYPE_CHECKING, Any, cast @@ -52,7 +53,10 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, ) -from vllm.utils.multi_stream_utils import maybe_execute_in_parallel +from vllm.utils.multi_stream_utils import ( + execute_in_parallel, + maybe_execute_in_parallel, +) from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata from vllm.v1.attention.backends.mla.flashmla_sparse import ( DeepseekV4FlashMLASparseBackend, @@ -219,10 +223,10 @@ def __init__( ) self.aux_stream_list = mla_modules.aux_stream_list - # ln_events[0..1]: post-GEMM maybe_execute_in_parallel event0/event1. - # The GEMM phase synchronizes via stream.wait_stream() (no user- - # managed events) so it doesn't share event slots with this pair. - self.ln_events = [torch.cuda.Event(), torch.cuda.Event()] + # [0]: GEMM start / post-GEMM event0. [1..3]: GEMM done events; + # [1] doubles as post-GEMM event1. Reuse is safe: GEMM fully joins + # before post-GEMM starts. + self.ln_events = [torch.cuda.Event() for _ in range(4)] assert cache_config is not None, "DeepseekV4 attention requires cache_config" self.swa_cache_layer = DeepseekV4SWACache( @@ -335,50 +339,50 @@ def attn_gemm_parallel_execute(self, hidden_states) -> tuple[Any, ...]: assert self.aux_stream_list is not None assert len(self.aux_stream_list) >= 3 - # Run fused_wqa_wkv (heaviest) on the current (default) stream and the - # three lighter input GEMMs on aux streams 0..2 when their owning - # module exists. Sync via stream.wait_stream() rather than user-managed - # events: a single fan-out barrier (all aux wait_stream(current)) at - # the start and a single fan-in barrier (current.wait_stream(each aux)) - # at the end. - current = torch.cuda.current_stream() - aux0, aux1, aux2 = ( - self.aux_stream_list[0], - self.aux_stream_list[1], - self.aux_stream_list[2], - ) - - # Fan-out: every aux stream observes default's prior work. - aux0.wait_stream(current) - aux1.wait_stream(current) - aux2.wait_stream(current) - - kv_score: torch.Tensor | None = None - indexer_weights: torch.Tensor | None = None - indexer_kv_score: torch.Tensor | None = None + # fused_wqa_wkv (heaviest) on default; the three lighter input GEMMs + # on aux streams 0..2 when their owning module exists. ln_events[0] + # is the fan-out start event; ln_events[1..3] are per-aux done events. + aux_fns: list[Callable[[], Any] | None] = [None, None, None] if self.compressor is not None: - with torch.cuda.stream(aux0): - kv_score = cublas_gemm_bf16_bf16_fp32( - hidden_states, self.compressor.fused_wkv_wgate.weight + # Local ref so the closure keeps a non-None type for mypy. + compressor = self.compressor + + def compressor_kv_score() -> torch.Tensor: + return cublas_gemm_bf16_bf16_fp32( + hidden_states, compressor.fused_wkv_wgate.weight ) + aux_fns[0] = compressor_kv_score + if self.indexer is not None: - with torch.cuda.stream(aux1): + indexer = self.indexer + + def indexer_weights_proj() -> torch.Tensor: # ReplicatedLinear returns (output, bias); bias is None. - indexer_weights, _ = self.indexer.weights_proj(hidden_states) - with torch.cuda.stream(aux2): - indexer_kv_score = cublas_gemm_bf16_bf16_fp32( - hidden_states, self.indexer.compressor.fused_wkv_wgate.weight + weights, _ = indexer.weights_proj(hidden_states) + return weights + + def indexer_compressor_kv_score() -> torch.Tensor: + return cublas_gemm_bf16_bf16_fp32( + hidden_states, indexer.compressor.fused_wkv_wgate.weight ) - # MergedColumnParallelLinear returns (output, bias); bias is None. - qr_kv, _ = self.fused_wqa_wkv(hidden_states) + aux_fns[1] = indexer_weights_proj + aux_fns[2] = indexer_compressor_kv_score + + def fused_wqa_wkv() -> torch.Tensor: + # MergedColumnParallelLinear returns (output, bias); bias is None. + qr_kv, _ = self.fused_wqa_wkv(hidden_states) + return qr_kv - # Fan-in: default observes every aux stream's outputs. - current.wait_stream(aux0) - current.wait_stream(aux1) - current.wait_stream(aux2) + qr_kv, (kv_score, indexer_weights, indexer_kv_score) = execute_in_parallel( + fused_wqa_wkv, + aux_fns, + self.ln_events[0], + self.ln_events[1:4], + self.aux_stream_list[:3], + ) return qr_kv, kv_score, indexer_kv_score, indexer_weights