Skip to content

PERF (DeepSeek): Multi-stream overlap of indexer wk+weights_proj with QKV-A for DeepSeek V3.2 NSA#39943

Open
alexm-redhat wants to merge 1 commit intomainfrom
indexer_multistream
Open

PERF (DeepSeek): Multi-stream overlap of indexer wk+weights_proj with QKV-A for DeepSeek V3.2 NSA#39943
alexm-redhat wants to merge 1 commit intomainfrom
indexer_multistream

Conversation

@alexm-redhat
Copy link
Copy Markdown
Collaborator

@alexm-redhat alexm-redhat commented Apr 15, 2026

Summary

Overlaps the NSA (Nested Sparse Attention) indexer's fused wk_weights_proj GEMM and k_norm with the QKV-A GEMM on a separate CUDA stream, hiding indexer computation behind main-stream
work per transformer block for DeepSeek V3.2 FP8 inference.

Result: ~3% TPOT improvement (68.4 → 70.5 tok/s) on 8×H200 GPUs at batch size 1 decode-only, with full correctness validated via lm_eval.

Motivation

Profiling shows that vLLM runs all NSA indexer operations sequentially on the main CUDA stream. The fused wk_weights_proj GEMM (which combines wk and weights_proj into a single GEMM)
plus k_norm depend only on hidden_states (the layer input), not on any intermediate result from QKV-A. This means they can start executing at the very beginning of the layer's
forward pass, concurrent with QKV-A on the main stream.

SGLang uses a similar multi-stream design to hide indexer work. However, a direct port to vLLM is not possible because vLLM uses torch.compile with fullgraph=True (PIECEWISE mode),
which cannot represent CUDA stream operations — any stream op must be wrapped in a custom op, but this removes the wrapped operations from the compiler's optimization scope.

By wrapping only the data-independent wk_weights_proj + k_norm in minimal custom ops, all other GEMMs (wq_b, q_b_proj) stay inside torch.compile's optimization scope.

Design

Two minimal custom ops (mla_wk_fork / mla_wk_join) implement a fork-join pattern extending the existing MoE shared expert streaming approach (default_moe_runner.py):

forward()
→ mla_wk_fork(hidden_states) # Custom op: launch fused wk_weights_proj+k_norm on alt_stream
→ fused_qkv_a_proj(hidden_states) # COMPILED on main stream (concurrent with fork)
→ q_a_layernorm(q_c) # COMPILED
→ q_b_proj(q_c) # COMPILED
→ kv_preprocess + RoPE # COMPILED
→ mla_wk_join() # Custom op: wait for alt_stream, return [k | weights]
→ Indexer.forward(precomputed_k=k, # COMPILED: skips wk_weights_proj+k_norm,
precomputed_weights=weights) # runs wq_b inline

Alt stream (compiled): wk_weights_proj fused GEMM + k_norm
Main stream (compiled): QKV-A + Q-A LN + Q-B proj + kv preprocess + RoPE
Alt < Main → fork is completely hidden

Novel extension over MoE pattern: The fork operations are wrapped in a separately torch.compile'd _WkForkModule, so they benefit from Inductor optimizations (FP8 quant+GEMM fusion,
kernel selection) even on the alt stream. The MoE pattern runs shared experts eagerly inside custom ops; this PR adds a second compilation unit with graceful fallback to eager if
compilation fails.

Key design decisions

Aspect Why
Only fused wk_weights_proj + k_norm on alt stream These depend only on hidden_states — can start at layer entry, concurrent with QKV-A
Minimal custom op surface (fork + join only) Wrapping more ops removes them from torch.compile scope, losing fusion benefits
Non-splitting ops Splitting ops cause CUDA graph breaks
Separately compiled fork module Eager fork ops lose FP8 fusion and kernel selection
needs_fixed_stride_order tag Prevents Inductor stride conversion overhead
precomputed_k / precomputed_weights params Indexer receives pre-computed results, skipping its own wk_weights_proj + k_norm

Guards and fallbacks

  • Token threshold: Multi-stream disabled for batches > 1024 tokens (avoids SM contention)
  • Environment variable: VLLM_DISABLE_INDEXER_STREAM=1 disables the optimization
  • Non-V3.2 / non-CUDA: alt_stream is None, fork/join become no-ops
  • Compilation failure: Graceful fallback to eager _WkForkModule
  • Fork/join symmetry: _wk_forked flag ensures join only waits when fork actually launched on alt stream

Additional fix

Minor fix in cuda_graph.py: changed gc.collect and torch.accelerator.empty_cache patch lambdas to accept *args, **kwargs for compatibility when these functions are called with
arguments during CUDA graph capture.

Files changed

File Changes Purpose
vllm/compilation/cuda_graph.py +4/-2 Fix lambda signatures in GC/cache patches for CUDA graph capture
vllm/envs.py +5 VLLM_DISABLE_INDEXER_STREAM env var
vllm/model_executor/layers/mla.py +308/-10 _WkForkModule, mla_wk_fork/mla_wk_join custom ops, fork-join logic in forward()
vllm/model_executor/models/deepseek_v2.py +30/-4 Alt stream creation, Indexer.forward() precomputed_k/precomputed_weights params

Performance

Metric Value
Baseline TPOT 68.4 tok/s
Optimized TPOT 70.5 tok/s
Improvement ~3%
Model DeepSeek V3.2 FP8
Hardware 8×H200 GPUs
Batch size 1 (decode-only)
Correctness Full lm_eval validation passed

Test plan

  • lm_eval correctness validation passed
  • CUDA graph capture/replay: no topology mismatch between warmup and capture
  • Regression: VLLM_DISABLE_INDEXER_STREAM=1 matches baseline performance
  • Edge cases: non-V3.2 models, non-CUDA platforms, large batch fallback

@alexm-redhat alexm-redhat self-assigned this Apr 15, 2026
@alexm-redhat alexm-redhat changed the title perf(deepseek): Multi-stream overlap of indexer wk+weights_proj with … PERF (DeepSeek): Multi-stream overlap of indexer wk+weights_proj with … Apr 15, 2026
@alexm-redhat alexm-redhat changed the title PERF (DeepSeek): Multi-stream overlap of indexer wk+weights_proj with … PERF (DeepSeek): Multi-stream overlap of indexer wk+weights_proj with QKV-A for DeepSeek V3.2 NSA Apr 15, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements multi-stream parallelism for the MLA indexer in DeepSeek-V2 models to overlap projection operations with the main forward pass. It introduces custom fork and join operations, utilizing a separately compiled module for side-stream execution. Feedback focuses on optimizing memory by clearing cached results after the join operation and removing unused dead code introduced in the model implementation.

Comment thread vllm/model_executor/layers/mla.py
Comment thread vllm/model_executor/models/deepseek_v2.py Outdated
@mergify mergify Bot added the deepseek Related to DeepSeek models label Apr 15, 2026
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 15, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @alexm-redhat.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 16, 2026

Hi @alexm-redhat, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

…QKV-A for DeepSeek V3.2 NSA

Signed-off-by: Alexander Matveev <amatveev@redhat.com>
Copy link
Copy Markdown
Contributor

@LopezCastroRoberto LopezCastroRoberto left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like we might want wait for torch 2.12 which can handle multi-stream so we don't need to wrap it as a custom op? see #39309 (comment), #39748

I am still curious about the perf improvement at larger batch sizes (e.g. bs {32,64}). Could you bench that, please? Thanks!

stack.enter_context(patch("gc.collect", lambda: None))
stack.enter_context(
patch("torch.accelerator.empty_cache", lambda: None)
patch("gc.collect", lambda *args, **kwargs: None)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems somehow unrelated to this PR?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it is a leftover from previous bugs. Will remove it.


# Token threshold for multi-stream indexer overlap.
# Disables multi-stream for batches > 1024 to avoid SM contention.
_INDEXER_STREAM_TOKEN_THRESHOLD = 1024
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why 1024?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will adjust

# All other GEMMs stay INSIDE torch.compile scope.
hidden_states = torch.ops.vllm.mla_wk_fork(
hidden_states,
self.prefix,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you seen this PR? #39748

Seems like this is the exact pattern that caused the Qwen3 revert?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, I will fix it.

@alexm-redhat
Copy link
Copy Markdown
Collaborator Author

alexm-redhat commented Apr 17, 2026

@LopezCastroRoberto thanks for the review comments, these are good points. Will check the other PRs and also verify perf for the 32-64 batch sizes. The 1024 threshold is coming from SGLang, but maybe in vLLM it may be different.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

deepseek Related to DeepSeek models nvidia

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

2 participants