Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import os
import tempfile
from contextlib import nullcontext

import torch
from torch.cuda.memory import CUDAPluggableAllocator

from sglang.srt.distributed.parallel_state import GroupCoordinator
from sglang.srt.server_args import get_global_server_args

nccl_allocator_source = """

Expand Down Expand Up @@ -60,6 +60,9 @@


def is_symmetric_memory_enabled():
# Import here to avoid circular import
from sglang.srt.server_args import get_global_server_args

return get_global_server_args().enable_symm_mem


Expand Down Expand Up @@ -92,33 +95,25 @@ def get_nccl_mem_pool():
return _mem_pool


class use_symmetric_memory:
class SymmetricMemoryContext:
"""
Context manager for using symmetric memory with pynccl.

To Utilize the symmetric memory feature in NCCL, the buffers need to be allocated
by `ncclMemAlloc` and registered by `ncclCommWindowRegister`. Due to this, we introduce
this context manager. All tensors created under this context will be correctly
allocated and registered with a custom allocator.

In addition, developers need to manually tag the tensors that will be used as the input/output
of NCCL collectives with `tag(tensor)`.
"""

def __init__(self, group_coordinator: GroupCoordinator):
self.enabled = is_symmetric_memory_enabled()

if not self.enabled:
return

def __init__(
self,
group_coordinator: GroupCoordinator,
):
self.group_coordinator = group_coordinator
self._mem_pool_ctx = torch.cuda.use_mem_pool(get_nccl_mem_pool())
self.is_graph_capture = torch.cuda.is_current_stream_capturing()
Comment thread
nvcastet marked this conversation as resolved.

def __enter__(self):
if not self.enabled:
return self

assert (
self.group_coordinator.pynccl_comm is not None
), f"Symmetric memory requires pynccl to be enabled in group '{self.group_coordinator.group_name}'"
Expand All @@ -139,16 +134,16 @@ def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
if not self.enabled:
return

self._mem_pool_ctx.__exit__(exc_type, exc_val, exc_tb)

if self.is_graph_capture:
torch._C._cuda_beginAllocateCurrentThreadToPool(_cur_device, _graph_pool_id)

def tag(self, tensor: torch.Tensor):
if not self.enabled:
return

tensor.symmetric_memory = True
def use_symmetric_memory(group_coordinator: GroupCoordinator, disabled: bool = False):
disabled = (
not is_symmetric_memory_enabled()
or disabled
or group_coordinator.world_size == 1
)
return SymmetricMemoryContext(group_coordinator) if not disabled else nullcontext()
95 changes: 76 additions & 19 deletions python/sglang/srt/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,27 @@ def reg_all_gather_into_tensor_fake(
fake_impl=reg_all_gather_into_tensor_fake,
)

def reg_reduce_scatter_tensor(
output: torch.Tensor, input: torch.Tensor, group_name: str
) -> None:
assert group_name in _groups, f"Group {group_name} is not found."
group = _groups[group_name]()
if group is None:
raise ValueError(f"Group {group_name} is destroyed.")
group._reduce_scatter_tensor(output, input)

def reg_reduce_scatter_tensor_fake(
output: torch.Tensor, input: torch.Tensor, group_name: str
) -> None:
pass

direct_register_custom_op(
op_name="reg_reduce_scatter_tensor",
op_func=reg_reduce_scatter_tensor,
mutates_args=["output"],
fake_impl=reg_reduce_scatter_tensor_fake,
)


class GroupCoordinator:
"""
Expand Down Expand Up @@ -314,10 +335,16 @@ def __init__(
from sglang.srt.distributed.device_communicators.pynccl import (
PyNcclCommunicator,
)
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
is_symmetric_memory_enabled,
use_symmetric_memory,
)
from sglang.srt.distributed.device_communicators.torch_symm_mem import (
TorchSymmMemCommunicator,
)

self.is_symmetric_memory_enabled = is_symmetric_memory_enabled
self.use_symmetric_memory = use_symmetric_memory
if is_hip():
from sglang.srt.distributed.device_communicators.quick_all_reduce import (
QuickAllReduce,
Expand Down Expand Up @@ -552,7 +579,7 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
if self.npu_communicator is not None and not self.npu_communicator.disabled:
return self.npu_communicator.all_reduce(input_)

if self.pynccl_comm is not None and getattr(input_, "symmetric_memory", False):
if self.pynccl_comm is not None and self.is_symmetric_memory_enabled():
with self.pynccl_comm.change_state(
enable=True, stream=get_current_device_stream_fast()
):
Expand Down Expand Up @@ -627,15 +654,33 @@ def _all_reduce_in_place(self, input_: torch.Tensor) -> None:
else:
torch.distributed.all_reduce(input_, group=self.device_group)

def reduce_scatter_tensor(
def _reduce_scatter_tensor(
self,
output: torch.Tensor,
input: torch.Tensor,
) -> None:
# TODO(ch-wan): support other backends
torch.distributed.reduce_scatter_tensor(output, input, group=self.device_group)
) -> torch.Tensor:
pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and (
not pynccl_comm.disabled or self.is_symmetric_memory_enabled()
):
with pynccl_comm.change_state(
enable=True, stream=get_current_device_stream_fast()
):
pynccl_comm.reduce_scatter(output, input)
else:
torch.distributed.reduce_scatter_tensor(
output, input, group=self.device_group
)
return output

def reduce_scatter_tensor(self, output: torch.Tensor, input: torch.Tensor):
Comment thread
nvcastet marked this conversation as resolved.
if _is_npu or not supports_custom_op():
self._reduce_scatter_tensor(output, input)
else:
torch.ops.sglang.reg_reduce_scatter_tensor(
output, input, group_name=self.unique_name
)

def reduce_scatter(
self,
output: torch.Tensor,
Expand Down Expand Up @@ -682,8 +727,13 @@ def reduce_scatterv(

def _all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor):
pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled:
pynccl_comm.all_gather(output, input)
if pynccl_comm is not None and (
not pynccl_comm.disabled or self.is_symmetric_memory_enabled()
):
with pynccl_comm.change_state(
enable=True, stream=get_current_device_stream_fast()
):
pynccl_comm.all_gather(output, input)
else:
torch.distributed.all_gather_into_tensor(
output, input, group=self.device_group
Expand Down Expand Up @@ -745,9 +795,10 @@ def all_gather(
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
output_size = (input_size[0] * world_size,) + input_size[1:]
# Allocate output tensor.
output_tensor = torch.empty(
output_size, dtype=input_.dtype, device=input_.device
)
with self.use_symmetric_memory(self):
output_tensor = torch.empty(
output_size, dtype=input_.dtype, device=input_.device
)

# All-gather.
if input_.is_cpu:
Expand Down Expand Up @@ -787,7 +838,7 @@ def all_gatherv(
pynccl_comm is not None and not pynccl_comm.disabled
), "pynccl is required for all_gatherv"

def _all_gather_single(
def _all_gather_allocate_output(
input_: torch.Tensor, sizes: Optional[List[int]] = None
):
input_size = input_.size()
Expand All @@ -801,19 +852,25 @@ def _all_gather_single(
else:
output_size = (input_size[0] * world_size,) + input_size[1:]
# Allocate output tensor.
output_tensor = torch.empty(
output_size, dtype=input_.dtype, device=input_.device
)
pynccl_comm.all_gather(output_tensor, input_, sizes=sizes)
return output_tensor
with self.use_symmetric_memory(self, disabled=sizes is not None):
output_tensor = torch.empty(
output_size, dtype=input_.dtype, device=input_.device
)
return output_tensor, sizes

if isinstance(input_, torch.Tensor):
return _all_gather_single(input_, sizes)
input_ = [input_]

output_list = []
pynccl_comm.group_start()
size_list = []
for inp in input_:
output_list.append(_all_gather_single(inp, sizes=sizes))
output_tensor, s = _all_gather_allocate_output(inp, sizes=sizes)
output_list.append(output_tensor)
size_list.append(s)

pynccl_comm.group_start()
for i, inp in enumerate(input_):
pynccl_comm.all_gather(output_list[i], inp, sizes=size_list[i])
pynccl_comm.group_end()

return output_list
Expand Down
12 changes: 11 additions & 1 deletion python/sglang/srt/layers/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,12 @@

from sglang.srt.distributed import (
get_tensor_model_parallel_world_size,
get_tp_group,
tensor_model_parallel_all_reduce,
)
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
use_symmetric_memory,
)
from sglang.srt.layers.dp_attention import (
attn_tp_all_gather_into_tensor,
attn_tp_reduce_scatter_tensor,
Expand All @@ -34,6 +38,7 @@
get_attention_tp_size,
get_global_dp_buffer,
get_local_dp_buffer,
is_allocation_symmetric,
is_dp_attention_enabled,
)
from sglang.srt.layers.moe import (
Expand Down Expand Up @@ -540,7 +545,12 @@ def _gather_hidden_states_and_residual(
use_layer_norm_before_gather = context.attn_tp_size == 1
if use_layer_norm_before_gather and hidden_states.shape[0] != 0:
residual = hidden_states
hidden_states = layernorm(hidden_states)
with use_symmetric_memory(
get_tp_group(),
disabled=not is_allocation_symmetric(),
):
hidden_states = layernorm(hidden_states)

hidden_states, local_hidden_states = (
get_global_dp_buffer(),
hidden_states,
Expand Down
45 changes: 34 additions & 11 deletions python/sglang/srt/layers/dp_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
get_tp_group,
tensor_model_parallel_all_reduce,
)
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
use_symmetric_memory,
)
from sglang.srt.utils import get_bool_env_var, is_hip

if TYPE_CHECKING:
Expand Down Expand Up @@ -86,6 +89,7 @@ class _DpGatheredBufferWrapper:
_device: torch.device
_global_dp_buffer_len: int
_local_dp_buffer_len: int
_dp_max_padding: bool
_global_num_tokens: Optional[List[int]]
_is_extend_in_batch: bool

Expand All @@ -100,27 +104,33 @@ def set_dp_buffer_len(
cls,
global_dp_buffer_len: int,
local_dp_buffer_len: int,
dp_max_padding: bool,
global_num_tokens: Optional[List[int]] = None,
):
cls._global_dp_buffer_len = global_dp_buffer_len
cls._local_dp_buffer_len = local_dp_buffer_len
cls._dp_max_padding = dp_max_padding
cls._global_num_tokens = global_num_tokens

@classmethod
def get_global_dp_buffer(cls) -> torch.Tensor:
return torch.empty(
(cls._global_dp_buffer_len, cls._hidden_size),
dtype=cls._dtype,
device=cls._device,
)
with use_symmetric_memory(get_tp_group()):
buffer = torch.empty(
(cls._global_dp_buffer_len, cls._hidden_size),
dtype=cls._dtype,
device=cls._device,
)
return buffer

@classmethod
def get_local_dp_buffer(cls) -> torch.Tensor:
return torch.empty(
(cls._local_dp_buffer_len, cls._hidden_size),
dtype=cls._dtype,
device=cls._device,
)
with use_symmetric_memory(get_tp_group(), disabled=not cls._dp_max_padding):
buffer = torch.empty(
(cls._local_dp_buffer_len, cls._hidden_size),
dtype=cls._dtype,
device=cls._device,
)
return buffer

@classmethod
def get_global_dp_buffer_len(cls) -> int:
Expand Down Expand Up @@ -154,14 +164,19 @@ def set_is_extend_in_batch(cls, is_extend_in_batch: bool):
def get_is_extend_in_batch(cls) -> bool:
return cls._is_extend_in_batch

@classmethod
def is_dp_max_padding(cls) -> bool:
return cls._dp_max_padding


def set_dp_buffer_len(
global_dp_buffer_len: int,
local_dp_buffer_len: int,
dp_max_padding: bool,
global_num_tokens: Optional[List[int]] = None,
):
_DpGatheredBufferWrapper.set_dp_buffer_len(
global_dp_buffer_len, local_dp_buffer_len, global_num_tokens
global_dp_buffer_len, local_dp_buffer_len, dp_max_padding, global_num_tokens
)


Expand Down Expand Up @@ -205,6 +220,10 @@ def get_is_extend_in_batch() -> bool:
return _DpGatheredBufferWrapper.get_is_extend_in_batch()


def is_dp_max_padding() -> bool:
return _DpGatheredBufferWrapper.is_dp_max_padding()


def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
if not enable_dp_attention:
return tp_rank, tp_size, 0
Expand Down Expand Up @@ -298,6 +317,10 @@ def is_dp_attention_enabled() -> bool:
return _ENABLE_DP_ATTENTION_FLAG


def is_allocation_symmetric() -> bool:
return not is_dp_attention_enabled() or is_dp_max_padding()


def get_attention_tp_group() -> GroupCoordinator:
assert _ATTN_TP_GROUP is not None, "dp attention not initialized!"
return _ATTN_TP_GROUP
Expand Down
Loading
Loading