Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions python/sglang/srt/configs/falcon_h1.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from transformers.utils import logging

from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape
from sglang.srt.layers.dp_attention import get_tensor_model_parallel_world_size

logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -297,8 +296,10 @@ def linear_layer_ids(self):

@property
def mamba2_cache_params(self):
from sglang.srt.layers.dp_attention import get_attention_tp_size

shape = Mamba2StateShape.create(
tp_world_size=get_tensor_model_parallel_world_size(),
tp_world_size=get_attention_tp_size(),
intermediate_size=self.mamba_intermediate,
n_groups=self.mamba_n_groups,
num_heads=self.mamba_n_heads,
Expand Down
3 changes: 2 additions & 1 deletion python/sglang/srt/configs/nemotron_h.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from transformers.utils import logging

from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape
from sglang.srt.layers.dp_attention import get_attention_tp_size

logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -273,6 +272,8 @@ def full_attention_layer_ids(self):

@property
def mamba2_cache_params(self) -> Mamba2CacheParams:
from sglang.srt.layers.dp_attention import get_attention_tp_size

shape = Mamba2StateShape.create(
tp_world_size=get_attention_tp_size(),
intermediate_size=self.mamba_num_heads * self.mamba_head_dim,
Expand Down
3 changes: 2 additions & 1 deletion python/sglang/srt/configs/qwen3_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from transformers.utils import logging

from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape
from sglang.srt.layers.dp_attention import get_attention_tp_size

logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -277,6 +276,8 @@ def full_attention_layer_ids(self):

@property
def mamba2_cache_params(self) -> Mamba2CacheParams:
from sglang.srt.layers.dp_attention import get_attention_tp_size

shape = Mamba2StateShape.create(
tp_world_size=get_attention_tp_size(),
intermediate_size=self.linear_value_head_dim * self.linear_num_value_heads,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,11 @@


try:
if ops.use_vllm_custom_allreduce and not _is_hip:
# Use vLLM custom allreduce
ops.meta_size()
else:
# Use custom allreduce from sgl kernel (ROCM and TRT-LLM)
import sgl_kernel # noqa: F401
# Use custom allreduce from sgl kernel (ROCM and TRT-LLM)
import sgl_kernel # noqa: F401

custom_ar = True
except Exception:
except ImportError:
# For CPUs
custom_ar = False

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import tempfile
from contextlib import nullcontext

import torch
from packaging import version
Expand Down Expand Up @@ -29,12 +30,23 @@
_mem_pool = None
_registered_base_addrs = set()
_graph_pool_id = None
_cached_pool_snapshot = None


def is_symmetric_memory_enabled():
return get_global_server_args().enable_symm_mem


def is_symmetric_memory_tensor(tensor: torch.Tensor):
if not is_symmetric_memory_enabled() or _cached_pool_snapshot is None:
return False
for segment in _cached_pool_snapshot:
Copy link
Copy Markdown
Contributor

@merrymercy merrymercy Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This runs a for loop over a long list, will it be slow?
I suspect it is even slower than the old appraoch sm.tag

Copy link
Copy Markdown
Collaborator Author

@nvcastet nvcastet Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I measured the perf it was around 2-3us, the _cached_pool_snapshot is just a dictionary with only symmetric memory segments, not the full memory used by the app.

for block in segment["blocks"]:
if block["address"] == tensor.untyped_storage().data_ptr():
return True
return False


def set_graph_pool_id(graph_pool_id):
global _graph_pool_id
_graph_pool_id = graph_pool_id
Expand Down Expand Up @@ -63,30 +75,18 @@ def get_nccl_mem_pool():
return _mem_pool


class use_symmetric_memory:
def __init__(self, group_coordinator: GroupCoordinator):
if not is_symmetric_memory_enabled():
self.group_coordinator = None
self._mem_pool_ctx = None
self.is_graph_capture = None
self.device = None
self.pre_2_8_0 = None
else:
self.group_coordinator = group_coordinator
self._mem_pool_ctx = torch.cuda.use_mem_pool(get_nccl_mem_pool())
self.is_graph_capture = torch.cuda.is_current_stream_capturing()
self.device = torch.cuda.current_device()
self.pre_2_8_0 = version.parse(torch.__version__) < version.parse("2.8.0")
class SymmetricMemoryContext:
def __init__(
self,
group_coordinator: GroupCoordinator,
):
self.group_coordinator = group_coordinator
self._mem_pool_ctx = torch.cuda.use_mem_pool(get_nccl_mem_pool())
self.is_graph_capture = torch.cuda.is_current_stream_capturing()
self.device = torch.cuda.current_device()
self.pre_2_8_0 = version.parse(torch.__version__) < version.parse("2.8.0")

def __enter__(self):
if not is_symmetric_memory_enabled():
return self
assert (
self.group_coordinator.pynccl_comm is not None
), f"Symmetric memory requires pynccl to be enabled in group '{self.group_coordinator.group_name}'"
assert (
self.group_coordinator.pynccl_comm.nccl_version >= 22703
), "NCCL version 2.27.3 or higher is required for NCCL symmetric memory"
if self.is_graph_capture:
assert (
_graph_pool_id is not None
Expand All @@ -101,17 +101,12 @@ def __enter__(self):
self._mem_pool_ctx.__enter__()
return self

def tag(self, tensor: torch.Tensor):
if not is_symmetric_memory_enabled():
return
tensor.symmetric_memory = True

def __exit__(self, exc_type, exc_val, exc_tb):
if not is_symmetric_memory_enabled():
return
global _cached_pool_snapshot
global _registered_base_addrs
self._mem_pool_ctx.__exit__(exc_type, exc_val, exc_tb)
for segment in get_nccl_mem_pool().snapshot():
_cached_pool_snapshot = get_nccl_mem_pool().snapshot()
for segment in _cached_pool_snapshot:
if segment["address"] not in _registered_base_addrs:
if segment["stream"] == 0 and self.pre_2_8_0:
# PyTorch version < 2.8.0 has a multi-thread MemPool bug
Expand All @@ -131,3 +126,12 @@ def __exit__(self, exc_type, exc_val, exc_tb):
torch._C._cuda_beginAllocateCurrentThreadToPool(
self.device, _graph_pool_id
)


def use_symmetric_memory(group_coordinator: GroupCoordinator, disabled: bool = False):
disabled = (
disabled
or not is_symmetric_memory_enabled()
or group_coordinator.world_size == 1
)
return SymmetricMemoryContext(group_coordinator) if not disabled else nullcontext()
107 changes: 84 additions & 23 deletions python/sglang/srt/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,27 @@ def reg_all_gather_into_tensor_fake(
fake_impl=reg_all_gather_into_tensor_fake,
)

def reg_reduce_scatter_tensor(
output: torch.Tensor, input: torch.Tensor, group_name: str
) -> None:
assert group_name in _groups, f"Group {group_name} is not found."
group = _groups[group_name]()
if group is None:
raise ValueError(f"Group {group_name} is destroyed.")
group._reduce_scatter_tensor(output, input)

def reg_reduce_scatter_tensor_fake(
output: torch.Tensor, input: torch.Tensor, group_name: str
) -> None:
pass

direct_register_custom_op(
op_name="reg_reduce_scatter_tensor",
op_func=reg_reduce_scatter_tensor,
mutates_args=["output"],
fake_impl=reg_reduce_scatter_tensor_fake,
)


class GroupCoordinator:
"""
Expand Down Expand Up @@ -311,10 +332,16 @@ def __init__(
from sglang.srt.distributed.device_communicators.pynccl import (
PyNcclCommunicator,
)
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
is_symmetric_memory_tensor,
use_symmetric_memory,
)
from sglang.srt.distributed.device_communicators.symm_mem import (
SymmMemCommunicator,
)

self.is_symmetric_memory_tensor = is_symmetric_memory_tensor
self.use_symmetric_memory = use_symmetric_memory
if is_hip():
from sglang.srt.distributed.device_communicators.quick_all_reduce import (
QuickAllReduce,
Expand Down Expand Up @@ -549,11 +576,7 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
if self.npu_communicator is not None and not self.npu_communicator.disabled:
return self.npu_communicator.all_reduce(input_)

if (
self.pynccl_comm is not None
and hasattr(input_, "symmetric_memory")
and input_.symmetric_memory
):
if self.pynccl_comm is not None and self.is_symmetric_memory_tensor(input_):
with self.pynccl_comm.change_state(
enable=True, stream=torch.get_device_module().current_stream()
):
Expand Down Expand Up @@ -628,15 +651,37 @@ def _all_reduce_in_place(self, input_: torch.Tensor) -> None:
else:
torch.distributed.all_reduce(input_, group=self.device_group)

def reduce_scatter_tensor(
def _reduce_scatter_tensor(
self,
output: torch.Tensor,
input: torch.Tensor,
) -> None:
# TODO(ch-wan): support other backends
torch.distributed.reduce_scatter_tensor(output, input, group=self.device_group)
) -> torch.Tensor:
pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and (
not pynccl_comm.disabled
or (
self.is_symmetric_memory_tensor(output)
and self.is_symmetric_memory_tensor(input)
)
):
with pynccl_comm.change_state(
enable=True, stream=torch.cuda.current_stream()
):
pynccl_comm.reduce_scatter(output, input)
else:
torch.distributed.reduce_scatter_tensor(
output, input, group=self.device_group
)
return output
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return output is unnecessary now.

Copy link
Copy Markdown
Collaborator Author

@nvcastet nvcastet Oct 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kept the code "as-is" to avoid unrelated changes. But yes we could clean-up those function signatures in another PR if needed.


def reduce_scatter_tensor(self, output: torch.Tensor, input: torch.Tensor):
if _is_npu or not supports_custom_op():
self._reduce_scatter_tensor(output, input)
else:
torch.ops.sglang.reg_reduce_scatter_tensor(
output, input, group_name=self.unique_name
)

def reduce_scatter(
self,
output: torch.Tensor,
Expand Down Expand Up @@ -683,8 +728,17 @@ def reduce_scatterv(

def _all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor):
pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled:
pynccl_comm.all_gather(output, input)
if pynccl_comm is not None and (
not pynccl_comm.disabled
or (
self.is_symmetric_memory_tensor(output)
and self.is_symmetric_memory_tensor(input)
)
):
with pynccl_comm.change_state(
enable=True, stream=torch.cuda.current_stream()
):
pynccl_comm.all_gather(output, input)
else:
torch.distributed.all_gather_into_tensor(
output, input, group=self.device_group
Expand Down Expand Up @@ -746,9 +800,10 @@ def all_gather(
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
output_size = (input_size[0] * world_size,) + input_size[1:]
# Allocate output tensor.
output_tensor = torch.empty(
output_size, dtype=input_.dtype, device=input_.device
)
with self.use_symmetric_memory(self):
output_tensor = torch.empty(
output_size, dtype=input_.dtype, device=input_.device
)

# All-gather.
if input_.is_cpu:
Expand Down Expand Up @@ -788,7 +843,7 @@ def all_gatherv(
pynccl_comm is not None and not pynccl_comm.disabled
), "pynccl is required for all_gatherv"

def _all_gather_single(
def _all_gather_allocate_output(
input_: torch.Tensor, sizes: Optional[List[int]] = None
):
input_size = input_.size()
Expand All @@ -802,19 +857,25 @@ def _all_gather_single(
else:
output_size = (input_size[0] * world_size,) + input_size[1:]
# Allocate output tensor.
output_tensor = torch.empty(
output_size, dtype=input_.dtype, device=input_.device
)
pynccl_comm.all_gather(output_tensor, input_, sizes=sizes)
return output_tensor
with self.use_symmetric_memory(self, disabled=sizes is not None):
output_tensor = torch.empty(
output_size, dtype=input_.dtype, device=input_.device
)
return output_tensor, sizes

if isinstance(input_, torch.Tensor):
return _all_gather_single(input_, sizes)
input_ = [input_]

output_list = []
pynccl_comm.group_start()
size_list = []
for inp in input_:
output_list.append(_all_gather_single(inp, sizes=sizes))
output_tensor, s = _all_gather_allocate_output(inp, sizes=sizes)
output_list.append(output_tensor)
size_list.append(s)

pynccl_comm.group_start()
for i, inp in enumerate(input_):
pynccl_comm.all_gather(output_list[i], inp, sizes=size_list[i])
pynccl_comm.group_end()

return output_list
Expand Down
1 change: 0 additions & 1 deletion python/sglang/srt/environ.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,6 @@ class Envs:
SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK = EnvBool(False)

# vLLM dependencies (TODO: they have been deprecated, we can remove them safely)
USE_VLLM_CUSTOM_ALLREDUCE = EnvBool(False)
USE_VLLM_CUTLASS_W8A8_FP8_KERNEL = EnvBool(False)

USE_TRITON_W8A8_FP8_KERNEL = EnvBool(False)
Expand Down
11 changes: 10 additions & 1 deletion python/sglang/srt/layers/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,12 @@

from sglang.srt.distributed import (
get_tensor_model_parallel_world_size,
get_tp_group,
tensor_model_parallel_all_reduce,
)
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
use_symmetric_memory,
)
from sglang.srt.layers.dp_attention import (
attn_tp_all_gather_into_tensor,
attn_tp_reduce_scatter_tensor,
Expand Down Expand Up @@ -540,7 +544,12 @@ def _gather_hidden_states_and_residual(
use_layer_norm_before_gather = context.attn_tp_size == 1
if use_layer_norm_before_gather and hidden_states.shape[0] != 0:
residual = hidden_states
hidden_states = layernorm(hidden_states)
with use_symmetric_memory(
get_tp_group(),
disabled=not forward_batch.dp_padding_mode.is_max_len(),
):
Comment on lines +547 to +550
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you try to cache as much variables as possible?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those 2 variables are already cached: tp group and is_max_len for this current batch.

hidden_states = layernorm(hidden_states)

hidden_states, local_hidden_states = (
get_global_dp_buffer(),
hidden_states,
Expand Down
Loading