Skip to content

[Performance] DeepSeek V3.2 multi-stream indexer overlap#35968

Open
haosdent wants to merge 2 commits intovllm-project:mainfrom
haosdent:fix-35226
Open

[Performance] DeepSeek V3.2 multi-stream indexer overlap#35968
haosdent wants to merge 2 commits intovllm-project:mainfrom
haosdent:fix-35226

Conversation

@haosdent
Copy link
Contributor

@haosdent haosdent commented Mar 4, 2026

Purpose

Overlap weights_proj with wk + k_norm in the DeepSeek V3.2 Indexer forward pass using a secondary CUDA stream. The weights_proj GEMM is small (hidden_size → n_head, i.e. 7168→64) and underutilizes GPU SMs, so it can run concurrently with wk + k_norm on the auxiliary stream, removing them from the critical path.

torch.compile compatibility

The dual-stream execution is wrapped in a custom op (torch.ops.vllm.indexer_weights_and_k_proj) registered via direct_register_custom_op in deepseek_v2.py, following the same pattern as gdn_in_proj in PR #36795. This makes stream/event operations opaque to torch.compile, preventing graph breaks.

The custom op returns only (weights, k) — both contiguous tensors. The torch.split that produces k_pe/k_nope happens outside the op boundary in Indexer.forward(), where torch.compile traces it natively with correct strides. This avoids stride mismatch issues where non-contiguous torch.split views would leak across the op boundary.

Uses the global aux_stream() singleton (from vllm.utils.torch_utils) and maybe_execute_in_parallel (from vllm.utils.multi_stream_utils), consistent with existing vLLM conventions.

Closes #35226

Test Plan

  • FakeTensorMode tests for the custom op in tests/utils_/test_indexer_dual_stream.py validating output shapes and contiguous strides across multiple dimension combinations

Test Result

tests/utils_/test_indexer_dual_stream.py::TestIndexerWeightsAndKProjOp::test_fake_output_shapes_and_strides PASSED
tests/utils_/test_indexer_dual_stream.py::TestIndexerWeightsAndKProjOp::test_fake_output_shapes_parametrized[1-64-128] PASSED
tests/utils_/test_indexer_dual_stream.py::TestIndexerWeightsAndKProjOp::test_fake_output_shapes_parametrized[16-64-128] PASSED
tests/utils_/test_indexer_dual_stream.py::TestIndexerWeightsAndKProjOp::test_fake_output_shapes_parametrized[128-64-128] PASSED
tests/utils_/test_indexer_dual_stream.py::TestIndexerWeightsAndKProjOp::test_fake_output_shapes_parametrized[256-32-64] PASSED

======================== 5 passed, 6 warnings in 0.82s =========================

@mergify mergify bot added the deepseek Related to DeepSeek models label Mar 4, 2026
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a performance optimization for the DeepSeek V3.2 Indexer by using a secondary CUDA stream to overlap computations. The overall approach is sound and correctly implemented. I've suggested one improvement to further enhance parallelism and better align with the stated goal of the PR, which should lead to better performance.

Note: Security Review did not run due to the size of the PR.

@benchislett
Copy link
Collaborator

I prefer the TRTLLM style of maybe_execute_in_parallel, if you think it's feasible to implement something similar here. Having to maintain two code paths for single-stream and multi-stream is bound to cause issues and duplicate work. See: https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/_torch/modules/attention.py#L1393

@benchislett
Copy link
Collaborator

There are concerns that multi-stream in this naive way will break torch.compile and custom-ops would be required to avoid breaking the graph. Have you observed this? Do the decodes still run in a single full graph with compilation active?

@jhaotingc
Copy link
Contributor

jhaotingc commented Mar 5, 2026

#33505

There's a similar implementation for dual stream hiding but seems that it needs fake op to bypass torch compile error or else dual stream won't show in actual run.

Does this implementation have this issue?

Also, relative discussion here: #32828 (comment)
I guess for some gemms torch compiles aggressively but if it's in a torch op that bypasses torch compile it'll run slower.

(I also vote for TRTLLM's maybe_execute_in_parallel)

@haosdent
Copy link
Contributor Author

haosdent commented Mar 5, 2026

Got it, let me research how TensorRT-LLM works @benchislett @jhaotingc

@haosdent
Copy link
Contributor Author

haosdent commented Mar 5, 2026

There are concerns that multi-stream in this naive way will break torch.compile and custom-ops would be required to avoid breaking the graph. Have you observed this? Do the decodes still run in a single full graph with compilation active?

Yes, we did need to use custom-ops as how TensorRT-LLM does. @benchislett

@haosdent
Copy link
Contributor Author

haosdent commented Mar 5, 2026

@jhaotingc thanks a lot for your useful references!

Does this implementation have this issue?

Not sure, because the weights_proj and wk + norm are relatively small for our case, I guess the impact is limited. But only test could prove this.

@benchislett
Copy link
Collaborator

I get this error when running DSV3.2 NVFP4 on 8xB200 in TP8:

(Worker pid=3157294) (Worker_TP5 pid=3157294) ERROR 03-11 00:34:25 [multiproc_executor.py:932]   File "/tmp/torchinductor_root/pe/cpet3wrv5gxzezciaqngn3apgvewozoe67jpmwcakyztnyeuaduh.py", line 1699, in call
(Worker pid=3157294) (Worker_TP5 pid=3157294) ERROR 03-11 00:34:25 [multiproc_executor.py:932]     assert_size_stride(buf10, (s72, 64), (64, 1), 'torch.ops.vllm.indexer_dual_stream.default')
(Worker pid=3157294) (Worker_TP5 pid=3157294) ERROR 03-11 00:34:25 [multiproc_executor.py:932] AssertionError: expected size 16384==16384, stride 128==64 at dim=0
(Worker pid=3157294) (Worker_TP5 pid=3157294) ERROR 03-11 00:34:25 [multiproc_executor.py:932] Error in op: torch.ops.vllm.indexer_dual_stream.default
(Worker pid=3157294) (Worker_TP5 pid=3157294) ERROR 03-11 00:34:25 [multiproc_executor.py:932] This error most often comes from a incorrect fake (aka meta) kernel for a custom op.
(Worker pid=3157294) (Worker_TP5 pid=3157294) ERROR 03-11 00:34:25 [multiproc_executor.py:932] Use torch.library.opcheck to test your custom op.

Please include tests for the custom op in addition to testing the maybe_execute_in_parallel

@benchislett
Copy link
Collaborator

I got it working using this fake implementation instead:

def _indexer_dual_stream_fake(
    hidden_states: torch.Tensor,
    layer_name: str,
    n_head: int,
    head_dim: int,
    rope_dim: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """Fake implementation for torch.compile shape inference."""
    num_tokens = hidden_states.shape[0]
    dtype = hidden_states.dtype
    device = hidden_states.device
    
    # weights: contiguous, shape (N, 64) -> stride (64, 1)
    weights = torch.empty_strided(
        (num_tokens, n_head), (n_head, 1), dtype=dtype, device=device
    )
    
    # k: contiguous, shape (N, 128) -> stride (128, 1)
    k_stride = (head_dim, 1)
    k = torch.empty_strided(
        (num_tokens, head_dim), k_stride, dtype=dtype, device=device
    )
    
    # k_pe and k_nope: shape (N, 64), but inherit k's stride -> (128, 1)
    k_pe = torch.empty_strided(
        (num_tokens, rope_dim), k_stride, dtype=dtype, device=device
    )
    k_nope = torch.empty_strided(
        (num_tokens, head_dim - rope_dim), k_stride, dtype=dtype, device=device
    )
    
    return weights, k, k_pe, k_nope

@haosdent
Copy link
Contributor Author

(Worker pid=3157294) (Worker_TP5 pid=3157294) ERROR 03-11 00:34:25 [multiproc_executor.py:932] AssertionError: expected size 16384==16384, stride 128==64 at dim=0
I got it working using this fake implementation instead:

Second thought, I change the function to return k only and split outside, to workaround the issue.

@haosdent
Copy link
Contributor Author

haosdent commented Mar 16, 2026

@benchislett May you help to review again when you are available? Thank you in advance.

Then I could do a benchmark when you think the direction of this change is correct.

@haosdent haosdent marked this pull request as ready for review March 16, 2026 10:26
@haosdent haosdent changed the title [WIP] [Performance] DeepSeek V3.2 multi-stream indexer overlap [Performance] DeepSeek V3.2 multi-stream indexer overlap Mar 16, 2026
@benchislett
Copy link
Collaborator

The change seems reasonable. I have no qualms other than those stated in my review of #36795.

@benchislett
Copy link
Collaborator

Note though that this is just one of many opportunities for multi-streaming in DSV3.2 indexer. Hopefully once torch.compile + multi-stream becomes standard we can expand to many more cases

@benchislett
Copy link
Collaborator

@haosdent maybe_execute_in_parallel has been added in #36795 which just merged. Please update and resolve the conflict (use the existing implementation)

@haosdent haosdent force-pushed the fix-35226 branch 2 times, most recently from 367c11c to 415566d Compare March 19, 2026 04:12
@haosdent
Copy link
Contributor Author

Thanks @benchislett , have rebased and pushed.

Copy link
Collaborator

@benchislett benchislett left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

few more changes please. see comments

Overlap `weights_proj` with `wk + k_norm` in the Indexer forward pass
using a secondary CUDA stream. Wrapped as a custom op
(`torch.ops.vllm.indexer_weights_and_k_proj`) so stream/event
operations are opaque to torch.compile and do not cause graph breaks.

The custom op returns `(weights, k)` — both contiguous tensors.
`torch.split` to produce `k_pe`/`k_nope` happens outside the op
boundary where torch.compile traces it natively with correct strides.

Signed-off-by: haosdent <haosdent@gmail.com>
Co-authored-by: Xin Yang <xyangx@amazon.com>
Co-authored-by: Ben Chislett <chislett.ben@gmail.com>
@haosdent
Copy link
Contributor Author

Thanks @benchislett , I have fixed some of the comments except 1 that needs clarification

@benchislett benchislett added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 23, 2026
@LucasWilkinson
Copy link
Collaborator

Do we have any numbers on the performance benefits?


def _indexer_weights_and_k_proj_impl(
hidden_states: torch.Tensor,
layer_name: str,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to avoid passing the layer_name as a string? This will regress cold compile times (something like 4x usually). Alternatively, is it possible for this to wait until after vLLM upgrades to PyTorch 2.11? (next Monday/Tuesday probably)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be clear I need to ship a quick PR to vLLM after the PyTorch 2.11 update and then we should be able to use something like the layer_name as an input to the custom operator

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @zou3519 , then I rebase after your PR is merged.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the PR btw, #38123

@haosdent
Copy link
Contributor Author

Do we have any numbers on the performance benefits?

@LucasWilkinson Sry that I haven't finished the benchmark yet. I encountered issues finding cards and setting up the environment recently. Would post the result once it finishes

@benchislett
Copy link
Collaborator

We will need to see benchmark numbers before moving forward.

I am also exploring an optimization which fuses the weights_proj and the wk projection, which involves upcasting the wk matrix to FP32. We will need to compare these to see which one gives more speedup

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

deepseek Related to DeepSeek models ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Performance]: DeepSeek 3.2 Multi-stream indexer

6 participants