Skip to content

Use NCCL symmetric memory for DP (includes allgather, fp4 allgatherv, and reducescatter)#9358

Closed
nvcastet wants to merge 10 commits intosgl-project:mainfrom
nvcastet:fp4_allgather
Closed

Use NCCL symmetric memory for DP (includes allgather, fp4 allgatherv, and reducescatter)#9358
nvcastet wants to merge 10 commits intosgl-project:mainfrom
nvcastet:fp4_allgather

Conversation

@nvcastet
Copy link
Copy Markdown
Collaborator

@nvcastet nvcastet commented Aug 19, 2025

Motivation

7.4% e2e perf gain
(Best perf with NCCL 2.28 that just released)
After this PR:

python3 -m sglang.launch_server --model-path nvidia/DeepSeek-R1-0528-FP4 --trust-remote-code --quantization modelopt_fp4 --tp 8 --enable-flashinfer-cutlass-moe --enable-ep-moe --ep-size 8 --dp 8 --enable-dp-attention --chunked-prefill-size 16384 --mem-fraction-static 0.85 --max-running-requests 4096 --stream-interval 5 --enable-dp-lm-head --attention-backend trtllm_mla --cuda-graph-bs 1 2 4 8 16 32 64 128 256 512 1024 2048 4096 8192 --disable-radix-cache  --enable-symm-mem

python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 1024 --random-input 1024 --random-output 2048 --random-range-ratio 1 --warmup-request 1024 --max-concurrency 1024
============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    inf
Max request concurrency:                 1024
Successful requests:                     1024
Benchmark duration (s):                  97.55
Total input tokens:                      1048576
Total generated tokens:                  2097152
Total generated tokens (retokenized):    2083551
Request throughput (req/s):              10.50
Input token throughput (tok/s):          10749.27
Output token throughput (tok/s):         21498.54
Total token throughput (tok/s):          32247.81
Concurrency:                             1022.43
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   97398.95
Median E2E Latency (ms):                 97407.47
---------------Time to First Token----------------
Mean TTFT (ms):                          8711.54
Median TTFT (ms):                        8592.25
P99 TTFT (ms):                           15501.95
---------------Inter-Token Latency----------------
Mean ITL (ms):                           43.33
Median ITL (ms):                         39.74
P95 ITL (ms):                            43.96
P99 ITL (ms):                            46.83
Max ITL (ms):                            2868.54
==================================================

Before:

============ Serving Benchmark Result ============
Backend:                                 sglang
Traffic request rate:                    inf
Max request concurrency:                 1024
Successful requests:                     1024
Benchmark duration (s):                  104.79
Total input tokens:                      1048576
Total generated tokens:                  2097152
Total generated tokens (retokenized):    2089689
Request throughput (req/s):              9.77
Input token throughput (tok/s):          10006.11
Output token throughput (tok/s):         20012.21
Total token throughput (tok/s):          30018.32
Concurrency:                             1022.57
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   104647.36
Median E2E Latency (ms):                 104659.58
---------------Time to First Token----------------
Mean TTFT (ms):                          9131.11
Median TTFT (ms):                        9061.90
P99 TTFT (ms):                           16866.24
---------------Inter-Token Latency----------------
Mean ITL (ms):                           46.66
Median ITL (ms):                         42.67
P95 ITL (ms):                            46.43
P99 ITL (ms):                            48.57
Max ITL (ms):                            3148.65
==================================================

Modifications

Accuracy Tests

Benchmarking and Profiling

Checklist

@kushanam
Copy link
Copy Markdown
Collaborator

kushanam commented Sep 3, 2025

This PR includes the changes (and is a replacement ) from #8934.

@nvcastet
Copy link
Copy Markdown
Collaborator Author

nvcastet commented Sep 3, 2025

This PR includes the changes (and is a replacement ) from #8934.

This PR builds on #8934 (first commit of the PR).
We can either merge #8934 and rebase this one or just merge this one in one shot.

@nvcastet nvcastet requested a review from trevor-m September 5, 2025 18:44
@nvcastet nvcastet changed the title Register Fp4 allgather with NCCL symmetric memory Use NCCL symmetric memory for DP (includes allgather, fp4 allgatherv, and reducescatter) Sep 5, 2025
@nvcastet nvcastet force-pushed the fp4_allgather branch 3 times, most recently from de77ab0 to 68e1d72 Compare September 5, 2025 22:36
@zhyncs
Copy link
Copy Markdown
Collaborator

zhyncs commented Sep 8, 2025

@nvcastet @kushanam please fix the conflicts. thanks

@nvcastet nvcastet force-pushed the fp4_allgather branch 3 times, most recently from ad23fb4 to 9e24242 Compare September 11, 2025 01:48
@fzyzcjy
Copy link
Copy Markdown
Collaborator

fzyzcjy commented Sep 17, 2025

hi could you please rebase it thanks

fzyzcjy added a commit to fzyzcjy/sglang that referenced this pull request Sep 20, 2025
Comment on lines +520 to +523
with use_symmetric_memory(
get_tp_group(),
disabled=not forward_batch.dp_padding_mode.is_max_len(),
):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you try to cache as much variables as possible?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those 2 variables are already cached: tp group and is_max_len for this current batch.

dtype=cls._dtype,
device=cls._device,
)
with use_symmetric_memory(get_tp_group()):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so it always use symmetric memory?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes when the server flag --enable-symm-mem is used. Tensors here are always symmetric, so there is not a case where it needs to be disabled.

Comment on lines +376 to +378
with use_symmetric_memory(
get_tp_group(), disabled=not is_allocation_symmetric()
):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we further simplify the code?

Our principal: When symmetric_memory is not enabled, the added overhead in the model forward pass should just be < 2 if/else and function calls.

For example, you can

  • and hopefully cache the result of is_allocation_symmetric
  • move the function get_tp_group into use_symmetric_memory, when it is disabled, even do not call get_tp_group.
  • In the best case, even do not create a context when it is disabled.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_allocation_symmetric is batch-specific so it is cached in dp_attention.py for the scope of a batch.
get_tp_group gets the communicator spanning the symmetric group, depending on the cases it does not have to be the generic TP communicator, it could the TP attention communicator, etc...
Besides those 2 read-accessor calls and use_symmetric_memory, a nullcontext will be created if symmetric memory is disabled.
During cuda graph replays, this host code will not be re-executed for the life-time of the application.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For both the communicator arg and the disabled arg we could pass the function itself or a lambda where the function calls will be deferred after the "disabled" check. So those functions/lambda will not be called when symmetric memory is globally disabled.
@merrymercy Would that address your concerns?

Comment on lines +1430 to +1432
with use_symmetric_memory(
get_tp_group(), disabled=not is_allocation_symmetric()
):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems you call with use_symmetric_memory( twice here. There are overheads. Can you try to cache and simplify things so that the overhead is minimal and only some if/else when the symmetric memory is not used?

Ideally, when the symmetric memory is not used, it should just be some simple if/else, even without context creation or redundant function calls (e.g., get_tp_group).
In the worst case, we can create empty context to avoid code duplication.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree. The pervasive use of use_symmetric_memory compromises the readability of the code.

Copy link
Copy Markdown
Collaborator Author

@nvcastet nvcastet Oct 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First use is to annotate the input used for the "quantize-before-comm" allgatherv.
Second use is to annotate the output buffer of the fused-MoE since it is not an in-place op. This buffer will then be used for reducescatter or allreduce comm.

In the worst case, we can create empty context to avoid code duplication.

@merrymercy use_symmetric_memory uses a nullcontext when it is disabled.

There is no alternative to achieve "zero-copy" between compute and communication buffers than to annotate tensors that will be used for communication.
PyTorch put in-place this context-base mechanism and PyTorch team is moving forward to expand its support to torch.compile internals.

torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
final_hidden_states = final_hidden_states_out
sm.tag(final_hidden_states)
final_hidden_states += shared_output
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is good!

torch.distributed.reduce_scatter_tensor(
output, input, group=self.device_group
)
return output
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return output is unnecessary now.

Copy link
Copy Markdown
Collaborator Author

@nvcastet nvcastet Oct 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kept the code "as-is" to avoid unrelated changes. But yes we could clean-up those function signatures in another PR if needed.

@ch-wan ch-wan self-assigned this Oct 15, 2025
@nvcastet
Copy link
Copy Markdown
Collaborator Author

@merrymercy I will wait to rebase and solve merge-conflicts until you tell me it is ok to move forward with this PR.

To be honest, it has been 2 months since I opened the PR, we have made a lot of efforts to improve it and make it as clean as possible while still delivering significant e2e perf gain. I don't want all of the us including you to spend more time here if there is not a clear end to the back and forth in sight.

I hear that you have concerned about python-side perf when this feature is disabled. I tried to address them as best as possible while keeping the code modular. I don't believe there is currently overhead in critical code paths that is visible in a e2e run. I also believe the host code will only be executed once per batch-size when cuda-graph is used.

Please, let us know your thoughts.

@nvcastet
Copy link
Copy Markdown
Collaborator Author

@merrymercy Let me know if implemented the logic in comment #9358 (comment) would address your python perf concerns.

def is_symmetric_memory_tensor(tensor: torch.Tensor):
if not is_symmetric_memory_enabled() or _cached_pool_snapshot is None:
return False
for segment in _cached_pool_snapshot:
Copy link
Copy Markdown
Contributor

@merrymercy merrymercy Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This runs a for loop over a long list, will it be slow?
I suspect it is even slower than the old appraoch sm.tag

Copy link
Copy Markdown
Collaborator Author

@nvcastet nvcastet Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I measured the perf it was around 2-3us, the _cached_pool_snapshot is just a dictionary with only symmetric memory segments, not the full memory used by the app.



def is_allocation_symmetric() -> bool:
return not is_dp_attention_enabled() or is_dp_max_padding()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why cannot we use symmetric memory for sum padding?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because in "sum padding", the input tensors are not allocated with the same size on each GPU (batch size) which breaks the requirement for symmetric allocation.

@nvcastet
Copy link
Copy Markdown
Collaborator Author

nvcastet commented Nov 6, 2025

Merged with #12572

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants