From 395fe471e995b60cf4b39093b0e2d5c9f9e0b980 Mon Sep 17 00:00:00 2001 From: Nicolas Castet Date: Wed, 6 Aug 2025 11:36:09 -0500 Subject: [PATCH 1/6] Register allgather/reducescatter buffers with symm memory --- .../device_communicators/pynccl_allocator.py | 30 +++--- .../sglang/srt/distributed/parallel_state.py | 95 +++++++++++++++---- python/sglang/srt/layers/communicator.py | 12 ++- python/sglang/srt/layers/dp_attention.py | 45 ++++++--- python/sglang/srt/layers/linear.py | 6 +- python/sglang/srt/layers/moe/cutlass_moe.py | 2 +- .../srt/layers/moe/fused_moe_triton/layer.py | 7 +- python/sglang/srt/layers/moe/topk.py | 30 ++++-- python/sglang/srt/layers/quantization/fp8.py | 19 ++-- .../srt/layers/quantization/modelopt_quant.py | 50 +++++----- .../sglang/srt/layers/quantization/mxfp4.py | 7 +- .../srt/layers/vocab_parallel_embedding.py | 3 +- .../srt/model_executor/cuda_graph_runner.py | 6 +- .../srt/model_executor/forward_batch_info.py | 4 +- python/sglang/srt/models/deepseek_v2.py | 8 +- python/sglang/srt/models/glm4_moe.py | 6 +- python/sglang/srt/operations.py | 2 + .../eagle_draft_cuda_graph_runner.py | 6 +- .../eagle_draft_extend_cuda_graph_runner.py | 6 +- 19 files changed, 237 insertions(+), 107 deletions(-) diff --git a/python/sglang/srt/distributed/device_communicators/pynccl_allocator.py b/python/sglang/srt/distributed/device_communicators/pynccl_allocator.py index 6f29185caa1f..fcd35ae072f7 100644 --- a/python/sglang/srt/distributed/device_communicators/pynccl_allocator.py +++ b/python/sglang/srt/distributed/device_communicators/pynccl_allocator.py @@ -1,5 +1,6 @@ import os import tempfile +from contextlib import nullcontext import torch from torch.cuda.memory import CUDAPluggableAllocator @@ -92,7 +93,7 @@ def get_nccl_mem_pool(): return _mem_pool -class use_symmetric_memory: +class SymmetricMemoryContext: """ Context manager for using symmetric memory with pynccl. @@ -100,25 +101,17 @@ class use_symmetric_memory: by `ncclMemAlloc` and registered by `ncclCommWindowRegister`. Due to this, we introduce this context manager. All tensors created under this context will be correctly allocated and registered with a custom allocator. - - In addition, developers need to manually tag the tensors that will be used as the input/output - of NCCL collectives with `tag(tensor)`. """ - def __init__(self, group_coordinator: GroupCoordinator): - self.enabled = is_symmetric_memory_enabled() - - if not self.enabled: - return - + def __init__( + self, + group_coordinator: GroupCoordinator, + ): self.group_coordinator = group_coordinator self._mem_pool_ctx = torch.cuda.use_mem_pool(get_nccl_mem_pool()) self.is_graph_capture = torch.cuda.is_current_stream_capturing() def __enter__(self): - if not self.enabled: - return self - assert ( self.group_coordinator.pynccl_comm is not None ), f"Symmetric memory requires pynccl to be enabled in group '{self.group_coordinator.group_name}'" @@ -147,8 +140,11 @@ def __exit__(self, exc_type, exc_val, exc_tb): if self.is_graph_capture: torch._C._cuda_beginAllocateCurrentThreadToPool(_cur_device, _graph_pool_id) - def tag(self, tensor: torch.Tensor): - if not self.enabled: - return - tensor.symmetric_memory = True +def use_symmetric_memory(group_coordinator: GroupCoordinator, disabled: bool = False): + disabled = ( + not is_symmetric_memory_enabled() + or disabled + or group_coordinator.world_size == 1 + ) + return SymmetricMemoryContext(group_coordinator) if not disabled else nullcontext() diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index 16c00d17b271..4060af3763f7 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -188,6 +188,27 @@ def reg_all_gather_into_tensor_fake( fake_impl=reg_all_gather_into_tensor_fake, ) + def reg_reduce_scatter_tensor( + output: torch.Tensor, input: torch.Tensor, group_name: str + ) -> None: + assert group_name in _groups, f"Group {group_name} is not found." + group = _groups[group_name]() + if group is None: + raise ValueError(f"Group {group_name} is destroyed.") + group._reduce_scatter_tensor(output, input) + + def reg_reduce_scatter_tensor_fake( + output: torch.Tensor, input: torch.Tensor, group_name: str + ) -> None: + pass + + direct_register_custom_op( + op_name="reg_reduce_scatter_tensor", + op_func=reg_reduce_scatter_tensor, + mutates_args=["output"], + fake_impl=reg_reduce_scatter_tensor_fake, + ) + class GroupCoordinator: """ @@ -314,10 +335,16 @@ def __init__( from sglang.srt.distributed.device_communicators.pynccl import ( PyNcclCommunicator, ) + from sglang.srt.distributed.device_communicators.pynccl_allocator import ( + is_symmetric_memory_enabled, + use_symmetric_memory, + ) from sglang.srt.distributed.device_communicators.torch_symm_mem import ( TorchSymmMemCommunicator, ) + self.is_symmetric_memory_enabled = is_symmetric_memory_enabled() + self.use_symmetric_memory = use_symmetric_memory if is_hip(): from sglang.srt.distributed.device_communicators.quick_all_reduce import ( QuickAllReduce, @@ -552,7 +579,7 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: if self.npu_communicator is not None and not self.npu_communicator.disabled: return self.npu_communicator.all_reduce(input_) - if self.pynccl_comm is not None and getattr(input_, "symmetric_memory", False): + if self.pynccl_comm is not None and self.is_symmetric_memory_enabled: with self.pynccl_comm.change_state( enable=True, stream=get_current_device_stream_fast() ): @@ -627,15 +654,33 @@ def _all_reduce_in_place(self, input_: torch.Tensor) -> None: else: torch.distributed.all_reduce(input_, group=self.device_group) - def reduce_scatter_tensor( + def _reduce_scatter_tensor( self, output: torch.Tensor, input: torch.Tensor, - ) -> None: - # TODO(ch-wan): support other backends - torch.distributed.reduce_scatter_tensor(output, input, group=self.device_group) + ) -> torch.Tensor: + pynccl_comm = self.pynccl_comm + if pynccl_comm is not None and ( + not pynccl_comm.disabled or self.is_symmetric_memory_enabled + ): + with pynccl_comm.change_state( + enable=True, stream=torch.cuda.current_stream() + ): + pynccl_comm.reduce_scatter(output, input) + else: + torch.distributed.reduce_scatter_tensor( + output, input, group=self.device_group + ) return output + def reduce_scatter_tensor(self, output: torch.Tensor, input: torch.Tensor): + if _is_npu or not supports_custom_op(): + self._reduce_scatter_tensor(output, input) + else: + torch.ops.sglang.reg_reduce_scatter_tensor( + output, input, group_name=self.unique_name + ) + def reduce_scatter( self, output: torch.Tensor, @@ -682,8 +727,13 @@ def reduce_scatterv( def _all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor): pynccl_comm = self.pynccl_comm - if pynccl_comm is not None and not pynccl_comm.disabled: - pynccl_comm.all_gather(output, input) + if pynccl_comm is not None and ( + not pynccl_comm.disabled or self.is_symmetric_memory_enabled + ): + with pynccl_comm.change_state( + enable=True, stream=torch.cuda.current_stream() + ): + pynccl_comm.all_gather(output, input) else: torch.distributed.all_gather_into_tensor( output, input, group=self.device_group @@ -745,9 +795,10 @@ def all_gather( # torch.compile . see https://github.com/pytorch/pytorch/issues/138795 output_size = (input_size[0] * world_size,) + input_size[1:] # Allocate output tensor. - output_tensor = torch.empty( - output_size, dtype=input_.dtype, device=input_.device - ) + with self.use_symmetric_memory(self): + output_tensor = torch.empty( + output_size, dtype=input_.dtype, device=input_.device + ) # All-gather. if input_.is_cpu: @@ -787,7 +838,7 @@ def all_gatherv( pynccl_comm is not None and not pynccl_comm.disabled ), "pynccl is required for all_gatherv" - def _all_gather_single( + def _all_gather_allocate_output( input_: torch.Tensor, sizes: Optional[List[int]] = None ): input_size = input_.size() @@ -801,19 +852,25 @@ def _all_gather_single( else: output_size = (input_size[0] * world_size,) + input_size[1:] # Allocate output tensor. - output_tensor = torch.empty( - output_size, dtype=input_.dtype, device=input_.device - ) - pynccl_comm.all_gather(output_tensor, input_, sizes=sizes) - return output_tensor + with self.use_symmetric_memory(self, disabled=sizes is not None): + output_tensor = torch.empty( + output_size, dtype=input_.dtype, device=input_.device + ) + return output_tensor, sizes if isinstance(input_, torch.Tensor): - return _all_gather_single(input_, sizes) + input_ = [input_] output_list = [] - pynccl_comm.group_start() + size_list = [] for inp in input_: - output_list.append(_all_gather_single(inp, sizes=sizes)) + output_tensor, s = _all_gather_allocate_output(inp, sizes=sizes) + output_list.append(output_tensor) + size_list.append(s) + + pynccl_comm.group_start() + for i, inp in enumerate(input_): + pynccl_comm.all_gather(output_list[i], inp, sizes=size_list[i]) pynccl_comm.group_end() return output_list diff --git a/python/sglang/srt/layers/communicator.py b/python/sglang/srt/layers/communicator.py index eaeeac51def2..423bd2801471 100644 --- a/python/sglang/srt/layers/communicator.py +++ b/python/sglang/srt/layers/communicator.py @@ -21,8 +21,12 @@ from sglang.srt.distributed import ( get_tensor_model_parallel_world_size, + get_tp_group, tensor_model_parallel_all_reduce, ) +from sglang.srt.distributed.device_communicators.pynccl_allocator import ( + use_symmetric_memory, +) from sglang.srt.layers.dp_attention import ( attn_tp_all_gather_into_tensor, attn_tp_reduce_scatter_tensor, @@ -34,6 +38,7 @@ get_attention_tp_size, get_global_dp_buffer, get_local_dp_buffer, + is_allocation_symmetric, is_dp_attention_enabled, ) from sglang.srt.layers.moe import ( @@ -540,7 +545,12 @@ def _gather_hidden_states_and_residual( use_layer_norm_before_gather = context.attn_tp_size == 1 if use_layer_norm_before_gather and hidden_states.shape[0] != 0: residual = hidden_states - hidden_states = layernorm(hidden_states) + with use_symmetric_memory( + get_tp_group(), + disabled=not is_allocation_symmetric(), + ): + hidden_states = layernorm(hidden_states) + hidden_states, local_hidden_states = ( get_global_dp_buffer(), hidden_states, diff --git a/python/sglang/srt/layers/dp_attention.py b/python/sglang/srt/layers/dp_attention.py index 2b0f92a04474..4956d76ee4cf 100644 --- a/python/sglang/srt/layers/dp_attention.py +++ b/python/sglang/srt/layers/dp_attention.py @@ -17,6 +17,9 @@ get_tp_group, tensor_model_parallel_all_reduce, ) +from sglang.srt.distributed.device_communicators.pynccl_allocator import ( + use_symmetric_memory, +) from sglang.srt.utils import get_bool_env_var, is_hip if TYPE_CHECKING: @@ -86,6 +89,7 @@ class _DpGatheredBufferWrapper: _device: torch.device _global_dp_buffer_len: int _local_dp_buffer_len: int + _dp_max_padding: bool _global_num_tokens: Optional[List[int]] _is_extend_in_batch: bool @@ -100,27 +104,33 @@ def set_dp_buffer_len( cls, global_dp_buffer_len: int, local_dp_buffer_len: int, + dp_max_padding: bool, global_num_tokens: Optional[List[int]] = None, ): cls._global_dp_buffer_len = global_dp_buffer_len cls._local_dp_buffer_len = local_dp_buffer_len + cls._dp_max_padding = dp_max_padding cls._global_num_tokens = global_num_tokens @classmethod def get_global_dp_buffer(cls) -> torch.Tensor: - return torch.empty( - (cls._global_dp_buffer_len, cls._hidden_size), - dtype=cls._dtype, - device=cls._device, - ) + with use_symmetric_memory(get_tp_group()): + buffer = torch.empty( + (cls._global_dp_buffer_len, cls._hidden_size), + dtype=cls._dtype, + device=cls._device, + ) + return buffer @classmethod def get_local_dp_buffer(cls) -> torch.Tensor: - return torch.empty( - (cls._local_dp_buffer_len, cls._hidden_size), - dtype=cls._dtype, - device=cls._device, - ) + with use_symmetric_memory(get_tp_group(), disabled=not cls._dp_max_padding): + buffer = torch.empty( + (cls._local_dp_buffer_len, cls._hidden_size), + dtype=cls._dtype, + device=cls._device, + ) + return buffer @classmethod def get_global_dp_buffer_len(cls) -> int: @@ -154,14 +164,19 @@ def set_is_extend_in_batch(cls, is_extend_in_batch: bool): def get_is_extend_in_batch(cls) -> bool: return cls._is_extend_in_batch + @classmethod + def is_dp_max_padding(cls) -> bool: + return cls._dp_max_padding + def set_dp_buffer_len( global_dp_buffer_len: int, local_dp_buffer_len: int, + dp_max_padding: bool, global_num_tokens: Optional[List[int]] = None, ): _DpGatheredBufferWrapper.set_dp_buffer_len( - global_dp_buffer_len, local_dp_buffer_len, global_num_tokens + global_dp_buffer_len, local_dp_buffer_len, dp_max_padding, global_num_tokens ) @@ -205,6 +220,10 @@ def get_is_extend_in_batch() -> bool: return _DpGatheredBufferWrapper.get_is_extend_in_batch() +def is_dp_max_padding() -> bool: + return _DpGatheredBufferWrapper.is_dp_max_padding() + + def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size): if not enable_dp_attention: return tp_rank, tp_size, 0 @@ -298,6 +317,10 @@ def is_dp_attention_enabled() -> bool: return _ENABLE_DP_ATTENTION_FLAG +def is_allocation_symmetric() -> bool: + return not is_dp_attention_enabled() or is_dp_max_padding() + + def get_attention_tp_group() -> GroupCoordinator: assert _ATTN_TP_GROUP is not None, "dp attention not initialized!" return _ATTN_TP_GROUP diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py index 65003af581e9..e1a97111debb 100644 --- a/python/sglang/srt/layers/linear.py +++ b/python/sglang/srt/layers/linear.py @@ -21,6 +21,7 @@ from sglang.srt.distributed.device_communicators.pynccl_allocator import ( use_symmetric_memory, ) +from sglang.srt.layers.dp_attention import is_allocation_symmetric from sglang.srt.layers.parameter import ( BasevLLMParameter, BlockQuantScaleParameter, @@ -1372,9 +1373,10 @@ def forward(self, input_, skip_all_reduce=False): # Only fuse bias add into GEMM for rank 0 (this ensures that # bias will not get added more than once in TP>1 case) bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias - with use_symmetric_memory(get_tp_group()) as sm: + with use_symmetric_memory( + get_tp_group(), disabled=not is_allocation_symmetric() + ): output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_) - sm.tag(output_parallel) if self.reduce_results and self.tp_size > 1 and not skip_all_reduce: output = tensor_model_parallel_all_reduce(output_parallel) diff --git a/python/sglang/srt/layers/moe/cutlass_moe.py b/python/sglang/srt/layers/moe/cutlass_moe.py index 368c42929bcc..12d2c3991cdc 100755 --- a/python/sglang/srt/layers/moe/cutlass_moe.py +++ b/python/sglang/srt/layers/moe/cutlass_moe.py @@ -97,7 +97,7 @@ def cutlass_fused_experts_fp8( b_scales_ptrs (torch.Tensor): Pointers container for calculating offsets of the input scales for each expert. use_fp8_blockscale (bool, optional): Flag indicating usage of FP8 with block scaling. Currently, only `True` is supported. Defaults to `True`. - + output (torch.Tensor, optional): Output tensor. If not provided, a new tensor will be created. Returns: torch.Tensor: The computed MoE layer output. Shape: `(m, k)`, dtype matches `a`. diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 3212af3c4a50..35e15276f3fe 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -18,6 +18,7 @@ use_symmetric_memory, ) from sglang.srt.eplb.expert_location import get_global_expert_location_metadata +from sglang.srt.layers.dp_attention import is_allocation_symmetric from sglang.srt.layers.moe import ( MoeRunnerConfig, get_deepep_mode, @@ -1048,10 +1049,10 @@ def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput): router_logits = router_logits.to(torch.float32) - with use_symmetric_memory(get_tp_group()) as sm: + with use_symmetric_memory( + get_tp_group(), disabled=not is_allocation_symmetric() + ): symm_output = torch.empty_like(hidden_states) - sm.tag(symm_output) - result = trtllm_fp4_block_scale_moe( routing_logits=router_logits, routing_bias=topk_config.correction_bias.to(hidden_states.dtype), diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 3a504320be8a..28636de23b05 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -32,12 +32,17 @@ import torch.nn.functional as F from sglang.srt.custom_op import CustomOp +from sglang.srt.distributed import get_tp_group +from sglang.srt.distributed.device_communicators.pynccl_allocator import ( + use_symmetric_memory, +) from sglang.srt.eplb import expert_location_dispatch from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.eplb.expert_location_dispatch import ( ExpertLocationDispatchInfo, topk_ids_logical_to_physical, ) +from sglang.srt.layers.dp_attention import is_allocation_symmetric from sglang.srt.layers.moe import get_moe_runner_backend from sglang.srt.utils import ( cpu_has_amx_support, @@ -279,13 +284,17 @@ def forward_cuda( ) else: self.topk_config.torch_native = False - return select_experts( - hidden_states=hidden_states, - router_logits=router_logits, - topk_config=self.topk_config, - num_token_non_padded=num_token_non_padded, - expert_location_dispatch_info=expert_location_dispatch_info, - ) + with use_symmetric_memory( + get_tp_group(), disabled=not is_allocation_symmetric() + ): + topk_output = select_experts( + hidden_states=hidden_states, + router_logits=router_logits, + topk_config=self.topk_config, + num_token_non_padded=num_token_non_padded, + expert_location_dispatch_info=expert_location_dispatch_info, + ) + return topk_output def forward_cpu( self, @@ -386,8 +395,11 @@ def forward_npu( def empty_topk_output(self, device: torch.device) -> TopKOutput: topk = self.topk_config.top_k - self.topk_config.num_fused_shared_experts - topk_weights = torch.empty((0, topk), dtype=torch.float32, device=device) - topk_ids = torch.full((0, topk), -1, dtype=torch.int32, device=device) + with use_symmetric_memory( + get_tp_group(), disabled=not is_allocation_symmetric() + ): + topk_weights = torch.empty((0, topk), dtype=torch.float32, device=device) + topk_ids = torch.full((0, topk), -1, dtype=torch.int32, device=device) # FIXME: router_logits should be of size (0, num_experts) router_logits = torch.empty((0, topk), dtype=torch.float32, device=device) return StandardTopKOutput(topk_weights, topk_ids, router_logits) diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 40bfb2910ade..37f15ae2c4c3 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -10,6 +10,12 @@ from torch.nn import Module from torch.nn.parameter import Parameter +from sglang.srt.distributed import get_tp_group +from sglang.srt.distributed.device_communicators.pynccl_allocator import ( + use_symmetric_memory, +) +from sglang.srt.layers.dp_attention import is_allocation_symmetric + try: from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( apply_fp8_marlin_linear, @@ -1033,9 +1039,10 @@ def apply( if get_moe_runner_backend().is_cutlass(): from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8 - with use_symmetric_memory(get_tp_group()) as sm: + with use_symmetric_memory( + get_tp_group(), disabled=not is_allocation_symmetric() + ): symm_output = torch.empty_like(x) - sm.tag(symm_output) topk_weights, topk_ids, _ = dispatch_output.topk_output output = cutlass_fused_experts_fp8( @@ -1208,12 +1215,14 @@ def apply_with_router_logits( else topk_config.correction_bias.to(x.dtype) ) - with use_symmetric_memory(get_tp_group()) as sm: + with use_symmetric_memory( + get_tp_group(), disabled=not is_allocation_symmetric() + ): # FIXME: there is a bug in the trtllm_fp8_block_scale_moe. # It ignored the `output`` argument. https://github.com/flashinfer-ai/flashinfer/blob/da01b1bd8f9f22aec8c0eea189ad54860b034947/flashinfer/fused_moe/core.py#L1323-L1325 # so we put the whole function under the ``use_symmetric_memory`` context manager. # If the bug is fixed, we can only put the output tensor allocation under the context manager. - output = trtllm_fp8_block_scale_moe( + return trtllm_fp8_block_scale_moe( routing_logits=router_logits.to(torch.float32), routing_bias=correction_bias, hidden_states=a_q, @@ -1238,8 +1247,6 @@ def apply_with_router_logits( routing_method_type=2, # DeepSeek-styled routing method use_shuffled_weight=False, ) - sm.tag(output) - return output def maybe_apply_hip_fused_experts( self, diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index 910f1d3c0a82..980a46c216ac 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -11,7 +11,11 @@ from sglang.srt.distributed.device_communicators.pynccl_allocator import ( use_symmetric_memory, ) -from sglang.srt.layers.dp_attention import get_dp_global_num_tokens, get_local_dp_buffer +from sglang.srt.layers.dp_attention import ( + get_dp_global_num_tokens, + get_local_dp_buffer, + is_allocation_symmetric, +) from sglang.srt.layers.moe import ( MoeRunner, MoeRunnerBackend, @@ -663,7 +667,7 @@ def apply( None if correction_bias is None else correction_bias.to(torch.bfloat16) ) - with use_symmetric_memory(get_tp_group()) as sm: + with use_symmetric_memory(get_tp_group()): # FIXME: there is a bug in the trtllm_fp8_block_scale_moe. # It ignored the `output`` argument. https://github.com/flashinfer-ai/flashinfer/blob/da01b1bd8f9f22aec8c0eea189ad54860b034947/flashinfer/fused_moe/core.py#L1323-L1325 # so we put the whole function under the ``use_symmetric_memory`` context manager. @@ -693,7 +697,6 @@ def apply( tile_tokens_dim=None, routing_method_type=routing_method_type, ) - sm.tag(output) from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput @@ -1581,38 +1584,42 @@ def apply( topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids output_dtype = x.dtype + original_col = x.shape[1] x_sf = None + if should_use_flashinfer_cutlass_moe_fp4_allgather(): from flashinfer import nvfp4_block_scale_interleave # Quantize before comm, swizzle after. - if x.shape[0] > 0: - x, x_sf = fp4_quantize_flashinfer( - x, layer.w13_input_scale_quant, is_sf_swizzled_layout=False - ) - else: - x_col = x.shape[1] - x = torch.zeros(0, x_col // 2, dtype=torch.uint8, device=x.device) - x_sf = torch.zeros( - 0, x_col // 16, dtype=torch.uint8, device=x.device - ) + with use_symmetric_memory( + get_tp_group(), disabled=not is_allocation_symmetric() + ): + if x.shape[0] > 0: + x, x_sf = fp4_quantize_flashinfer( + x, layer.w13_input_scale_quant, is_sf_swizzled_layout=False + ) + else: + x_col = x.shape[1] + x = torch.zeros( + 0, x_col // 2, dtype=torch.uint8, device=x.device + ) + x_sf = torch.zeros( + 0, x_col // 16, dtype=torch.uint8, device=x.device + ) topk_weights, topk_ids, x, x_sf = get_tp_group().all_gatherv( [topk_weights, topk_ids, x, x_sf], sizes=get_dp_global_num_tokens() ) x_sf = nvfp4_block_scale_interleave(x_sf) - with use_symmetric_memory(get_tp_group()) as sm: - # The x might be packed in the case of fp4. So, use the output dim of the - # weight of the second GEMM. + with use_symmetric_memory( + get_tp_group(), disabled=not is_allocation_symmetric() + ): symm_output = torch.empty( - x.shape[0], - layer.w2_weight.shape[1], - dtype=output_dtype, - device=x.device, + x.shape[0], original_col, dtype=output_dtype, device=x.device ) - sm.tag(symm_output) output = flashinfer_cutlass_fused_moe( + output=symm_output, input=x, token_selected_experts=topk_ids.to(torch.int), token_final_scales=topk_weights, @@ -1633,7 +1640,6 @@ def apply( tp_size=layer.moe_tp_size, tp_rank=layer.moe_tp_rank, tune_max_num_tokens=next_power_of_2(x.shape[0]), - output=symm_output, )[0] if should_use_flashinfer_cutlass_moe_fp4_allgather(): output, global_output = get_local_dp_buffer(), output diff --git a/python/sglang/srt/layers/quantization/mxfp4.py b/python/sglang/srt/layers/quantization/mxfp4.py index 387fb53c5817..23dbda5e3d11 100644 --- a/python/sglang/srt/layers/quantization/mxfp4.py +++ b/python/sglang/srt/layers/quantization/mxfp4.py @@ -26,6 +26,7 @@ from sglang.srt.distributed.device_communicators.pynccl_allocator import ( use_symmetric_memory, ) +from sglang.srt.layers.dp_attention import is_allocation_symmetric from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo from sglang.srt.layers.moe.utils import get_moe_runner_backend @@ -640,10 +641,10 @@ def apply( top_k = topk_output.topk_config.top_k router_logits = topk_output.router_logits - with use_symmetric_memory(get_tp_group()) as sm: + with use_symmetric_memory( + get_tp_group(), disabled=not is_allocation_symmetric() + ): symm_output = torch.empty_like(x) - sm.tag(symm_output) - trtllm_gen_output = trtllm_fp4_block_scale_moe( router_logits.to(torch.bfloat16), None, # routing_bias diff --git a/python/sglang/srt/layers/vocab_parallel_embedding.py b/python/sglang/srt/layers/vocab_parallel_embedding.py index 9a26922656b5..6c153b25051b 100644 --- a/python/sglang/srt/layers/vocab_parallel_embedding.py +++ b/python/sglang/srt/layers/vocab_parallel_embedding.py @@ -473,9 +473,8 @@ def forward(self, input_): else: masked_input = input_ # Get the embeddings. - with use_symmetric_memory(get_tp_group()) as sm: + with use_symmetric_memory(get_tp_group(), disabled=not self.enable_tp): output_parallel = self.quant_method.embedding(self, masked_input.long()) - sm.tag(output_parallel) # Mask the output embedding. if self.tp_size > 1: output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0) diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 80f8c65648cd..243b9a84b286 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -660,7 +660,11 @@ def capture_one_batch_size(self, bs: int, forward: Callable): def run_once(): # Clean intermediate result cache for DP attention forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None - set_dp_buffer_len(global_dp_buffer_len, num_tokens) + set_dp_buffer_len( + global_dp_buffer_len, + num_tokens, + forward_batch.dp_padding_mode.is_max_len(), + ) set_is_extend_in_batch(False) kwargs = {} diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index dd42efc1680e..4f843726fa43 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -719,7 +719,9 @@ def prepare_mlp_sync_batch(self, model_runner: ModelRunner): num_tokens = global_num_tokens[0] self.global_dp_buffer_len = buffer_len - set_dp_buffer_len(buffer_len, num_tokens, global_num_tokens) + set_dp_buffer_len( + buffer_len, num_tokens, dp_padding_mode.is_max_len(), global_num_tokens + ) set_is_extend_in_batch(self.is_extend_in_batch) bs = self.batch_size diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 5f80f8cc5f8b..4140cf2c6d1c 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -480,7 +480,8 @@ def forward( gate_up, _ = self.gate_up_proj(x) x = self.act_fn(gate_up) x, _ = self.down_proj( - x, skip_all_reduce=should_allreduce_fusion or use_reduce_scatter + x, + skip_all_reduce=should_allreduce_fusion or use_reduce_scatter, ) return x @@ -814,7 +815,6 @@ def _forward_shared_experts_and_put_results(): final_hidden_states *= self.routed_scaling_factor if shared_output is not None: final_hidden_states += shared_output - if ( self.tp_size > 1 and not should_allreduce_fusion @@ -883,7 +883,9 @@ def forward_cpu( return final_hidden_states def forward_deepep( - self, hidden_states: torch.Tensor, forward_batch: ForwardBatch + self, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, ) -> torch.Tensor: shared_output = None if hidden_states.shape[0] > 0: diff --git a/python/sglang/srt/models/glm4_moe.py b/python/sglang/srt/models/glm4_moe.py index 6051a5bb7c14..d8117672f728 100644 --- a/python/sglang/srt/models/glm4_moe.py +++ b/python/sglang/srt/models/glm4_moe.py @@ -482,12 +482,11 @@ def forward_normal_dual_stream( final_hidden_states *= self.routed_scaling_factor current_stream.wait_stream(self.alt_stream) - with use_symmetric_memory(parallel_state.get_tp_group()) as sm: + with use_symmetric_memory(parallel_state.get_tp_group()): final_hidden_states_out = torch.empty_like(final_hidden_states) torch.add(final_hidden_states, shared_output, out=final_hidden_states_out) final_hidden_states = final_hidden_states_out - sm.tag(final_hidden_states) if ( self.tp_size > 1 and not should_allreduce_fusion @@ -517,11 +516,10 @@ def forward_normal( # fused in biased_grouped_topk so we can skip here final_hidden_states *= self.routed_scaling_factor if shared_output is not None: - with use_symmetric_memory(parallel_state.get_tp_group()) as sm: + with use_symmetric_memory(parallel_state.get_tp_group()): final_hidden_states_out = torch.empty_like(final_hidden_states) torch.add(final_hidden_states, shared_output, out=final_hidden_states_out) final_hidden_states = final_hidden_states_out - sm.tag(final_hidden_states) if ( self.tp_size > 1 and not should_allreduce_fusion diff --git a/python/sglang/srt/operations.py b/python/sglang/srt/operations.py index f8730cd77232..9d824587c4da 100644 --- a/python/sglang/srt/operations.py +++ b/python/sglang/srt/operations.py @@ -85,6 +85,7 @@ def __init__(self, debug_name: str, stages: List[Stage], inputs: dict): self._global_dp_buffer_len = forward_batch.global_dp_buffer_len self._local_dp_buffer_len = forward_batch.input_ids.shape[0] self._global_num_tokens = forward_batch.global_num_tokens_cpu + self._is_dp_max_padding = forward_batch.dp_padding_mode.is_max_len() def next(self): assert not self.done @@ -95,6 +96,7 @@ def next(self): set_dp_buffer_len( self._global_dp_buffer_len, self._local_dp_buffer_len, + self._is_dp_max_padding, self._global_num_tokens, ) diff --git a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py index 1b785443dd00..38ec9f46667f 100644 --- a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py @@ -263,7 +263,11 @@ def capture_one_batch_size(self, num_seqs: int, forward: Callable): def run_once(): # Clean intermediate result cache for DP attention forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None - set_dp_buffer_len(global_dp_buffer_len, num_tokens) + set_dp_buffer_len( + global_dp_buffer_len, + num_tokens, + forward_batch.dp_padding_mode.is_max_len(), + ) set_is_extend_in_batch(False) # Backup two fields, which will be modified in-place in `draft_forward`. diff --git a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py index d4b5aeb27fbf..4571ac540f09 100644 --- a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py @@ -294,7 +294,11 @@ def capture_one_batch_size(self, bs: int, forward: Callable): def run_once(): # Clean intermediate result cache for DP attention forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None - set_dp_buffer_len(global_dp_buffer_len, num_tokens) + set_dp_buffer_len( + global_dp_buffer_len, + num_tokens, + forward_batch.dp_padding_mode.is_max_len(), + ) set_is_extend_in_batch(False) # Backup two fields, which will be modified in-place in `draft_forward`. From 19ce59b9fafc91d10d7f69c7848a8a7d65713867 Mon Sep 17 00:00:00 2001 From: Nicolas Castet Date: Tue, 4 Nov 2025 09:32:45 -0600 Subject: [PATCH 2/6] Fix import --- .../srt/distributed/device_communicators/pynccl_allocator.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/distributed/device_communicators/pynccl_allocator.py b/python/sglang/srt/distributed/device_communicators/pynccl_allocator.py index fcd35ae072f7..268320bc44d0 100644 --- a/python/sglang/srt/distributed/device_communicators/pynccl_allocator.py +++ b/python/sglang/srt/distributed/device_communicators/pynccl_allocator.py @@ -6,7 +6,6 @@ from torch.cuda.memory import CUDAPluggableAllocator from sglang.srt.distributed.parallel_state import GroupCoordinator -from sglang.srt.server_args import get_global_server_args nccl_allocator_source = """ @@ -61,6 +60,9 @@ def is_symmetric_memory_enabled(): + # Import here to avoid circular import + from sglang.srt.server_args import get_global_server_args + return get_global_server_args().enable_symm_mem From e5839ff7e24a9770454edcc8806f2090946d7cc5 Mon Sep 17 00:00:00 2001 From: Nicolas Castet Date: Tue, 4 Nov 2025 09:40:25 -0600 Subject: [PATCH 3/6] Fix typo in contextmanager --- .../srt/distributed/device_communicators/pynccl_allocator.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/python/sglang/srt/distributed/device_communicators/pynccl_allocator.py b/python/sglang/srt/distributed/device_communicators/pynccl_allocator.py index 268320bc44d0..e628eaf24778 100644 --- a/python/sglang/srt/distributed/device_communicators/pynccl_allocator.py +++ b/python/sglang/srt/distributed/device_communicators/pynccl_allocator.py @@ -134,9 +134,6 @@ def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): - if not self.enabled: - return - self._mem_pool_ctx.__exit__(exc_type, exc_val, exc_tb) if self.is_graph_capture: From 288c6097aada9d7c6a9b8e925cb1f95b5ad550f5 Mon Sep 17 00:00:00 2001 From: Nicolas Castet Date: Tue, 4 Nov 2025 09:52:02 -0600 Subject: [PATCH 4/6] Fix leftover tag --- python/sglang/srt/layers/moe/fused_moe_triton/layer.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 35e15276f3fe..9c4509678ad3 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -842,7 +842,7 @@ def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput, **kwargs **kwargs, ) - with use_symmetric_memory(get_tp_group()) as sm: + with use_symmetric_memory(get_tp_group()): final_hidden_states = self.dispatcher.combine(combine_input=combine_input) # TODO: should we add some conditions here? @@ -850,8 +850,6 @@ def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput, **kwargs ..., :origin_hidden_states_dim ].contiguous() - sm.tag(final_hidden_states) - if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1): final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) From 430583b350a80283ab6e0b250c62cefb2fa424c1 Mon Sep 17 00:00:00 2001 From: Nicolas Castet Date: Tue, 4 Nov 2025 13:25:26 -0600 Subject: [PATCH 5/6] Add missing disable args --- python/sglang/srt/layers/moe/fused_moe_triton/layer.py | 4 +++- python/sglang/srt/layers/quantization/modelopt_quant.py | 4 +++- python/sglang/srt/models/glm4_moe.py | 9 +++++++-- 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 9c4509678ad3..03e3cdf34112 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -842,7 +842,9 @@ def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput, **kwargs **kwargs, ) - with use_symmetric_memory(get_tp_group()): + with use_symmetric_memory( + get_tp_group(), disabled=not is_allocation_symmetric() + ): final_hidden_states = self.dispatcher.combine(combine_input=combine_input) # TODO: should we add some conditions here? diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index 980a46c216ac..d485f8ef7dc0 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -667,7 +667,9 @@ def apply( None if correction_bias is None else correction_bias.to(torch.bfloat16) ) - with use_symmetric_memory(get_tp_group()): + with use_symmetric_memory( + get_tp_group(), disabled=not is_allocation_symmetric() + ): # FIXME: there is a bug in the trtllm_fp8_block_scale_moe. # It ignored the `output`` argument. https://github.com/flashinfer-ai/flashinfer/blob/da01b1bd8f9f22aec8c0eea189ad54860b034947/flashinfer/fused_moe/core.py#L1323-L1325 # so we put the whole function under the ``use_symmetric_memory`` context manager. diff --git a/python/sglang/srt/models/glm4_moe.py b/python/sglang/srt/models/glm4_moe.py index d8117672f728..85398325e883 100644 --- a/python/sglang/srt/models/glm4_moe.py +++ b/python/sglang/srt/models/glm4_moe.py @@ -45,6 +45,7 @@ from sglang.srt.layers.dp_attention import ( get_attention_tp_rank, get_attention_tp_size, + is_allocation_symmetric, is_dp_attention_enabled, ) from sglang.srt.layers.layernorm import RMSNorm @@ -482,7 +483,9 @@ def forward_normal_dual_stream( final_hidden_states *= self.routed_scaling_factor current_stream.wait_stream(self.alt_stream) - with use_symmetric_memory(parallel_state.get_tp_group()): + with use_symmetric_memory( + parallel_state.get_tp_group(), disabled=not is_allocation_symmetric() + ): final_hidden_states_out = torch.empty_like(final_hidden_states) torch.add(final_hidden_states, shared_output, out=final_hidden_states_out) @@ -516,7 +519,9 @@ def forward_normal( # fused in biased_grouped_topk so we can skip here final_hidden_states *= self.routed_scaling_factor if shared_output is not None: - with use_symmetric_memory(parallel_state.get_tp_group()): + with use_symmetric_memory( + parallel_state.get_tp_group(), disabled=not is_allocation_symmetric() + ): final_hidden_states_out = torch.empty_like(final_hidden_states) torch.add(final_hidden_states, shared_output, out=final_hidden_states_out) final_hidden_states = final_hidden_states_out From cad89292674a4c330a63397ac92b29bef42baebe Mon Sep 17 00:00:00 2001 From: Nicolas Castet Date: Tue, 4 Nov 2025 14:47:52 -0600 Subject: [PATCH 6/6] Address review comments --- python/sglang/srt/distributed/parallel_state.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index 4060af3763f7..c954d1e52d41 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -343,7 +343,7 @@ def __init__( TorchSymmMemCommunicator, ) - self.is_symmetric_memory_enabled = is_symmetric_memory_enabled() + self.is_symmetric_memory_enabled = is_symmetric_memory_enabled self.use_symmetric_memory = use_symmetric_memory if is_hip(): from sglang.srt.distributed.device_communicators.quick_all_reduce import ( @@ -579,7 +579,7 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: if self.npu_communicator is not None and not self.npu_communicator.disabled: return self.npu_communicator.all_reduce(input_) - if self.pynccl_comm is not None and self.is_symmetric_memory_enabled: + if self.pynccl_comm is not None and self.is_symmetric_memory_enabled(): with self.pynccl_comm.change_state( enable=True, stream=get_current_device_stream_fast() ): @@ -661,10 +661,10 @@ def _reduce_scatter_tensor( ) -> torch.Tensor: pynccl_comm = self.pynccl_comm if pynccl_comm is not None and ( - not pynccl_comm.disabled or self.is_symmetric_memory_enabled + not pynccl_comm.disabled or self.is_symmetric_memory_enabled() ): with pynccl_comm.change_state( - enable=True, stream=torch.cuda.current_stream() + enable=True, stream=get_current_device_stream_fast() ): pynccl_comm.reduce_scatter(output, input) else: @@ -728,10 +728,10 @@ def reduce_scatterv( def _all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor): pynccl_comm = self.pynccl_comm if pynccl_comm is not None and ( - not pynccl_comm.disabled or self.is_symmetric_memory_enabled + not pynccl_comm.disabled or self.is_symmetric_memory_enabled() ): with pynccl_comm.change_state( - enable=True, stream=torch.cuda.current_stream() + enable=True, stream=get_current_device_stream_fast() ): pynccl_comm.all_gather(output, input) else: