Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 39 additions & 1 deletion python/sglang/srt/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1519,6 +1519,14 @@ def get_pp_group() -> GroupCoordinator:
return _PP


_DCP: Optional[GroupCoordinator] = None


def get_dcp_group() -> GroupCoordinator:
assert _DCP is not None, "decode context parallel group is not initialized"
return _DCP


# kept for backward compatibility
get_pipeline_model_parallel_group = get_pp_group

Expand Down Expand Up @@ -1552,7 +1560,9 @@ def graph_capture(stream: Optional[torch.cuda.Stream] = None):
"""
with get_tp_group().graph_capture(
stream=stream
) as context, get_pp_group().graph_capture(context):
) as context, get_pp_group().graph_capture(context), get_dcp_group().graph_capture(
context
):
yield context


Expand Down Expand Up @@ -1665,6 +1675,7 @@ def initialize_model_parallel(
attention_data_parallel_size: int = 1,
attention_context_model_parallel_size: int = 1,
moe_data_model_parallel_size: int = 1,
decode_context_parallel_size: int = 1,
backend: Optional[str] = None,
duplicate_tp_group: bool = False,
) -> None:
Expand Down Expand Up @@ -1836,6 +1847,26 @@ def initialize_model_parallel(
group_name="attention_tp",
)

# Build the decode context parallel groups.
num_decode_context_parallel_groups: int = world_size // decode_context_parallel_size
global _DCP
assert _DCP is None, "decode context parallel group is already initialized"
group_ranks = []
for i in range(num_decode_context_parallel_groups):
ranks = list(
range(
i * decode_context_parallel_size,
(i + 1) * decode_context_parallel_size,
)
)
group_ranks.append(ranks)
_DCP = init_model_parallel_group(
group_ranks,
get_world_group().local_rank,
backend,
group_name="dcp",
)

moe_ep_size = expert_model_parallel_size
moe_dp_size = moe_data_model_parallel_size
moe_tp_size = tensor_model_parallel_size // moe_ep_size // moe_dp_size
Expand Down Expand Up @@ -1986,6 +2017,7 @@ def ensure_model_parallel_initialized(
tensor_model_parallel_size: int,
expert_model_parallel_size: int,
pipeline_model_parallel_size: int,
decode_context_parallel_size: int,
backend: Optional[str] = None,
) -> None:
"""Helper to initialize model parallel groups if they are not initialized,
Expand All @@ -1998,6 +2030,7 @@ def ensure_model_parallel_initialized(
tensor_model_parallel_size,
expert_model_parallel_size,
pipeline_model_parallel_size,
decode_context_parallel_size,
backend,
)
return
Expand Down Expand Up @@ -2140,6 +2173,11 @@ def destroy_model_parallel():
_TP.destroy()
_TP = None

global _DCP
if _DCP:
_DCP.destroy()
_DCP = None

global _PP
if _PP:
_PP.destroy()
Expand Down
5 changes: 5 additions & 0 deletions python/sglang/srt/entrypoints/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,6 +838,11 @@ def _set_envs_and_config(server_args: ServerArgs):
os.environ["NCCL_NVLS_ENABLE"] = str(
int(server_args.enable_nccl_nvls or server_args.enable_symm_mem)
)
if "NCCL_GRAPH_MIXING_SUPPORT" not in os.environ and server_args.dcp_size > 1:
# NCCL_GRAPH_MIXING_SUPPORT=0 can avoid the unnecessary EVENT_WAIT and EVENT_RECORD in cuda graph.
# This is helpful for improving DCP performance because it reduces bubbles.
# https://discuss.pytorch.org/t/unexplained-gaps-in-execution-before-nccl-operations-when-using-cuda-graphs/197818/15
os.environ["NCCL_GRAPH_MIXING_SUPPORT"] = "0"
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "8"
os.environ["CUDA_MODULE_LOADING"] = "AUTO"

Expand Down
5 changes: 4 additions & 1 deletion python/sglang/srt/layers/attention/nsa/dequant_k_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def dequantize_k_cache_paged(
quant_k_cache: torch.Tensor,
page_table_1_flattened: torch.Tensor,
group_size: int = 128,
dcp_size: int = 1,
) -> torch.Tensor:
"""
De-quantize the k-cache with paged layout
Expand Down Expand Up @@ -226,6 +227,7 @@ def dequantize_k_cache_paged(
GROUP_SIZE=group_size,
DIM_NOPE=dim_nope,
DIM_ROPE=dim_rope,
DCP_SIZE=dcp_size,
)

return output
Expand All @@ -246,9 +248,10 @@ def _dequantize_k_cache_paged_kernel(
GROUP_SIZE: tl.constexpr,
DIM_NOPE: tl.constexpr,
DIM_ROPE: tl.constexpr,
DCP_SIZE: tl.constexpr,
):
token_id = tl.program_id(0)
token_id_paged = tl.load(page_table_1_ptr + token_id).to(tl.int32)
token_id_paged = tl.load(page_table_1_ptr + token_id).to(tl.int32) // DCP_SIZE
raw_block_id = tl.program_id(1)

if raw_block_id < NUM_NOPE_BLOCKS:
Expand Down
14 changes: 13 additions & 1 deletion python/sglang/srt/layers/attention/nsa/transform_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def transform_index_page_table_decode_kernel(
result_ptr: torch.Tensor,
page_size: tl.constexpr,
max_seqlen_k: tl.constexpr,
dcp_size: tl.constexpr,
):
TOPK: tl.constexpr = 2048
req_id = tl.program_id(0)
Expand All @@ -30,7 +31,9 @@ def transform_index_page_table_decode_kernel(
offset = tl.arange(0, TOPK) # topk should be 2048
loaded_topk_indices = tl.load(topk_indices_ptr + offset)
mask = loaded_topk_indices >= 0
loaded_kv_indices = tl.load(page_table_ptr + loaded_topk_indices, mask=mask)
loaded_kv_indices = (
tl.load(page_table_ptr + loaded_topk_indices, mask=mask) // dcp_size
)
tl.store(result_ptr + offset, loaded_kv_indices, mask=mask)
tl.store(result_ptr + offset, -1, mask=~mask)

Expand All @@ -40,6 +43,7 @@ def transform_index_page_table_decode_fast(
topk_indices: torch.Tensor,
result: Optional[torch.Tensor] = None,
page_size: int = 1,
dcp_size: int = 1,
) -> torch.Tensor:
"""
Transform the page table according to topk indices for sparse topk attention.
Expand All @@ -65,6 +69,7 @@ def transform_index_page_table_decode_fast(
result,
page_size,
max_seqlen_k=max_seqlen_k,
dcp_size=dcp_size,
)
return result

Expand All @@ -74,6 +79,7 @@ def transform_index_page_table_prefill_fast(
topk_indices: torch.Tensor,
extend_lens_cpu: List[int],
page_size: int = 1,
dcp_size: int = 1,
) -> torch.Tensor:
# TODO(baizhou): can be implemented with another triton kernel
assert page_size == 1
Expand All @@ -85,6 +91,7 @@ def transform_index_page_table_prefill_fast(
page_table[i].unsqueeze(0).expand(l, -1),
topk_indices[offset : offset + l],
result=result[offset : offset + l],
dcp_size=dcp_size,
)
offset += l
assert offset == topk_indices.shape[0]
Expand All @@ -96,6 +103,7 @@ def transform_index_page_table_decode_ref(
topk_indices: torch.Tensor,
result: Optional[torch.Tensor] = None,
page_size: int = 1,
dcp_size: int = 1,
) -> torch.Tensor:
assert page_size == 1
assert page_table.shape[0] == topk_indices.shape[0]
Expand All @@ -108,6 +116,8 @@ def transform_index_page_table_decode_ref(
index=topk_indices.clamp(min=0),
out=result,
)
if dcp_size > 1:
result //= dcp_size
result[topk_indices < 0] = -1
return result

Expand All @@ -117,6 +127,7 @@ def transform_index_page_table_prefill_ref(
topk_indices: torch.Tensor,
extend_lens_cpu: List[int],
page_size: int = 1,
dcp_size: int = 1,
) -> torch.Tensor:
assert page_size == 1
result = torch.empty_like(topk_indices, dtype=torch.int32)
Expand All @@ -127,6 +138,7 @@ def transform_index_page_table_prefill_ref(
page_table[i].unsqueeze(0).expand(l, -1),
topk_indices[offset : offset + l],
result=result[offset : offset + l],
dcp_size=dcp_size,
)
offset += l
assert offset == topk_indices.shape[0]
Expand Down
4 changes: 4 additions & 0 deletions python/sglang/srt/layers/attention/nsa/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ def is_nsa_enable_prefill_cp():
return get_global_server_args().enable_nsa_prefill_context_parallel


def is_nsa_enable_decode_cp():
return get_global_server_args().dcp_size > 1


def is_nsa_prefill_cp_in_seq_split():
return (
is_nsa_enable_prefill_cp()
Expand Down
Loading
Loading