PERF (DeepSeek): Multi-stream overlap of indexer wk+weights_proj with QKV-A for DeepSeek V3.2 NSA#39943
PERF (DeepSeek): Multi-stream overlap of indexer wk+weights_proj with QKV-A for DeepSeek V3.2 NSA#39943alexm-redhat wants to merge 1 commit intomainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request implements multi-stream parallelism for the MLA indexer in DeepSeek-V2 models to overlap projection operations with the main forward pass. It introduces custom fork and join operations, utilizing a separately compiled module for side-stream execution. Feedback focuses on optimizing memory by clearing cached results after the join operation and removing unused dead code introduced in the model implementation.
29d7e3f to
65c4e89
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
65c4e89 to
3047565
Compare
3047565 to
d2ff159
Compare
|
Hi @alexm-redhat, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
…QKV-A for DeepSeek V3.2 NSA Signed-off-by: Alexander Matveev <amatveev@redhat.com>
d2ff159 to
cefb933
Compare
LopezCastroRoberto
left a comment
There was a problem hiding this comment.
Seems like we might want wait for torch 2.12 which can handle multi-stream so we don't need to wrap it as a custom op? see #39309 (comment), #39748
I am still curious about the perf improvement at larger batch sizes (e.g. bs {32,64}). Could you bench that, please? Thanks!
| stack.enter_context(patch("gc.collect", lambda: None)) | ||
| stack.enter_context( | ||
| patch("torch.accelerator.empty_cache", lambda: None) | ||
| patch("gc.collect", lambda *args, **kwargs: None) |
There was a problem hiding this comment.
This seems somehow unrelated to this PR?
There was a problem hiding this comment.
Yeah, it is a leftover from previous bugs. Will remove it.
|
|
||
| # Token threshold for multi-stream indexer overlap. | ||
| # Disables multi-stream for batches > 1024 to avoid SM contention. | ||
| _INDEXER_STREAM_TOKEN_THRESHOLD = 1024 |
| # All other GEMMs stay INSIDE torch.compile scope. | ||
| hidden_states = torch.ops.vllm.mla_wk_fork( | ||
| hidden_states, | ||
| self.prefix, |
There was a problem hiding this comment.
Have you seen this PR? #39748
Seems like this is the exact pattern that caused the Qwen3 revert?
There was a problem hiding this comment.
Good catch, I will fix it.
|
@LopezCastroRoberto thanks for the review comments, these are good points. Will check the other PRs and also verify perf for the 32-64 batch sizes. The 1024 threshold is coming from SGLang, but maybe in vLLM it may be different. |
Summary
Overlaps the NSA (Nested Sparse Attention) indexer's fused
wk_weights_projGEMM andk_normwith the QKV-A GEMM on a separate CUDA stream, hiding indexer computation behind main-streamwork per transformer block for DeepSeek V3.2 FP8 inference.
Result: ~3% TPOT improvement (68.4 → 70.5 tok/s) on 8×H200 GPUs at batch size 1 decode-only, with full correctness validated via
lm_eval.Motivation
Profiling shows that vLLM runs all NSA indexer operations sequentially on the main CUDA stream. The fused
wk_weights_projGEMM (which combineswkandweights_projinto a single GEMM)plus
k_normdepend only onhidden_states(the layer input), not on any intermediate result from QKV-A. This means they can start executing at the very beginning of the layer'sforward pass, concurrent with QKV-A on the main stream.
SGLang uses a similar multi-stream design to hide indexer work. However, a direct port to vLLM is not possible because vLLM uses
torch.compilewithfullgraph=True(PIECEWISE mode),which cannot represent CUDA stream operations — any stream op must be wrapped in a custom op, but this removes the wrapped operations from the compiler's optimization scope.
By wrapping only the data-independent
wk_weights_proj+k_normin minimal custom ops, all other GEMMs (wq_b,q_b_proj) stay insidetorch.compile's optimization scope.Design
Two minimal custom ops (
mla_wk_fork/mla_wk_join) implement a fork-join pattern extending the existing MoE shared expert streaming approach (default_moe_runner.py):forward()
→ mla_wk_fork(hidden_states) # Custom op: launch fused wk_weights_proj+k_norm on alt_stream
→ fused_qkv_a_proj(hidden_states) # COMPILED on main stream (concurrent with fork)
→ q_a_layernorm(q_c) # COMPILED
→ q_b_proj(q_c) # COMPILED
→ kv_preprocess + RoPE # COMPILED
→ mla_wk_join() # Custom op: wait for alt_stream, return [k | weights]
→ Indexer.forward(precomputed_k=k, # COMPILED: skips wk_weights_proj+k_norm,
precomputed_weights=weights) # runs wq_b inline
Alt stream (compiled): wk_weights_proj fused GEMM + k_norm
Main stream (compiled): QKV-A + Q-A LN + Q-B proj + kv preprocess + RoPE
Alt < Main → fork is completely hidden
Novel extension over MoE pattern: The fork operations are wrapped in a separately
torch.compile'd_WkForkModule, so they benefit from Inductor optimizations (FP8 quant+GEMM fusion,kernel selection) even on the alt stream. The MoE pattern runs shared experts eagerly inside custom ops; this PR adds a second compilation unit with graceful fallback to eager if
compilation fails.
Key design decisions
wk_weights_proj+k_normon alt streamhidden_states— can start at layer entry, concurrent with QKV-Atorch.compilescope, losing fusion benefitsneeds_fixed_stride_ordertagprecomputed_k/precomputed_weightsparamswk_weights_proj+k_normGuards and fallbacks
VLLM_DISABLE_INDEXER_STREAM=1disables the optimizationalt_streamisNone, fork/join become no-ops_WkForkModule_wk_forkedflag ensures join only waits when fork actually launched on alt streamAdditional fix
Minor fix in
cuda_graph.py: changedgc.collectandtorch.accelerator.empty_cachepatch lambdas to accept*args, **kwargsfor compatibility when these functions are called witharguments during CUDA graph capture.
Files changed
vllm/compilation/cuda_graph.pyvllm/envs.pyVLLM_DISABLE_INDEXER_STREAMenv varvllm/model_executor/layers/mla.py_WkForkModule,mla_wk_fork/mla_wk_joincustom ops, fork-join logic inforward()vllm/model_executor/models/deepseek_v2.pyIndexer.forward()precomputed_k/precomputed_weightsparamsPerformance
lm_evalvalidation passedTest plan
lm_evalcorrectness validation passedVLLM_DISABLE_INDEXER_STREAM=1matches baseline performance