-
Notifications
You must be signed in to change notification settings - Fork 5.3k
Use NCCL symmetric memory for DP (includes allgather, fp4 allgatherv, and reducescatter) #9358
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
9bba92b
14cc17a
6344142
f231557
d9eed9c
8c4caed
e17af00
cb51899
41c5682
2c92c38
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -187,6 +187,27 @@ def reg_all_gather_into_tensor_fake( | |
| fake_impl=reg_all_gather_into_tensor_fake, | ||
| ) | ||
|
|
||
| def reg_reduce_scatter_tensor( | ||
| output: torch.Tensor, input: torch.Tensor, group_name: str | ||
| ) -> None: | ||
| assert group_name in _groups, f"Group {group_name} is not found." | ||
| group = _groups[group_name]() | ||
| if group is None: | ||
| raise ValueError(f"Group {group_name} is destroyed.") | ||
| group._reduce_scatter_tensor(output, input) | ||
|
|
||
| def reg_reduce_scatter_tensor_fake( | ||
| output: torch.Tensor, input: torch.Tensor, group_name: str | ||
| ) -> None: | ||
| pass | ||
|
|
||
| direct_register_custom_op( | ||
| op_name="reg_reduce_scatter_tensor", | ||
| op_func=reg_reduce_scatter_tensor, | ||
| mutates_args=["output"], | ||
| fake_impl=reg_reduce_scatter_tensor_fake, | ||
| ) | ||
|
|
||
|
|
||
| class GroupCoordinator: | ||
| """ | ||
|
|
@@ -311,10 +332,16 @@ def __init__( | |
| from sglang.srt.distributed.device_communicators.pynccl import ( | ||
| PyNcclCommunicator, | ||
| ) | ||
| from sglang.srt.distributed.device_communicators.pynccl_allocator import ( | ||
| is_symmetric_memory_tensor, | ||
| use_symmetric_memory, | ||
| ) | ||
| from sglang.srt.distributed.device_communicators.symm_mem import ( | ||
| SymmMemCommunicator, | ||
| ) | ||
|
|
||
| self.is_symmetric_memory_tensor = is_symmetric_memory_tensor | ||
| self.use_symmetric_memory = use_symmetric_memory | ||
| if is_hip(): | ||
| from sglang.srt.distributed.device_communicators.quick_all_reduce import ( | ||
| QuickAllReduce, | ||
|
|
@@ -549,11 +576,7 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: | |
| if self.npu_communicator is not None and not self.npu_communicator.disabled: | ||
| return self.npu_communicator.all_reduce(input_) | ||
|
|
||
| if ( | ||
| self.pynccl_comm is not None | ||
| and hasattr(input_, "symmetric_memory") | ||
| and input_.symmetric_memory | ||
| ): | ||
| if self.pynccl_comm is not None and self.is_symmetric_memory_tensor(input_): | ||
| with self.pynccl_comm.change_state( | ||
| enable=True, stream=torch.get_device_module().current_stream() | ||
| ): | ||
|
|
@@ -628,15 +651,37 @@ def _all_reduce_in_place(self, input_: torch.Tensor) -> None: | |
| else: | ||
| torch.distributed.all_reduce(input_, group=self.device_group) | ||
|
|
||
| def reduce_scatter_tensor( | ||
| def _reduce_scatter_tensor( | ||
| self, | ||
| output: torch.Tensor, | ||
| input: torch.Tensor, | ||
| ) -> None: | ||
| # TODO(ch-wan): support other backends | ||
| torch.distributed.reduce_scatter_tensor(output, input, group=self.device_group) | ||
| ) -> torch.Tensor: | ||
| pynccl_comm = self.pynccl_comm | ||
| if pynccl_comm is not None and ( | ||
| not pynccl_comm.disabled | ||
| or ( | ||
| self.is_symmetric_memory_tensor(output) | ||
| and self.is_symmetric_memory_tensor(input) | ||
| ) | ||
| ): | ||
| with pynccl_comm.change_state( | ||
| enable=True, stream=torch.cuda.current_stream() | ||
| ): | ||
| pynccl_comm.reduce_scatter(output, input) | ||
| else: | ||
| torch.distributed.reduce_scatter_tensor( | ||
| output, input, group=self.device_group | ||
| ) | ||
| return output | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. return output is unnecessary now.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I kept the code "as-is" to avoid unrelated changes. But yes we could clean-up those function signatures in another PR if needed. |
||
|
|
||
| def reduce_scatter_tensor(self, output: torch.Tensor, input: torch.Tensor): | ||
| if _is_npu or not supports_custom_op(): | ||
| self._reduce_scatter_tensor(output, input) | ||
| else: | ||
| torch.ops.sglang.reg_reduce_scatter_tensor( | ||
| output, input, group_name=self.unique_name | ||
| ) | ||
|
|
||
| def reduce_scatter( | ||
| self, | ||
| output: torch.Tensor, | ||
|
|
@@ -683,8 +728,17 @@ def reduce_scatterv( | |
|
|
||
| def _all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor): | ||
| pynccl_comm = self.pynccl_comm | ||
| if pynccl_comm is not None and not pynccl_comm.disabled: | ||
| pynccl_comm.all_gather(output, input) | ||
| if pynccl_comm is not None and ( | ||
| not pynccl_comm.disabled | ||
| or ( | ||
| self.is_symmetric_memory_tensor(output) | ||
| and self.is_symmetric_memory_tensor(input) | ||
| ) | ||
| ): | ||
| with pynccl_comm.change_state( | ||
| enable=True, stream=torch.cuda.current_stream() | ||
| ): | ||
| pynccl_comm.all_gather(output, input) | ||
| else: | ||
| torch.distributed.all_gather_into_tensor( | ||
| output, input, group=self.device_group | ||
|
|
@@ -746,9 +800,10 @@ def all_gather( | |
| # torch.compile . see https://github.com/pytorch/pytorch/issues/138795 | ||
| output_size = (input_size[0] * world_size,) + input_size[1:] | ||
| # Allocate output tensor. | ||
| output_tensor = torch.empty( | ||
| output_size, dtype=input_.dtype, device=input_.device | ||
| ) | ||
| with self.use_symmetric_memory(self): | ||
| output_tensor = torch.empty( | ||
| output_size, dtype=input_.dtype, device=input_.device | ||
| ) | ||
|
|
||
| # All-gather. | ||
| if input_.is_cpu: | ||
|
|
@@ -788,7 +843,7 @@ def all_gatherv( | |
| pynccl_comm is not None and not pynccl_comm.disabled | ||
| ), "pynccl is required for all_gatherv" | ||
|
|
||
| def _all_gather_single( | ||
| def _all_gather_allocate_output( | ||
| input_: torch.Tensor, sizes: Optional[List[int]] = None | ||
| ): | ||
| input_size = input_.size() | ||
|
|
@@ -802,19 +857,25 @@ def _all_gather_single( | |
| else: | ||
| output_size = (input_size[0] * world_size,) + input_size[1:] | ||
| # Allocate output tensor. | ||
| output_tensor = torch.empty( | ||
| output_size, dtype=input_.dtype, device=input_.device | ||
| ) | ||
| pynccl_comm.all_gather(output_tensor, input_, sizes=sizes) | ||
| return output_tensor | ||
| with self.use_symmetric_memory(self, disabled=sizes is not None): | ||
| output_tensor = torch.empty( | ||
| output_size, dtype=input_.dtype, device=input_.device | ||
| ) | ||
| return output_tensor, sizes | ||
|
|
||
| if isinstance(input_, torch.Tensor): | ||
| return _all_gather_single(input_, sizes) | ||
| input_ = [input_] | ||
|
|
||
| output_list = [] | ||
| pynccl_comm.group_start() | ||
| size_list = [] | ||
| for inp in input_: | ||
| output_list.append(_all_gather_single(inp, sizes=sizes)) | ||
| output_tensor, s = _all_gather_allocate_output(inp, sizes=sizes) | ||
| output_list.append(output_tensor) | ||
| size_list.append(s) | ||
|
|
||
| pynccl_comm.group_start() | ||
| for i, inp in enumerate(input_): | ||
| pynccl_comm.all_gather(output_list[i], inp, sizes=size_list[i]) | ||
| pynccl_comm.group_end() | ||
|
|
||
| return output_list | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,8 +21,12 @@ | |
|
|
||
| from sglang.srt.distributed import ( | ||
| get_tensor_model_parallel_world_size, | ||
| get_tp_group, | ||
| tensor_model_parallel_all_reduce, | ||
| ) | ||
| from sglang.srt.distributed.device_communicators.pynccl_allocator import ( | ||
| use_symmetric_memory, | ||
| ) | ||
| from sglang.srt.layers.dp_attention import ( | ||
| attn_tp_all_gather_into_tensor, | ||
| attn_tp_reduce_scatter_tensor, | ||
|
|
@@ -540,7 +544,12 @@ def _gather_hidden_states_and_residual( | |
| use_layer_norm_before_gather = context.attn_tp_size == 1 | ||
| if use_layer_norm_before_gather and hidden_states.shape[0] != 0: | ||
| residual = hidden_states | ||
| hidden_states = layernorm(hidden_states) | ||
| with use_symmetric_memory( | ||
| get_tp_group(), | ||
| disabled=not forward_batch.dp_padding_mode.is_max_len(), | ||
| ): | ||
|
Comment on lines
+547
to
+550
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you try to cache as much variables as possible?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Those 2 variables are already cached: tp group and is_max_len for this current batch. |
||
| hidden_states = layernorm(hidden_states) | ||
|
|
||
| hidden_states, local_hidden_states = ( | ||
| get_global_dp_buffer(), | ||
| hidden_states, | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This runs a for loop over a long list, will it be slow?
I suspect it is even slower than the old appraoch
sm.tagUh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So I measured the perf it was around 2-3us, the
_cached_pool_snapshotis just a dictionary with only symmetric memory segments, not the full memory used by the app.