diff --git a/vllm/model_executor/layers/deepseek_compressor.py b/vllm/model_executor/layers/deepseek_compressor.py index af2783f604da..cae80c35316a 100644 --- a/vllm/model_executor/layers/deepseek_compressor.py +++ b/vllm/model_executor/layers/deepseek_compressor.py @@ -14,7 +14,6 @@ from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, ) -from vllm.model_executor.layers.utils import cublas_gemm_bf16_bf16_fp32 from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.v1.attention.backend import ( @@ -271,16 +270,12 @@ def __init__( def forward( self, - # [num_tokens, hidden_size] - x: torch.Tensor, + # [num_tokens, 2 * self.coff * self.head_dim] + kv_score: torch.Tensor, # [num_tokens] positions: torch.Tensor, rotary_emb, ) -> None: - num_tokens, _ = x.shape - # bf16 weights/activations but fp32 output for numerical stability of - # the downstream compressor math. - kv_score = cublas_gemm_bf16_bf16_fp32(x, self.fused_wkv_wgate.weight) # Each of shape [num_tokens, coff * self.head_dim] # input bf16, output are fp32 kv, score = kv_score.split( diff --git a/vllm/model_executor/layers/deepseek_v4_attention.py b/vllm/model_executor/layers/deepseek_v4_attention.py index 43242eddb5b2..a968a06bb650 100644 --- a/vllm/model_executor/layers/deepseek_v4_attention.py +++ b/vllm/model_executor/layers/deepseek_v4_attention.py @@ -4,8 +4,9 @@ DeepseekV4 MLA Attention Layer """ +from collections.abc import Callable from dataclasses import dataclass -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING, Any, cast import torch import torch.nn as nn @@ -16,6 +17,7 @@ ReplicatedLinear, ) from vllm.model_executor.layers.sparse_attn_indexer import SparseAttnIndexer +from vllm.model_executor.layers.utils import cublas_gemm_bf16_bf16_fp32 from vllm.utils.deep_gemm import fp8_einsum from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.ops.deepseek_v4_ops import ( @@ -51,7 +53,10 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, ) -from vllm.utils.multi_stream_utils import maybe_execute_in_parallel +from vllm.utils.multi_stream_utils import ( + execute_in_parallel, + maybe_execute_in_parallel, +) from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata from vllm.v1.attention.backends.mla.flashmla_sparse import ( DeepseekV4FlashMLASparseBackend, @@ -94,7 +99,7 @@ class DeepseekV4MLAModules: indexer: torch.nn.Module | None indexer_rotary_emb: torch.nn.Module topk_indices_buffer: torch.Tensor | None - aux_stream: torch.cuda.Stream | None = None + aux_stream_list: list[torch.cuda.Stream] | None = None # --8<-- [start:multi_head_latent_attention] @@ -217,8 +222,11 @@ def __init__( + 1 # 1B pad ) - self.aux_stream = mla_modules.aux_stream - self.ln_events = [torch.cuda.Event(), torch.cuda.Event()] + self.aux_stream_list = mla_modules.aux_stream_list + # [0]: GEMM start / post-GEMM event0. [1..3]: GEMM done events; + # [1] doubles as post-GEMM event1. Reuse is safe: GEMM fully joins + # before post-GEMM starts. + self.ln_events = [torch.cuda.Event() for _ in range(4)] assert cache_config is not None, "DeepseekV4 attention requires cache_config" self.swa_cache_layer = DeepseekV4SWACache( @@ -277,9 +285,6 @@ def forward( hidden_states: torch.Tensor, llama_4_scaling: torch.Tensor | None = None, ) -> torch.Tensor: - qr_kv, _ = self.fused_wqa_wkv(hidden_states) - qr, kv = qr_kv.split([self.q_lora_rank, self.head_dim], dim=-1) - # Pre-allocate attention output with FlashMLA-padded head count. # The op writes into `o_padded`; we slice to n_local_heads after. num_tokens = hidden_states.shape[0] @@ -292,8 +297,6 @@ def forward( # Attention (inside custom op for torch.compile boundary) torch.ops.vllm.deepseek_v4_attention( hidden_states, - qr, - kv, positions, o_padded, self.layer_name, @@ -332,17 +335,71 @@ def forward( return self.wo_b(z.flatten(1)) + def attn_gemm_parallel_execute(self, hidden_states) -> tuple[Any, ...]: + assert self.aux_stream_list is not None + assert len(self.aux_stream_list) >= 3 + + # fused_wqa_wkv (heaviest) on default; the three lighter input GEMMs + # on aux streams 0..2 when their owning module exists. ln_events[0] + # is the fan-out start event; ln_events[1..3] are per-aux done events. + aux_fns: list[Callable[[], Any] | None] = [None, None, None] + + if self.compressor is not None: + # Local ref so the closure keeps a non-None type for mypy. + compressor = self.compressor + + def compressor_kv_score() -> torch.Tensor: + return cublas_gemm_bf16_bf16_fp32( + hidden_states, compressor.fused_wkv_wgate.weight + ) + + aux_fns[0] = compressor_kv_score + + if self.indexer is not None: + indexer = self.indexer + + def indexer_weights_proj() -> torch.Tensor: + # ReplicatedLinear returns (output, bias); bias is None. + weights, _ = indexer.weights_proj(hidden_states) + return weights + + def indexer_compressor_kv_score() -> torch.Tensor: + return cublas_gemm_bf16_bf16_fp32( + hidden_states, indexer.compressor.fused_wkv_wgate.weight + ) + + aux_fns[1] = indexer_weights_proj + aux_fns[2] = indexer_compressor_kv_score + + def fused_wqa_wkv() -> torch.Tensor: + # MergedColumnParallelLinear returns (output, bias); bias is None. + qr_kv, _ = self.fused_wqa_wkv(hidden_states) + return qr_kv + + qr_kv, (kv_score, indexer_weights, indexer_kv_score) = execute_in_parallel( + fused_wqa_wkv, + aux_fns, + self.ln_events[0], + self.ln_events[1:4], + self.aux_stream_list[:3], + ) + + return qr_kv, kv_score, indexer_kv_score, indexer_weights + def attention_impl( self, hidden_states: torch.Tensor, - qr: torch.Tensor, - kv: torch.Tensor, positions: torch.Tensor, out: torch.Tensor, # [num_tokens, padded_heads, head_dim], written in place ) -> None: forward_context = get_forward_context() attn_metadata = forward_context.attn_metadata + qr_kv, kv_score, indexer_kv_score, indexer_weights = ( + self.attn_gemm_parallel_execute(hidden_states) + ) + + qr, kv = qr_kv.split([self.q_lora_rank, self.head_dim], dim=-1) qr, kv = fused_q_kv_rmsnorm( qr, kv, @@ -350,42 +407,60 @@ def attention_impl( self.kv_norm.weight.data, self.eps, ) - q = self.wq_b(qr).view(-1, self.n_local_heads, self.head_dim) - # Overlap kv_insert with whichever of indexer/compressor is present. - # Indexer implies compressor; when both exist, compressor rides on the - # aux stream alongside kv_insert so the heavy indexer owns default. + # wq_b + kv_insert (+ MLA compressor when an indexer is present) ride + # on the default stream so q stays on its consumer stream (mla_attn + # downstream reads q on default). Indexer/compressor go on aux for + # overlap with default's GEMM + cache write. if self.indexer is not None: + assert self.aux_stream_list is not None + aux_stream = self.aux_stream_list[0] indexer = self.indexer # Local ref so the closure keeps a non-None type for mypy. assert self.compressor is not None compressor = self.compressor - def kv_insert_and_compress() -> None: + def wq_b_kv_insert_and_compress() -> torch.Tensor: + q = self.wq_b(qr).view(-1, self.n_local_heads, self.head_dim) self._fused_qnorm_rope_kv_insert(q, kv, positions, attn_metadata) - compressor(hidden_states, positions, self.rotary_emb) - - maybe_execute_in_parallel( - lambda: indexer(hidden_states, qr, positions, self.indexer_rotary_emb), - kv_insert_and_compress, + compressor(kv_score, positions, self.rotary_emb) + return q + + q, _ = maybe_execute_in_parallel( + wq_b_kv_insert_and_compress, + lambda: indexer( + hidden_states, + qr, + indexer_kv_score, + indexer_weights, + positions, + self.indexer_rotary_emb, + ), self.ln_events[0], self.ln_events[1], - self.aux_stream, + aux_stream, ) elif self.compressor is not None: - # Compressor on default, kv_insert on aux. + # wq_b + kv_insert on default, compressor on aux. + assert self.aux_stream_list is not None + aux_stream = self.aux_stream_list[0] compressor = self.compressor - maybe_execute_in_parallel( - lambda: compressor(hidden_states, positions, self.rotary_emb), - lambda: self._fused_qnorm_rope_kv_insert( - q, kv, positions, attn_metadata - ), + + def wq_b_kv_insert() -> torch.Tensor: + q = self.wq_b(qr).view(-1, self.n_local_heads, self.head_dim) + self._fused_qnorm_rope_kv_insert(q, kv, positions, attn_metadata) + return q + + q, _ = maybe_execute_in_parallel( + wq_b_kv_insert, + lambda: compressor(kv_score, positions, self.rotary_emb), self.ln_events[0], self.ln_events[1], - self.aux_stream, + aux_stream, ) else: # SWA-only layer: no compressor, no overlap. + q = self.wq_b(qr).view(-1, self.n_local_heads, self.head_dim) self._fused_qnorm_rope_kv_insert(q, kv, positions, attn_metadata) # Handle dummy run (no metadata). @@ -455,21 +530,17 @@ def _fused_qnorm_rope_kv_insert( def deepseek_v4_attention( hidden_states: torch.Tensor, - qr: torch.Tensor, - kv: torch.Tensor, positions: torch.Tensor, out: torch.Tensor, layer_name: str, ) -> None: forward_context: ForwardContext = get_forward_context() self = forward_context.no_compile_layers[layer_name] - self.attention_impl(hidden_states, qr, kv, positions, out) + self.attention_impl(hidden_states, positions, out) def deepseek_v4_attention_fake( hidden_states: torch.Tensor, - qr: torch.Tensor, - kv: torch.Tensor, positions: torch.Tensor, out: torch.Tensor, layer_name: str, @@ -1057,18 +1128,20 @@ def forward( self, hidden_states: torch.Tensor, qr: torch.Tensor, + compressed_kv_score: torch.Tensor, + indexer_weights: torch.Tensor, positions: torch.Tensor, rotary_emb: nn.Module, ) -> torch.Tensor: + # ReplicatedLinear returns (output, bias); bias is None. q, _ = self.wq_b(qr) q = q.view(-1, self.n_head, self.head_dim) - k = self.compressor(hidden_states, positions, rotary_emb) - weights, _ = self.weights_proj(hidden_states) + k = self.compressor(compressed_kv_score, positions, rotary_emb) q_quant, weights = fused_indexer_q_rope_quant( positions, q, rotary_emb.cos_sin_cache, - weights, + indexer_weights, self.softmax_scale, self.n_head**-0.5, use_fp4=self.use_fp4_kv, diff --git a/vllm/model_executor/models/deepseek_v4.py b/vllm/model_executor/models/deepseek_v4.py index 7733252804b7..6b10532dc827 100644 --- a/vllm/model_executor/models/deepseek_v4.py +++ b/vllm/model_executor/models/deepseek_v4.py @@ -54,7 +54,6 @@ from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.triton_utils import tl, triton -from vllm.utils.multi_stream_utils import AuxStreamType from vllm.utils.torch_utils import direct_register_custom_op from .utils import ( @@ -872,7 +871,7 @@ def __init__( vllm_config: VllmConfig, prefix: str, topk_indices_buffer: torch.Tensor | None = None, - aux_stream: torch.cuda.Stream | None = None, + aux_stream_list: list[torch.cuda.Stream] | None = None, ): super().__init__() config = vllm_config.model_config.hf_config @@ -1005,7 +1004,7 @@ def __init__( indexer=self.indexer, indexer_rotary_emb=self.rotary_emb, topk_indices_buffer=topk_indices_buffer, - aux_stream=aux_stream, + aux_stream_list=aux_stream_list, ) self.mla_attn = DeepseekV4MultiHeadLatentAttentionWrapper( hidden_size=self.hidden_size, @@ -1041,7 +1040,7 @@ def __init__( vllm_config, prefix, topk_indices_buffer: torch.Tensor | None = None, - aux_stream_dict: dict[AuxStreamType, torch.cuda.Stream] | None = None, + aux_stream_list: list[torch.cuda.Stream] | None = None, ): super().__init__() config = vllm_config.model_config.hf_config @@ -1052,9 +1051,7 @@ def __init__( vllm_config, prefix=f"{prefix}.attn", topk_indices_buffer=topk_indices_buffer, - aux_stream=aux_stream_dict.get(AuxStreamType.Attention) - if aux_stream_dict is not None - else None, + aux_stream_list=aux_stream_list, ) self.ffn = DeepseekV4MoE(vllm_config, prefix=f"{prefix}.ffn") @@ -1182,10 +1179,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.hc_dim = self.hc_mult * config.hidden_size self.rms_norm_eps = config.rms_norm_eps - aux_stream_list = [torch.cuda.Stream() for _ in range(1)] - self.aux_stream_dict = { - AuxStreamType.Attention: aux_stream_list[0], - } + # Three aux streams: one per non-default input GEMM in + # DeepseekV4MultiHeadLatentAttentionWrapper.attn_gemm_parallel_execute + # (compressor kv_score, indexer.weights_proj, indexer.compressor + # kv_score). fused_wqa_wkv stays on the default stream. + aux_stream_list = [torch.cuda.Stream() for _ in range(3)] self.device = current_platform.device_type # Reserved topk indices buffer for all Indexer layers to reuse. @@ -1209,7 +1207,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): vllm_config, prefix=prefix, topk_indices_buffer=self.topk_indices_buffer, - aux_stream_dict=self.aux_stream_dict, + aux_stream_list=aux_stream_list, ), prefix=f"{prefix}.layers", ) diff --git a/vllm/utils/multi_stream_utils.py b/vllm/utils/multi_stream_utils.py index cc6bc6462449..c00f08f93329 100644 --- a/vllm/utils/multi_stream_utils.py +++ b/vllm/utils/multi_stream_utils.py @@ -56,3 +56,67 @@ def maybe_execute_in_parallel( result0 = fn0() result1 = fn1() return (result0, result1) + + +def execute_in_parallel( + default_fn: Callable[[], Any], + aux_fns: list[Callable[[], Any] | None], + start_event: torch.cuda.Event, + done_events: list[torch.cuda.Event], + aux_streams: list[torch.cuda.Stream] | None = None, +) -> tuple[Any, list[Any]]: + """Run default_fn on the current stream and aux_fns concurrently on + aux_streams. + + Generalizes maybe_execute_in_parallel to N aux callables. Slots where + aux_fns[i] is None are skipped (no stream switch, no event record); their + corresponding entry in the returned aux_results list is None. + + start_event fans out from the current stream to every launched aux stream; + done_events[i] is recorded after aux_fns[i] so the current stream joins + before returning. When aux_streams is None, all aux_fns run sequentially + on the current stream. + + Args: + default_fn: Callable for the default (current) stream. + aux_fns: Per-aux callables; entries may be None to skip. + start_event: CUDA event recorded on the current stream before + default_fn so each launched aux stream can wait on it. + done_events: One CUDA event per aux slot, recorded after the + corresponding aux_fn. Length must match aux_fns. + aux_streams: Per-aux CUDA streams. Length must match aux_fns. + Multi-stream is disabled when None. + + Returns: + Tuple of (default_result, aux_results) where aux_results[i] is the + result of aux_fns[i] (or None when skipped). + """ + aux_results: list[Any] + if aux_streams is None: + default_result = default_fn() + aux_results = [fn() if fn is not None else None for fn in aux_fns] + return default_result, aux_results + + assert len(aux_fns) == len(aux_streams) == len(done_events), ( + "aux_fns, aux_streams, and done_events must be the same length" + ) + + aux_results = [None] * len(aux_fns) + pending: list[torch.cuda.Event] = [] + + start_event.record() + for i, fn in enumerate(aux_fns): + if fn is None: + continue + with torch.cuda.stream(aux_streams[i]): + start_event.wait() + aux_results[i] = fn() + done_events[i].record() + pending.append(done_events[i]) + + default_result = default_fn() + + for ev in pending: + ev.wait() + + return default_result, aux_results