Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
ncclUniqueId,
)
from sglang.srt.distributed.utils import StatelessProcessGroup
from sglang.srt.utils.common import get_current_device_stream_fast

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -137,7 +138,7 @@ def _resolve_stream(self, stream: Optional[torch.cuda.Stream]):
if stream is not None:
return stream
if self.use_current_stream:
return torch.cuda.current_stream()
return get_current_device_stream_fast()
return self.stream

def all_reduce(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,21 +1,30 @@
import os
import tempfile

import torch
from packaging import version
from torch.cuda.memory import CUDAPluggableAllocator

from sglang.srt.distributed.parallel_state import GroupCoordinator
from sglang.srt.server_args import get_global_server_args

nccl_allocator_source = """
#include <nccl.h>

extern "C" {

void* nccl_alloc_plug(size_t size, int device, void* stream) {
void* ptr;
ncclResult_t err = ncclMemAlloc(&ptr, size);
return ptr;

const char *str_val = getenv("SGLANG_TMP_NCCL_COMM_VALUE");
char *endptr;
void* int_val = (void *)strtoull(str_val, &endptr, 0);

ncclComm_t comm = (ncclComm_t)(int_val);
ncclWindow_t win;
ncclResult_t err2 = ncclCommWindowRegister(comm, ptr, size, &win, NCCL_WIN_COLL_SYMMETRIC);

return ptr;
}

void nccl_free_plug(void* ptr, size_t size, int device, void* stream) {
Expand All @@ -27,8 +36,8 @@

_allocator = None
_mem_pool = None
_registered_base_addrs = set()
_graph_pool_id = None
_cur_device = None


def is_symmetric_memory_enabled():
Expand All @@ -41,7 +50,7 @@ def set_graph_pool_id(graph_pool_id):


def get_nccl_mem_pool():
global _allocator, _mem_pool
global _allocator, _mem_pool, _cur_device
if _mem_pool is None:
out_dir = tempfile.gettempdir()
nccl_allocator_libname = "nccl_allocator"
Expand All @@ -60,74 +69,67 @@ def get_nccl_mem_pool():
"nccl_free_plug",
).allocator()
_mem_pool = torch.cuda.MemPool(_allocator)
_cur_device = torch.cuda.current_device()
return _mem_pool


class use_symmetric_memory:
"""
Context manager for using symmetric memory with pynccl.

To Utilize the symmetric memory feature in NCCL, the buffers need to be allocated
by `ncclMemAlloc` and registered by `ncclCommWindowRegister`. Due to this, we introduce
this context manager. All tensors created under this context will be correctly
allocated and registered with a custom allocator.

In addition, developers need to manually tag the tensors that will be used as the input/output
of NCCL collectives with `tag(tensor)`.
"""

def __init__(self, group_coordinator: GroupCoordinator):
if not is_symmetric_memory_enabled():
self.group_coordinator = None
self._mem_pool_ctx = None
self.is_graph_capture = None
self.device = None
self.pre_2_8_0 = None
else:
self.group_coordinator = group_coordinator
self._mem_pool_ctx = torch.cuda.use_mem_pool(get_nccl_mem_pool())
self.is_graph_capture = torch.cuda.is_current_stream_capturing()
self.device = torch.cuda.current_device()
self.pre_2_8_0 = version.parse(torch.__version__) < version.parse("2.8.0")
self.enabled = is_symmetric_memory_enabled()

if not self.enabled:
return

self.group_coordinator = group_coordinator
self._mem_pool_ctx = torch.cuda.use_mem_pool(get_nccl_mem_pool())
self.is_graph_capture = torch.cuda.is_current_stream_capturing()

def __enter__(self):
if not is_symmetric_memory_enabled():
if not self.enabled:
return self

assert (
self.group_coordinator.pynccl_comm is not None
), f"Symmetric memory requires pynccl to be enabled in group '{self.group_coordinator.group_name}'"
assert (
self.group_coordinator.pynccl_comm.nccl_version >= 22703
), "NCCL version 2.27.3 or higher is required for NCCL symmetric memory"

if self.is_graph_capture:
assert (
_graph_pool_id is not None
), "graph_pool_id is not set under graph capture"
# Pause graph memory pool to use symmetric memory with cuda graph
if self.pre_2_8_0:
torch._C._cuda_endAllocateCurrentStreamToPool(
self.device, _graph_pool_id
)
else:
torch._C._cuda_endAllocateToPool(self.device, _graph_pool_id)
torch._C._cuda_endAllocateToPool(_cur_device, _graph_pool_id)

self._mem_pool_ctx.__enter__()
return self

def tag(self, tensor: torch.Tensor):
if not is_symmetric_memory_enabled():
return
tensor.symmetric_memory = True
# Set the env var to pass this argument to the C functions.
os.environ["SGLANG_TMP_NCCL_COMM_VALUE"] = str(
self.group_coordinator.pynccl_comm.comm.value
)
return self

def __exit__(self, exc_type, exc_val, exc_tb):
if not is_symmetric_memory_enabled():
if not self.enabled:
return
global _registered_base_addrs

self._mem_pool_ctx.__exit__(exc_type, exc_val, exc_tb)
for segment in get_nccl_mem_pool().snapshot():
if segment["address"] not in _registered_base_addrs:
if segment["stream"] == 0 and self.pre_2_8_0:
# PyTorch version < 2.8.0 has a multi-thread MemPool bug
# See https://github.com/pytorch/pytorch/issues/152861
# Fixed at https://github.com/pytorch/pytorch/commit/f01e628e3b31852983ab30b25bf251f557ba9c0b
# WAR is to skip allocations on the default stream since the forward_pass thread always runs on a custom stream
continue
self.group_coordinator.pynccl_comm.register_comm_window_raw(
segment["address"], segment["total_size"]
)
_registered_base_addrs.add(segment["address"])

if self.is_graph_capture:
if self.pre_2_8_0:
torch._C._cuda_beginAllocateToPool(self.device, _graph_pool_id)
else:
torch._C._cuda_beginAllocateCurrentThreadToPool(
self.device, _graph_pool_id
)
torch._C._cuda_beginAllocateCurrentThreadToPool(_cur_device, _graph_pool_id)

def tag(self, tensor: torch.Tensor):
if not self.enabled:
return

tensor.symmetric_memory = True
17 changes: 7 additions & 10 deletions python/sglang/srt/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from sglang.srt.utils import (
direct_register_custom_op,
get_bool_env_var,
get_current_device_stream_fast,
get_int_env_var,
get_local_ip_auto,
is_cpu,
Expand Down Expand Up @@ -466,7 +467,7 @@ def graph_capture(

# ensure all initialization operations complete before attempting to
# capture the graph on another stream
curr_stream = self.device_module.current_stream()
curr_stream = get_current_device_stream_fast()
if curr_stream != stream:
stream.wait_stream(curr_stream)

Expand Down Expand Up @@ -500,7 +501,7 @@ def graph_capture(
maybe_pynccl_context = nullcontext()
else:
maybe_pynccl_context = pynccl_comm.change_state(
enable=True, stream=torch.get_device_module().current_stream()
enable=True, stream=get_current_device_stream_fast()
)

pymscclpp_comm = self.pymscclpp_comm
Expand Down Expand Up @@ -551,13 +552,9 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
if self.npu_communicator is not None and not self.npu_communicator.disabled:
return self.npu_communicator.all_reduce(input_)

if (
self.pynccl_comm is not None
and hasattr(input_, "symmetric_memory")
and input_.symmetric_memory
):
if self.pynccl_comm is not None and getattr(input_, "symmetric_memory", False):
with self.pynccl_comm.change_state(
enable=True, stream=torch.get_device_module().current_stream()
enable=True, stream=get_current_device_stream_fast()
):
self.pynccl_comm.all_reduce(input_)
return input_
Expand Down Expand Up @@ -658,7 +655,7 @@ def reduce_scatterv(
pynccl_comm = self.pynccl_comm

with pynccl_comm.change_state(
enable=True, stream=torch.get_device_module().current_stream()
enable=True, stream=get_current_device_stream_fast()
):
assert (
pynccl_comm is not None and not pynccl_comm.disabled
Expand Down Expand Up @@ -784,7 +781,7 @@ def all_gatherv(
pynccl_comm = self.pynccl_comm

with pynccl_comm.change_state(
enable=True, stream=torch.get_device_module().current_stream()
enable=True, stream=get_current_device_stream_fast()
):
assert (
pynccl_comm is not None and not pynccl_comm.disabled
Expand Down
12 changes: 9 additions & 3 deletions python/sglang/srt/entrypoints/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,10 +677,16 @@ def update_weights_from_ipc(
def _set_envs_and_config(server_args: ServerArgs):
# Set global environments
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
if "NCCL_CUMEM_ENABLE" not in os.environ:
if "NCCL_CUMEM_ENABLE" not in os.environ or server_args.enable_symm_mem:
os.environ["NCCL_CUMEM_ENABLE"] = str(int(server_args.enable_symm_mem))
if not server_args.enable_symm_mem:
os.environ["NCCL_NVLS_ENABLE"] = str(int(server_args.enable_nccl_nvls))
if (
"NCCL_NVLS_ENABLE" not in os.environ
or server_args.enable_nccl_nvls
or server_args.enable_symm_mem
):
os.environ["NCCL_NVLS_ENABLE"] = str(
int(server_args.enable_nccl_nvls or server_args.enable_symm_mem)
)
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "8"
os.environ["CUDA_MODULE_LOADING"] = "AUTO"

Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
divide,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
parallel_state,
get_tp_group,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
Expand Down Expand Up @@ -1372,7 +1372,7 @@ def forward(self, input_, skip_all_reduce=False):
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
with use_symmetric_memory(get_tp_group()) as sm:
output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
sm.tag(output_parallel)

Expand Down
11 changes: 8 additions & 3 deletions python/sglang/srt/layers/moe/cutlass_moe.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""CUTLASS based Fused MoE kernels."""

from typing import Optional

import torch

from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams
Expand Down Expand Up @@ -40,6 +42,7 @@ def cutlass_fused_experts_fp8(
problem_sizes1: torch.Tensor,
problem_sizes2: torch.Tensor,
use_fp8_blockscale: bool = True,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Performs Fused MoE computation using CUTLASS-like kernels with FP8 weights and activations.

Expand Down Expand Up @@ -200,9 +203,11 @@ def cutlass_fused_experts_fp8(
workspace,
)

result = torch.empty((m, k), device=device, dtype=out_dtype)
apply_shuffle_mul_sum(c2, result, c_map, topk_weights.to(out_dtype))
return result
if output is None:
output = torch.empty((m, k), device=device, dtype=out_dtype)

apply_shuffle_mul_sum(c2, output, c_map, topk_weights.to(out_dtype))
return output


FLOAT4_E2M1_MAX = 6.0
Expand Down
36 changes: 26 additions & 10 deletions python/sglang/srt/layers/moe/fused_moe_triton/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
get_tp_group,
tensor_model_parallel_all_reduce,
)
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
use_symmetric_memory,
)
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
from sglang.srt.layers.moe import (
MoeRunnerConfig,
Expand Down Expand Up @@ -55,11 +58,6 @@
if is_flashinfer_available():
from flashinfer import RoutingMethodType, fp4_quantize

_is_hip = is_hip()
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()


# Try to import FP4 TRTLLM function if flashinfer is available
trtllm_fp4_block_scale_moe = None
if should_use_flashinfer_trtllm_moe():
Expand All @@ -68,6 +66,10 @@
except ImportError:
trtllm_fp4_block_scale_moe = None

_is_hip = is_hip()
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -839,12 +841,16 @@ def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput, **kwargs
dispatch_output=dispatch_output,
**kwargs,
)
final_hidden_states = self.dispatcher.combine(combine_input=combine_input)

# TODO: should we add some conditions here?
final_hidden_states = final_hidden_states[
..., :origin_hidden_states_dim
].contiguous()
with use_symmetric_memory(get_tp_group()) as sm:
final_hidden_states = self.dispatcher.combine(combine_input=combine_input)

# TODO: should we add some conditions here?
final_hidden_states = final_hidden_states[
..., :origin_hidden_states_dim
].contiguous()

sm.tag(final_hidden_states)

if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
Expand Down Expand Up @@ -980,6 +986,11 @@ def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
),
)

# NOTE for symmetric memory tagging:
# We do not create the context in this function.
# Instead, we create the context and tagging inside each FusedMoEMethodBase
# This can allow fine-grained tagging.

if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)

Expand Down Expand Up @@ -1040,6 +1051,10 @@ def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):

router_logits = router_logits.to(torch.float32)

with use_symmetric_memory(get_tp_group()) as sm:
symm_output = torch.empty_like(hidden_states)
sm.tag(symm_output)

result = trtllm_fp4_block_scale_moe(
routing_logits=router_logits,
routing_bias=topk_config.correction_bias.to(hidden_states.dtype),
Expand Down Expand Up @@ -1072,6 +1087,7 @@ def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
tile_tokens_dim=None,
routing_method_type=RoutingMethodType.DeepSeekV3,
do_finalize=True,
output=symm_output,
)[0]

return result
Loading
Loading