Use NCCL symmetric memory for DP (includes allgather, fp4 allgatherv, and reducescatter)#9358
Use NCCL symmetric memory for DP (includes allgather, fp4 allgatherv, and reducescatter)#9358nvcastet wants to merge 10 commits intosgl-project:mainfrom
Conversation
f8c5f13 to
66b3a8c
Compare
66b3a8c to
c380e4b
Compare
|
This PR includes the changes (and is a replacement ) from #8934. |
c380e4b to
929cd86
Compare
de77ab0 to
68e1d72
Compare
ad23fb4 to
9e24242
Compare
|
hi could you please rebase it thanks |
9e24242 to
6344142
Compare
| with use_symmetric_memory( | ||
| get_tp_group(), | ||
| disabled=not forward_batch.dp_padding_mode.is_max_len(), | ||
| ): |
There was a problem hiding this comment.
Can you try to cache as much variables as possible?
There was a problem hiding this comment.
Those 2 variables are already cached: tp group and is_max_len for this current batch.
| dtype=cls._dtype, | ||
| device=cls._device, | ||
| ) | ||
| with use_symmetric_memory(get_tp_group()): |
There was a problem hiding this comment.
so it always use symmetric memory?
There was a problem hiding this comment.
Yes when the server flag --enable-symm-mem is used. Tensors here are always symmetric, so there is not a case where it needs to be disabled.
| with use_symmetric_memory( | ||
| get_tp_group(), disabled=not is_allocation_symmetric() | ||
| ): |
There was a problem hiding this comment.
Can we further simplify the code?
Our principal: When symmetric_memory is not enabled, the added overhead in the model forward pass should just be < 2 if/else and function calls.
For example, you can
- and hopefully cache the result of
is_allocation_symmetric - move the function
get_tp_groupintouse_symmetric_memory, when it is disabled, even do not callget_tp_group. - In the best case, even do not create a context when it is disabled.
There was a problem hiding this comment.
is_allocation_symmetric is batch-specific so it is cached in dp_attention.py for the scope of a batch.
get_tp_group gets the communicator spanning the symmetric group, depending on the cases it does not have to be the generic TP communicator, it could the TP attention communicator, etc...
Besides those 2 read-accessor calls and use_symmetric_memory, a nullcontext will be created if symmetric memory is disabled.
During cuda graph replays, this host code will not be re-executed for the life-time of the application.
There was a problem hiding this comment.
For both the communicator arg and the disabled arg we could pass the function itself or a lambda where the function calls will be deferred after the "disabled" check. So those functions/lambda will not be called when symmetric memory is globally disabled.
@merrymercy Would that address your concerns?
| with use_symmetric_memory( | ||
| get_tp_group(), disabled=not is_allocation_symmetric() | ||
| ): |
There was a problem hiding this comment.
It seems you call with use_symmetric_memory( twice here. There are overheads. Can you try to cache and simplify things so that the overhead is minimal and only some if/else when the symmetric memory is not used?
Ideally, when the symmetric memory is not used, it should just be some simple if/else, even without context creation or redundant function calls (e.g., get_tp_group).
In the worst case, we can create empty context to avoid code duplication.
There was a problem hiding this comment.
Agree. The pervasive use of use_symmetric_memory compromises the readability of the code.
There was a problem hiding this comment.
First use is to annotate the input used for the "quantize-before-comm" allgatherv.
Second use is to annotate the output buffer of the fused-MoE since it is not an in-place op. This buffer will then be used for reducescatter or allreduce comm.
In the worst case, we can create empty context to avoid code duplication.
@merrymercy use_symmetric_memory uses a nullcontext when it is disabled.
There is no alternative to achieve "zero-copy" between compute and communication buffers than to annotate tensors that will be used for communication.
PyTorch put in-place this context-base mechanism and PyTorch team is moving forward to expand its support to torch.compile internals.
| torch.add(final_hidden_states, shared_output, out=final_hidden_states_out) | ||
| final_hidden_states = final_hidden_states_out | ||
| sm.tag(final_hidden_states) | ||
| final_hidden_states += shared_output |
| torch.distributed.reduce_scatter_tensor( | ||
| output, input, group=self.device_group | ||
| ) | ||
| return output |
There was a problem hiding this comment.
return output is unnecessary now.
There was a problem hiding this comment.
I kept the code "as-is" to avoid unrelated changes. But yes we could clean-up those function signatures in another PR if needed.
|
@merrymercy I will wait to rebase and solve merge-conflicts until you tell me it is ok to move forward with this PR. To be honest, it has been 2 months since I opened the PR, we have made a lot of efforts to improve it and make it as clean as possible while still delivering significant e2e perf gain. I don't want all of the us including you to spend more time here if there is not a clear end to the back and forth in sight. I hear that you have concerned about python-side perf when this feature is disabled. I tried to address them as best as possible while keeping the code modular. I don't believe there is currently overhead in critical code paths that is visible in a e2e run. I also believe the host code will only be executed once per batch-size when cuda-graph is used. Please, let us know your thoughts. |
|
@merrymercy Let me know if implemented the logic in comment #9358 (comment) would address your python perf concerns. |
| def is_symmetric_memory_tensor(tensor: torch.Tensor): | ||
| if not is_symmetric_memory_enabled() or _cached_pool_snapshot is None: | ||
| return False | ||
| for segment in _cached_pool_snapshot: |
There was a problem hiding this comment.
This runs a for loop over a long list, will it be slow?
I suspect it is even slower than the old appraoch sm.tag
There was a problem hiding this comment.
So I measured the perf it was around 2-3us, the _cached_pool_snapshot is just a dictionary with only symmetric memory segments, not the full memory used by the app.
|
|
||
|
|
||
| def is_allocation_symmetric() -> bool: | ||
| return not is_dp_attention_enabled() or is_dp_max_padding() |
There was a problem hiding this comment.
why cannot we use symmetric memory for sum padding?
There was a problem hiding this comment.
Because in "sum padding", the input tensors are not allocated with the same size on each GPU (batch size) which breaks the requirement for symmetric allocation.
|
Merged with #12572 |
Motivation
7.4% e2e perf gain
(Best perf with NCCL 2.28 that just released)
After this PR:
Before:
Modifications
Accuracy Tests
Benchmarking and Profiling
Checklist