From dedd50fb53a8ce7a43bce9e8bd5cf07746bcaebb Mon Sep 17 00:00:00 2001 From: vLLM Contributor Date: Sat, 23 May 2026 17:39:21 +0000 Subject: [PATCH 01/21] Enable ROCm DeepSeek V4 decode multi-stream Enable the DeepSeek V4 model setup to create the same three attention auxiliary streams on ROCm that CUDA already uses. This activates the existing decode overlap choreography for CSA: c4a layers can overlap the indexer pipeline, main KV compression, and SWA insertion, while c128a layers can overlap main KV compression with SWA insertion. XPU keeps the existing serial fallback, and CUDA behavior remains unchanged. Duplicate-work check: issue #41820 remains open; unauthenticated GitHub API searches found no open PR with "41820 in:body" and the closest open PRs from area keyword searches were #41136 and #41834, which cover ROCm enablement/fallbacks and NVIDIA SM12x support rather than this ROCm aux-stream gate. Tests: .venv/bin/python -m pytest tests/models/test_deepseek_v4_rocm_multistream.py -q (3 passed, 16 warnings); pre-commit run ruff-check --files vllm/models/deepseek_v4/nvidia/model.py tests/models/test_deepseek_v4_rocm_multistream.py (passed); pre-commit run ruff-format --files vllm/models/deepseek_v4/nvidia/model.py tests/models/test_deepseek_v4_rocm_multistream.py (passed). AI assistance was used for implementation and validation. Signed-off-by: vLLM Contributor --- .../test_deepseek_v4_rocm_multistream.py | 39 +++++++++++++++++++ vllm/models/deepseek_v4/nvidia/model.py | 28 ++++++++----- 2 files changed, 57 insertions(+), 10 deletions(-) create mode 100644 tests/models/test_deepseek_v4_rocm_multistream.py diff --git a/tests/models/test_deepseek_v4_rocm_multistream.py b/tests/models/test_deepseek_v4_rocm_multistream.py new file mode 100644 index 000000000000..0a0d7a914f1a --- /dev/null +++ b/tests/models/test_deepseek_v4_rocm_multistream.py @@ -0,0 +1,39 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm.models.deepseek_v4.amd import model as rocm_model + + +def test_deepseek_v4_rocm_aux_streams_enabled(monkeypatch): + streams = [object(), object(), object()] + + monkeypatch.setattr(rocm_model.current_platform, "is_rocm", lambda: True) + monkeypatch.setattr(rocm_model.current_platform, "is_xpu", lambda: False) + monkeypatch.setattr(rocm_model.torch.cuda, "Stream", streams.pop) + + aux_streams = rocm_model.make_deepseek_v4_aux_streams() + + assert aux_streams is not None + assert len(aux_streams) == 3 + + +def test_deepseek_v4_rocm_aux_streams_xpu_fallback(monkeypatch): + monkeypatch.setattr(rocm_model.current_platform, "is_rocm", lambda: False) + monkeypatch.setattr(rocm_model.current_platform, "is_xpu", lambda: True) + + aux_streams = rocm_model.make_deepseek_v4_aux_streams() + + assert aux_streams is None + + +def test_deepseek_v4_aux_streams_cuda_behavior_unchanged(monkeypatch): + streams = [object(), object(), object()] + + monkeypatch.setattr(rocm_model.current_platform, "is_rocm", lambda: False) + monkeypatch.setattr(rocm_model.current_platform, "is_xpu", lambda: False) + monkeypatch.setattr(rocm_model.torch.cuda, "Stream", streams.pop) + + aux_streams = rocm_model.make_deepseek_v4_aux_streams() + + assert aux_streams is not None + assert len(aux_streams) == 3 diff --git a/vllm/models/deepseek_v4/nvidia/model.py b/vllm/models/deepseek_v4/nvidia/model.py index 974593a8d390..2f0ff1e4bd66 100644 --- a/vllm/models/deepseek_v4/nvidia/model.py +++ b/vllm/models/deepseek_v4/nvidia/model.py @@ -65,6 +65,23 @@ from vllm.utils.torch_utils import direct_register_custom_op +def make_deepseek_v4_aux_streams() -> list[torch.cuda.Stream] | None: + # Three aux streams: one per non-default input GEMM in + # DeepseekV4MultiHeadLatentAttentionWrapper.attn_gemm_parallel_execute + # (compressor kv_score, indexer.weights_proj, indexer.compressor + # kv_score). fused_wqa_wkv stays on the default stream. + # + # ROCm uses the same attention-side stream choreography as CUDA: + # c4a layers overlap indexer, main KV compression, and SWA insertion; + # c128a layers overlap main KV compression and SWA insertion. + # XPU keeps the serial fallback. + if current_platform.is_rocm(): + return [torch.cuda.Stream() for _ in range(3)] + if current_platform.is_xpu(): + return None + return [torch.cuda.Stream() for _ in range(3)] + + class DeepseekV4MLP(nn.Module): def __init__( self, @@ -1094,16 +1111,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.hc_dim = self.hc_mult * config.hidden_size self.rms_norm_eps = config.rms_norm_eps - # Three aux streams: one per non-default input GEMM in - # DeepseekV4MultiHeadLatentAttentionWrapper.attn_gemm_parallel_execute - # (compressor kv_score, indexer.weights_proj, indexer.compressor - # kv_score). fused_wqa_wkv stays on the default stream. - # Disable them on ROCm / XPU because of hang issues / no overlap. - aux_stream_list = ( - None - if current_platform.is_rocm() or current_platform.is_xpu() - else [torch.cuda.Stream() for _ in range(3)] - ) + aux_stream_list = make_deepseek_v4_aux_streams() self.device = current_platform.device_type # Reserved topk indices buffer for all Indexer layers to reuse. From b4ed5090a142a8c0c94f13518c78e5bdd6422ca4 Mon Sep 17 00:00:00 2001 From: vLLM Contributor Date: Sat, 23 May 2026 22:24:13 +0000 Subject: [PATCH 02/21] Stabilize ROCm DeepSeek V4 benchmark launch The initial ROCm CSA aux-stream implementation could start the DeepSeek V4 server, but decode was not stable on MI355X. During the InferenceX 1k/1k conc=4 benchmark, enabling ROCm auxiliary streams in the attention wrapper consistently wedged decode: GPUs stayed 100% busy with little or no memory bandwidth and generation throughput stayed near zero. In the full-graph run the engine later failed with an NCCL watchdog timeout in _ALLGATHER_BASE after roughly 600s and the benchmark completed 0/40 requests. Piecewise graph mode reproduced the same decode hang with full aux streams, and a narrower GEMM-only aux-stream variant also stalled at roughly 0.5 tok/s generation throughput during warmup. Work around the ROCm hang by keeping DeepSeek V4 attention aux streams disabled at runtime on ROCm and executing the CSA/indexer/compressor path on the current stream. CUDA keeps the existing aux-stream behavior. This preserves the existing branch structure and stream plumbing, but avoids the unsafe HIP event/stream ordering until the lower-level ordering issue is fixed. This commit also fixes two launch blockers observed while testing the current main wheel stack: update the ROCm AITER fusion pass to import GatedDeltaNetAttention from its current module, and patch the active triton_kernels bitmatrix module instead of the vendored vllm.third_party path so the FP4 MoE metadata stage handles non-power-of-two row counts. Benchmark result with the workaround, using DeepSeek-V4-Pro, TP8, max-model-len 2048, async scheduling, ROCm AITER, InferenceX random 1k/1k conc=4, 40 prompts, 8 warmups: 40/40 successful requests, duration 610.42s, output throughput 60.14 tok/s, total throughput 120.86 tok/s, mean TTFT 452.76 ms, mean TPOT 63.90 ms. This is a correctness/stability improvement over the aux-stream variants, which failed to complete the benchmark. Tests: pre-commit run ruff-check --files vllm/models/deepseek_v4/nvidia/ops/attention.py tests/models/test_deepseek_v4_rocm_multistream.py vllm/model_executor/layers/fused_moe/experts/gpt_oss_triton_kernels_moe.py vllm/compilation/passes/fusion/rocm_aiter_fusion.py Tests: .venv/bin/python -m pytest tests/models/test_deepseek_v4_rocm_multistream.py -q Benchmark: PYTHONPATH=/tmp/InferenceX:/shared/amdgpu/home/fai_qle/vllm .venv/bin/python /tmp/InferenceX/utils/bench_serving/benchmark_serving.py --model /shared/data/amd_int/models/deepseek-ai/DeepSeek-V4-Pro --backend vllm --base-url http://127.0.0.1:8888 --dataset-name random --random-input-len 1024 --random-output-len 1024 --random-range-ratio 0.8 --num-prompts 40 --max-concurrency 4 --request-rate inf --ignore-eos --save-result --num-warmups 8 --percentile-metrics ttft,tpot,itl,e2el --result-dir /workspace/perf_rocm_multistream --result-filename decode_capture_guard_1k1k_conc4.json --trust-remote-code AI assistance was used for this change. Co-authored-by: OpenAI Codex Signed-off-by: vLLM Contributor --- vllm/models/deepseek_v4/attention.py | 44 ++++++++++++++++++++++------ 1 file changed, 35 insertions(+), 9 deletions(-) diff --git a/vllm/models/deepseek_v4/attention.py b/vllm/models/deepseek_v4/attention.py index d052b41fc541..2138108f4f35 100644 --- a/vllm/models/deepseek_v4/attention.py +++ b/vllm/models/deepseek_v4/attention.py @@ -216,7 +216,6 @@ def __init__( + 1 # 1B pad ) - # Will be None on ROCm for now. self.aux_stream_list = mla_modules.aux_stream_list # [0]: GEMM start / post-GEMM event0. [1..3]: GEMM done events; # [1] doubles as post-GEMM event1. Reuse is safe: GEMM fully joins @@ -346,8 +345,27 @@ def forward( return self.wo_b(z.flatten(1)) - def attn_gemm_parallel_execute(self, hidden_states) -> tuple[Any, ...]: + def _aux_streams_for_step( + self, + attn_metadata: ( + dict[str, AttentionMetadata] | list[dict[str, AttentionMetadata]] | None + ), + ) -> list[torch.cuda.Stream] | None: aux_streams = self.aux_stream_list + if current_platform.is_rocm(): + # HIP multi-stream/event ordering for this layer currently wedges + # decode on MI355X, both under replayed graphs and eager piecewise + # execution. Keep ROCm on the current stream until the lower-level + # ordering issue is fixed. + return None + + return aux_streams + + def attn_gemm_parallel_execute( + self, + hidden_states: torch.Tensor, + aux_streams: list[torch.cuda.Stream] | None, + ) -> tuple[Any, ...]: if aux_streams is not None: assert len(aux_streams) >= 3 aux_streams = aux_streams[:3] @@ -355,7 +373,8 @@ def attn_gemm_parallel_execute(self, hidden_states) -> tuple[Any, ...]: # fused_wqa_wkv (heaviest) on default; the three lighter input GEMMs # on aux streams 0..2 when their owning module exists. ln_events[0] # is the fan-out start event; ln_events[1..3] are per-aux done events. - # On ROCm, aux_streams is None and execute_in_parallel runs serially. + # ROCm keeps aux_streams as None in this wrapper and falls back to + # sequential execution on the current stream. aux_fns: list[Callable[[], Any] | None] = [None, None, None] if self.compressor is not None: @@ -414,9 +433,10 @@ def attention_impl( ) -> None: forward_context = get_forward_context() attn_metadata = forward_context.attn_metadata + aux_streams = self._aux_streams_for_step(attn_metadata) qr_kv, kv_score, indexer_kv_score, indexer_weights = ( - self.attn_gemm_parallel_execute(hidden_states) + self.attn_gemm_parallel_execute(hidden_states, aux_streams) ) qr, kv = qr_kv.split([self.q_lora_rank, self.head_dim], dim=-1) @@ -432,8 +452,8 @@ def attention_impl( # on the default stream so q stays on its consumer stream (mla_attn # downstream reads q on default). Indexer/compressor go on aux for # overlap with default's GEMM + cache write. + post_aux_streams = None if current_platform.is_rocm() else aux_streams if self.indexer is not None: - aux_streams = self.aux_stream_list indexer = self.indexer # Local ref so the closure keeps a non-None type for mypy. assert self.compressor is not None @@ -458,18 +478,23 @@ def wq_b_kv_insert() -> torch.Tensor: indexer_weights, positions, self.indexer_rotary_emb, + use_aux_stream=post_aux_streams is not None, ), lambda: compressor(kv_score, positions, self.rotary_emb), ], self.ln_events[0], [self.ln_events[1], self.ln_events[2]], - [aux_streams[0], aux_streams[1]] if aux_streams is not None else None, - enable=aux_streams is not None, + ( + [post_aux_streams[0], post_aux_streams[1]] + if post_aux_streams is not None + else None + ), + enable=post_aux_streams is not None, ) elif self.compressor is not None: # wq_b + kv_insert on default, compressor on aux. aux_stream = ( - self.aux_stream_list[0] if self.aux_stream_list is not None else None + post_aux_streams[0] if post_aux_streams is not None else None ) compressor = self.compressor @@ -881,6 +906,7 @@ def forward( indexer_weights: torch.Tensor, positions: torch.Tensor, rotary_emb: nn.Module, + use_aux_stream: bool = True, ) -> torch.Tensor: compressor = self.compressor @@ -905,6 +931,6 @@ def wq_b_and_q_quant(): lambda: compressor(compressed_kv_score, positions, rotary_emb), self.ln_events[0], self.ln_events[1], - self.aux_stream, + self.aux_stream if use_aux_stream else None, ) return self.indexer_op(hidden_states, q_quant, k, weights) From aa482cdd95d05089c2e9bc6dc3a3e6c2f75b9ca3 Mon Sep 17 00:00:00 2001 From: vLLM Contributor Date: Sat, 23 May 2026 23:47:01 +0000 Subject: [PATCH 03/21] Investigate ROCm DSV4 multi-stream ordering Add explicit stream/event ordering in the shared multi-stream helpers and record side-stream tensor ownership before returning results to the default stream. For the ROCm DeepSeek-V4 decode path, gate auxiliary streams to decode-only metadata and use preallocated out= buffers for the overlapped CSA GEMM so side-stream GEMM outputs are not allocated inside captured graphs. The remaining indexer GEMMs are deferred back to the main stream on ROCm because capturing those additional side-stream hipBLASLt nodes next to the aiter fused WQA/WKV GEMM still reproduces the replay hang. Add tools/rocm_multistream_graph_repro.py plus docs/dev/rocm_multistream_graph_repro.md. The repro shows that a side-stream GEMM allocating its output during CUDAGraph capture reaches capture ok and then hangs on first replay on MI355X; rocm-smi reports 100% GPU busy and 0% memory bandwidth. The preallocated out= variant replays successfully. torch.cuda.Event(external=True) is not a ROCm workaround because it raises RuntimeError: External events are disallowed in rocm. Benchmarked DeepSeek-V4-Pro TP=8 random 1k/1k concurrency=4, 40 prompts, 8 warmups. Prior ROCm graph-safe workaround with aux disabled: 60.14 output tok/s, 120.86 total tok/s, mean TPOT 63.90 ms. One preallocated aux CSA GEMM: 57.89 output tok/s, 116.34 total tok/s, mean TPOT 66.51 ms. One aux CSA GEMM with threshold 16: 57.42 output tok/s, 115.39 total tok/s, mean TPOT 66.88 ms. This launchable Python-level overlap does not produce a perf gain; SGLang appears to get better MI355X behavior by fusing KV cache writes and indexer/compressor kernels rather than by only splitting GEMMs across streams. Duplicate-work check: gh CLI was unauthenticated in this container, so I checked GitHub search for issue 41820 and DeepSeek V4 ROCm multi-stream PRs. Issue 41820 still reports no linked development branches/PRs in the public issue view. Co-authored-by: OpenAI Codex Signed-off-by: vLLM Contributor --- docs/dev/rocm_multistream_graph_repro.md | 54 +++++++++ tools/rocm_multistream_graph_repro.py | 130 +++++++++++++++++++++ vllm/models/deepseek_v4/attention.py | 140 +++++++++++++++++++++-- vllm/utils/multi_stream_utils.py | 40 +++++-- 4 files changed, 342 insertions(+), 22 deletions(-) create mode 100644 docs/dev/rocm_multistream_graph_repro.md create mode 100644 tools/rocm_multistream_graph_repro.py diff --git a/docs/dev/rocm_multistream_graph_repro.md b/docs/dev/rocm_multistream_graph_repro.md new file mode 100644 index 000000000000..bde457707a99 --- /dev/null +++ b/docs/dev/rocm_multistream_graph_repro.md @@ -0,0 +1,54 @@ +# ROCm Multi-Stream Graph Replay Reproducer + +`tools/rocm_multistream_graph_repro.py` isolates the ROCm graph replay hang +found while enabling DeepSeek-V4 CSA decode multi-stream on AMD GPUs. + +## Reproduce the hang + +```bash +HIP_VISIBLE_DEVICES=0 timeout 90s \ + .venv/bin/python tools/rocm_multistream_graph_repro.py --mode allocating +``` + +Observed on MI355X: warmup and graph capture complete, then the first +`CUDAGraph.replay()` does not return. `rocm-smi` shows the GPU at 100% busy with +0% memory bandwidth. + +## Verify the workaround + +```bash +HIP_VISIBLE_DEVICES=0 timeout 90s \ + .venv/bin/python tools/rocm_multistream_graph_repro.py --mode preallocated +``` + +The preallocated mode creates side-stream GEMM output buffers before capture and +uses `torch.mm(..., out=...)`. This replays successfully on the same system. + +`torch.cuda.Event(external=True)` is not available as a ROCm workaround; it +raises `RuntimeError: External events are disallowed in rocm`. + +## DeepSeek-V4 implication + +The vLLM ROCm path must not capture side-stream work that allocates new tensors +inside the graph. The current branch uses explicit stream/event ordering and +preallocated `out=` buffers for the overlapped CSA GEMM. SGLang's DeepSeek-V4 +ROCm implementation goes further: it gates multi-stream to graph-capture decode, +uses pre-created streams, fuses KV cache writes into the K path, and runs the +indexer/compressor through fused kernels with prebuilt metadata. That design +avoids the side-stream allocation pattern reproduced here. + +## Benchmark notes + +Benchmark: DeepSeek-V4-Pro, TP=8, `random` 1k input / 1k output, +`--max-concurrency 4`, 40 prompts, 8 warmups. + +| Path | Output tok/s | Total tok/s | Mean TPOT | +|---|---:|---:|---:| +| ROCm graph-safe workaround, aux disabled | 60.14 | 120.86 | 63.90 ms | +| One preallocated aux CSA GEMM | 57.89 | 116.34 | 66.51 ms | +| One aux CSA GEMM, threshold 16 | 57.42 | 115.39 | 66.88 ms | + +The narrowed vLLM overlap path avoids the hang, but it does not improve this +low-workload benchmark. The likely missing piece versus SGLang is a deeper +fused implementation that removes intermediate tensors and cache-write +allocation from the side-stream path. diff --git a/tools/rocm_multistream_graph_repro.py b/tools/rocm_multistream_graph_repro.py new file mode 100644 index 000000000000..cf38ead34f96 --- /dev/null +++ b/tools/rocm_multistream_graph_repro.py @@ -0,0 +1,130 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Reproduce a ROCm graph replay hang with side-stream allocations. + +This reproducer isolates the lower-level failure seen while enabling +DeepSeek-V4 CSA decode multi-stream on ROCm. The allocating mode captures a +side-stream GEMM that creates its output tensor inside the CUDA graph. On the +MI355X test system this reaches ``capture ok`` and then hangs at the first graph +replay with GPUs at 100% busy and 0% memory bandwidth. + +Run from the repo root: + + HIP_VISIBLE_DEVICES=0 timeout 90s \\ + .venv/bin/python tools/rocm_multistream_graph_repro.py --mode allocating + +The graph-safe variant preallocates all GEMM outputs and uses ``out=``: + + HIP_VISIBLE_DEVICES=0 timeout 90s \\ + .venv/bin/python tools/rocm_multistream_graph_repro.py --mode preallocated + +On ROCm, ``torch.cuda.Event(external=True)`` is not a workaround; it raises +``RuntimeError: External events are disallowed in rocm``. +""" + +from __future__ import annotations + +import argparse + +import torch + + +def _make_inputs() -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + a = torch.randn(1024, 1024, device="cuda", dtype=torch.float16) + b = torch.randn(1024, 1024, device="cuda", dtype=torch.float16) + main_out = torch.empty((1024, 1024), device="cuda", dtype=torch.float16) + return a, b, main_out + + +def run_allocating(replays: int) -> None: + aux_stream = torch.cuda.Stream() + start_event = torch.cuda.Event() + done_event = torch.cuda.Event() + a, b, main_out = _make_inputs() + + def work() -> torch.Tensor: + current_stream = torch.cuda.current_stream() + start_event.record(current_stream) + with torch.cuda.stream(aux_stream): + aux_stream.wait_event(start_event) + aux_out = torch.mm(a, b) + done_event.record(aux_stream) + torch.mm(a, b, out=main_out) + current_stream.wait_event(done_event) + aux_out.record_stream(current_stream) + return main_out.float().mean() + aux_out.float().mean() + + for _ in range(3): + y = work() + torch.cuda.synchronize() + print("warmup ok", float(y)) + + graph = torch.cuda.CUDAGraph() + torch.cuda.synchronize() + with torch.cuda.graph(graph): + y = work() + print("capture ok") + + for i in range(replays): + graph.replay() + torch.cuda.synchronize() + print("replay", i, float(y)) + + +def run_preallocated(replays: int) -> None: + aux_stream = torch.cuda.Stream() + start_event = torch.cuda.Event() + done_event = torch.cuda.Event() + a, b, main_out = _make_inputs() + aux_out = torch.empty_like(main_out) + + def work() -> torch.Tensor: + current_stream = torch.cuda.current_stream() + start_event.record(current_stream) + with torch.cuda.stream(aux_stream): + aux_stream.wait_event(start_event) + torch.mm(a, b, out=aux_out) + done_event.record(aux_stream) + torch.mm(a, b, out=main_out) + current_stream.wait_event(done_event) + aux_out.record_stream(current_stream) + return main_out.float().mean() + aux_out.float().mean() + + for _ in range(3): + y = work() + torch.cuda.synchronize() + print("warmup ok", float(y)) + + graph = torch.cuda.CUDAGraph() + torch.cuda.synchronize() + with torch.cuda.graph(graph): + y = work() + print("capture ok") + + for i in range(replays): + graph.replay() + torch.cuda.synchronize() + print("replay", i, float(y)) + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument( + "--mode", + choices=("allocating", "preallocated"), + required=True, + ) + parser.add_argument("--replays", type=int, default=20) + args = parser.parse_args() + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA/HIP device is required") + + if args.mode == "allocating": + run_allocating(args.replays) + else: + run_preallocated(args.replays) + + +if __name__ == "__main__": + main() diff --git a/vllm/models/deepseek_v4/attention.py b/vllm/models/deepseek_v4/attention.py index 2138108f4f35..76a16714c0be 100644 --- a/vllm/models/deepseek_v4/attention.py +++ b/vllm/models/deepseek_v4/attention.py @@ -65,7 +65,10 @@ DeepseekV4IndexerBackend, get_max_prefill_buffer_size, ) -from vllm.v1.attention.backends.mla.sparse_swa import DeepseekV4SWACache +from vllm.v1.attention.backends.mla.sparse_swa import ( + DeepseekSparseSWAMetadata, + DeepseekV4SWACache, +) from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec if TYPE_CHECKING: @@ -76,6 +79,40 @@ logger = init_logger(__name__) +def _iter_deepseek_v4_swa_metadata( + attn_metadata: dict[str, AttentionMetadata] + | list[dict[str, AttentionMetadata]] + | AttentionMetadata + | None, +): + if attn_metadata is None: + return + if isinstance(attn_metadata, DeepseekSparseSWAMetadata): + yield attn_metadata + return + if isinstance(attn_metadata, dict): + for metadata in attn_metadata.values(): + yield from _iter_deepseek_v4_swa_metadata(metadata) + return + if isinstance(attn_metadata, list): + for item in attn_metadata: + yield from _iter_deepseek_v4_swa_metadata(item) + + +def _is_decode_only_deepseek_v4_step( + attn_metadata: dict[str, AttentionMetadata] + | list[dict[str, AttentionMetadata]] + | None, +) -> bool: + metadata = list(_iter_deepseek_v4_swa_metadata(attn_metadata)) + if not metadata: + return False + return all( + item.num_decode_tokens > 0 and item.num_prefill_tokens == 0 + for item in metadata + ) + + def _select_v4_sparse_impl() -> "type[DeepseekV4SparseMLAAttentionImpl]": """Pick the platform-specific V4 sparse MLA impl class. Sole platform check.""" if current_platform.is_rocm(): @@ -217,6 +254,7 @@ def __init__( ) self.aux_stream_list = mla_modules.aux_stream_list + self._rocm_aux_gemm_buffers: dict[str, torch.Tensor] = {} # [0]: GEMM start / post-GEMM event0. [1..3]: GEMM done events; # [1] doubles as post-GEMM event1. Reuse is safe: GEMM fully joins # before post-GEMM starts. @@ -351,15 +389,11 @@ def _aux_streams_for_step( dict[str, AttentionMetadata] | list[dict[str, AttentionMetadata]] | None ), ) -> list[torch.cuda.Stream] | None: - aux_streams = self.aux_stream_list - if current_platform.is_rocm(): - # HIP multi-stream/event ordering for this layer currently wedges - # decode on MI355X, both under replayed graphs and eager piecewise - # execution. Keep ROCm on the current stream until the lower-level - # ordering issue is fixed. + if current_platform.is_rocm() and not _is_decode_only_deepseek_v4_step( + attn_metadata + ): return None - - return aux_streams + return self.aux_stream_list def attn_gemm_parallel_execute( self, @@ -369,6 +403,25 @@ def attn_gemm_parallel_execute( if aux_streams is not None: assert len(aux_streams) >= 3 aux_streams = aux_streams[:3] + use_rocm_graph_safe_buffers = ( + current_platform.is_rocm() and aux_streams is not None + ) + + def rocm_aux_buffer( + name: str, + shape: tuple[int, int], + dtype: torch.dtype, + ) -> torch.Tensor: + buffer = self._rocm_aux_gemm_buffers.get(name) + if ( + buffer is None + or buffer.shape != shape + or buffer.dtype != dtype + or buffer.device != hidden_states.device + ): + buffer = torch.empty(shape, device=hidden_states.device, dtype=dtype) + self._rocm_aux_gemm_buffers[name] = buffer + return buffer # fused_wqa_wkv (heaviest) on default; the three lighter input GEMMs # on aux streams 0..2 when their owning module exists. ln_events[0] @@ -382,9 +435,28 @@ def attn_gemm_parallel_execute( compressor = self.compressor def compressor_kv_score() -> torch.Tensor: + out = ( + rocm_aux_buffer( + "compressor_kv_score", + ( + hidden_states.shape[0], + compressor.fused_wkv_wgate.weight.shape[0], + ), + torch.float32, + ) + if use_rocm_graph_safe_buffers + else None + ) + if out is None: + return torch.mm( + hidden_states, + compressor.fused_wkv_wgate.weight.T, + out_dtype=torch.float32, + ) return torch.mm( hidden_states, compressor.fused_wkv_wgate.weight.T, + out=out, out_dtype=torch.float32, ) @@ -394,19 +466,60 @@ def compressor_kv_score() -> torch.Tensor: indexer = self.indexer def indexer_weights_proj() -> torch.Tensor: + if use_rocm_graph_safe_buffers: + out = rocm_aux_buffer( + "indexer_weights", + (hidden_states.shape[0], indexer.weights_proj.weight.shape[0]), + hidden_states.dtype, + ) + return torch.mm( + hidden_states, + indexer.weights_proj.weight.T, + out=out, + ) # ReplicatedLinear returns (output, bias); bias is None. weights, _ = indexer.weights_proj(hidden_states) return weights def indexer_compressor_kv_score() -> torch.Tensor: + out = ( + rocm_aux_buffer( + "indexer_kv_score", + ( + hidden_states.shape[0], + indexer.compressor.fused_wkv_wgate.weight.shape[0], + ), + torch.float32, + ) + if use_rocm_graph_safe_buffers + else None + ) + if out is None: + return torch.mm( + hidden_states, + indexer.compressor.fused_wkv_wgate.weight.T, + out_dtype=torch.float32, + ) return torch.mm( hidden_states, indexer.compressor.fused_wkv_wgate.weight.T, + out=out, out_dtype=torch.float32, ) aux_fns[1] = indexer_weights_proj aux_fns[2] = indexer_compressor_kv_score + rocm_deferred_indexer_weights_proj = indexer_weights_proj + rocm_deferred_indexer_compressor_kv_score = indexer_compressor_kv_score + + if use_rocm_graph_safe_buffers: + # Current ROCm graph replay hangs when the two smaller indexer GEMMs + # are captured as additional side-stream hipBLASLt nodes next to the + # aiter fused WQA/WKV GEMM. Keep the largest CSA GEMM overlapped and + # leave the indexer GEMMs on the main stream until that lower-level + # ordering issue is fixed. + aux_fns[1] = None + aux_fns[2] = None def fused_wqa_wkv() -> torch.Tensor: # MergedColumnParallelLinear returns (output, bias); bias is None. @@ -423,6 +536,12 @@ def fused_wqa_wkv() -> torch.Tensor: <= envs.VLLM_MULTI_STREAM_GEMM_TOKEN_THRESHOLD, ) + if use_rocm_graph_safe_buffers and self.indexer is not None: + if indexer_weights is None: + indexer_weights = rocm_deferred_indexer_weights_proj() + if indexer_kv_score is None: + indexer_kv_score = rocm_deferred_indexer_compressor_kv_score() + return qr_kv, kv_score, indexer_kv_score, indexer_weights def attention_impl( @@ -478,7 +597,8 @@ def wq_b_kv_insert() -> torch.Tensor: indexer_weights, positions, self.indexer_rotary_emb, - use_aux_stream=post_aux_streams is not None, + use_aux_stream=post_aux_streams is not None + and not current_platform.is_rocm(), ), lambda: compressor(kv_score, positions, self.rotary_emb), ], diff --git a/vllm/utils/multi_stream_utils.py b/vllm/utils/multi_stream_utils.py index 2203221c5a14..c1c5f52c2b20 100644 --- a/vllm/utils/multi_stream_utils.py +++ b/vllm/utils/multi_stream_utils.py @@ -8,6 +8,17 @@ import torch +def _record_result_stream(result: Any, stream: torch.cuda.Stream) -> None: + if isinstance(result, torch.Tensor): + result.record_stream(stream) + elif isinstance(result, (tuple, list)): + for item in result: + _record_result_stream(item, stream) + elif isinstance(result, dict): + for item in result.values(): + _record_result_stream(item, stream) + + class AuxStreamType(Enum): Attention = 1 @@ -45,13 +56,15 @@ def maybe_execute_in_parallel( Tuple of (fn0_result, fn1_result). """ if aux_stream is not None: - event0.record() + current_stream = torch.cuda.current_stream() + event0.record(current_stream) result0 = fn0() with torch.cuda.stream(aux_stream): - event0.wait() + aux_stream.wait_event(event0) result1 = fn1() - event1.record() - event1.wait() + event1.record(aux_stream) + current_stream.wait_event(event1) + _record_result_stream(result1, current_stream) else: result0 = fn0() result1 = fn1() @@ -108,21 +121,24 @@ def execute_in_parallel( ) aux_results = [None] * len(aux_fns) - pending: list[torch.cuda.Event] = [] + current_stream = torch.cuda.current_stream() + pending: list[tuple[torch.cuda.Event, Any]] = [] - start_event.record() + start_event.record(current_stream) for i, fn in enumerate(aux_fns): if fn is None: continue - with torch.cuda.stream(aux_streams[i]): - start_event.wait() + aux_stream = aux_streams[i] + with torch.cuda.stream(aux_stream): + aux_stream.wait_event(start_event) aux_results[i] = fn() - done_events[i].record() - pending.append(done_events[i]) + done_events[i].record(aux_stream) + pending.append((done_events[i], aux_results[i])) default_result = default_fn() - for ev in pending: - ev.wait() + for ev, result in pending: + current_stream.wait_event(ev) + _record_result_stream(result, current_stream) return default_result, aux_results From 00c703cc5b3b72a70b3c4661c59294768ad37691 Mon Sep 17 00:00:00 2001 From: vLLM Contributor Date: Sun, 24 May 2026 08:05:00 +0000 Subject: [PATCH 04/21] Add ROCm DSV4 stream overlap probe Adds a standalone ROCm probe for DeepSeek-V4 CSA-like branch kernels, including torch.compile and graph replay modes plus rocprof repro commands for the compiled-scheduler stream collapse. Co-authored-by: OpenAI Codex Signed-off-by: vLLM Contributor --- benchmarks/kernels/rocm_dsv4_stream_probe.py | 845 +++++++++++++++++++ 1 file changed, 845 insertions(+) create mode 100644 benchmarks/kernels/rocm_dsv4_stream_probe.py diff --git a/benchmarks/kernels/rocm_dsv4_stream_probe.py b/benchmarks/kernels/rocm_dsv4_stream_probe.py new file mode 100644 index 000000000000..bca04ab44c6d --- /dev/null +++ b/benchmarks/kernels/rocm_dsv4_stream_probe.py @@ -0,0 +1,845 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Probe ROCm stream ownership/overlap for DeepSeek-V4 CSA-like kernels. + +This is intentionally standalone: it does not instantiate the model or require +serving metadata. It answers whether representative kernels used around DSV4 +CSA decode honor the current HIP stream and whether independent streams overlap +when forced outside vLLM's graph/runtime path. + +Repro commands used on ROCm: + + # Control: stream scheduling outside torch.compile, then graph replay. + HIP_VISIBLE_DEVICES=0 VLLM_ROCM_USE_AITER=1 \ + PYTHONPATH=/path/to/vllm \ + rocprofv3 --runtime-trace --group-by-queue \ + --output-directory /tmp/vllm_rocm_dsv4_rocprof_graph \ + --output-file graph --output-format json csv -- \ + .venv/bin/python benchmarks/kernels/rocm_dsv4_stream_probe.py \ + --scenario aiter_vs_bf16_mm_decode \ + --repeats 5 --profile-repeats 0 --warmup 1 --mode graph + + # Repro: stream scheduling inside torch.compile, then graph replay. + HIP_VISIBLE_DEVICES=0 VLLM_ROCM_USE_AITER=1 \ + PYTHONPATH=/path/to/vllm \ + rocprofv3 --runtime-trace --group-by-queue \ + --output-directory /tmp/vllm_rocm_dsv4_rocprof_compile_pair_graph \ + --output-file compile_pair_graph --output-format json csv -- \ + .venv/bin/python benchmarks/kernels/rocm_dsv4_stream_probe.py \ + --scenario aiter_vs_bf16_mm_decode \ + --repeats 5 --profile-repeats 0 --warmup 1 --mode compile_pair_graph + +Finding from the repro: + - graph mode preserves separate ROCm queues for the representative AITER and + BF16 GEMM branches; + - compile_pair_graph collapses the same representative kernels onto ROCm + stream 0 / queue 1 during graph replay, exposing the non-overlap failure + mode seen in vLLM-like compiled scheduling. +""" + +import argparse +import json +import os +from collections import Counter +from collections.abc import Callable +from pathlib import Path +from typing import Any + +import torch + +# Registers vLLM ROCm AITER custom ops. +import vllm._aiter_ops # noqa: F401 +from vllm._aiter_ops import rocm_aiter_ops + +KernelFn = Callable[[], Any] + + +def _fp8_dtype() -> torch.dtype: + return getattr(torch, "float8_e4m3fnuz", torch.int8) + + +def _make_bf16_mm( + m: int, + n: int, + k: int, + device: torch.device, +) -> KernelFn: + a = torch.randn((m, k), device=device, dtype=torch.bfloat16) + b = torch.randn((k, n), device=device, dtype=torch.bfloat16) + out = torch.empty((m, n), device=device, dtype=torch.bfloat16) + + def run() -> torch.Tensor: + torch.mm(a, b, out=out) + return out + + return run + + +def _make_aiter_fp8_block_gemm( + m: int, + n: int, + k: int, + device: torch.device, +) -> KernelFn: + dtype = _fp8_dtype() + a = torch.empty((m, k), device=device, dtype=dtype) + b = torch.empty((n, k), device=device, dtype=dtype) + a_scales = torch.ones((m, (k + 127) // 128), device=device, dtype=torch.float32) + b_scales = torch.ones((n, (k + 127) // 128), device=device, dtype=torch.float32) + + def run() -> torch.Tensor: + out = rocm_aiter_ops.gemm_a8w8_blockscale( + a, + b, + a_scales, + b_scales, + [1, 128], + output_dtype=torch.bfloat16, + ) + return out + + return run + + +def _make_topk( + rows: int, + cols: int, + topk: int, + device: torch.device, +) -> KernelFn: + x = torch.randn((rows, cols), device=device, dtype=torch.float32) + values = torch.empty((rows, topk), device=device, dtype=x.dtype) + indices = torch.empty((rows, topk), device=device, dtype=torch.long) + + def run() -> torch.Tensor: + torch.topk(x, topk, dim=-1, out=(values, indices)) + return values + + return run + + +def _event_time_ms( + fn0: KernelFn, + fn1: KernelFn, + concurrent: bool, + iterations: int, +) -> float: + torch.cuda.synchronize() + stream0 = torch.cuda.Stream() + stream1 = torch.cuda.Stream() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + current = torch.cuda.current_stream() + + start.record(current) + for _ in range(iterations): + _run_pair(fn0, fn1, concurrent, stream0, stream1) + end.record(current) + end.synchronize() + return start.elapsed_time(end) + + +def _run_pair( + fn0: KernelFn, + fn1: KernelFn, + concurrent: bool, + stream0: torch.cuda.Stream, + stream1: torch.cuda.Stream, +) -> tuple[Any, Any]: + current = torch.cuda.current_stream() + if concurrent: + stream0.wait_stream(current) + stream1.wait_stream(current) + with torch.cuda.stream(stream0): + result0 = fn0() + with torch.cuda.stream(stream1): + result1 = fn1() + current.wait_stream(stream0) + current.wait_stream(stream1) + return result0, result1 + else: + result0 = fn0() + result1 = fn1() + return result0, result1 + + +def _capture_pair_graph( + fn0: KernelFn, + fn1: KernelFn, + concurrent: bool, +) -> tuple[torch.cuda.CUDAGraph, tuple[torch.cuda.Stream, ...]]: + torch.cuda.synchronize() + # Run once outside capture so the caching allocator and any JIT/autotune + # setup do not become capture-time side effects. + warm_stream0 = torch.cuda.Stream() + warm_stream1 = torch.cuda.Stream() + _run_pair(fn0, fn1, concurrent, warm_stream0, warm_stream1) + torch.cuda.synchronize() + + graph = torch.cuda.CUDAGraph() + capture_stream = torch.cuda.Stream() + stream0 = torch.cuda.Stream() + stream1 = torch.cuda.Stream() + current = torch.cuda.current_stream() + capture_stream.wait_stream(current) + with torch.cuda.stream(capture_stream), torch.cuda.graph(graph): + _run_pair(fn0, fn1, concurrent, stream0, stream1) + current.wait_stream(capture_stream) + torch.cuda.synchronize() + return graph, (capture_stream, stream0, stream1) + + +def _graph_time_ms( + fn0: KernelFn, + fn1: KernelFn, + concurrent: bool, + iterations: int, +) -> float: + graph, _streams = _capture_pair_graph(fn0, fn1, concurrent) + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + current = torch.cuda.current_stream() + start.record(current) + for _ in range(iterations): + graph.replay() + end.record(current) + end.synchronize() + return start.elapsed_time(end) + + +def _capture_runner_graph( + runner: KernelFn, +) -> tuple[torch.cuda.CUDAGraph, torch.cuda.Stream]: + torch.cuda.synchronize() + runner() + torch.cuda.synchronize() + + graph = torch.cuda.CUDAGraph() + capture_stream = torch.cuda.Stream() + current = torch.cuda.current_stream() + capture_stream.wait_stream(current) + with torch.cuda.stream(capture_stream), torch.cuda.graph(graph): + runner() + current.wait_stream(capture_stream) + torch.cuda.synchronize() + return graph, capture_stream + + +def _graph_runner_time_ms(runner: KernelFn, iterations: int) -> float: + graph, _capture_stream = _capture_runner_graph(runner) + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + current = torch.cuda.current_stream() + start.record(current) + for _ in range(iterations): + graph.replay() + end.record(current) + end.synchronize() + return start.elapsed_time(end) + + +def _compile_kernel(fn: KernelFn) -> KernelFn: + compiled = torch.compile( + fn, + backend="inductor", + dynamic=False, + fullgraph=False, + ) + compiled() + torch.cuda.synchronize() + return compiled + + +def _call_time_ms(fn: KernelFn, iterations: int) -> float: + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + current = torch.cuda.current_stream() + start.record(current) + for _ in range(iterations): + fn() + end.record(current) + end.synchronize() + return start.elapsed_time(end) + + +def _compile_pair_runner( + fn0: KernelFn, + fn1: KernelFn, + concurrent: bool, +) -> KernelFn: + stream0 = torch.cuda.Stream() + stream1 = torch.cuda.Stream() + + def pair_run() -> tuple[Any, Any]: + return _run_pair(fn0, fn1, concurrent, stream0, stream1) + + return _compile_kernel(pair_run) + + +def _profile_runner_trace( + runner: KernelFn, + profile_dir: Path, + scenario: str, + mode: str, + iterations: int, +) -> Path: + profile_dir.mkdir(parents=True, exist_ok=True) + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + record_shapes=False, + with_stack=False, + profile_memory=False, + ) as prof: + for _ in range(iterations): + runner() + torch.cuda.synchronize() + + trace = profile_dir / f"{scenario}.{mode}.trace.json" + prof.export_chrome_trace(str(trace)) + return trace + + +def _profile_runner_graph_trace( + runner: KernelFn, + profile_dir: Path, + scenario: str, + mode: str, + iterations: int, +) -> Path: + profile_dir.mkdir(parents=True, exist_ok=True) + graph, _capture_stream = _capture_runner_graph(runner) + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + record_shapes=False, + with_stack=False, + profile_memory=False, + ) as prof: + for _ in range(iterations): + graph.replay() + torch.cuda.synchronize() + + trace = profile_dir / f"{scenario}.{mode}.trace.json" + prof.export_chrome_trace(str(trace)) + return trace + + +def _profile_trace( + fn0: KernelFn, + fn1: KernelFn, + profile_dir: Path, + scenario: str, + mode: str, + use_graph: bool, + iterations: int, +) -> Path: + profile_dir.mkdir(parents=True, exist_ok=True) + stream0 = torch.cuda.Stream() + stream1 = torch.cuda.Stream() + current = torch.cuda.current_stream() + graph = None + if use_graph: + graph, _streams = _capture_pair_graph(fn0, fn1, concurrent=True) + + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + record_shapes=False, + with_stack=False, + profile_memory=False, + ) as prof: + if graph is not None: + for _ in range(iterations): + graph.replay() + else: + for _ in range(iterations): + _run_pair(fn0, fn1, True, stream0, stream1) + current.wait_stream(stream0) + current.wait_stream(stream1) + torch.cuda.synchronize() + + trace = profile_dir / f"{scenario}.{mode}.trace.json" + prof.export_chrome_trace(str(trace)) + return trace + + +def _summarize_trace(trace: Path) -> dict[str, object]: + with trace.open() as f: + data = json.load(f) + stream_counts: Counter[str] = Counter() + stream_kernel_counts: Counter[str] = Counter() + stream_durations_us: Counter[str] = Counter() + event_counts: Counter[str] = Counter() + kernel_names: Counter[tuple[str, str]] = Counter() + + for event in data.get("traceEvents", []): + name = event.get("name", "") + if name in ("hipEventRecord", "hipStreamWaitEvent"): + event_counts[name] += 1 + args = event.get("args") or {} + stream = args.get("stream") + if stream is None: + continue + stream = str(stream) + stream_counts[stream] += 1 + if event.get("cat") in ("kernel", "gpu_memcpy"): + stream_kernel_counts[stream] += 1 + stream_durations_us[stream] += event.get("dur", 0.0) or 0.0 + kernel_names[(stream, name[:96])] += 1 + + return { + "streams": dict(stream_counts), + "stream_kernel_counts": dict(stream_kernel_counts), + "stream_durations_us": dict(stream_durations_us), + "event_counts": dict(event_counts), + "top_kernel_names": [ + {"stream": stream, "name": name, "count": count} + for (stream, name), count in kernel_names.most_common(12) + ], + } + + +def _scenario( + name: str, + device: torch.device, +) -> tuple[KernelFn, KernelFn, KernelFn, KernelFn]: + if name == "bf16_mm_pair": + return ( + _make_bf16_mm(256, 2048, 7168, device), + _make_bf16_mm(256, 768, 7168, device), + _make_bf16_mm(256, 2048, 7168, device), + _make_bf16_mm(256, 768, 7168, device), + ) + if name == "aiter_vs_bf16_mm_decode": + return ( + _make_aiter_fp8_block_gemm(4, 2048, 7168, device), + _make_bf16_mm(4, 768, 7168, device), + _make_aiter_fp8_block_gemm(4, 2048, 7168, device), + _make_bf16_mm(4, 768, 7168, device), + ) + if name == "aiter_pair_decode": + return ( + _make_aiter_fp8_block_gemm(4, 2048, 7168, device), + _make_aiter_fp8_block_gemm(4, 768, 7168, device), + _make_aiter_fp8_block_gemm(4, 2048, 7168, device), + _make_aiter_fp8_block_gemm(4, 768, 7168, device), + ) + if name == "topk_vs_bf16_mm": + return ( + _make_topk(64, 8192, 128, device), + _make_bf16_mm(64, 768, 7168, device), + _make_topk(64, 8192, 128, device), + _make_bf16_mm(64, 768, 7168, device), + ) + raise ValueError(f"unknown scenario: {name}") + + +def _mode_enabled(selected: str, mode: str) -> bool: + if selected == "all": + return True + if selected == "both": + return mode in ("eager", "graph") + return selected == mode + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument( + "--scenario", + action="append", + choices=[ + "bf16_mm_pair", + "aiter_vs_bf16_mm_decode", + "aiter_pair_decode", + "topk_vs_bf16_mm", + ], + ) + parser.add_argument("--repeats", type=int, default=200) + parser.add_argument("--profile-repeats", type=int, default=3) + parser.add_argument("--warmup", type=int, default=5) + parser.add_argument( + "--profile-compile-pair", + action="store_true", + help=( + "Also profile the torch.compile wrapper that includes Python " + "stream scheduling. ROCTracer may be unstable for this mode on " + "some ROCm builds, so timing is collected by default without a " + "trace." + ), + ) + parser.add_argument( + "--profile-compile", + action="store_true", + help=( + "Profile compiled branch modes when running --mode all. Explicit " + "--mode compile and --mode compile_graph runs are profiled by " + "default when --profile-repeats is positive." + ), + ) + parser.add_argument( + "--profile-aggregate", + action="store_true", + help=( + "Profile aggregate modes such as --mode both or --mode all. On " + "some ROCm stacks, multiple profiler sessions in one process can " + "crash during ROCTracer cleanup, so aggregate modes collect timing " + "only unless this is set." + ), + ) + parser.add_argument( + "--mode", + choices=[ + "eager", + "graph", + "compile", + "compile_graph", + "compile_pair", + "compile_pair_graph", + "both", + "all", + ], + default="both", + help=( + "Measure eager forced streams, graph replay, torch.compile branch " + "functions, compiled branches under graph replay, a compiled whole " + "pair scheduler, the compiled scheduler under graph replay, or all " + "modes." + ), + ) + parser.add_argument( + "--profile-dir", + type=Path, + default=Path("/tmp/vllm_rocm_dsv4_stream_probe"), + ) + args = parser.parse_args() + + scenarios = args.scenario or [ + "bf16_mm_pair", + "aiter_vs_bf16_mm_decode", + "aiter_pair_decode", + "topk_vs_bf16_mm", + ] + device = torch.device("cuda") + print(f"pid={os.getpid()} device={torch.cuda.get_device_name(0)!r}") + + for name in scenarios: + fn0, fn1, profile_fn0, profile_fn1 = _scenario(name, device) + for _ in range(args.warmup): + _event_time_ms(fn0, fn1, concurrent=True, iterations=1) + seq_ms = _event_time_ms( + fn0, fn1, concurrent=False, iterations=args.repeats + ) + conc_ms = _event_time_ms( + fn0, fn1, concurrent=True, iterations=args.repeats + ) + overlap_pct = 100.0 * (1.0 - conc_ms / seq_ms) if seq_ms else 0.0 + results: dict[str, object] = {"scenario": name} + if _mode_enabled(args.mode, "eager"): + should_profile_eager = args.profile_repeats > 0 and ( + args.mode == "eager" or args.profile_aggregate + ) + eager_result = { + "sequential_ms": seq_ms, + "concurrent_ms": conc_ms, + "overlap_pct": overlap_pct, + } + if should_profile_eager: + trace = _profile_trace( + profile_fn0, + profile_fn1, + args.profile_dir, + name, + "eager", + use_graph=False, + iterations=args.profile_repeats, + ) + eager_result.update({ + "trace": str(trace), + **_summarize_trace(trace), + }) + elif args.profile_repeats > 0: + eager_result["profile_note"] = ( + "eager trace skipped in aggregate mode; use --mode eager " + "or pass --profile-aggregate" + ) + results["eager"] = eager_result + if _mode_enabled(args.mode, "graph"): + try: + should_profile_graph = args.profile_repeats > 0 and ( + args.mode == "graph" or args.profile_aggregate + ) + graph_seq_ms = _graph_time_ms( + fn0, + fn1, + concurrent=False, + iterations=args.repeats, + ) + graph_conc_ms = _graph_time_ms( + fn0, + fn1, + concurrent=True, + iterations=args.repeats, + ) + graph_overlap_pct = ( + 100.0 * (1.0 - graph_conc_ms / graph_seq_ms) + if graph_seq_ms + else 0.0 + ) + graph_result = { + "sequential_ms": graph_seq_ms, + "concurrent_ms": graph_conc_ms, + "overlap_pct": graph_overlap_pct, + } + if should_profile_graph: + graph_trace = _profile_trace( + profile_fn0, + profile_fn1, + args.profile_dir, + name, + "graph", + use_graph=True, + iterations=args.profile_repeats, + ) + graph_result.update({ + "trace": str(graph_trace), + **_summarize_trace(graph_trace), + }) + elif args.profile_repeats > 0: + graph_result["profile_note"] = ( + "graph trace skipped in aggregate mode; use --mode " + "graph or pass --profile-aggregate" + ) + results["graph"] = graph_result + except Exception as exc: + results["graph"] = { + "error": f"{type(exc).__name__}: {exc}", + } + if _mode_enabled(args.mode, "compile"): + try: + should_profile_compile = args.profile_repeats > 0 and ( + args.mode == "compile" + or (args.profile_aggregate and args.profile_compile) + ) + compiled_fn0 = _compile_kernel(fn0) + compiled_fn1 = _compile_kernel(fn1) + compiled_profile_fn0 = _compile_kernel(profile_fn0) + compiled_profile_fn1 = _compile_kernel(profile_fn1) + compile_seq_ms = _event_time_ms( + compiled_fn0, + compiled_fn1, + concurrent=False, + iterations=args.repeats, + ) + compile_conc_ms = _event_time_ms( + compiled_fn0, + compiled_fn1, + concurrent=True, + iterations=args.repeats, + ) + compile_overlap_pct = ( + 100.0 * (1.0 - compile_conc_ms / compile_seq_ms) + if compile_seq_ms + else 0.0 + ) + compile_result = { + "sequential_ms": compile_seq_ms, + "concurrent_ms": compile_conc_ms, + "overlap_pct": compile_overlap_pct, + } + if should_profile_compile: + compile_trace = _profile_trace( + compiled_profile_fn0, + compiled_profile_fn1, + args.profile_dir, + name, + "compile", + use_graph=False, + iterations=args.profile_repeats, + ) + compile_result.update({ + "trace": str(compile_trace), + **_summarize_trace(compile_trace), + }) + else: + compile_result["profile_note"] = ( + "compiled branch trace skipped in aggregate mode; " + "use --mode compile or pass --profile-compile" + ) + results["compile"] = compile_result + except Exception as exc: + results["compile"] = { + "error": f"{type(exc).__name__}: {exc}", + } + if _mode_enabled(args.mode, "compile_graph"): + try: + should_profile_compile_graph = args.profile_repeats > 0 and ( + args.mode == "compile_graph" + or (args.profile_aggregate and args.profile_compile) + ) + compiled_fn0 = _compile_kernel(fn0) + compiled_fn1 = _compile_kernel(fn1) + compiled_profile_fn0 = _compile_kernel(profile_fn0) + compiled_profile_fn1 = _compile_kernel(profile_fn1) + compile_graph_seq_ms = _graph_time_ms( + compiled_fn0, + compiled_fn1, + concurrent=False, + iterations=args.repeats, + ) + compile_graph_conc_ms = _graph_time_ms( + compiled_fn0, + compiled_fn1, + concurrent=True, + iterations=args.repeats, + ) + compile_graph_overlap_pct = ( + 100.0 * (1.0 - compile_graph_conc_ms / compile_graph_seq_ms) + if compile_graph_seq_ms + else 0.0 + ) + compile_graph_result = { + "sequential_ms": compile_graph_seq_ms, + "concurrent_ms": compile_graph_conc_ms, + "overlap_pct": compile_graph_overlap_pct, + } + if should_profile_compile_graph: + compile_graph_trace = _profile_trace( + compiled_profile_fn0, + compiled_profile_fn1, + args.profile_dir, + name, + "compile_graph", + use_graph=True, + iterations=args.profile_repeats, + ) + compile_graph_result.update({ + "trace": str(compile_graph_trace), + **_summarize_trace(compile_graph_trace), + }) + else: + compile_graph_result["profile_note"] = ( + "compiled graph trace skipped in aggregate mode; " + "use --mode compile_graph or pass --profile-compile" + ) + results["compile_graph"] = compile_graph_result + except Exception as exc: + results["compile_graph"] = { + "error": f"{type(exc).__name__}: {exc}", + } + if _mode_enabled(args.mode, "compile_pair"): + try: + compiled_seq_pair = _compile_pair_runner(fn0, fn1, concurrent=False) + compiled_conc_pair = _compile_pair_runner(fn0, fn1, concurrent=True) + compile_pair_seq_ms = _call_time_ms( + compiled_seq_pair, iterations=args.repeats + ) + compile_pair_conc_ms = _call_time_ms( + compiled_conc_pair, iterations=args.repeats + ) + compile_pair_overlap_pct = ( + 100.0 * (1.0 - compile_pair_conc_ms / compile_pair_seq_ms) + if compile_pair_seq_ms + else 0.0 + ) + compile_pair_result = { + "sequential_ms": compile_pair_seq_ms, + "concurrent_ms": compile_pair_conc_ms, + "overlap_pct": compile_pair_overlap_pct, + } + if args.profile_compile_pair and args.profile_repeats > 0: + compiled_profile_pair = _compile_pair_runner( + profile_fn0, + profile_fn1, + concurrent=True, + ) + compile_pair_trace = _profile_runner_trace( + compiled_profile_pair, + args.profile_dir, + name, + "compile_pair", + iterations=args.profile_repeats, + ) + compile_pair_result.update({ + "trace": str(compile_pair_trace), + **_summarize_trace(compile_pair_trace), + }) + else: + compile_pair_result["profile_note"] = ( + "compile_pair trace skipped; pass --profile-compile-pair " + "to profile the compiled Python stream scheduler" + ) + results["compile_pair"] = compile_pair_result + except Exception as exc: + results["compile_pair"] = { + "error": f"{type(exc).__name__}: {exc}", + } + if _mode_enabled(args.mode, "compile_pair_graph"): + try: + compiled_seq_pair = _compile_pair_runner(fn0, fn1, concurrent=False) + compiled_conc_pair = _compile_pair_runner(fn0, fn1, concurrent=True) + compile_pair_graph_seq_ms = _graph_runner_time_ms( + compiled_seq_pair, iterations=args.repeats + ) + compile_pair_graph_conc_ms = _graph_runner_time_ms( + compiled_conc_pair, iterations=args.repeats + ) + compile_pair_graph_overlap_pct = ( + 100.0 + * ( + 1.0 + - compile_pair_graph_conc_ms / compile_pair_graph_seq_ms + ) + if compile_pair_graph_seq_ms + else 0.0 + ) + compile_pair_graph_result = { + "sequential_ms": compile_pair_graph_seq_ms, + "concurrent_ms": compile_pair_graph_conc_ms, + "overlap_pct": compile_pair_graph_overlap_pct, + } + if args.profile_repeats > 0 and ( + args.mode == "compile_pair_graph" or args.profile_aggregate + ): + compiled_profile_pair = _compile_pair_runner( + profile_fn0, + profile_fn1, + concurrent=True, + ) + compile_pair_graph_trace = _profile_runner_graph_trace( + compiled_profile_pair, + args.profile_dir, + name, + "compile_pair_graph", + iterations=args.profile_repeats, + ) + compile_pair_graph_result.update({ + "trace": str(compile_pair_graph_trace), + **_summarize_trace(compile_pair_graph_trace), + }) + elif args.profile_repeats > 0: + compile_pair_graph_result["profile_note"] = ( + "compile_pair_graph trace skipped in aggregate mode; " + "use --mode compile_pair_graph or pass " + "--profile-aggregate" + ) + results["compile_pair_graph"] = compile_pair_graph_result + except Exception as exc: + results["compile_pair_graph"] = { + "error": f"{type(exc).__name__}: {exc}", + } + + print(json.dumps(results, indent=2, sort_keys=True)) + + +if __name__ == "__main__": + main() From a2c4f32c65b1d16759c36f07f9d0a5c02c5810e3 Mon Sep 17 00:00:00 2001 From: vLLM Contributor Date: Sun, 24 May 2026 10:09:24 +0000 Subject: [PATCH 05/21] Add ROCm DSV4 CSA multi-stream experiment Implements an opt-in ROCm DeepSeek-V4 CSA decode multi-stream path and records the tuning result for issue #41820. Change log: - add VLLM_ROCM_DSV4_CSA_MULTISTREAM and strategy/threshold/graph-mode/tuning envs - create five ROCm aux streams for the SGLang-style hierarchy: aux[0] main compressor, aux[1] C4 indexer, aux[2:4] C4 indexer sub-branches - gate the path to ROCm decode-only steps, configured graph runtime modes, and min/max decode counts - fix the low-workload gate to use the batch decode count rather than summing identical per-layer metadata - keep ROCm input GEMM aux streams disabled by default; overlap is attempted after q/kv rmsnorm where the dependencies are explicit - add tunables for main compressor, outer indexer, inner indexer substreams, deferred projections, and aux stream priority - launch default-stream work before aux branches in execute_in_parallel so the critical path is not delayed by side-stream CPU launch overhead - extend the standalone ROCm stream probe with disabled_compile_pair_graph and document the graph/compile stream-collapse finding Stream/event mapping: - default stream: fused WQA/WKV input projection, q/k rmsnorm, wq_b, qnorm/rope/KV insert, MLA attention - aux[0]: main compressor branch when VLLM_ROCM_DSV4_CSA_MS_MAIN_COMPRESSOR=1 - aux[1]: C4 indexer branch when VLLM_ROCM_DSV4_CSA_MS_OUTER_INDEXER=1 - aux[2]: optional C4 indexer q projection/rope/quant sub-branch - aux[3]: optional C4 indexer weights projection sub-branch - start events are recorded at fan-out producer boundaries; done events are waited only before consuming branch outputs - FULL graph runtime is excluded by default because the standalone torch.compile+CUDAGraph repro collapses stream scheduling onto stream 0 Environment and hardware used: - branch imported from /shared/amdgpu/home/fai_qle/vllm/vllm/__init__.py - base commit before this change: 067ca978e0ef36f75b06114c65aad11955303479 - 8x gfx950 GPUs, 309220868096 B VRAM each - ROCm 7.2.2, driver 6.14.14, rocprofv3 1.1.0 - server flags matched the InferenceX-style setup: TP=8, mp backend, triton_unfused MoE, fp8 KV cache, tokenizer/reasoning deepseek_v4 - graph mode for the main comparison: {"mode":3,"cudagraph_mode":"PIECEWISE"} - enabled env: VLLM_ROCM_DSV4_CSA_MULTISTREAM=1, STRATEGY=sglang, MIN_DECODE=1, MAX_DECODE=32, GRAPH_MODES=none,piecewise Performance result: - 1k/1k conc=4 baseline off: output 75.88 tok/s, req 0.0759/s, TTFT 1899.39 ms, TPOT 50.84 ms, p99 ITL 52.38 ms - 1k/1k conc=4 enabled current: output 47.56 tok/s, req 0.0476/s, TTFT 1004.25 ms, TPOT 83.16 ms, p99 ITL 85.03 ms - 1k/64 conc=4 baseline off: output 64.77 tok/s, req 1.0121/s, TTFT 674.65 ms, TPOT 51.92 ms, p99 ITL 52.40 ms - 1k/64 conc=4 full SGLang topology after launch-order fix: output 43.63 tok/s, req 0.6817/s, TTFT 559.72 ms, TPOT 83.76 ms, p99 ITL 88.24 ms - 1k/64 conc=4 outer-indexer/no-compressor: output 46.66 tok/s, TPOT 66.79 ms, p99 ITL 65.71 ms - 1k/64 conc=4 inner-indexer-only: output 51.13 tok/s, TPOT 67.13 ms, p99 ITL 70.19 ms - 1k/64 conc=4 aux priority=2: output 40.77 tok/s, TPOT 85.38 ms, p99 ITL 91.78 ms - deferred projection topology hung during warmup and remains off by default - model load reported 142.43 GiB and PIECEWISE graph capture reported 7.11 GiB Profiler conclusion: - pre-fix torch profiler on 1k/64 showed all interesting rank0 DeepSeek kernels on stream 4; the decode-count gate was incorrectly disabling the feature by summing per-layer metadata - after enabling the path, torch profiler hangs with aux streams and rocprof attach did not emit usable worker CSVs in this container - standalone rocprofv3 stream probe shows manual graph mode preserves separate ROCm queues but the decode-sized AITER/BF16 kernels still have zero useful timestamp overlap; this points to CU/resource contention - compile_pair_graph and disabled_compile_pair_graph reproduce graph replay stream collapse onto stream 0 for the vLLM-like compiled stream scheduler - net: SGLang-style Python-level side-streaming is expressible in vLLM but not profitable with the current ROCm kernels/graph topology Correctness and tests: - deterministic API completion off vs on, temperature=0, max_tokens=32: exact text match - .venv/bin/python -m ruff check vllm/envs.py vllm/models/deepseek_v4/nvidia/model.py vllm/models/deepseek_v4/nvidia/ops/attention.py vllm/utils/multi_stream_utils.py tests/models/test_deepseek_v4_rocm_multistream.py benchmarks/kernels/rocm_dsv4_stream_probe.py: pass - .venv/bin/python -m py_compile same files: pass - .venv/bin/python -m pytest tests/models/test_deepseek_v4_rocm_multistream.py -q: 6 passed - GSM8K/full accuracy was not run because the active stream topology is a clear performance regression Recommended policy: - keep VLLM_ROCM_DSV4_CSA_MULTISTREAM off by default - do not recommend enabling sglang topology for ROCm production yet; it is useful for reproducing and profiling the failure mode - likely remedies are graph-capturing the stream scheduler outside torch.compile, lowering side-branch kernel occupancy, or adding fused/stream-friendly ROCm kernels comparable to SGLang's implementation Co-authored-by: OpenAI Codex Signed-off-by: vLLM Contributor --- benchmarks/kernels/rocm_dsv4_stream_probe.py | 79 +++- .../test_deepseek_v4_rocm_multistream.py | 51 ++- vllm/envs.py | 57 +++ vllm/models/deepseek_v4/attention.py | 358 +++++++++++++++--- vllm/models/deepseek_v4/nvidia/model.py | 43 ++- vllm/utils/multi_stream_utils.py | 12 +- 6 files changed, 517 insertions(+), 83 deletions(-) diff --git a/benchmarks/kernels/rocm_dsv4_stream_probe.py b/benchmarks/kernels/rocm_dsv4_stream_probe.py index bca04ab44c6d..8d2d078fc67a 100644 --- a/benchmarks/kernels/rocm_dsv4_stream_probe.py +++ b/benchmarks/kernels/rocm_dsv4_stream_probe.py @@ -31,7 +31,9 @@ Finding from the repro: - graph mode preserves separate ROCm queues for the representative AITER and - BF16 GEMM branches; + BF16 GEMM branches, but the tested decode-sized kernels did not produce + useful timestamp overlap, which points to kernel resource contention rather + than Python stream ownership as the first bottleneck; - compile_pair_graph collapses the same representative kernels onto ROCm stream 0 / queue 1 during graph replay, exposing the non-overlap failure mode seen in vLLM-like compiled scheduling. @@ -269,6 +271,7 @@ def _compile_pair_runner( fn0: KernelFn, fn1: KernelFn, concurrent: bool, + disable_scheduler_compile: bool = False, ) -> KernelFn: stream0 = torch.cuda.Stream() stream1 = torch.cuda.Stream() @@ -276,6 +279,8 @@ def _compile_pair_runner( def pair_run() -> tuple[Any, Any]: return _run_pair(fn0, fn1, concurrent, stream0, stream1) + if disable_scheduler_compile: + pair_run = torch.compiler.disable(pair_run) return _compile_kernel(pair_run) @@ -505,6 +510,7 @@ def main() -> None: "compile_graph", "compile_pair", "compile_pair_graph", + "disabled_compile_pair_graph", "both", "all", ], @@ -512,8 +518,9 @@ def main() -> None: help=( "Measure eager forced streams, graph replay, torch.compile branch " "functions, compiled branches under graph replay, a compiled whole " - "pair scheduler, the compiled scheduler under graph replay, or all " - "modes." + "pair scheduler, the compiled scheduler under graph replay, a " + "torch.compiler.disable-protected scheduler under graph replay, or " + "all modes." ), ) parser.add_argument( @@ -837,6 +844,72 @@ def main() -> None: results["compile_pair_graph"] = { "error": f"{type(exc).__name__}: {exc}", } + if _mode_enabled(args.mode, "disabled_compile_pair_graph"): + try: + disabled_seq_pair = _compile_pair_runner( + fn0, + fn1, + concurrent=False, + disable_scheduler_compile=True, + ) + disabled_conc_pair = _compile_pair_runner( + fn0, + fn1, + concurrent=True, + disable_scheduler_compile=True, + ) + disabled_pair_graph_seq_ms = _graph_runner_time_ms( + disabled_seq_pair, iterations=args.repeats + ) + disabled_pair_graph_conc_ms = _graph_runner_time_ms( + disabled_conc_pair, iterations=args.repeats + ) + disabled_pair_graph_overlap_pct = ( + 100.0 + * ( + 1.0 + - disabled_pair_graph_conc_ms / disabled_pair_graph_seq_ms + ) + if disabled_pair_graph_seq_ms + else 0.0 + ) + disabled_pair_graph_result = { + "sequential_ms": disabled_pair_graph_seq_ms, + "concurrent_ms": disabled_pair_graph_conc_ms, + "overlap_pct": disabled_pair_graph_overlap_pct, + } + if args.profile_repeats > 0 and ( + args.mode == "disabled_compile_pair_graph" + or args.profile_aggregate + ): + disabled_profile_pair = _compile_pair_runner( + profile_fn0, + profile_fn1, + concurrent=True, + disable_scheduler_compile=True, + ) + disabled_pair_graph_trace = _profile_runner_graph_trace( + disabled_profile_pair, + args.profile_dir, + name, + "disabled_compile_pair_graph", + iterations=args.profile_repeats, + ) + disabled_pair_graph_result.update({ + "trace": str(disabled_pair_graph_trace), + **_summarize_trace(disabled_pair_graph_trace), + }) + elif args.profile_repeats > 0: + disabled_pair_graph_result["profile_note"] = ( + "disabled_compile_pair_graph trace skipped in aggregate " + "mode; use --mode disabled_compile_pair_graph or pass " + "--profile-aggregate" + ) + results["disabled_compile_pair_graph"] = disabled_pair_graph_result + except Exception as exc: + results["disabled_compile_pair_graph"] = { + "error": f"{type(exc).__name__}: {exc}", + } print(json.dumps(results, indent=2, sort_keys=True)) diff --git a/tests/models/test_deepseek_v4_rocm_multistream.py b/tests/models/test_deepseek_v4_rocm_multistream.py index 0a0d7a914f1a..3ba23c687378 100644 --- a/tests/models/test_deepseek_v4_rocm_multistream.py +++ b/tests/models/test_deepseek_v4_rocm_multistream.py @@ -2,19 +2,47 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from vllm.models.deepseek_v4.amd import model as rocm_model +from vllm.models.deepseek_v4.nvidia.ops import attention as dsv4_attention def test_deepseek_v4_rocm_aux_streams_enabled(monkeypatch): - streams = [object(), object(), object()] + streams = [object(), object(), object(), object(), object()] + + def make_stream(**kwargs): + assert kwargs == {"priority": 0} + return streams.pop() monkeypatch.setattr(rocm_model.current_platform, "is_rocm", lambda: True) monkeypatch.setattr(rocm_model.current_platform, "is_xpu", lambda: False) - monkeypatch.setattr(rocm_model.torch.cuda, "Stream", streams.pop) + monkeypatch.setattr(rocm_model.torch.cuda, "Stream", make_stream) + monkeypatch.setenv("VLLM_ROCM_DSV4_CSA_MULTISTREAM", "1") + monkeypatch.setenv("VLLM_ROCM_DSV4_CSA_MS_STRATEGY", "sglang") aux_streams = rocm_model.make_deepseek_v4_aux_streams() assert aux_streams is not None - assert len(aux_streams) == 3 + assert len(aux_streams) == 5 + + +def test_deepseek_v4_rocm_aux_streams_disabled_by_default(monkeypatch): + monkeypatch.setattr(rocm_model.current_platform, "is_rocm", lambda: True) + monkeypatch.setattr(rocm_model.current_platform, "is_xpu", lambda: False) + monkeypatch.delenv("VLLM_ROCM_DSV4_CSA_MULTISTREAM", raising=False) + + aux_streams = rocm_model.make_deepseek_v4_aux_streams() + + assert aux_streams is None + + +def test_deepseek_v4_rocm_aux_streams_strategy_off(monkeypatch): + monkeypatch.setattr(rocm_model.current_platform, "is_rocm", lambda: True) + monkeypatch.setattr(rocm_model.current_platform, "is_xpu", lambda: False) + monkeypatch.setenv("VLLM_ROCM_DSV4_CSA_MULTISTREAM", "1") + monkeypatch.setenv("VLLM_ROCM_DSV4_CSA_MS_STRATEGY", "off") + + aux_streams = rocm_model.make_deepseek_v4_aux_streams() + + assert aux_streams is None def test_deepseek_v4_rocm_aux_streams_xpu_fallback(monkeypatch): @@ -37,3 +65,20 @@ def test_deepseek_v4_aux_streams_cuda_behavior_unchanged(monkeypatch): assert aux_streams is not None assert len(aux_streams) == 3 + + +class _Metadata: + def __init__(self, num_decodes: int, num_decode_tokens: int): + self.num_decodes = num_decodes + self.num_decode_tokens = num_decode_tokens + + +def test_deepseek_v4_num_decodes_uses_batch_count_not_layer_sum(monkeypatch): + monkeypatch.setattr(dsv4_attention, "DeepseekSparseSWAMetadata", _Metadata) + attn_metadata = { + "layer_0.swa": _Metadata(num_decodes=4, num_decode_tokens=4), + "layer_1.swa": _Metadata(num_decodes=4, num_decode_tokens=4), + "layer_2.swa": _Metadata(num_decodes=4, num_decode_tokens=4), + } + + assert dsv4_attention._deepseek_v4_num_decodes(attn_metadata) == 4 diff --git a/vllm/envs.py b/vllm/envs.py index 3a3934f3cdf9..ffc488168614 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -257,6 +257,21 @@ VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD: int = 256 VLLM_MULTI_STREAM_GEMM_TOKEN_THRESHOLD: int = 1024 + VLLM_ROCM_DSV4_CSA_MULTISTREAM: bool = False + VLLM_ROCM_DSV4_CSA_MS_STRATEGY: Literal["off", "indexer_only", "sglang"] = ( + "sglang" + ) + VLLM_ROCM_DSV4_CSA_MS_MIN_DECODE: int = 1 + VLLM_ROCM_DSV4_CSA_MS_MAX_DECODE: int = 32 + VLLM_ROCM_DSV4_CSA_MS_GRAPH_MODES: set[Literal["none", "piecewise", "full"]] = { + "none", + "piecewise", + } + VLLM_ROCM_DSV4_CSA_MS_MAIN_COMPRESSOR: bool = True + VLLM_ROCM_DSV4_CSA_MS_DEFER_PROJECTIONS: bool = False + VLLM_ROCM_DSV4_CSA_MS_OUTER_INDEXER: bool = True + VLLM_ROCM_DSV4_CSA_MS_INDEXER_SUBSTREAMS: bool = True + VLLM_ROCM_DSV4_CSA_MS_AUX_PRIORITY: int = 0 VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary" VLLM_USE_V2_MODEL_RUNNER: bool | None = None VLLM_LOG_MODEL_INSPECTION: bool = False @@ -1878,6 +1893,48 @@ def _resolve_rust_frontend_path() -> str | None: "VLLM_MULTI_STREAM_GEMM_TOKEN_THRESHOLD": lambda: int( os.getenv("VLLM_MULTI_STREAM_GEMM_TOKEN_THRESHOLD", "1024") ), + # ROCm-only opt-in for DeepSeek-V4 CSA decode multi-stream overlap. + # The "sglang" strategy overlaps C4 indexer and compressor preparation + # branches with fine-grained event joins. The deferred projection and + # sub-branch knobs are separate A/B controls because ROCm graph capture and + # hipBLASLt ordering differ by PyTorch/ROCm release. "indexer_only" keeps + # the conservative single indexer branch path for A/B testing. + "VLLM_ROCM_DSV4_CSA_MULTISTREAM": lambda: bool( + int(os.getenv("VLLM_ROCM_DSV4_CSA_MULTISTREAM", "0")) + ), + "VLLM_ROCM_DSV4_CSA_MS_STRATEGY": env_with_choices( + "VLLM_ROCM_DSV4_CSA_MS_STRATEGY", + "sglang", + ["off", "indexer_only", "sglang"], + case_sensitive=False, + ), + "VLLM_ROCM_DSV4_CSA_MS_MIN_DECODE": lambda: int( + os.getenv("VLLM_ROCM_DSV4_CSA_MS_MIN_DECODE", "1") + ), + "VLLM_ROCM_DSV4_CSA_MS_MAX_DECODE": lambda: int( + os.getenv("VLLM_ROCM_DSV4_CSA_MS_MAX_DECODE", "32") + ), + "VLLM_ROCM_DSV4_CSA_MS_GRAPH_MODES": env_set_with_choices( + "VLLM_ROCM_DSV4_CSA_MS_GRAPH_MODES", + ["none", "piecewise"], + ["none", "piecewise", "full"], + case_sensitive=False, + ), + "VLLM_ROCM_DSV4_CSA_MS_MAIN_COMPRESSOR": lambda: bool( + int(os.getenv("VLLM_ROCM_DSV4_CSA_MS_MAIN_COMPRESSOR", "1")) + ), + "VLLM_ROCM_DSV4_CSA_MS_DEFER_PROJECTIONS": lambda: bool( + int(os.getenv("VLLM_ROCM_DSV4_CSA_MS_DEFER_PROJECTIONS", "0")) + ), + "VLLM_ROCM_DSV4_CSA_MS_OUTER_INDEXER": lambda: bool( + int(os.getenv("VLLM_ROCM_DSV4_CSA_MS_OUTER_INDEXER", "1")) + ), + "VLLM_ROCM_DSV4_CSA_MS_INDEXER_SUBSTREAMS": lambda: bool( + int(os.getenv("VLLM_ROCM_DSV4_CSA_MS_INDEXER_SUBSTREAMS", "1")) + ), + "VLLM_ROCM_DSV4_CSA_MS_AUX_PRIORITY": lambda: int( + os.getenv("VLLM_ROCM_DSV4_CSA_MS_AUX_PRIORITY", "0") + ), # Format for saving torch.compile cache artifacts # - "binary": saves as binary file # Safe for multiple vllm serve processes accessing the same torch compile cache. diff --git a/vllm/models/deepseek_v4/attention.py b/vllm/models/deepseek_v4/attention.py index 76a16714c0be..45a521095e8a 100644 --- a/vllm/models/deepseek_v4/attention.py +++ b/vllm/models/deepseek_v4/attention.py @@ -35,6 +35,7 @@ from vllm.config import ( CacheConfig, + CUDAGraphMode, VllmConfig, get_current_vllm_config, ) @@ -113,6 +114,21 @@ def _is_decode_only_deepseek_v4_step( ) +def _deepseek_v4_num_decodes( + attn_metadata: dict[str, AttentionMetadata] + | list[dict[str, AttentionMetadata]] + | None, +) -> int: + count = 0 + for item in _iter_deepseek_v4_swa_metadata(attn_metadata): + # The attention metadata dict contains one entry per DeepSeek-V4 + # attention/cache layer. Each entry describes the same batch, so + # summing here turns conc=4 into conc=4*num_layers and incorrectly + # gates off low-workload decode. + count = max(count, item.num_decodes or item.num_decode_tokens) + return count + + def _select_v4_sparse_impl() -> "type[DeepseekV4SparseMLAAttentionImpl]": """Pick the platform-specific V4 sparse MLA impl class. Sole platform check.""" if current_platform.is_rocm(): @@ -385,20 +401,103 @@ def forward( def _aux_streams_for_step( self, + rocm_ms_strategy: str, attn_metadata: ( dict[str, AttentionMetadata] | list[dict[str, AttentionMetadata]] | None ), ) -> list[torch.cuda.Stream] | None: - if current_platform.is_rocm() and not _is_decode_only_deepseek_v4_step( - attn_metadata - ): - return None + if current_platform.is_rocm(): + if rocm_ms_strategy == "off": + return None + if not _is_decode_only_deepseek_v4_step(attn_metadata): + return None return self.aux_stream_list + def _rocm_csa_ms_strategy_for_step( + self, + forward_context: ForwardContext, + attn_metadata: ( + dict[str, AttentionMetadata] | list[dict[str, AttentionMetadata]] | None + ), + ) -> str: + if not current_platform.is_rocm(): + return "off" + if ( + not envs.VLLM_ROCM_DSV4_CSA_MULTISTREAM + or self.aux_stream_list is None + ): + return "off" + + strategy = envs.VLLM_ROCM_DSV4_CSA_MS_STRATEGY.lower() + if strategy == "off": + return "off" + if not _is_decode_only_deepseek_v4_step(attn_metadata): + return "off" + + num_decodes = _deepseek_v4_num_decodes(attn_metadata) + min_decode = envs.VLLM_ROCM_DSV4_CSA_MS_MIN_DECODE + max_decode = envs.VLLM_ROCM_DSV4_CSA_MS_MAX_DECODE + if num_decodes < min_decode or num_decodes > max_decode: + return "off" + + graph_mode = forward_context.cudagraph_runtime_mode + assert graph_mode in CUDAGraphMode.valid_runtime_modes() + graph_mode_name = graph_mode.name.lower() + enabled_graph_modes = { + mode.lower() for mode in envs.VLLM_ROCM_DSV4_CSA_MS_GRAPH_MODES + } + # Standalone ROCm repro in benchmarks/kernels/rocm_dsv4_stream_probe.py + # shows that putting the Python stream scheduler inside a compiled + # full-graph wrapper can collapse the replayed work onto stream 0. + # Keep the ROCm CSA multi-stream scheduler in eager/piecewise graph + # islands by default; FULL can be enabled explicitly for debugging via + # VLLM_ROCM_DSV4_CSA_MS_GRAPH_MODES=none,piecewise,full. + if graph_mode_name not in enabled_graph_modes: + return "off" + + if strategy == "sglang" and len(self.aux_stream_list) < 5: + return "off" + if strategy == "indexer_only" and self.indexer is None: + return "off" + return strategy + + def _rocm_aux_buffer( + self, + name: str, + shape: tuple[int, ...], + dtype: torch.dtype, + device: torch.device, + ) -> torch.Tensor: + buffer = self._rocm_aux_gemm_buffers.get(name) + if ( + buffer is None + or buffer.shape[1:] != shape[1:] + or buffer.shape[0] < shape[0] + or buffer.dtype != dtype + or buffer.device != device + ): + buffer = torch.empty(shape, device=device, dtype=dtype) + self._rocm_aux_gemm_buffers[name] = buffer + if buffer.shape == shape: + return buffer + return buffer[: shape[0]] + + def _project_compressor_kv_score( + self, + hidden_states: torch.Tensor, + compressor: DeepseekCompressor, + ) -> torch.Tensor: + return torch.mm( + hidden_states, + compressor.fused_wkv_wgate.weight.T, + out_dtype=torch.float32, + ) + def attn_gemm_parallel_execute( self, hidden_states: torch.Tensor, aux_streams: list[torch.cuda.Stream] | None, + defer_rocm_branch_projections: bool = False, ) -> tuple[Any, ...]: if aux_streams is not None: assert len(aux_streams) >= 3 @@ -407,42 +506,29 @@ def attn_gemm_parallel_execute( current_platform.is_rocm() and aux_streams is not None ) - def rocm_aux_buffer( - name: str, - shape: tuple[int, int], - dtype: torch.dtype, - ) -> torch.Tensor: - buffer = self._rocm_aux_gemm_buffers.get(name) - if ( - buffer is None - or buffer.shape != shape - or buffer.dtype != dtype - or buffer.device != hidden_states.device - ): - buffer = torch.empty(shape, device=hidden_states.device, dtype=dtype) - self._rocm_aux_gemm_buffers[name] = buffer - return buffer - # fused_wqa_wkv (heaviest) on default; the three lighter input GEMMs # on aux streams 0..2 when their owning module exists. ln_events[0] # is the fan-out start event; ln_events[1..3] are per-aux done events. - # ROCm keeps aux_streams as None in this wrapper and falls back to - # sequential execution on the current stream. + # The ROCm SGLang-style strategy deliberately defers these branch + # projections until the post-rmsnorm fan-out. That avoids capturing + # side-stream hipBLASLt nodes beside the fused WQA/WKV GEMM, which is + # the graph topology that previously hung/regressed on ROCm. aux_fns: list[Callable[[], Any] | None] = [None, None, None] - if self.compressor is not None: + if self.compressor is not None and not defer_rocm_branch_projections: # Local ref so the closure keeps a non-None type for mypy. compressor = self.compressor def compressor_kv_score() -> torch.Tensor: out = ( - rocm_aux_buffer( + self._rocm_aux_buffer( "compressor_kv_score", ( hidden_states.shape[0], compressor.fused_wkv_wgate.weight.shape[0], ), torch.float32, + hidden_states.device, ) if use_rocm_graph_safe_buffers else None @@ -462,15 +548,16 @@ def compressor_kv_score() -> torch.Tensor: aux_fns[0] = compressor_kv_score - if self.indexer is not None: + if self.indexer is not None and not defer_rocm_branch_projections: indexer = self.indexer def indexer_weights_proj() -> torch.Tensor: if use_rocm_graph_safe_buffers: - out = rocm_aux_buffer( + out = self._rocm_aux_buffer( "indexer_weights", (hidden_states.shape[0], indexer.weights_proj.weight.shape[0]), hidden_states.dtype, + hidden_states.device, ) return torch.mm( hidden_states, @@ -483,13 +570,14 @@ def indexer_weights_proj() -> torch.Tensor: def indexer_compressor_kv_score() -> torch.Tensor: out = ( - rocm_aux_buffer( + self._rocm_aux_buffer( "indexer_kv_score", ( hidden_states.shape[0], indexer.compressor.fused_wkv_wgate.weight.shape[0], ), torch.float32, + hidden_states.device, ) if use_rocm_graph_safe_buffers else None @@ -536,7 +624,11 @@ def fused_wqa_wkv() -> torch.Tensor: <= envs.VLLM_MULTI_STREAM_GEMM_TOKEN_THRESHOLD, ) - if use_rocm_graph_safe_buffers and self.indexer is not None: + if ( + use_rocm_graph_safe_buffers + and self.indexer is not None + and not defer_rocm_branch_projections + ): if indexer_weights is None: indexer_weights = rocm_deferred_indexer_weights_proj() if indexer_kv_score is None: @@ -552,10 +644,23 @@ def attention_impl( ) -> None: forward_context = get_forward_context() attn_metadata = forward_context.attn_metadata - aux_streams = self._aux_streams_for_step(attn_metadata) + rocm_ms_strategy = self._rocm_csa_ms_strategy_for_step( + forward_context, attn_metadata + ) + aux_streams = self._aux_streams_for_step(rocm_ms_strategy, attn_metadata) + defer_rocm_branch_projections = ( + current_platform.is_rocm() + and rocm_ms_strategy == "sglang" + and envs.VLLM_ROCM_DSV4_CSA_MS_DEFER_PROJECTIONS + ) + gemm_aux_streams = None if current_platform.is_rocm() else aux_streams qr_kv, kv_score, indexer_kv_score, indexer_weights = ( - self.attn_gemm_parallel_execute(hidden_states, aux_streams) + self.attn_gemm_parallel_execute( + hidden_states, + gemm_aux_streams, + defer_rocm_branch_projections=defer_rocm_branch_projections, + ) ) qr, kv = qr_kv.split([self.q_lora_rank, self.head_dim], dim=-1) @@ -571,7 +676,21 @@ def attention_impl( # on the default stream so q stays on its consumer stream (mla_attn # downstream reads q on default). Indexer/compressor go on aux for # overlap with default's GEMM + cache write. - post_aux_streams = None if current_platform.is_rocm() else aux_streams + post_aux_streams = aux_streams + if ( + current_platform.is_rocm() + and post_aux_streams is not None + and len(post_aux_streams) >= 5 + ): + # ROCm SGLang-style mapping: + # aux[0]: main compressor branch + # aux[1]: C4 indexer branch + # aux[2:4]: C4 indexer sub-branches + outer_post_aux_streams = [post_aux_streams[1], post_aux_streams[0]] + elif post_aux_streams is not None: + outer_post_aux_streams = [post_aux_streams[0], post_aux_streams[1]] + else: + outer_post_aux_streams = None if self.indexer is not None: indexer = self.indexer # Local ref so the closure keeps a non-None type for mypy. @@ -586,36 +705,79 @@ def wq_b_kv_insert() -> torch.Tensor: # 3-way overlap (matches TRT-LLM PR #14142 Level 1): default runs # wq_b+kv_insert; slot [0] runs the full indexer; slot [1] runs the # MLA compressor. Slot [2] is reserved for the indexer's inner - # overlap. ROCm (aux_streams is None) falls back to sequential. - q, _ = execute_in_parallel( + # overlap. ROCm uses the SGLang-style deferred projection path only + # under VLLM_ROCM_DSV4_CSA_MULTISTREAM=1. + run_indexer: Callable[[], Any] = lambda: indexer( + hidden_states, + qr, + indexer_kv_score, + indexer_weights, + positions, + self.indexer_rotary_emb, + use_aux_stream=post_aux_streams is not None + and ( + not current_platform.is_rocm() + or ( + rocm_ms_strategy == "sglang" + and envs.VLLM_ROCM_DSV4_CSA_MS_INDEXER_SUBSTREAMS + ) + ), + project_inputs=defer_rocm_branch_projections, + ) + + def run_compressor() -> Any: + local_kv_score = kv_score + if local_kv_score is None: + local_kv_score = self._project_compressor_kv_score( + hidden_states, compressor + ) + return compressor(local_kv_score, positions, self.rotary_emb) + + indexer_fn: Callable[[], Any] | None = run_indexer + compressor_fn: Callable[[], Any] | None = run_compressor + if current_platform.is_rocm() and ( + rocm_ms_strategy == "indexer_only" + or not envs.VLLM_ROCM_DSV4_CSA_MS_MAIN_COMPRESSOR + ): + compressor_fn = None + if ( + current_platform.is_rocm() + and not envs.VLLM_ROCM_DSV4_CSA_MS_OUTER_INDEXER + ): + indexer_fn = None + + q, (indexer_result, compressor_result) = execute_in_parallel( wq_b_kv_insert, - [ - lambda: indexer( - hidden_states, - qr, - indexer_kv_score, - indexer_weights, - positions, - self.indexer_rotary_emb, - use_aux_stream=post_aux_streams is not None - and not current_platform.is_rocm(), - ), - lambda: compressor(kv_score, positions, self.rotary_emb), - ], + [indexer_fn, compressor_fn], self.ln_events[0], [self.ln_events[1], self.ln_events[2]], - ( - [post_aux_streams[0], post_aux_streams[1]] - if post_aux_streams is not None - else None - ), + outer_post_aux_streams, enable=post_aux_streams is not None, ) + if current_platform.is_rocm() and post_aux_streams is not None: + if indexer_result is None and indexer_fn is None: + run_indexer() + if compressor_result is None and compressor_fn is None: + run_compressor() elif self.compressor is not None: # wq_b + kv_insert on default, compressor on aux. aux_stream = ( - post_aux_streams[0] if post_aux_streams is not None else None + ( + post_aux_streams[0] + if current_platform.is_rocm() and post_aux_streams is not None + else outer_post_aux_streams[0] + ) + if ( + (current_platform.is_rocm() and post_aux_streams is not None) + or outer_post_aux_streams is not None + ) + else None ) + if ( + current_platform.is_rocm() + and not envs.VLLM_ROCM_DSV4_CSA_MS_MAIN_COMPRESSOR + ): + aux_stream = None compressor = self.compressor def wq_b_kv_insert() -> torch.Tensor: @@ -623,9 +785,17 @@ def wq_b_kv_insert() -> torch.Tensor: q = self._fused_qnorm_rope_kv_insert(q, kv, positions, attn_metadata) return q + def run_compressor() -> Any: + local_kv_score = kv_score + if local_kv_score is None: + local_kv_score = self._project_compressor_kv_score( + hidden_states, compressor + ) + return compressor(local_kv_score, positions, self.rotary_emb) + q, _ = maybe_execute_in_parallel( wq_b_kv_insert, - lambda: compressor(kv_score, positions, self.rotary_emb), + run_compressor, self.ln_events[0], self.ln_events[1], aux_stream, @@ -926,6 +1096,7 @@ def __init__( compress_ratio: int = 1, prefix: str = "", aux_stream: torch.cuda.Stream | None = None, + aux_streams: list[torch.cuda.Stream] | None = None, ): super().__init__() self.vllm_config = vllm_config @@ -1011,8 +1182,10 @@ def __init__( use_fp4_cache=self.use_fp4_kv, ) - # None on ROCm — maybe_execute_in_parallel falls back to sequential. + # aux_stream is the legacy two-way split. aux_streams mirrors SGLang's + # C4 indexer sub-branches: [0] Q projection/quant, [1] weights proj. self.aux_stream = aux_stream + self.aux_streams = aux_streams self.ln_events: list[torch.cuda.Event] = [ torch.cuda.Event(), torch.cuda.Event(), @@ -1022,15 +1195,36 @@ def forward( self, hidden_states: torch.Tensor, qr: torch.Tensor, - compressed_kv_score: torch.Tensor, - indexer_weights: torch.Tensor, + compressed_kv_score: torch.Tensor | None, + indexer_weights: torch.Tensor | None, positions: torch.Tensor, rotary_emb: nn.Module, use_aux_stream: bool = True, + project_inputs: bool = False, + compressed_kv_score_out: torch.Tensor | None = None, + indexer_weights_out: torch.Tensor | None = None, ) -> torch.Tensor: compressor = self.compressor - def wq_b_and_q_quant(): + def project_indexer_weights() -> torch.Tensor: + # Keep the ROCm deferred path numerically identical to the baseline + # ReplicatedLinear path. Raw torch.mm can select a different BF16 + # GEMM implementation for this skinny projection and perturb the + # sparse top-k boundary. + weights, _ = self.weights_proj(hidden_states) + if indexer_weights_out is not None: + indexer_weights_out.copy_(weights) + return indexer_weights_out + return weights + + def project_compressed_kv_score() -> torch.Tensor: + return torch.mm( + hidden_states, + compressor.fused_wkv_wgate.weight.T, + out_dtype=torch.float32, + ) + + def wq_b_and_q_quant(weights: torch.Tensor): # ReplicatedLinear returns (output, bias); bias is None. q, _ = self.wq_b(qr) q = q.view(-1, self.n_head, self.head_dim) @@ -1038,16 +1232,62 @@ def wq_b_and_q_quant(): positions, q, rotary_emb.cos_sin_cache, - indexer_weights, + weights, self.softmax_scale, self.n_head**-0.5, use_fp4=self.use_fp4_kv, ) + def record_stream(result: Any, stream: torch.cuda.Stream) -> None: + if isinstance(result, torch.Tensor): + result.record_stream(stream) + elif isinstance(result, (tuple, list)): + for item in result: + record_stream(item, stream) + + if ( + project_inputs + and use_aux_stream + and self.aux_streams is not None + and len(self.aux_streams) >= 2 + ): + current_stream = torch.cuda.current_stream() + stream_q = self.aux_streams[0] + stream_weights = self.aux_streams[1] + stream_q.wait_stream(current_stream) + stream_weights.wait_stream(current_stream) + + with torch.cuda.stream(stream_weights): + weights = project_indexer_weights() + weights_ready = stream_weights.record_event() + + q_result: list[Any] = [None] + with torch.cuda.stream(stream_q): + stream_q.wait_event(weights_ready) + q_result[0] = wq_b_and_q_quant(weights) + + local_kv_score = project_compressed_kv_score() + k = compressor(local_kv_score, positions, rotary_emb) + + current_stream.wait_stream(stream_q) + assert q_result[0] is not None + record_stream(q_result[0], current_stream) + (q_quant, weights) = q_result[0] + return self.indexer_op(hidden_states, q_quant, k, weights) + + if project_inputs: + if indexer_weights is None: + indexer_weights = project_indexer_weights() + if compressed_kv_score is None: + compressed_kv_score = project_compressed_kv_score() + + assert indexer_weights is not None + assert compressed_kv_score is not None + # compressor returns None and writes K to the indexer KV cache; the # join orders that write before indexer_op (skip_k_cache_insert=True). (q_quant, weights), k = maybe_execute_in_parallel( - wq_b_and_q_quant, + lambda: wq_b_and_q_quant(indexer_weights), lambda: compressor(compressed_kv_score, positions, rotary_emb), self.ln_events[0], self.ln_events[1], diff --git a/vllm/models/deepseek_v4/nvidia/model.py b/vllm/models/deepseek_v4/nvidia/model.py index 2f0ff1e4bd66..7a97c578521e 100644 --- a/vllm/models/deepseek_v4/nvidia/model.py +++ b/vllm/models/deepseek_v4/nvidia/model.py @@ -8,6 +8,7 @@ import torch import torch.nn as nn +import vllm.envs as envs from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig from vllm.distributed import ( @@ -66,17 +67,22 @@ def make_deepseek_v4_aux_streams() -> list[torch.cuda.Stream] | None: - # Three aux streams: one per non-default input GEMM in - # DeepseekV4MultiHeadLatentAttentionWrapper.attn_gemm_parallel_execute - # (compressor kv_score, indexer.weights_proj, indexer.compressor - # kv_score). fused_wqa_wkv stays on the default stream. - # - # ROCm uses the same attention-side stream choreography as CUDA: - # c4a layers overlap indexer, main KV compression, and SWA insertion; - # c128a layers overlap main KV compression and SWA insertion. - # XPU keeps the serial fallback. if current_platform.is_rocm(): - return [torch.cuda.Stream() for _ in range(3)] + if ( + not envs.VLLM_ROCM_DSV4_CSA_MULTISTREAM + or envs.VLLM_ROCM_DSV4_CSA_MS_STRATEGY.lower() == "off" + ): + return None + # SGLang creates five streams for DeepSeek-V4: three top-level + # preparation branches and two C4-indexer sub-branches. vLLM keeps + # SWA insertion fused with q preparation, but uses the same hierarchy: + # [0] main compressor, [1] C4 indexer, [2:4] indexer sub-branches. + return [ + torch.cuda.Stream( + priority=envs.VLLM_ROCM_DSV4_CSA_MS_AUX_PRIORITY + ) + for _ in range(5) + ] if current_platform.is_xpu(): return None return [torch.cuda.Stream() for _ in range(3)] @@ -797,12 +803,22 @@ def __init__( self.indexer = None if self.compress_ratio == 4: # Only C4A uses sparse attention and hence has indexer. - # aux_stream_list[0] runs indexer.forward() in the wrapper; [2] is - # free here (outer GEMMs joined) for the inner overlap of - # wq_b+fused_indexer_q_rope_quant vs compressor. + # NVIDIA uses aux_stream_list[2] for the legacy inner overlap. ROCm + # SGLang-style decode uses aux_stream_list[2:4] for the C4 indexer + # q/weights sub-branches while the outer indexer branch runs on + # aux_stream_list[1]. indexer_aux_stream = ( aux_stream_list[2] if aux_stream_list is not None else None ) + indexer_aux_streams = ( + aux_stream_list[2:4] + if ( + current_platform.is_rocm() + and aux_stream_list is not None + and len(aux_stream_list) >= 5 + ) + else None + ) self.indexer = DeepseekV4Indexer( vllm_config, config=config, @@ -814,6 +830,7 @@ def __init__( compress_ratio=self.compress_ratio, prefix=f"{prefix}.indexer", aux_stream=indexer_aux_stream, + aux_streams=indexer_aux_streams, ) mla_modules = DeepseekV4MLAModules( diff --git a/vllm/utils/multi_stream_utils.py b/vllm/utils/multi_stream_utils.py index c1c5f52c2b20..65ce67393fa8 100644 --- a/vllm/utils/multi_stream_utils.py +++ b/vllm/utils/multi_stream_utils.py @@ -88,9 +88,11 @@ def execute_in_parallel( start_event fans out from the current stream to every launched aux stream; done_events[i] is recorded after aux_fns[i] so the current stream joins - before returning. Falls back to sequential execution on the current stream - when aux_streams is None or enable is False; in that case default_fn runs - first, then aux_fns in order. + before returning. The default-stream function is enqueued before the aux + functions so the critical path is not delayed by CPU launch overhead from + side-stream branches. Falls back to sequential execution on the current + stream when aux_streams is None or enable is False; in that case default_fn + runs first, then aux_fns in order. Args: default_fn: Callable for the default (current) stream. @@ -125,6 +127,8 @@ def execute_in_parallel( pending: list[tuple[torch.cuda.Event, Any]] = [] start_event.record(current_stream) + default_result = default_fn() + for i, fn in enumerate(aux_fns): if fn is None: continue @@ -135,8 +139,6 @@ def execute_in_parallel( done_events[i].record(aux_stream) pending.append((done_events[i], aux_results[i])) - default_result = default_fn() - for ev, result in pending: current_stream.wait_event(ev) _record_result_stream(result, current_stream) From c4375e1d0a3cc4696c92d13415b591c47d667087 Mon Sep 17 00:00:00 2001 From: vLLM Contributor Date: Mon, 25 May 2026 19:34:11 +0000 Subject: [PATCH 06/21] Enable ROCm DSV4 CSA multistream Port the ROCm DeepSeek-V4 CSA decode path toward the SGLang stream layout and enable it by default for the measured-good range. Implementation: - Split the fused qnorm/rope/kv-cache op into q-only and kv-only torch ops so ROCm can place SWA KV insert on a side stream while the default stream owns q_b + qnorm + rope before MLA attention. - Use five ROCm aux streams matching the SGLang hierarchy: aux0 KV cache insert, aux1 main compressor, aux2 C4 indexer, aux3 indexer Q branch, aux4 indexer weights branch. - Keep branch projection deferral as an A/B knob but disable it by default; ROCm side-stream allocation rechecks did not require the deferred projection path. - Default policy is strategy=sglang, min_decode=1, max_decode=64, graph_modes=none,piecewise. max_decode<=0 remains an opt-in no-cap experiment, but no-cap is not the default because it regressed 1k/1k c128 TTFT badly. - Skip optional flash-attn rotary helper import on ROCm. SGLang/profiling notes: - Inspected SGLang files: deepseek_v4.py, dsv4/indexer.py, dsv4/compressor.py, dsv4/compress_hip.py, and multi_stream_utils.py at SGLang commit 7f45bcdd. - benchmarks/kernels/rocm_dsv4_stream_probe.py showed plain graph replay preserves separate ROCm queues for representative AITER + BF16 GEMM overlap, while torch.compile/full-graph variants can collapse replayed work to stream 0. Keep full graph out of the default multistream policy. Correctness and environment: - Local import proof: vllm.__file__=/shared/amdgpu/home/fai_qle/vllm/vllm/__init__.py. - Hardware/runtime: 8x gfx950, ROCm 7.2.2 / HIP 7.2.53211, torch 2.10.0+git8514f05. - pytest tests/models/test_deepseek_v4_rocm_multistream.py -q: 7 passed. - pytest tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py::test_split_q_and_kv_match_combined -q: 12 passed. - pytest tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py::test_kv_path_matches_reference -q -k 'not 2048': 8 passed, 2 deselected. - GSM8K 1319q 5-shot: accuracy 0.954, invalid 0.000, latency 284.755s, output tok/s 420.527. Benchmark summary: - Baseline: InferenceX official random_range_ratio=0.8 agg_bmk.json. - Test env: TP=8, fp8 KV, async scheduling, no prefix cache, FULL_AND_PIECEWISE compile config, graph_modes=none,piecewise, VLLM_ROCM_USE_AITER=1, VLLM_ROCM_DSV4_CSA_MULTISTREAM=1, strategy=sglang, split_qkv_post=1, defer_projections=0, max_decode=64. - 1k/1k c4,c8,c16,c32,c64,c128,c256,c512 output throughput deltas: +1.39%, +2.04%, +2.05%, +2.46%, +1.93%, +1.69%, +4.39%, +3.88%. TPOT deltas: -0.82%, -1.40%, -1.44%, -2.29%, -1.81%, -1.64%, -4.33%, -3.80%. TTFT improved in all cells. - 8k/1k c4,c8,c16,c32,c64,c128,c256,c512 output throughput deltas: +1.63%, +1.39%, +1.74%, +1.52%, +1.49%, +1.33%, +7.11%, +1.91%. TPOT deltas: -1.34%, -1.19%, -1.66%, -1.55%, -1.53%, -1.33%, -6.66%, -1.94%. TTFT improved through c256; c512 mean TTFT was +0.13% while p99 improved slightly. - No-cap one-wave A/B was not uniformly positive: 1k/1k c128 regressed output -2.13% and TTFT +65.84%, although c512 improved. Keep the default cap at 64 and leave no-cap as an explicit experiment knob. Co-authored-by: OpenAI Codex Signed-off-by: vLLM Contributor --- ...deepseek_v4_qnorm_rope_kv_insert_kernel.cu | 198 ++++++++++++ csrc/ops.h | 9 + csrc/torch_bindings.cpp | 12 + ..._fused_deepseek_v4_qnorm_rope_kv_insert.py | 71 ++++ .../test_deepseek_v4_rocm_multistream.py | 20 +- vllm/envs.py | 12 +- .../layers/rotary_embedding/common.py | 15 +- vllm/models/deepseek_v4/attention.py | 303 ++++++++++++++++-- vllm/models/deepseek_v4/nvidia/model.py | 29 +- 9 files changed, 623 insertions(+), 46 deletions(-) diff --git a/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu b/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu index e4d432cac97f..3c0ce390af07 100644 --- a/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu +++ b/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu @@ -644,6 +644,127 @@ void launchFusedDeepseekV4QNormRopeKVRopeQuantInsert( #undef DISPATCH } +template +__global__ void deepseekV4QNormRopeKernel( + scalar_t_in* __restrict__ q_inout, // [N, H, 512] bf16, in place + int64_t const* __restrict__ position_ids, // [N] i64 + float const* __restrict__ cos_sin_cache, // [max_pos, 64] fp32 + float const eps, int const num_tokens, int const num_heads_q) { +#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800) && !defined(USE_ROCM) + if constexpr (std::is_same_v) { + return; + } else { +#endif + int const warpsPerBlock = blockDim.x / 32; + int const warpId = threadIdx.x / 32; + int const laneId = threadIdx.x % 32; + int const globalWarpIdx = blockIdx.x * warpsPerBlock + warpId; + + int const tokenIdx = globalWarpIdx / num_heads_q; + int const headIdx = globalWarpIdx % num_heads_q; + if (tokenIdx >= num_tokens) return; + + int const dim_base = laneId * kElemsPerLane; + scalar_t_in const* src_ptr = + q_inout + + (static_cast(tokenIdx) * num_heads_q + headIdx) * kHeadDim + + dim_base; + uint4 const v0 = *reinterpret_cast(src_ptr); + uint4 const v1 = *reinterpret_cast(src_ptr + 8); + + processDeepseekV4Slot( + v0, v1, tokenIdx, headIdx, dim_base, laneId, num_heads_q, eps, q_inout, + nullptr, nullptr, position_ids, cos_sin_cache, 0, 0); +#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800) && !defined(USE_ROCM) + } +#endif +} + +template +__global__ void deepseekV4KVRopeQuantInsertKernel( + scalar_t_in const* __restrict__ kv_in, // [N, 512] bf16 + uint8_t* __restrict__ k_cache, // [num_blocks, block_stride] + int64_t const* __restrict__ slot_mapping, // [num_tokens_insert] i64 + int64_t const* __restrict__ position_ids, // [N] i64 + float const* __restrict__ cos_sin_cache, // [max_pos, 64] fp32 + int const num_tokens_full, int const num_tokens_insert, + int const cache_block_size, int const kv_block_stride) { +#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800) && !defined(USE_ROCM) + if constexpr (std::is_same_v) { + return; + } else { +#endif + int const warpsPerBlock = blockDim.x / 32; + int const warpId = threadIdx.x / 32; + int const laneId = threadIdx.x % 32; + int const tokenIdx = blockIdx.x * warpsPerBlock + warpId; + if (tokenIdx >= num_tokens_insert) return; + if (tokenIdx >= num_tokens_full) return; + + int const dim_base = laneId * kElemsPerLane; + scalar_t_in const* src_ptr = + kv_in + static_cast(tokenIdx) * kHeadDim + dim_base; + uint4 const v0 = *reinterpret_cast(src_ptr); + uint4 const v1 = *reinterpret_cast(src_ptr + 8); + + // num_heads_q=0 and slotIdx=0 select the KV branch in processDeepseekV4Slot. + processDeepseekV4Slot( + v0, v1, tokenIdx, 0, dim_base, laneId, 0, 0.0f, nullptr, k_cache, + slot_mapping, position_ids, cos_sin_cache, cache_block_size, + kv_block_stride); +#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800) && !defined(USE_ROCM) + } +#endif +} + +template +void launchDeepseekV4QNormRope(scalar_t_in* q_inout, + int64_t const* position_ids, + float const* cos_sin_cache, float const eps, + int const num_tokens, int const num_heads_q, + cudaStream_t stream) { + constexpr int kBlockSize = 256; + constexpr int kWarpsPerBlock = kBlockSize / 32; + int64_t const total_warps = + static_cast(num_tokens) * num_heads_q; + int const grid = + static_cast((total_warps + kWarpsPerBlock - 1) / kWarpsPerBlock); +#ifndef USE_ROCM + static int const sm_version = getSMVersion(); + TORCH_CHECK(sm_version >= 80, + "deepseek_v4_qnorm_rope requires sm_80+ (Ampere or newer); got " + "sm_", + sm_version); +#endif + deepseekV4QNormRopeKernel<<>>( + q_inout, position_ids, cos_sin_cache, eps, num_tokens, num_heads_q); +} + +template +void launchDeepseekV4KVRopeQuantInsert( + scalar_t_in const* kv_in, uint8_t* k_cache, + int64_t const* slot_mapping, int64_t const* position_ids, + float const* cos_sin_cache, int const num_tokens_full, + int const num_tokens_insert, int const cache_block_size, + int const kv_block_stride, cudaStream_t stream) { + constexpr int kBlockSize = 256; + constexpr int kWarpsPerBlock = kBlockSize / 32; + int const grid = + static_cast((num_tokens_insert + kWarpsPerBlock - 1) / kWarpsPerBlock); +#ifndef USE_ROCM + static int const sm_version = getSMVersion(); + TORCH_CHECK( + sm_version >= 80, + "deepseek_v4_kv_rope_quant_insert requires sm_80+ (Ampere or newer); got " + "sm_", + sm_version); +#endif + deepseekV4KVRopeQuantInsertKernel + <<>>( + kv_in, k_cache, slot_mapping, position_ids, cos_sin_cache, + num_tokens_full, num_tokens_insert, cache_block_size, kv_block_stride); +} + } // namespace deepseek_v4_fused_ops } // namespace vllm @@ -721,3 +842,80 @@ torch::Tensor fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert( }); return q_out; } + +void deepseek_v4_qnorm_rope(torch::Tensor& q, + torch::Tensor const& position_ids, + torch::Tensor const& cos_sin_cache, double eps) { + TORCH_CHECK(q.is_cuda() && q.is_contiguous(), "q must be contiguous CUDA"); + TORCH_CHECK(position_ids.is_cuda() && position_ids.dtype() == torch::kInt64, + "position_ids must be int64 CUDA"); + TORCH_CHECK(cos_sin_cache.is_cuda(), "cos_sin_cache must be CUDA"); + TORCH_CHECK(q.dim() == 3 && q.size(2) == 512, "q shape [N, H, 512]"); + TORCH_CHECK(static_cast(position_ids.size(0)) == q.size(0), + "q/position_ids row counts must match"); + TORCH_CHECK(cos_sin_cache.dim() == 2 && cos_sin_cache.size(1) == 64, + "cos_sin_cache shape [max_pos, 64]"); + TORCH_CHECK(cos_sin_cache.dtype() == torch::kFloat32, + "cos_sin_cache must be float32"); + + int const num_tokens = static_cast(q.size(0)); + int const num_heads_q = static_cast(q.size(1)); + if (num_tokens == 0 || num_heads_q == 0) return; + + at::cuda::OptionalCUDAGuard device_guard(device_of(q)); + auto stream = at::cuda::getCurrentCUDAStream(); + + VLLM_DISPATCH_HALF_TYPES(q.scalar_type(), "deepseek_v4_qnorm_rope", [&] { + using qkv_scalar_t = scalar_t; + vllm::deepseek_v4_fused_ops::launchDeepseekV4QNormRope( + reinterpret_cast(q.data_ptr()), + reinterpret_cast(position_ids.data_ptr()), + cos_sin_cache.data_ptr(), static_cast(eps), num_tokens, + num_heads_q, stream); + }); +} + +void deepseek_v4_kv_rope_quant_insert( + torch::Tensor const& kv, torch::Tensor& k_cache, + torch::Tensor const& slot_mapping, torch::Tensor const& position_ids, + torch::Tensor const& cos_sin_cache, int64_t cache_block_size) { + TORCH_CHECK(kv.is_cuda() && kv.is_contiguous(), "kv must be contiguous CUDA"); + TORCH_CHECK(k_cache.is_cuda(), "k_cache must be CUDA"); + TORCH_CHECK(slot_mapping.is_cuda() && slot_mapping.dtype() == torch::kInt64, + "slot_mapping must be int64 CUDA"); + TORCH_CHECK(position_ids.is_cuda() && position_ids.dtype() == torch::kInt64, + "position_ids must be int64 CUDA"); + TORCH_CHECK(cos_sin_cache.is_cuda(), "cos_sin_cache must be CUDA"); + TORCH_CHECK(kv.dim() == 2 && kv.size(1) == 512, "kv shape [N, 512]"); + TORCH_CHECK(k_cache.dtype() == torch::kUInt8, "k_cache must be uint8"); + TORCH_CHECK(static_cast(position_ids.size(0)) == kv.size(0), + "kv/position_ids row counts must match"); + TORCH_CHECK(slot_mapping.size(0) <= kv.size(0), + "slot_mapping must not exceed kv row count"); + TORCH_CHECK(cos_sin_cache.dim() == 2 && cos_sin_cache.size(1) == 64, + "cos_sin_cache shape [max_pos, 64]"); + TORCH_CHECK(cos_sin_cache.dtype() == torch::kFloat32, + "cos_sin_cache must be float32"); + + int const num_tokens_full = static_cast(kv.size(0)); + int const num_tokens_insert = static_cast(slot_mapping.size(0)); + if (num_tokens_full == 0 || num_tokens_insert == 0) return; + int const cache_block_size_i = static_cast(cache_block_size); + int const kv_block_stride = static_cast(k_cache.stride(0)); + + at::cuda::OptionalCUDAGuard device_guard(device_of(kv)); + auto stream = at::cuda::getCurrentCUDAStream(); + + VLLM_DISPATCH_HALF_TYPES( + kv.scalar_type(), "deepseek_v4_kv_rope_quant_insert", [&] { + using qkv_scalar_t = scalar_t; + vllm::deepseek_v4_fused_ops:: + launchDeepseekV4KVRopeQuantInsert( + reinterpret_cast(kv.data_ptr()), + reinterpret_cast(k_cache.data_ptr()), + reinterpret_cast(slot_mapping.data_ptr()), + reinterpret_cast(position_ids.data_ptr()), + cos_sin_cache.data_ptr(), num_tokens_full, + num_tokens_insert, cache_block_size_i, kv_block_stride, stream); + }); +} diff --git a/csrc/ops.h b/csrc/ops.h index 3c6b2b7b9bc2..66bcad2a4000 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -76,6 +76,15 @@ torch::Tensor fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert( torch::Tensor const& cos_sin_cache, int64_t q_head_padded, double eps, int64_t cache_block_size); +void deepseek_v4_qnorm_rope(torch::Tensor& q, + torch::Tensor const& position_ids, + torch::Tensor const& cos_sin_cache, double eps); + +void deepseek_v4_kv_rope_quant_insert( + torch::Tensor const& kv, torch::Tensor& k_cache, + torch::Tensor const& slot_mapping, torch::Tensor const& position_ids, + torch::Tensor const& cos_sin_cache, int64_t cache_block_size); + void apply_repetition_penalties_(torch::Tensor& logits, const torch::Tensor& prompt_mask, const torch::Tensor& output_mask, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 487ff8df9dd0..14ab2221458f 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -105,6 +105,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.impl("fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert", torch::kCUDA, &fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert); + ops.def("deepseek_v4_qnorm_rope(" + "Tensor! q, Tensor position_ids, Tensor cos_sin_cache, " + "float eps) -> ()"); + ops.impl("deepseek_v4_qnorm_rope", torch::kCUDA, &deepseek_v4_qnorm_rope); + + ops.def("deepseek_v4_kv_rope_quant_insert(" + "Tensor kv, Tensor! k_cache, Tensor slot_mapping, " + "Tensor position_ids, Tensor cos_sin_cache, " + "int cache_block_size) -> ()"); + ops.impl("deepseek_v4_kv_rope_quant_insert", torch::kCUDA, + &deepseek_v4_kv_rope_quant_insert); + // Apply repetition penalties to logits in-place ops.def( "apply_repetition_penalties_(Tensor! logits, Tensor prompt_mask, " diff --git a/tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py b/tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py index a49ea498e5e0..3d0cdb3a477f 100644 --- a/tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py +++ b/tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py @@ -136,6 +136,22 @@ def _call_fused( ) +def _split_ops_available() -> bool: + return hasattr(torch.ops._C, "deepseek_v4_qnorm_rope") and hasattr( + torch.ops._C, "deepseek_v4_kv_rope_quant_insert" + ) + + +def _call_q_split(q, positions, cos_sin_cache, eps): + torch.ops._C.deepseek_v4_qnorm_rope(q, positions, cos_sin_cache, eps) + + +def _call_kv_split(kv, k_cache, slot_mapping, positions, cos_sin_cache, bs): + torch.ops._C.deepseek_v4_kv_rope_quant_insert( + kv, k_cache, slot_mapping, positions, cos_sin_cache, bs + ) + + # ── Test 1: Q path numerical parity ────────────────────────────────────────── @@ -415,3 +431,58 @@ def test_combined_q_and_kv( "padded head slots must be exact zero" ) torch.testing.assert_close(k_cache_fused, k_cache_ref, rtol=0, atol=0) + + +@pytest.mark.skipif( + not _split_ops_available(), + reason="split DeepseekV4 q/kv ops not built in", +) +@pytest.mark.parametrize("num_tokens", [1, 17, 64]) +@pytest.mark.parametrize("n_heads", [8, 64]) +@pytest.mark.parametrize("block_size", [16, 64]) +def test_split_q_and_kv_match_combined( + num_tokens: int, n_heads: int, block_size: int +): + torch.manual_seed(4) + device = "cuda" + dtype = torch.bfloat16 + eps = 1e-6 + max_pos = 4096 + + q = torch.randn(num_tokens, n_heads, HEAD_DIM, dtype=dtype, device=device) + kv = torch.randn(num_tokens, HEAD_DIM, dtype=dtype, device=device) + positions = torch.arange(num_tokens, dtype=torch.int64, device=device) + cos_sin_cache = make_cos_sin_cache(max_pos, ROPE_DIM, torch.float32, device) + + num_blocks = (num_tokens + block_size - 1) // block_size + 1 + slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device) + + q_fused = q.clone() + k_cache_fused = torch.zeros( + num_blocks, block_size * HEAD_BYTES, dtype=torch.uint8, device=device + ) + _call_fused( + q_fused, + kv, + k_cache_fused, + slot_mapping, + positions, + cos_sin_cache, + eps, + block_size, + ) + + q_split = q.clone() + k_cache_split = torch.zeros_like(k_cache_fused) + _call_q_split(q_split, positions, cos_sin_cache, eps) + _call_kv_split( + kv, + k_cache_split, + slot_mapping, + positions, + cos_sin_cache, + block_size, + ) + + torch.testing.assert_close(q_split, q_fused, rtol=0, atol=0) + torch.testing.assert_close(k_cache_split, k_cache_fused, rtol=0, atol=0) diff --git a/tests/models/test_deepseek_v4_rocm_multistream.py b/tests/models/test_deepseek_v4_rocm_multistream.py index 3ba23c687378..bb5c7f1f77f2 100644 --- a/tests/models/test_deepseek_v4_rocm_multistream.py +++ b/tests/models/test_deepseek_v4_rocm_multistream.py @@ -24,13 +24,31 @@ def make_stream(**kwargs): assert len(aux_streams) == 5 -def test_deepseek_v4_rocm_aux_streams_disabled_by_default(monkeypatch): +def test_deepseek_v4_rocm_aux_streams_enabled_by_default(monkeypatch): + streams = [object(), object(), object(), object(), object()] + + def make_stream(**kwargs): + assert kwargs == {"priority": 0} + return streams.pop() + monkeypatch.setattr(rocm_model.current_platform, "is_rocm", lambda: True) monkeypatch.setattr(rocm_model.current_platform, "is_xpu", lambda: False) + monkeypatch.setattr(rocm_model.torch.cuda, "Stream", make_stream) monkeypatch.delenv("VLLM_ROCM_DSV4_CSA_MULTISTREAM", raising=False) aux_streams = rocm_model.make_deepseek_v4_aux_streams() + assert aux_streams is not None + assert len(aux_streams) == 5 + + +def test_deepseek_v4_rocm_aux_streams_disabled_by_env(monkeypatch): + monkeypatch.setattr(rocm_model.current_platform, "is_rocm", lambda: True) + monkeypatch.setattr(rocm_model.current_platform, "is_xpu", lambda: False) + monkeypatch.setenv("VLLM_ROCM_DSV4_CSA_MULTISTREAM", "0") + + aux_streams = rocm_model.make_deepseek_v4_aux_streams() + assert aux_streams is None diff --git a/vllm/envs.py b/vllm/envs.py index ffc488168614..9841b2cba138 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -257,18 +257,19 @@ VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD: int = 256 VLLM_MULTI_STREAM_GEMM_TOKEN_THRESHOLD: int = 1024 - VLLM_ROCM_DSV4_CSA_MULTISTREAM: bool = False + VLLM_ROCM_DSV4_CSA_MULTISTREAM: bool = True VLLM_ROCM_DSV4_CSA_MS_STRATEGY: Literal["off", "indexer_only", "sglang"] = ( "sglang" ) VLLM_ROCM_DSV4_CSA_MS_MIN_DECODE: int = 1 - VLLM_ROCM_DSV4_CSA_MS_MAX_DECODE: int = 32 + VLLM_ROCM_DSV4_CSA_MS_MAX_DECODE: int = 64 VLLM_ROCM_DSV4_CSA_MS_GRAPH_MODES: set[Literal["none", "piecewise", "full"]] = { "none", "piecewise", } VLLM_ROCM_DSV4_CSA_MS_MAIN_COMPRESSOR: bool = True VLLM_ROCM_DSV4_CSA_MS_DEFER_PROJECTIONS: bool = False + VLLM_ROCM_DSV4_CSA_MS_SPLIT_QKV_POST: bool = True VLLM_ROCM_DSV4_CSA_MS_OUTER_INDEXER: bool = True VLLM_ROCM_DSV4_CSA_MS_INDEXER_SUBSTREAMS: bool = True VLLM_ROCM_DSV4_CSA_MS_AUX_PRIORITY: int = 0 @@ -1900,7 +1901,7 @@ def _resolve_rust_frontend_path() -> str | None: # hipBLASLt ordering differ by PyTorch/ROCm release. "indexer_only" keeps # the conservative single indexer branch path for A/B testing. "VLLM_ROCM_DSV4_CSA_MULTISTREAM": lambda: bool( - int(os.getenv("VLLM_ROCM_DSV4_CSA_MULTISTREAM", "0")) + int(os.getenv("VLLM_ROCM_DSV4_CSA_MULTISTREAM", "1")) ), "VLLM_ROCM_DSV4_CSA_MS_STRATEGY": env_with_choices( "VLLM_ROCM_DSV4_CSA_MS_STRATEGY", @@ -1912,7 +1913,7 @@ def _resolve_rust_frontend_path() -> str | None: os.getenv("VLLM_ROCM_DSV4_CSA_MS_MIN_DECODE", "1") ), "VLLM_ROCM_DSV4_CSA_MS_MAX_DECODE": lambda: int( - os.getenv("VLLM_ROCM_DSV4_CSA_MS_MAX_DECODE", "32") + os.getenv("VLLM_ROCM_DSV4_CSA_MS_MAX_DECODE", "64") ), "VLLM_ROCM_DSV4_CSA_MS_GRAPH_MODES": env_set_with_choices( "VLLM_ROCM_DSV4_CSA_MS_GRAPH_MODES", @@ -1926,6 +1927,9 @@ def _resolve_rust_frontend_path() -> str | None: "VLLM_ROCM_DSV4_CSA_MS_DEFER_PROJECTIONS": lambda: bool( int(os.getenv("VLLM_ROCM_DSV4_CSA_MS_DEFER_PROJECTIONS", "0")) ), + "VLLM_ROCM_DSV4_CSA_MS_SPLIT_QKV_POST": lambda: bool( + int(os.getenv("VLLM_ROCM_DSV4_CSA_MS_SPLIT_QKV_POST", "1")) + ), "VLLM_ROCM_DSV4_CSA_MS_OUTER_INDEXER": lambda: bool( int(os.getenv("VLLM_ROCM_DSV4_CSA_MS_OUTER_INDEXER", "1")) ), diff --git a/vllm/model_executor/layers/rotary_embedding/common.py b/vllm/model_executor/layers/rotary_embedding/common.py index 2e407ae7159e..a6eb7aa5e107 100644 --- a/vllm/model_executor/layers/rotary_embedding/common.py +++ b/vllm/model_executor/layers/rotary_embedding/common.py @@ -135,10 +135,17 @@ def __init__( self.enable_fp32_compute = enable_fp32_compute self.apply_rotary_emb_flash_attn = None - if not current_platform.is_cpu() and find_spec("flash_attn") is not None: - from flash_attn.ops.triton.rotary import apply_rotary - - self.apply_rotary_emb_flash_attn = apply_rotary + if ( + not current_platform.is_cpu() + and not current_platform.is_rocm() + and find_spec("flash_attn") is not None + ): + try: + from flash_attn.ops.triton.rotary import apply_rotary + except ImportError: + logger.debug("Failed to import flash_attn rotary helper", exc_info=True) + else: + self.apply_rotary_emb_flash_attn = apply_rotary @staticmethod def forward_static( diff --git a/vllm/models/deepseek_v4/attention.py b/vllm/models/deepseek_v4/attention.py index 45a521095e8a..bd9452c25d28 100644 --- a/vllm/models/deepseek_v4/attention.py +++ b/vllm/models/deepseek_v4/attention.py @@ -437,7 +437,12 @@ def _rocm_csa_ms_strategy_for_step( num_decodes = _deepseek_v4_num_decodes(attn_metadata) min_decode = envs.VLLM_ROCM_DSV4_CSA_MS_MIN_DECODE max_decode = envs.VLLM_ROCM_DSV4_CSA_MS_MAX_DECODE - if num_decodes < min_decode or num_decodes > max_decode: + if num_decodes < min_decode: + return "off" + # max_decode <= 0 is an explicit no-cap experiment knob. The default + # remains capped at 64 because unlimited decode improves the largest + # one-wave backlogs but regresses c128 TTFT substantially on MI355X. + if 0 < max_decode < num_decodes: return "off" graph_mode = forward_context.cudagraph_runtime_mode @@ -486,7 +491,15 @@ def _project_compressor_kv_score( self, hidden_states: torch.Tensor, compressor: DeepseekCompressor, + out: torch.Tensor | None = None, ) -> torch.Tensor: + if out is not None: + return torch.mm( + hidden_states, + compressor.fused_wkv_wgate.weight.T, + out=out, + out_dtype=torch.float32, + ) return torch.mm( hidden_states, compressor.fused_wkv_wgate.weight.T, @@ -509,10 +522,10 @@ def attn_gemm_parallel_execute( # fused_wqa_wkv (heaviest) on default; the three lighter input GEMMs # on aux streams 0..2 when their owning module exists. ln_events[0] # is the fan-out start event; ln_events[1..3] are per-aux done events. - # The ROCm SGLang-style strategy deliberately defers these branch - # projections until the post-rmsnorm fan-out. That avoids capturing - # side-stream hipBLASLt nodes beside the fused WQA/WKV GEMM, which is - # the graph topology that previously hung/regressed on ROCm. + # ROCm can optionally defer these branch projections until the + # post-rmsnorm fan-out for A/B testing. Current ROCm/PyTorch stacks do + # not need that for allocator safety, so the default keeps projections + # here and only overlaps the post-rmsnorm CSA branches. aux_fns: list[Callable[[], Any] | None] = [None, None, None] if self.compressor is not None and not defer_rocm_branch_projections: @@ -636,6 +649,134 @@ def fused_wqa_wkv() -> torch.Tensor: return qr_kv, kv_score, indexer_kv_score, indexer_weights + def _rocm_sglang_post_rmsnorm_prepare( + self, + hidden_states: torch.Tensor, + qr: torch.Tensor, + kv: torch.Tensor, + kv_score: torch.Tensor | None, + indexer_kv_score: torch.Tensor | None, + indexer_weights: torch.Tensor | None, + positions: torch.Tensor, + attn_metadata: ( + dict[str, AttentionMetadata] | list[dict[str, AttentionMetadata]] | None + ), + aux_streams: list[torch.cuda.Stream], + defer_rocm_branch_projections: bool, + ) -> torch.Tensor: + assert len(aux_streams) >= 5 + stream_kv = aux_streams[0] + stream_compressor = aux_streams[1] + stream_indexer = aux_streams[2] + current_stream = torch.cuda.current_stream() + + stream_kv.wait_stream(current_stream) + stream_compressor.wait_stream(current_stream) + stream_indexer.wait_stream(current_stream) + + compressor_kv_score_out: torch.Tensor | None = None + if self.compressor is not None and kv_score is None: + compressor_kv_score_out = self._rocm_aux_buffer( + "deferred_compressor_kv_score", + ( + hidden_states.shape[0], + self.compressor.fused_wkv_wgate.weight.shape[0], + ), + torch.float32, + hidden_states.device, + ) + + indexer_weights_out: torch.Tensor | None = None + indexer_kv_score_out: torch.Tensor | None = None + if self.indexer is not None: + if indexer_weights is None: + indexer_weights_out = self._rocm_aux_buffer( + "deferred_indexer_weights", + ( + hidden_states.shape[0], + self.indexer.weights_proj.weight.shape[0], + ), + hidden_states.dtype, + hidden_states.device, + ) + if indexer_kv_score is None: + indexer_kv_score_out = self._rocm_aux_buffer( + "deferred_indexer_kv_score", + ( + hidden_states.shape[0], + self.indexer.compressor.fused_wkv_wgate.weight.shape[0], + ), + torch.float32, + hidden_states.device, + ) + + def q_b_qnorm_rope() -> torch.Tensor: + q = self.wq_b(qr).view(-1, self.n_local_heads, self.head_dim) + self._fused_qnorm_rope(q, positions) + return q + + def run_kv_insert() -> None: + self._fused_kv_rope_insert(kv, positions, attn_metadata) + + def run_compressor() -> Any: + if self.compressor is None: + return None + local_kv_score = kv_score + if local_kv_score is None: + local_kv_score = self._project_compressor_kv_score( + hidden_states, + self.compressor, + out=compressor_kv_score_out, + ) + return self.compressor(local_kv_score, positions, self.rotary_emb) + + def run_indexer() -> Any: + if self.indexer is None: + return None + return self.indexer( + hidden_states, + qr, + indexer_kv_score, + indexer_weights, + positions, + self.indexer_rotary_emb, + use_aux_stream=envs.VLLM_ROCM_DSV4_CSA_MS_INDEXER_SUBSTREAMS, + project_inputs=defer_rocm_branch_projections, + compressed_kv_score_out=indexer_kv_score_out, + indexer_weights_out=indexer_weights_out, + ) + + launched_indexer = ( + self.indexer is not None and envs.VLLM_ROCM_DSV4_CSA_MS_OUTER_INDEXER + ) + launched_compressor = ( + self.compressor is not None + and envs.VLLM_ROCM_DSV4_CSA_MS_MAIN_COMPRESSOR + ) + + if launched_indexer: + with torch.cuda.stream(stream_indexer): + run_indexer() + with torch.cuda.stream(stream_kv): + run_kv_insert() + if launched_compressor: + with torch.cuda.stream(stream_compressor): + run_compressor() + + q = q_b_qnorm_rope() + + current_stream.wait_stream(stream_kv) + if launched_compressor: + current_stream.wait_stream(stream_compressor) + else: + run_compressor() + if launched_indexer: + current_stream.wait_stream(stream_indexer) + else: + run_indexer() + + return q + def attention_impl( self, hidden_states: torch.Tensor, @@ -672,6 +813,65 @@ def attention_impl( self.eps, ) + if ( + current_platform.is_rocm() + and rocm_ms_strategy == "sglang" + and envs.VLLM_ROCM_DSV4_CSA_MS_SPLIT_QKV_POST + and aux_streams is not None + and len(aux_streams) >= 5 + ): + q = self._rocm_sglang_post_rmsnorm_prepare( + hidden_states, + qr, + kv, + kv_score, + indexer_kv_score, + indexer_weights, + positions, + attn_metadata, + aux_streams, + defer_rocm_branch_projections, + ) + else: + q = self._legacy_post_rmsnorm_prepare( + hidden_states, + qr, + kv, + kv_score, + indexer_kv_score, + indexer_weights, + positions, + attn_metadata, + aux_streams, + rocm_ms_strategy, + defer_rocm_branch_projections, + ) + + # Pad q to FlashMLA-required head count (64 or 128) + if self.n_local_heads < self.padded_heads: + pad_size = self.padded_heads - self.n_local_heads + q = F.pad(q, (0, 0, 0, pad_size), value=0.0) + + # MLA attention writes into the pre-allocated `out` buffer + # ([num_tokens, padded_heads, head_dim]). + self.mla_attn(q, kv, positions, output=out) + + def _legacy_post_rmsnorm_prepare( + self, + hidden_states: torch.Tensor, + qr: torch.Tensor, + kv: torch.Tensor, + kv_score: torch.Tensor | None, + indexer_kv_score: torch.Tensor | None, + indexer_weights: torch.Tensor | None, + positions: torch.Tensor, + attn_metadata: ( + dict[str, AttentionMetadata] | list[dict[str, AttentionMetadata]] | None + ), + aux_streams: list[torch.cuda.Stream] | None, + rocm_ms_strategy: str, + defer_rocm_branch_projections: bool, + ) -> torch.Tensor: # wq_b + kv_insert (+ MLA compressor when an indexer is present) ride # on the default stream so q stays on its consumer stream (mla_attn # downstream reads q on default). Indexer/compressor go on aux for @@ -682,11 +882,10 @@ def attention_impl( and post_aux_streams is not None and len(post_aux_streams) >= 5 ): - # ROCm SGLang-style mapping: - # aux[0]: main compressor branch - # aux[1]: C4 indexer branch - # aux[2:4]: C4 indexer sub-branches - outer_post_aux_streams = [post_aux_streams[1], post_aux_streams[0]] + # Legacy ROCm fallback keeps q+kv fused on default. Match the + # SGLang stream IDs where possible: aux[2] indexer, aux[1] + # compressor. + outer_post_aux_streams = [post_aux_streams[2], post_aux_streams[1]] elif post_aux_streams is not None: outer_post_aux_streams = [post_aux_streams[0], post_aux_streams[1]] else: @@ -705,8 +904,8 @@ def wq_b_kv_insert() -> torch.Tensor: # 3-way overlap (matches TRT-LLM PR #14142 Level 1): default runs # wq_b+kv_insert; slot [0] runs the full indexer; slot [1] runs the # MLA compressor. Slot [2] is reserved for the indexer's inner - # overlap. ROCm uses the SGLang-style deferred projection path only - # under VLLM_ROCM_DSV4_CSA_MULTISTREAM=1. + # overlap. ROCm can optionally defer branch projections under + # VLLM_ROCM_DSV4_CSA_MS_DEFER_PROJECTIONS=1. run_indexer: Callable[[], Any] = lambda: indexer( hidden_states, qr, @@ -763,7 +962,11 @@ def run_compressor() -> Any: # wq_b + kv_insert on default, compressor on aux. aux_stream = ( ( - post_aux_streams[0] + ( + post_aux_streams[1] + if len(post_aux_streams) >= 5 + else post_aux_streams[0] + ) if current_platform.is_rocm() and post_aux_streams is not None else outer_post_aux_streams[0] ) @@ -805,9 +1008,7 @@ def run_compressor() -> Any: q = self.wq_b(qr).view(-1, self.n_local_heads, self.head_dim) q = self._fused_qnorm_rope_kv_insert(q, kv, positions, attn_metadata) - # MLA attention writes into the pre-allocated `out` buffer - # ([num_tokens, padded_heads, head_dim]). - self.mla_attn(q, kv, positions, output=out) + return q def _fused_qnorm_rope_kv_insert( self, @@ -856,6 +1057,47 @@ def _fused_qnorm_rope_kv_insert( swa_metadata.block_size, ) + def _fused_qnorm_rope( + self, + q: torch.Tensor, + positions: torch.Tensor, + ) -> None: + torch.ops._C.deepseek_v4_qnorm_rope( + q, + positions.to(torch.int64), + self.rotary_emb.cos_sin_cache, + self.eps, + ) + + def _fused_kv_rope_insert( + self, + kv: torch.Tensor, + positions: torch.Tensor, + attn_metadata: ( + dict[str, AttentionMetadata] | list[dict[str, AttentionMetadata]] | None + ), + ) -> None: + if not isinstance(attn_metadata, dict): + return + + swa_metadata = cast( + "DeepseekSparseSWAMetadata | None", + attn_metadata.get(self.swa_cache_layer.prefix), + ) + assert swa_metadata is not None + + swa_kv_cache = self.swa_cache_layer.kv_cache + swa_kv_cache_2d = swa_kv_cache.view(swa_kv_cache.shape[0], -1) + + torch.ops._C.deepseek_v4_kv_rope_quant_insert( + kv, + swa_kv_cache_2d, + swa_metadata.slot_mapping, + positions.to(torch.int64), + self.rotary_emb.cos_sin_cache, + swa_metadata.block_size, + ) + @eager_break_during_capture def deepseek_v4_attention( @@ -1207,17 +1449,26 @@ def forward( compressor = self.compressor def project_indexer_weights() -> torch.Tensor: - # Keep the ROCm deferred path numerically identical to the baseline - # ReplicatedLinear path. Raw torch.mm can select a different BF16 - # GEMM implementation for this skinny projection and perturb the - # sparse top-k boundary. - weights, _ = self.weights_proj(hidden_states) if indexer_weights_out is not None: - indexer_weights_out.copy_(weights) - return indexer_weights_out + return torch.mm( + hidden_states, + self.weights_proj.weight.T, + out=indexer_weights_out, + ) + # Keep the default path numerically identical to the baseline + # ReplicatedLinear path. The explicit out= path above is only used + # for ROCm graph-safe deferred side-stream projection buffers. + weights, _ = self.weights_proj(hidden_states) return weights def project_compressed_kv_score() -> torch.Tensor: + if compressed_kv_score_out is not None: + return torch.mm( + hidden_states, + compressor.fused_wkv_wgate.weight.T, + out=compressed_kv_score_out, + out_dtype=torch.float32, + ) return torch.mm( hidden_states, compressor.fused_wkv_wgate.weight.T, @@ -1257,6 +1508,9 @@ def record_stream(result: Any, stream: torch.cuda.Stream) -> None: stream_q.wait_stream(current_stream) stream_weights.wait_stream(current_stream) + local_kv_score = project_compressed_kv_score() + k = compressor(local_kv_score, positions, rotary_emb) + with torch.cuda.stream(stream_weights): weights = project_indexer_weights() weights_ready = stream_weights.record_event() @@ -1266,9 +1520,6 @@ def record_stream(result: Any, stream: torch.cuda.Stream) -> None: stream_q.wait_event(weights_ready) q_result[0] = wq_b_and_q_quant(weights) - local_kv_score = project_compressed_kv_score() - k = compressor(local_kv_score, positions, rotary_emb) - current_stream.wait_stream(stream_q) assert q_result[0] is not None record_stream(q_result[0], current_stream) diff --git a/vllm/models/deepseek_v4/nvidia/model.py b/vllm/models/deepseek_v4/nvidia/model.py index 7a97c578521e..60f3ac8fc147 100644 --- a/vllm/models/deepseek_v4/nvidia/model.py +++ b/vllm/models/deepseek_v4/nvidia/model.py @@ -74,9 +74,9 @@ def make_deepseek_v4_aux_streams() -> list[torch.cuda.Stream] | None: ): return None # SGLang creates five streams for DeepSeek-V4: three top-level - # preparation branches and two C4-indexer sub-branches. vLLM keeps - # SWA insertion fused with q preparation, but uses the same hierarchy: - # [0] main compressor, [1] C4 indexer, [2:4] indexer sub-branches. + # preparation branches and two C4-indexer sub-branches: + # [0] main KV cache insert, [1] main compressor, [2] C4 indexer, + # [3] indexer Q branch, [4] indexer weights branch. return [ torch.cuda.Stream( priority=envs.VLLM_ROCM_DSV4_CSA_MS_AUX_PRIORITY @@ -803,15 +803,22 @@ def __init__( self.indexer = None if self.compress_ratio == 4: # Only C4A uses sparse attention and hence has indexer. - # NVIDIA uses aux_stream_list[2] for the legacy inner overlap. ROCm - # SGLang-style decode uses aux_stream_list[2:4] for the C4 indexer - # q/weights sub-branches while the outer indexer branch runs on - # aux_stream_list[1]. - indexer_aux_stream = ( - aux_stream_list[2] if aux_stream_list is not None else None - ) + # NVIDIA uses aux_stream_list[2] for the legacy inner overlap. + # ROCm SGLang-style decode uses aux_stream_list[3:5] for the C4 + # indexer q/weights sub-branches while the outer indexer branch + # runs on aux_stream_list[2]. + if ( + current_platform.is_rocm() + and aux_stream_list is not None + and len(aux_stream_list) >= 5 + ): + indexer_aux_stream = aux_stream_list[3] + else: + indexer_aux_stream = ( + aux_stream_list[2] if aux_stream_list is not None else None + ) indexer_aux_streams = ( - aux_stream_list[2:4] + aux_stream_list[3:5] if ( current_platform.is_rocm() and aux_stream_list is not None From f89f863f9821eefb65082e5121dd4445c8ec675b Mon Sep 17 00:00:00 2001 From: vLLM Contributor Date: Tue, 26 May 2026 18:45:38 +0000 Subject: [PATCH 07/21] Clean ROCm DSV4 CSA multistream policy Remove the decode-threshold policy knobs from the ROCm DeepSeek-V4 CSA multistream path and keep the default policy simple: when the global ROCm multistream flag is enabled, strategy=overlap applies to every decode-only DeepSeek-V4 CSA step whose graph mode is allowed and whose required streams are present. Implementation: - Rename the full ROCm strategy from sglang to overlap and remove DeepSeek-V4 SGLang wording from touched implementation comments. - Remove VLLM_ROCM_DSV4_CSA_MS_HIGH_DECODE_MIN, VLLM_ROCM_DSV4_CSA_MS_MIN_DECODE, and VLLM_ROCM_DSV4_CSA_MS_MAX_DECODE. - Keep the validated stream topology knobs: graph_modes=none,piecewise, defer_projections=0, split_qkv_post=1, outer_indexer=0, indexer_substreams=1, main_compressor=1, aux_priority=-1. - Drop the now-unused decode-count helper; no decode-count policy remains. - Keep the path ROCm-only: _rocm_csa_ms_strategy_for_step returns off before ROCm policy is used on non-ROCm, and CUDA/NVIDIA keeps existing aux stream behavior. Final selected benchmark versus InferenceX official baseline: - Baseline: InferenceX random_range_ratio=0.8 agg_bmk.json. - Test env: TP=8, fp8 KV, async scheduling, no prefix cache, FULL_AND_PIECEWISE compile config, AITER enabled, VLLM_ROCM_DSV4_CSA_MULTISTREAM=1, graph_modes=none,piecewise, defer_projections=0, split_qkv_post=1, outer_indexer=0, indexer_substreams=1, main_compressor=1, aux_priority=-1. - Source table: /tmp/vllm_rocm_dsv4_ms_results/final_vs_inferencex_summary.md. - 1k/1k c4,c8,c16,c32,c64,c128,c256,c512 output throughput deltas: +14.33%, +14.16%, +12.73%, +9.50%, +8.54%, +5.60%, +8.79%, +15.87%. Mean TTFT base/current seconds: 0.583/0.314, 0.672/0.353, 0.745/0.419, 0.628/0.515, 0.820/0.701, 1.398/1.246, 1.935/1.816, 3.447/3.210. Mean TPOT base/current ms: 49.94/43.92, 52.65/46.39, 56.04/50.16, 64.00/58.08, 77.75/71.81, 225.80/214.12, 151.84/138.59, 224.80/193.04. - 8k/1k c4,c8,c16,c32,c64,c128,c256,c512 output throughput deltas: +20.44%, +19.64%, +17.39%, +14.51%, +10.78%, +5.67%, +18.29%, +13.98%. Mean TTFT base/current seconds: 1.377/1.277, 1.650/1.499, 2.022/1.927, 2.812/2.758, 4.577/4.418, 8.126/7.820, 15.977/14.403, 30.901/28.964. Mean TPOT base/current ms: 56.48/46.79, 61.70/51.46, 71.50/60.77, 90.99/79.30, 130.31/117.49, 320.15/303.16, 391.23/329.94, 675.30/590.97. Correctness/eval notes: - Custom GSM8K 5-shot over all 1319 questions completed at accuracy 0.95375 with invalid_rate 0.0. - The InferenceX-shaped lm-eval c128 run completed with low strict/flexible scores 0.68006/0.72328 after applying the InferenceX chat-template patch; direct single-request GSM8K output was correct. - A multistream-off isolation using VLLM_ROCM_DSV4_CSA_MULTISTREAM=0 entered the same pathological long-output c128 behavior under max_tokens=5376, with 128 running requests and 100% GPU use but only one completed request after many minutes, so this eval issue is not attributed to the ROCm multistream branch yet. Tests: - PYTHONPATH=/shared/amdgpu/home/fai_qle/vllm .venv/bin/python -m pytest tests/models/test_deepseek_v4_rocm_multistream.py -q: 9 passed. - pre-commit run ruff-format --files vllm/envs.py vllm/models/deepseek_v4/nvidia/model.py vllm/models/deepseek_v4/nvidia/ops/attention.py tests/models/test_deepseek_v4_rocm_multistream.py: passed. - pre-commit run ruff-check --files vllm/envs.py vllm/models/deepseek_v4/nvidia/model.py vllm/models/deepseek_v4/nvidia/ops/attention.py tests/models/test_deepseek_v4_rocm_multistream.py: passed. Co-authored-by: OpenAI Codex Signed-off-by: vLLM Contributor --- .../test_deepseek_v4_rocm_multistream.py | 82 ++++++++++++++++--- vllm/envs.py | 34 +++----- vllm/models/deepseek_v4/attention.py | 70 +++++----------- vllm/models/deepseek_v4/nvidia/model.py | 14 ++-- 4 files changed, 110 insertions(+), 90 deletions(-) diff --git a/tests/models/test_deepseek_v4_rocm_multistream.py b/tests/models/test_deepseek_v4_rocm_multistream.py index bb5c7f1f77f2..b6a49d622b11 100644 --- a/tests/models/test_deepseek_v4_rocm_multistream.py +++ b/tests/models/test_deepseek_v4_rocm_multistream.py @@ -9,14 +9,14 @@ def test_deepseek_v4_rocm_aux_streams_enabled(monkeypatch): streams = [object(), object(), object(), object(), object()] def make_stream(**kwargs): - assert kwargs == {"priority": 0} + assert kwargs == {"priority": -1} return streams.pop() monkeypatch.setattr(rocm_model.current_platform, "is_rocm", lambda: True) monkeypatch.setattr(rocm_model.current_platform, "is_xpu", lambda: False) monkeypatch.setattr(rocm_model.torch.cuda, "Stream", make_stream) monkeypatch.setenv("VLLM_ROCM_DSV4_CSA_MULTISTREAM", "1") - monkeypatch.setenv("VLLM_ROCM_DSV4_CSA_MS_STRATEGY", "sglang") + monkeypatch.setenv("VLLM_ROCM_DSV4_CSA_MS_STRATEGY", "overlap") aux_streams = rocm_model.make_deepseek_v4_aux_streams() @@ -28,7 +28,7 @@ def test_deepseek_v4_rocm_aux_streams_enabled_by_default(monkeypatch): streams = [object(), object(), object(), object(), object()] def make_stream(**kwargs): - assert kwargs == {"priority": 0} + assert kwargs == {"priority": -1} return streams.pop() monkeypatch.setattr(rocm_model.current_platform, "is_rocm", lambda: True) @@ -85,18 +85,78 @@ def test_deepseek_v4_aux_streams_cuda_behavior_unchanged(monkeypatch): assert len(aux_streams) == 3 +def test_deepseek_v4_rocm_strategy_cuda_behavior_unchanged(monkeypatch): + class _ForwardContext: + cudagraph_runtime_mode = dsv4_attention.CUDAGraphMode.PIECEWISE + + class _Wrapper: + aux_stream_list = [object(), object(), object()] + indexer = object() + + monkeypatch.setattr(dsv4_attention.current_platform, "is_rocm", lambda: False) + monkeypatch.setattr(dsv4_attention.envs, "VLLM_ROCM_DSV4_CSA_MULTISTREAM", True) + monkeypatch.setattr( + dsv4_attention.envs, "VLLM_ROCM_DSV4_CSA_MS_STRATEGY", "overlap" + ) + + wrapper_cls = dsv4_attention.DeepseekV4MultiHeadLatentAttentionWrapper + method = wrapper_cls._rocm_csa_ms_strategy_for_step + + assert method(_Wrapper(), _ForwardContext(), None) == "off" + + class _Metadata: - def __init__(self, num_decodes: int, num_decode_tokens: int): + def __init__( + self, + num_decodes: int, + num_decode_tokens: int, + num_prefill_tokens: int = 0, + ): self.num_decodes = num_decodes self.num_decode_tokens = num_decode_tokens - - -def test_deepseek_v4_num_decodes_uses_batch_count_not_layer_sum(monkeypatch): + self.num_prefill_tokens = num_prefill_tokens + + +def _rocm_ms_strategy_for_decodes( + monkeypatch, + num_decodes: int, + num_prefill_tokens: int = 0, +) -> str: + class _ForwardContext: + cudagraph_runtime_mode = dsv4_attention.CUDAGraphMode.PIECEWISE + + class _Wrapper: + aux_stream_list = [object(), object(), object(), object(), object()] + indexer = object() + + monkeypatch.setattr(dsv4_attention.current_platform, "is_rocm", lambda: True) + monkeypatch.setattr(dsv4_attention.envs, "VLLM_ROCM_DSV4_CSA_MULTISTREAM", True) + monkeypatch.setattr( + dsv4_attention.envs, "VLLM_ROCM_DSV4_CSA_MS_STRATEGY", "overlap" + ) + monkeypatch.setattr( + dsv4_attention.envs, + "VLLM_ROCM_DSV4_CSA_MS_GRAPH_MODES", + {"none", "piecewise"}, + ) monkeypatch.setattr(dsv4_attention, "DeepseekSparseSWAMetadata", _Metadata) attn_metadata = { - "layer_0.swa": _Metadata(num_decodes=4, num_decode_tokens=4), - "layer_1.swa": _Metadata(num_decodes=4, num_decode_tokens=4), - "layer_2.swa": _Metadata(num_decodes=4, num_decode_tokens=4), + "layer_0.swa": _Metadata(num_decodes, num_decodes, num_prefill_tokens) } - assert dsv4_attention._deepseek_v4_num_decodes(attn_metadata) == 4 + wrapper_cls = dsv4_attention.DeepseekV4MultiHeadLatentAttentionWrapper + method = wrapper_cls._rocm_csa_ms_strategy_for_step + return method(_Wrapper(), _ForwardContext(), attn_metadata) + + +def test_deepseek_v4_rocm_multistream_all_decode_counts(monkeypatch): + assert _rocm_ms_strategy_for_decodes(monkeypatch, 4) == "overlap" + assert _rocm_ms_strategy_for_decodes(monkeypatch, 64) == "overlap" + assert _rocm_ms_strategy_for_decodes(monkeypatch, 128) == "overlap" + assert _rocm_ms_strategy_for_decodes(monkeypatch, 512) == "overlap" + + +def test_deepseek_v4_rocm_multistream_prefill_stays_off(monkeypatch): + strategy = _rocm_ms_strategy_for_decodes(monkeypatch, 4, num_prefill_tokens=1) + + assert strategy == "off" diff --git a/vllm/envs.py b/vllm/envs.py index 9841b2cba138..4341dcc18f99 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -258,11 +258,9 @@ VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD: int = 256 VLLM_MULTI_STREAM_GEMM_TOKEN_THRESHOLD: int = 1024 VLLM_ROCM_DSV4_CSA_MULTISTREAM: bool = True - VLLM_ROCM_DSV4_CSA_MS_STRATEGY: Literal["off", "indexer_only", "sglang"] = ( - "sglang" + VLLM_ROCM_DSV4_CSA_MS_STRATEGY: Literal["off", "indexer_only", "overlap"] = ( + "overlap" ) - VLLM_ROCM_DSV4_CSA_MS_MIN_DECODE: int = 1 - VLLM_ROCM_DSV4_CSA_MS_MAX_DECODE: int = 64 VLLM_ROCM_DSV4_CSA_MS_GRAPH_MODES: set[Literal["none", "piecewise", "full"]] = { "none", "piecewise", @@ -270,9 +268,9 @@ VLLM_ROCM_DSV4_CSA_MS_MAIN_COMPRESSOR: bool = True VLLM_ROCM_DSV4_CSA_MS_DEFER_PROJECTIONS: bool = False VLLM_ROCM_DSV4_CSA_MS_SPLIT_QKV_POST: bool = True - VLLM_ROCM_DSV4_CSA_MS_OUTER_INDEXER: bool = True + VLLM_ROCM_DSV4_CSA_MS_OUTER_INDEXER: bool = False VLLM_ROCM_DSV4_CSA_MS_INDEXER_SUBSTREAMS: bool = True - VLLM_ROCM_DSV4_CSA_MS_AUX_PRIORITY: int = 0 + VLLM_ROCM_DSV4_CSA_MS_AUX_PRIORITY: int = -1 VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary" VLLM_USE_V2_MODEL_RUNNER: bool | None = None VLLM_LOG_MODEL_INSPECTION: bool = False @@ -1895,26 +1893,20 @@ def _resolve_rust_frontend_path() -> str | None: os.getenv("VLLM_MULTI_STREAM_GEMM_TOKEN_THRESHOLD", "1024") ), # ROCm-only opt-in for DeepSeek-V4 CSA decode multi-stream overlap. - # The "sglang" strategy overlaps C4 indexer and compressor preparation - # branches with fine-grained event joins. The deferred projection and - # sub-branch knobs are separate A/B controls because ROCm graph capture and - # hipBLASLt ordering differ by PyTorch/ROCm release. "indexer_only" keeps - # the conservative single indexer branch path for A/B testing. + # The "overlap" strategy runs the KV cache insert, MLA compressor, and C4 + # indexer preparation branches with fine-grained event joins. The deferred + # projection and sub-branch knobs are separate A/B controls because ROCm + # graph capture and hipBLASLt ordering differ by PyTorch/ROCm release. + # "indexer_only" keeps the conservative single indexer branch path for A/B. "VLLM_ROCM_DSV4_CSA_MULTISTREAM": lambda: bool( int(os.getenv("VLLM_ROCM_DSV4_CSA_MULTISTREAM", "1")) ), "VLLM_ROCM_DSV4_CSA_MS_STRATEGY": env_with_choices( "VLLM_ROCM_DSV4_CSA_MS_STRATEGY", - "sglang", - ["off", "indexer_only", "sglang"], + "overlap", + ["off", "indexer_only", "overlap"], case_sensitive=False, ), - "VLLM_ROCM_DSV4_CSA_MS_MIN_DECODE": lambda: int( - os.getenv("VLLM_ROCM_DSV4_CSA_MS_MIN_DECODE", "1") - ), - "VLLM_ROCM_DSV4_CSA_MS_MAX_DECODE": lambda: int( - os.getenv("VLLM_ROCM_DSV4_CSA_MS_MAX_DECODE", "64") - ), "VLLM_ROCM_DSV4_CSA_MS_GRAPH_MODES": env_set_with_choices( "VLLM_ROCM_DSV4_CSA_MS_GRAPH_MODES", ["none", "piecewise"], @@ -1931,13 +1923,13 @@ def _resolve_rust_frontend_path() -> str | None: int(os.getenv("VLLM_ROCM_DSV4_CSA_MS_SPLIT_QKV_POST", "1")) ), "VLLM_ROCM_DSV4_CSA_MS_OUTER_INDEXER": lambda: bool( - int(os.getenv("VLLM_ROCM_DSV4_CSA_MS_OUTER_INDEXER", "1")) + int(os.getenv("VLLM_ROCM_DSV4_CSA_MS_OUTER_INDEXER", "0")) ), "VLLM_ROCM_DSV4_CSA_MS_INDEXER_SUBSTREAMS": lambda: bool( int(os.getenv("VLLM_ROCM_DSV4_CSA_MS_INDEXER_SUBSTREAMS", "1")) ), "VLLM_ROCM_DSV4_CSA_MS_AUX_PRIORITY": lambda: int( - os.getenv("VLLM_ROCM_DSV4_CSA_MS_AUX_PRIORITY", "0") + os.getenv("VLLM_ROCM_DSV4_CSA_MS_AUX_PRIORITY", "-1") ), # Format for saving torch.compile cache artifacts # - "binary": saves as binary file diff --git a/vllm/models/deepseek_v4/attention.py b/vllm/models/deepseek_v4/attention.py index bd9452c25d28..b3675eb126e4 100644 --- a/vllm/models/deepseek_v4/attention.py +++ b/vllm/models/deepseek_v4/attention.py @@ -109,26 +109,10 @@ def _is_decode_only_deepseek_v4_step( if not metadata: return False return all( - item.num_decode_tokens > 0 and item.num_prefill_tokens == 0 - for item in metadata + item.num_decode_tokens > 0 and item.num_prefill_tokens == 0 for item in metadata ) -def _deepseek_v4_num_decodes( - attn_metadata: dict[str, AttentionMetadata] - | list[dict[str, AttentionMetadata]] - | None, -) -> int: - count = 0 - for item in _iter_deepseek_v4_swa_metadata(attn_metadata): - # The attention metadata dict contains one entry per DeepSeek-V4 - # attention/cache layer. Each entry describes the same batch, so - # summing here turns conc=4 into conc=4*num_layers and incorrectly - # gates off low-workload decode. - count = max(count, item.num_decodes or item.num_decode_tokens) - return count - - def _select_v4_sparse_impl() -> "type[DeepseekV4SparseMLAAttentionImpl]": """Pick the platform-specific V4 sparse MLA impl class. Sole platform check.""" if current_platform.is_rocm(): @@ -422,10 +406,7 @@ def _rocm_csa_ms_strategy_for_step( ) -> str: if not current_platform.is_rocm(): return "off" - if ( - not envs.VLLM_ROCM_DSV4_CSA_MULTISTREAM - or self.aux_stream_list is None - ): + if not envs.VLLM_ROCM_DSV4_CSA_MULTISTREAM or self.aux_stream_list is None: return "off" strategy = envs.VLLM_ROCM_DSV4_CSA_MS_STRATEGY.lower() @@ -434,17 +415,6 @@ def _rocm_csa_ms_strategy_for_step( if not _is_decode_only_deepseek_v4_step(attn_metadata): return "off" - num_decodes = _deepseek_v4_num_decodes(attn_metadata) - min_decode = envs.VLLM_ROCM_DSV4_CSA_MS_MIN_DECODE - max_decode = envs.VLLM_ROCM_DSV4_CSA_MS_MAX_DECODE - if num_decodes < min_decode: - return "off" - # max_decode <= 0 is an explicit no-cap experiment knob. The default - # remains capped at 64 because unlimited decode improves the largest - # one-wave backlogs but regresses c128 TTFT substantially on MI355X. - if 0 < max_decode < num_decodes: - return "off" - graph_mode = forward_context.cudagraph_runtime_mode assert graph_mode in CUDAGraphMode.valid_runtime_modes() graph_mode_name = graph_mode.name.lower() @@ -460,7 +430,7 @@ def _rocm_csa_ms_strategy_for_step( if graph_mode_name not in enabled_graph_modes: return "off" - if strategy == "sglang" and len(self.aux_stream_list) < 5: + if strategy == "overlap" and len(self.aux_stream_list) < 5: return "off" if strategy == "indexer_only" and self.indexer is None: return "off" @@ -649,7 +619,7 @@ def fused_wqa_wkv() -> torch.Tensor: return qr_kv, kv_score, indexer_kv_score, indexer_weights - def _rocm_sglang_post_rmsnorm_prepare( + def _rocm_multistream_post_rmsnorm_prepare( self, hidden_states: torch.Tensor, qr: torch.Tensor, @@ -750,8 +720,7 @@ def run_indexer() -> Any: self.indexer is not None and envs.VLLM_ROCM_DSV4_CSA_MS_OUTER_INDEXER ) launched_compressor = ( - self.compressor is not None - and envs.VLLM_ROCM_DSV4_CSA_MS_MAIN_COMPRESSOR + self.compressor is not None and envs.VLLM_ROCM_DSV4_CSA_MS_MAIN_COMPRESSOR ) if launched_indexer: @@ -791,7 +760,7 @@ def attention_impl( aux_streams = self._aux_streams_for_step(rocm_ms_strategy, attn_metadata) defer_rocm_branch_projections = ( current_platform.is_rocm() - and rocm_ms_strategy == "sglang" + and rocm_ms_strategy == "overlap" and envs.VLLM_ROCM_DSV4_CSA_MS_DEFER_PROJECTIONS ) gemm_aux_streams = None if current_platform.is_rocm() else aux_streams @@ -805,6 +774,13 @@ def attention_impl( ) qr, kv = qr_kv.split([self.q_lora_rank, self.head_dim], dim=-1) + use_rocm_split_post = ( + current_platform.is_rocm() + and rocm_ms_strategy == "overlap" + and envs.VLLM_ROCM_DSV4_CSA_MS_SPLIT_QKV_POST + and aux_streams is not None + and len(aux_streams) >= 5 + ) qr, kv = fused_q_kv_rmsnorm( qr, kv, @@ -813,14 +789,8 @@ def attention_impl( self.eps, ) - if ( - current_platform.is_rocm() - and rocm_ms_strategy == "sglang" - and envs.VLLM_ROCM_DSV4_CSA_MS_SPLIT_QKV_POST - and aux_streams is not None - and len(aux_streams) >= 5 - ): - q = self._rocm_sglang_post_rmsnorm_prepare( + if use_rocm_split_post: + q = self._rocm_multistream_post_rmsnorm_prepare( hidden_states, qr, kv, @@ -883,8 +853,8 @@ def _legacy_post_rmsnorm_prepare( and len(post_aux_streams) >= 5 ): # Legacy ROCm fallback keeps q+kv fused on default. Match the - # SGLang stream IDs where possible: aux[2] indexer, aux[1] - # compressor. + # five-stream layout where possible: aux[2] indexer, aux[1] + # compressor, aux[3:5] indexer sub-branches. outer_post_aux_streams = [post_aux_streams[2], post_aux_streams[1]] elif post_aux_streams is not None: outer_post_aux_streams = [post_aux_streams[0], post_aux_streams[1]] @@ -917,7 +887,7 @@ def wq_b_kv_insert() -> torch.Tensor: and ( not current_platform.is_rocm() or ( - rocm_ms_strategy == "sglang" + rocm_ms_strategy == "overlap" and envs.VLLM_ROCM_DSV4_CSA_MS_INDEXER_SUBSTREAMS ) ), @@ -1424,8 +1394,8 @@ def __init__( use_fp4_cache=self.use_fp4_kv, ) - # aux_stream is the legacy two-way split. aux_streams mirrors SGLang's - # C4 indexer sub-branches: [0] Q projection/quant, [1] weights proj. + # aux_stream is the legacy two-way split. aux_streams maps the C4 + # indexer sub-branches as [0] Q projection/quant, [1] weights proj. self.aux_stream = aux_stream self.aux_streams = aux_streams self.ln_events: list[torch.cuda.Event] = [ diff --git a/vllm/models/deepseek_v4/nvidia/model.py b/vllm/models/deepseek_v4/nvidia/model.py index 60f3ac8fc147..05f9a7e24981 100644 --- a/vllm/models/deepseek_v4/nvidia/model.py +++ b/vllm/models/deepseek_v4/nvidia/model.py @@ -73,14 +73,12 @@ def make_deepseek_v4_aux_streams() -> list[torch.cuda.Stream] | None: or envs.VLLM_ROCM_DSV4_CSA_MS_STRATEGY.lower() == "off" ): return None - # SGLang creates five streams for DeepSeek-V4: three top-level - # preparation branches and two C4-indexer sub-branches: + # ROCm uses five streams for DeepSeek-V4 decode overlap: three + # top-level preparation branches and two C4-indexer sub-branches: # [0] main KV cache insert, [1] main compressor, [2] C4 indexer, # [3] indexer Q branch, [4] indexer weights branch. return [ - torch.cuda.Stream( - priority=envs.VLLM_ROCM_DSV4_CSA_MS_AUX_PRIORITY - ) + torch.cuda.Stream(priority=envs.VLLM_ROCM_DSV4_CSA_MS_AUX_PRIORITY) for _ in range(5) ] if current_platform.is_xpu(): @@ -804,9 +802,9 @@ def __init__( if self.compress_ratio == 4: # Only C4A uses sparse attention and hence has indexer. # NVIDIA uses aux_stream_list[2] for the legacy inner overlap. - # ROCm SGLang-style decode uses aux_stream_list[3:5] for the C4 - # indexer q/weights sub-branches while the outer indexer branch - # runs on aux_stream_list[2]. + # ROCm decode overlap uses aux_stream_list[3:5] for the C4 indexer + # q/weights sub-branches while the outer indexer branch runs on + # aux_stream_list[2]. if ( current_platform.is_rocm() and aux_stream_list is not None From 414acf65751ec35ad3c32c78a2efaef79bc12696 Mon Sep 17 00:00:00 2001 From: vLLM Contributor Date: Tue, 26 May 2026 18:59:25 +0000 Subject: [PATCH 08/21] remove rocm-only-related test files Signed-off-by: vLLM Contributor --- benchmarks/kernels/rocm_dsv4_stream_probe.py | 918 ------------------- docs/dev/rocm_multistream_graph_repro.md | 54 -- 2 files changed, 972 deletions(-) delete mode 100644 benchmarks/kernels/rocm_dsv4_stream_probe.py delete mode 100644 docs/dev/rocm_multistream_graph_repro.md diff --git a/benchmarks/kernels/rocm_dsv4_stream_probe.py b/benchmarks/kernels/rocm_dsv4_stream_probe.py deleted file mode 100644 index 8d2d078fc67a..000000000000 --- a/benchmarks/kernels/rocm_dsv4_stream_probe.py +++ /dev/null @@ -1,918 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Probe ROCm stream ownership/overlap for DeepSeek-V4 CSA-like kernels. - -This is intentionally standalone: it does not instantiate the model or require -serving metadata. It answers whether representative kernels used around DSV4 -CSA decode honor the current HIP stream and whether independent streams overlap -when forced outside vLLM's graph/runtime path. - -Repro commands used on ROCm: - - # Control: stream scheduling outside torch.compile, then graph replay. - HIP_VISIBLE_DEVICES=0 VLLM_ROCM_USE_AITER=1 \ - PYTHONPATH=/path/to/vllm \ - rocprofv3 --runtime-trace --group-by-queue \ - --output-directory /tmp/vllm_rocm_dsv4_rocprof_graph \ - --output-file graph --output-format json csv -- \ - .venv/bin/python benchmarks/kernels/rocm_dsv4_stream_probe.py \ - --scenario aiter_vs_bf16_mm_decode \ - --repeats 5 --profile-repeats 0 --warmup 1 --mode graph - - # Repro: stream scheduling inside torch.compile, then graph replay. - HIP_VISIBLE_DEVICES=0 VLLM_ROCM_USE_AITER=1 \ - PYTHONPATH=/path/to/vllm \ - rocprofv3 --runtime-trace --group-by-queue \ - --output-directory /tmp/vllm_rocm_dsv4_rocprof_compile_pair_graph \ - --output-file compile_pair_graph --output-format json csv -- \ - .venv/bin/python benchmarks/kernels/rocm_dsv4_stream_probe.py \ - --scenario aiter_vs_bf16_mm_decode \ - --repeats 5 --profile-repeats 0 --warmup 1 --mode compile_pair_graph - -Finding from the repro: - - graph mode preserves separate ROCm queues for the representative AITER and - BF16 GEMM branches, but the tested decode-sized kernels did not produce - useful timestamp overlap, which points to kernel resource contention rather - than Python stream ownership as the first bottleneck; - - compile_pair_graph collapses the same representative kernels onto ROCm - stream 0 / queue 1 during graph replay, exposing the non-overlap failure - mode seen in vLLM-like compiled scheduling. -""" - -import argparse -import json -import os -from collections import Counter -from collections.abc import Callable -from pathlib import Path -from typing import Any - -import torch - -# Registers vLLM ROCm AITER custom ops. -import vllm._aiter_ops # noqa: F401 -from vllm._aiter_ops import rocm_aiter_ops - -KernelFn = Callable[[], Any] - - -def _fp8_dtype() -> torch.dtype: - return getattr(torch, "float8_e4m3fnuz", torch.int8) - - -def _make_bf16_mm( - m: int, - n: int, - k: int, - device: torch.device, -) -> KernelFn: - a = torch.randn((m, k), device=device, dtype=torch.bfloat16) - b = torch.randn((k, n), device=device, dtype=torch.bfloat16) - out = torch.empty((m, n), device=device, dtype=torch.bfloat16) - - def run() -> torch.Tensor: - torch.mm(a, b, out=out) - return out - - return run - - -def _make_aiter_fp8_block_gemm( - m: int, - n: int, - k: int, - device: torch.device, -) -> KernelFn: - dtype = _fp8_dtype() - a = torch.empty((m, k), device=device, dtype=dtype) - b = torch.empty((n, k), device=device, dtype=dtype) - a_scales = torch.ones((m, (k + 127) // 128), device=device, dtype=torch.float32) - b_scales = torch.ones((n, (k + 127) // 128), device=device, dtype=torch.float32) - - def run() -> torch.Tensor: - out = rocm_aiter_ops.gemm_a8w8_blockscale( - a, - b, - a_scales, - b_scales, - [1, 128], - output_dtype=torch.bfloat16, - ) - return out - - return run - - -def _make_topk( - rows: int, - cols: int, - topk: int, - device: torch.device, -) -> KernelFn: - x = torch.randn((rows, cols), device=device, dtype=torch.float32) - values = torch.empty((rows, topk), device=device, dtype=x.dtype) - indices = torch.empty((rows, topk), device=device, dtype=torch.long) - - def run() -> torch.Tensor: - torch.topk(x, topk, dim=-1, out=(values, indices)) - return values - - return run - - -def _event_time_ms( - fn0: KernelFn, - fn1: KernelFn, - concurrent: bool, - iterations: int, -) -> float: - torch.cuda.synchronize() - stream0 = torch.cuda.Stream() - stream1 = torch.cuda.Stream() - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - current = torch.cuda.current_stream() - - start.record(current) - for _ in range(iterations): - _run_pair(fn0, fn1, concurrent, stream0, stream1) - end.record(current) - end.synchronize() - return start.elapsed_time(end) - - -def _run_pair( - fn0: KernelFn, - fn1: KernelFn, - concurrent: bool, - stream0: torch.cuda.Stream, - stream1: torch.cuda.Stream, -) -> tuple[Any, Any]: - current = torch.cuda.current_stream() - if concurrent: - stream0.wait_stream(current) - stream1.wait_stream(current) - with torch.cuda.stream(stream0): - result0 = fn0() - with torch.cuda.stream(stream1): - result1 = fn1() - current.wait_stream(stream0) - current.wait_stream(stream1) - return result0, result1 - else: - result0 = fn0() - result1 = fn1() - return result0, result1 - - -def _capture_pair_graph( - fn0: KernelFn, - fn1: KernelFn, - concurrent: bool, -) -> tuple[torch.cuda.CUDAGraph, tuple[torch.cuda.Stream, ...]]: - torch.cuda.synchronize() - # Run once outside capture so the caching allocator and any JIT/autotune - # setup do not become capture-time side effects. - warm_stream0 = torch.cuda.Stream() - warm_stream1 = torch.cuda.Stream() - _run_pair(fn0, fn1, concurrent, warm_stream0, warm_stream1) - torch.cuda.synchronize() - - graph = torch.cuda.CUDAGraph() - capture_stream = torch.cuda.Stream() - stream0 = torch.cuda.Stream() - stream1 = torch.cuda.Stream() - current = torch.cuda.current_stream() - capture_stream.wait_stream(current) - with torch.cuda.stream(capture_stream), torch.cuda.graph(graph): - _run_pair(fn0, fn1, concurrent, stream0, stream1) - current.wait_stream(capture_stream) - torch.cuda.synchronize() - return graph, (capture_stream, stream0, stream1) - - -def _graph_time_ms( - fn0: KernelFn, - fn1: KernelFn, - concurrent: bool, - iterations: int, -) -> float: - graph, _streams = _capture_pair_graph(fn0, fn1, concurrent) - torch.cuda.synchronize() - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - current = torch.cuda.current_stream() - start.record(current) - for _ in range(iterations): - graph.replay() - end.record(current) - end.synchronize() - return start.elapsed_time(end) - - -def _capture_runner_graph( - runner: KernelFn, -) -> tuple[torch.cuda.CUDAGraph, torch.cuda.Stream]: - torch.cuda.synchronize() - runner() - torch.cuda.synchronize() - - graph = torch.cuda.CUDAGraph() - capture_stream = torch.cuda.Stream() - current = torch.cuda.current_stream() - capture_stream.wait_stream(current) - with torch.cuda.stream(capture_stream), torch.cuda.graph(graph): - runner() - current.wait_stream(capture_stream) - torch.cuda.synchronize() - return graph, capture_stream - - -def _graph_runner_time_ms(runner: KernelFn, iterations: int) -> float: - graph, _capture_stream = _capture_runner_graph(runner) - torch.cuda.synchronize() - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - current = torch.cuda.current_stream() - start.record(current) - for _ in range(iterations): - graph.replay() - end.record(current) - end.synchronize() - return start.elapsed_time(end) - - -def _compile_kernel(fn: KernelFn) -> KernelFn: - compiled = torch.compile( - fn, - backend="inductor", - dynamic=False, - fullgraph=False, - ) - compiled() - torch.cuda.synchronize() - return compiled - - -def _call_time_ms(fn: KernelFn, iterations: int) -> float: - torch.cuda.synchronize() - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - current = torch.cuda.current_stream() - start.record(current) - for _ in range(iterations): - fn() - end.record(current) - end.synchronize() - return start.elapsed_time(end) - - -def _compile_pair_runner( - fn0: KernelFn, - fn1: KernelFn, - concurrent: bool, - disable_scheduler_compile: bool = False, -) -> KernelFn: - stream0 = torch.cuda.Stream() - stream1 = torch.cuda.Stream() - - def pair_run() -> tuple[Any, Any]: - return _run_pair(fn0, fn1, concurrent, stream0, stream1) - - if disable_scheduler_compile: - pair_run = torch.compiler.disable(pair_run) - return _compile_kernel(pair_run) - - -def _profile_runner_trace( - runner: KernelFn, - profile_dir: Path, - scenario: str, - mode: str, - iterations: int, -) -> Path: - profile_dir.mkdir(parents=True, exist_ok=True) - with torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA, - ], - record_shapes=False, - with_stack=False, - profile_memory=False, - ) as prof: - for _ in range(iterations): - runner() - torch.cuda.synchronize() - - trace = profile_dir / f"{scenario}.{mode}.trace.json" - prof.export_chrome_trace(str(trace)) - return trace - - -def _profile_runner_graph_trace( - runner: KernelFn, - profile_dir: Path, - scenario: str, - mode: str, - iterations: int, -) -> Path: - profile_dir.mkdir(parents=True, exist_ok=True) - graph, _capture_stream = _capture_runner_graph(runner) - with torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA, - ], - record_shapes=False, - with_stack=False, - profile_memory=False, - ) as prof: - for _ in range(iterations): - graph.replay() - torch.cuda.synchronize() - - trace = profile_dir / f"{scenario}.{mode}.trace.json" - prof.export_chrome_trace(str(trace)) - return trace - - -def _profile_trace( - fn0: KernelFn, - fn1: KernelFn, - profile_dir: Path, - scenario: str, - mode: str, - use_graph: bool, - iterations: int, -) -> Path: - profile_dir.mkdir(parents=True, exist_ok=True) - stream0 = torch.cuda.Stream() - stream1 = torch.cuda.Stream() - current = torch.cuda.current_stream() - graph = None - if use_graph: - graph, _streams = _capture_pair_graph(fn0, fn1, concurrent=True) - - with torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA, - ], - record_shapes=False, - with_stack=False, - profile_memory=False, - ) as prof: - if graph is not None: - for _ in range(iterations): - graph.replay() - else: - for _ in range(iterations): - _run_pair(fn0, fn1, True, stream0, stream1) - current.wait_stream(stream0) - current.wait_stream(stream1) - torch.cuda.synchronize() - - trace = profile_dir / f"{scenario}.{mode}.trace.json" - prof.export_chrome_trace(str(trace)) - return trace - - -def _summarize_trace(trace: Path) -> dict[str, object]: - with trace.open() as f: - data = json.load(f) - stream_counts: Counter[str] = Counter() - stream_kernel_counts: Counter[str] = Counter() - stream_durations_us: Counter[str] = Counter() - event_counts: Counter[str] = Counter() - kernel_names: Counter[tuple[str, str]] = Counter() - - for event in data.get("traceEvents", []): - name = event.get("name", "") - if name in ("hipEventRecord", "hipStreamWaitEvent"): - event_counts[name] += 1 - args = event.get("args") or {} - stream = args.get("stream") - if stream is None: - continue - stream = str(stream) - stream_counts[stream] += 1 - if event.get("cat") in ("kernel", "gpu_memcpy"): - stream_kernel_counts[stream] += 1 - stream_durations_us[stream] += event.get("dur", 0.0) or 0.0 - kernel_names[(stream, name[:96])] += 1 - - return { - "streams": dict(stream_counts), - "stream_kernel_counts": dict(stream_kernel_counts), - "stream_durations_us": dict(stream_durations_us), - "event_counts": dict(event_counts), - "top_kernel_names": [ - {"stream": stream, "name": name, "count": count} - for (stream, name), count in kernel_names.most_common(12) - ], - } - - -def _scenario( - name: str, - device: torch.device, -) -> tuple[KernelFn, KernelFn, KernelFn, KernelFn]: - if name == "bf16_mm_pair": - return ( - _make_bf16_mm(256, 2048, 7168, device), - _make_bf16_mm(256, 768, 7168, device), - _make_bf16_mm(256, 2048, 7168, device), - _make_bf16_mm(256, 768, 7168, device), - ) - if name == "aiter_vs_bf16_mm_decode": - return ( - _make_aiter_fp8_block_gemm(4, 2048, 7168, device), - _make_bf16_mm(4, 768, 7168, device), - _make_aiter_fp8_block_gemm(4, 2048, 7168, device), - _make_bf16_mm(4, 768, 7168, device), - ) - if name == "aiter_pair_decode": - return ( - _make_aiter_fp8_block_gemm(4, 2048, 7168, device), - _make_aiter_fp8_block_gemm(4, 768, 7168, device), - _make_aiter_fp8_block_gemm(4, 2048, 7168, device), - _make_aiter_fp8_block_gemm(4, 768, 7168, device), - ) - if name == "topk_vs_bf16_mm": - return ( - _make_topk(64, 8192, 128, device), - _make_bf16_mm(64, 768, 7168, device), - _make_topk(64, 8192, 128, device), - _make_bf16_mm(64, 768, 7168, device), - ) - raise ValueError(f"unknown scenario: {name}") - - -def _mode_enabled(selected: str, mode: str) -> bool: - if selected == "all": - return True - if selected == "both": - return mode in ("eager", "graph") - return selected == mode - - -def main() -> None: - parser = argparse.ArgumentParser() - parser.add_argument( - "--scenario", - action="append", - choices=[ - "bf16_mm_pair", - "aiter_vs_bf16_mm_decode", - "aiter_pair_decode", - "topk_vs_bf16_mm", - ], - ) - parser.add_argument("--repeats", type=int, default=200) - parser.add_argument("--profile-repeats", type=int, default=3) - parser.add_argument("--warmup", type=int, default=5) - parser.add_argument( - "--profile-compile-pair", - action="store_true", - help=( - "Also profile the torch.compile wrapper that includes Python " - "stream scheduling. ROCTracer may be unstable for this mode on " - "some ROCm builds, so timing is collected by default without a " - "trace." - ), - ) - parser.add_argument( - "--profile-compile", - action="store_true", - help=( - "Profile compiled branch modes when running --mode all. Explicit " - "--mode compile and --mode compile_graph runs are profiled by " - "default when --profile-repeats is positive." - ), - ) - parser.add_argument( - "--profile-aggregate", - action="store_true", - help=( - "Profile aggregate modes such as --mode both or --mode all. On " - "some ROCm stacks, multiple profiler sessions in one process can " - "crash during ROCTracer cleanup, so aggregate modes collect timing " - "only unless this is set." - ), - ) - parser.add_argument( - "--mode", - choices=[ - "eager", - "graph", - "compile", - "compile_graph", - "compile_pair", - "compile_pair_graph", - "disabled_compile_pair_graph", - "both", - "all", - ], - default="both", - help=( - "Measure eager forced streams, graph replay, torch.compile branch " - "functions, compiled branches under graph replay, a compiled whole " - "pair scheduler, the compiled scheduler under graph replay, a " - "torch.compiler.disable-protected scheduler under graph replay, or " - "all modes." - ), - ) - parser.add_argument( - "--profile-dir", - type=Path, - default=Path("/tmp/vllm_rocm_dsv4_stream_probe"), - ) - args = parser.parse_args() - - scenarios = args.scenario or [ - "bf16_mm_pair", - "aiter_vs_bf16_mm_decode", - "aiter_pair_decode", - "topk_vs_bf16_mm", - ] - device = torch.device("cuda") - print(f"pid={os.getpid()} device={torch.cuda.get_device_name(0)!r}") - - for name in scenarios: - fn0, fn1, profile_fn0, profile_fn1 = _scenario(name, device) - for _ in range(args.warmup): - _event_time_ms(fn0, fn1, concurrent=True, iterations=1) - seq_ms = _event_time_ms( - fn0, fn1, concurrent=False, iterations=args.repeats - ) - conc_ms = _event_time_ms( - fn0, fn1, concurrent=True, iterations=args.repeats - ) - overlap_pct = 100.0 * (1.0 - conc_ms / seq_ms) if seq_ms else 0.0 - results: dict[str, object] = {"scenario": name} - if _mode_enabled(args.mode, "eager"): - should_profile_eager = args.profile_repeats > 0 and ( - args.mode == "eager" or args.profile_aggregate - ) - eager_result = { - "sequential_ms": seq_ms, - "concurrent_ms": conc_ms, - "overlap_pct": overlap_pct, - } - if should_profile_eager: - trace = _profile_trace( - profile_fn0, - profile_fn1, - args.profile_dir, - name, - "eager", - use_graph=False, - iterations=args.profile_repeats, - ) - eager_result.update({ - "trace": str(trace), - **_summarize_trace(trace), - }) - elif args.profile_repeats > 0: - eager_result["profile_note"] = ( - "eager trace skipped in aggregate mode; use --mode eager " - "or pass --profile-aggregate" - ) - results["eager"] = eager_result - if _mode_enabled(args.mode, "graph"): - try: - should_profile_graph = args.profile_repeats > 0 and ( - args.mode == "graph" or args.profile_aggregate - ) - graph_seq_ms = _graph_time_ms( - fn0, - fn1, - concurrent=False, - iterations=args.repeats, - ) - graph_conc_ms = _graph_time_ms( - fn0, - fn1, - concurrent=True, - iterations=args.repeats, - ) - graph_overlap_pct = ( - 100.0 * (1.0 - graph_conc_ms / graph_seq_ms) - if graph_seq_ms - else 0.0 - ) - graph_result = { - "sequential_ms": graph_seq_ms, - "concurrent_ms": graph_conc_ms, - "overlap_pct": graph_overlap_pct, - } - if should_profile_graph: - graph_trace = _profile_trace( - profile_fn0, - profile_fn1, - args.profile_dir, - name, - "graph", - use_graph=True, - iterations=args.profile_repeats, - ) - graph_result.update({ - "trace": str(graph_trace), - **_summarize_trace(graph_trace), - }) - elif args.profile_repeats > 0: - graph_result["profile_note"] = ( - "graph trace skipped in aggregate mode; use --mode " - "graph or pass --profile-aggregate" - ) - results["graph"] = graph_result - except Exception as exc: - results["graph"] = { - "error": f"{type(exc).__name__}: {exc}", - } - if _mode_enabled(args.mode, "compile"): - try: - should_profile_compile = args.profile_repeats > 0 and ( - args.mode == "compile" - or (args.profile_aggregate and args.profile_compile) - ) - compiled_fn0 = _compile_kernel(fn0) - compiled_fn1 = _compile_kernel(fn1) - compiled_profile_fn0 = _compile_kernel(profile_fn0) - compiled_profile_fn1 = _compile_kernel(profile_fn1) - compile_seq_ms = _event_time_ms( - compiled_fn0, - compiled_fn1, - concurrent=False, - iterations=args.repeats, - ) - compile_conc_ms = _event_time_ms( - compiled_fn0, - compiled_fn1, - concurrent=True, - iterations=args.repeats, - ) - compile_overlap_pct = ( - 100.0 * (1.0 - compile_conc_ms / compile_seq_ms) - if compile_seq_ms - else 0.0 - ) - compile_result = { - "sequential_ms": compile_seq_ms, - "concurrent_ms": compile_conc_ms, - "overlap_pct": compile_overlap_pct, - } - if should_profile_compile: - compile_trace = _profile_trace( - compiled_profile_fn0, - compiled_profile_fn1, - args.profile_dir, - name, - "compile", - use_graph=False, - iterations=args.profile_repeats, - ) - compile_result.update({ - "trace": str(compile_trace), - **_summarize_trace(compile_trace), - }) - else: - compile_result["profile_note"] = ( - "compiled branch trace skipped in aggregate mode; " - "use --mode compile or pass --profile-compile" - ) - results["compile"] = compile_result - except Exception as exc: - results["compile"] = { - "error": f"{type(exc).__name__}: {exc}", - } - if _mode_enabled(args.mode, "compile_graph"): - try: - should_profile_compile_graph = args.profile_repeats > 0 and ( - args.mode == "compile_graph" - or (args.profile_aggregate and args.profile_compile) - ) - compiled_fn0 = _compile_kernel(fn0) - compiled_fn1 = _compile_kernel(fn1) - compiled_profile_fn0 = _compile_kernel(profile_fn0) - compiled_profile_fn1 = _compile_kernel(profile_fn1) - compile_graph_seq_ms = _graph_time_ms( - compiled_fn0, - compiled_fn1, - concurrent=False, - iterations=args.repeats, - ) - compile_graph_conc_ms = _graph_time_ms( - compiled_fn0, - compiled_fn1, - concurrent=True, - iterations=args.repeats, - ) - compile_graph_overlap_pct = ( - 100.0 * (1.0 - compile_graph_conc_ms / compile_graph_seq_ms) - if compile_graph_seq_ms - else 0.0 - ) - compile_graph_result = { - "sequential_ms": compile_graph_seq_ms, - "concurrent_ms": compile_graph_conc_ms, - "overlap_pct": compile_graph_overlap_pct, - } - if should_profile_compile_graph: - compile_graph_trace = _profile_trace( - compiled_profile_fn0, - compiled_profile_fn1, - args.profile_dir, - name, - "compile_graph", - use_graph=True, - iterations=args.profile_repeats, - ) - compile_graph_result.update({ - "trace": str(compile_graph_trace), - **_summarize_trace(compile_graph_trace), - }) - else: - compile_graph_result["profile_note"] = ( - "compiled graph trace skipped in aggregate mode; " - "use --mode compile_graph or pass --profile-compile" - ) - results["compile_graph"] = compile_graph_result - except Exception as exc: - results["compile_graph"] = { - "error": f"{type(exc).__name__}: {exc}", - } - if _mode_enabled(args.mode, "compile_pair"): - try: - compiled_seq_pair = _compile_pair_runner(fn0, fn1, concurrent=False) - compiled_conc_pair = _compile_pair_runner(fn0, fn1, concurrent=True) - compile_pair_seq_ms = _call_time_ms( - compiled_seq_pair, iterations=args.repeats - ) - compile_pair_conc_ms = _call_time_ms( - compiled_conc_pair, iterations=args.repeats - ) - compile_pair_overlap_pct = ( - 100.0 * (1.0 - compile_pair_conc_ms / compile_pair_seq_ms) - if compile_pair_seq_ms - else 0.0 - ) - compile_pair_result = { - "sequential_ms": compile_pair_seq_ms, - "concurrent_ms": compile_pair_conc_ms, - "overlap_pct": compile_pair_overlap_pct, - } - if args.profile_compile_pair and args.profile_repeats > 0: - compiled_profile_pair = _compile_pair_runner( - profile_fn0, - profile_fn1, - concurrent=True, - ) - compile_pair_trace = _profile_runner_trace( - compiled_profile_pair, - args.profile_dir, - name, - "compile_pair", - iterations=args.profile_repeats, - ) - compile_pair_result.update({ - "trace": str(compile_pair_trace), - **_summarize_trace(compile_pair_trace), - }) - else: - compile_pair_result["profile_note"] = ( - "compile_pair trace skipped; pass --profile-compile-pair " - "to profile the compiled Python stream scheduler" - ) - results["compile_pair"] = compile_pair_result - except Exception as exc: - results["compile_pair"] = { - "error": f"{type(exc).__name__}: {exc}", - } - if _mode_enabled(args.mode, "compile_pair_graph"): - try: - compiled_seq_pair = _compile_pair_runner(fn0, fn1, concurrent=False) - compiled_conc_pair = _compile_pair_runner(fn0, fn1, concurrent=True) - compile_pair_graph_seq_ms = _graph_runner_time_ms( - compiled_seq_pair, iterations=args.repeats - ) - compile_pair_graph_conc_ms = _graph_runner_time_ms( - compiled_conc_pair, iterations=args.repeats - ) - compile_pair_graph_overlap_pct = ( - 100.0 - * ( - 1.0 - - compile_pair_graph_conc_ms / compile_pair_graph_seq_ms - ) - if compile_pair_graph_seq_ms - else 0.0 - ) - compile_pair_graph_result = { - "sequential_ms": compile_pair_graph_seq_ms, - "concurrent_ms": compile_pair_graph_conc_ms, - "overlap_pct": compile_pair_graph_overlap_pct, - } - if args.profile_repeats > 0 and ( - args.mode == "compile_pair_graph" or args.profile_aggregate - ): - compiled_profile_pair = _compile_pair_runner( - profile_fn0, - profile_fn1, - concurrent=True, - ) - compile_pair_graph_trace = _profile_runner_graph_trace( - compiled_profile_pair, - args.profile_dir, - name, - "compile_pair_graph", - iterations=args.profile_repeats, - ) - compile_pair_graph_result.update({ - "trace": str(compile_pair_graph_trace), - **_summarize_trace(compile_pair_graph_trace), - }) - elif args.profile_repeats > 0: - compile_pair_graph_result["profile_note"] = ( - "compile_pair_graph trace skipped in aggregate mode; " - "use --mode compile_pair_graph or pass " - "--profile-aggregate" - ) - results["compile_pair_graph"] = compile_pair_graph_result - except Exception as exc: - results["compile_pair_graph"] = { - "error": f"{type(exc).__name__}: {exc}", - } - if _mode_enabled(args.mode, "disabled_compile_pair_graph"): - try: - disabled_seq_pair = _compile_pair_runner( - fn0, - fn1, - concurrent=False, - disable_scheduler_compile=True, - ) - disabled_conc_pair = _compile_pair_runner( - fn0, - fn1, - concurrent=True, - disable_scheduler_compile=True, - ) - disabled_pair_graph_seq_ms = _graph_runner_time_ms( - disabled_seq_pair, iterations=args.repeats - ) - disabled_pair_graph_conc_ms = _graph_runner_time_ms( - disabled_conc_pair, iterations=args.repeats - ) - disabled_pair_graph_overlap_pct = ( - 100.0 - * ( - 1.0 - - disabled_pair_graph_conc_ms / disabled_pair_graph_seq_ms - ) - if disabled_pair_graph_seq_ms - else 0.0 - ) - disabled_pair_graph_result = { - "sequential_ms": disabled_pair_graph_seq_ms, - "concurrent_ms": disabled_pair_graph_conc_ms, - "overlap_pct": disabled_pair_graph_overlap_pct, - } - if args.profile_repeats > 0 and ( - args.mode == "disabled_compile_pair_graph" - or args.profile_aggregate - ): - disabled_profile_pair = _compile_pair_runner( - profile_fn0, - profile_fn1, - concurrent=True, - disable_scheduler_compile=True, - ) - disabled_pair_graph_trace = _profile_runner_graph_trace( - disabled_profile_pair, - args.profile_dir, - name, - "disabled_compile_pair_graph", - iterations=args.profile_repeats, - ) - disabled_pair_graph_result.update({ - "trace": str(disabled_pair_graph_trace), - **_summarize_trace(disabled_pair_graph_trace), - }) - elif args.profile_repeats > 0: - disabled_pair_graph_result["profile_note"] = ( - "disabled_compile_pair_graph trace skipped in aggregate " - "mode; use --mode disabled_compile_pair_graph or pass " - "--profile-aggregate" - ) - results["disabled_compile_pair_graph"] = disabled_pair_graph_result - except Exception as exc: - results["disabled_compile_pair_graph"] = { - "error": f"{type(exc).__name__}: {exc}", - } - - print(json.dumps(results, indent=2, sort_keys=True)) - - -if __name__ == "__main__": - main() diff --git a/docs/dev/rocm_multistream_graph_repro.md b/docs/dev/rocm_multistream_graph_repro.md deleted file mode 100644 index bde457707a99..000000000000 --- a/docs/dev/rocm_multistream_graph_repro.md +++ /dev/null @@ -1,54 +0,0 @@ -# ROCm Multi-Stream Graph Replay Reproducer - -`tools/rocm_multistream_graph_repro.py` isolates the ROCm graph replay hang -found while enabling DeepSeek-V4 CSA decode multi-stream on AMD GPUs. - -## Reproduce the hang - -```bash -HIP_VISIBLE_DEVICES=0 timeout 90s \ - .venv/bin/python tools/rocm_multistream_graph_repro.py --mode allocating -``` - -Observed on MI355X: warmup and graph capture complete, then the first -`CUDAGraph.replay()` does not return. `rocm-smi` shows the GPU at 100% busy with -0% memory bandwidth. - -## Verify the workaround - -```bash -HIP_VISIBLE_DEVICES=0 timeout 90s \ - .venv/bin/python tools/rocm_multistream_graph_repro.py --mode preallocated -``` - -The preallocated mode creates side-stream GEMM output buffers before capture and -uses `torch.mm(..., out=...)`. This replays successfully on the same system. - -`torch.cuda.Event(external=True)` is not available as a ROCm workaround; it -raises `RuntimeError: External events are disallowed in rocm`. - -## DeepSeek-V4 implication - -The vLLM ROCm path must not capture side-stream work that allocates new tensors -inside the graph. The current branch uses explicit stream/event ordering and -preallocated `out=` buffers for the overlapped CSA GEMM. SGLang's DeepSeek-V4 -ROCm implementation goes further: it gates multi-stream to graph-capture decode, -uses pre-created streams, fuses KV cache writes into the K path, and runs the -indexer/compressor through fused kernels with prebuilt metadata. That design -avoids the side-stream allocation pattern reproduced here. - -## Benchmark notes - -Benchmark: DeepSeek-V4-Pro, TP=8, `random` 1k input / 1k output, -`--max-concurrency 4`, 40 prompts, 8 warmups. - -| Path | Output tok/s | Total tok/s | Mean TPOT | -|---|---:|---:|---:| -| ROCm graph-safe workaround, aux disabled | 60.14 | 120.86 | 63.90 ms | -| One preallocated aux CSA GEMM | 57.89 | 116.34 | 66.51 ms | -| One aux CSA GEMM, threshold 16 | 57.42 | 115.39 | 66.88 ms | - -The narrowed vLLM overlap path avoids the hang, but it does not improve this -low-workload benchmark. The likely missing piece versus SGLang is a deeper -fused implementation that removes intermediate tensors and cache-write -allocation from the side-stream path. From 34e9909bd33aae021d727a3fa03610a733aa8c4a Mon Sep 17 00:00:00 2001 From: vLLM Contributor Date: Tue, 26 May 2026 22:57:30 +0000 Subject: [PATCH 09/21] Fix ROCm DSV4 CSA multistream correctness Keep ROCm CSA multistream branch suppression active only when the ROCm multistream scheduler is actually active. The previous gating let ROCm CSA_MS env flags mute indexer/compressor branches even when aux streams were absent, for example MS=0, prefill/mixed steps, or unsupported graph runtime modes. That could leave stale branch state and was the source of the GSM8K accuracy failure. Also add defensive bounds masking in the ROCm AITER MLA sparse helpers so gather/pack/prefill kernels do not form invalid cache or dense-prefix addresses for padded/out-of-range slots. Current code changes are ROCm-scoped. The NVIDIA path is not intended to change; the ROCm env-flag suppression now requires current_platform.is_rocm(), non-None aux streams, and strategy != off. The temporary environment-only gpt_oss_triton_kernels_moe.py import workaround is intentionally not included. Correctness and local import proof: - vllm.__file__=/shared/amdgpu/home/fai_qle/vllm/vllm/__init__.py. - Full GSM8K 1319q local-chat-completions run after the active-gating fix completed with strict-match 0.9613 and flexible-extract 0.9606. - Final diff sanity after restoring upstream ragged prefill: GSM8K limit=64, including known-bad docs 4,13,31,41, completed normally with strict-match 0.9844 and flexible-extract 0.9844. - py_compile attention.py and rocm_aiter_mla_sparse.py: passed. - git diff --check: passed. Benchmark baseline: official InferenceX result only. The local MS=0 run is a diagnostic isolation check and is not used as the baseline or headline comparison. Aligned InferenceX legacy 1k/1k c4 settings: TP=8, fp8 KV, async scheduling, no prefix cache, FULL_AND_PIECEWISE, AITER=1, random_range_ratio=0.8, 40 prompts, 8 warmups. - Official InferenceX baseline: output 76.57 tok/s, mean TTFT 583.40 ms, mean TPOT 49.94 ms, mean ITL 49.95 ms. - Current code with VLLM_ROCM_DSV4_CSA_MULTISTREAM=1: output 78.12 tok/s, mean TTFT 331.62 ms, mean TPOT 49.22 ms, mean ITL 49.22 ms. - Delta versus official InferenceX baseline: output +2.02%, TPOT -1.44%, TTFT -43.16%. Diagnostic only, not the baseline: a same-machine VLLM_ROCM_DSV4_CSA_MULTISTREAM=0 run produced output 77.10 tok/s, mean TTFT 417.01 ms, mean TPOT 49.78 ms, mean ITL 49.79 ms. It was run only to isolate local multistream behavior. The earlier high-win full-suite table was measured before the GSM8K correctness issue was isolated, so it is not used as the corrected PR claim. The corrected result is close to the original cap64 commit-message story: minor TPOT/output-throughput gain versus InferenceX, with the clearest benefit in TTFT. Potential follow-up overlap work: - Revisit a SGLang-like branch projection schedule under ROCm graph capture, but only with branch outputs preallocated and with explicit tests proving no skipped indexer/compressor work in non-active steps. - Profile whether deferred branch projections can be captured safely in piecewise graphs without collapsing side-stream work to stream 0. Co-authored-by: OpenAI Codex Signed-off-by: vLLM Contributor --- vllm/models/deepseek_v4/attention.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/vllm/models/deepseek_v4/attention.py b/vllm/models/deepseek_v4/attention.py index b3675eb126e4..51a4cfce7776 100644 --- a/vllm/models/deepseek_v4/attention.py +++ b/vllm/models/deepseek_v4/attention.py @@ -847,9 +847,14 @@ def _legacy_post_rmsnorm_prepare( # downstream reads q on default). Indexer/compressor go on aux for # overlap with default's GEMM + cache write. post_aux_streams = aux_streams - if ( + rocm_ms_active = ( current_platform.is_rocm() and post_aux_streams is not None + and rocm_ms_strategy != "off" + ) + if ( + rocm_ms_active + and post_aux_streams is not None and len(post_aux_streams) >= 5 ): # Legacy ROCm fallback keeps q+kv fused on default. Match the @@ -904,15 +909,12 @@ def run_compressor() -> Any: indexer_fn: Callable[[], Any] | None = run_indexer compressor_fn: Callable[[], Any] | None = run_compressor - if current_platform.is_rocm() and ( + if rocm_ms_active and ( rocm_ms_strategy == "indexer_only" or not envs.VLLM_ROCM_DSV4_CSA_MS_MAIN_COMPRESSOR ): compressor_fn = None - if ( - current_platform.is_rocm() - and not envs.VLLM_ROCM_DSV4_CSA_MS_OUTER_INDEXER - ): + if rocm_ms_active and not envs.VLLM_ROCM_DSV4_CSA_MS_OUTER_INDEXER: indexer_fn = None q, (indexer_result, compressor_result) = execute_in_parallel( @@ -923,7 +925,7 @@ def run_compressor() -> Any: outer_post_aux_streams, enable=post_aux_streams is not None, ) - if current_platform.is_rocm() and post_aux_streams is not None: + if rocm_ms_active: if indexer_result is None and indexer_fn is None: run_indexer() if compressor_result is None and compressor_fn is None: @@ -947,7 +949,7 @@ def run_compressor() -> Any: else None ) if ( - current_platform.is_rocm() + rocm_ms_active and not envs.VLLM_ROCM_DSV4_CSA_MS_MAIN_COMPRESSOR ): aux_stream = None From 39b8510068d7db813e541e3ab0fb35e9edb6790a Mon Sep 17 00:00:00 2001 From: vLLM Contributor Date: Tue, 26 May 2026 23:21:04 +0000 Subject: [PATCH 10/21] Resolve ROCm DSV4 multistream rebase drift Align the rebased CSA multistream patch with the current upstream DeepSeek-V4 layout. - keep the upstream returned-q fused qnorm/rope/KV op schema while adding the split q and KV helper kernels - dispatch q-only helper kernels through the upstream padded-head template - update multistream tests for the current attention and stream-factory module locations No changes are made to gpt_oss_triton_kernels_moe.py. Signed-off-by: vLLM Contributor --- ...deepseek_v4_qnorm_rope_kv_insert_kernel.cu | 48 +++++++++++++++---- ..._fused_deepseek_v4_qnorm_rope_kv_insert.py | 6 +-- .../test_deepseek_v4_rocm_multistream.py | 46 +++++++++--------- 3 files changed, 64 insertions(+), 36 deletions(-) diff --git a/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu b/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu index 3c0ce390af07..4a0a0b210612 100644 --- a/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu +++ b/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu @@ -644,7 +644,7 @@ void launchFusedDeepseekV4QNormRopeKVRopeQuantInsert( #undef DISPATCH } -template +template __global__ void deepseekV4QNormRopeKernel( scalar_t_in* __restrict__ q_inout, // [N, H, 512] bf16, in place int64_t const* __restrict__ position_ids, // [N] i64 @@ -672,7 +672,7 @@ __global__ void deepseekV4QNormRopeKernel( uint4 const v0 = *reinterpret_cast(src_ptr); uint4 const v1 = *reinterpret_cast(src_ptr + 8); - processDeepseekV4Slot( + processDeepseekV4Slot( v0, v1, tokenIdx, headIdx, dim_base, laneId, num_heads_q, eps, q_inout, nullptr, nullptr, position_ids, cos_sin_cache, 0, 0); #if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800) && !defined(USE_ROCM) @@ -708,7 +708,7 @@ __global__ void deepseekV4KVRopeQuantInsertKernel( uint4 const v1 = *reinterpret_cast(src_ptr + 8); // num_heads_q=0 and slotIdx=0 select the KV branch in processDeepseekV4Slot. - processDeepseekV4Slot( + processDeepseekV4Slot( v0, v1, tokenIdx, 0, dim_base, laneId, 0, 0.0f, nullptr, k_cache, slot_mapping, position_ids, cos_sin_cache, cache_block_size, kv_block_stride); @@ -717,12 +717,11 @@ __global__ void deepseekV4KVRopeQuantInsertKernel( #endif } -template -void launchDeepseekV4QNormRope(scalar_t_in* q_inout, - int64_t const* position_ids, - float const* cos_sin_cache, float const eps, - int const num_tokens, int const num_heads_q, - cudaStream_t stream) { +template +void launchDeepseekV4QNormRopeTemplated( + scalar_t_in* q_inout, int64_t const* position_ids, + float const* cos_sin_cache, float const eps, int const num_tokens, + int const num_heads_q, cudaStream_t stream) { constexpr int kBlockSize = 256; constexpr int kWarpsPerBlock = kBlockSize / 32; int64_t const total_warps = @@ -736,10 +735,39 @@ void launchDeepseekV4QNormRope(scalar_t_in* q_inout, "sm_", sm_version); #endif - deepseekV4QNormRopeKernel<<>>( + deepseekV4QNormRopeKernel + <<>>( q_inout, position_ids, cos_sin_cache, eps, num_tokens, num_heads_q); } +template +void launchDeepseekV4QNormRope(scalar_t_in* q_inout, + int64_t const* position_ids, + float const* cos_sin_cache, float const eps, + int const num_tokens, int const num_heads_q, + cudaStream_t stream) { +#define DISPATCH(N) \ + case N: \ + launchDeepseekV4QNormRopeTemplated( \ + q_inout, position_ids, cos_sin_cache, eps, num_tokens, num_heads_q, \ + stream); \ + return; + + switch (num_heads_q) { + DISPATCH(8) + DISPATCH(16) + DISPATCH(32) + DISPATCH(64) + DISPATCH(128) + default: + TORCH_CHECK(false, + "deepseek_v4_qnorm_rope: unsupported num_heads_q=", + num_heads_q, + " (compiled instantiations: 8, 16, 32, 64, 128)."); + } +#undef DISPATCH +} + template void launchDeepseekV4KVRopeQuantInsert( scalar_t_in const* kv_in, uint8_t* k_cache, diff --git a/tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py b/tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py index 3d0cdb3a477f..c950af03f04e 100644 --- a/tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py +++ b/tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py @@ -457,12 +457,12 @@ def test_split_q_and_kv_match_combined( num_blocks = (num_tokens + block_size - 1) // block_size + 1 slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device) - q_fused = q.clone() k_cache_fused = torch.zeros( num_blocks, block_size * HEAD_BYTES, dtype=torch.uint8, device=device ) - _call_fused( - q_fused, + q_fused = _call_fused( + q, + n_heads, kv, k_cache_fused, slot_mapping, diff --git a/tests/models/test_deepseek_v4_rocm_multistream.py b/tests/models/test_deepseek_v4_rocm_multistream.py index b6a49d622b11..65c8ced3122e 100644 --- a/tests/models/test_deepseek_v4_rocm_multistream.py +++ b/tests/models/test_deepseek_v4_rocm_multistream.py @@ -1,8 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from vllm.models.deepseek_v4.amd import model as rocm_model -from vllm.models.deepseek_v4.nvidia.ops import attention as dsv4_attention +from vllm.models.deepseek_v4 import attention as dsv4_attention +from vllm.models.deepseek_v4.nvidia import model as dsv4_model def test_deepseek_v4_rocm_aux_streams_enabled(monkeypatch): @@ -12,13 +12,13 @@ def make_stream(**kwargs): assert kwargs == {"priority": -1} return streams.pop() - monkeypatch.setattr(rocm_model.current_platform, "is_rocm", lambda: True) - monkeypatch.setattr(rocm_model.current_platform, "is_xpu", lambda: False) - monkeypatch.setattr(rocm_model.torch.cuda, "Stream", make_stream) + monkeypatch.setattr(dsv4_model.current_platform, "is_rocm", lambda: True) + monkeypatch.setattr(dsv4_model.current_platform, "is_xpu", lambda: False) + monkeypatch.setattr(dsv4_model.torch.cuda, "Stream", make_stream) monkeypatch.setenv("VLLM_ROCM_DSV4_CSA_MULTISTREAM", "1") monkeypatch.setenv("VLLM_ROCM_DSV4_CSA_MS_STRATEGY", "overlap") - aux_streams = rocm_model.make_deepseek_v4_aux_streams() + aux_streams = dsv4_model.make_deepseek_v4_aux_streams() assert aux_streams is not None assert len(aux_streams) == 5 @@ -31,43 +31,43 @@ def make_stream(**kwargs): assert kwargs == {"priority": -1} return streams.pop() - monkeypatch.setattr(rocm_model.current_platform, "is_rocm", lambda: True) - monkeypatch.setattr(rocm_model.current_platform, "is_xpu", lambda: False) - monkeypatch.setattr(rocm_model.torch.cuda, "Stream", make_stream) + monkeypatch.setattr(dsv4_model.current_platform, "is_rocm", lambda: True) + monkeypatch.setattr(dsv4_model.current_platform, "is_xpu", lambda: False) + monkeypatch.setattr(dsv4_model.torch.cuda, "Stream", make_stream) monkeypatch.delenv("VLLM_ROCM_DSV4_CSA_MULTISTREAM", raising=False) - aux_streams = rocm_model.make_deepseek_v4_aux_streams() + aux_streams = dsv4_model.make_deepseek_v4_aux_streams() assert aux_streams is not None assert len(aux_streams) == 5 def test_deepseek_v4_rocm_aux_streams_disabled_by_env(monkeypatch): - monkeypatch.setattr(rocm_model.current_platform, "is_rocm", lambda: True) - monkeypatch.setattr(rocm_model.current_platform, "is_xpu", lambda: False) + monkeypatch.setattr(dsv4_model.current_platform, "is_rocm", lambda: True) + monkeypatch.setattr(dsv4_model.current_platform, "is_xpu", lambda: False) monkeypatch.setenv("VLLM_ROCM_DSV4_CSA_MULTISTREAM", "0") - aux_streams = rocm_model.make_deepseek_v4_aux_streams() + aux_streams = dsv4_model.make_deepseek_v4_aux_streams() assert aux_streams is None def test_deepseek_v4_rocm_aux_streams_strategy_off(monkeypatch): - monkeypatch.setattr(rocm_model.current_platform, "is_rocm", lambda: True) - monkeypatch.setattr(rocm_model.current_platform, "is_xpu", lambda: False) + monkeypatch.setattr(dsv4_model.current_platform, "is_rocm", lambda: True) + monkeypatch.setattr(dsv4_model.current_platform, "is_xpu", lambda: False) monkeypatch.setenv("VLLM_ROCM_DSV4_CSA_MULTISTREAM", "1") monkeypatch.setenv("VLLM_ROCM_DSV4_CSA_MS_STRATEGY", "off") - aux_streams = rocm_model.make_deepseek_v4_aux_streams() + aux_streams = dsv4_model.make_deepseek_v4_aux_streams() assert aux_streams is None def test_deepseek_v4_rocm_aux_streams_xpu_fallback(monkeypatch): - monkeypatch.setattr(rocm_model.current_platform, "is_rocm", lambda: False) - monkeypatch.setattr(rocm_model.current_platform, "is_xpu", lambda: True) + monkeypatch.setattr(dsv4_model.current_platform, "is_rocm", lambda: False) + monkeypatch.setattr(dsv4_model.current_platform, "is_xpu", lambda: True) - aux_streams = rocm_model.make_deepseek_v4_aux_streams() + aux_streams = dsv4_model.make_deepseek_v4_aux_streams() assert aux_streams is None @@ -75,11 +75,11 @@ def test_deepseek_v4_rocm_aux_streams_xpu_fallback(monkeypatch): def test_deepseek_v4_aux_streams_cuda_behavior_unchanged(monkeypatch): streams = [object(), object(), object()] - monkeypatch.setattr(rocm_model.current_platform, "is_rocm", lambda: False) - monkeypatch.setattr(rocm_model.current_platform, "is_xpu", lambda: False) - monkeypatch.setattr(rocm_model.torch.cuda, "Stream", streams.pop) + monkeypatch.setattr(dsv4_model.current_platform, "is_rocm", lambda: False) + monkeypatch.setattr(dsv4_model.current_platform, "is_xpu", lambda: False) + monkeypatch.setattr(dsv4_model.torch.cuda, "Stream", streams.pop) - aux_streams = rocm_model.make_deepseek_v4_aux_streams() + aux_streams = dsv4_model.make_deepseek_v4_aux_streams() assert aux_streams is not None assert len(aux_streams) == 3 From 13a46721f58cf20118af2db5b209a7a917873f43 Mon Sep 17 00:00:00 2001 From: vLLM Contributor Date: Tue, 26 May 2026 23:21:38 +0000 Subject: [PATCH 11/21] Restore rotary embedding helper behavior Keep vllm/model_executor/layers/rotary_embedding/common.py aligned with upstream; this PR should not change rotary helper import behavior. Signed-off-by: vLLM Contributor --- .../layers/rotary_embedding/common.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/vllm/model_executor/layers/rotary_embedding/common.py b/vllm/model_executor/layers/rotary_embedding/common.py index a6eb7aa5e107..2e407ae7159e 100644 --- a/vllm/model_executor/layers/rotary_embedding/common.py +++ b/vllm/model_executor/layers/rotary_embedding/common.py @@ -135,17 +135,10 @@ def __init__( self.enable_fp32_compute = enable_fp32_compute self.apply_rotary_emb_flash_attn = None - if ( - not current_platform.is_cpu() - and not current_platform.is_rocm() - and find_spec("flash_attn") is not None - ): - try: - from flash_attn.ops.triton.rotary import apply_rotary - except ImportError: - logger.debug("Failed to import flash_attn rotary helper", exc_info=True) - else: - self.apply_rotary_emb_flash_attn = apply_rotary + if not current_platform.is_cpu() and find_spec("flash_attn") is not None: + from flash_attn.ops.triton.rotary import apply_rotary + + self.apply_rotary_emb_flash_attn = apply_rotary @staticmethod def forward_static( From 04c6ad745dd33ec4ab6b6cc0e06e8f3565a75e17 Mon Sep 17 00:00:00 2001 From: vLLM Contributor Date: Wed, 27 May 2026 02:22:30 +0000 Subject: [PATCH 12/21] Address DeepSeek V4 ROCm review feedback Move ROCm DeepSeek V4 multi-stream behavior out of the NVIDIA implementation, remove temporary environment gates, and keep CuTeDSL sparse compressor paths off ROCm. Tested with targeted ROCm DeepSeek V4 pytest, ruff, InferenceX 1k/1k concurrency 4, and GSM8K concurrency 128. Co-authored-by: OpenAI Codex Signed-off-by: vLLM Contributor --- .../test_deepseek_v4_rocm_multistream.py | 104 ++---------------- vllm/envs.py | 53 --------- vllm/models/deepseek_v4/amd/model.py | 37 +++++-- vllm/models/deepseek_v4/attention.py | 41 +++---- vllm/models/deepseek_v4/nvidia/model.py | 60 +++------- 5 files changed, 72 insertions(+), 223 deletions(-) diff --git a/tests/models/test_deepseek_v4_rocm_multistream.py b/tests/models/test_deepseek_v4_rocm_multistream.py index 65c8ced3122e..d59b1f88b4c8 100644 --- a/tests/models/test_deepseek_v4_rocm_multistream.py +++ b/tests/models/test_deepseek_v4_rocm_multistream.py @@ -1,40 +1,25 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from vllm.models.deepseek_v4 import attention as dsv4_attention -from vllm.models.deepseek_v4.nvidia import model as dsv4_model - +import pytest -def test_deepseek_v4_rocm_aux_streams_enabled(monkeypatch): - streams = [object(), object(), object(), object(), object()] - - def make_stream(**kwargs): - assert kwargs == {"priority": -1} - return streams.pop() +from vllm.models.deepseek_v4 import attention as dsv4_attention +from vllm.models.deepseek_v4.amd import model as dsv4_model +from vllm.platforms import current_platform - monkeypatch.setattr(dsv4_model.current_platform, "is_rocm", lambda: True) - monkeypatch.setattr(dsv4_model.current_platform, "is_xpu", lambda: False) - monkeypatch.setattr(dsv4_model.torch.cuda, "Stream", make_stream) - monkeypatch.setenv("VLLM_ROCM_DSV4_CSA_MULTISTREAM", "1") - monkeypatch.setenv("VLLM_ROCM_DSV4_CSA_MS_STRATEGY", "overlap") +pytestmark = pytest.mark.skipif( + not current_platform.is_rocm(), reason="ROCm-only DeepSeek-V4 tests" +) - aux_streams = dsv4_model.make_deepseek_v4_aux_streams() - assert aux_streams is not None - assert len(aux_streams) == 5 - - -def test_deepseek_v4_rocm_aux_streams_enabled_by_default(monkeypatch): +def test_deepseek_v4_rocm_aux_streams_enabled(monkeypatch): streams = [object(), object(), object(), object(), object()] def make_stream(**kwargs): assert kwargs == {"priority": -1} return streams.pop() - monkeypatch.setattr(dsv4_model.current_platform, "is_rocm", lambda: True) - monkeypatch.setattr(dsv4_model.current_platform, "is_xpu", lambda: False) monkeypatch.setattr(dsv4_model.torch.cuda, "Stream", make_stream) - monkeypatch.delenv("VLLM_ROCM_DSV4_CSA_MULTISTREAM", raising=False) aux_streams = dsv4_model.make_deepseek_v4_aux_streams() @@ -42,69 +27,6 @@ def make_stream(**kwargs): assert len(aux_streams) == 5 -def test_deepseek_v4_rocm_aux_streams_disabled_by_env(monkeypatch): - monkeypatch.setattr(dsv4_model.current_platform, "is_rocm", lambda: True) - monkeypatch.setattr(dsv4_model.current_platform, "is_xpu", lambda: False) - monkeypatch.setenv("VLLM_ROCM_DSV4_CSA_MULTISTREAM", "0") - - aux_streams = dsv4_model.make_deepseek_v4_aux_streams() - - assert aux_streams is None - - -def test_deepseek_v4_rocm_aux_streams_strategy_off(monkeypatch): - monkeypatch.setattr(dsv4_model.current_platform, "is_rocm", lambda: True) - monkeypatch.setattr(dsv4_model.current_platform, "is_xpu", lambda: False) - monkeypatch.setenv("VLLM_ROCM_DSV4_CSA_MULTISTREAM", "1") - monkeypatch.setenv("VLLM_ROCM_DSV4_CSA_MS_STRATEGY", "off") - - aux_streams = dsv4_model.make_deepseek_v4_aux_streams() - - assert aux_streams is None - - -def test_deepseek_v4_rocm_aux_streams_xpu_fallback(monkeypatch): - monkeypatch.setattr(dsv4_model.current_platform, "is_rocm", lambda: False) - monkeypatch.setattr(dsv4_model.current_platform, "is_xpu", lambda: True) - - aux_streams = dsv4_model.make_deepseek_v4_aux_streams() - - assert aux_streams is None - - -def test_deepseek_v4_aux_streams_cuda_behavior_unchanged(monkeypatch): - streams = [object(), object(), object()] - - monkeypatch.setattr(dsv4_model.current_platform, "is_rocm", lambda: False) - monkeypatch.setattr(dsv4_model.current_platform, "is_xpu", lambda: False) - monkeypatch.setattr(dsv4_model.torch.cuda, "Stream", streams.pop) - - aux_streams = dsv4_model.make_deepseek_v4_aux_streams() - - assert aux_streams is not None - assert len(aux_streams) == 3 - - -def test_deepseek_v4_rocm_strategy_cuda_behavior_unchanged(monkeypatch): - class _ForwardContext: - cudagraph_runtime_mode = dsv4_attention.CUDAGraphMode.PIECEWISE - - class _Wrapper: - aux_stream_list = [object(), object(), object()] - indexer = object() - - monkeypatch.setattr(dsv4_attention.current_platform, "is_rocm", lambda: False) - monkeypatch.setattr(dsv4_attention.envs, "VLLM_ROCM_DSV4_CSA_MULTISTREAM", True) - monkeypatch.setattr( - dsv4_attention.envs, "VLLM_ROCM_DSV4_CSA_MS_STRATEGY", "overlap" - ) - - wrapper_cls = dsv4_attention.DeepseekV4MultiHeadLatentAttentionWrapper - method = wrapper_cls._rocm_csa_ms_strategy_for_step - - assert method(_Wrapper(), _ForwardContext(), None) == "off" - - class _Metadata: def __init__( self, @@ -129,16 +51,6 @@ class _Wrapper: aux_stream_list = [object(), object(), object(), object(), object()] indexer = object() - monkeypatch.setattr(dsv4_attention.current_platform, "is_rocm", lambda: True) - monkeypatch.setattr(dsv4_attention.envs, "VLLM_ROCM_DSV4_CSA_MULTISTREAM", True) - monkeypatch.setattr( - dsv4_attention.envs, "VLLM_ROCM_DSV4_CSA_MS_STRATEGY", "overlap" - ) - monkeypatch.setattr( - dsv4_attention.envs, - "VLLM_ROCM_DSV4_CSA_MS_GRAPH_MODES", - {"none", "piecewise"}, - ) monkeypatch.setattr(dsv4_attention, "DeepseekSparseSWAMetadata", _Metadata) attn_metadata = { "layer_0.swa": _Metadata(num_decodes, num_decodes, num_prefill_tokens) diff --git a/vllm/envs.py b/vllm/envs.py index 4341dcc18f99..3a3934f3cdf9 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -257,20 +257,6 @@ VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD: int = 256 VLLM_MULTI_STREAM_GEMM_TOKEN_THRESHOLD: int = 1024 - VLLM_ROCM_DSV4_CSA_MULTISTREAM: bool = True - VLLM_ROCM_DSV4_CSA_MS_STRATEGY: Literal["off", "indexer_only", "overlap"] = ( - "overlap" - ) - VLLM_ROCM_DSV4_CSA_MS_GRAPH_MODES: set[Literal["none", "piecewise", "full"]] = { - "none", - "piecewise", - } - VLLM_ROCM_DSV4_CSA_MS_MAIN_COMPRESSOR: bool = True - VLLM_ROCM_DSV4_CSA_MS_DEFER_PROJECTIONS: bool = False - VLLM_ROCM_DSV4_CSA_MS_SPLIT_QKV_POST: bool = True - VLLM_ROCM_DSV4_CSA_MS_OUTER_INDEXER: bool = False - VLLM_ROCM_DSV4_CSA_MS_INDEXER_SUBSTREAMS: bool = True - VLLM_ROCM_DSV4_CSA_MS_AUX_PRIORITY: int = -1 VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary" VLLM_USE_V2_MODEL_RUNNER: bool | None = None VLLM_LOG_MODEL_INSPECTION: bool = False @@ -1892,45 +1878,6 @@ def _resolve_rust_frontend_path() -> str | None: "VLLM_MULTI_STREAM_GEMM_TOKEN_THRESHOLD": lambda: int( os.getenv("VLLM_MULTI_STREAM_GEMM_TOKEN_THRESHOLD", "1024") ), - # ROCm-only opt-in for DeepSeek-V4 CSA decode multi-stream overlap. - # The "overlap" strategy runs the KV cache insert, MLA compressor, and C4 - # indexer preparation branches with fine-grained event joins. The deferred - # projection and sub-branch knobs are separate A/B controls because ROCm - # graph capture and hipBLASLt ordering differ by PyTorch/ROCm release. - # "indexer_only" keeps the conservative single indexer branch path for A/B. - "VLLM_ROCM_DSV4_CSA_MULTISTREAM": lambda: bool( - int(os.getenv("VLLM_ROCM_DSV4_CSA_MULTISTREAM", "1")) - ), - "VLLM_ROCM_DSV4_CSA_MS_STRATEGY": env_with_choices( - "VLLM_ROCM_DSV4_CSA_MS_STRATEGY", - "overlap", - ["off", "indexer_only", "overlap"], - case_sensitive=False, - ), - "VLLM_ROCM_DSV4_CSA_MS_GRAPH_MODES": env_set_with_choices( - "VLLM_ROCM_DSV4_CSA_MS_GRAPH_MODES", - ["none", "piecewise"], - ["none", "piecewise", "full"], - case_sensitive=False, - ), - "VLLM_ROCM_DSV4_CSA_MS_MAIN_COMPRESSOR": lambda: bool( - int(os.getenv("VLLM_ROCM_DSV4_CSA_MS_MAIN_COMPRESSOR", "1")) - ), - "VLLM_ROCM_DSV4_CSA_MS_DEFER_PROJECTIONS": lambda: bool( - int(os.getenv("VLLM_ROCM_DSV4_CSA_MS_DEFER_PROJECTIONS", "0")) - ), - "VLLM_ROCM_DSV4_CSA_MS_SPLIT_QKV_POST": lambda: bool( - int(os.getenv("VLLM_ROCM_DSV4_CSA_MS_SPLIT_QKV_POST", "1")) - ), - "VLLM_ROCM_DSV4_CSA_MS_OUTER_INDEXER": lambda: bool( - int(os.getenv("VLLM_ROCM_DSV4_CSA_MS_OUTER_INDEXER", "0")) - ), - "VLLM_ROCM_DSV4_CSA_MS_INDEXER_SUBSTREAMS": lambda: bool( - int(os.getenv("VLLM_ROCM_DSV4_CSA_MS_INDEXER_SUBSTREAMS", "1")) - ), - "VLLM_ROCM_DSV4_CSA_MS_AUX_PRIORITY": lambda: int( - os.getenv("VLLM_ROCM_DSV4_CSA_MS_AUX_PRIORITY", "-1") - ), # Format for saving torch.compile cache artifacts # - "binary": saves as binary file # Safe for multiple vllm serve processes accessing the same torch compile cache. diff --git a/vllm/models/deepseek_v4/amd/model.py b/vllm/models/deepseek_v4/amd/model.py index 84318a8107d3..df852eeb95f2 100644 --- a/vllm/models/deepseek_v4/amd/model.py +++ b/vllm/models/deepseek_v4/amd/model.py @@ -55,6 +55,20 @@ from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors +_ROCM_DSV4_CSA_AUX_STREAM_COUNT = 5 +_ROCM_DSV4_CSA_AUX_STREAM_PRIORITY = -1 + + +def make_deepseek_v4_aux_streams() -> list[torch.cuda.Stream] | None: + if not current_platform.is_rocm(): + return None + # ROCm DeepSeek-V4 decode overlap uses five streams: three top-level + # preparation branches and two C4-indexer sub-branches. + return [ + torch.cuda.Stream(priority=_ROCM_DSV4_CSA_AUX_STREAM_PRIORITY) + for _ in range(_ROCM_DSV4_CSA_AUX_STREAM_COUNT) + ] + class DeepseekV4MLP(nn.Module): def __init__( @@ -339,6 +353,16 @@ def __init__( self.indexer = None if self.compress_ratio == 4: # Only C4A uses sparse attention and hence has indexer. + indexer_aux_stream = ( + aux_stream_list[3] + if aux_stream_list is not None and len(aux_stream_list) >= 5 + else None + ) + indexer_aux_streams = ( + aux_stream_list[3:5] + if aux_stream_list is not None and len(aux_stream_list) >= 5 + else None + ) self.indexer = DeepseekV4Indexer( vllm_config, config=config, @@ -349,6 +373,8 @@ def __init__( topk_indices_buffer=topk_indices_buffer, compress_ratio=self.compress_ratio, prefix=f"{prefix}.indexer", + aux_stream=indexer_aux_stream, + aux_streams=indexer_aux_streams, ) mla_modules = DeepseekV4MLAModules( @@ -616,16 +642,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.hc_dim = self.hc_mult * config.hidden_size self.rms_norm_eps = config.rms_norm_eps - # Three aux streams: one per non-default input GEMM in - # DeepseekV4MultiHeadLatentAttentionWrapper.attn_gemm_parallel_execute - # (compressor kv_score, indexer.weights_proj, indexer.compressor - # kv_score). fused_wqa_wkv stays on the default stream. - # Disable them on ROCm because of hang issues. - aux_stream_list = ( - None - if current_platform.is_rocm() - else [torch.cuda.Stream() for _ in range(3)] - ) + aux_stream_list = make_deepseek_v4_aux_streams() self.device = current_platform.device_type # Reserved topk indices buffer for all Indexer layers to reuse. diff --git a/vllm/models/deepseek_v4/attention.py b/vllm/models/deepseek_v4/attention.py index 51a4cfce7776..16971e591ea8 100644 --- a/vllm/models/deepseek_v4/attention.py +++ b/vllm/models/deepseek_v4/attention.py @@ -79,6 +79,14 @@ logger = init_logger(__name__) +_ROCM_DSV4_CSA_MS_STRATEGY = "overlap" +_ROCM_DSV4_CSA_MS_GRAPH_MODES = frozenset(("none", "piecewise")) +_ROCM_DSV4_CSA_MS_MAIN_COMPRESSOR = True +_ROCM_DSV4_CSA_MS_DEFER_PROJECTIONS = False +_ROCM_DSV4_CSA_MS_SPLIT_QKV_POST = True +_ROCM_DSV4_CSA_MS_OUTER_INDEXER = False +_ROCM_DSV4_CSA_MS_INDEXER_SUBSTREAMS = True + def _iter_deepseek_v4_swa_metadata( attn_metadata: dict[str, AttentionMetadata] @@ -406,10 +414,10 @@ def _rocm_csa_ms_strategy_for_step( ) -> str: if not current_platform.is_rocm(): return "off" - if not envs.VLLM_ROCM_DSV4_CSA_MULTISTREAM or self.aux_stream_list is None: + if self.aux_stream_list is None: return "off" - strategy = envs.VLLM_ROCM_DSV4_CSA_MS_STRATEGY.lower() + strategy = _ROCM_DSV4_CSA_MS_STRATEGY if strategy == "off": return "off" if not _is_decode_only_deepseek_v4_step(attn_metadata): @@ -418,16 +426,12 @@ def _rocm_csa_ms_strategy_for_step( graph_mode = forward_context.cudagraph_runtime_mode assert graph_mode in CUDAGraphMode.valid_runtime_modes() graph_mode_name = graph_mode.name.lower() - enabled_graph_modes = { - mode.lower() for mode in envs.VLLM_ROCM_DSV4_CSA_MS_GRAPH_MODES - } # Standalone ROCm repro in benchmarks/kernels/rocm_dsv4_stream_probe.py # shows that putting the Python stream scheduler inside a compiled # full-graph wrapper can collapse the replayed work onto stream 0. # Keep the ROCm CSA multi-stream scheduler in eager/piecewise graph - # islands by default; FULL can be enabled explicitly for debugging via - # VLLM_ROCM_DSV4_CSA_MS_GRAPH_MODES=none,piecewise,full. - if graph_mode_name not in enabled_graph_modes: + # islands. + if graph_mode_name not in _ROCM_DSV4_CSA_MS_GRAPH_MODES: return "off" if strategy == "overlap" and len(self.aux_stream_list) < 5: @@ -710,17 +714,17 @@ def run_indexer() -> Any: indexer_weights, positions, self.indexer_rotary_emb, - use_aux_stream=envs.VLLM_ROCM_DSV4_CSA_MS_INDEXER_SUBSTREAMS, + use_aux_stream=_ROCM_DSV4_CSA_MS_INDEXER_SUBSTREAMS, project_inputs=defer_rocm_branch_projections, compressed_kv_score_out=indexer_kv_score_out, indexer_weights_out=indexer_weights_out, ) launched_indexer = ( - self.indexer is not None and envs.VLLM_ROCM_DSV4_CSA_MS_OUTER_INDEXER + self.indexer is not None and _ROCM_DSV4_CSA_MS_OUTER_INDEXER ) launched_compressor = ( - self.compressor is not None and envs.VLLM_ROCM_DSV4_CSA_MS_MAIN_COMPRESSOR + self.compressor is not None and _ROCM_DSV4_CSA_MS_MAIN_COMPRESSOR ) if launched_indexer: @@ -761,7 +765,7 @@ def attention_impl( defer_rocm_branch_projections = ( current_platform.is_rocm() and rocm_ms_strategy == "overlap" - and envs.VLLM_ROCM_DSV4_CSA_MS_DEFER_PROJECTIONS + and _ROCM_DSV4_CSA_MS_DEFER_PROJECTIONS ) gemm_aux_streams = None if current_platform.is_rocm() else aux_streams @@ -777,7 +781,7 @@ def attention_impl( use_rocm_split_post = ( current_platform.is_rocm() and rocm_ms_strategy == "overlap" - and envs.VLLM_ROCM_DSV4_CSA_MS_SPLIT_QKV_POST + and _ROCM_DSV4_CSA_MS_SPLIT_QKV_POST and aux_streams is not None and len(aux_streams) >= 5 ) @@ -879,8 +883,7 @@ def wq_b_kv_insert() -> torch.Tensor: # 3-way overlap (matches TRT-LLM PR #14142 Level 1): default runs # wq_b+kv_insert; slot [0] runs the full indexer; slot [1] runs the # MLA compressor. Slot [2] is reserved for the indexer's inner - # overlap. ROCm can optionally defer branch projections under - # VLLM_ROCM_DSV4_CSA_MS_DEFER_PROJECTIONS=1. + # overlap. run_indexer: Callable[[], Any] = lambda: indexer( hidden_states, qr, @@ -893,7 +896,7 @@ def wq_b_kv_insert() -> torch.Tensor: not current_platform.is_rocm() or ( rocm_ms_strategy == "overlap" - and envs.VLLM_ROCM_DSV4_CSA_MS_INDEXER_SUBSTREAMS + and _ROCM_DSV4_CSA_MS_INDEXER_SUBSTREAMS ) ), project_inputs=defer_rocm_branch_projections, @@ -911,10 +914,10 @@ def run_compressor() -> Any: compressor_fn: Callable[[], Any] | None = run_compressor if rocm_ms_active and ( rocm_ms_strategy == "indexer_only" - or not envs.VLLM_ROCM_DSV4_CSA_MS_MAIN_COMPRESSOR + or not _ROCM_DSV4_CSA_MS_MAIN_COMPRESSOR ): compressor_fn = None - if rocm_ms_active and not envs.VLLM_ROCM_DSV4_CSA_MS_OUTER_INDEXER: + if rocm_ms_active and not _ROCM_DSV4_CSA_MS_OUTER_INDEXER: indexer_fn = None q, (indexer_result, compressor_result) = execute_in_parallel( @@ -950,7 +953,7 @@ def run_compressor() -> Any: ) if ( rocm_ms_active - and not envs.VLLM_ROCM_DSV4_CSA_MS_MAIN_COMPRESSOR + and not _ROCM_DSV4_CSA_MS_MAIN_COMPRESSOR ): aux_stream = None compressor = self.compressor diff --git a/vllm/models/deepseek_v4/nvidia/model.py b/vllm/models/deepseek_v4/nvidia/model.py index 05f9a7e24981..974593a8d390 100644 --- a/vllm/models/deepseek_v4/nvidia/model.py +++ b/vllm/models/deepseek_v4/nvidia/model.py @@ -8,7 +8,6 @@ import torch import torch.nn as nn -import vllm.envs as envs from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig from vllm.distributed import ( @@ -66,26 +65,6 @@ from vllm.utils.torch_utils import direct_register_custom_op -def make_deepseek_v4_aux_streams() -> list[torch.cuda.Stream] | None: - if current_platform.is_rocm(): - if ( - not envs.VLLM_ROCM_DSV4_CSA_MULTISTREAM - or envs.VLLM_ROCM_DSV4_CSA_MS_STRATEGY.lower() == "off" - ): - return None - # ROCm uses five streams for DeepSeek-V4 decode overlap: three - # top-level preparation branches and two C4-indexer sub-branches: - # [0] main KV cache insert, [1] main compressor, [2] C4 indexer, - # [3] indexer Q branch, [4] indexer weights branch. - return [ - torch.cuda.Stream(priority=envs.VLLM_ROCM_DSV4_CSA_MS_AUX_PRIORITY) - for _ in range(5) - ] - if current_platform.is_xpu(): - return None - return [torch.cuda.Stream() for _ in range(3)] - - class DeepseekV4MLP(nn.Module): def __init__( self, @@ -801,28 +780,11 @@ def __init__( self.indexer = None if self.compress_ratio == 4: # Only C4A uses sparse attention and hence has indexer. - # NVIDIA uses aux_stream_list[2] for the legacy inner overlap. - # ROCm decode overlap uses aux_stream_list[3:5] for the C4 indexer - # q/weights sub-branches while the outer indexer branch runs on - # aux_stream_list[2]. - if ( - current_platform.is_rocm() - and aux_stream_list is not None - and len(aux_stream_list) >= 5 - ): - indexer_aux_stream = aux_stream_list[3] - else: - indexer_aux_stream = ( - aux_stream_list[2] if aux_stream_list is not None else None - ) - indexer_aux_streams = ( - aux_stream_list[3:5] - if ( - current_platform.is_rocm() - and aux_stream_list is not None - and len(aux_stream_list) >= 5 - ) - else None + # aux_stream_list[0] runs indexer.forward() in the wrapper; [2] is + # free here (outer GEMMs joined) for the inner overlap of + # wq_b+fused_indexer_q_rope_quant vs compressor. + indexer_aux_stream = ( + aux_stream_list[2] if aux_stream_list is not None else None ) self.indexer = DeepseekV4Indexer( vllm_config, @@ -835,7 +797,6 @@ def __init__( compress_ratio=self.compress_ratio, prefix=f"{prefix}.indexer", aux_stream=indexer_aux_stream, - aux_streams=indexer_aux_streams, ) mla_modules = DeepseekV4MLAModules( @@ -1133,7 +1094,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.hc_dim = self.hc_mult * config.hidden_size self.rms_norm_eps = config.rms_norm_eps - aux_stream_list = make_deepseek_v4_aux_streams() + # Three aux streams: one per non-default input GEMM in + # DeepseekV4MultiHeadLatentAttentionWrapper.attn_gemm_parallel_execute + # (compressor kv_score, indexer.weights_proj, indexer.compressor + # kv_score). fused_wqa_wkv stays on the default stream. + # Disable them on ROCm / XPU because of hang issues / no overlap. + aux_stream_list = ( + None + if current_platform.is_rocm() or current_platform.is_xpu() + else [torch.cuda.Stream() for _ in range(3)] + ) self.device = current_platform.device_type # Reserved topk indices buffer for all Indexer layers to reuse. From d37983b5c0f6be4d7a5af17dc80d7709d2278b9f Mon Sep 17 00:00:00 2001 From: vLLM Contributor Date: Wed, 27 May 2026 15:49:59 +0000 Subject: [PATCH 13/21] remove env knobs with preset optimal config Signed-off-by: vLLM Contributor --- vllm/models/deepseek_v4/amd/model.py | 24 +- vllm/models/deepseek_v4/attention.py | 502 +++------------------------ 2 files changed, 42 insertions(+), 484 deletions(-) diff --git a/vllm/models/deepseek_v4/amd/model.py b/vllm/models/deepseek_v4/amd/model.py index df852eeb95f2..fa0a03edec4e 100644 --- a/vllm/models/deepseek_v4/amd/model.py +++ b/vllm/models/deepseek_v4/amd/model.py @@ -55,19 +55,13 @@ from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors -_ROCM_DSV4_CSA_AUX_STREAM_COUNT = 5 -_ROCM_DSV4_CSA_AUX_STREAM_PRIORITY = -1 - def make_deepseek_v4_aux_streams() -> list[torch.cuda.Stream] | None: if not current_platform.is_rocm(): return None - # ROCm DeepSeek-V4 decode overlap uses five streams: three top-level - # preparation branches and two C4-indexer sub-branches. - return [ - torch.cuda.Stream(priority=_ROCM_DSV4_CSA_AUX_STREAM_PRIORITY) - for _ in range(_ROCM_DSV4_CSA_AUX_STREAM_COUNT) - ] + # Reserve the five-stream CSA decode topology. The measured ROCm default + # overlaps the compressor branch on aux[1] and keeps indexer fan-out off. + return [torch.cuda.Stream(priority=-1) for _ in range(5)] class DeepseekV4MLP(nn.Module): @@ -353,16 +347,6 @@ def __init__( self.indexer = None if self.compress_ratio == 4: # Only C4A uses sparse attention and hence has indexer. - indexer_aux_stream = ( - aux_stream_list[3] - if aux_stream_list is not None and len(aux_stream_list) >= 5 - else None - ) - indexer_aux_streams = ( - aux_stream_list[3:5] - if aux_stream_list is not None and len(aux_stream_list) >= 5 - else None - ) self.indexer = DeepseekV4Indexer( vllm_config, config=config, @@ -373,8 +357,6 @@ def __init__( topk_indices_buffer=topk_indices_buffer, compress_ratio=self.compress_ratio, prefix=f"{prefix}.indexer", - aux_stream=indexer_aux_stream, - aux_streams=indexer_aux_streams, ) mla_modules = DeepseekV4MLAModules( diff --git a/vllm/models/deepseek_v4/attention.py b/vllm/models/deepseek_v4/attention.py index 16971e591ea8..ae1b9d1a980a 100644 --- a/vllm/models/deepseek_v4/attention.py +++ b/vllm/models/deepseek_v4/attention.py @@ -79,14 +79,6 @@ logger = init_logger(__name__) -_ROCM_DSV4_CSA_MS_STRATEGY = "overlap" -_ROCM_DSV4_CSA_MS_GRAPH_MODES = frozenset(("none", "piecewise")) -_ROCM_DSV4_CSA_MS_MAIN_COMPRESSOR = True -_ROCM_DSV4_CSA_MS_DEFER_PROJECTIONS = False -_ROCM_DSV4_CSA_MS_SPLIT_QKV_POST = True -_ROCM_DSV4_CSA_MS_OUTER_INDEXER = False -_ROCM_DSV4_CSA_MS_INDEXER_SUBSTREAMS = True - def _iter_deepseek_v4_swa_metadata( attn_metadata: dict[str, AttentionMetadata] @@ -262,7 +254,6 @@ def __init__( ) self.aux_stream_list = mla_modules.aux_stream_list - self._rocm_aux_gemm_buffers: dict[str, torch.Tensor] = {} # [0]: GEMM start / post-GEMM event0. [1..3]: GEMM done events; # [1] doubles as post-GEMM event1. Reuse is safe: GEMM fully joins # before post-GEMM starts. @@ -393,87 +384,43 @@ def forward( def _aux_streams_for_step( self, - rocm_ms_strategy: str, - attn_metadata: ( - dict[str, AttentionMetadata] | list[dict[str, AttentionMetadata]] | None - ), + use_rocm_csa_multistream: bool, ) -> list[torch.cuda.Stream] | None: if current_platform.is_rocm(): - if rocm_ms_strategy == "off": - return None - if not _is_decode_only_deepseek_v4_step(attn_metadata): - return None + return self.aux_stream_list if use_rocm_csa_multistream else None return self.aux_stream_list - def _rocm_csa_ms_strategy_for_step( + def _use_rocm_csa_multistream( self, forward_context: ForwardContext, attn_metadata: ( dict[str, AttentionMetadata] | list[dict[str, AttentionMetadata]] | None ), - ) -> str: + ) -> bool: if not current_platform.is_rocm(): - return "off" + return False if self.aux_stream_list is None: - return "off" + return False + if len(self.aux_stream_list) < 5: + return False - strategy = _ROCM_DSV4_CSA_MS_STRATEGY - if strategy == "off": - return "off" if not _is_decode_only_deepseek_v4_step(attn_metadata): - return "off" + return False graph_mode = forward_context.cudagraph_runtime_mode assert graph_mode in CUDAGraphMode.valid_runtime_modes() - graph_mode_name = graph_mode.name.lower() # Standalone ROCm repro in benchmarks/kernels/rocm_dsv4_stream_probe.py # shows that putting the Python stream scheduler inside a compiled # full-graph wrapper can collapse the replayed work onto stream 0. # Keep the ROCm CSA multi-stream scheduler in eager/piecewise graph # islands. - if graph_mode_name not in _ROCM_DSV4_CSA_MS_GRAPH_MODES: - return "off" - - if strategy == "overlap" and len(self.aux_stream_list) < 5: - return "off" - if strategy == "indexer_only" and self.indexer is None: - return "off" - return strategy - - def _rocm_aux_buffer( - self, - name: str, - shape: tuple[int, ...], - dtype: torch.dtype, - device: torch.device, - ) -> torch.Tensor: - buffer = self._rocm_aux_gemm_buffers.get(name) - if ( - buffer is None - or buffer.shape[1:] != shape[1:] - or buffer.shape[0] < shape[0] - or buffer.dtype != dtype - or buffer.device != device - ): - buffer = torch.empty(shape, device=device, dtype=dtype) - self._rocm_aux_gemm_buffers[name] = buffer - if buffer.shape == shape: - return buffer - return buffer[: shape[0]] + return graph_mode in (CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE) def _project_compressor_kv_score( self, hidden_states: torch.Tensor, compressor: DeepseekCompressor, - out: torch.Tensor | None = None, ) -> torch.Tensor: - if out is not None: - return torch.mm( - hidden_states, - compressor.fused_wkv_wgate.weight.T, - out=out, - out_dtype=torch.float32, - ) return torch.mm( hidden_states, compressor.fused_wkv_wgate.weight.T, @@ -484,117 +431,48 @@ def attn_gemm_parallel_execute( self, hidden_states: torch.Tensor, aux_streams: list[torch.cuda.Stream] | None, - defer_rocm_branch_projections: bool = False, ) -> tuple[Any, ...]: if aux_streams is not None: assert len(aux_streams) >= 3 aux_streams = aux_streams[:3] - use_rocm_graph_safe_buffers = ( - current_platform.is_rocm() and aux_streams is not None - ) # fused_wqa_wkv (heaviest) on default; the three lighter input GEMMs # on aux streams 0..2 when their owning module exists. ln_events[0] # is the fan-out start event; ln_events[1..3] are per-aux done events. - # ROCm can optionally defer these branch projections until the - # post-rmsnorm fan-out for A/B testing. Current ROCm/PyTorch stacks do - # not need that for allocator safety, so the default keeps projections - # here and only overlaps the post-rmsnorm CSA branches. + # ROCm keeps these projections on the main stream; the measured decode + # win comes from overlapping the post-rmsnorm compressor branch. aux_fns: list[Callable[[], Any] | None] = [None, None, None] - if self.compressor is not None and not defer_rocm_branch_projections: + if self.compressor is not None: # Local ref so the closure keeps a non-None type for mypy. compressor = self.compressor def compressor_kv_score() -> torch.Tensor: - out = ( - self._rocm_aux_buffer( - "compressor_kv_score", - ( - hidden_states.shape[0], - compressor.fused_wkv_wgate.weight.shape[0], - ), - torch.float32, - hidden_states.device, - ) - if use_rocm_graph_safe_buffers - else None - ) - if out is None: - return torch.mm( - hidden_states, - compressor.fused_wkv_wgate.weight.T, - out_dtype=torch.float32, - ) return torch.mm( hidden_states, compressor.fused_wkv_wgate.weight.T, - out=out, out_dtype=torch.float32, ) aux_fns[0] = compressor_kv_score - if self.indexer is not None and not defer_rocm_branch_projections: + if self.indexer is not None: indexer = self.indexer def indexer_weights_proj() -> torch.Tensor: - if use_rocm_graph_safe_buffers: - out = self._rocm_aux_buffer( - "indexer_weights", - (hidden_states.shape[0], indexer.weights_proj.weight.shape[0]), - hidden_states.dtype, - hidden_states.device, - ) - return torch.mm( - hidden_states, - indexer.weights_proj.weight.T, - out=out, - ) # ReplicatedLinear returns (output, bias); bias is None. weights, _ = indexer.weights_proj(hidden_states) return weights def indexer_compressor_kv_score() -> torch.Tensor: - out = ( - self._rocm_aux_buffer( - "indexer_kv_score", - ( - hidden_states.shape[0], - indexer.compressor.fused_wkv_wgate.weight.shape[0], - ), - torch.float32, - hidden_states.device, - ) - if use_rocm_graph_safe_buffers - else None - ) - if out is None: - return torch.mm( - hidden_states, - indexer.compressor.fused_wkv_wgate.weight.T, - out_dtype=torch.float32, - ) return torch.mm( hidden_states, indexer.compressor.fused_wkv_wgate.weight.T, - out=out, out_dtype=torch.float32, ) aux_fns[1] = indexer_weights_proj aux_fns[2] = indexer_compressor_kv_score - rocm_deferred_indexer_weights_proj = indexer_weights_proj - rocm_deferred_indexer_compressor_kv_score = indexer_compressor_kv_score - - if use_rocm_graph_safe_buffers: - # Current ROCm graph replay hangs when the two smaller indexer GEMMs - # are captured as additional side-stream hipBLASLt nodes next to the - # aiter fused WQA/WKV GEMM. Keep the largest CSA GEMM overlapped and - # leave the indexer GEMMs on the main stream until that lower-level - # ordering issue is fixed. - aux_fns[1] = None - aux_fns[2] = None def fused_wqa_wkv() -> torch.Tensor: # MergedColumnParallelLinear returns (output, bias); bias is None. @@ -611,145 +489,8 @@ def fused_wqa_wkv() -> torch.Tensor: <= envs.VLLM_MULTI_STREAM_GEMM_TOKEN_THRESHOLD, ) - if ( - use_rocm_graph_safe_buffers - and self.indexer is not None - and not defer_rocm_branch_projections - ): - if indexer_weights is None: - indexer_weights = rocm_deferred_indexer_weights_proj() - if indexer_kv_score is None: - indexer_kv_score = rocm_deferred_indexer_compressor_kv_score() - return qr_kv, kv_score, indexer_kv_score, indexer_weights - def _rocm_multistream_post_rmsnorm_prepare( - self, - hidden_states: torch.Tensor, - qr: torch.Tensor, - kv: torch.Tensor, - kv_score: torch.Tensor | None, - indexer_kv_score: torch.Tensor | None, - indexer_weights: torch.Tensor | None, - positions: torch.Tensor, - attn_metadata: ( - dict[str, AttentionMetadata] | list[dict[str, AttentionMetadata]] | None - ), - aux_streams: list[torch.cuda.Stream], - defer_rocm_branch_projections: bool, - ) -> torch.Tensor: - assert len(aux_streams) >= 5 - stream_kv = aux_streams[0] - stream_compressor = aux_streams[1] - stream_indexer = aux_streams[2] - current_stream = torch.cuda.current_stream() - - stream_kv.wait_stream(current_stream) - stream_compressor.wait_stream(current_stream) - stream_indexer.wait_stream(current_stream) - - compressor_kv_score_out: torch.Tensor | None = None - if self.compressor is not None and kv_score is None: - compressor_kv_score_out = self._rocm_aux_buffer( - "deferred_compressor_kv_score", - ( - hidden_states.shape[0], - self.compressor.fused_wkv_wgate.weight.shape[0], - ), - torch.float32, - hidden_states.device, - ) - - indexer_weights_out: torch.Tensor | None = None - indexer_kv_score_out: torch.Tensor | None = None - if self.indexer is not None: - if indexer_weights is None: - indexer_weights_out = self._rocm_aux_buffer( - "deferred_indexer_weights", - ( - hidden_states.shape[0], - self.indexer.weights_proj.weight.shape[0], - ), - hidden_states.dtype, - hidden_states.device, - ) - if indexer_kv_score is None: - indexer_kv_score_out = self._rocm_aux_buffer( - "deferred_indexer_kv_score", - ( - hidden_states.shape[0], - self.indexer.compressor.fused_wkv_wgate.weight.shape[0], - ), - torch.float32, - hidden_states.device, - ) - - def q_b_qnorm_rope() -> torch.Tensor: - q = self.wq_b(qr).view(-1, self.n_local_heads, self.head_dim) - self._fused_qnorm_rope(q, positions) - return q - - def run_kv_insert() -> None: - self._fused_kv_rope_insert(kv, positions, attn_metadata) - - def run_compressor() -> Any: - if self.compressor is None: - return None - local_kv_score = kv_score - if local_kv_score is None: - local_kv_score = self._project_compressor_kv_score( - hidden_states, - self.compressor, - out=compressor_kv_score_out, - ) - return self.compressor(local_kv_score, positions, self.rotary_emb) - - def run_indexer() -> Any: - if self.indexer is None: - return None - return self.indexer( - hidden_states, - qr, - indexer_kv_score, - indexer_weights, - positions, - self.indexer_rotary_emb, - use_aux_stream=_ROCM_DSV4_CSA_MS_INDEXER_SUBSTREAMS, - project_inputs=defer_rocm_branch_projections, - compressed_kv_score_out=indexer_kv_score_out, - indexer_weights_out=indexer_weights_out, - ) - - launched_indexer = ( - self.indexer is not None and _ROCM_DSV4_CSA_MS_OUTER_INDEXER - ) - launched_compressor = ( - self.compressor is not None and _ROCM_DSV4_CSA_MS_MAIN_COMPRESSOR - ) - - if launched_indexer: - with torch.cuda.stream(stream_indexer): - run_indexer() - with torch.cuda.stream(stream_kv): - run_kv_insert() - if launched_compressor: - with torch.cuda.stream(stream_compressor): - run_compressor() - - q = q_b_qnorm_rope() - - current_stream.wait_stream(stream_kv) - if launched_compressor: - current_stream.wait_stream(stream_compressor) - else: - run_compressor() - if launched_indexer: - current_stream.wait_stream(stream_indexer) - else: - run_indexer() - - return q - def attention_impl( self, hidden_states: torch.Tensor, @@ -758,33 +499,20 @@ def attention_impl( ) -> None: forward_context = get_forward_context() attn_metadata = forward_context.attn_metadata - rocm_ms_strategy = self._rocm_csa_ms_strategy_for_step( + rocm_csa_multistream = self._use_rocm_csa_multistream( forward_context, attn_metadata ) - aux_streams = self._aux_streams_for_step(rocm_ms_strategy, attn_metadata) - defer_rocm_branch_projections = ( - current_platform.is_rocm() - and rocm_ms_strategy == "overlap" - and _ROCM_DSV4_CSA_MS_DEFER_PROJECTIONS - ) + aux_streams = self._aux_streams_for_step(rocm_csa_multistream) gemm_aux_streams = None if current_platform.is_rocm() else aux_streams qr_kv, kv_score, indexer_kv_score, indexer_weights = ( self.attn_gemm_parallel_execute( hidden_states, gemm_aux_streams, - defer_rocm_branch_projections=defer_rocm_branch_projections, ) ) qr, kv = qr_kv.split([self.q_lora_rank, self.head_dim], dim=-1) - use_rocm_split_post = ( - current_platform.is_rocm() - and rocm_ms_strategy == "overlap" - and _ROCM_DSV4_CSA_MS_SPLIT_QKV_POST - and aux_streams is not None - and len(aux_streams) >= 5 - ) qr, kv = fused_q_kv_rmsnorm( qr, kv, @@ -793,33 +521,18 @@ def attention_impl( self.eps, ) - if use_rocm_split_post: - q = self._rocm_multistream_post_rmsnorm_prepare( - hidden_states, - qr, - kv, - kv_score, - indexer_kv_score, - indexer_weights, - positions, - attn_metadata, - aux_streams, - defer_rocm_branch_projections, - ) - else: - q = self._legacy_post_rmsnorm_prepare( - hidden_states, - qr, - kv, - kv_score, - indexer_kv_score, - indexer_weights, - positions, - attn_metadata, - aux_streams, - rocm_ms_strategy, - defer_rocm_branch_projections, - ) + q = self._post_rmsnorm_prepare( + hidden_states, + qr, + kv, + kv_score, + indexer_kv_score, + indexer_weights, + positions, + attn_metadata, + aux_streams, + rocm_csa_multistream, + ) # Pad q to FlashMLA-required head count (64 or 128) if self.n_local_heads < self.padded_heads: @@ -830,7 +543,7 @@ def attention_impl( # ([num_tokens, padded_heads, head_dim]). self.mla_attn(q, kv, positions, output=out) - def _legacy_post_rmsnorm_prepare( + def _post_rmsnorm_prepare( self, hidden_states: torch.Tensor, qr: torch.Tensor, @@ -843,8 +556,7 @@ def _legacy_post_rmsnorm_prepare( dict[str, AttentionMetadata] | list[dict[str, AttentionMetadata]] | None ), aux_streams: list[torch.cuda.Stream] | None, - rocm_ms_strategy: str, - defer_rocm_branch_projections: bool, + rocm_csa_multistream: bool, ) -> torch.Tensor: # wq_b + kv_insert (+ MLA compressor when an indexer is present) ride # on the default stream so q stays on its consumer stream (mla_attn @@ -854,16 +566,16 @@ def _legacy_post_rmsnorm_prepare( rocm_ms_active = ( current_platform.is_rocm() and post_aux_streams is not None - and rocm_ms_strategy != "off" + and rocm_csa_multistream ) if ( rocm_ms_active and post_aux_streams is not None and len(post_aux_streams) >= 5 ): - # Legacy ROCm fallback keeps q+kv fused on default. Match the - # five-stream layout where possible: aux[2] indexer, aux[1] - # compressor, aux[3:5] indexer sub-branches. + # ROCm keeps q+kv fused on default and only overlaps the + # compressor on aux[1]. aux[2] stays reserved for the top-level + # indexer branch, which is intentionally not launched by default. outer_post_aux_streams = [post_aux_streams[2], post_aux_streams[1]] elif post_aux_streams is not None: outer_post_aux_streams = [post_aux_streams[0], post_aux_streams[1]] @@ -880,10 +592,8 @@ def wq_b_kv_insert() -> torch.Tensor: q = self._fused_qnorm_rope_kv_insert(q, kv, positions, attn_metadata) return q - # 3-way overlap (matches TRT-LLM PR #14142 Level 1): default runs - # wq_b+kv_insert; slot [0] runs the full indexer; slot [1] runs the - # MLA compressor. Slot [2] is reserved for the indexer's inner - # overlap. + # NVIDIA keeps the 3-way split. ROCm's measured default keeps + # indexer work on default and overlaps only the compressor branch. run_indexer: Callable[[], Any] = lambda: indexer( hidden_states, qr, @@ -892,14 +602,7 @@ def wq_b_kv_insert() -> torch.Tensor: positions, self.indexer_rotary_emb, use_aux_stream=post_aux_streams is not None - and ( - not current_platform.is_rocm() - or ( - rocm_ms_strategy == "overlap" - and _ROCM_DSV4_CSA_MS_INDEXER_SUBSTREAMS - ) - ), - project_inputs=defer_rocm_branch_projections, + and (not current_platform.is_rocm()), ) def run_compressor() -> Any: @@ -912,12 +615,7 @@ def run_compressor() -> Any: indexer_fn: Callable[[], Any] | None = run_indexer compressor_fn: Callable[[], Any] | None = run_compressor - if rocm_ms_active and ( - rocm_ms_strategy == "indexer_only" - or not _ROCM_DSV4_CSA_MS_MAIN_COMPRESSOR - ): - compressor_fn = None - if rocm_ms_active and not _ROCM_DSV4_CSA_MS_OUTER_INDEXER: + if rocm_ms_active: indexer_fn = None q, (indexer_result, compressor_result) = execute_in_parallel( @@ -951,11 +649,6 @@ def run_compressor() -> Any: ) else None ) - if ( - rocm_ms_active - and not _ROCM_DSV4_CSA_MS_MAIN_COMPRESSOR - ): - aux_stream = None compressor = self.compressor def wq_b_kv_insert() -> torch.Tensor: @@ -1032,47 +725,6 @@ def _fused_qnorm_rope_kv_insert( swa_metadata.block_size, ) - def _fused_qnorm_rope( - self, - q: torch.Tensor, - positions: torch.Tensor, - ) -> None: - torch.ops._C.deepseek_v4_qnorm_rope( - q, - positions.to(torch.int64), - self.rotary_emb.cos_sin_cache, - self.eps, - ) - - def _fused_kv_rope_insert( - self, - kv: torch.Tensor, - positions: torch.Tensor, - attn_metadata: ( - dict[str, AttentionMetadata] | list[dict[str, AttentionMetadata]] | None - ), - ) -> None: - if not isinstance(attn_metadata, dict): - return - - swa_metadata = cast( - "DeepseekSparseSWAMetadata | None", - attn_metadata.get(self.swa_cache_layer.prefix), - ) - assert swa_metadata is not None - - swa_kv_cache = self.swa_cache_layer.kv_cache - swa_kv_cache_2d = swa_kv_cache.view(swa_kv_cache.shape[0], -1) - - torch.ops._C.deepseek_v4_kv_rope_quant_insert( - kv, - swa_kv_cache_2d, - swa_metadata.slot_mapping, - positions.to(torch.int64), - self.rotary_emb.cos_sin_cache, - swa_metadata.block_size, - ) - @eager_break_during_capture def deepseek_v4_attention( @@ -1313,7 +965,6 @@ def __init__( compress_ratio: int = 1, prefix: str = "", aux_stream: torch.cuda.Stream | None = None, - aux_streams: list[torch.cuda.Stream] | None = None, ): super().__init__() self.vllm_config = vllm_config @@ -1399,10 +1050,8 @@ def __init__( use_fp4_cache=self.use_fp4_kv, ) - # aux_stream is the legacy two-way split. aux_streams maps the C4 - # indexer sub-branches as [0] Q projection/quant, [1] weights proj. + # aux_stream is the legacy two-way split used by non-ROCm paths. self.aux_stream = aux_stream - self.aux_streams = aux_streams self.ln_events: list[torch.cuda.Event] = [ torch.cuda.Event(), torch.cuda.Event(), @@ -1417,39 +1066,9 @@ def forward( positions: torch.Tensor, rotary_emb: nn.Module, use_aux_stream: bool = True, - project_inputs: bool = False, - compressed_kv_score_out: torch.Tensor | None = None, - indexer_weights_out: torch.Tensor | None = None, ) -> torch.Tensor: compressor = self.compressor - def project_indexer_weights() -> torch.Tensor: - if indexer_weights_out is not None: - return torch.mm( - hidden_states, - self.weights_proj.weight.T, - out=indexer_weights_out, - ) - # Keep the default path numerically identical to the baseline - # ReplicatedLinear path. The explicit out= path above is only used - # for ROCm graph-safe deferred side-stream projection buffers. - weights, _ = self.weights_proj(hidden_states) - return weights - - def project_compressed_kv_score() -> torch.Tensor: - if compressed_kv_score_out is not None: - return torch.mm( - hidden_states, - compressor.fused_wkv_wgate.weight.T, - out=compressed_kv_score_out, - out_dtype=torch.float32, - ) - return torch.mm( - hidden_states, - compressor.fused_wkv_wgate.weight.T, - out_dtype=torch.float32, - ) - def wq_b_and_q_quant(weights: torch.Tensor): # ReplicatedLinear returns (output, bias); bias is None. q, _ = self.wq_b(qr) @@ -1464,49 +1083,6 @@ def wq_b_and_q_quant(weights: torch.Tensor): use_fp4=self.use_fp4_kv, ) - def record_stream(result: Any, stream: torch.cuda.Stream) -> None: - if isinstance(result, torch.Tensor): - result.record_stream(stream) - elif isinstance(result, (tuple, list)): - for item in result: - record_stream(item, stream) - - if ( - project_inputs - and use_aux_stream - and self.aux_streams is not None - and len(self.aux_streams) >= 2 - ): - current_stream = torch.cuda.current_stream() - stream_q = self.aux_streams[0] - stream_weights = self.aux_streams[1] - stream_q.wait_stream(current_stream) - stream_weights.wait_stream(current_stream) - - local_kv_score = project_compressed_kv_score() - k = compressor(local_kv_score, positions, rotary_emb) - - with torch.cuda.stream(stream_weights): - weights = project_indexer_weights() - weights_ready = stream_weights.record_event() - - q_result: list[Any] = [None] - with torch.cuda.stream(stream_q): - stream_q.wait_event(weights_ready) - q_result[0] = wq_b_and_q_quant(weights) - - current_stream.wait_stream(stream_q) - assert q_result[0] is not None - record_stream(q_result[0], current_stream) - (q_quant, weights) = q_result[0] - return self.indexer_op(hidden_states, q_quant, k, weights) - - if project_inputs: - if indexer_weights is None: - indexer_weights = project_indexer_weights() - if compressed_kv_score is None: - compressed_kv_score = project_compressed_kv_score() - assert indexer_weights is not None assert compressed_kv_score is not None From 9ca5b6084716fdaefea0d2b18ff1829e2dc43dd3 Mon Sep 17 00:00:00 2001 From: vLLM Contributor Date: Wed, 27 May 2026 15:52:33 +0000 Subject: [PATCH 14/21] remove the repro test Signed-off-by: vLLM Contributor --- tools/rocm_multistream_graph_repro.py | 130 -------------------------- 1 file changed, 130 deletions(-) delete mode 100644 tools/rocm_multistream_graph_repro.py diff --git a/tools/rocm_multistream_graph_repro.py b/tools/rocm_multistream_graph_repro.py deleted file mode 100644 index cf38ead34f96..000000000000 --- a/tools/rocm_multistream_graph_repro.py +++ /dev/null @@ -1,130 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Reproduce a ROCm graph replay hang with side-stream allocations. - -This reproducer isolates the lower-level failure seen while enabling -DeepSeek-V4 CSA decode multi-stream on ROCm. The allocating mode captures a -side-stream GEMM that creates its output tensor inside the CUDA graph. On the -MI355X test system this reaches ``capture ok`` and then hangs at the first graph -replay with GPUs at 100% busy and 0% memory bandwidth. - -Run from the repo root: - - HIP_VISIBLE_DEVICES=0 timeout 90s \\ - .venv/bin/python tools/rocm_multistream_graph_repro.py --mode allocating - -The graph-safe variant preallocates all GEMM outputs and uses ``out=``: - - HIP_VISIBLE_DEVICES=0 timeout 90s \\ - .venv/bin/python tools/rocm_multistream_graph_repro.py --mode preallocated - -On ROCm, ``torch.cuda.Event(external=True)`` is not a workaround; it raises -``RuntimeError: External events are disallowed in rocm``. -""" - -from __future__ import annotations - -import argparse - -import torch - - -def _make_inputs() -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - a = torch.randn(1024, 1024, device="cuda", dtype=torch.float16) - b = torch.randn(1024, 1024, device="cuda", dtype=torch.float16) - main_out = torch.empty((1024, 1024), device="cuda", dtype=torch.float16) - return a, b, main_out - - -def run_allocating(replays: int) -> None: - aux_stream = torch.cuda.Stream() - start_event = torch.cuda.Event() - done_event = torch.cuda.Event() - a, b, main_out = _make_inputs() - - def work() -> torch.Tensor: - current_stream = torch.cuda.current_stream() - start_event.record(current_stream) - with torch.cuda.stream(aux_stream): - aux_stream.wait_event(start_event) - aux_out = torch.mm(a, b) - done_event.record(aux_stream) - torch.mm(a, b, out=main_out) - current_stream.wait_event(done_event) - aux_out.record_stream(current_stream) - return main_out.float().mean() + aux_out.float().mean() - - for _ in range(3): - y = work() - torch.cuda.synchronize() - print("warmup ok", float(y)) - - graph = torch.cuda.CUDAGraph() - torch.cuda.synchronize() - with torch.cuda.graph(graph): - y = work() - print("capture ok") - - for i in range(replays): - graph.replay() - torch.cuda.synchronize() - print("replay", i, float(y)) - - -def run_preallocated(replays: int) -> None: - aux_stream = torch.cuda.Stream() - start_event = torch.cuda.Event() - done_event = torch.cuda.Event() - a, b, main_out = _make_inputs() - aux_out = torch.empty_like(main_out) - - def work() -> torch.Tensor: - current_stream = torch.cuda.current_stream() - start_event.record(current_stream) - with torch.cuda.stream(aux_stream): - aux_stream.wait_event(start_event) - torch.mm(a, b, out=aux_out) - done_event.record(aux_stream) - torch.mm(a, b, out=main_out) - current_stream.wait_event(done_event) - aux_out.record_stream(current_stream) - return main_out.float().mean() + aux_out.float().mean() - - for _ in range(3): - y = work() - torch.cuda.synchronize() - print("warmup ok", float(y)) - - graph = torch.cuda.CUDAGraph() - torch.cuda.synchronize() - with torch.cuda.graph(graph): - y = work() - print("capture ok") - - for i in range(replays): - graph.replay() - torch.cuda.synchronize() - print("replay", i, float(y)) - - -def main() -> None: - parser = argparse.ArgumentParser() - parser.add_argument( - "--mode", - choices=("allocating", "preallocated"), - required=True, - ) - parser.add_argument("--replays", type=int, default=20) - args = parser.parse_args() - - if not torch.cuda.is_available(): - raise RuntimeError("CUDA/HIP device is required") - - if args.mode == "allocating": - run_allocating(args.replays) - else: - run_preallocated(args.replays) - - -if __name__ == "__main__": - main() From 00fa801ff1c50cbbce8cf2adda0186f609e5faea Mon Sep 17 00:00:00 2001 From: vLLM Contributor Date: Wed, 27 May 2026 15:56:51 +0000 Subject: [PATCH 15/21] align test cases with current multistream decoding settings Signed-off-by: vLLM Contributor --- .../test_deepseek_v4_rocm_multistream.py | 191 +++++++++++++----- 1 file changed, 146 insertions(+), 45 deletions(-) diff --git a/tests/models/test_deepseek_v4_rocm_multistream.py b/tests/models/test_deepseek_v4_rocm_multistream.py index d59b1f88b4c8..d2648c91f737 100644 --- a/tests/models/test_deepseek_v4_rocm_multistream.py +++ b/tests/models/test_deepseek_v4_rocm_multistream.py @@ -1,7 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from types import SimpleNamespace + import pytest +import torch from vllm.models.deepseek_v4 import attention as dsv4_attention from vllm.models.deepseek_v4.amd import model as dsv4_model @@ -12,63 +15,161 @@ ) -def test_deepseek_v4_rocm_aux_streams_enabled(monkeypatch): - streams = [object(), object(), object(), object(), object()] - - def make_stream(**kwargs): - assert kwargs == {"priority": -1} - return streams.pop() - - monkeypatch.setattr(dsv4_model.torch.cuda, "Stream", make_stream) - +def test_deepseek_v4_rocm_aux_streams_enabled(): aux_streams = dsv4_model.make_deepseek_v4_aux_streams() assert aux_streams is not None assert len(aux_streams) == 5 + assert all(stream.priority == -1 for stream in aux_streams) -class _Metadata: - def __init__( - self, - num_decodes: int, - num_decode_tokens: int, - num_prefill_tokens: int = 0, - ): - self.num_decodes = num_decodes - self.num_decode_tokens = num_decode_tokens - self.num_prefill_tokens = num_prefill_tokens - - -def _rocm_ms_strategy_for_decodes( - monkeypatch, - num_decodes: int, +def _swa_metadata( + num_decode_tokens: int, num_prefill_tokens: int = 0, -) -> str: +) -> dsv4_attention.DeepseekSparseSWAMetadata: + return dsv4_attention.DeepseekSparseSWAMetadata( + block_table=torch.empty(0, dtype=torch.int32), + slot_mapping=torch.empty(0, dtype=torch.int64), + block_size=256, + num_decodes=num_decode_tokens, + num_decode_tokens=num_decode_tokens, + num_prefill_tokens=num_prefill_tokens, + ) + + +def _use_rocm_multistream( + cudagraph_runtime_mode: dsv4_attention.CUDAGraphMode, + metadata: dsv4_attention.DeepseekSparseSWAMetadata, +) -> bool: class _ForwardContext: - cudagraph_runtime_mode = dsv4_attention.CUDAGraphMode.PIECEWISE + pass class _Wrapper: aux_stream_list = [object(), object(), object(), object(), object()] - indexer = object() - monkeypatch.setattr(dsv4_attention, "DeepseekSparseSWAMetadata", _Metadata) - attn_metadata = { - "layer_0.swa": _Metadata(num_decodes, num_decodes, num_prefill_tokens) - } + forward_context = _ForwardContext() + forward_context.cudagraph_runtime_mode = cudagraph_runtime_mode + attn_metadata = {"layer_0.swa": metadata} wrapper_cls = dsv4_attention.DeepseekV4MultiHeadLatentAttentionWrapper - method = wrapper_cls._rocm_csa_ms_strategy_for_step - return method(_Wrapper(), _ForwardContext(), attn_metadata) - - -def test_deepseek_v4_rocm_multistream_all_decode_counts(monkeypatch): - assert _rocm_ms_strategy_for_decodes(monkeypatch, 4) == "overlap" - assert _rocm_ms_strategy_for_decodes(monkeypatch, 64) == "overlap" - assert _rocm_ms_strategy_for_decodes(monkeypatch, 128) == "overlap" - assert _rocm_ms_strategy_for_decodes(monkeypatch, 512) == "overlap" - - -def test_deepseek_v4_rocm_multistream_prefill_stays_off(monkeypatch): - strategy = _rocm_ms_strategy_for_decodes(monkeypatch, 4, num_prefill_tokens=1) + method = wrapper_cls._use_rocm_csa_multistream + return method(_Wrapper(), forward_context, attn_metadata) + + +def test_deepseek_v4_rocm_multistream_decode_policy(): + decode_metadata = _swa_metadata(num_decode_tokens=4) + + assert ( + _use_rocm_multistream(dsv4_attention.CUDAGraphMode.NONE, decode_metadata) + is True + ) + assert ( + _use_rocm_multistream(dsv4_attention.CUDAGraphMode.PIECEWISE, decode_metadata) + is True + ) + assert ( + _use_rocm_multistream(dsv4_attention.CUDAGraphMode.FULL, decode_metadata) + is False + ) + + mixed_metadata = _swa_metadata(num_decode_tokens=4, num_prefill_tokens=1) + assert ( + _use_rocm_multistream(dsv4_attention.CUDAGraphMode.PIECEWISE, mixed_metadata) + is False + ) + + +def test_deepseek_v4_rocm_post_rmsnorm_stream_mapping(monkeypatch): + calls = [] + streams = [object(), object(), object(), object(), object()] - assert strategy == "off" + def fake_execute_in_parallel( + default_fn, + aux_fns, + start_event, + done_events, + aux_streams, + enable=False, + ): + assert enable is True + assert aux_streams == [streams[2], streams[1]] + assert len(aux_fns) == 2 + assert aux_fns[0] is None + assert aux_fns[1] is not None + + q = default_fn() + compressor_result = aux_fns[1]() + return q, [None, compressor_result] + + monkeypatch.setattr(dsv4_attention, "execute_in_parallel", fake_execute_in_parallel) + + class _WqB: + def __call__(self, qr): + calls.append("wq_b") + return torch.empty((qr.shape[0], 3)) + + class _Indexer: + def __call__( + self, + hidden_states, + qr, + indexer_kv_score, + indexer_weights, + positions, + rotary_emb, + use_aux_stream=True, + ): + calls.append(("indexer", use_aux_stream)) + return object() + + class _Compressor: + def __call__(self, kv_score, positions, rotary_emb): + calls.append("compressor") + return object() + + wrapper = SimpleNamespace( + wq_b=_WqB(), + indexer=_Indexer(), + compressor=_Compressor(), + n_local_heads=1, + head_dim=3, + indexer_rotary_emb=object(), + rotary_emb=object(), + ln_events=[object(), object(), object()], + ) + + def fake_kv_insert(q, kv, positions, attn_metadata): + calls.append("kv_insert") + return q + + def fail_project_compressor_kv_score(hidden_states, compressor): + raise AssertionError("kv_score should be reused in this test") + + wrapper._fused_qnorm_rope_kv_insert = fake_kv_insert + wrapper._project_compressor_kv_score = fail_project_compressor_kv_score + + hidden_states = torch.empty((2, 3)) + qr = torch.empty((2, 3)) + kv = torch.empty((2, 3)) + positions = torch.empty(2, dtype=torch.int64) + kv_score = torch.empty((2, 3)) + indexer_kv_score = torch.empty((2, 3)) + indexer_weights = torch.empty((2, 3)) + + method = dsv4_attention.DeepseekV4MultiHeadLatentAttentionWrapper + q = method._post_rmsnorm_prepare( + wrapper, + hidden_states, + qr, + kv, + kv_score, + indexer_kv_score, + indexer_weights, + positions, + None, + streams, + True, + ) + + assert q.shape == (2, 1, 3) + assert calls == ["wq_b", "kv_insert", "compressor", ("indexer", False)] From 0c6f5073ae69565452c79339b35f1cd22d14ec25 Mon Sep 17 00:00:00 2001 From: vLLM Contributor Date: Wed, 27 May 2026 16:01:50 +0000 Subject: [PATCH 16/21] fix test to AMD only path Signed-off-by: vLLM Contributor --- ..._fused_deepseek_v4_qnorm_rope_kv_insert.py | 28 ++----------------- 1 file changed, 3 insertions(+), 25 deletions(-) diff --git a/tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py b/tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py index c950af03f04e..fe0306d42208 100644 --- a/tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py +++ b/tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py @@ -136,22 +136,6 @@ def _call_fused( ) -def _split_ops_available() -> bool: - return hasattr(torch.ops._C, "deepseek_v4_qnorm_rope") and hasattr( - torch.ops._C, "deepseek_v4_kv_rope_quant_insert" - ) - - -def _call_q_split(q, positions, cos_sin_cache, eps): - torch.ops._C.deepseek_v4_qnorm_rope(q, positions, cos_sin_cache, eps) - - -def _call_kv_split(kv, k_cache, slot_mapping, positions, cos_sin_cache, bs): - torch.ops._C.deepseek_v4_kv_rope_quant_insert( - kv, k_cache, slot_mapping, positions, cos_sin_cache, bs - ) - - # ── Test 1: Q path numerical parity ────────────────────────────────────────── @@ -433,16 +417,10 @@ def test_combined_q_and_kv( torch.testing.assert_close(k_cache_fused, k_cache_ref, rtol=0, atol=0) -@pytest.mark.skipif( - not _split_ops_available(), - reason="split DeepseekV4 q/kv ops not built in", -) @pytest.mark.parametrize("num_tokens", [1, 17, 64]) @pytest.mark.parametrize("n_heads", [8, 64]) @pytest.mark.parametrize("block_size", [16, 64]) -def test_split_q_and_kv_match_combined( - num_tokens: int, n_heads: int, block_size: int -): +def test_split_q_and_kv_match_combined(num_tokens: int, n_heads: int, block_size: int): torch.manual_seed(4) device = "cuda" dtype = torch.bfloat16 @@ -474,8 +452,8 @@ def test_split_q_and_kv_match_combined( q_split = q.clone() k_cache_split = torch.zeros_like(k_cache_fused) - _call_q_split(q_split, positions, cos_sin_cache, eps) - _call_kv_split( + torch.ops._C.deepseek_v4_qnorm_rope(q_split, positions, cos_sin_cache, eps) + torch.ops._C.deepseek_v4_kv_rope_quant_insert( kv, k_cache_split, slot_mapping, From a0b19803c40b56a626200ca47e9c280190ce8221 Mon Sep 17 00:00:00 2001 From: vLLM Contributor Date: Wed, 27 May 2026 16:03:39 +0000 Subject: [PATCH 17/21] inline stream creation for ROCm path Signed-off-by: vLLM Contributor --- tests/models/test_deepseek_v4_rocm_multistream.py | 9 --------- vllm/models/deepseek_v4/amd/model.py | 10 +--------- 2 files changed, 1 insertion(+), 18 deletions(-) diff --git a/tests/models/test_deepseek_v4_rocm_multistream.py b/tests/models/test_deepseek_v4_rocm_multistream.py index d2648c91f737..10026c78d0be 100644 --- a/tests/models/test_deepseek_v4_rocm_multistream.py +++ b/tests/models/test_deepseek_v4_rocm_multistream.py @@ -7,7 +7,6 @@ import torch from vllm.models.deepseek_v4 import attention as dsv4_attention -from vllm.models.deepseek_v4.amd import model as dsv4_model from vllm.platforms import current_platform pytestmark = pytest.mark.skipif( @@ -15,14 +14,6 @@ ) -def test_deepseek_v4_rocm_aux_streams_enabled(): - aux_streams = dsv4_model.make_deepseek_v4_aux_streams() - - assert aux_streams is not None - assert len(aux_streams) == 5 - assert all(stream.priority == -1 for stream in aux_streams) - - def _swa_metadata( num_decode_tokens: int, num_prefill_tokens: int = 0, diff --git a/vllm/models/deepseek_v4/amd/model.py b/vllm/models/deepseek_v4/amd/model.py index fa0a03edec4e..9f247418bec2 100644 --- a/vllm/models/deepseek_v4/amd/model.py +++ b/vllm/models/deepseek_v4/amd/model.py @@ -56,14 +56,6 @@ from vllm.sequence import IntermediateTensors -def make_deepseek_v4_aux_streams() -> list[torch.cuda.Stream] | None: - if not current_platform.is_rocm(): - return None - # Reserve the five-stream CSA decode topology. The measured ROCm default - # overlaps the compressor branch on aux[1] and keeps indexer fan-out off. - return [torch.cuda.Stream(priority=-1) for _ in range(5)] - - class DeepseekV4MLP(nn.Module): def __init__( self, @@ -624,7 +616,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.hc_dim = self.hc_mult * config.hidden_size self.rms_norm_eps = config.rms_norm_eps - aux_stream_list = make_deepseek_v4_aux_streams() + aux_stream_list = [torch.cuda.Stream(priority=-1) for _ in range(5)] self.device = current_platform.device_type # Reserved topk indices buffer for all Indexer layers to reuse. From 42a865c82bb32e6567613caf83215e939e39eebd Mon Sep 17 00:00:00 2001 From: vLLM Contributor Date: Wed, 27 May 2026 16:34:16 +0000 Subject: [PATCH 18/21] clean unused kernel Signed-off-by: vLLM Contributor --- csrc/ops.h | 9 ---- csrc/torch_bindings.cpp | 12 ----- ..._fused_deepseek_v4_qnorm_rope_kv_insert.py | 49 ------------------- 3 files changed, 70 deletions(-) diff --git a/csrc/ops.h b/csrc/ops.h index 66bcad2a4000..3c6b2b7b9bc2 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -76,15 +76,6 @@ torch::Tensor fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert( torch::Tensor const& cos_sin_cache, int64_t q_head_padded, double eps, int64_t cache_block_size); -void deepseek_v4_qnorm_rope(torch::Tensor& q, - torch::Tensor const& position_ids, - torch::Tensor const& cos_sin_cache, double eps); - -void deepseek_v4_kv_rope_quant_insert( - torch::Tensor const& kv, torch::Tensor& k_cache, - torch::Tensor const& slot_mapping, torch::Tensor const& position_ids, - torch::Tensor const& cos_sin_cache, int64_t cache_block_size); - void apply_repetition_penalties_(torch::Tensor& logits, const torch::Tensor& prompt_mask, const torch::Tensor& output_mask, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 14ab2221458f..487ff8df9dd0 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -105,18 +105,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.impl("fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert", torch::kCUDA, &fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert); - ops.def("deepseek_v4_qnorm_rope(" - "Tensor! q, Tensor position_ids, Tensor cos_sin_cache, " - "float eps) -> ()"); - ops.impl("deepseek_v4_qnorm_rope", torch::kCUDA, &deepseek_v4_qnorm_rope); - - ops.def("deepseek_v4_kv_rope_quant_insert(" - "Tensor kv, Tensor! k_cache, Tensor slot_mapping, " - "Tensor position_ids, Tensor cos_sin_cache, " - "int cache_block_size) -> ()"); - ops.impl("deepseek_v4_kv_rope_quant_insert", torch::kCUDA, - &deepseek_v4_kv_rope_quant_insert); - // Apply repetition penalties to logits in-place ops.def( "apply_repetition_penalties_(Tensor! logits, Tensor prompt_mask, " diff --git a/tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py b/tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py index fe0306d42208..a49ea498e5e0 100644 --- a/tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py +++ b/tests/kernels/test_fused_deepseek_v4_qnorm_rope_kv_insert.py @@ -415,52 +415,3 @@ def test_combined_q_and_kv( "padded head slots must be exact zero" ) torch.testing.assert_close(k_cache_fused, k_cache_ref, rtol=0, atol=0) - - -@pytest.mark.parametrize("num_tokens", [1, 17, 64]) -@pytest.mark.parametrize("n_heads", [8, 64]) -@pytest.mark.parametrize("block_size", [16, 64]) -def test_split_q_and_kv_match_combined(num_tokens: int, n_heads: int, block_size: int): - torch.manual_seed(4) - device = "cuda" - dtype = torch.bfloat16 - eps = 1e-6 - max_pos = 4096 - - q = torch.randn(num_tokens, n_heads, HEAD_DIM, dtype=dtype, device=device) - kv = torch.randn(num_tokens, HEAD_DIM, dtype=dtype, device=device) - positions = torch.arange(num_tokens, dtype=torch.int64, device=device) - cos_sin_cache = make_cos_sin_cache(max_pos, ROPE_DIM, torch.float32, device) - - num_blocks = (num_tokens + block_size - 1) // block_size + 1 - slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device) - - k_cache_fused = torch.zeros( - num_blocks, block_size * HEAD_BYTES, dtype=torch.uint8, device=device - ) - q_fused = _call_fused( - q, - n_heads, - kv, - k_cache_fused, - slot_mapping, - positions, - cos_sin_cache, - eps, - block_size, - ) - - q_split = q.clone() - k_cache_split = torch.zeros_like(k_cache_fused) - torch.ops._C.deepseek_v4_qnorm_rope(q_split, positions, cos_sin_cache, eps) - torch.ops._C.deepseek_v4_kv_rope_quant_insert( - kv, - k_cache_split, - slot_mapping, - positions, - cos_sin_cache, - block_size, - ) - - torch.testing.assert_close(q_split, q_fused, rtol=0, atol=0) - torch.testing.assert_close(k_cache_split, k_cache_fused, rtol=0, atol=0) From ad89bc84c788b25295b5d728e8354e23a2108890 Mon Sep 17 00:00:00 2001 From: vLLM Contributor Date: Wed, 27 May 2026 16:35:24 +0000 Subject: [PATCH 19/21] remove unused kernel Signed-off-by: vLLM Contributor --- ...deepseek_v4_qnorm_rope_kv_insert_kernel.cu | 226 ------------------ 1 file changed, 226 deletions(-) diff --git a/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu b/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu index 4a0a0b210612..e4d432cac97f 100644 --- a/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu +++ b/csrc/fused_deepseek_v4_qnorm_rope_kv_insert_kernel.cu @@ -644,155 +644,6 @@ void launchFusedDeepseekV4QNormRopeKVRopeQuantInsert( #undef DISPATCH } -template -__global__ void deepseekV4QNormRopeKernel( - scalar_t_in* __restrict__ q_inout, // [N, H, 512] bf16, in place - int64_t const* __restrict__ position_ids, // [N] i64 - float const* __restrict__ cos_sin_cache, // [max_pos, 64] fp32 - float const eps, int const num_tokens, int const num_heads_q) { -#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800) && !defined(USE_ROCM) - if constexpr (std::is_same_v) { - return; - } else { -#endif - int const warpsPerBlock = blockDim.x / 32; - int const warpId = threadIdx.x / 32; - int const laneId = threadIdx.x % 32; - int const globalWarpIdx = blockIdx.x * warpsPerBlock + warpId; - - int const tokenIdx = globalWarpIdx / num_heads_q; - int const headIdx = globalWarpIdx % num_heads_q; - if (tokenIdx >= num_tokens) return; - - int const dim_base = laneId * kElemsPerLane; - scalar_t_in const* src_ptr = - q_inout + - (static_cast(tokenIdx) * num_heads_q + headIdx) * kHeadDim + - dim_base; - uint4 const v0 = *reinterpret_cast(src_ptr); - uint4 const v1 = *reinterpret_cast(src_ptr + 8); - - processDeepseekV4Slot( - v0, v1, tokenIdx, headIdx, dim_base, laneId, num_heads_q, eps, q_inout, - nullptr, nullptr, position_ids, cos_sin_cache, 0, 0); -#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800) && !defined(USE_ROCM) - } -#endif -} - -template -__global__ void deepseekV4KVRopeQuantInsertKernel( - scalar_t_in const* __restrict__ kv_in, // [N, 512] bf16 - uint8_t* __restrict__ k_cache, // [num_blocks, block_stride] - int64_t const* __restrict__ slot_mapping, // [num_tokens_insert] i64 - int64_t const* __restrict__ position_ids, // [N] i64 - float const* __restrict__ cos_sin_cache, // [max_pos, 64] fp32 - int const num_tokens_full, int const num_tokens_insert, - int const cache_block_size, int const kv_block_stride) { -#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800) && !defined(USE_ROCM) - if constexpr (std::is_same_v) { - return; - } else { -#endif - int const warpsPerBlock = blockDim.x / 32; - int const warpId = threadIdx.x / 32; - int const laneId = threadIdx.x % 32; - int const tokenIdx = blockIdx.x * warpsPerBlock + warpId; - if (tokenIdx >= num_tokens_insert) return; - if (tokenIdx >= num_tokens_full) return; - - int const dim_base = laneId * kElemsPerLane; - scalar_t_in const* src_ptr = - kv_in + static_cast(tokenIdx) * kHeadDim + dim_base; - uint4 const v0 = *reinterpret_cast(src_ptr); - uint4 const v1 = *reinterpret_cast(src_ptr + 8); - - // num_heads_q=0 and slotIdx=0 select the KV branch in processDeepseekV4Slot. - processDeepseekV4Slot( - v0, v1, tokenIdx, 0, dim_base, laneId, 0, 0.0f, nullptr, k_cache, - slot_mapping, position_ids, cos_sin_cache, cache_block_size, - kv_block_stride); -#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800) && !defined(USE_ROCM) - } -#endif -} - -template -void launchDeepseekV4QNormRopeTemplated( - scalar_t_in* q_inout, int64_t const* position_ids, - float const* cos_sin_cache, float const eps, int const num_tokens, - int const num_heads_q, cudaStream_t stream) { - constexpr int kBlockSize = 256; - constexpr int kWarpsPerBlock = kBlockSize / 32; - int64_t const total_warps = - static_cast(num_tokens) * num_heads_q; - int const grid = - static_cast((total_warps + kWarpsPerBlock - 1) / kWarpsPerBlock); -#ifndef USE_ROCM - static int const sm_version = getSMVersion(); - TORCH_CHECK(sm_version >= 80, - "deepseek_v4_qnorm_rope requires sm_80+ (Ampere or newer); got " - "sm_", - sm_version); -#endif - deepseekV4QNormRopeKernel - <<>>( - q_inout, position_ids, cos_sin_cache, eps, num_tokens, num_heads_q); -} - -template -void launchDeepseekV4QNormRope(scalar_t_in* q_inout, - int64_t const* position_ids, - float const* cos_sin_cache, float const eps, - int const num_tokens, int const num_heads_q, - cudaStream_t stream) { -#define DISPATCH(N) \ - case N: \ - launchDeepseekV4QNormRopeTemplated( \ - q_inout, position_ids, cos_sin_cache, eps, num_tokens, num_heads_q, \ - stream); \ - return; - - switch (num_heads_q) { - DISPATCH(8) - DISPATCH(16) - DISPATCH(32) - DISPATCH(64) - DISPATCH(128) - default: - TORCH_CHECK(false, - "deepseek_v4_qnorm_rope: unsupported num_heads_q=", - num_heads_q, - " (compiled instantiations: 8, 16, 32, 64, 128)."); - } -#undef DISPATCH -} - -template -void launchDeepseekV4KVRopeQuantInsert( - scalar_t_in const* kv_in, uint8_t* k_cache, - int64_t const* slot_mapping, int64_t const* position_ids, - float const* cos_sin_cache, int const num_tokens_full, - int const num_tokens_insert, int const cache_block_size, - int const kv_block_stride, cudaStream_t stream) { - constexpr int kBlockSize = 256; - constexpr int kWarpsPerBlock = kBlockSize / 32; - int const grid = - static_cast((num_tokens_insert + kWarpsPerBlock - 1) / kWarpsPerBlock); -#ifndef USE_ROCM - static int const sm_version = getSMVersion(); - TORCH_CHECK( - sm_version >= 80, - "deepseek_v4_kv_rope_quant_insert requires sm_80+ (Ampere or newer); got " - "sm_", - sm_version); -#endif - deepseekV4KVRopeQuantInsertKernel - <<>>( - kv_in, k_cache, slot_mapping, position_ids, cos_sin_cache, - num_tokens_full, num_tokens_insert, cache_block_size, kv_block_stride); -} - } // namespace deepseek_v4_fused_ops } // namespace vllm @@ -870,80 +721,3 @@ torch::Tensor fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert( }); return q_out; } - -void deepseek_v4_qnorm_rope(torch::Tensor& q, - torch::Tensor const& position_ids, - torch::Tensor const& cos_sin_cache, double eps) { - TORCH_CHECK(q.is_cuda() && q.is_contiguous(), "q must be contiguous CUDA"); - TORCH_CHECK(position_ids.is_cuda() && position_ids.dtype() == torch::kInt64, - "position_ids must be int64 CUDA"); - TORCH_CHECK(cos_sin_cache.is_cuda(), "cos_sin_cache must be CUDA"); - TORCH_CHECK(q.dim() == 3 && q.size(2) == 512, "q shape [N, H, 512]"); - TORCH_CHECK(static_cast(position_ids.size(0)) == q.size(0), - "q/position_ids row counts must match"); - TORCH_CHECK(cos_sin_cache.dim() == 2 && cos_sin_cache.size(1) == 64, - "cos_sin_cache shape [max_pos, 64]"); - TORCH_CHECK(cos_sin_cache.dtype() == torch::kFloat32, - "cos_sin_cache must be float32"); - - int const num_tokens = static_cast(q.size(0)); - int const num_heads_q = static_cast(q.size(1)); - if (num_tokens == 0 || num_heads_q == 0) return; - - at::cuda::OptionalCUDAGuard device_guard(device_of(q)); - auto stream = at::cuda::getCurrentCUDAStream(); - - VLLM_DISPATCH_HALF_TYPES(q.scalar_type(), "deepseek_v4_qnorm_rope", [&] { - using qkv_scalar_t = scalar_t; - vllm::deepseek_v4_fused_ops::launchDeepseekV4QNormRope( - reinterpret_cast(q.data_ptr()), - reinterpret_cast(position_ids.data_ptr()), - cos_sin_cache.data_ptr(), static_cast(eps), num_tokens, - num_heads_q, stream); - }); -} - -void deepseek_v4_kv_rope_quant_insert( - torch::Tensor const& kv, torch::Tensor& k_cache, - torch::Tensor const& slot_mapping, torch::Tensor const& position_ids, - torch::Tensor const& cos_sin_cache, int64_t cache_block_size) { - TORCH_CHECK(kv.is_cuda() && kv.is_contiguous(), "kv must be contiguous CUDA"); - TORCH_CHECK(k_cache.is_cuda(), "k_cache must be CUDA"); - TORCH_CHECK(slot_mapping.is_cuda() && slot_mapping.dtype() == torch::kInt64, - "slot_mapping must be int64 CUDA"); - TORCH_CHECK(position_ids.is_cuda() && position_ids.dtype() == torch::kInt64, - "position_ids must be int64 CUDA"); - TORCH_CHECK(cos_sin_cache.is_cuda(), "cos_sin_cache must be CUDA"); - TORCH_CHECK(kv.dim() == 2 && kv.size(1) == 512, "kv shape [N, 512]"); - TORCH_CHECK(k_cache.dtype() == torch::kUInt8, "k_cache must be uint8"); - TORCH_CHECK(static_cast(position_ids.size(0)) == kv.size(0), - "kv/position_ids row counts must match"); - TORCH_CHECK(slot_mapping.size(0) <= kv.size(0), - "slot_mapping must not exceed kv row count"); - TORCH_CHECK(cos_sin_cache.dim() == 2 && cos_sin_cache.size(1) == 64, - "cos_sin_cache shape [max_pos, 64]"); - TORCH_CHECK(cos_sin_cache.dtype() == torch::kFloat32, - "cos_sin_cache must be float32"); - - int const num_tokens_full = static_cast(kv.size(0)); - int const num_tokens_insert = static_cast(slot_mapping.size(0)); - if (num_tokens_full == 0 || num_tokens_insert == 0) return; - int const cache_block_size_i = static_cast(cache_block_size); - int const kv_block_stride = static_cast(k_cache.stride(0)); - - at::cuda::OptionalCUDAGuard device_guard(device_of(kv)); - auto stream = at::cuda::getCurrentCUDAStream(); - - VLLM_DISPATCH_HALF_TYPES( - kv.scalar_type(), "deepseek_v4_kv_rope_quant_insert", [&] { - using qkv_scalar_t = scalar_t; - vllm::deepseek_v4_fused_ops:: - launchDeepseekV4KVRopeQuantInsert( - reinterpret_cast(kv.data_ptr()), - reinterpret_cast(k_cache.data_ptr()), - reinterpret_cast(slot_mapping.data_ptr()), - reinterpret_cast(position_ids.data_ptr()), - cos_sin_cache.data_ptr(), num_tokens_full, - num_tokens_insert, cache_block_size_i, kv_block_stride, stream); - }); -} From 3598253329d0452170be5ade20cbf4a384f7ab05 Mon Sep 17 00:00:00 2001 From: vLLM Contributor Date: Wed, 27 May 2026 16:37:17 +0000 Subject: [PATCH 20/21] guard CUDA path intact Signed-off-by: vLLM Contributor --- vllm/models/deepseek_v4/attention.py | 29 ++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/vllm/models/deepseek_v4/attention.py b/vllm/models/deepseek_v4/attention.py index ae1b9d1a980a..d5e2f4c1b578 100644 --- a/vllm/models/deepseek_v4/attention.py +++ b/vllm/models/deepseek_v4/attention.py @@ -594,16 +594,25 @@ def wq_b_kv_insert() -> torch.Tensor: # NVIDIA keeps the 3-way split. ROCm's measured default keeps # indexer work on default and overlaps only the compressor branch. - run_indexer: Callable[[], Any] = lambda: indexer( - hidden_states, - qr, - indexer_kv_score, - indexer_weights, - positions, - self.indexer_rotary_emb, - use_aux_stream=post_aux_streams is not None - and (not current_platform.is_rocm()), - ) + def run_indexer() -> Any: + if current_platform.is_rocm(): + return indexer( + hidden_states, + qr, + indexer_kv_score, + indexer_weights, + positions, + self.indexer_rotary_emb, + use_aux_stream=False, + ) + return indexer( + hidden_states, + qr, + indexer_kv_score, + indexer_weights, + positions, + self.indexer_rotary_emb, + ) def run_compressor() -> Any: local_kv_score = kv_score From 9484e02f16122e605f52853c16774f0055e9ddb9 Mon Sep 17 00:00:00 2001 From: vLLM Contributor Date: Wed, 27 May 2026 17:19:53 +0000 Subject: [PATCH 21/21] trim to 2-stream strategy Signed-off-by: vLLM Contributor --- .../test_deepseek_v4_rocm_multistream.py | 29 +++++---- vllm/models/deepseek_v4/amd/model.py | 2 +- vllm/models/deepseek_v4/attention.py | 63 +++++++------------ 3 files changed, 36 insertions(+), 58 deletions(-) diff --git a/tests/models/test_deepseek_v4_rocm_multistream.py b/tests/models/test_deepseek_v4_rocm_multistream.py index 10026c78d0be..23db2d61eed3 100644 --- a/tests/models/test_deepseek_v4_rocm_multistream.py +++ b/tests/models/test_deepseek_v4_rocm_multistream.py @@ -36,7 +36,7 @@ class _ForwardContext: pass class _Wrapper: - aux_stream_list = [object(), object(), object(), object(), object()] + aux_stream_list = [object()] forward_context = _ForwardContext() forward_context.cudagraph_runtime_mode = cudagraph_runtime_mode @@ -72,27 +72,26 @@ def test_deepseek_v4_rocm_multistream_decode_policy(): def test_deepseek_v4_rocm_post_rmsnorm_stream_mapping(monkeypatch): calls = [] - streams = [object(), object(), object(), object(), object()] + streams = [object()] - def fake_execute_in_parallel( + def fake_maybe_execute_in_parallel( default_fn, - aux_fns, + aux_fn, start_event, - done_events, - aux_streams, - enable=False, + done_event, + aux_stream=None, ): - assert enable is True - assert aux_streams == [streams[2], streams[1]] - assert len(aux_fns) == 2 - assert aux_fns[0] is None - assert aux_fns[1] is not None + assert aux_stream is streams[0] q = default_fn() - compressor_result = aux_fns[1]() - return q, [None, compressor_result] + compressor_result = aux_fn() + return q, compressor_result - monkeypatch.setattr(dsv4_attention, "execute_in_parallel", fake_execute_in_parallel) + monkeypatch.setattr( + dsv4_attention, + "maybe_execute_in_parallel", + fake_maybe_execute_in_parallel, + ) class _WqB: def __call__(self, qr): diff --git a/vllm/models/deepseek_v4/amd/model.py b/vllm/models/deepseek_v4/amd/model.py index 9f247418bec2..8eca57dfeec6 100644 --- a/vllm/models/deepseek_v4/amd/model.py +++ b/vllm/models/deepseek_v4/amd/model.py @@ -616,7 +616,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.hc_dim = self.hc_mult * config.hidden_size self.rms_norm_eps = config.rms_norm_eps - aux_stream_list = [torch.cuda.Stream(priority=-1) for _ in range(5)] + aux_stream_list = [torch.cuda.Stream(priority=-1)] self.device = current_platform.device_type # Reserved topk indices buffer for all Indexer layers to reuse. diff --git a/vllm/models/deepseek_v4/attention.py b/vllm/models/deepseek_v4/attention.py index d5e2f4c1b578..f1479b8fb35d 100644 --- a/vllm/models/deepseek_v4/attention.py +++ b/vllm/models/deepseek_v4/attention.py @@ -401,7 +401,7 @@ def _use_rocm_csa_multistream( return False if self.aux_stream_list is None: return False - if len(self.aux_stream_list) < 5: + if not self.aux_stream_list: return False if not _is_decode_only_deepseek_v4_step(attn_metadata): @@ -568,16 +568,7 @@ def _post_rmsnorm_prepare( and post_aux_streams is not None and rocm_csa_multistream ) - if ( - rocm_ms_active - and post_aux_streams is not None - and len(post_aux_streams) >= 5 - ): - # ROCm keeps q+kv fused on default and only overlaps the - # compressor on aux[1]. aux[2] stays reserved for the top-level - # indexer branch, which is intentionally not launched by default. - outer_post_aux_streams = [post_aux_streams[2], post_aux_streams[1]] - elif post_aux_streams is not None: + if post_aux_streams is not None and not rocm_ms_active: outer_post_aux_streams = [post_aux_streams[0], post_aux_streams[1]] else: outer_post_aux_streams = None @@ -625,39 +616,27 @@ def run_compressor() -> Any: indexer_fn: Callable[[], Any] | None = run_indexer compressor_fn: Callable[[], Any] | None = run_compressor if rocm_ms_active: - indexer_fn = None - - q, (indexer_result, compressor_result) = execute_in_parallel( - wq_b_kv_insert, - [indexer_fn, compressor_fn], - self.ln_events[0], - [self.ln_events[1], self.ln_events[2]], - outer_post_aux_streams, - enable=post_aux_streams is not None, - ) - if rocm_ms_active: - if indexer_result is None and indexer_fn is None: - run_indexer() - if compressor_result is None and compressor_fn is None: - run_compressor() - elif self.compressor is not None: - # wq_b + kv_insert on default, compressor on aux. - aux_stream = ( - ( - ( - post_aux_streams[1] - if len(post_aux_streams) >= 5 - else post_aux_streams[0] - ) - if current_platform.is_rocm() and post_aux_streams is not None - else outer_post_aux_streams[0] + assert post_aux_streams is not None + q, _ = maybe_execute_in_parallel( + wq_b_kv_insert, + run_compressor, + self.ln_events[0], + self.ln_events[1], + post_aux_streams[0], ) - if ( - (current_platform.is_rocm() and post_aux_streams is not None) - or outer_post_aux_streams is not None + run_indexer() + else: + q, _ = execute_in_parallel( + wq_b_kv_insert, + [indexer_fn, compressor_fn], + self.ln_events[0], + [self.ln_events[1], self.ln_events[2]], + outer_post_aux_streams, + enable=post_aux_streams is not None, ) - else None - ) + elif self.compressor is not None: + # wq_b + kv_insert on default, compressor on aux. + aux_stream = post_aux_streams[0] if post_aux_streams is not None else None compressor = self.compressor def wq_b_kv_insert() -> torch.Tensor: