diff --git a/tests/models/test_deepseek_v4_rocm_multistream.py b/tests/models/test_deepseek_v4_rocm_multistream.py new file mode 100644 index 000000000000..23db2d61eed3 --- /dev/null +++ b/tests/models/test_deepseek_v4_rocm_multistream.py @@ -0,0 +1,165 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from types import SimpleNamespace + +import pytest +import torch + +from vllm.models.deepseek_v4 import attention as dsv4_attention +from vllm.platforms import current_platform + +pytestmark = pytest.mark.skipif( + not current_platform.is_rocm(), reason="ROCm-only DeepSeek-V4 tests" +) + + +def _swa_metadata( + num_decode_tokens: int, + num_prefill_tokens: int = 0, +) -> dsv4_attention.DeepseekSparseSWAMetadata: + return dsv4_attention.DeepseekSparseSWAMetadata( + block_table=torch.empty(0, dtype=torch.int32), + slot_mapping=torch.empty(0, dtype=torch.int64), + block_size=256, + num_decodes=num_decode_tokens, + num_decode_tokens=num_decode_tokens, + num_prefill_tokens=num_prefill_tokens, + ) + + +def _use_rocm_multistream( + cudagraph_runtime_mode: dsv4_attention.CUDAGraphMode, + metadata: dsv4_attention.DeepseekSparseSWAMetadata, +) -> bool: + class _ForwardContext: + pass + + class _Wrapper: + aux_stream_list = [object()] + + forward_context = _ForwardContext() + forward_context.cudagraph_runtime_mode = cudagraph_runtime_mode + attn_metadata = {"layer_0.swa": metadata} + + wrapper_cls = dsv4_attention.DeepseekV4MultiHeadLatentAttentionWrapper + method = wrapper_cls._use_rocm_csa_multistream + return method(_Wrapper(), forward_context, attn_metadata) + + +def test_deepseek_v4_rocm_multistream_decode_policy(): + decode_metadata = _swa_metadata(num_decode_tokens=4) + + assert ( + _use_rocm_multistream(dsv4_attention.CUDAGraphMode.NONE, decode_metadata) + is True + ) + assert ( + _use_rocm_multistream(dsv4_attention.CUDAGraphMode.PIECEWISE, decode_metadata) + is True + ) + assert ( + _use_rocm_multistream(dsv4_attention.CUDAGraphMode.FULL, decode_metadata) + is False + ) + + mixed_metadata = _swa_metadata(num_decode_tokens=4, num_prefill_tokens=1) + assert ( + _use_rocm_multistream(dsv4_attention.CUDAGraphMode.PIECEWISE, mixed_metadata) + is False + ) + + +def test_deepseek_v4_rocm_post_rmsnorm_stream_mapping(monkeypatch): + calls = [] + streams = [object()] + + def fake_maybe_execute_in_parallel( + default_fn, + aux_fn, + start_event, + done_event, + aux_stream=None, + ): + assert aux_stream is streams[0] + + q = default_fn() + compressor_result = aux_fn() + return q, compressor_result + + monkeypatch.setattr( + dsv4_attention, + "maybe_execute_in_parallel", + fake_maybe_execute_in_parallel, + ) + + class _WqB: + def __call__(self, qr): + calls.append("wq_b") + return torch.empty((qr.shape[0], 3)) + + class _Indexer: + def __call__( + self, + hidden_states, + qr, + indexer_kv_score, + indexer_weights, + positions, + rotary_emb, + use_aux_stream=True, + ): + calls.append(("indexer", use_aux_stream)) + return object() + + class _Compressor: + def __call__(self, kv_score, positions, rotary_emb): + calls.append("compressor") + return object() + + wrapper = SimpleNamespace( + wq_b=_WqB(), + indexer=_Indexer(), + compressor=_Compressor(), + n_local_heads=1, + head_dim=3, + indexer_rotary_emb=object(), + rotary_emb=object(), + ln_events=[object(), object(), object()], + ) + + def fake_kv_insert(q, kv, positions, attn_metadata): + calls.append("kv_insert") + return q + + def fail_project_compressor_kv_score(hidden_states, compressor): + raise AssertionError("kv_score should be reused in this test") + + wrapper._fused_qnorm_rope_kv_insert = fake_kv_insert + wrapper._project_compressor_kv_score = fail_project_compressor_kv_score + + hidden_states = torch.empty((2, 3)) + qr = torch.empty((2, 3)) + kv = torch.empty((2, 3)) + positions = torch.empty(2, dtype=torch.int64) + kv_score = torch.empty((2, 3)) + indexer_kv_score = torch.empty((2, 3)) + indexer_weights = torch.empty((2, 3)) + + method = dsv4_attention.DeepseekV4MultiHeadLatentAttentionWrapper + q = method._post_rmsnorm_prepare( + wrapper, + hidden_states, + qr, + kv, + kv_score, + indexer_kv_score, + indexer_weights, + positions, + None, + streams, + True, + ) + + assert q.shape == (2, 1, 3) + assert calls == ["wq_b", "kv_insert", "compressor", ("indexer", False)] diff --git a/vllm/models/deepseek_v4/amd/model.py b/vllm/models/deepseek_v4/amd/model.py index 84318a8107d3..8eca57dfeec6 100644 --- a/vllm/models/deepseek_v4/amd/model.py +++ b/vllm/models/deepseek_v4/amd/model.py @@ -616,16 +616,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.hc_dim = self.hc_mult * config.hidden_size self.rms_norm_eps = config.rms_norm_eps - # Three aux streams: one per non-default input GEMM in - # DeepseekV4MultiHeadLatentAttentionWrapper.attn_gemm_parallel_execute - # (compressor kv_score, indexer.weights_proj, indexer.compressor - # kv_score). fused_wqa_wkv stays on the default stream. - # Disable them on ROCm because of hang issues. - aux_stream_list = ( - None - if current_platform.is_rocm() - else [torch.cuda.Stream() for _ in range(3)] - ) + aux_stream_list = [torch.cuda.Stream(priority=-1)] self.device = current_platform.device_type # Reserved topk indices buffer for all Indexer layers to reuse. diff --git a/vllm/models/deepseek_v4/attention.py b/vllm/models/deepseek_v4/attention.py index d052b41fc541..f1479b8fb35d 100644 --- a/vllm/models/deepseek_v4/attention.py +++ b/vllm/models/deepseek_v4/attention.py @@ -35,6 +35,7 @@ from vllm.config import ( CacheConfig, + CUDAGraphMode, VllmConfig, get_current_vllm_config, ) @@ -65,7 +66,10 @@ DeepseekV4IndexerBackend, get_max_prefill_buffer_size, ) -from vllm.v1.attention.backends.mla.sparse_swa import DeepseekV4SWACache +from vllm.v1.attention.backends.mla.sparse_swa import ( + DeepseekSparseSWAMetadata, + DeepseekV4SWACache, +) from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec if TYPE_CHECKING: @@ -76,6 +80,39 @@ logger = init_logger(__name__) +def _iter_deepseek_v4_swa_metadata( + attn_metadata: dict[str, AttentionMetadata] + | list[dict[str, AttentionMetadata]] + | AttentionMetadata + | None, +): + if attn_metadata is None: + return + if isinstance(attn_metadata, DeepseekSparseSWAMetadata): + yield attn_metadata + return + if isinstance(attn_metadata, dict): + for metadata in attn_metadata.values(): + yield from _iter_deepseek_v4_swa_metadata(metadata) + return + if isinstance(attn_metadata, list): + for item in attn_metadata: + yield from _iter_deepseek_v4_swa_metadata(item) + + +def _is_decode_only_deepseek_v4_step( + attn_metadata: dict[str, AttentionMetadata] + | list[dict[str, AttentionMetadata]] + | None, +) -> bool: + metadata = list(_iter_deepseek_v4_swa_metadata(attn_metadata)) + if not metadata: + return False + return all( + item.num_decode_tokens > 0 and item.num_prefill_tokens == 0 for item in metadata + ) + + def _select_v4_sparse_impl() -> "type[DeepseekV4SparseMLAAttentionImpl]": """Pick the platform-specific V4 sparse MLA impl class. Sole platform check.""" if current_platform.is_rocm(): @@ -216,7 +253,6 @@ def __init__( + 1 # 1B pad ) - # Will be None on ROCm for now. self.aux_stream_list = mla_modules.aux_stream_list # [0]: GEMM start / post-GEMM event0. [1..3]: GEMM done events; # [1] doubles as post-GEMM event1. Reuse is safe: GEMM fully joins @@ -346,8 +382,56 @@ def forward( return self.wo_b(z.flatten(1)) - def attn_gemm_parallel_execute(self, hidden_states) -> tuple[Any, ...]: - aux_streams = self.aux_stream_list + def _aux_streams_for_step( + self, + use_rocm_csa_multistream: bool, + ) -> list[torch.cuda.Stream] | None: + if current_platform.is_rocm(): + return self.aux_stream_list if use_rocm_csa_multistream else None + return self.aux_stream_list + + def _use_rocm_csa_multistream( + self, + forward_context: ForwardContext, + attn_metadata: ( + dict[str, AttentionMetadata] | list[dict[str, AttentionMetadata]] | None + ), + ) -> bool: + if not current_platform.is_rocm(): + return False + if self.aux_stream_list is None: + return False + if not self.aux_stream_list: + return False + + if not _is_decode_only_deepseek_v4_step(attn_metadata): + return False + + graph_mode = forward_context.cudagraph_runtime_mode + assert graph_mode in CUDAGraphMode.valid_runtime_modes() + # Standalone ROCm repro in benchmarks/kernels/rocm_dsv4_stream_probe.py + # shows that putting the Python stream scheduler inside a compiled + # full-graph wrapper can collapse the replayed work onto stream 0. + # Keep the ROCm CSA multi-stream scheduler in eager/piecewise graph + # islands. + return graph_mode in (CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE) + + def _project_compressor_kv_score( + self, + hidden_states: torch.Tensor, + compressor: DeepseekCompressor, + ) -> torch.Tensor: + return torch.mm( + hidden_states, + compressor.fused_wkv_wgate.weight.T, + out_dtype=torch.float32, + ) + + def attn_gemm_parallel_execute( + self, + hidden_states: torch.Tensor, + aux_streams: list[torch.cuda.Stream] | None, + ) -> tuple[Any, ...]: if aux_streams is not None: assert len(aux_streams) >= 3 aux_streams = aux_streams[:3] @@ -355,7 +439,8 @@ def attn_gemm_parallel_execute(self, hidden_states) -> tuple[Any, ...]: # fused_wqa_wkv (heaviest) on default; the three lighter input GEMMs # on aux streams 0..2 when their owning module exists. ln_events[0] # is the fan-out start event; ln_events[1..3] are per-aux done events. - # On ROCm, aux_streams is None and execute_in_parallel runs serially. + # ROCm keeps these projections on the main stream; the measured decode + # win comes from overlapping the post-rmsnorm compressor branch. aux_fns: list[Callable[[], Any] | None] = [None, None, None] if self.compressor is not None: @@ -414,9 +499,17 @@ def attention_impl( ) -> None: forward_context = get_forward_context() attn_metadata = forward_context.attn_metadata + rocm_csa_multistream = self._use_rocm_csa_multistream( + forward_context, attn_metadata + ) + aux_streams = self._aux_streams_for_step(rocm_csa_multistream) + gemm_aux_streams = None if current_platform.is_rocm() else aux_streams qr_kv, kv_score, indexer_kv_score, indexer_weights = ( - self.attn_gemm_parallel_execute(hidden_states) + self.attn_gemm_parallel_execute( + hidden_states, + gemm_aux_streams, + ) ) qr, kv = qr_kv.split([self.q_lora_rank, self.head_dim], dim=-1) @@ -428,12 +521,58 @@ def attention_impl( self.eps, ) + q = self._post_rmsnorm_prepare( + hidden_states, + qr, + kv, + kv_score, + indexer_kv_score, + indexer_weights, + positions, + attn_metadata, + aux_streams, + rocm_csa_multistream, + ) + + # Pad q to FlashMLA-required head count (64 or 128) + if self.n_local_heads < self.padded_heads: + pad_size = self.padded_heads - self.n_local_heads + q = F.pad(q, (0, 0, 0, pad_size), value=0.0) + + # MLA attention writes into the pre-allocated `out` buffer + # ([num_tokens, padded_heads, head_dim]). + self.mla_attn(q, kv, positions, output=out) + + def _post_rmsnorm_prepare( + self, + hidden_states: torch.Tensor, + qr: torch.Tensor, + kv: torch.Tensor, + kv_score: torch.Tensor | None, + indexer_kv_score: torch.Tensor | None, + indexer_weights: torch.Tensor | None, + positions: torch.Tensor, + attn_metadata: ( + dict[str, AttentionMetadata] | list[dict[str, AttentionMetadata]] | None + ), + aux_streams: list[torch.cuda.Stream] | None, + rocm_csa_multistream: bool, + ) -> torch.Tensor: # wq_b + kv_insert (+ MLA compressor when an indexer is present) ride # on the default stream so q stays on its consumer stream (mla_attn # downstream reads q on default). Indexer/compressor go on aux for # overlap with default's GEMM + cache write. + post_aux_streams = aux_streams + rocm_ms_active = ( + current_platform.is_rocm() + and post_aux_streams is not None + and rocm_csa_multistream + ) + if post_aux_streams is not None and not rocm_ms_active: + outer_post_aux_streams = [post_aux_streams[0], post_aux_streams[1]] + else: + outer_post_aux_streams = None if self.indexer is not None: - aux_streams = self.aux_stream_list indexer = self.indexer # Local ref so the closure keeps a non-None type for mypy. assert self.compressor is not None @@ -444,33 +583,60 @@ def wq_b_kv_insert() -> torch.Tensor: q = self._fused_qnorm_rope_kv_insert(q, kv, positions, attn_metadata) return q - # 3-way overlap (matches TRT-LLM PR #14142 Level 1): default runs - # wq_b+kv_insert; slot [0] runs the full indexer; slot [1] runs the - # MLA compressor. Slot [2] is reserved for the indexer's inner - # overlap. ROCm (aux_streams is None) falls back to sequential. - q, _ = execute_in_parallel( - wq_b_kv_insert, - [ - lambda: indexer( + # NVIDIA keeps the 3-way split. ROCm's measured default keeps + # indexer work on default and overlaps only the compressor branch. + def run_indexer() -> Any: + if current_platform.is_rocm(): + return indexer( hidden_states, qr, indexer_kv_score, indexer_weights, positions, self.indexer_rotary_emb, - ), - lambda: compressor(kv_score, positions, self.rotary_emb), - ], - self.ln_events[0], - [self.ln_events[1], self.ln_events[2]], - [aux_streams[0], aux_streams[1]] if aux_streams is not None else None, - enable=aux_streams is not None, - ) + use_aux_stream=False, + ) + return indexer( + hidden_states, + qr, + indexer_kv_score, + indexer_weights, + positions, + self.indexer_rotary_emb, + ) + + def run_compressor() -> Any: + local_kv_score = kv_score + if local_kv_score is None: + local_kv_score = self._project_compressor_kv_score( + hidden_states, compressor + ) + return compressor(local_kv_score, positions, self.rotary_emb) + + indexer_fn: Callable[[], Any] | None = run_indexer + compressor_fn: Callable[[], Any] | None = run_compressor + if rocm_ms_active: + assert post_aux_streams is not None + q, _ = maybe_execute_in_parallel( + wq_b_kv_insert, + run_compressor, + self.ln_events[0], + self.ln_events[1], + post_aux_streams[0], + ) + run_indexer() + else: + q, _ = execute_in_parallel( + wq_b_kv_insert, + [indexer_fn, compressor_fn], + self.ln_events[0], + [self.ln_events[1], self.ln_events[2]], + outer_post_aux_streams, + enable=post_aux_streams is not None, + ) elif self.compressor is not None: # wq_b + kv_insert on default, compressor on aux. - aux_stream = ( - self.aux_stream_list[0] if self.aux_stream_list is not None else None - ) + aux_stream = post_aux_streams[0] if post_aux_streams is not None else None compressor = self.compressor def wq_b_kv_insert() -> torch.Tensor: @@ -478,9 +644,17 @@ def wq_b_kv_insert() -> torch.Tensor: q = self._fused_qnorm_rope_kv_insert(q, kv, positions, attn_metadata) return q + def run_compressor() -> Any: + local_kv_score = kv_score + if local_kv_score is None: + local_kv_score = self._project_compressor_kv_score( + hidden_states, compressor + ) + return compressor(local_kv_score, positions, self.rotary_emb) + q, _ = maybe_execute_in_parallel( wq_b_kv_insert, - lambda: compressor(kv_score, positions, self.rotary_emb), + run_compressor, self.ln_events[0], self.ln_events[1], aux_stream, @@ -490,9 +664,7 @@ def wq_b_kv_insert() -> torch.Tensor: q = self.wq_b(qr).view(-1, self.n_local_heads, self.head_dim) q = self._fused_qnorm_rope_kv_insert(q, kv, positions, attn_metadata) - # MLA attention writes into the pre-allocated `out` buffer - # ([num_tokens, padded_heads, head_dim]). - self.mla_attn(q, kv, positions, output=out) + return q def _fused_qnorm_rope_kv_insert( self, @@ -866,7 +1038,7 @@ def __init__( use_fp4_cache=self.use_fp4_kv, ) - # None on ROCm — maybe_execute_in_parallel falls back to sequential. + # aux_stream is the legacy two-way split used by non-ROCm paths. self.aux_stream = aux_stream self.ln_events: list[torch.cuda.Event] = [ torch.cuda.Event(), @@ -877,14 +1049,15 @@ def forward( self, hidden_states: torch.Tensor, qr: torch.Tensor, - compressed_kv_score: torch.Tensor, - indexer_weights: torch.Tensor, + compressed_kv_score: torch.Tensor | None, + indexer_weights: torch.Tensor | None, positions: torch.Tensor, rotary_emb: nn.Module, + use_aux_stream: bool = True, ) -> torch.Tensor: compressor = self.compressor - def wq_b_and_q_quant(): + def wq_b_and_q_quant(weights: torch.Tensor): # ReplicatedLinear returns (output, bias); bias is None. q, _ = self.wq_b(qr) q = q.view(-1, self.n_head, self.head_dim) @@ -892,19 +1065,22 @@ def wq_b_and_q_quant(): positions, q, rotary_emb.cos_sin_cache, - indexer_weights, + weights, self.softmax_scale, self.n_head**-0.5, use_fp4=self.use_fp4_kv, ) + assert indexer_weights is not None + assert compressed_kv_score is not None + # compressor returns None and writes K to the indexer KV cache; the # join orders that write before indexer_op (skip_k_cache_insert=True). (q_quant, weights), k = maybe_execute_in_parallel( - wq_b_and_q_quant, + lambda: wq_b_and_q_quant(indexer_weights), lambda: compressor(compressed_kv_score, positions, rotary_emb), self.ln_events[0], self.ln_events[1], - self.aux_stream, + self.aux_stream if use_aux_stream else None, ) return self.indexer_op(hidden_states, q_quant, k, weights) diff --git a/vllm/utils/multi_stream_utils.py b/vllm/utils/multi_stream_utils.py index 2203221c5a14..65ce67393fa8 100644 --- a/vllm/utils/multi_stream_utils.py +++ b/vllm/utils/multi_stream_utils.py @@ -8,6 +8,17 @@ import torch +def _record_result_stream(result: Any, stream: torch.cuda.Stream) -> None: + if isinstance(result, torch.Tensor): + result.record_stream(stream) + elif isinstance(result, (tuple, list)): + for item in result: + _record_result_stream(item, stream) + elif isinstance(result, dict): + for item in result.values(): + _record_result_stream(item, stream) + + class AuxStreamType(Enum): Attention = 1 @@ -45,13 +56,15 @@ def maybe_execute_in_parallel( Tuple of (fn0_result, fn1_result). """ if aux_stream is not None: - event0.record() + current_stream = torch.cuda.current_stream() + event0.record(current_stream) result0 = fn0() with torch.cuda.stream(aux_stream): - event0.wait() + aux_stream.wait_event(event0) result1 = fn1() - event1.record() - event1.wait() + event1.record(aux_stream) + current_stream.wait_event(event1) + _record_result_stream(result1, current_stream) else: result0 = fn0() result1 = fn1() @@ -75,9 +88,11 @@ def execute_in_parallel( start_event fans out from the current stream to every launched aux stream; done_events[i] is recorded after aux_fns[i] so the current stream joins - before returning. Falls back to sequential execution on the current stream - when aux_streams is None or enable is False; in that case default_fn runs - first, then aux_fns in order. + before returning. The default-stream function is enqueued before the aux + functions so the critical path is not delayed by CPU launch overhead from + side-stream branches. Falls back to sequential execution on the current + stream when aux_streams is None or enable is False; in that case default_fn + runs first, then aux_fns in order. Args: default_fn: Callable for the default (current) stream. @@ -108,21 +123,24 @@ def execute_in_parallel( ) aux_results = [None] * len(aux_fns) - pending: list[torch.cuda.Event] = [] + current_stream = torch.cuda.current_stream() + pending: list[tuple[torch.cuda.Event, Any]] = [] + + start_event.record(current_stream) + default_result = default_fn() - start_event.record() for i, fn in enumerate(aux_fns): if fn is None: continue - with torch.cuda.stream(aux_streams[i]): - start_event.wait() + aux_stream = aux_streams[i] + with torch.cuda.stream(aux_stream): + aux_stream.wait_event(start_event) aux_results[i] = fn() - done_events[i].record() - pending.append(done_events[i]) - - default_result = default_fn() + done_events[i].record(aux_stream) + pending.append((done_events[i], aux_results[i])) - for ev in pending: - ev.wait() + for ev, result in pending: + current_stream.wait_event(ev) + _record_result_stream(result, current_stream) return default_result, aux_results