diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index 5b8a82232b5b..c623a56d7b5b 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -1519,6 +1519,14 @@ def get_pp_group() -> GroupCoordinator: return _PP +_DCP: Optional[GroupCoordinator] = None + + +def get_dcp_group() -> GroupCoordinator: + assert _DCP is not None, "decode context parallel group is not initialized" + return _DCP + + # kept for backward compatibility get_pipeline_model_parallel_group = get_pp_group @@ -1552,7 +1560,9 @@ def graph_capture(stream: Optional[torch.cuda.Stream] = None): """ with get_tp_group().graph_capture( stream=stream - ) as context, get_pp_group().graph_capture(context): + ) as context, get_pp_group().graph_capture(context), get_dcp_group().graph_capture( + context + ): yield context @@ -1665,6 +1675,7 @@ def initialize_model_parallel( attention_data_parallel_size: int = 1, attention_context_model_parallel_size: int = 1, moe_data_model_parallel_size: int = 1, + decode_context_parallel_size: int = 1, backend: Optional[str] = None, duplicate_tp_group: bool = False, ) -> None: @@ -1836,6 +1847,26 @@ def initialize_model_parallel( group_name="attention_tp", ) + # Build the decode context parallel groups. + num_decode_context_parallel_groups: int = world_size // decode_context_parallel_size + global _DCP + assert _DCP is None, "decode context parallel group is already initialized" + group_ranks = [] + for i in range(num_decode_context_parallel_groups): + ranks = list( + range( + i * decode_context_parallel_size, + (i + 1) * decode_context_parallel_size, + ) + ) + group_ranks.append(ranks) + _DCP = init_model_parallel_group( + group_ranks, + get_world_group().local_rank, + backend, + group_name="dcp", + ) + moe_ep_size = expert_model_parallel_size moe_dp_size = moe_data_model_parallel_size moe_tp_size = tensor_model_parallel_size // moe_ep_size // moe_dp_size @@ -1986,6 +2017,7 @@ def ensure_model_parallel_initialized( tensor_model_parallel_size: int, expert_model_parallel_size: int, pipeline_model_parallel_size: int, + decode_context_parallel_size: int, backend: Optional[str] = None, ) -> None: """Helper to initialize model parallel groups if they are not initialized, @@ -1998,6 +2030,7 @@ def ensure_model_parallel_initialized( tensor_model_parallel_size, expert_model_parallel_size, pipeline_model_parallel_size, + decode_context_parallel_size, backend, ) return @@ -2140,6 +2173,11 @@ def destroy_model_parallel(): _TP.destroy() _TP = None + global _DCP + if _DCP: + _DCP.destroy() + _DCP = None + global _PP if _PP: _PP.destroy() diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 9c460f559410..c7caa83231aa 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -838,6 +838,11 @@ def _set_envs_and_config(server_args: ServerArgs): os.environ["NCCL_NVLS_ENABLE"] = str( int(server_args.enable_nccl_nvls or server_args.enable_symm_mem) ) + if "NCCL_GRAPH_MIXING_SUPPORT" not in os.environ and server_args.dcp_size > 1: + # NCCL_GRAPH_MIXING_SUPPORT=0 can avoid the unnecessary EVENT_WAIT and EVENT_RECORD in cuda graph. + # This is helpful for improving DCP performance because it reduces bubbles. + # https://discuss.pytorch.org/t/unexplained-gaps-in-execution-before-nccl-operations-when-using-cuda-graphs/197818/15 + os.environ["NCCL_GRAPH_MIXING_SUPPORT"] = "0" os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "8" os.environ["CUDA_MODULE_LOADING"] = "AUTO" diff --git a/python/sglang/srt/layers/attention/nsa/dequant_k_cache.py b/python/sglang/srt/layers/attention/nsa/dequant_k_cache.py index 5c69d135bb24..0a1e4d9962ad 100644 --- a/python/sglang/srt/layers/attention/nsa/dequant_k_cache.py +++ b/python/sglang/srt/layers/attention/nsa/dequant_k_cache.py @@ -169,6 +169,7 @@ def dequantize_k_cache_paged( quant_k_cache: torch.Tensor, page_table_1_flattened: torch.Tensor, group_size: int = 128, + dcp_size: int = 1, ) -> torch.Tensor: """ De-quantize the k-cache with paged layout @@ -226,6 +227,7 @@ def dequantize_k_cache_paged( GROUP_SIZE=group_size, DIM_NOPE=dim_nope, DIM_ROPE=dim_rope, + DCP_SIZE=dcp_size, ) return output @@ -246,9 +248,10 @@ def _dequantize_k_cache_paged_kernel( GROUP_SIZE: tl.constexpr, DIM_NOPE: tl.constexpr, DIM_ROPE: tl.constexpr, + DCP_SIZE: tl.constexpr, ): token_id = tl.program_id(0) - token_id_paged = tl.load(page_table_1_ptr + token_id).to(tl.int32) + token_id_paged = tl.load(page_table_1_ptr + token_id).to(tl.int32) // DCP_SIZE raw_block_id = tl.program_id(1) if raw_block_id < NUM_NOPE_BLOCKS: diff --git a/python/sglang/srt/layers/attention/nsa/transform_index.py b/python/sglang/srt/layers/attention/nsa/transform_index.py index 10b1068f5241..77f8856ae1cd 100644 --- a/python/sglang/srt/layers/attention/nsa/transform_index.py +++ b/python/sglang/srt/layers/attention/nsa/transform_index.py @@ -20,6 +20,7 @@ def transform_index_page_table_decode_kernel( result_ptr: torch.Tensor, page_size: tl.constexpr, max_seqlen_k: tl.constexpr, + dcp_size: tl.constexpr, ): TOPK: tl.constexpr = 2048 req_id = tl.program_id(0) @@ -30,7 +31,9 @@ def transform_index_page_table_decode_kernel( offset = tl.arange(0, TOPK) # topk should be 2048 loaded_topk_indices = tl.load(topk_indices_ptr + offset) mask = loaded_topk_indices >= 0 - loaded_kv_indices = tl.load(page_table_ptr + loaded_topk_indices, mask=mask) + loaded_kv_indices = ( + tl.load(page_table_ptr + loaded_topk_indices, mask=mask) // dcp_size + ) tl.store(result_ptr + offset, loaded_kv_indices, mask=mask) tl.store(result_ptr + offset, -1, mask=~mask) @@ -40,6 +43,7 @@ def transform_index_page_table_decode_fast( topk_indices: torch.Tensor, result: Optional[torch.Tensor] = None, page_size: int = 1, + dcp_size: int = 1, ) -> torch.Tensor: """ Transform the page table according to topk indices for sparse topk attention. @@ -65,6 +69,7 @@ def transform_index_page_table_decode_fast( result, page_size, max_seqlen_k=max_seqlen_k, + dcp_size=dcp_size, ) return result @@ -74,6 +79,7 @@ def transform_index_page_table_prefill_fast( topk_indices: torch.Tensor, extend_lens_cpu: List[int], page_size: int = 1, + dcp_size: int = 1, ) -> torch.Tensor: # TODO(baizhou): can be implemented with another triton kernel assert page_size == 1 @@ -85,6 +91,7 @@ def transform_index_page_table_prefill_fast( page_table[i].unsqueeze(0).expand(l, -1), topk_indices[offset : offset + l], result=result[offset : offset + l], + dcp_size=dcp_size, ) offset += l assert offset == topk_indices.shape[0] @@ -96,6 +103,7 @@ def transform_index_page_table_decode_ref( topk_indices: torch.Tensor, result: Optional[torch.Tensor] = None, page_size: int = 1, + dcp_size: int = 1, ) -> torch.Tensor: assert page_size == 1 assert page_table.shape[0] == topk_indices.shape[0] @@ -108,6 +116,8 @@ def transform_index_page_table_decode_ref( index=topk_indices.clamp(min=0), out=result, ) + if dcp_size > 1: + result //= dcp_size result[topk_indices < 0] = -1 return result @@ -117,6 +127,7 @@ def transform_index_page_table_prefill_ref( topk_indices: torch.Tensor, extend_lens_cpu: List[int], page_size: int = 1, + dcp_size: int = 1, ) -> torch.Tensor: assert page_size == 1 result = torch.empty_like(topk_indices, dtype=torch.int32) @@ -127,6 +138,7 @@ def transform_index_page_table_prefill_ref( page_table[i].unsqueeze(0).expand(l, -1), topk_indices[offset : offset + l], result=result[offset : offset + l], + dcp_size=dcp_size, ) offset += l assert offset == topk_indices.shape[0] diff --git a/python/sglang/srt/layers/attention/nsa/utils.py b/python/sglang/srt/layers/attention/nsa/utils.py index 00ef96f9b8b3..79733362383f 100644 --- a/python/sglang/srt/layers/attention/nsa/utils.py +++ b/python/sglang/srt/layers/attention/nsa/utils.py @@ -35,6 +35,10 @@ def is_nsa_enable_prefill_cp(): return get_global_server_args().enable_nsa_prefill_context_parallel +def is_nsa_enable_decode_cp(): + return get_global_server_args().dcp_size > 1 + + def is_nsa_prefill_cp_in_seq_split(): return ( is_nsa_enable_prefill_cp() diff --git a/python/sglang/srt/layers/attention/nsa_backend.py b/python/sglang/srt/layers/attention/nsa_backend.py index 170f75e3edc6..6cd5c8320e96 100644 --- a/python/sglang/srt/layers/attention/nsa_backend.py +++ b/python/sglang/srt/layers/attention/nsa_backend.py @@ -7,6 +7,7 @@ import torch from sglang.srt.configs.model_config import get_nsa_index_topk, is_deepseek_nsa +from sglang.srt.distributed.parallel_state import get_dcp_group from sglang.srt.environ import envs from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.nsa.dequant_k_cache import dequantize_k_cache_paged @@ -28,6 +29,7 @@ from sglang.srt.layers.attention.nsa.utils import ( can_nsa_prefill_cp_round_robin_split, compute_nsa_seqlens, + is_nsa_enable_decode_cp, is_nsa_enable_prefill_cp, nsa_cp_round_robin_split_data, nsa_cp_round_robin_split_q_seqs, @@ -37,6 +39,7 @@ concat_mla_absorb_q_general, mla_quantize_and_rope_for_fp8, ) +from sglang.srt.layers.attention.utils import cp_lse_ag_out_rs from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.utils import is_cuda, is_hip @@ -358,6 +361,9 @@ def __init__( else: self.workspace_buffer = None + self.dcp_size = get_dcp_group().world_size + self.dcp_rank = get_dcp_group().rank_in_group + def get_device_int32_arange(self, l: int) -> torch.Tensor: if l > len(self._arange_buf): next_pow_of_2 = 1 << (l - 1).bit_length() @@ -1291,6 +1297,33 @@ def init_forward_metadata_replay_cuda_graph_from_precomputed( self.forward_metadata = metadata + def _save_kv_cache( + self, + layer: RadixAttention, + forward_batch: ForwardBatch, + k: torch.Tensor, + k_rope: Optional[torch.Tensor], + ) -> None: + """Save KV cache to the token pool.""" + cache_loc = ( + forward_batch.out_cache_loc + if not layer.is_cross_attention + else forward_batch.encoder_out_cache_loc + ) + args = ( + layer, + cache_loc, + k, + k_rope, + ) + kwargs = {} + if self.dcp_size > 1: + kwargs["dcp_kv_mask"] = forward_batch.dcp_kv_mask + kwargs["dcp_size"] = self.dcp_size + forward_batch.token_to_kv_pool.set_mla_kv_buffer( # type: ignore + *args, **kwargs + ) + def forward_extend( self, q: torch.Tensor, @@ -1341,17 +1374,7 @@ def forward_extend( if k is not None: assert v is not None if save_kv_cache: - cache_loc = ( - forward_batch.out_cache_loc - if not layer.is_cross_attention - else forward_batch.encoder_out_cache_loc - ) - forward_batch.token_to_kv_pool.set_mla_kv_buffer( # type: ignore - layer, - cache_loc, - k, - k_rope, - ) + self._save_kv_cache(layer, forward_batch, k, k_rope) # Use MHA kernel if in MHA_ONE_SHOT mode if self.use_mha: @@ -1412,6 +1435,7 @@ def forward_extend( topk_indices=topk_indices, extend_lens_cpu=metadata.nsa_extend_seq_lens_list, page_size=1, + dcp_size=self.dcp_size, ) if nsa_impl == "tilelang": @@ -1435,23 +1459,27 @@ def forward_extend( ) assert page_table_1_flattened is not None kv_cache = dequantize_k_cache_paged( - kv_cache, page_table_1_flattened + kv_cache, page_table_1_flattened, dcp_size=self.dcp_size ) else: kv_cache = _cat([k, k_rope], dim=-1) page_table_1 = topk_indices - return self._forward_flashmla_sparse( + o, s = self._forward_flashmla_sparse( q_all=q_all, kv_cache=kv_cache, page_table_1=page_table_1, sm_scale=layer.scaling, v_head_dim=layer.v_head_dim, + return_softmax_lse=True, ) + if self.dcp_size > 1: + return cp_lse_ag_out_rs(o, s, get_dcp_group()) + return o elif nsa_impl == "flashmla_kv": if q_rope is not None: q_all = concat_mla_absorb_q_general(q_nope, q_rope) - return self._forward_flashmla_kv( + o, s = self._forward_flashmla_kv( q_all=q_all, kv_cache=kv_cache, sm_scale=layer.scaling, @@ -1460,9 +1488,13 @@ def forward_extend( layer=layer, metadata=metadata, page_table_1=page_table_1, + return_softmax_lse=True, ) + if self.dcp_size > 1: + return cp_lse_ag_out_rs(o, s, get_dcp_group()) + return o elif nsa_impl == "fa3": - return self._forward_fa3( + o, s = self._forward_fa3( q_rope=q_rope, kv_cache=kv_cache, v_head_dim=layer.v_head_dim, @@ -1475,7 +1507,11 @@ def forward_extend( sm_scale=layer.scaling, logit_cap=layer.logit_cap, page_size=1, + return_softmax_lse=True, ) + if self.dcp_size > 1: + return cp_lse_ag_out_rs(o, s, get_dcp_group()) + return o else: raise ValueError(f"Unsupported {nsa_impl = }") @@ -1520,17 +1556,7 @@ def forward_decode( if k is not None: assert v is not None if save_kv_cache: - cache_loc = ( - forward_batch.out_cache_loc - if not layer.is_cross_attention - else forward_batch.encoder_out_cache_loc - ) - forward_batch.token_to_kv_pool.set_mla_kv_buffer( # type: ignore - layer, - cache_loc, - k, - k_rope, - ) + self._save_kv_cache(layer, forward_batch, k, k_rope) # Do absorbed multi-latent attention kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) @@ -1555,22 +1581,27 @@ def forward_decode( page_table=metadata.page_table_1, topk_indices=topk_indices, page_size=1, + dcp_size=self.dcp_size, ) if self.nsa_decode_impl == "flashmla_sparse": if q_rope is not None: q_all = concat_mla_absorb_q_general(q_nope, q_rope) - return self._forward_flashmla_sparse( + o, s = self._forward_flashmla_sparse( q_all=q_all, kv_cache=kv_cache, page_table_1=page_table_1, sm_scale=layer.scaling, v_head_dim=layer.v_head_dim, + return_softmax_lse=True, ) + if self.dcp_size > 1: + return cp_lse_ag_out_rs(o, s, get_dcp_group()) + return o elif self.nsa_decode_impl == "flashmla_kv": if q_rope is not None: q_all = concat_mla_absorb_q_general(q_nope, q_rope) - return self._forward_flashmla_kv( + o, s = self._forward_flashmla_kv( q_all=q_all, kv_cache=kv_cache, sm_scale=layer.scaling, @@ -1579,7 +1610,11 @@ def forward_decode( layer=layer, metadata=metadata, page_table_1=page_table_1, + return_softmax_lse=True, ) + if self.dcp_size > 1: + return cp_lse_ag_out_rs(o, s, get_dcp_group()) + return o elif self.nsa_decode_impl == "tilelang": if q_rope is not None: q_all = concat_mla_absorb_q_general(q_nope, q_rope) @@ -1591,7 +1626,7 @@ def forward_decode( v_head_dim=layer.v_head_dim, ) elif self.nsa_decode_impl == "fa3": - return self._forward_fa3( + o, s = self._forward_fa3( q_rope=q_rope, kv_cache=kv_cache, v_head_dim=layer.v_head_dim, @@ -1604,7 +1639,11 @@ def forward_decode( sm_scale=layer.scaling, logit_cap=layer.logit_cap, page_size=1, + return_softmax_lse=True, ) + if self.dcp_size > 1: + return cp_lse_ag_out_rs(o, s, get_dcp_group()) + return o elif self.nsa_decode_impl == "aiter": if q_rope is not None: q_all = torch.cat([q_nope, q_rope], dim=-1) @@ -1634,13 +1673,14 @@ def _forward_fa3( sm_scale: float, logit_cap: float, page_size: int, + return_softmax_lse: bool = False, ) -> torch.Tensor: k_rope_cache = kv_cache[:, :, v_head_dim:] c_kv_cache = kv_cache[:, :, :v_head_dim] qk_rope_dim = k_rope_cache.shape[-1] k_rope_cache = k_rope_cache.view(-1, page_size, 1, qk_rope_dim) c_kv_cache = c_kv_cache.view(-1, page_size, 1, v_head_dim) - o = flash_attn_with_kvcache( + out = flash_attn_with_kvcache( q=q_rope, k_cache=k_rope_cache, v_cache=c_kv_cache, @@ -1653,10 +1693,13 @@ def _forward_fa3( softmax_scale=sm_scale, causal=True, softcap=logit_cap, - return_softmax_lse=False, + return_softmax_lse=return_softmax_lse, num_splits=self.num_splits, ) - return o # type: ignore + if return_softmax_lse: + out, lse, *reset = out + return out, lse + return out # type: ignore def _forward_flashmla_sparse( self, @@ -1665,7 +1708,11 @@ def _forward_flashmla_sparse( v_head_dim: int, page_table_1: torch.Tensor, sm_scale: float, + return_softmax_lse: bool = False, ) -> torch.Tensor: + if self.dcp_size > 1: + q_all = get_dcp_group().all_gather(q_all.contiguous(), dim=1) + from sgl_kernel.flash_mla import flash_mla_sparse_fwd # FlashMLA sparse kernel requires num_heads to be a multiple of 64 (Hopper) or 128 (Blackwell) @@ -1693,7 +1740,7 @@ def _forward_flashmla_sparse( # indices shape must be (s_q, h_kv=1, topk), keep h_kv=1 unchanged indices_input = page_table_1.unsqueeze(1) - o, _, _ = flash_mla_sparse_fwd( + o, _, lse = flash_mla_sparse_fwd( q=q_input, kv=kv_cache, indices=indices_input, @@ -1704,8 +1751,10 @@ def _forward_flashmla_sparse( # Trim output back to original num_heads if we padded if need_padding: o = o[:, :num_heads, :] + if self.dcp_size > 1: + o = o.contiguous() - return o + return (o, lse) if return_softmax_lse else o def _forward_flashmla_kv( self, @@ -1716,13 +1765,17 @@ def _forward_flashmla_kv( layer, metadata: NSAMetadata, page_table_1, + return_softmax_lse: bool = False, ) -> torch.Tensor: + if self.dcp_size > 1: + q_all = get_dcp_group().all_gather(q_all.contiguous(), dim=1) + from sgl_kernel.flash_mla import flash_mla_with_kvcache cache_seqlens = metadata.nsa_cache_seqlens_int32 # TODO the 2nd dim is seq_len_q, need to be >1 when MTP - q_all = q_all.view(-1, 1, layer.tp_q_head_num, layer.head_dim) + q_all = q_all.view(-1, 1, layer.tp_q_head_num * self.dcp_size, layer.head_dim) kv_cache = kv_cache.view(-1, self.real_page_size, 1, self.kv_cache_dim) assert self.real_page_size == 64, "only page size 64 is supported" @@ -1735,7 +1788,7 @@ def _forward_flashmla_kv( indices.shape[-1] == self.nsa_index_topk ) # requirement of FlashMLA decode kernel - o, _ = flash_mla_with_kvcache( + o, lse = flash_mla_with_kvcache( q=q_all, k_cache=kv_cache, cache_seqlens=cache_seqlens, @@ -1750,7 +1803,7 @@ def _forward_flashmla_kv( ), is_fp8_kvcache=True, ) - return o + return (o, lse) if return_softmax_lse else o def _forward_standard_mha( self, @@ -2050,6 +2103,7 @@ def set_nsa_prefill_impl(self, forward_batch: Optional[ForwardBatch] = None): and sum_seq_lens <= forward_batch.get_max_chunk_capacity() # Fits in chunk and (not is_nsa_enable_prefill_cp()) # CP not enabled + and (not is_nsa_enable_decode_cp()) # DCP not enabled ) else: self.use_mha = False # Decode/verify always use MLA diff --git a/python/sglang/srt/layers/attention/utils.py b/python/sglang/srt/layers/attention/utils.py index 44d5edaafa30..6d36958ea3f3 100644 --- a/python/sglang/srt/layers/attention/utils.py +++ b/python/sglang/srt/layers/attention/utils.py @@ -4,6 +4,8 @@ from sglang.srt.utils import is_cuda +from sglang.srt.distributed.parallel_state import GroupCoordinator + _FLASHMLA_CREATE_KV_BLOCK_SIZE = 4096 FLASHMLA_CREATE_KV_BLOCK_SIZE_TRITON = tl.constexpr(_FLASHMLA_CREATE_KV_BLOCK_SIZE) @@ -411,3 +413,212 @@ def concat_mla_absorb_q_general(q_nope, q_rope): return concat_mla_absorb_q(q_nope, q_rope) else: return torch.cat([q_nope, q_rope], dim=-1) + + +# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.12.0/vllm/attention/ops/common.py +@triton.jit +def _correct_attn_cp_out_kernel( + outputs_ptr, + new_output_ptr, + lses_ptr, + vlse_ptr, + outputs_stride_B, + outputs_stride_H, + outputs_stride_D, + lses_stride_N, + lses_stride_B, + lses_stride_H, + lse_idx, + HEAD_DIM: tl.constexpr, + N_ROUNDED: tl.constexpr, + IS_BASE_E: tl.constexpr, +): + """ + Apply the all-gathered lses to correct each local rank's attention + output. we still need perform a cross-rank reduction to obtain the + final attention output. + + Args: + outputs_ptr (triton.PointerType): + Pointer to input tensor of shape [ B, H, D ] + lses_ptr (triton.PointerType): + Pointer to input tensor of shape [ N, B, H ] + new_output_ptr (triton.PointerType): + Pointer to output tensor of shape [ B, H, D ] + vlse_ptr (triton.PointerType): + Pointer to output tensor of shape [ B, H ] + """ + batch_idx = tl.program_id(axis=0).to(tl.int64) + head_idx = tl.program_id(axis=1).to(tl.int64) + d_offsets = tl.arange(0, HEAD_DIM) + num_n_offsets = tl.arange(0, N_ROUNDED) + + # shape = [N] + lse_offsets = ( + num_n_offsets * lses_stride_N + + batch_idx * lses_stride_B + + head_idx * lses_stride_H + ) + + # calc final lse + lse = tl.load(lses_ptr + lse_offsets) + lse = tl.where((lse != lse) | (lse == float("inf")), -float("inf"), lse) + lse_max = tl.max(lse, axis=0) + lse_max = tl.where(lse_max == -float("inf"), 0, lse_max) + lse -= lse_max + if IS_BASE_E: + lse_exp = tl.exp(lse) + lse_acc = tl.sum(lse_exp, axis=0) + lse = tl.log(lse_acc) + else: + lse_exp = tl.exp2(lse) + lse_acc = tl.sum(lse_exp, axis=0) + lse = tl.log2(lse_acc) + lse += lse_max + + lse_offsets = batch_idx * lses_stride_B + head_idx * lses_stride_H + tl.store(vlse_ptr + lse_offsets, lse) + + # shape = [D] + output_offsets = ( + batch_idx * outputs_stride_B + + head_idx * outputs_stride_H + + d_offsets * outputs_stride_D + ) + + # correct output + lse_offset = ( + lse_idx * lses_stride_N + batch_idx * lses_stride_B + head_idx * lses_stride_H + ) + lse_tmp = tl.load(lses_ptr + lse_offset) + lse_finally = lse_tmp - lse + lse_finally = tl.where( + (lse_finally != lse_finally) | (lse_finally == float("inf")), + -float("inf"), + lse_finally, + ) + factor = tl.exp(lse_finally) if IS_BASE_E else tl.exp2(lse_finally) + output = tl.load(outputs_ptr + output_offsets) + output = output * factor + + tl.store(new_output_ptr + output_offsets, output) + + +class CPTritonContext: + """The CPTritonContext is used to avoid recompilation of the Triton JIT.""" + + def __init__(self): + self.inner_kernel = None + + def call_kernel(self, kernel, grid, *regular_args, **const_args): + if self.inner_kernel is None: + self.inner_kernel = kernel[grid](*regular_args, **const_args) + else: + self.inner_kernel[grid](*regular_args) + + +def correct_attn_out( + out: torch.Tensor, + lses: torch.Tensor, + cp_rank: int, + ctx: CPTritonContext, + is_lse_base_on_e: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + """Correct the attention output using the all-gathered lses. + + Args: + out: Tensor of shape [ B, H, D ] + lses: Tensor of shape [ N, B, H ] + cp_rank: Current rank in the context-parallel group + ctx: Triton context to avoid recompilation + + Returns: + Tuple of (out, lse) with corrected attention and final log-sum-exp. + """ + if ctx is None: + ctx = CPTritonContext() + + # --- Normalize to 3D views --- + if out.ndim == 4 and out.shape[1] == 1: + out = out.squeeze(1) + assert out.ndim == 3, f"expected out [B,H,D] or [B,1,H,D], got {tuple(out.shape)}" + + if lses.ndim == 4 and lses.shape[-1] == 1: + lses = lses.squeeze(-1) + if lses.ndim == 4 and lses.shape[1] == 1: + lses = lses.squeeze(1) + assert lses.ndim == 3, ( + f"expected lses [N,B,H] (optionally with a 1-sized extra dim), " + f"got {tuple(lses.shape)}" + ) + + B, H, D = out.shape + N = lses.shape[0] + + # Strides after we normalized shapes to 3-D views. The kernel computes + # offsets for `vlse_ptr` using lses_stride_B/H, so the output buffer must + # have the same B/H stride layout as a slice of `lses`. + o_sB, o_sH, o_sD = out.stride() + l_sN, l_sB, l_sH = lses.stride() + + # Allocate LSE with the same B/H strides as `lses` so writes land correctly + # even when `lses` is a non-contiguous view (e.g., 4-D to 3-D squeeze). + lse = torch.empty_strided( + (B, H), (l_sB, l_sH), device=lses.device, dtype=lses.dtype + ) + + # Kernel launch config + grid = (B, H, 1) + + regular_args = ( + out, + out, + lses, + lse, + o_sB, + o_sH, + o_sD, + l_sN, + l_sB, + l_sH, + cp_rank, + ) + const_args = {"HEAD_DIM": D, "N_ROUNDED": N, "IS_BASE_E": is_lse_base_on_e} + + ctx.call_kernel(_correct_attn_cp_out_kernel, grid, *regular_args, **const_args) + return out, lse + + +def cp_lse_ag_out_rs( + cp_attn_out: torch.Tensor, + cp_attn_lse: torch.Tensor, + cp_group: GroupCoordinator, + ctx: CPTritonContext = None, + return_lse: bool = False, + is_lse_base_on_e: bool = False, +): + """ + cp_attn_out: [ B, H, D ] + cp_attn_lse: [ B, H ] + """ + if ctx is None: + ctx = CPTritonContext() + + assert cp_attn_lse.is_contiguous() + lses = cp_group.all_gather(cp_attn_lse, dim=0).view( + (cp_group.world_size,) + cp_attn_lse.shape + ) + out, lse = correct_attn_out( + cp_attn_out, lses, cp_group.rank_in_group, ctx, is_lse_base_on_e + ) + assert out.is_contiguous() + # All-reduce seems faster than `reduce_scatter_along_dim` as it avoids + # the extra `contiguous()` call. + out = cp_group.all_reduce(out) + cp_num_heads = lse.shape[1] // cp_group.world_size + cp_rank = cp_group.rank_in_group + out = out[:, cp_num_heads * cp_rank : cp_num_heads * (cp_rank + 1), :] + if return_lse: + lse = lse[:, cp_num_heads * cp_rank : cp_num_heads * (cp_rank + 1)] + return out, lse + return out diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index b885729c4746..9bb542747e20 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -287,6 +287,7 @@ def __init__( self.moe_dp_size = server_args.moe_dp_size self.dp_rank = dp_rank self.tp_size = server_args.tp_size + self.dcp_size = server_args.dcp_size self.moe_ep_size = server_args.ep_size self.pp_size = server_args.pp_size self.dp_size = server_args.dp_size @@ -615,7 +616,8 @@ def init_model_worker(self): self.device, self.gpu_id, empty_cache=False ) logger.info( - f"max_total_num_tokens={self.max_total_num_tokens}, " + f"max_total_num_tokens={self.max_total_num_tokens * self.dcp_size}, " + f"{f'dcp_size={self.dcp_size}, ' if self.dcp_size > 1 else ''}" f"chunked_prefill_size={self.server_args.chunked_prefill_size}, " f"max_prefill_tokens={self.max_prefill_tokens}, " f"max_running_requests={self.max_running_requests}, " @@ -654,7 +656,7 @@ def init_cache_with_memory_pool(self): disable=server_args.disable_radix_cache, req_to_token_pool=self.req_to_token_pool, token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, - page_size=self.page_size, + page_size=self.page_size * self.dcp_size, is_eagle=self.spec_algorithm.is_eagle(), tp_cache_group=( self.attn_tp_cpu_group @@ -2031,7 +2033,7 @@ def _get_new_batch_prefill_raw( # Prefill policy adder = PrefillAdder( - self.page_size, + self.page_size * self.dcp_size, self.tree_cache, self.token_to_kv_pool_allocator, self.running_batch, diff --git a/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py b/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py index 484a949f5b23..4b182858b49e 100644 --- a/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py +++ b/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py @@ -20,8 +20,10 @@ class SchedulerRuntimeCheckerMixin: def _get_token_info(self: Scheduler): - available_size = self.token_to_kv_pool_allocator.available_size() - evictable_size = self.tree_cache.evictable_size() + available_size = ( + self.token_to_kv_pool_allocator.available_size() // self.dcp_size + ) + evictable_size = self.tree_cache.evictable_size() // self.dcp_size num_used = self.max_total_num_tokens - (available_size + evictable_size) token_usage = num_used / self.max_total_num_tokens return num_used, token_usage, available_size, evictable_size diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 5312b2edb4f3..1c51f295c42d 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -1519,6 +1519,8 @@ def set_mla_kv_buffer( loc: torch.Tensor, cache_k_nope: torch.Tensor, cache_k_rope: torch.Tensor, + dcp_kv_mask: Optional[torch.Tensor] = None, + dcp_size: Optional[int] = None, ): layer_id = layer.layer_id @@ -1538,6 +1540,8 @@ def set_mla_kv_buffer( loc, cache_k_nope_fp8, cache_k_rope_fp8, + dcp_kv_mask=dcp_kv_mask, + dcp_size=dcp_size, ) else: if cache_k_nope.dtype != self.dtype: @@ -1552,6 +1556,8 @@ def set_mla_kv_buffer( loc, cache_k_nope, cache_k_rope, + dcp_kv_mask=dcp_kv_mask, + dcp_size=dcp_size, ) def get_mla_kv_buffer( @@ -1751,6 +1757,7 @@ def __init__( index_head_dim: int, enable_memory_saver: bool, kv_cache_dim: int, + dcp_size: int = 1, start_layer: Optional[int] = None, end_layer: Optional[int] = None, ): @@ -1797,7 +1804,8 @@ def __init__( # * buf[i, :page_size * head_dim] for fp8 data # * buf[i, page_size * head_dim:].view(float32) for scale ( - (size + page_size + 1) // self.page_size, + # We do not shard the indexer K cache when DCP is enabled. + (size * dcp_size + page_size + 1) // self.page_size, self.page_size * ( index_head_dim + index_head_dim // self.quant_block_size * 4 diff --git a/python/sglang/srt/mem_cache/utils.py b/python/sglang/srt/mem_cache/utils.py index 3aec3dd89b51..8eb91665e0f6 100644 --- a/python/sglang/srt/mem_cache/utils.py +++ b/python/sglang/srt/mem_cache/utils.py @@ -82,11 +82,81 @@ def set_mla_kv_buffer_kernel( tl.store(dst_ptr, src, mask=mask) +@triton.jit +def set_mla_kv_buffer_with_mask_kernel( + kv_buffer_ptr, + cache_k_nope_ptr, + cache_k_rope_ptr, + loc_ptr, + dcp_kv_mask_ptr, + dcp_size: tl.constexpr, + buffer_stride: tl.constexpr, + nope_stride: tl.constexpr, + rope_stride: tl.constexpr, + nope_dim: tl.constexpr, + rope_dim: tl.constexpr, + BLOCK: tl.constexpr, +): + pid_loc = tl.program_id(0) + pid_blk = tl.program_id(1) + + dcp_mask = tl.load(dcp_kv_mask_ptr + pid_loc) + + # Skip write if mask is 0 + if dcp_mask == 0: + return + + base = pid_blk * BLOCK + offs = base + tl.arange(0, BLOCK) + total_dim = nope_dim + rope_dim + mask = offs < total_dim + + loc = tl.load(loc_ptr + pid_loc).to(tl.int64) // dcp_size + dst_ptr = kv_buffer_ptr + loc * buffer_stride + offs + + # Three-way branch to handle boundary correctly while preserving fast path + if base + BLOCK <= nope_dim: + # Fast path: entire block is in nope region + src = tl.load( + cache_k_nope_ptr + pid_loc * nope_stride + offs, + mask=mask, + ) + elif base >= nope_dim: + # Fast path: entire block is in rope region + offs_rope = offs - nope_dim + src = tl.load( + cache_k_rope_ptr + pid_loc * rope_stride + offs_rope, + mask=mask, + ) + else: + # Boundary case: block spans nope/rope boundary (e.g., FP8 with nope_dim=528) + # Handle each offset individually to avoid negative indexing + is_nope = offs < nope_dim + is_rope = (offs >= nope_dim) & (offs < (nope_dim + rope_dim)) + + src_nope = tl.load( + cache_k_nope_ptr + pid_loc * nope_stride + offs, + mask=mask & is_nope, + other=0, + ) + src_rope = tl.load( + cache_k_rope_ptr + pid_loc * rope_stride + (offs - nope_dim), + mask=mask & is_rope, + other=0, + ) + + src = tl.where(is_nope, src_nope, src_rope) + + tl.store(dst_ptr, src, mask=mask) + + def set_mla_kv_buffer_triton( kv_buffer: torch.Tensor, loc: torch.Tensor, cache_k_nope: torch.Tensor, cache_k_rope: torch.Tensor, + dcp_kv_mask: Optional[torch.Tensor] = None, + dcp_size: Optional[int] = None, ): nope_dim = cache_k_nope.shape[-1] rope_dim = cache_k_rope.shape[-1] @@ -95,18 +165,34 @@ def set_mla_kv_buffer_triton( n_loc = loc.numel() grid = (n_loc, triton.cdiv(total_dim, BLOCK)) - set_mla_kv_buffer_kernel[grid]( - kv_buffer, - cache_k_nope, - cache_k_rope, - loc, - kv_buffer.stride(0), - cache_k_nope.stride(0), - cache_k_rope.stride(0), - nope_dim, - rope_dim, - BLOCK=BLOCK, - ) + if dcp_kv_mask is None: + set_mla_kv_buffer_kernel[grid]( + kv_buffer, + cache_k_nope, + cache_k_rope, + loc, + kv_buffer.stride(0), + cache_k_nope.stride(0), + cache_k_rope.stride(0), + nope_dim, + rope_dim, + BLOCK=BLOCK, + ) + else: + set_mla_kv_buffer_with_mask_kernel[grid]( + kv_buffer, + cache_k_nope, + cache_k_rope, + loc, + dcp_kv_mask, + dcp_size, + kv_buffer.stride(0), + cache_k_nope.stride(0), + cache_k_rope.stride(0), + nope_dim, + rope_dim, + BLOCK=BLOCK, + ) @triton.jit diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 9ff800a9f65d..b508371f0fb1 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -113,6 +113,7 @@ class DecodeInputBuffers(ForwardInputBuffers): global_num_tokens_for_logprob_gpu: torch.Tensor encoder_lens: Optional[torch.Tensor] pp_proxy_tensors: Optional[Dict[str, torch.Tensor]] + dcp_kv_mask: Optional[torch.Tensor] @classmethod def create( @@ -126,6 +127,7 @@ def create( dtype: torch.dtype, dp_size: int, pp_size: int, + dcp_size: int, is_encoder_decoder: bool, require_mlp_tp_gather: bool, seq_len_fill_value: int, @@ -168,6 +170,11 @@ def create( else: pp_proxy_tensors = None + if dcp_size > 1: + dcp_kv_mask = torch.zeros((max_num_token,), dtype=torch.bool) + else: + dcp_kv_mask = None + if is_encoder_decoder: encoder_lens = torch.full( (max_bs,), encoder_len_fill_value, dtype=torch.int32 @@ -210,6 +217,7 @@ def create( global_num_tokens_gpu=global_num_tokens_gpu, global_num_tokens_for_logprob_gpu=global_num_tokens_for_logprob_gpu, pp_proxy_tensors=pp_proxy_tensors, + dcp_kv_mask=dcp_kv_mask, ) def populate_from_forward_batch( @@ -285,6 +293,10 @@ def populate_from_forward_batch( dim = src.shape[0] buf[:dim].copy_(src) + # decode context parallel tensors. + if forward_batch.dcp_kv_mask is not None: + self.dcp_kv_mask[:raw_num_token].copy_(forward_batch.dcp_kv_mask) + # Detect whether the current forward pass is in capture mode is_capture_mode = False @@ -456,6 +468,7 @@ def __init__(self, model_runner: ModelRunner): model_runner.server_args.enable_profile_cuda_graph ) self.tp_size = model_runner.server_args.tp_size + self.dcp_size = model_runner.server_args.dcp_size self.dp_size = model_runner.server_args.dp_size self.pp_size = model_runner.server_args.pp_size self.enable_pdmux = model_runner.server_args.enable_pdmux @@ -542,6 +555,7 @@ def __init__(self, model_runner: ModelRunner): dtype=self.model_runner.model_config.dtype, dp_size=self.dp_size, pp_size=self.pp_size, + dcp_size=self.dcp_size, is_encoder_decoder=self.is_encoder_decoder, require_mlp_tp_gather=self.require_mlp_tp_gather, seq_len_fill_value=self.seq_len_fill_value, @@ -779,6 +793,11 @@ def capture_one_batch_size( {k: v[:num_tokens] for k, v in buffers.pp_proxy_tensors.items()} ) + if self.dcp_size > 1: + dcp_kv_mask = buffers.dcp_kv_mask[:num_tokens] + else: + dcp_kv_mask = None + if self.require_mlp_tp_gather: buffers.global_num_tokens_gpu.copy_( torch.tensor( @@ -876,6 +895,7 @@ def capture_one_batch_size( num_token_non_padded=buffers.num_token_non_padded, global_forward_mode=self.capture_forward_mode, lora_ids=lora_ids, + dcp_kv_mask=dcp_kv_mask, ) self.tbo_plugin.capture_one_batch_size(forward_batch, num_tokens=num_tokens) diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 0e84ec8aab27..8e05a03cf500 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -376,6 +376,9 @@ class ForwardBatch(ForwardBatchDeepSeekMHAMixin): # For dumper: request IDs for cross-step sequence tracking rids: Optional[List[str]] = None + # For decode context parallel + dcp_kv_mask: Optional[torch.Tensor] = None + @classmethod def init_new( cls, @@ -527,6 +530,13 @@ def init_new( model_runner.lora_manager.prepare_lora_batch(ret) + # For DCP + if model_runner.dcp_size > 1: + dcp_size = model_runner.dcp_size + dcp_rank = model_runner.dcp_rank + ret.dcp_kv_mask = ret.out_cache_loc % dcp_size == dcp_rank + # ret.out_cache_loc = ret.out_cache_loc // dcp_size + return ret def adjust_num_token_non_padded_for_attn_tp(self, server_args) -> None: diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index d3ff835bbd8a..03b8539c227f 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -307,6 +307,8 @@ def __init__( self.gpu_id = gpu_id self.tp_rank = tp_rank self.tp_size = tp_size + self.dcp_size = server_args.dcp_size + self.dcp_rank = self.tp_rank % self.dcp_size self.moe_ep_rank = moe_ep_rank self.moe_ep_size = moe_ep_size self.dp_size = server_args.dp_size if server_args.enable_dp_attention else 1 @@ -804,6 +806,7 @@ def _(data, dim): attention_data_parallel_size=self.dp_size, pipeline_model_parallel_size=self.pp_size, expert_model_parallel_size=self.moe_ep_size, + decode_context_parallel_size=self.dcp_size, attention_context_model_parallel_size=self.attn_cp_size, moe_data_model_parallel_size=self.moe_dp_size, duplicate_tp_group=self.server_args.enable_pdmux, @@ -1911,6 +1914,7 @@ def _dummy_run(self, batch_size: int, run_ctx=None): dtype=self.model_config.dtype, dp_size=self.server_args.dp_size, pp_size=self.server_args.pp_size, + dcp_size=self.server_args.dcp_size, is_encoder_decoder=self.model_config.is_encoder_decoder, require_mlp_tp_gather=require_mlp_tp_gather_, seq_len_fill_value=seq_len_fill_value, diff --git a/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py b/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py index 2af60158f7cd..782fe76b5106 100644 --- a/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py +++ b/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py @@ -77,7 +77,10 @@ def get_cell_size_per_token(self: ModelRunner, num_layers: int) -> int: element_size = torch._utils._element_size( NSATokenToKVPool.index_k_with_scale_buffer_dtype ) - cell_size += indexer_size_per_token * num_layers * element_size + # We do not shard the indexer K cache when DCP is enabled. + cell_size += ( + indexer_size_per_token * num_layers * element_size * self.dcp_size + ) else: if self.model_config.is_hybrid_swa: full_layers_num = len(self.model_config.full_attention_layer_ids) @@ -539,6 +542,7 @@ def init_memory_pool(self: ModelRunner, total_gpu_memory: int): start_layer=self.start_layer, end_layer=self.end_layer, index_head_dim=get_nsa_index_head_dim(self.model_config.hf_config), + dcp_size=self.dcp_size, ) elif self.use_mla_backend and not self.mambaish_config: assert not is_nsa_model @@ -714,7 +718,7 @@ def init_memory_pool(self: ModelRunner, total_gpu_memory: int): need_sort=need_sort, ) else: - if self.page_size == 1: + if self.page_size == 1 and self.dcp_size == 1: self.token_to_kv_pool_allocator = TokenToKVPoolAllocator( self.max_total_num_tokens, dtype=self.kv_cache_dtype, @@ -724,8 +728,8 @@ def init_memory_pool(self: ModelRunner, total_gpu_memory: int): ) else: self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator( - self.max_total_num_tokens, - page_size=self.page_size, + self.max_total_num_tokens * self.dcp_size, + page_size=self.page_size * self.dcp_size, dtype=self.kv_cache_dtype, device=self.device, kvcache=self.token_to_kv_pool, diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 5348d6a27ab8..fd36b4b9b510 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -42,6 +42,7 @@ ) from sglang.srt.distributed import ( divide, + get_dcp_group, get_moe_expert_parallel_world_size, get_pp_group, get_tensor_model_parallel_world_size, @@ -1126,6 +1127,9 @@ def __init__( self.max_position_embeddings = max_position_embeddings self.kv_cache_dtype = get_global_server_args().kv_cache_dtype + self.dcp_size = get_dcp_group().world_size + self.dcp_rank = get_dcp_group().rank_in_group + # NOTE modification to rope_scaling must be done early enough, b/c e.g. Indexer needs it if rope_scaling: rope_scaling["rope_type"] = "deepseek_yarn" @@ -1535,6 +1539,33 @@ def rebuild_cp_kv_cache(self, latent_cache, forward_batch, k_nope, k_pe): k_pe = latent_cache_output[..., self.kv_lora_rank :].unsqueeze(1) return k_nope, k_pe + def _filter_topk_indices_by_dcp(self, topk_indices: torch.Tensor) -> torch.Tensor: + """Filter topk_indices to keep only indices where index % dcp_size == dcp_rank. + + This is used for DCP (decode context parallel) to ensure each rank + only processes its assigned portion of the KV cache. + + Args: + topk_indices: Tensor of shape [num_tokens, topk] containing topk indices + + Returns: + The same tensor (modified in-place), where indices not matching + the DCP rank are set to -1 + """ + if self.dcp_size <= 1: + return topk_indices + + # Create mask for indices that belong to this DCP rank + mask = (topk_indices % self.dcp_size == self.dcp_rank) & (topk_indices >= 0) + + if envs.SGLANG_NSA_FUSE_TOPK.get(): + topk_indices //= self.dcp_size + + # Set invalid indices to -1 in-place + topk_indices[~mask] = -1 + + return topk_indices + def forward_absorb_prepare( self, positions: torch.Tensor, @@ -1638,6 +1669,7 @@ def forward_absorb_prepare( forward_batch=forward_batch, layer_id=self.layer_id, ) + topk_indices = self._filter_topk_indices_by_dcp(topk_indices) current_stream.wait_stream(self.alt_stream) else: k_nope = k_nope.unsqueeze(1) @@ -1650,6 +1682,7 @@ def forward_absorb_prepare( forward_batch=forward_batch, layer_id=self.layer_id, ) + topk_indices = self._filter_topk_indices_by_dcp(topk_indices) else: q = self.q_proj(hidden_states)[0].view( -1, self.num_local_heads, self.qk_head_dim diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index a250cf1feed2..eab97b86f572 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -349,6 +349,7 @@ class ServerArgs: # Runtime options device: Optional[str] = None tp_size: int = 1 + dcp_size: int = 1 pp_size: int = 1 pp_max_micro_batch_size: Optional[int] = None pp_async_batch_depth: int = 0 @@ -3348,6 +3349,13 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.tp_size, help="The tensor parallelism size.", ) + parser.add_argument( + "--decode-context-parallel-size", + "--dcp-size", + type=int, + default=ServerArgs.dcp_size, + help="The decode context parallel size.", + ) parser.add_argument( "--attention-context-parallel-size", "--attn-cp-size", @@ -5156,6 +5164,7 @@ def add_cli_args(parser: argparse.ArgumentParser): @classmethod def from_cli_args(cls, args: argparse.Namespace): args.tp_size = args.tensor_parallel_size + args.dcp_size = args.decode_context_parallel_size args.pp_size = args.pipeline_parallel_size args.attn_cp_size = args.attention_context_parallel_size args.moe_dp_size = args.moe_data_parallel_size