diff --git a/python/sglang/srt/distributed/device_communicators/pynccl.py b/python/sglang/srt/distributed/device_communicators/pynccl.py index 6180e64d63c8..f485c24c2df0 100644 --- a/python/sglang/srt/distributed/device_communicators/pynccl.py +++ b/python/sglang/srt/distributed/device_communicators/pynccl.py @@ -19,6 +19,7 @@ ncclUniqueId, ) from sglang.srt.distributed.utils import StatelessProcessGroup +from sglang.srt.utils.common import get_current_device_stream_fast logger = logging.getLogger(__name__) @@ -137,7 +138,7 @@ def _resolve_stream(self, stream: Optional[torch.cuda.Stream]): if stream is not None: return stream if self.use_current_stream: - return torch.cuda.current_stream() + return get_current_device_stream_fast() return self.stream def all_reduce( diff --git a/python/sglang/srt/distributed/device_communicators/pynccl_allocator.py b/python/sglang/srt/distributed/device_communicators/pynccl_allocator.py index 9ce1c1c20f1c..5a5e9ebe6dc0 100644 --- a/python/sglang/srt/distributed/device_communicators/pynccl_allocator.py +++ b/python/sglang/srt/distributed/device_communicators/pynccl_allocator.py @@ -1,7 +1,7 @@ +import os import tempfile import torch -from packaging import version from torch.cuda.memory import CUDAPluggableAllocator from sglang.srt.distributed.parallel_state import GroupCoordinator @@ -9,13 +9,22 @@ nccl_allocator_source = """ #include + extern "C" { void* nccl_alloc_plug(size_t size, int device, void* stream) { void* ptr; ncclResult_t err = ncclMemAlloc(&ptr, size); - return ptr; + const char *str_val = getenv("SGLANG_TMP_NCCL_COMM_VALUE"); + char *endptr; + void* int_val = (void *)strtoull(str_val, &endptr, 0); + + ncclComm_t comm = (ncclComm_t)(int_val); + ncclWindow_t win; + ncclResult_t err2 = ncclCommWindowRegister(comm, ptr, size, &win, NCCL_WIN_COLL_SYMMETRIC); + + return ptr; } void nccl_free_plug(void* ptr, size_t size, int device, void* stream) { @@ -27,8 +36,8 @@ _allocator = None _mem_pool = None -_registered_base_addrs = set() _graph_pool_id = None +_cur_device = None def is_symmetric_memory_enabled(): @@ -41,7 +50,7 @@ def set_graph_pool_id(graph_pool_id): def get_nccl_mem_pool(): - global _allocator, _mem_pool + global _allocator, _mem_pool, _cur_device if _mem_pool is None: out_dir = tempfile.gettempdir() nccl_allocator_libname = "nccl_allocator" @@ -60,74 +69,67 @@ def get_nccl_mem_pool(): "nccl_free_plug", ).allocator() _mem_pool = torch.cuda.MemPool(_allocator) + _cur_device = torch.cuda.current_device() return _mem_pool class use_symmetric_memory: + """ + Context manager for using symmetric memory with pynccl. + + To Utilize the symmetric memory feature in NCCL, the buffers need to be allocated + by `ncclMemAlloc` and registered by `ncclCommWindowRegister`. Due to this, we introduce + this context manager. All tensors created under this context will be correctly + allocated and registered with a custom allocator. + + In addition, developers need to manually tag the tensors that will be used as the input/output + of NCCL collectives with `tag(tensor)`. + """ + def __init__(self, group_coordinator: GroupCoordinator): - if not is_symmetric_memory_enabled(): - self.group_coordinator = None - self._mem_pool_ctx = None - self.is_graph_capture = None - self.device = None - self.pre_2_8_0 = None - else: - self.group_coordinator = group_coordinator - self._mem_pool_ctx = torch.cuda.use_mem_pool(get_nccl_mem_pool()) - self.is_graph_capture = torch.cuda.is_current_stream_capturing() - self.device = torch.cuda.current_device() - self.pre_2_8_0 = version.parse(torch.__version__) < version.parse("2.8.0") + self.enabled = is_symmetric_memory_enabled() + + if not self.enabled: + return + + self.group_coordinator = group_coordinator + self._mem_pool_ctx = torch.cuda.use_mem_pool(get_nccl_mem_pool()) + self.is_graph_capture = torch.cuda.is_current_stream_capturing() def __enter__(self): - if not is_symmetric_memory_enabled(): + if not self.enabled: return self + assert ( self.group_coordinator.pynccl_comm is not None ), f"Symmetric memory requires pynccl to be enabled in group '{self.group_coordinator.group_name}'" - assert ( - self.group_coordinator.pynccl_comm.nccl_version >= 22703 - ), "NCCL version 2.27.3 or higher is required for NCCL symmetric memory" + if self.is_graph_capture: assert ( _graph_pool_id is not None ), "graph_pool_id is not set under graph capture" # Pause graph memory pool to use symmetric memory with cuda graph - if self.pre_2_8_0: - torch._C._cuda_endAllocateCurrentStreamToPool( - self.device, _graph_pool_id - ) - else: - torch._C._cuda_endAllocateToPool(self.device, _graph_pool_id) + torch._C._cuda_endAllocateToPool(_cur_device, _graph_pool_id) + self._mem_pool_ctx.__enter__() - return self - def tag(self, tensor: torch.Tensor): - if not is_symmetric_memory_enabled(): - return - tensor.symmetric_memory = True + # Set the env var to pass this argument to the C functions. + os.environ["SGLANG_TMP_NCCL_COMM_VALUE"] = str( + self.group_coordinator.pynccl_comm.comm.value + ) + return self def __exit__(self, exc_type, exc_val, exc_tb): - if not is_symmetric_memory_enabled(): + if not self.enabled: return - global _registered_base_addrs + self._mem_pool_ctx.__exit__(exc_type, exc_val, exc_tb) - for segment in get_nccl_mem_pool().snapshot(): - if segment["address"] not in _registered_base_addrs: - if segment["stream"] == 0 and self.pre_2_8_0: - # PyTorch version < 2.8.0 has a multi-thread MemPool bug - # See https://github.com/pytorch/pytorch/issues/152861 - # Fixed at https://github.com/pytorch/pytorch/commit/f01e628e3b31852983ab30b25bf251f557ba9c0b - # WAR is to skip allocations on the default stream since the forward_pass thread always runs on a custom stream - continue - self.group_coordinator.pynccl_comm.register_comm_window_raw( - segment["address"], segment["total_size"] - ) - _registered_base_addrs.add(segment["address"]) if self.is_graph_capture: - if self.pre_2_8_0: - torch._C._cuda_beginAllocateToPool(self.device, _graph_pool_id) - else: - torch._C._cuda_beginAllocateCurrentThreadToPool( - self.device, _graph_pool_id - ) + torch._C._cuda_beginAllocateCurrentThreadToPool(_cur_device, _graph_pool_id) + + def tag(self, tensor: torch.Tensor): + if not self.enabled: + return + + tensor.symmetric_memory = True diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index 314a3ab77a29..16c00d17b271 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -43,6 +43,7 @@ from sglang.srt.utils import ( direct_register_custom_op, get_bool_env_var, + get_current_device_stream_fast, get_int_env_var, get_local_ip_auto, is_cpu, @@ -466,7 +467,7 @@ def graph_capture( # ensure all initialization operations complete before attempting to # capture the graph on another stream - curr_stream = self.device_module.current_stream() + curr_stream = get_current_device_stream_fast() if curr_stream != stream: stream.wait_stream(curr_stream) @@ -500,7 +501,7 @@ def graph_capture( maybe_pynccl_context = nullcontext() else: maybe_pynccl_context = pynccl_comm.change_state( - enable=True, stream=torch.get_device_module().current_stream() + enable=True, stream=get_current_device_stream_fast() ) pymscclpp_comm = self.pymscclpp_comm @@ -551,13 +552,9 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: if self.npu_communicator is not None and not self.npu_communicator.disabled: return self.npu_communicator.all_reduce(input_) - if ( - self.pynccl_comm is not None - and hasattr(input_, "symmetric_memory") - and input_.symmetric_memory - ): + if self.pynccl_comm is not None and getattr(input_, "symmetric_memory", False): with self.pynccl_comm.change_state( - enable=True, stream=torch.get_device_module().current_stream() + enable=True, stream=get_current_device_stream_fast() ): self.pynccl_comm.all_reduce(input_) return input_ @@ -658,7 +655,7 @@ def reduce_scatterv( pynccl_comm = self.pynccl_comm with pynccl_comm.change_state( - enable=True, stream=torch.get_device_module().current_stream() + enable=True, stream=get_current_device_stream_fast() ): assert ( pynccl_comm is not None and not pynccl_comm.disabled @@ -784,7 +781,7 @@ def all_gatherv( pynccl_comm = self.pynccl_comm with pynccl_comm.change_state( - enable=True, stream=torch.get_device_module().current_stream() + enable=True, stream=get_current_device_stream_fast() ): assert ( pynccl_comm is not None and not pynccl_comm.disabled diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 34b071486d49..3bc5c72ca214 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -677,10 +677,16 @@ def update_weights_from_ipc( def _set_envs_and_config(server_args: ServerArgs): # Set global environments os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" - if "NCCL_CUMEM_ENABLE" not in os.environ: + if "NCCL_CUMEM_ENABLE" not in os.environ or server_args.enable_symm_mem: os.environ["NCCL_CUMEM_ENABLE"] = str(int(server_args.enable_symm_mem)) - if not server_args.enable_symm_mem: - os.environ["NCCL_NVLS_ENABLE"] = str(int(server_args.enable_nccl_nvls)) + if ( + "NCCL_NVLS_ENABLE" not in os.environ + or server_args.enable_nccl_nvls + or server_args.enable_symm_mem + ): + os.environ["NCCL_NVLS_ENABLE"] = str( + int(server_args.enable_nccl_nvls or server_args.enable_symm_mem) + ) os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "8" os.environ["CUDA_MODULE_LOADING"] = "AUTO" diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py index 23b2635d2421..65003af581e9 100644 --- a/python/sglang/srt/layers/linear.py +++ b/python/sglang/srt/layers/linear.py @@ -13,7 +13,7 @@ divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, - parallel_state, + get_tp_group, split_tensor_along_last_dim, tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce, @@ -1372,7 +1372,7 @@ def forward(self, input_, skip_all_reduce=False): # Only fuse bias add into GEMM for rank 0 (this ensures that # bias will not get added more than once in TP>1 case) bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias - with use_symmetric_memory(parallel_state.get_tp_group()) as sm: + with use_symmetric_memory(get_tp_group()) as sm: output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_) sm.tag(output_parallel) diff --git a/python/sglang/srt/layers/moe/cutlass_moe.py b/python/sglang/srt/layers/moe/cutlass_moe.py index 870749d4799e..368c42929bcc 100755 --- a/python/sglang/srt/layers/moe/cutlass_moe.py +++ b/python/sglang/srt/layers/moe/cutlass_moe.py @@ -1,5 +1,7 @@ """CUTLASS based Fused MoE kernels.""" +from typing import Optional + import torch from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams @@ -40,6 +42,7 @@ def cutlass_fused_experts_fp8( problem_sizes1: torch.Tensor, problem_sizes2: torch.Tensor, use_fp8_blockscale: bool = True, + output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Performs Fused MoE computation using CUTLASS-like kernels with FP8 weights and activations. @@ -200,9 +203,11 @@ def cutlass_fused_experts_fp8( workspace, ) - result = torch.empty((m, k), device=device, dtype=out_dtype) - apply_shuffle_mul_sum(c2, result, c_map, topk_weights.to(out_dtype)) - return result + if output is None: + output = torch.empty((m, k), device=device, dtype=out_dtype) + + apply_shuffle_mul_sum(c2, output, c_map, topk_weights.to(out_dtype)) + return output FLOAT4_E2M1_MAX = 6.0 diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 5b527a386e07..b1f8d3af7feb 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -14,6 +14,9 @@ get_tp_group, tensor_model_parallel_all_reduce, ) +from sglang.srt.distributed.device_communicators.pynccl_allocator import ( + use_symmetric_memory, +) from sglang.srt.eplb.expert_location import get_global_expert_location_metadata from sglang.srt.layers.moe import ( MoeRunnerConfig, @@ -55,11 +58,6 @@ if is_flashinfer_available(): from flashinfer import RoutingMethodType, fp4_quantize -_is_hip = is_hip() -_is_cpu_amx_available = cpu_has_amx_support() -_is_cpu = is_cpu() - - # Try to import FP4 TRTLLM function if flashinfer is available trtllm_fp4_block_scale_moe = None if should_use_flashinfer_trtllm_moe(): @@ -68,6 +66,10 @@ except ImportError: trtllm_fp4_block_scale_moe = None +_is_hip = is_hip() +_is_cpu_amx_available = cpu_has_amx_support() +_is_cpu = is_cpu() + logger = logging.getLogger(__name__) @@ -839,12 +841,16 @@ def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput, **kwargs dispatch_output=dispatch_output, **kwargs, ) - final_hidden_states = self.dispatcher.combine(combine_input=combine_input) - # TODO: should we add some conditions here? - final_hidden_states = final_hidden_states[ - ..., :origin_hidden_states_dim - ].contiguous() + with use_symmetric_memory(get_tp_group()) as sm: + final_hidden_states = self.dispatcher.combine(combine_input=combine_input) + + # TODO: should we add some conditions here? + final_hidden_states = final_hidden_states[ + ..., :origin_hidden_states_dim + ].contiguous() + + sm.tag(final_hidden_states) if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1): final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) @@ -980,6 +986,11 @@ def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput): ), ) + # NOTE for symmetric memory tagging: + # We do not create the context in this function. + # Instead, we create the context and tagging inside each FusedMoEMethodBase + # This can allow fine-grained tagging. + if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1): final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) @@ -1040,6 +1051,10 @@ def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput): router_logits = router_logits.to(torch.float32) + with use_symmetric_memory(get_tp_group()) as sm: + symm_output = torch.empty_like(hidden_states) + sm.tag(symm_output) + result = trtllm_fp4_block_scale_moe( routing_logits=router_logits, routing_bias=topk_config.correction_bias.to(hidden_states.dtype), @@ -1072,6 +1087,7 @@ def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput): tile_tokens_dim=None, routing_method_type=RoutingMethodType.DeepSeekV3, do_finalize=True, + output=symm_output, )[0] return result diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 3fd5e9cf793c..624cf5ee3226 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -28,7 +28,10 @@ def dummy_func(*args, **kwargs): apply_fp8_marlin_linear = prepare_fp8_layer_for_marlin = dummy_func -from sglang.srt.distributed import get_tensor_model_parallel_world_size +from sglang.srt.distributed import get_tensor_model_parallel_world_size, get_tp_group +from sglang.srt.distributed.device_communicators.pynccl_allocator import ( + use_symmetric_memory, +) from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig from sglang.srt.layers.moe.moe_runner.deep_gemm import DeepGemmMoeQuantInfo @@ -1025,6 +1028,10 @@ def apply( if self._should_use_cutlass_fused_experts(): from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8 + with use_symmetric_memory(get_tp_group()) as sm: + symm_output = torch.empty_like(x) + sm.tag(symm_output) + topk_weights, topk_ids, _ = dispatch_output.topk_output output = cutlass_fused_experts_fp8( x, @@ -1048,6 +1055,7 @@ def apply( self.problem_sizes1, self.problem_sizes2, use_fp8_blockscale=True, + output=symm_output, ) return StandardCombineInput(hidden_states=output) @@ -1211,31 +1219,38 @@ def apply_with_router_logits( else topk_config.correction_bias.to(x.dtype) ) - return trtllm_fp8_block_scale_moe( - routing_logits=router_logits.to(torch.float32), - routing_bias=correction_bias, - hidden_states=a_q, - hidden_states_scale=a_sf_t, - gemm1_weights=layer.w13_weight, - gemm1_weights_scale=layer.w13_weight_scale_inv, - gemm2_weights=layer.w2_weight, - gemm2_weights_scale=layer.w2_weight_scale_inv, - num_experts=layer.num_experts, - top_k=topk_config.top_k, - n_group=topk_config.num_expert_group, - topk_group=topk_config.topk_group, - intermediate_size=layer.w2_weight.shape[2], - local_expert_offset=layer.moe_ep_rank * layer.num_local_experts, - local_num_experts=layer.num_local_experts, - routed_scaling_factor=( - routed_scaling_factor if routed_scaling_factor is not None else 1.0 - ), - tile_tokens_dim=get_tile_tokens_dim( - x.shape[0], topk_config.top_k, layer.num_experts - ), - routing_method_type=2, # DeepSeek-styled routing method - use_shuffled_weight=False, - ) + with use_symmetric_memory(get_tp_group()) as sm: + # FIXME: there is a bug in the trtllm_fp8_block_scale_moe. + # It ignored the `output`` argument. https://github.com/flashinfer-ai/flashinfer/blob/da01b1bd8f9f22aec8c0eea189ad54860b034947/flashinfer/fused_moe/core.py#L1323-L1325 + # so we put the whole function under the ``use_symmetric_memory`` context manager. + # If the bug is fixed, we can only put the output tensor allocation under the context manager. + output = trtllm_fp8_block_scale_moe( + routing_logits=router_logits.to(torch.float32), + routing_bias=correction_bias, + hidden_states=a_q, + hidden_states_scale=a_sf_t, + gemm1_weights=layer.w13_weight, + gemm1_weights_scale=layer.w13_weight_scale_inv, + gemm2_weights=layer.w2_weight, + gemm2_weights_scale=layer.w2_weight_scale_inv, + num_experts=layer.num_experts, + top_k=topk_config.top_k, + n_group=topk_config.num_expert_group, + topk_group=topk_config.topk_group, + intermediate_size=layer.w2_weight.shape[2], + local_expert_offset=layer.moe_ep_rank * layer.num_local_experts, + local_num_experts=layer.num_local_experts, + routed_scaling_factor=( + routed_scaling_factor if routed_scaling_factor is not None else 1.0 + ), + tile_tokens_dim=get_tile_tokens_dim( + x.shape[0], topk_config.top_k, layer.num_experts + ), + routing_method_type=2, # DeepSeek-styled routing method + use_shuffled_weight=False, + ) + sm.tag(output) + return output def maybe_apply_hip_fused_experts( self, diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index 0017f740dba9..04423d7ee66b 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -8,6 +8,9 @@ from torch.nn.parameter import Parameter from sglang.srt.distributed import get_tp_group +from sglang.srt.distributed.device_communicators.pynccl_allocator import ( + use_symmetric_memory, +) from sglang.srt.layers.dp_attention import get_dp_global_num_tokens, get_local_dp_buffer from sglang.srt.layers.moe import ( MoeRunner, @@ -657,29 +660,37 @@ def apply( None if correction_bias is None else correction_bias.to(torch.bfloat16) ) - output = trtllm_fp8_per_tensor_scale_moe( - routing_logits=routing_logits_cast, - routing_bias=routing_bias_cast, - hidden_states=x_fp8, - gemm1_weights=layer.w13_weight, - output1_scales_scalar=layer.output1_scales_scalar, - output1_scales_gate_scalar=layer.output1_scales_gate_scalar, - gemm2_weights=layer.w2_weight, - output2_scales_scalar=layer.output2_scales_scalar, - num_experts=layer.num_experts, - top_k=topk_config.top_k, - n_group=0, - topk_group=0, - intermediate_size=layer.w2_weight.shape[2], - local_expert_offset=layer.moe_ep_rank * layer.num_local_experts, - local_num_experts=layer.num_local_experts, - routed_scaling_factor=( - routed_scaling_factor if routed_scaling_factor is not None else 1.0 - ), - use_routing_scales_on_input=use_routing_scales_on_input, - tile_tokens_dim=8, # TODO(brayden): use the FI tile calculation - routing_method_type=routing_method_type, - ) + with use_symmetric_memory(get_tp_group()) as sm: + # FIXME: there is a bug in the trtllm_fp8_block_scale_moe. + # It ignored the `output`` argument. https://github.com/flashinfer-ai/flashinfer/blob/da01b1bd8f9f22aec8c0eea189ad54860b034947/flashinfer/fused_moe/core.py#L1323-L1325 + # so we put the whole function under the ``use_symmetric_memory`` context manager. + # If the bug is fixed, we can only put the output tensor allocation under the context manager. + output = trtllm_fp8_per_tensor_scale_moe( + routing_logits=routing_logits_cast, + routing_bias=routing_bias_cast, + hidden_states=x_fp8, + gemm1_weights=layer.w13_weight, + output1_scales_scalar=layer.output1_scales_scalar, + output1_scales_gate_scalar=layer.output1_scales_gate_scalar, + gemm2_weights=layer.w2_weight, + output2_scales_scalar=layer.output2_scales_scalar, + num_experts=layer.num_experts, + top_k=topk_config.top_k, + n_group=0, + topk_group=0, + intermediate_size=layer.w2_weight.shape[2], + local_expert_offset=layer.moe_ep_rank * layer.num_local_experts, + local_num_experts=layer.num_local_experts, + routed_scaling_factor=( + routed_scaling_factor + if routed_scaling_factor is not None + else 1.0 + ), + use_routing_scales_on_input=use_routing_scales_on_input, + tile_tokens_dim=8, # TODO(brayden): use the FI tile calculation + routing_method_type=routing_method_type, + ) + sm.tag(output) from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput @@ -1585,6 +1596,12 @@ def apply( ) x_sf = nvfp4_block_scale_interleave(x_sf) + with use_symmetric_memory(get_tp_group()) as sm: + symm_output = torch.empty( + x.shape[0], x.shape[1], dtype=output_dtype, device=x.device + ) + sm.tag(symm_output) + output = flashinfer_cutlass_fused_moe( input=x, token_selected_experts=topk_ids.to(torch.int), @@ -1606,6 +1623,7 @@ def apply( tp_size=layer.moe_tp_size, tp_rank=layer.moe_tp_rank, tune_max_num_tokens=next_power_of_2(x.shape[0]), + output=symm_output, )[0] if should_use_flashinfer_cutlass_moe_fp4_allgather(): output, global_output = get_local_dp_buffer(), output diff --git a/python/sglang/srt/layers/quantization/mxfp4.py b/python/sglang/srt/layers/quantization/mxfp4.py index 77f5f3f3567f..387fb53c5817 100644 --- a/python/sglang/srt/layers/quantization/mxfp4.py +++ b/python/sglang/srt/layers/quantization/mxfp4.py @@ -22,6 +22,10 @@ import torch from torch.nn.parameter import Parameter +from sglang.srt.distributed import get_tp_group +from sglang.srt.distributed.device_communicators.pynccl_allocator import ( + use_symmetric_memory, +) from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo from sglang.srt.layers.moe.utils import get_moe_runner_backend @@ -70,14 +74,14 @@ if _is_hip: # import aiter try: - from aiter import ActivationType, QuantType, dtypes + from aiter import ActivationType, QuantType from aiter.fused_moe import fused_moe from aiter.ops.triton.quant import dynamic_mxfp4_quant from aiter.utility.fp4_utils import e8m0_shuffle except ImportError as err: - ActivationType = QuantType = dtypes = fused_moe = dynamic_mxfp4_quant = ( - e8m0_shuffle - ) = err + ActivationType = QuantType = fused_moe = dynamic_mxfp4_quant = e8m0_shuffle = ( + err + ) def _swizzle_mxfp4(quant_tensor, scale, num_warps): @@ -606,8 +610,6 @@ def apply( x = dispatch_output.hidden_states topk_output = dispatch_output.topk_output - moe_runner_config = self.moe_runner_config - if self.use_flashinfer: # When bf16 mode is enabled, we don't need to quantize the input, # TRT-LLM automatically handles quantization in the kernel implementation and pipelines it with GEMM operations, @@ -630,7 +632,7 @@ def apply( x_quant, x_scale = mxfp8_quantize(x, False, alignment=self.hidden_size) x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1) else: - raise NotImplementedError + raise NotImplementedError() assert x_quant.shape[-1] == self.hidden_size assert TopKOutputChecker.format_is_bypassed(topk_output) @@ -638,6 +640,10 @@ def apply( top_k = topk_output.topk_config.top_k router_logits = topk_output.router_logits + with use_symmetric_memory(get_tp_group()) as sm: + symm_output = torch.empty_like(x) + sm.tag(symm_output) + trtllm_gen_output = trtllm_fp4_block_scale_moe( router_logits.to(torch.bfloat16), None, # routing_bias @@ -666,6 +672,7 @@ def apply( None, # tile_tokens_dim 1, # routing_method_type, renormalize True, # do finalize + output=symm_output, )[0] return StandardCombineInput(hidden_states=trtllm_gen_output) diff --git a/python/sglang/srt/layers/vocab_parallel_embedding.py b/python/sglang/srt/layers/vocab_parallel_embedding.py index 986babf2bc44..9a26922656b5 100644 --- a/python/sglang/srt/layers/vocab_parallel_embedding.py +++ b/python/sglang/srt/layers/vocab_parallel_embedding.py @@ -11,7 +11,7 @@ divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, - parallel_state, + get_tp_group, tensor_model_parallel_all_reduce, ) from sglang.srt.distributed.device_communicators.pynccl_allocator import ( @@ -473,7 +473,7 @@ def forward(self, input_): else: masked_input = input_ # Get the embeddings. - with use_symmetric_memory(parallel_state.get_tp_group()) as sm: + with use_symmetric_memory(get_tp_group()) as sm: output_parallel = self.quant_method.embedding(self, masked_input.long()) sm.tag(output_parallel) # Mask the output embedding. diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 64c72b9701a8..5943834f3b0d 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -39,12 +39,8 @@ get_moe_expert_parallel_world_size, get_pp_group, get_tensor_model_parallel_world_size, - parallel_state, tensor_model_parallel_all_reduce, ) -from sglang.srt.distributed.device_communicators.pynccl_allocator import ( - use_symmetric_memory, -) from sglang.srt.environ import envs from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation @@ -760,12 +756,7 @@ def forward_normal_dual_stream( final_hidden_states *= self.routed_scaling_factor current_stream.wait_stream(self.alt_stream) - with use_symmetric_memory(parallel_state.get_tp_group()) as sm: - final_hidden_states_out = torch.empty_like(final_hidden_states) - - torch.add(final_hidden_states, shared_output, out=final_hidden_states_out) - final_hidden_states = final_hidden_states_out - sm.tag(final_hidden_states) + final_hidden_states += shared_output if ( self.tp_size > 1 and not should_allreduce_fusion @@ -824,11 +815,8 @@ def _forward_shared_experts_and_put_results(): # fused in biased_grouped_topk so we can skip here final_hidden_states *= self.routed_scaling_factor if shared_output is not None: - with use_symmetric_memory(parallel_state.get_tp_group()) as sm: - final_hidden_states_out = torch.empty_like(final_hidden_states) - torch.add(final_hidden_states, shared_output, out=final_hidden_states_out) - final_hidden_states = final_hidden_states_out - sm.tag(final_hidden_states) + final_hidden_states += shared_output + if ( self.tp_size > 1 and not should_allreduce_fusion diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 4faed41d346b..4f8c816f1106 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -3168,7 +3168,7 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( "--enable-torch-symm-mem", action="store_true", - help="Enable using torch symm mem for all-reduce kernel and fall back to NCCL. Only supports CUDA device SM90 and above. SM90 supports world size 4, 6, 8. SM10 supports world size 6, 8.", + help="Enable using torch symm mem for all-reduce kernel and fall back to NCCL. Only supports CUDA device SM90 and above. SM90 supports world size 4, 6, 8. SM100 supports world size 6, 8.", ) parser.add_argument( "--disable-overlap-schedule", diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index 375b9b56d9f2..db8ca5d381b7 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -3602,3 +3602,13 @@ def calc_diff(x, y): denominator = (x * x + y * y).sum() sim = 2 * (x * y).sum() / denominator return 1 - sim + + +cached_device_index = -1 + + +def get_current_device_stream_fast(): + global cached_device_index + if cached_device_index == -1: + cached_device_index = torch.get_device_module().current_device() + return torch.get_device_module().current_stream(cached_device_index)