diff --git a/python/sglang/srt/configs/falcon_h1.py b/python/sglang/srt/configs/falcon_h1.py index b8869b4ffa35..1f524b892d0d 100644 --- a/python/sglang/srt/configs/falcon_h1.py +++ b/python/sglang/srt/configs/falcon_h1.py @@ -19,7 +19,6 @@ from transformers.utils import logging from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape -from sglang.srt.layers.dp_attention import get_tensor_model_parallel_world_size logger = logging.get_logger(__name__) @@ -297,8 +296,10 @@ def linear_layer_ids(self): @property def mamba2_cache_params(self): + from sglang.srt.layers.dp_attention import get_attention_tp_size + shape = Mamba2StateShape.create( - tp_world_size=get_tensor_model_parallel_world_size(), + tp_world_size=get_attention_tp_size(), intermediate_size=self.mamba_intermediate, n_groups=self.mamba_n_groups, num_heads=self.mamba_n_heads, diff --git a/python/sglang/srt/configs/nemotron_h.py b/python/sglang/srt/configs/nemotron_h.py index 9e156f6a7faa..b73b146feb6e 100644 --- a/python/sglang/srt/configs/nemotron_h.py +++ b/python/sglang/srt/configs/nemotron_h.py @@ -20,7 +20,6 @@ from transformers.utils import logging from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape -from sglang.srt.layers.dp_attention import get_attention_tp_size logger = logging.get_logger(__name__) @@ -273,6 +272,8 @@ def full_attention_layer_ids(self): @property def mamba2_cache_params(self) -> Mamba2CacheParams: + from sglang.srt.layers.dp_attention import get_attention_tp_size + shape = Mamba2StateShape.create( tp_world_size=get_attention_tp_size(), intermediate_size=self.mamba_num_heads * self.mamba_head_dim, diff --git a/python/sglang/srt/configs/qwen3_next.py b/python/sglang/srt/configs/qwen3_next.py index 630227a2c625..cd1b6f1ea59a 100644 --- a/python/sglang/srt/configs/qwen3_next.py +++ b/python/sglang/srt/configs/qwen3_next.py @@ -21,7 +21,6 @@ from transformers.utils import logging from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape -from sglang.srt.layers.dp_attention import get_attention_tp_size logger = logging.get_logger(__name__) @@ -277,6 +276,8 @@ def full_attention_layer_ids(self): @property def mamba2_cache_params(self) -> Mamba2CacheParams: + from sglang.srt.layers.dp_attention import get_attention_tp_size + shape = Mamba2StateShape.create( tp_world_size=get_attention_tp_size(), intermediate_size=self.linear_value_head_dim * self.linear_num_value_heads, diff --git a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py index 634bd4aa41fc..94703592658e 100644 --- a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py +++ b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py @@ -28,14 +28,11 @@ try: - if ops.use_vllm_custom_allreduce and not _is_hip: - # Use vLLM custom allreduce - ops.meta_size() - else: - # Use custom allreduce from sgl kernel (ROCM and TRT-LLM) - import sgl_kernel # noqa: F401 + # Use custom allreduce from sgl kernel (ROCM and TRT-LLM) + import sgl_kernel # noqa: F401 + custom_ar = True -except Exception: +except ImportError: # For CPUs custom_ar = False diff --git a/python/sglang/srt/distributed/device_communicators/pynccl_allocator.py b/python/sglang/srt/distributed/device_communicators/pynccl_allocator.py index 9ce1c1c20f1c..5f9644c0ce67 100644 --- a/python/sglang/srt/distributed/device_communicators/pynccl_allocator.py +++ b/python/sglang/srt/distributed/device_communicators/pynccl_allocator.py @@ -1,4 +1,5 @@ import tempfile +from contextlib import nullcontext import torch from packaging import version @@ -29,12 +30,23 @@ _mem_pool = None _registered_base_addrs = set() _graph_pool_id = None +_cached_pool_snapshot = None def is_symmetric_memory_enabled(): return get_global_server_args().enable_symm_mem +def is_symmetric_memory_tensor(tensor: torch.Tensor): + if not is_symmetric_memory_enabled() or _cached_pool_snapshot is None: + return False + for segment in _cached_pool_snapshot: + for block in segment["blocks"]: + if block["address"] == tensor.untyped_storage().data_ptr(): + return True + return False + + def set_graph_pool_id(graph_pool_id): global _graph_pool_id _graph_pool_id = graph_pool_id @@ -63,30 +75,18 @@ def get_nccl_mem_pool(): return _mem_pool -class use_symmetric_memory: - def __init__(self, group_coordinator: GroupCoordinator): - if not is_symmetric_memory_enabled(): - self.group_coordinator = None - self._mem_pool_ctx = None - self.is_graph_capture = None - self.device = None - self.pre_2_8_0 = None - else: - self.group_coordinator = group_coordinator - self._mem_pool_ctx = torch.cuda.use_mem_pool(get_nccl_mem_pool()) - self.is_graph_capture = torch.cuda.is_current_stream_capturing() - self.device = torch.cuda.current_device() - self.pre_2_8_0 = version.parse(torch.__version__) < version.parse("2.8.0") +class SymmetricMemoryContext: + def __init__( + self, + group_coordinator: GroupCoordinator, + ): + self.group_coordinator = group_coordinator + self._mem_pool_ctx = torch.cuda.use_mem_pool(get_nccl_mem_pool()) + self.is_graph_capture = torch.cuda.is_current_stream_capturing() + self.device = torch.cuda.current_device() + self.pre_2_8_0 = version.parse(torch.__version__) < version.parse("2.8.0") def __enter__(self): - if not is_symmetric_memory_enabled(): - return self - assert ( - self.group_coordinator.pynccl_comm is not None - ), f"Symmetric memory requires pynccl to be enabled in group '{self.group_coordinator.group_name}'" - assert ( - self.group_coordinator.pynccl_comm.nccl_version >= 22703 - ), "NCCL version 2.27.3 or higher is required for NCCL symmetric memory" if self.is_graph_capture: assert ( _graph_pool_id is not None @@ -101,17 +101,12 @@ def __enter__(self): self._mem_pool_ctx.__enter__() return self - def tag(self, tensor: torch.Tensor): - if not is_symmetric_memory_enabled(): - return - tensor.symmetric_memory = True - def __exit__(self, exc_type, exc_val, exc_tb): - if not is_symmetric_memory_enabled(): - return + global _cached_pool_snapshot global _registered_base_addrs self._mem_pool_ctx.__exit__(exc_type, exc_val, exc_tb) - for segment in get_nccl_mem_pool().snapshot(): + _cached_pool_snapshot = get_nccl_mem_pool().snapshot() + for segment in _cached_pool_snapshot: if segment["address"] not in _registered_base_addrs: if segment["stream"] == 0 and self.pre_2_8_0: # PyTorch version < 2.8.0 has a multi-thread MemPool bug @@ -131,3 +126,12 @@ def __exit__(self, exc_type, exc_val, exc_tb): torch._C._cuda_beginAllocateCurrentThreadToPool( self.device, _graph_pool_id ) + + +def use_symmetric_memory(group_coordinator: GroupCoordinator, disabled: bool = False): + disabled = ( + disabled + or not is_symmetric_memory_enabled() + or group_coordinator.world_size == 1 + ) + return SymmetricMemoryContext(group_coordinator) if not disabled else nullcontext() diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index 8d1e903d3de8..96c61a131624 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -187,6 +187,27 @@ def reg_all_gather_into_tensor_fake( fake_impl=reg_all_gather_into_tensor_fake, ) + def reg_reduce_scatter_tensor( + output: torch.Tensor, input: torch.Tensor, group_name: str + ) -> None: + assert group_name in _groups, f"Group {group_name} is not found." + group = _groups[group_name]() + if group is None: + raise ValueError(f"Group {group_name} is destroyed.") + group._reduce_scatter_tensor(output, input) + + def reg_reduce_scatter_tensor_fake( + output: torch.Tensor, input: torch.Tensor, group_name: str + ) -> None: + pass + + direct_register_custom_op( + op_name="reg_reduce_scatter_tensor", + op_func=reg_reduce_scatter_tensor, + mutates_args=["output"], + fake_impl=reg_reduce_scatter_tensor_fake, + ) + class GroupCoordinator: """ @@ -311,10 +332,16 @@ def __init__( from sglang.srt.distributed.device_communicators.pynccl import ( PyNcclCommunicator, ) + from sglang.srt.distributed.device_communicators.pynccl_allocator import ( + is_symmetric_memory_tensor, + use_symmetric_memory, + ) from sglang.srt.distributed.device_communicators.symm_mem import ( SymmMemCommunicator, ) + self.is_symmetric_memory_tensor = is_symmetric_memory_tensor + self.use_symmetric_memory = use_symmetric_memory if is_hip(): from sglang.srt.distributed.device_communicators.quick_all_reduce import ( QuickAllReduce, @@ -549,11 +576,7 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: if self.npu_communicator is not None and not self.npu_communicator.disabled: return self.npu_communicator.all_reduce(input_) - if ( - self.pynccl_comm is not None - and hasattr(input_, "symmetric_memory") - and input_.symmetric_memory - ): + if self.pynccl_comm is not None and self.is_symmetric_memory_tensor(input_): with self.pynccl_comm.change_state( enable=True, stream=torch.get_device_module().current_stream() ): @@ -628,15 +651,37 @@ def _all_reduce_in_place(self, input_: torch.Tensor) -> None: else: torch.distributed.all_reduce(input_, group=self.device_group) - def reduce_scatter_tensor( + def _reduce_scatter_tensor( self, output: torch.Tensor, input: torch.Tensor, - ) -> None: - # TODO(ch-wan): support other backends - torch.distributed.reduce_scatter_tensor(output, input, group=self.device_group) + ) -> torch.Tensor: + pynccl_comm = self.pynccl_comm + if pynccl_comm is not None and ( + not pynccl_comm.disabled + or ( + self.is_symmetric_memory_tensor(output) + and self.is_symmetric_memory_tensor(input) + ) + ): + with pynccl_comm.change_state( + enable=True, stream=torch.cuda.current_stream() + ): + pynccl_comm.reduce_scatter(output, input) + else: + torch.distributed.reduce_scatter_tensor( + output, input, group=self.device_group + ) return output + def reduce_scatter_tensor(self, output: torch.Tensor, input: torch.Tensor): + if _is_npu or not supports_custom_op(): + self._reduce_scatter_tensor(output, input) + else: + torch.ops.sglang.reg_reduce_scatter_tensor( + output, input, group_name=self.unique_name + ) + def reduce_scatter( self, output: torch.Tensor, @@ -683,8 +728,17 @@ def reduce_scatterv( def _all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor): pynccl_comm = self.pynccl_comm - if pynccl_comm is not None and not pynccl_comm.disabled: - pynccl_comm.all_gather(output, input) + if pynccl_comm is not None and ( + not pynccl_comm.disabled + or ( + self.is_symmetric_memory_tensor(output) + and self.is_symmetric_memory_tensor(input) + ) + ): + with pynccl_comm.change_state( + enable=True, stream=torch.cuda.current_stream() + ): + pynccl_comm.all_gather(output, input) else: torch.distributed.all_gather_into_tensor( output, input, group=self.device_group @@ -746,9 +800,10 @@ def all_gather( # torch.compile . see https://github.com/pytorch/pytorch/issues/138795 output_size = (input_size[0] * world_size,) + input_size[1:] # Allocate output tensor. - output_tensor = torch.empty( - output_size, dtype=input_.dtype, device=input_.device - ) + with self.use_symmetric_memory(self): + output_tensor = torch.empty( + output_size, dtype=input_.dtype, device=input_.device + ) # All-gather. if input_.is_cpu: @@ -788,7 +843,7 @@ def all_gatherv( pynccl_comm is not None and not pynccl_comm.disabled ), "pynccl is required for all_gatherv" - def _all_gather_single( + def _all_gather_allocate_output( input_: torch.Tensor, sizes: Optional[List[int]] = None ): input_size = input_.size() @@ -802,19 +857,25 @@ def _all_gather_single( else: output_size = (input_size[0] * world_size,) + input_size[1:] # Allocate output tensor. - output_tensor = torch.empty( - output_size, dtype=input_.dtype, device=input_.device - ) - pynccl_comm.all_gather(output_tensor, input_, sizes=sizes) - return output_tensor + with self.use_symmetric_memory(self, disabled=sizes is not None): + output_tensor = torch.empty( + output_size, dtype=input_.dtype, device=input_.device + ) + return output_tensor, sizes if isinstance(input_, torch.Tensor): - return _all_gather_single(input_, sizes) + input_ = [input_] output_list = [] - pynccl_comm.group_start() + size_list = [] for inp in input_: - output_list.append(_all_gather_single(inp, sizes=sizes)) + output_tensor, s = _all_gather_allocate_output(inp, sizes=sizes) + output_list.append(output_tensor) + size_list.append(s) + + pynccl_comm.group_start() + for i, inp in enumerate(input_): + pynccl_comm.all_gather(output_list[i], inp, sizes=size_list[i]) pynccl_comm.group_end() return output_list diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py index 73be13ef942d..0130d3f01f19 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py @@ -211,7 +211,6 @@ class Envs: SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK = EnvBool(False) # vLLM dependencies (TODO: they have been deprecated, we can remove them safely) - USE_VLLM_CUSTOM_ALLREDUCE = EnvBool(False) USE_VLLM_CUTLASS_W8A8_FP8_KERNEL = EnvBool(False) USE_TRITON_W8A8_FP8_KERNEL = EnvBool(False) diff --git a/python/sglang/srt/layers/communicator.py b/python/sglang/srt/layers/communicator.py index eaeeac51def2..8315f7289e09 100644 --- a/python/sglang/srt/layers/communicator.py +++ b/python/sglang/srt/layers/communicator.py @@ -21,8 +21,12 @@ from sglang.srt.distributed import ( get_tensor_model_parallel_world_size, + get_tp_group, tensor_model_parallel_all_reduce, ) +from sglang.srt.distributed.device_communicators.pynccl_allocator import ( + use_symmetric_memory, +) from sglang.srt.layers.dp_attention import ( attn_tp_all_gather_into_tensor, attn_tp_reduce_scatter_tensor, @@ -540,7 +544,12 @@ def _gather_hidden_states_and_residual( use_layer_norm_before_gather = context.attn_tp_size == 1 if use_layer_norm_before_gather and hidden_states.shape[0] != 0: residual = hidden_states - hidden_states = layernorm(hidden_states) + with use_symmetric_memory( + get_tp_group(), + disabled=not forward_batch.dp_padding_mode.is_max_len(), + ): + hidden_states = layernorm(hidden_states) + hidden_states, local_hidden_states = ( get_global_dp_buffer(), hidden_states, diff --git a/python/sglang/srt/layers/dp_attention.py b/python/sglang/srt/layers/dp_attention.py index 2b413b4467bd..a7a5344c7179 100644 --- a/python/sglang/srt/layers/dp_attention.py +++ b/python/sglang/srt/layers/dp_attention.py @@ -18,6 +18,9 @@ tensor_model_parallel_all_reduce, ) from sglang.srt.utils import get_bool_env_var, is_hip +from sglang.srt.distributed.device_communicators.pynccl_allocator import ( + use_symmetric_memory, +) if TYPE_CHECKING: from sglang.srt.configs.model_config import ModelConfig @@ -86,6 +89,7 @@ class _DpGatheredBufferWrapper: _device: torch.device _global_dp_buffer_len: int _local_dp_buffer_len: int + _dp_max_padding: bool _global_num_tokens: Optional[List[int]] _is_extend_in_batch: bool @@ -100,27 +104,33 @@ def set_dp_buffer_len( cls, global_dp_buffer_len: int, local_dp_buffer_len: int, + dp_max_padding: bool, global_num_tokens: Optional[List[int]] = None, ): cls._global_dp_buffer_len = global_dp_buffer_len cls._local_dp_buffer_len = local_dp_buffer_len + cls._dp_max_padding = dp_max_padding cls._global_num_tokens = global_num_tokens @classmethod def get_global_dp_buffer(cls) -> torch.Tensor: - return torch.empty( - (cls._global_dp_buffer_len, cls._hidden_size), - dtype=cls._dtype, - device=cls._device, - ) + with use_symmetric_memory(get_tp_group()): + buffer = torch.empty( + (cls._global_dp_buffer_len, cls._hidden_size), + dtype=cls._dtype, + device=cls._device, + ) + return buffer @classmethod def get_local_dp_buffer(cls) -> torch.Tensor: - return torch.empty( - (cls._local_dp_buffer_len, cls._hidden_size), - dtype=cls._dtype, - device=cls._device, - ) + with use_symmetric_memory(get_tp_group(), disabled=not cls._dp_max_padding): + buffer = torch.empty( + (cls._local_dp_buffer_len, cls._hidden_size), + dtype=cls._dtype, + device=cls._device, + ) + return buffer @classmethod def get_global_dp_buffer_len(cls) -> int: @@ -146,6 +156,10 @@ def get_dp_dtype(cls) -> torch.dtype: def get_dp_device(cls) -> torch.device: return cls._device + @classmethod + def is_dp_max_padding(cls) -> bool: + return cls._dp_max_padding + @classmethod def set_is_extend_in_batch(cls, is_extend_in_batch: bool): cls._is_extend_in_batch = is_extend_in_batch @@ -158,10 +172,11 @@ def get_is_extend_in_batch(cls) -> bool: def set_dp_buffer_len( global_dp_buffer_len: int, local_dp_buffer_len: int, + dp_max_padding: bool, global_num_tokens: Optional[List[int]] = None, ): _DpGatheredBufferWrapper.set_dp_buffer_len( - global_dp_buffer_len, local_dp_buffer_len, global_num_tokens + global_dp_buffer_len, local_dp_buffer_len, dp_max_padding, global_num_tokens ) @@ -197,6 +212,10 @@ def get_dp_device() -> torch.device: return _DpGatheredBufferWrapper.get_dp_device() +def is_dp_max_padding() -> bool: + return _DpGatheredBufferWrapper.is_dp_max_padding() + + def set_is_extend_in_batch(is_extend_in_batch: bool): _DpGatheredBufferWrapper.set_is_extend_in_batch(is_extend_in_batch) @@ -298,6 +317,10 @@ def is_dp_attention_enabled() -> bool: return _ENABLE_DP_ATTENTION_FLAG +def is_allocation_symmetric() -> bool: + return not is_dp_attention_enabled() or is_dp_max_padding() + + def get_attention_tp_group() -> GroupCoordinator: assert _ATTN_TP_GROUP is not None, "dp attention not initialized!" return _ATTN_TP_GROUP diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py index 23b2635d2421..e1a97111debb 100644 --- a/python/sglang/srt/layers/linear.py +++ b/python/sglang/srt/layers/linear.py @@ -13,7 +13,7 @@ divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, - parallel_state, + get_tp_group, split_tensor_along_last_dim, tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce, @@ -21,6 +21,7 @@ from sglang.srt.distributed.device_communicators.pynccl_allocator import ( use_symmetric_memory, ) +from sglang.srt.layers.dp_attention import is_allocation_symmetric from sglang.srt.layers.parameter import ( BasevLLMParameter, BlockQuantScaleParameter, @@ -1372,9 +1373,10 @@ def forward(self, input_, skip_all_reduce=False): # Only fuse bias add into GEMM for rank 0 (this ensures that # bias will not get added more than once in TP>1 case) bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias - with use_symmetric_memory(parallel_state.get_tp_group()) as sm: + with use_symmetric_memory( + get_tp_group(), disabled=not is_allocation_symmetric() + ): output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_) - sm.tag(output_parallel) if self.reduce_results and self.tp_size > 1 and not skip_all_reduce: output = tensor_model_parallel_all_reduce(output_parallel) diff --git a/python/sglang/srt/layers/moe/cutlass_moe.py b/python/sglang/srt/layers/moe/cutlass_moe.py index 870749d4799e..b4c99d73dd98 100755 --- a/python/sglang/srt/layers/moe/cutlass_moe.py +++ b/python/sglang/srt/layers/moe/cutlass_moe.py @@ -40,6 +40,7 @@ def cutlass_fused_experts_fp8( problem_sizes1: torch.Tensor, problem_sizes2: torch.Tensor, use_fp8_blockscale: bool = True, + output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Performs Fused MoE computation using CUTLASS-like kernels with FP8 weights and activations. @@ -94,7 +95,7 @@ def cutlass_fused_experts_fp8( b_scales_ptrs (torch.Tensor): Pointers container for calculating offsets of the input scales for each expert. use_fp8_blockscale (bool, optional): Flag indicating usage of FP8 with block scaling. Currently, only `True` is supported. Defaults to `True`. - + output (torch.Tensor, optional): Output tensor. If not provided, a new tensor will be created. Returns: torch.Tensor: The computed MoE layer output. Shape: `(m, k)`, dtype matches `a`. @@ -200,9 +201,11 @@ def cutlass_fused_experts_fp8( workspace, ) - result = torch.empty((m, k), device=device, dtype=out_dtype) - apply_shuffle_mul_sum(c2, result, c_map, topk_weights.to(out_dtype)) - return result + if output is None: + output = torch.empty((m, k), device=device, dtype=out_dtype) + + apply_shuffle_mul_sum(c2, output, c_map, topk_weights.to(out_dtype)) + return output FLOAT4_E2M1_MAX = 6.0 diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 5b527a386e07..cd00a74d3358 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -15,6 +15,7 @@ tensor_model_parallel_all_reduce, ) from sglang.srt.eplb.expert_location import get_global_expert_location_metadata +from sglang.srt.layers.dp_attention import is_allocation_symmetric from sglang.srt.layers.moe import ( MoeRunnerConfig, get_deepep_mode, @@ -1040,6 +1041,10 @@ def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput): router_logits = router_logits.to(torch.float32) + with use_symmetric_memory( + get_tp_group(), disabled=not is_allocation_symmetric() + ): + symm_output = torch.empty_like(hidden_states) result = trtllm_fp4_block_scale_moe( routing_logits=router_logits, routing_bias=topk_config.correction_bias.to(hidden_states.dtype), @@ -1072,6 +1077,7 @@ def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput): tile_tokens_dim=None, routing_method_type=RoutingMethodType.DeepSeekV3, do_finalize=True, + output=symm_output, )[0] return result diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index f85ccd879a49..c2105aaf7834 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -32,12 +32,17 @@ import torch.nn.functional as F from sglang.srt.custom_op import CustomOp +from sglang.srt.distributed import get_tp_group +from sglang.srt.distributed.device_communicators.pynccl_allocator import ( + use_symmetric_memory, +) from sglang.srt.eplb import expert_location_dispatch from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.eplb.expert_location_dispatch import ( ExpertLocationDispatchInfo, topk_ids_logical_to_physical, ) +from sglang.srt.layers.dp_attention import is_allocation_symmetric from sglang.srt.layers.moe import ( get_moe_runner_backend, should_use_flashinfer_trtllm_moe, @@ -282,13 +287,17 @@ def forward_cuda( ) else: self.topk_config.torch_native = False - return select_experts( - hidden_states=hidden_states, - router_logits=router_logits, - topk_config=self.topk_config, - num_token_non_padded=num_token_non_padded, - expert_location_dispatch_info=expert_location_dispatch_info, - ) + with use_symmetric_memory( + get_tp_group(), disabled=not is_allocation_symmetric() + ): + topk_output = select_experts( + hidden_states=hidden_states, + router_logits=router_logits, + topk_config=self.topk_config, + num_token_non_padded=num_token_non_padded, + expert_location_dispatch_info=expert_location_dispatch_info, + ) + return topk_output def forward_cpu( self, @@ -389,9 +398,11 @@ def forward_npu( def empty_topk_output(self, device: torch.device) -> TopKOutput: topk = self.topk_config.top_k - self.topk_config.num_fused_shared_experts - topk_weights = torch.empty((0, topk), dtype=torch.float32, device=device) - topk_ids = torch.full((0, topk), -1, dtype=torch.int32, device=device) - # FIXME: router_logits should be of size (0, num_experts) + with use_symmetric_memory( + get_tp_group(), disabled=not is_allocation_symmetric() + ): + topk_weights = torch.empty((0, topk), dtype=torch.float32, device=device) + topk_ids = torch.full((0, topk), -1, dtype=torch.int32, device=device) router_logits = torch.empty((0, topk), dtype=torch.float32, device=device) return StandardTopKOutput(topk_weights, topk_ids, router_logits) diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 91b54e1257ea..5113ce37448f 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -10,6 +10,12 @@ from torch.nn import Module from torch.nn.parameter import Parameter +from sglang.srt.distributed import get_tp_group +from sglang.srt.distributed.device_communicators.pynccl_allocator import ( + use_symmetric_memory, +) +from sglang.srt.layers.dp_attention import is_allocation_symmetric + try: from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( apply_fp8_marlin_linear, @@ -1027,6 +1033,10 @@ def apply( from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8 topk_weights, topk_ids, _ = dispatch_output.topk_output + with use_symmetric_memory( + get_tp_group(), disabled=not is_allocation_symmetric() + ): + symm_output = torch.empty_like(x) output = cutlass_fused_experts_fp8( x, layer.w13_weight.transpose(1, 2), @@ -1049,6 +1059,7 @@ def apply( self.problem_sizes1, self.problem_sizes2, use_fp8_blockscale=True, + output=symm_output, ) return StandardCombineInput(hidden_states=output) @@ -1212,31 +1223,34 @@ def apply_with_router_logits( else topk_config.correction_bias.to(x.dtype) ) - return trtllm_fp8_block_scale_moe( - routing_logits=router_logits.to(torch.float32), - routing_bias=correction_bias, - hidden_states=a_q, - hidden_states_scale=a_sf_t, - gemm1_weights=layer.w13_weight, - gemm1_weights_scale=layer.w13_weight_scale_inv, - gemm2_weights=layer.w2_weight, - gemm2_weights_scale=layer.w2_weight_scale_inv, - num_experts=layer.num_experts, - top_k=topk_config.top_k, - n_group=topk_config.num_expert_group, - topk_group=topk_config.topk_group, - intermediate_size=layer.w2_weight.shape[2], - local_expert_offset=layer.moe_ep_rank * layer.num_local_experts, - local_num_experts=layer.num_local_experts, - routed_scaling_factor=( - routed_scaling_factor if routed_scaling_factor is not None else 1.0 - ), - tile_tokens_dim=get_tile_tokens_dim( - x.shape[0], topk_config.top_k, layer.num_experts - ), - routing_method_type=2, # DeepSeek-styled routing method - use_shuffled_weight=False, - ) + with use_symmetric_memory( + get_tp_group(), disabled=not is_allocation_symmetric() + ): + return trtllm_fp8_block_scale_moe( + routing_logits=router_logits.to(torch.float32), + routing_bias=correction_bias, + hidden_states=a_q, + hidden_states_scale=a_sf_t, + gemm1_weights=layer.w13_weight, + gemm1_weights_scale=layer.w13_weight_scale_inv, + gemm2_weights=layer.w2_weight, + gemm2_weights_scale=layer.w2_weight_scale_inv, + num_experts=layer.num_experts, + top_k=topk_config.top_k, + n_group=topk_config.num_expert_group, + topk_group=topk_config.topk_group, + intermediate_size=layer.w2_weight.shape[2], + local_expert_offset=layer.moe_ep_rank * layer.num_local_experts, + local_num_experts=layer.num_local_experts, + routed_scaling_factor=( + routed_scaling_factor if routed_scaling_factor is not None else 1.0 + ), + tile_tokens_dim=get_tile_tokens_dim( + x.shape[0], topk_config.top_k, layer.num_experts + ), + routing_method_type=2, # DeepSeek-styled routing method + use_shuffled_weight=False, + ) def maybe_apply_hip_fused_experts( self, diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index 40fd2bb5b63a..35cb29670cd8 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -8,7 +8,14 @@ from torch.nn.parameter import Parameter from sglang.srt.distributed import get_tp_group -from sglang.srt.layers.dp_attention import get_dp_global_num_tokens, get_local_dp_buffer +from sglang.srt.distributed.device_communicators.pynccl_allocator import ( + use_symmetric_memory, +) +from sglang.srt.layers.dp_attention import ( + get_dp_global_num_tokens, + get_local_dp_buffer, + is_allocation_symmetric, +) from sglang.srt.layers.moe import ( MoeRunner, MoeRunnerBackend, @@ -1561,27 +1568,42 @@ def apply( topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids output_dtype = x.dtype + output_col = x.shape[1] x_sf = None + if should_use_flashinfer_cutlass_moe_fp4_allgather(): from flashinfer import nvfp4_block_scale_interleave # Quantize before comm, swizzle after. - if x.shape[0] > 0: - x, x_sf = fp4_quantize( - x, layer.w13_input_scale_quant, is_sf_swizzled_layout=False - ) - else: - x_col = x.shape[1] - x = torch.zeros(0, x_col // 2, dtype=torch.uint8, device=x.device) - x_sf = torch.zeros( - 0, x_col // 16, dtype=torch.uint8, device=x.device - ) + with use_symmetric_memory( + get_tp_group(), disabled=not is_allocation_symmetric() + ): + if x.shape[0] > 0: + x, x_sf = fp4_quantize( + x, layer.w13_input_scale_quant, is_sf_swizzled_layout=False + ) + else: + x_col = x.shape[1] + x = torch.zeros( + 0, x_col // 2, dtype=torch.uint8, device=x.device + ) + x_sf = torch.zeros( + 0, x_col // 16, dtype=torch.uint8, device=x.device + ) topk_weights, topk_ids, x, x_sf = get_tp_group().all_gatherv( [topk_weights, topk_ids, x, x_sf], sizes=get_dp_global_num_tokens() ) x_sf = nvfp4_block_scale_interleave(x_sf) + with use_symmetric_memory( + get_tp_group(), disabled=not is_allocation_symmetric() + ): + symm_output = torch.empty( + x.shape[0], output_col, dtype=output_dtype, device=x.device + ) + output = flashinfer_cutlass_fused_moe( + output=symm_output, input=x, token_selected_experts=topk_ids.to(torch.int), token_final_scales=topk_weights, diff --git a/python/sglang/srt/layers/quantization/mxfp4.py b/python/sglang/srt/layers/quantization/mxfp4.py index 77f5f3f3567f..3329692fc8d1 100644 --- a/python/sglang/srt/layers/quantization/mxfp4.py +++ b/python/sglang/srt/layers/quantization/mxfp4.py @@ -22,6 +22,11 @@ import torch from torch.nn.parameter import Parameter +from sglang.srt.distributed.device_communicators.pynccl_allocator import ( + use_symmetric_memory, +) +from sglang.srt.distributed.parallel_state import get_tp_group +from sglang.srt.layers.dp_attention import is_allocation_symmetric from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo from sglang.srt.layers.moe.utils import get_moe_runner_backend @@ -638,6 +643,10 @@ def apply( top_k = topk_output.topk_config.top_k router_logits = topk_output.router_logits + with use_symmetric_memory( + get_tp_group(), disabled=not is_allocation_symmetric() + ): + symm_output = torch.empty_like(x) trtllm_gen_output = trtllm_fp4_block_scale_moe( router_logits.to(torch.bfloat16), None, # routing_bias @@ -666,6 +675,7 @@ def apply( None, # tile_tokens_dim 1, # routing_method_type, renormalize True, # do finalize + output=symm_output, )[0] return StandardCombineInput(hidden_states=trtllm_gen_output) diff --git a/python/sglang/srt/layers/vocab_parallel_embedding.py b/python/sglang/srt/layers/vocab_parallel_embedding.py index 986babf2bc44..6c153b25051b 100644 --- a/python/sglang/srt/layers/vocab_parallel_embedding.py +++ b/python/sglang/srt/layers/vocab_parallel_embedding.py @@ -11,7 +11,7 @@ divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, - parallel_state, + get_tp_group, tensor_model_parallel_all_reduce, ) from sglang.srt.distributed.device_communicators.pynccl_allocator import ( @@ -473,9 +473,8 @@ def forward(self, input_): else: masked_input = input_ # Get the embeddings. - with use_symmetric_memory(parallel_state.get_tp_group()) as sm: + with use_symmetric_memory(get_tp_group(), disabled=not self.enable_tp): output_parallel = self.quant_method.embedding(self, masked_input.long()) - sm.tag(output_parallel) # Mask the output embedding. if self.tp_size > 1: output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0) diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 80f8c65648cd..243b9a84b286 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -660,7 +660,11 @@ def capture_one_batch_size(self, bs: int, forward: Callable): def run_once(): # Clean intermediate result cache for DP attention forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None - set_dp_buffer_len(global_dp_buffer_len, num_tokens) + set_dp_buffer_len( + global_dp_buffer_len, + num_tokens, + forward_batch.dp_padding_mode.is_max_len(), + ) set_is_extend_in_batch(False) kwargs = {} diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 55ebf1a48a42..e358a1e22b1a 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -710,7 +710,9 @@ def prepare_mlp_sync_batch(self, model_runner: ModelRunner): num_tokens = global_num_tokens[0] self.global_dp_buffer_len = buffer_len - set_dp_buffer_len(buffer_len, num_tokens, global_num_tokens) + set_dp_buffer_len( + buffer_len, num_tokens, dp_padding_mode.is_max_len(), global_num_tokens + ) set_is_extend_in_batch(self.is_extend_in_batch) bs = self.batch_size diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 2b6378dbef1b..b687d19e61ef 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -39,12 +39,8 @@ get_moe_expert_parallel_world_size, get_pp_group, get_tensor_model_parallel_world_size, - parallel_state, tensor_model_parallel_all_reduce, ) -from sglang.srt.distributed.device_communicators.pynccl_allocator import ( - use_symmetric_memory, -) from sglang.srt.environ import envs from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation @@ -484,7 +480,8 @@ def forward( gate_up, _ = self.gate_up_proj(x) x = self.act_fn(gate_up) x, _ = self.down_proj( - x, skip_all_reduce=should_allreduce_fusion or use_reduce_scatter + x, + skip_all_reduce=should_allreduce_fusion or use_reduce_scatter, ) return x @@ -760,12 +757,8 @@ def forward_normal_dual_stream( final_hidden_states *= self.routed_scaling_factor current_stream.wait_stream(self.alt_stream) - with use_symmetric_memory(parallel_state.get_tp_group()) as sm: - final_hidden_states_out = torch.empty_like(final_hidden_states) + final_hidden_states += shared_output - torch.add(final_hidden_states, shared_output, out=final_hidden_states_out) - final_hidden_states = final_hidden_states_out - sm.tag(final_hidden_states) if ( self.tp_size > 1 and not should_allreduce_fusion @@ -824,11 +817,7 @@ def _forward_shared_experts_and_put_results(): # fused in biased_grouped_topk so we can skip here final_hidden_states *= self.routed_scaling_factor if shared_output is not None: - with use_symmetric_memory(parallel_state.get_tp_group()) as sm: - final_hidden_states_out = torch.empty_like(final_hidden_states) - torch.add(final_hidden_states, shared_output, out=final_hidden_states_out) - final_hidden_states = final_hidden_states_out - sm.tag(final_hidden_states) + final_hidden_states += shared_output if ( self.tp_size > 1 and not should_allreduce_fusion @@ -897,7 +886,9 @@ def forward_cpu( return final_hidden_states def forward_deepep( - self, hidden_states: torch.Tensor, forward_batch: ForwardBatch + self, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, ) -> torch.Tensor: shared_output = None if hidden_states.shape[0] > 0: diff --git a/python/sglang/srt/operations.py b/python/sglang/srt/operations.py index f8730cd77232..9d824587c4da 100644 --- a/python/sglang/srt/operations.py +++ b/python/sglang/srt/operations.py @@ -85,6 +85,7 @@ def __init__(self, debug_name: str, stages: List[Stage], inputs: dict): self._global_dp_buffer_len = forward_batch.global_dp_buffer_len self._local_dp_buffer_len = forward_batch.input_ids.shape[0] self._global_num_tokens = forward_batch.global_num_tokens_cpu + self._is_dp_max_padding = forward_batch.dp_padding_mode.is_max_len() def next(self): assert not self.done @@ -95,6 +96,7 @@ def next(self): set_dp_buffer_len( self._global_dp_buffer_len, self._local_dp_buffer_len, + self._is_dp_max_padding, self._global_num_tokens, ) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 3f130e3ca6b4..4f38a5b23c84 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -3156,7 +3156,7 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( "--enable-torch-symm-mem", action="store_true", - help="Enable using torch symm mem for all-reduce kernel and fall back to NCCL. Only supports CUDA device SM90 and above. SM90 supports world size 4, 6, 8. SM10 supports world size 6, 8.", + help="Enable using torch symm mem for all-reduce kernel and fall back to NCCL. Only supports CUDA device SM90 and above. SM90 supports world size 4, 6, 8. SM100 supports world size 6, 8.", ) parser.add_argument( "--disable-overlap-schedule", diff --git a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py index 1b785443dd00..38ec9f46667f 100644 --- a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py @@ -263,7 +263,11 @@ def capture_one_batch_size(self, num_seqs: int, forward: Callable): def run_once(): # Clean intermediate result cache for DP attention forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None - set_dp_buffer_len(global_dp_buffer_len, num_tokens) + set_dp_buffer_len( + global_dp_buffer_len, + num_tokens, + forward_batch.dp_padding_mode.is_max_len(), + ) set_is_extend_in_batch(False) # Backup two fields, which will be modified in-place in `draft_forward`. diff --git a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py index d4b5aeb27fbf..4571ac540f09 100644 --- a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py @@ -294,7 +294,11 @@ def capture_one_batch_size(self, bs: int, forward: Callable): def run_once(): # Clean intermediate result cache for DP attention forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None - set_dp_buffer_len(global_dp_buffer_len, num_tokens) + set_dp_buffer_len( + global_dp_buffer_len, + num_tokens, + forward_batch.dp_padding_mode.is_max_len(), + ) set_is_extend_in_batch(False) # Backup two fields, which will be modified in-place in `draft_forward`. diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index 0c1a13b5f86d..e04585d0c627 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -297,6 +297,7 @@ def xpu_has_xmx_support(): return False +@lru_cache(maxsize=1) def is_flashinfer_available(): """ Check whether flashinfer is available. diff --git a/sgl-kernel/python/sgl_kernel/elementwise.py b/sgl-kernel/python/sgl_kernel/elementwise.py index 13bb11be3a06..d4222a37fa4b 100644 --- a/sgl-kernel/python/sgl_kernel/elementwise.py +++ b/sgl-kernel/python/sgl_kernel/elementwise.py @@ -263,6 +263,10 @@ class FusedSetKVBufferArg: cache_loc: torch.Tensor +def _view_3d(x, head_size: int): + return x.view(x.shape[0], -1, head_size) + + def apply_rope_with_cos_sin_cache_inplace( positions: torch.Tensor, query: torch.Tensor, @@ -317,14 +321,11 @@ def apply_rope_with_cos_sin_cache_inplace( assert a.v_scale is None, "v_scale is not yet supported" assert a.cache_loc.dtype == torch.int64, f"{a.cache_loc.dtype=}" - def _view_3d(x): - return x.view(x.shape[0], -1, head_size) - torch.ops.sgl_kernel.apply_rope_pos_ids_cos_sin_cache.default( - _view_3d(query), - _view_3d(key), - _view_3d(query), - _view_3d(key), + _view_3d(query, head_size), + _view_3d(key, head_size), + _view_3d(query, head_size), + _view_3d(key, head_size), cos_sin_cache, positions.long(), (not is_neox),