Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
"""

_allocator = None
_mem_pool = None
_mem_pool_map = {}
_graph_pool_id = None
_cur_device = None
_active_symmetric_memory_context = None
Expand Down Expand Up @@ -99,9 +99,9 @@ def restore_symmetric_memory_context(saved_context):
saved_context.__enter__()


def get_nccl_mem_pool():
global _allocator, _mem_pool, _cur_device
if _mem_pool is None:
def init_pynccl_allocator():
global _allocator, _cur_device
if _allocator is None:
import torch.utils.cpp_extension

out_dir = tempfile.gettempdir()
Expand All @@ -120,9 +120,16 @@ def get_nccl_mem_pool():
"nccl_alloc_plug",
"nccl_free_plug",
).allocator()
_mem_pool = torch.cuda.MemPool(_allocator)
_cur_device = torch.cuda.current_device()
return _mem_pool


def get_nccl_mem_pool(group_coordinator: GroupCoordinator):
global _allocator, _mem_pool_map
if group_coordinator.unique_name not in _mem_pool_map:
init_pynccl_allocator()
_mem_pool = torch.cuda.MemPool(_allocator)
_mem_pool_map[group_coordinator.unique_name] = _mem_pool
return _mem_pool_map[group_coordinator.unique_name]


class SymmetricMemoryContext:
Expand All @@ -140,7 +147,9 @@ def __init__(
group_coordinator: GroupCoordinator,
):
self.group_coordinator = group_coordinator
self._mem_pool_ctx = torch.cuda.use_mem_pool(get_nccl_mem_pool())
self._mem_pool_ctx = torch.cuda.use_mem_pool(
get_nccl_mem_pool(self.group_coordinator)
)
self.is_graph_capture = torch.cuda.is_current_stream_capturing()
self.exited = False

Expand All @@ -163,7 +172,9 @@ def __enter__(self):

if self.exited:
# mempool ctx (@contextlib.contextmanager) is not re-entrant
self._mem_pool_ctx = torch.cuda.use_mem_pool(get_nccl_mem_pool())
self._mem_pool_ctx = torch.cuda.use_mem_pool(
get_nccl_mem_pool(self.group_coordinator)
)
self.exited = False
self._mem_pool_ctx.__enter__()

Expand Down
72 changes: 71 additions & 1 deletion python/sglang/srt/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,38 @@ def reduce_scatter(
torch.distributed.reduce_scatter(output, input_list, group=self.device_group)
return output

def reduce_scatter_along_dim(
self, input_: torch.Tensor, dim: int = -1
) -> torch.Tensor:
world_size = self.world_size
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
assert (
-input_.dim() <= dim < input_.dim()
), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"

if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()

with self.use_symmetric_memory(self):
input_tensor = input_.movedim(0, dim).contiguous()
assert input_tensor.shape[0] % world_size == 0

chunk_size = input_tensor.shape[0] // world_size
output_shape = (chunk_size,) + input_tensor.shape[1:]
with self.use_symmetric_memory(self):
output_tensor = torch.empty(
output_shape, dtype=input_tensor.dtype, device=input_tensor.device
)

# Perform reduce-scatter operation
self.reduce_scatter_tensor(output_tensor, input_tensor)

# Reshape before returning
return output_tensor.movedim(0, dim).contiguous()

def reduce_scatterv(
self,
input_: torch.Tensor,
Expand Down Expand Up @@ -1421,6 +1453,14 @@ def get_pp_group() -> GroupCoordinator:
return _PP


_DCP: Optional[GroupCoordinator] = None


def get_dcp_group() -> GroupCoordinator:
assert _DCP is not None, "decode context parallel group is not initialized"
return _DCP


# kept for backward compatibility
get_pipeline_model_parallel_group = get_pp_group

Expand All @@ -1442,7 +1482,9 @@ def graph_capture(stream: Optional[torch.cuda.Stream] = None):
"""
with get_tp_group().graph_capture(
stream=stream
) as context, get_pp_group().graph_capture(context):
) as context, get_pp_group().graph_capture(context), get_dcp_group().graph_capture(
context
):
yield context


Expand Down Expand Up @@ -1538,6 +1580,7 @@ def initialize_model_parallel(
tensor_model_parallel_size: int = 1,
expert_model_parallel_size: int = 1,
pipeline_model_parallel_size: int = 1,
decode_context_parallel_size: int = 1,
backend: Optional[str] = None,
duplicate_tp_group: bool = False,
) -> None:
Expand Down Expand Up @@ -1617,6 +1660,26 @@ def initialize_model_parallel(
_TP.pynccl_comm.disabled = False
_PDMUX_PREFILL_TP_GROUP.pynccl_comm.disabled = False

# Build the decode context parallel groups.
num_decode_context_parallel_groups: int = world_size // decode_context_parallel_size
global _DCP
assert _DCP is None, "decode context parallel group is already initialized"
group_ranks = []
for i in range(num_decode_context_parallel_groups):
ranks = list(
range(
i * decode_context_parallel_size,
(i + 1) * decode_context_parallel_size,
)
)
group_ranks.append(ranks)
_DCP = init_model_parallel_group(
group_ranks,
get_world_group().local_rank,
backend,
group_name="dcp",
)

moe_ep_size = expert_model_parallel_size
moe_tp_size = tensor_model_parallel_size // moe_ep_size

Expand Down Expand Up @@ -1731,6 +1794,7 @@ def ensure_model_parallel_initialized(
tensor_model_parallel_size: int,
expert_model_parallel_size: int,
pipeline_model_parallel_size: int,
decode_context_parallel_size: int,
backend: Optional[str] = None,
) -> None:
"""Helper to initialize model parallel groups if they are not initialized,
Expand All @@ -1743,6 +1807,7 @@ def ensure_model_parallel_initialized(
tensor_model_parallel_size,
expert_model_parallel_size,
pipeline_model_parallel_size,
decode_context_parallel_size,
backend,
)
return
Expand Down Expand Up @@ -1850,6 +1915,11 @@ def destroy_model_parallel():
_TP.destroy()
_TP = None

global _DCP
if _DCP:
_DCP.destroy()
_DCP = None

global _PP
if _PP:
_PP.destroy()
Expand Down
5 changes: 5 additions & 0 deletions python/sglang/srt/entrypoints/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,6 +768,11 @@ def _set_envs_and_config(server_args: ServerArgs):
os.environ["NCCL_NVLS_ENABLE"] = str(
int(server_args.enable_nccl_nvls or server_args.enable_symm_mem)
)
if "NCCL_GRAPH_MIXING_SUPPORT" not in os.environ and server_args.dcp_size > 1:
# NCCL_GRAPH_MIXING_SUPPORT=0 can avoid the unnecessary EVENT_WAIT and EVENT_RECORD in cuda graph.
# This is helpful for improving DCP performance because it reduces bubbles.
# https://discuss.pytorch.org/t/unexplained-gaps-in-execution-before-nccl-operations-when-using-cuda-graphs/197818/15
os.environ["NCCL_GRAPH_MIXING_SUPPORT"] = "0"
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "8"
os.environ["CUDA_MODULE_LOADING"] = "AUTO"

Expand Down
Loading
Loading