diff --git a/docs/advanced_features/server_arguments.md b/docs/advanced_features/server_arguments.md index d6ffbe235a72..e6930c841c20 100644 --- a/docs/advanced_features/server_arguments.md +++ b/docs/advanced_features/server_arguments.md @@ -153,6 +153,8 @@ Please consult the documentation below and [server_args.py](https://github.com/s | `--device` | The device to use ('cuda', 'xpu', 'hpu', 'npu', 'cpu'). Defaults to auto-detection if not specified. | `None` | Type: str | | `--tensor-parallel-size`
`--tp-size` | The tensor parallelism size. | `1` | Type: int | | `--pipeline-parallel-size`
`--pp-size` | The pipeline parallelism size. | `1` | Type: int | +| `--attention-context-parallel-size`
`--attn-cp-size`| The attention context parallelism size. | `1` | Type: int| +| `--moe-data-parallel-size`
`--moe-dp-size`| The moe data parallelism size. | `1` | Type: int| | `--pp-max-micro-batch-size` | The maximum micro batch size in pipeline parallelism. | `None` | Type: int | | `--pp-async-batch-depth` | The async batch depth of pipeline parallelism. | `0` | Type: int | | `--stream-interval` | The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher | `1` | Type: int | diff --git a/docs/basic_usage/deepseek_v32.md b/docs/basic_usage/deepseek_v32.md index 581399f5f7e1..4389c2836d19 100644 --- a/docs/basic_usage/deepseek_v32.md +++ b/docs/basic_usage/deepseek_v32.md @@ -308,9 +308,7 @@ For context parallel in DeepSeek V3.2 model, we provide two different modes of s ### In sequence splitting -The first mode can be enabled by `--nsa-prefill-cp-mode in-seq-split`. This mode implements context parallel for DSA by splitting the sequence uniformly between context parallel ranks. At attention stage, each cp rank computes the indexer results of sharded sequence, and collects the whole kv cache through all gather operator. - -The communication group for context parallel reuses the one for attention tp, thus `cp_size` equals `atten_tp_size = tp_size / dp_size`. +The first mode can be enabled by `--nsa-prefill-cp-mode in-seq-split`. This mode implements context parallel for DSA by splitting the sequence uniformly between context parallel ranks. At attention stage, each cp rank computes the indexer results of sharded sequence, and collects the whole kv cache through all gather operator. Add `attn_cp_size` for communication group for context parallel. Note that in sequence splitting mode has the following restrictions: - The batch size is restricted to 1 for prefill batches @@ -323,7 +321,7 @@ For more details, please refer to PR https://github.com/sgl-project/sglang/pull/ Example: ```bash # In-seq splitting mode launched with EP + DP -python -m sglang.launch_server --model deepseek-ai/DeepSeek-V3.2-Exp --tp 8 --ep 8 --dp 2 --enable-dp-attention --enable-nsa-prefill-context-parallel --nsa-prefill-cp-mode in-seq-split --max-running-requests 32 +python -m sglang.launch_server --model deepseek-ai/DeepSeek-V3.2-Exp --tp 8 --ep 8 --dp 2 --enable-dp-attention --enable-nsa-prefill-context-parallel --attn-cp-size 4 --nsa-prefill-cp-mode in-seq-split --max-running-requests 32 ``` ### Round robin splitting (default setting) @@ -337,7 +335,7 @@ For more details, please refer to PR https://github.com/sgl-project/sglang/pull/ Example usage: ```bash # Launch with FusedMoe + CP8 -python -m sglang.launch_server --model deepseek-ai/DeepSeek-V3.2-Exp --tp 8 --enable-nsa-prefill-context-parallel --nsa-prefill-cp-mode round-robin-split --max-running-requests 32 +python -m sglang.launch_server --model deepseek-ai/DeepSeek-V3.2-Exp --tp 8 --enable-nsa-prefill-context-parallel --attn-cp-size 8 --nsa-prefill-cp-mode round-robin-split --max-running-requests 32 ``` ### Pipeline Parallel + Context Parallel (PP + CP) @@ -361,6 +359,7 @@ python3 -m sglang.launch_server \ --tp 8 --pp-size 2 \ --dp-size 1 --moe-dense-tp-size 1 \ --enable-nsa-prefill-context-parallel \ + --attn-cp-size 8 \ --nsa-prefill-cp-mode round-robin-split \ --trust-remote-code \ --disable-radix-cache \ @@ -384,6 +383,7 @@ python3 -m sglang.launch_server \ --tp 8 --pp-size 2 \ --dp-size 1 --moe-dense-tp-size 1 \ --enable-nsa-prefill-context-parallel \ + --attn-cp-size 8 \ --nsa-prefill-cp-mode round-robin-split \ --trust-remote-code \ --disable-radix-cache \ @@ -411,6 +411,7 @@ python -m sglang.launch_server \ --tp 8 --pp-size 2 \ --dp-size 1 --moe-dense-tp-size 1 \ --enable-nsa-prefill-context-parallel \ + --attn-cp-size 8 \ --nsa-prefill-cp-mode round-robin-split \ --disaggregation-ib-device mlx5_bond_0,mlx5_bond_1,mlx5_bond_2,mlx5_bond_3 \ --trust-remote-code \ @@ -436,6 +437,7 @@ python -m sglang.launch_server \ --tp 8 --pp-size 2 \ --dp-size 1 --moe-dense-tp-size 1 \ --enable-nsa-prefill-context-parallel \ + --attn-cp-size 8 \ --nsa-prefill-cp-mode round-robin-split \ --disaggregation-ib-device mlx5_bond_0,mlx5_bond_1,mlx5_bond_2,mlx5_bond_3 \ --trust-remote-code \ diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index df3499a6532c..d7d12c7441d8 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -1351,6 +1351,7 @@ def init_model_parallel_group( group_ranks: List[List[int]], local_rank: int, backend: str, + use_pynccl: Optional[bool] = None, use_custom_allreduce: Optional[bool] = None, use_message_queue_broadcaster: bool = False, group_name: Optional[str] = None, @@ -1368,7 +1369,11 @@ def init_model_parallel_group( group_ranks=group_ranks, local_rank=local_rank, torch_distributed_backend=backend, - use_pynccl=not (_is_npu or _is_xpu or backend == "mooncake"), + use_pynccl=( + not (_is_npu or _is_xpu or backend == "mooncake") + if use_pynccl is None + else use_pynccl + ), use_pymscclpp=use_mscclpp_allreduce, use_custom_allreduce=use_custom_allreduce, use_torch_symm_mem_all_reduce=use_torch_symm_mem_allreduce, @@ -1382,6 +1387,8 @@ def init_model_parallel_group( _TP: Optional[GroupCoordinator] = None +_ATTN_TP: Optional[GroupCoordinator] = None +_ATTN_CP: Optional[GroupCoordinator] = None # duplicate GroupCoordinator for prefill in PD-Multiplexing _PDMUX_PREFILL_TP_GROUP: Optional[GroupCoordinator] = None @@ -1404,10 +1411,30 @@ def get_tp_group() -> GroupCoordinator: return _TP +def get_attn_tp_group() -> GroupCoordinator: + assert ( + _ATTN_TP is not None + ), "attention tensor model parallel group is not initialized" + return _ATTN_TP + + +def get_attn_cp_group() -> GroupCoordinator: + assert ( + _ATTN_CP is not None + ), "attention context model parallel group is not initialized" + return _ATTN_CP + + +_MOE_DP: Optional[GroupCoordinator] = None _MOE_EP: Optional[GroupCoordinator] = None _MOE_TP: Optional[GroupCoordinator] = None +def get_moe_dp_group() -> GroupCoordinator: + assert _MOE_DP is not None, "moe data parallel group is not initialized" + return _MOE_DP + + def get_moe_ep_group() -> GroupCoordinator: assert _MOE_EP is not None, "expert model parallel group is not initialized" return _MOE_EP @@ -1558,6 +1585,9 @@ def initialize_model_parallel( tensor_model_parallel_size: int = 1, expert_model_parallel_size: int = 1, pipeline_model_parallel_size: int = 1, + attention_data_parallel_size: int = 1, + attention_context_model_parallel_size: int = 1, + moe_data_model_parallel_size: int = 1, backend: Optional[str] = None, duplicate_tp_group: bool = False, ) -> None: @@ -1567,8 +1597,16 @@ def initialize_model_parallel( Arguments: tensor_model_parallel_size: number of GPUs used for tensor model parallelism. + expert_model_parallel_size: number of GPUs used for expert model + parallelism. pipeline_model_parallel_size: number of GPUs used for pipeline model parallelism. + attention_data_parallel_size: number of GPUs used for attention data + parallelism. + attention_context_model_parallel_size: number of GPUs used for attention context + parallelism. + moe_data_model_parallel_size: number of GPUs used for moe data + parallelism. Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize @@ -1578,6 +1616,20 @@ def initialize_model_parallel( [g0, g1], [g2, g3], [g4, g5], [g6, g7] 2 pipeline model-parallel groups: [g0, g2, g4, g6], [g1, g3, g5, g7] + + Let's say we use 2 GPUs for attention context parallelism (attn_cp_size=2) and 4 GPUs for + attention tensor parallelism (attn_tp_size=4). As for MoE part, we use 2 GPUs for moe data + parallelism (moe_dp_size=2) and 4 GPUs for moe expert parallelism (moe_ep_size=4). The present + function will create the following groups: + 2 tensor model-parallel groups: + [g0, g1, g2, g3], [g4, g5, g6, g7] + 4 attention context-parallel groups: + [g0, g4], [g1, g5], [g2, g6], [g3, g7] + 2 moe expert-parallel groups: + [g0, g1, g2, g3], [g4, g5, g6, g7] + 4 moe data-parallel groups: + [g0, g4], [g1, g5], [g2, g6], [g3, g7] + Note that for efficiency, the caller should make sure adjacent ranks are on the same DGX box. For example if we are using 2 DGX-1 boxes with a total of 16 GPUs, rank 0 to 7 belong to the first box and @@ -1600,9 +1652,12 @@ def initialize_model_parallel( global _TP assert _TP is None, "tensor model parallel group is already initialized" group_ranks = [] - for i in range(num_tensor_model_parallel_groups): + for tp_group_idx in range(num_tensor_model_parallel_groups): ranks = list( - range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size) + range( + tp_group_idx * tensor_model_parallel_size, + (tp_group_idx + 1) * tensor_model_parallel_size, + ) ) group_ranks.append(ranks) @@ -1637,8 +1692,98 @@ def initialize_model_parallel( _TP.pynccl_comm.disabled = False _PDMUX_PREFILL_TP_GROUP.pynccl_comm.disabled = False + attn_dp_size = attention_data_parallel_size + attn_cp_size = attention_context_model_parallel_size + attn_tp_size = tensor_model_parallel_size // attn_cp_size // attn_dp_size + + global _ATTN_CP + assert ( + _ATTN_CP is None + ), "attention context model parallel group is already initialized" + if attn_cp_size == tensor_model_parallel_size: + _ATTN_CP = _TP + else: + group_ranks = [] + for tp_group_idx in range(num_tensor_model_parallel_groups): + for dp_idx in range(attn_dp_size): + for attn_tp_idx in range(attn_tp_size): + st = ( + tp_group_idx * tensor_model_parallel_size + + dp_idx * attn_tp_size * attn_cp_size + + attn_tp_idx + ) + en = ( + tp_group_idx * tensor_model_parallel_size + + (dp_idx + 1) * attn_tp_size * attn_cp_size + + attn_tp_idx + ) + ranks = list(range(st, en, attn_tp_size)) + group_ranks.append(ranks) + _ATTN_CP = init_model_parallel_group( + group_ranks, + get_world_group().local_rank, + backend, + group_name="attn_cp", + ) + + from sglang.srt.layers.sampler import SYNC_TOKEN_IDS_ACROSS_TP + + global _ATTN_TP + assert ( + _ATTN_TP is None + ), "attention tensor model parallel group is already initialized" + if attn_tp_size == tensor_model_parallel_size: + _ATTN_TP = _TP + else: + group_ranks = [] + for tp_group_idx in range(num_tensor_model_parallel_groups): + for cp_dp_combined_idx in range(attn_cp_size * attn_dp_size): + st = ( + tp_group_idx * tensor_model_parallel_size + + cp_dp_combined_idx * attn_tp_size + ) + en = ( + tp_group_idx * tensor_model_parallel_size + + (cp_dp_combined_idx + 1) * attn_tp_size + ) + ranks = list(range(st, en)) + group_ranks.append(ranks) + _ATTN_TP = init_model_parallel_group( + group_ranks, + get_world_group().local_rank, + backend, + use_pynccl=SYNC_TOKEN_IDS_ACROSS_TP, + use_mscclpp_allreduce=False, + use_custom_allreduce=False, + use_torch_symm_mem_allreduce=False, + group_name="attention_tp", + ) + moe_ep_size = expert_model_parallel_size - moe_tp_size = tensor_model_parallel_size // moe_ep_size + moe_dp_size = moe_data_model_parallel_size + moe_tp_size = tensor_model_parallel_size // moe_ep_size // moe_dp_size + + global _MOE_DP + assert _MOE_DP is None, "moe data parallel group is already initialized" + # gpus_per_pp_stage = tensor_model_parallel_size * attention_context_model_parallel_size + if moe_dp_size == tensor_model_parallel_size: + _MOE_DP = _TP + else: + group_ranks = [] + for tp_group_idx in range(num_tensor_model_parallel_groups): + for tp_ep_combined_idx in range(moe_tp_size * moe_ep_size): + st = tp_group_idx * tensor_model_parallel_size + tp_ep_combined_idx + en = ( + tp_group_idx + 1 + ) * tensor_model_parallel_size + tp_ep_combined_idx + ranks = list(range(st, en, moe_tp_size * moe_ep_size)) + group_ranks.append(ranks) + _MOE_DP = init_model_parallel_group( + group_ranks, + get_world_group().local_rank, + backend, + group_name="moe_dp", + ) global _MOE_EP assert _MOE_EP is None, "expert model parallel group is already initialized" @@ -1647,12 +1792,17 @@ def initialize_model_parallel( else: # TODO(ch-wan): use split_group to save memory group_ranks = [] - for i in range(num_tensor_model_parallel_groups): - for j in range(moe_tp_size): - st = i * tensor_model_parallel_size + j - en = (i + 1) * tensor_model_parallel_size + j - ranks = list(range(st, en, moe_tp_size)) - group_ranks.append(ranks) + for tp_group_idx in range(num_tensor_model_parallel_groups): + for moe_dp_idx in range(moe_dp_size): + for moe_tp_idx in range(moe_tp_size): + st = ( + tp_group_idx * tensor_model_parallel_size + + moe_dp_idx * moe_ep_size * moe_tp_size + + moe_tp_idx + ) + en = st + moe_ep_size * moe_tp_size + ranks = list(range(st, en, moe_tp_size)) + group_ranks.append(ranks) _MOE_EP = init_model_parallel_group( group_ranks, get_world_group().local_rank, @@ -1667,10 +1817,16 @@ def initialize_model_parallel( else: # TODO(ch-wan): use split_group to save memory group_ranks = [] - for i in range(num_tensor_model_parallel_groups): - for j in range(moe_ep_size): - st = i * tensor_model_parallel_size + j * moe_tp_size - en = i * tensor_model_parallel_size + (j + 1) * moe_tp_size + for tp_group_idx in range(num_tensor_model_parallel_groups): + for ep_dp_combined_idx in range(moe_ep_size * moe_dp_size): + st = ( + tp_group_idx * tensor_model_parallel_size + + ep_dp_combined_idx * moe_tp_size + ) + en = ( + tp_group_idx * tensor_model_parallel_size + + (ep_dp_combined_idx + 1) * moe_tp_size + ) ranks = list(range(st, en)) group_ranks.append(ranks) _MOE_TP = init_model_parallel_group( @@ -1685,8 +1841,10 @@ def initialize_model_parallel( global _PP assert _PP is None, "pipeline model parallel group is already initialized" group_ranks = [] - for i in range(num_pipeline_model_parallel_groups): - ranks = list(range(i, world_size, num_pipeline_model_parallel_groups)) + for pp_group_idx in range(num_pipeline_model_parallel_groups): + ranks = list( + range(pp_group_idx, world_size, num_pipeline_model_parallel_groups) + ) group_ranks.append(ranks) # pipeline parallel does not need custom allreduce _PP = init_model_parallel_group( @@ -1833,6 +1991,28 @@ def get_tensor_model_parallel_rank(): return get_tp_group().rank_in_group +# ATTN_TP +def get_attn_tensor_model_parallel_world_size(): + """Return world size for the attention tensor model parallel group.""" + return get_attn_tp_group().world_size + + +def get_attn_tensor_model_parallel_rank(): + """Return my rank for the attention tensor model parallel group.""" + return get_attn_tp_group().rank_in_group + + +# ATTN_CP +def get_attn_context_model_parallel_world_size(): + """Return world size for the attention context model parallel group.""" + return get_attn_cp_group().world_size + + +def get_attn_context_model_parallel_rank(): + """Return my rank for the attention context model parallel group.""" + return get_attn_cp_group().rank_in_group + + def get_pipeline_model_parallel_world_size(): """Return world size for the pipeline model parallel group.""" return get_pp_group().world_size @@ -1843,6 +2023,18 @@ def get_pipeline_model_parallel_rank(): return get_pp_group().rank_in_group +# MOE_DP +def get_moe_data_parallel_world_size(): + """Return world size for the moe data parallel group.""" + return get_moe_dp_group().world_size + + +def get_moe_data_parallel_rank(): + """Return my rank for the moe data parallel group.""" + return get_moe_dp_group().rank_in_group + + +# MOE_EP def get_moe_expert_parallel_world_size(): """Return world size for the moe expert parallel group.""" return get_moe_ep_group().world_size @@ -1853,6 +2045,7 @@ def get_moe_expert_parallel_rank(): return get_moe_ep_group().rank_in_group +# MOE_TP def get_moe_tensor_parallel_world_size(): """Return world size for the moe tensor parallel group.""" return get_moe_tp_group().world_size @@ -1885,6 +2078,16 @@ def destroy_model_parallel(): _MOE_TP.destroy() _MOE_TP = None + global _ATTN_CP + if _ATTN_CP: + _ATTN_CP.destroy() + _ATTN_CP = None + + global _MOE_DP + if _MOE_DP: + _MOE_DP.destroy() + _MOE_DP = None + global _PDMUX_PREFILL_TP_GROUP if _PDMUX_PREFILL_TP_GROUP: # type: ignore[union-attr] _PDMUX_PREFILL_TP_GROUP.destroy() diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 51e158b6e9b5..b701349a2a5c 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -938,7 +938,29 @@ def _launch_scheduler_processes( + ((pp_rank % pp_size_per_node) * tp_size_per_node) + (tp_rank % tp_size_per_node) * server_args.gpu_id_step ) - moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size) + attn_dp_size = ( + server_args.dp_size if server_args.enable_dp_attention else 1 + ) + + # Parallelism hierarchy (outermost to innermost): + # - Attention: Global(TP) -> DP -> ATTN_CP -> ATTN_TP (innermost) + # - MoE: Global(TP) -> MOE_DP -> EP -> MOE_TP (innermost) + attn_tp_size = ( + server_args.tp_size // attn_dp_size // server_args.attn_cp_size + ) + attn_cp_rank = (tp_rank // attn_tp_size) % server_args.attn_cp_size + moe_dp_rank = tp_rank // ( + server_args.tp_size // server_args.moe_dp_size + ) + moe_ep_rank = ( + tp_rank + % (server_args.tp_size // server_args.moe_dp_size) + // ( + server_args.tp_size + // server_args.moe_dp_size + // server_args.ep_size + ) + ) with maybe_reindex_device_id(gpu_id) as gpu_id: proc = mp.Process( @@ -948,6 +970,8 @@ def _launch_scheduler_processes( port_args, gpu_id, tp_rank, + attn_cp_rank, + moe_dp_rank, moe_ep_rank, pp_rank, None, diff --git a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py index 07c3d8f25475..8b48dfd91253 100644 --- a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py +++ b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py @@ -28,6 +28,10 @@ import torch_npu from sglang.srt.hardware_backend.npu.utils import get_indexer_weight_stream +from sglang.srt.distributed import ( + get_attn_context_model_parallel_rank, + get_attn_context_model_parallel_world_size, +) from sglang.srt.distributed.parallel_state import get_pp_group from sglang.srt.layers import deep_gemm_wrapper from sglang.srt.layers.attention.nsa.utils import ( @@ -35,7 +39,6 @@ is_nsa_enable_prefill_cp, is_nsa_prefill_cp_in_seq_split, ) -from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size from sglang.srt.layers.linear import ReplicatedLinear from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.rotary_embedding import get_rope_wrapper @@ -162,8 +165,8 @@ def __init__( self.alt_stream = alt_stream self.nsa_enable_prefill_cp = is_nsa_enable_prefill_cp() if self.nsa_enable_prefill_cp: - self.cp_size = get_attention_tp_size() - self.cp_rank = get_attention_tp_rank() + self.cp_size = get_attn_context_model_parallel_world_size() + self.cp_rank = get_attn_context_model_parallel_rank() else: self.cp_size = None self.cp_rank = None diff --git a/python/sglang/srt/layers/attention/nsa/utils.py b/python/sglang/srt/layers/attention/nsa/utils.py index 0c6e21291102..57c281beedcb 100644 --- a/python/sglang/srt/layers/attention/nsa/utils.py +++ b/python/sglang/srt/layers/attention/nsa/utils.py @@ -13,11 +13,11 @@ ) from sglang.srt.layers.dp_attention import ( DpPaddingMode, - attn_tp_all_gather_into_tensor, + attn_cp_all_gather_into_tensor, + get_attention_cp_group, + get_attention_cp_rank, + get_attention_cp_size, get_attention_dp_rank, - get_attention_tp_group, - get_attention_tp_rank, - get_attention_tp_size, is_allocation_symmetric, ) from sglang.srt.server_args import get_global_server_args @@ -52,7 +52,7 @@ def is_nsa_prefill_cp_round_robin_split(): def can_nsa_prefill_cp_round_robin_split(forward_batch: "ForwardBatch"): if not forward_batch.forward_mode.is_context_parallel_extend(): return False - cp_size = get_attention_tp_size() + cp_size = get_attention_cp_size() seq_len = sum(forward_batch.extend_seq_lens_cpu) return is_nsa_prefill_cp_round_robin_split() and seq_len > 0 and cp_size > 1 @@ -70,8 +70,8 @@ def nsa_cp_round_robin_split_data(input_: Union[torch.Tensor, List]): | dp_atten_tp3: token3, token7, token11, token15, token19, ... | | +-------------------------+ """ - cp_size = get_attention_tp_size() - cp_rank = get_attention_tp_rank() + cp_size = get_attention_cp_size() + cp_rank = get_attention_cp_rank() if isinstance(input_, (tuple, list)): indices = range(cp_rank, len(input_), cp_size) return input_[indices] @@ -93,9 +93,9 @@ def cal_padded_tokens(forward_batch: "ForwardBatch"): # calculate the actual token length after padding when attn_tp_size > 1 or in the MAX_LEN padding mode. global_num_tokens = forward_batch.global_num_tokens_cpu.copy() sync_group_size = len(global_num_tokens) - attn_tp_size = get_attention_tp_size() + attn_cp_size = get_attention_cp_size() for i in range(sync_group_size): - global_num_tokens[i] = ceil_align(global_num_tokens[i], attn_tp_size) + global_num_tokens[i] = ceil_align(global_num_tokens[i], attn_cp_size) dp_padding_mode = DpPaddingMode.get_dp_padding_mode( forward_batch.is_extend_in_batch, global_num_tokens ) @@ -106,12 +106,13 @@ def cal_padded_tokens(forward_batch: "ForwardBatch"): else: tokens = global_num_tokens[0] if can_nsa_prefill_cp_round_robin_split(forward_batch): - tokens = ceil_div(tokens, attn_tp_size) + tokens = ceil_div(tokens, attn_cp_size) return tokens def pad_nsa_cache_seqlens(forward_batch: "ForwardBatch", nsa_cache_seqlens): - if forward_batch.global_num_tokens_cpu is None: + attn_cp_size = get_attention_cp_size() + if attn_cp_size == 1 or not can_nsa_prefill_cp_round_robin_split(forward_batch): return nsa_cache_seqlens tokens = cal_padded_tokens(forward_batch) pad_len = tokens - nsa_cache_seqlens.shape[0] @@ -170,7 +171,7 @@ def can_cp_split(seq_len: int, cp_size: int, use_nsa: bool, forward_batch): def cp_split_and_rebuild_data(forward_batch, input_: torch.Tensor): if is_nsa_prefill_cp_round_robin_split(): - cp_size = get_attention_tp_size() + cp_size = get_attention_cp_size() assert ( input_.shape[0] % cp_size == 0 ), f"Expect input shape 0 can divided by cp size, but got input shape {input_.shape}, cp size {cp_size}" @@ -187,7 +188,7 @@ def cp_split_and_rebuild_data(forward_batch, input_: torch.Tensor): def cp_split_and_rebuild_position(forward_batch, positions: torch.Tensor): if is_nsa_prefill_cp_round_robin_split(): - cp_size = get_attention_tp_size() + cp_size = get_attention_cp_size() assert positions.shape[0] % cp_size == 0, ( f"Expect positions shape 0 can divided by cp size, but got positions shape {positions.shape}, " f"cp size {cp_size}" @@ -227,8 +228,8 @@ def nsa_cp_round_robin_split_q_seqs_kernel( def nsa_cp_round_robin_split_q_seqs_cpu(extend_seqs): - cp_size = get_attention_tp_size() - cp_rank = get_attention_tp_rank() + cp_size = get_attention_cp_size() + cp_rank = get_attention_cp_rank() extra_seq = 0 q_seqs = [] for bs, cur_len in enumerate(extend_seqs): @@ -253,8 +254,8 @@ def nsa_cp_round_robin_split_q_seqs( bs_idx_cpu(List) and bs_idx(torch.Tensor): marks which sequences are ultimately selected, i.e., those with a partitioned length greater than zero. """ - cp_size = get_attention_tp_size() - cp_rank = get_attention_tp_rank() + cp_size = get_attention_cp_size() + cp_rank = get_attention_cp_rank() # len(ret_q_lens_cpu) == len(bs_idx_cpu) ret_q_lens_cpu, bs_idx_cpu = nsa_cp_round_robin_split_q_seqs_cpu(extend_seqs_cpu) ret_q_lens = torch.empty( @@ -299,7 +300,7 @@ def cp_attn_tp_all_gather_reorganazied_into_tensor( if pad_size > 0: input_ = F.pad(input_, (0, 0, 0, pad_size), mode="constant", value=0) with use_symmetric_memory( - get_attention_tp_group(), disabled=not is_allocation_symmetric() + get_attention_cp_group(), disabled=not is_allocation_symmetric() ): input_tensor_all = torch.empty( max_len * attn_tp_size, @@ -308,7 +309,7 @@ def cp_attn_tp_all_gather_reorganazied_into_tensor( dtype=input_.dtype, ) # step2 - get_attention_tp_group().cp_all_gather_into_tensor_async( + get_attention_cp_group().cp_all_gather_into_tensor_async( input_tensor_all, input_, stream_op ) # step3 @@ -356,12 +357,12 @@ def cp_all_gather_rerange_output(input_tensor, cp_size, forward_batch, stream): """ if is_nsa_prefill_cp_round_robin_split(): with use_symmetric_memory( - get_attention_tp_group(), disabled=not is_allocation_symmetric() + get_attention_cp_group(), disabled=not is_allocation_symmetric() ): output_tensor = input_tensor.new_empty( (input_tensor.shape[0] * cp_size, *input_tensor.shape[1:]), ) - attn_tp_all_gather_into_tensor( + attn_cp_all_gather_into_tensor( output_tensor, input_tensor, ) diff --git a/python/sglang/srt/layers/communicator.py b/python/sglang/srt/layers/communicator.py index ebd0d3bea5fc..64f2cf66242e 100644 --- a/python/sglang/srt/layers/communicator.py +++ b/python/sglang/srt/layers/communicator.py @@ -39,6 +39,8 @@ dp_gather_partial, dp_reduce_scatter_tensor, dp_scatter, + get_attention_cp_rank, + get_attention_cp_size, get_attention_dp_size, get_attention_tp_rank, get_attention_tp_size, @@ -611,6 +613,8 @@ class CommunicateContext: attn_tp_rank: int attn_tp_size: int attn_dp_size: int + attn_cp_rank: int + attn_cp_size: int tp_size: int cache = None tp_rank: int @@ -623,6 +627,8 @@ def init_new(cls): attn_tp_rank = get_attention_tp_rank() attn_tp_size = get_attention_tp_size() attn_dp_size = get_attention_dp_size() + attn_cp_size = get_attention_cp_size() + attn_cp_rank = get_attention_cp_rank() tp_size = get_tensor_model_parallel_world_size() tp_rank = get_tensor_model_parallel_rank() process_group_sizes = { @@ -636,6 +642,8 @@ def init_new(cls): attn_tp_rank=attn_tp_rank, attn_tp_size=attn_tp_size, attn_dp_size=attn_dp_size, + attn_cp_rank=attn_cp_rank, + attn_cp_size=attn_cp_size, tp_size=tp_size, tp_rank=tp_rank, ) diff --git a/python/sglang/srt/layers/communicator_nsa_cp.py b/python/sglang/srt/layers/communicator_nsa_cp.py index d3f668edbc04..296d1456812a 100644 --- a/python/sglang/srt/layers/communicator_nsa_cp.py +++ b/python/sglang/srt/layers/communicator_nsa_cp.py @@ -32,8 +32,8 @@ ScatterMode, ) from sglang.srt.layers.dp_attention import ( - attn_tp_all_gather_into_tensor, - attn_tp_reduce_scatter_tensor, + attn_cp_all_gather_into_tensor, + attn_cp_reduce_scatter_tensor, get_local_dp_buffer, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch @@ -157,7 +157,7 @@ def _gather_hidden_states_and_residual( get_local_dp_buffer(), hidden_states, ) - attn_tp_all_gather_into_tensor( + attn_cp_all_gather_into_tensor( hidden_states, local_hidden_states, ) @@ -203,8 +203,8 @@ def _scatter_hidden_states( if nsa_use_prefill_cp(forward_batch): assert context.attn_dp_size == 1 input_hidden_states = hidden_states - hidden_states = hidden_states.tensor_split(context.attn_tp_size)[ - context.attn_tp_rank + hidden_states = hidden_states.tensor_split(context.attn_cp_size)[ + context.attn_cp_rank ] - attn_tp_reduce_scatter_tensor(hidden_states, input_hidden_states) + attn_cp_reduce_scatter_tensor(hidden_states, input_hidden_states) return hidden_states, residual diff --git a/python/sglang/srt/layers/dp_attention.py b/python/sglang/srt/layers/dp_attention.py index a215913e62c1..5bf5aa0c8800 100644 --- a/python/sglang/srt/layers/dp_attention.py +++ b/python/sglang/srt/layers/dp_attention.py @@ -12,6 +12,12 @@ from sglang.srt.distributed import ( GroupCoordinator, + get_attn_context_model_parallel_rank, + get_attn_context_model_parallel_world_size, + get_attn_cp_group, + get_attn_tensor_model_parallel_rank, + get_attn_tensor_model_parallel_world_size, + get_attn_tp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, get_tp_group, @@ -31,9 +37,6 @@ if TYPE_CHECKING: from sglang.srt.model_executor.forward_batch_info import ForwardBatch -_ATTN_TP_GROUP: Optional[GroupCoordinator] = None -_ATTN_TP_RANK: Optional[int] = None -_ATTN_TP_SIZE: Optional[int] = None _ATTN_DP_RANK: Optional[int] = None _ATTN_DP_SIZE: Optional[int] = None _LOCAL_ATTN_DP_SIZE: Optional[int] = None @@ -224,14 +227,20 @@ def is_dp_max_padding() -> bool: return _DpGatheredBufferWrapper.is_dp_max_padding() -def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size): - if not enable_dp_attention: - return tp_rank, tp_size, 0 - - attn_tp_size = tp_size // dp_size - attn_dp_rank = tp_rank // attn_tp_size +def compute_dp_attention_world_info( + enable_dp_attention, tp_rank, tp_size, dp_size, attn_cp_size: int = 1 +): + attn_dp_size = dp_size if enable_dp_attention else 1 + attn_tp_size = tp_size // attn_dp_size // attn_cp_size attn_tp_rank = tp_rank % attn_tp_size + if not enable_dp_attention: + attn_dp_rank = 0 + else: + # Rank layout is (dp, cp, tp) where tp is the fastest-changing dim: + # tp_rank = ((cp_rank * dp_size) + dp_rank) * attn_tp_size + attn_tp_rank + attn_dp_rank = tp_rank // (attn_tp_size * attn_cp_size) + return attn_tp_rank, attn_tp_size, attn_dp_rank @@ -256,23 +265,20 @@ def initialize_dp_attention( server_args: ServerArgs, model_config: ModelConfig, ): - global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _ATTN_DP_RANK, _ATTN_DP_SIZE + global _ATTN_DP_RANK, _ATTN_DP_SIZE global _LOCAL_ATTN_DP_SIZE, _LOCAL_ATTN_DP_RANK, _ENABLE_DP_ATTENTION_FLAG - - from sglang.srt.layers.sampler import SYNC_TOKEN_IDS_ACROSS_TP - enable_dp_attention = server_args.enable_dp_attention - tp_size = server_args.tp_size dp_size = server_args.dp_size moe_dense_tp_size = server_args.moe_dense_tp_size - pp_size = server_args.pp_size - - tp_rank = get_tensor_model_parallel_rank() + attn_cp_size = server_args.attn_cp_size _ENABLE_DP_ATTENTION_FLAG = enable_dp_attention - _ATTN_TP_RANK, _ATTN_TP_SIZE, _ATTN_DP_RANK = compute_dp_attention_world_info( - enable_dp_attention, tp_rank, tp_size, dp_size + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + + _, _, _ATTN_DP_RANK = compute_dp_attention_world_info( + enable_dp_attention, tp_rank, tp_size, dp_size, attn_cp_size ) _, _, _LOCAL_ATTN_DP_RANK = compute_dp_attention_local_info( enable_dp_attention, tp_rank, tp_size, dp_size, moe_dense_tp_size @@ -288,28 +294,6 @@ def initialize_dp_attention( _ATTN_DP_SIZE = 1 _LOCAL_ATTN_DP_SIZE = 1 - tp_group = get_tp_group() - # Trick to solve circular references - from sglang.srt.layers.attention.nsa.utils import is_nsa_enable_prefill_cp - - use_pynccl = True if is_nsa_enable_prefill_cp() else SYNC_TOKEN_IDS_ACROSS_TP - _ATTN_TP_GROUP = GroupCoordinator( - [ - list(range(head, head + _ATTN_TP_SIZE)) - for head in range(0, pp_size * tp_size, _ATTN_TP_SIZE) - ], - tp_group.local_rank, - torch.distributed.get_backend(tp_group.device_group), - use_pynccl=use_pynccl, - use_pymscclpp=False, - use_custom_allreduce=False, - use_torch_symm_mem_all_reduce=False, - use_hpu_communicator=False, - use_xpu_communicator=False, - use_npu_communicator=False, - group_name="attention_tp", - ) - _DpGatheredBufferWrapper.set_metadata( hidden_size=model_config.hidden_size, dtype=model_config.dtype, @@ -326,18 +310,27 @@ def is_allocation_symmetric() -> bool: def get_attention_tp_group() -> GroupCoordinator: - assert _ATTN_TP_GROUP is not None, "dp attention not initialized!" - return _ATTN_TP_GROUP + return get_attn_tp_group() def get_attention_tp_rank() -> int: - assert _ATTN_TP_RANK is not None, "dp attention not initialized!" - return _ATTN_TP_RANK + return get_attn_tensor_model_parallel_rank() def get_attention_tp_size() -> int: - assert _ATTN_TP_SIZE is not None, "dp attention not initialized!" - return _ATTN_TP_SIZE + return get_attn_tensor_model_parallel_world_size() + + +def get_attention_cp_group() -> GroupCoordinator: + return get_attn_cp_group() + + +def get_attention_cp_rank() -> int: + return get_attn_context_model_parallel_rank() + + +def get_attention_cp_size() -> int: + return get_attn_context_model_parallel_world_size() def get_attention_dp_rank() -> int: @@ -564,6 +557,10 @@ def attn_tp_reduce_scatter_tensor(output: torch.Tensor, input: torch.Tensor): return get_attention_tp_group().reduce_scatter_tensor(output, input) +def attn_cp_reduce_scatter_tensor(output: torch.Tensor, input: torch.Tensor): + return get_attention_cp_group().reduce_scatter_tensor(output, input) + + def attn_tp_all_reduce(input: torch.Tensor): return get_attention_tp_group().all_reduce(input) @@ -572,5 +569,9 @@ def attn_tp_all_gather_into_tensor(output: torch.Tensor, input: torch.Tensor): return get_attention_tp_group().all_gather_into_tensor(output, input) +def attn_cp_all_gather_into_tensor(output: torch.Tensor, input: torch.Tensor): + return get_attention_cp_group().all_gather_into_tensor(output, input) + + def attn_tp_all_gather(output_list: List[torch.Tensor], input: torch.Tensor): return get_attention_tp_group().all_gather(input, output_tensor_list=output_list) diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index eea20137aaee..ac4869e68352 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -418,6 +418,8 @@ def launch_tensor_parallel_group( tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group + 1), ) + attn_cp_rank = 0 + moe_dp_rank = 0 for pp_rank in pp_rank_range: for tp_rank in tp_rank_range: rank_port_args = port_args @@ -429,6 +431,7 @@ def launch_tensor_parallel_group( tp_rank, server_args.tp_size, server_args.dp_size, + server_args.attn_cp_size, ) # compute zmq ports for this dp rank rank_port_args = PortArgs.init_new( @@ -445,7 +448,30 @@ def launch_tensor_parallel_group( + ((pp_rank % pp_size_per_node) * tp_size_per_node) + (tp_rank % tp_size_per_node) * server_args.gpu_id_step ) - moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size) + attn_dp_size = ( + server_args.dp_size if server_args.enable_dp_attention else 1 + ) + + # Parallelism hierarchy (outermost to innermost): + # - Attention: Global(TP) -> DP -> ATTN_CP -> ATTN_TP (innermost) + # - MoE: Global(TP) -> MOE_DP -> EP -> MOE_TP (innermost) + attn_tp_size = ( + server_args.tp_size // attn_dp_size // server_args.attn_cp_size + ) + attn_cp_rank = (tp_rank // attn_tp_size) % server_args.attn_cp_size + moe_dp_rank = tp_rank // ( + server_args.tp_size // server_args.moe_dp_size + ) + moe_ep_rank = ( + tp_rank + % (server_args.tp_size // server_args.moe_dp_size) + // ( + server_args.tp_size + // server_args.moe_dp_size + // server_args.ep_size + ) + ) + with self.env_lock, maybe_reindex_device_id(gpu_id) as gpu_id: proc = mp.Process( target=self.run_scheduler_process_func, @@ -454,6 +480,8 @@ def launch_tensor_parallel_group( rank_port_args, gpu_id, tp_rank, + attn_cp_rank, + moe_dp_rank, moe_ep_rank, pp_rank, dp_rank, diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 8bf5245d8e30..44a1af776832 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -63,6 +63,7 @@ from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.layers.dp_attention import ( compute_dp_attention_world_info, + get_attention_cp_group, get_attention_tp_group, ) from sglang.srt.layers.moe import initialize_moe_config @@ -267,6 +268,8 @@ def __init__( tp_rank: int, moe_ep_rank: int, pp_rank: int, + attn_cp_rank: int, + moe_dp_rank: int, dp_rank: Optional[int], ): self.is_initializing = True @@ -277,6 +280,10 @@ def __init__( self.tp_rank = tp_rank self.moe_ep_rank = moe_ep_rank self.pp_rank = pp_rank + self.attn_cp_rank = attn_cp_rank + self.attn_cp_size = server_args.attn_cp_size + self.moe_dp_rank = moe_dp_rank + self.moe_dp_size = server_args.moe_dp_size self.dp_rank = dp_rank self.tp_size = server_args.tp_size self.moe_ep_size = server_args.ep_size @@ -322,6 +329,7 @@ def __init__( self.tp_rank, self.tp_size, self.dp_size, + self.attn_cp_size, ) ) @@ -405,7 +413,7 @@ def init_ipc_channels(self, port_args: PortArgs): context = zmq.Context(2) self.idle_sleeper = None - if self.pp_rank == 0 and self.attn_tp_rank == 0: + if self.pp_rank == 0 and self.attn_tp_rank == 0 and self.attn_cp_rank == 0: self.recv_from_tokenizer = get_zmq_socket( context, zmq.PULL, port_args.scheduler_input_ipc_name, False ) @@ -506,6 +514,8 @@ def init_tp_model_worker(self): tp_rank=self.tp_rank, moe_ep_rank=self.moe_ep_rank, pp_rank=self.pp_rank, + attn_cp_rank=self.attn_cp_rank, + moe_dp_rank=self.moe_dp_rank, dp_rank=self.dp_rank, nccl_port=self.nccl_port, ) @@ -524,6 +534,8 @@ def maybe_init_draft_worker(self): nccl_port=self.nccl_port, target_worker=self.tp_worker, dp_rank=self.dp_rank, + attn_cp_rank=self.attn_cp_rank, + moe_dp_rank=self.moe_dp_rank, ) if self.server_args.speculative_draft_load_format is not None: @@ -571,6 +583,8 @@ def init_model_worker(self): self.tp_cpu_group = self.tp_group.cpu_group self.attn_tp_group = get_attention_tp_group() self.attn_tp_cpu_group = self.attn_tp_group.cpu_group + self.attn_cp_group = get_attention_cp_group() + self.attn_cp_cpu_group = self.attn_cp_group.cpu_group self.pp_group = get_pp_group() self.world_group = get_world_group() @@ -1201,7 +1215,7 @@ def recv_requests( return [] if self.pp_rank == 0: - if self.attn_tp_rank == 0: + if self.attn_tp_rank == 0 and self.attn_cp_rank == 0: recv_reqs = [] while True: @@ -1225,7 +1239,7 @@ def recv_requests( else: recv_reqs = None else: - if self.attn_tp_rank == 0: + if self.attn_tp_rank == 0 and self.attn_cp_rank == 0: dp_offset = self.attn_dp_rank * self.attn_tp_size recv_reqs = point_to_point_pyobj( [], @@ -1241,7 +1255,7 @@ def recv_requests( recv_reqs = self.input_blocker.handle(recv_reqs) if self.server_args.enable_dp_attention: - if self.attn_tp_rank == 0: + if self.attn_tp_rank == 0 and self.attn_cp_rank == 0: work_reqs, control_reqs = self._split_work_and_control_reqs(recv_reqs) else: work_reqs = None @@ -1254,6 +1268,15 @@ def recv_requests( self.attn_tp_cpu_group, src=self.attn_tp_group.ranks[0], ) + + if self.attn_cp_size != 1: + work_reqs = broadcast_pyobj( + work_reqs, + self.attn_cp_group.rank, + self.attn_cp_cpu_group, + src=self.attn_cp_group.ranks[0], + ) + if self.tp_size != 1: control_reqs = broadcast_pyobj( control_reqs, @@ -3003,6 +3026,8 @@ def run_scheduler_process( port_args: PortArgs, gpu_id: int, tp_rank: int, + attn_cp_rank: int, + moe_dp_rank: int, moe_ep_rank: int, pp_rank: int, dp_rank: Optional[int], @@ -3017,6 +3042,10 @@ def run_scheduler_process( prefix += f" DP{dp_rank}" if server_args.pp_size > 1: prefix += f" PP{pp_rank}" + if server_args.attn_cp_size > 1: + prefix += f" ATTN_CP{attn_cp_rank}" + if server_args.moe_dp_size > 1: + prefix += f" MOE_DP{moe_dp_rank}" if server_args.tp_size > 1: prefix += f" TP{tp_rank}" if server_args.ep_size > 1: @@ -3061,6 +3090,8 @@ def run_scheduler_process( tp_rank, moe_ep_rank, pp_rank, + attn_cp_rank, + moe_dp_rank, dp_rank, ) result_dict = { diff --git a/python/sglang/srt/managers/scheduler_dp_attn_mixin.py b/python/sglang/srt/managers/scheduler_dp_attn_mixin.py index 3a772d035f88..656b6e7c2417 100644 --- a/python/sglang/srt/managers/scheduler_dp_attn_mixin.py +++ b/python/sglang/srt/managers/scheduler_dp_attn_mixin.py @@ -25,6 +25,7 @@ class MLPSyncBatchInfo: dp_size: int tp_size: int + cp_size: int num_tokens: int num_tokens_for_logprob: int @@ -72,7 +73,7 @@ def _get_fallback_tensor(self, device, dtype=torch.int64) -> torch.Tensor: def all_gather(self, device, group: torch.distributed.ProcessGroup): local_info_tensor = self._get_local_tensor(device=device) global_info_tensor = torch.empty( - (self.dp_size, self.tp_size, 6), + (self.dp_size, self.tp_size * self.cp_size, 6), dtype=torch.int64, device=device, ) @@ -88,7 +89,7 @@ def all_gather(self, device, group: torch.distributed.ProcessGroup): tp_active_ranks = get_tp_group().active_ranks # Set fallback values for inactive ranks - tp_info = global_info_tensor.view(self.dp_size * self.tp_size, 6) + tp_info = global_info_tensor.view(self.dp_size * self.tp_size * self.cp_size, 6) tp_info[tp_active_ranks == 0] = self._get_fallback_tensor(device=device) tp0_info = global_info_tensor[:, 0, :] @@ -129,6 +130,7 @@ def prepare_mlp_sync_batch_raw( local_batch: ScheduleBatch, dp_size: int, attn_tp_size: int, + attn_cp_size: int, tp_group: GroupCoordinator, get_idle_batch: Callable[[], ScheduleBatch], disable_cuda_graph: bool, @@ -185,6 +187,7 @@ def prepare_mlp_sync_batch_raw( mlp_sync_info = MLPSyncBatchInfo( dp_size=dp_size, tp_size=attn_tp_size, + cp_size=attn_cp_size, num_tokens=num_tokens, num_tokens_for_logprob=num_tokens_for_logprob, can_cuda_graph=can_cuda_graph, @@ -226,6 +229,7 @@ def prepare_mlp_sync_batch(self: Scheduler, local_batch: ScheduleBatch): local_batch, dp_size=self.server_args.dp_size, attn_tp_size=self.attn_tp_size, + attn_cp_size=self.attn_cp_size, tp_group=self.tp_group, get_idle_batch=self.get_idle_batch, disable_cuda_graph=self.server_args.disable_cuda_graph, diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 7a334f6282cd..9dcfc9f3d9f7 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -212,6 +212,8 @@ def __init__( tp_rank: int, moe_ep_rank: int, pp_rank: int, + attn_cp_rank: int, + moe_dp_rank: int, dp_rank: Optional[int], nccl_port: int, is_draft_worker: bool = False, @@ -234,6 +236,8 @@ def __init__( self.is_multi_layer_eagle = is_multi_layer_eagle self.req_to_token_pool = req_to_token_pool self.token_to_kv_pool_allocator = token_to_kv_pool_allocator + self.attn_cp_rank = attn_cp_rank + self.moe_dp_rank = moe_dp_rank # MTP model runners self.model_runner_list: List[ModelRunner] = [] diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index b0b2ede6dbde..f1ea1de9d2e2 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -43,6 +43,7 @@ from sglang.srt.layers.attention.nsa.utils import is_nsa_enable_prefill_cp from sglang.srt.layers.dp_attention import ( DpPaddingMode, + get_attention_cp_size, get_attention_tp_rank, get_attention_tp_size, set_dp_buffer_len, @@ -204,6 +205,9 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner, num_tokens_per_bs=1): if require_gathered_buffer(server_args): mul_base *= get_attention_tp_size() + if mul_base % get_attention_cp_size() != 0: + mul_base *= get_attention_cp_size() + # Model input token count = bs * num_tokens_per_bs; must be a multiple of attn_tp_size. capture_bs = [bs for bs in capture_bs if bs * num_tokens_per_bs % mul_base == 0] diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 8d8e445cccde..a321226e34f1 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -45,6 +45,7 @@ from sglang.srt.layers.attention.nsa.utils import NSAContextParallelMetadata from sglang.srt.layers.dp_attention import ( DpPaddingMode, + get_attention_cp_size, get_attention_dp_rank, get_attention_tp_rank, get_attention_tp_size, @@ -749,6 +750,11 @@ def prepare_mlp_sync_batch(self, model_runner: ModelRunner): # there is no reduce-scatter in LM logprob, so we do not need to adjust the padded length for logprob global_num_tokens[i] = ceil_align(global_num_tokens[i], attn_tp_size) + # make sure that each rank has the same number of tokens to do collective communication. + attn_cp_size = get_attention_cp_size() + for i in range(sync_group_size): + global_num_tokens[i] = ceil_align(global_num_tokens[i], attn_cp_size) + dp_padding_mode = DpPaddingMode.get_dp_padding_mode( self.is_extend_in_batch, global_num_tokens ) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 7a93225b4e0c..0d6d64986dba 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -290,6 +290,8 @@ def __init__( nccl_port: int, server_args: ServerArgs, dp_rank: Optional[int] = None, + attn_cp_rank: Optional[int] = None, + moe_dp_rank: Optional[int] = None, is_draft_worker: bool = False, req_to_token_pool: Optional[ReqToTokenPool] = None, token_to_kv_pool_allocator: Optional[BaseTokenToKVPoolAllocator] = None, @@ -303,9 +305,13 @@ def __init__( self.tp_size = tp_size self.moe_ep_rank = moe_ep_rank self.moe_ep_size = moe_ep_size - self.dp_size = server_args.dp_size + self.dp_size = server_args.dp_size if server_args.enable_dp_attention else 1 self.pp_rank = pp_rank self.pp_size = pp_size + self.attn_cp_rank = attn_cp_rank + self.attn_cp_size = server_args.attn_cp_size + self.moe_dp_rank = moe_dp_rank + self.moe_dp_size = server_args.moe_dp_size self.model_config = model_config self.dist_port = nccl_port self.server_args = server_args @@ -586,8 +592,7 @@ def initialize(self, min_per_gpu_memory: float): ( self.max_total_num_tokens // 2 if server_args.max_running_requests is None - else server_args.max_running_requests - // (server_args.dp_size if server_args.enable_dp_attention else 1) + else server_args.max_running_requests // (self.dp_size) ), self.req_to_token_pool.size, ) @@ -797,8 +802,11 @@ def _(data, dim): ) initialize_model_parallel( tensor_model_parallel_size=self.tp_size, + attention_data_parallel_size=self.dp_size, pipeline_model_parallel_size=self.pp_size, expert_model_parallel_size=self.moe_ep_size, + attention_context_model_parallel_size=self.attn_cp_size, + moe_data_model_parallel_size=self.moe_dp_size, duplicate_tp_group=self.server_args.enable_pdmux, ) initialize_dp_attention( diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index e0583d308337..4b5eeb9517cd 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -72,6 +72,8 @@ ) from sglang.srt.layers.communicator_nsa_cp import NSACPLayerCommunicator from sglang.srt.layers.dp_attention import ( + get_attention_cp_rank, + get_attention_cp_size, get_attention_tp_rank, get_attention_tp_size, is_dp_attention_enabled, @@ -1097,9 +1099,7 @@ def __init__( assert self.use_nsa, "CP currently only supports deepseek v3.2 model" # cp reuse the attn_tp comm group but need to duplicate the weights if self.nsa_enable_prefill_cp and self.use_nsa: - attn_tp_rank = 0 - attn_tp_size = 1 - self.cp_size = get_attention_tp_size() + self.cp_size = get_attention_cp_size() self.num_heads = num_heads assert num_heads % attn_tp_size == 0 self.num_local_heads = num_heads // attn_tp_size @@ -2512,7 +2512,7 @@ def __init__( self.pp_group = get_pp_group() self.nsa_enable_prefill_cp = is_nsa_enable_prefill_cp() if self.nsa_enable_prefill_cp: - self.cp_size = get_attention_tp_size() + self.cp_size = get_attention_cp_size() else: self.cp_size = None @@ -2827,8 +2827,8 @@ def __init__( self.nsa_enable_prefill_cp = is_nsa_enable_prefill_cp() if self.nsa_enable_prefill_cp: - self.cp_rank = get_attention_tp_rank() - self.cp_size = get_attention_tp_size() + self.cp_rank = get_attention_cp_rank() + self.cp_size = get_attention_cp_size() else: self.cp_rank = self.cp_size = None diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index cf95d3e1b01c..ebf312c170ba 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -417,6 +417,9 @@ class ServerArgs: dp_size: int = 1 load_balance_method: str = "auto" + attn_cp_size: int = 1 + moe_dp_size: int = 1 + # Multi-node distributed serving dist_init_addr: Optional[str] = None nnodes: int = 1 @@ -745,6 +748,9 @@ def __post_init__(self): # Handle data parallelism. self._handle_data_parallelism() + # Handle context parallelism. + self._handle_context_parallelism() + # Handle MoE configurations. self._handle_moe_kernel_config() self._handle_a2a_moe() @@ -2042,6 +2048,32 @@ def _handle_grammar_backend(self): if self.grammar_backend is None: self.grammar_backend = "xgrammar" + def _handle_context_parallelism(self): + if self.attn_cp_size > 1: + # The tp_size is the world size, not the real tensor parallel size + assert ( + self.tp_size % self.attn_cp_size == 0 + ), "tp_size must be divisible by attn_cp_size" + assert ( + self.tp_size % (self.dp_size * self.attn_cp_size) == 0 + ), "tp_size must be divisible by dp_size * attn_cp_size" + assert self.pp_size == 1, "PP is not supported with context parallelism" + + if self.moe_dp_size > 1: + # The tp_size is the world size, not the real tensor parallel size + assert ( + self.tp_size % self.moe_dp_size == 0 + ), "tp_size must be divisible by moe_dp_size" + assert ( + self.ep_size * self.moe_dp_size <= self.tp_size + ), "ep_size * moe_dp_size must be less than or equal to tp_size" + assert self.pp_size == 1, "PP is not supported with context parallelism" + + if self.ep_size > 1: + assert ( + self.ep_size * self.moe_dp_size == self.tp_size + ), "ep_size * moe_dp_size must be equal to tp_size" + def _handle_data_parallelism(self): if self.dp_size == 1: self.enable_dp_attention = False @@ -3241,6 +3273,20 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.tp_size, help="The tensor parallelism size.", ) + parser.add_argument( + "--attention-context-parallel-size", + "--attn-cp-size", + type=int, + default=ServerArgs.attn_cp_size, + help="The attention context parallelism size.", + ) + parser.add_argument( + "--moe-data-parallel-size", + "--moe-dp-size", + type=int, + default=ServerArgs.moe_dp_size, + help="The moe data parallelism size.", + ) parser.add_argument( "--pipeline-parallel-size", "--pp-size", @@ -4989,6 +5035,8 @@ def add_cli_args(parser: argparse.ArgumentParser): def from_cli_args(cls, args: argparse.Namespace): args.tp_size = args.tensor_parallel_size args.pp_size = args.pipeline_parallel_size + args.attn_cp_size = args.attention_context_parallel_size + args.moe_dp_size = args.moe_data_parallel_size args.dp_size = args.data_parallel_size args.ep_size = args.expert_parallel_size diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 6d3a76f8fcf3..32b3a520a0db 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -84,6 +84,8 @@ def __init__( tp_rank: int, dp_rank: Optional[int], moe_ep_rank: int, + attn_cp_rank: int, + moe_dp_rank: int, nccl_port: int, target_worker: TpModelWorker, ): @@ -144,6 +146,8 @@ def __init__( pp_rank=0, # FIXME dp_rank=dp_rank, moe_ep_rank=moe_ep_rank, + attn_cp_rank=attn_cp_rank, + moe_dp_rank=moe_dp_rank, nccl_port=nccl_port, is_draft_worker=True, req_to_token_pool=self.req_to_token_pool, diff --git a/python/sglang/srt/speculative/eagle_worker_v2.py b/python/sglang/srt/speculative/eagle_worker_v2.py index a47c48bd06c1..f4affc9690a6 100644 --- a/python/sglang/srt/speculative/eagle_worker_v2.py +++ b/python/sglang/srt/speculative/eagle_worker_v2.py @@ -86,6 +86,8 @@ def __init__( tp_rank: int, dp_rank: int, moe_ep_rank: int, + attn_cp_rank: int, + moe_dp_rank: int, nccl_port: int, target_worker: TpModelWorker, ): @@ -97,6 +99,8 @@ def __init__( self.moe_ep_rank = moe_ep_rank self.nccl_port = nccl_port self.target_worker = target_worker + self.attn_cp_rank = attn_cp_rank + self.moe_dp_rank = moe_dp_rank # Args for easy access self.device = server_args.device @@ -134,6 +138,8 @@ def __init__( pp_rank=0, # FIXME dp_rank=dp_rank, moe_ep_rank=moe_ep_rank, + attn_cp_rank=attn_cp_rank, + moe_dp_rank=moe_dp_rank, nccl_port=nccl_port, is_draft_worker=True, req_to_token_pool=self.req_to_token_pool, @@ -582,6 +588,8 @@ def __init__( tp_rank: int, dp_rank: Optional[int], moe_ep_rank: int, + attn_cp_rank: int, + moe_dp_rank: int, nccl_port: int, target_worker: TpModelWorker, ): @@ -608,7 +616,15 @@ def __init__( server_args.context_length = target_worker.model_runner.model_config.context_len self._draft_worker = EagleDraftWorker( - server_args, gpu_id, tp_rank, dp_rank, moe_ep_rank, nccl_port, target_worker + server_args, + gpu_id, + tp_rank, + dp_rank, + moe_ep_rank, + attn_cp_rank, + moe_dp_rank, + nccl_port, + target_worker, ) # Some dummy tensors diff --git a/python/sglang/srt/speculative/multi_layer_eagle_worker.py b/python/sglang/srt/speculative/multi_layer_eagle_worker.py index 9369396a7d3a..dfd98b943300 100644 --- a/python/sglang/srt/speculative/multi_layer_eagle_worker.py +++ b/python/sglang/srt/speculative/multi_layer_eagle_worker.py @@ -76,6 +76,8 @@ def __init__( tp_rank: int, dp_rank: Optional[int], moe_ep_rank: int, + attn_cp_rank: int, + moe_dp_rank: int, nccl_port: int, target_worker: TpModelWorker, ): @@ -135,6 +137,8 @@ def __init__( pp_rank=0, # FIXME dp_rank=dp_rank, moe_ep_rank=moe_ep_rank, + attn_cp_rank=attn_cp_rank, + moe_dp_rank=moe_dp_rank, nccl_port=nccl_port, is_draft_worker=True, req_to_token_pool=self.req_to_token_pool, diff --git a/python/sglang/srt/speculative/multi_layer_eagle_worker_v2.py b/python/sglang/srt/speculative/multi_layer_eagle_worker_v2.py index dbbfa5bba486..7f660b9f085c 100644 --- a/python/sglang/srt/speculative/multi_layer_eagle_worker_v2.py +++ b/python/sglang/srt/speculative/multi_layer_eagle_worker_v2.py @@ -70,6 +70,8 @@ def __init__( tp_rank: int, dp_rank: int, moe_ep_rank: int, + attn_cp_rank: int, + moe_dp_rank: int, nccl_port: int, target_worker: TpModelWorker, ): @@ -117,6 +119,8 @@ def __init__( pp_rank=0, # FIXME dp_rank=dp_rank, moe_ep_rank=moe_ep_rank, + attn_cp_rank=attn_cp_rank, + moe_dp_rank=moe_dp_rank, nccl_port=nccl_port, is_draft_worker=True, req_to_token_pool=self.req_to_token_pool, @@ -532,6 +536,8 @@ def __init__( tp_rank: int, dp_rank: Optional[int], moe_ep_rank: int, + attn_cp_rank: int, + moe_dp_rank: int, nccl_port: int, target_worker: TpModelWorker, ): @@ -557,7 +563,15 @@ def __init__( server_args.context_length = target_worker.model_runner.model_config.context_len self._draft_worker = MultiLayerEagleDraftWorker( - server_args, gpu_id, tp_rank, dp_rank, moe_ep_rank, nccl_port, target_worker + server_args, + gpu_id, + tp_rank, + dp_rank, + moe_ep_rank, + attn_cp_rank, + moe_dp_rank, + nccl_port, + target_worker, ) # Some dummy tensors diff --git a/python/sglang/srt/speculative/ngram_worker.py b/python/sglang/srt/speculative/ngram_worker.py index 0a830fd95dbc..7f6277bb8fe4 100644 --- a/python/sglang/srt/speculative/ngram_worker.py +++ b/python/sglang/srt/speculative/ngram_worker.py @@ -30,6 +30,8 @@ def __init__( tp_rank: int, dp_rank: Optional[int], moe_ep_rank: int, + attn_cp_rank: int, + moe_dp_rank: int, nccl_port: int, target_worker: TpModelWorker, ): diff --git a/python/sglang/srt/speculative/standalone_worker.py b/python/sglang/srt/speculative/standalone_worker.py index e1f331975bde..4d7ca30e3e80 100644 --- a/python/sglang/srt/speculative/standalone_worker.py +++ b/python/sglang/srt/speculative/standalone_worker.py @@ -30,6 +30,8 @@ def __init__( tp_rank: int, dp_rank: Optional[int], moe_ep_rank: int, + attn_cp_rank: int, + moe_dp_rank: int, nccl_port: int, target_worker: TpModelWorker, ): @@ -79,6 +81,8 @@ def __init__( pp_rank=0, # FIXME dp_rank=dp_rank, moe_ep_rank=moe_ep_rank, + attn_cp_rank=attn_cp_rank, + moe_dp_rank=moe_dp_rank, nccl_port=nccl_port, is_draft_worker=True, req_to_token_pool=self.req_to_token_pool, diff --git a/python/sglang/srt/speculative/standalone_worker_v2.py b/python/sglang/srt/speculative/standalone_worker_v2.py index da6e3523d0f6..26ef88548ee7 100644 --- a/python/sglang/srt/speculative/standalone_worker_v2.py +++ b/python/sglang/srt/speculative/standalone_worker_v2.py @@ -42,6 +42,8 @@ def __init__( tp_rank: int, dp_rank: int, moe_ep_rank: int, + attn_cp_rank: int, + moe_dp_rank: int, nccl_port: int, target_worker: TpModelWorker, ): @@ -53,6 +55,8 @@ def __init__( self.moe_ep_rank = moe_ep_rank self.nccl_port = nccl_port self.target_worker = target_worker + self.attn_cp_rank = attn_cp_rank + self.moe_dp_rank = moe_dp_rank # Args for easy access self.device = server_args.device @@ -89,6 +93,8 @@ def __init__( pp_rank=0, # FIXME dp_rank=dp_rank, moe_ep_rank=moe_ep_rank, + attn_cp_rank=attn_cp_rank, + moe_dp_rank=moe_dp_rank, nccl_port=nccl_port, is_draft_worker=True, req_to_token_pool=self.req_to_token_pool, @@ -131,6 +137,8 @@ def __init__( tp_rank: int, dp_rank: Optional[int], moe_ep_rank: int, + attn_cp_rank: int, + moe_dp_rank: int, nccl_port: int, target_worker: TpModelWorker, ): @@ -157,7 +165,15 @@ def __init__( # Create our custom draft worker that doesn't share embeddings/lm_head self._draft_worker = StandaloneDraftWorker( - server_args, gpu_id, tp_rank, dp_rank, moe_ep_rank, nccl_port, target_worker + server_args, + gpu_id, + tp_rank, + dp_rank, + moe_ep_rank, + attn_cp_rank, + moe_dp_rank, + nccl_port, + target_worker, ) # Some dummy tensors diff --git a/test/registered/8-gpu-models/test_deepseek_v32_cp_single_node.py b/test/registered/8-gpu-models/test_deepseek_v32_cp_single_node.py index f06252d24760..7c31b028a8e6 100644 --- a/test/registered/8-gpu-models/test_deepseek_v32_cp_single_node.py +++ b/test/registered/8-gpu-models/test_deepseek_v32_cp_single_node.py @@ -18,6 +18,7 @@ DP_ARGS = [ "--tp=8", "--dp=2", + "--attn-cp-size=4", "--enable-dp-attention", ] @@ -43,6 +44,7 @@ CP_ROUND_ROBIN_ARGS = [ "--enable-nsa-prefill-context-parallel", "--nsa-prefill-cp-mode=round-robin-split", + "--attn-cp-size=8", ] diff --git a/test/registered/distributed/test_parallel_state.py b/test/registered/distributed/test_parallel_state.py new file mode 100644 index 000000000000..6bfbb1073341 --- /dev/null +++ b/test/registered/distributed/test_parallel_state.py @@ -0,0 +1,290 @@ +""" +Test file to verify the correctness of parallel group calculations. + +This test validates that the parallel group initialization creates the correct +groups for different parallelism configurations including: +- Tensor parallelism (TP) +- Pipeline parallelism (PP) +- Attention context parallelism (attn_cp) +- Attention data parallelism (attn_dp) +- MoE expert parallelism (EP) +- MoE data parallelism (moe_dp) + +These tests call the ACTUAL initialize_model_parallel() function with mocked +distributed backend to verify the group construction logic. + +## How These Tests Work + +initialize_model_parallel() creates ALL groups for ALL ranks in a single call. +For example, when creating TP groups with tp_size=2 and world_size=8: + + group_ranks = [[0,1], [2,3], [4,5], [6,7]] # ALL groups created + _TP = init_model_parallel_group(group_ranks, local_rank, ...) + +ALL ranks call this function and get the same complete group structure. Each rank +then figures out which specific group(s) it belongs to. + +Our tests: +1. Mock the distributed backend (no real GPUs needed) +2. Mock init_model_parallel_group to capture the group_ranks parameter +3. Call the real initialize_model_parallel() +4. Verify group_ranks contains the expected complete group structure + +We only need to simulate rank 0 because we're testing the group creation logic, +not the per-rank group membership logic. +""" + +from __future__ import annotations + +import sys +from unittest.mock import Mock, patch + +import pytest + +from sglang.test.ci.ci_register import register_cuda_ci + +register_cuda_ci(est_time=5, suite="stage-b-test-small-1-gpu") + +# Import the actual parallel_state module +parallel_state = pytest.importorskip("sglang.srt.distributed.parallel_state") + + +def test_parallel_group_construction_tp8_attn_cp2(): + """ + Test parallel group construction for 8 GPU configuration with: + - tensor_model_parallel_size = 8 + - attention_context_model_parallel_size = 2 + + Expected groups based on docstring example: + 1 tensor model-parallel group: + [g0, g1, g2, g3, g4, g5, g6, g7] + 4 attention context-parallel groups: + [g0, g4], [g1, g5], [g2, g6], [g3, g7] + + This test calls the ACTUAL initialize_model_parallel() and verifies the groups. + + Note: We simulate only rank 0 here, but initialize_model_parallel() creates + ALL groups for ALL ranks in a single call. We capture these groups via mocking + and verify the complete group structure. + """ + world_size = 8 + + # Mock the distributed backend + # Note: get_rank() returns 0 because we're testing from a single process, + # but initialize_model_parallel() still creates all groups for all ranks + with patch.object(parallel_state, "_WORLD", None), patch.object( + parallel_state, "_TP", None + ), patch.object(parallel_state, "_ATTN_CP", None), patch.object( + parallel_state, "_ATTN_TP", None + ), patch.object( + parallel_state, "_PP", None + ), patch( + "torch.distributed.is_initialized", return_value=True + ), patch( + "torch.distributed.get_world_size", return_value=world_size + ), patch( + "torch.distributed.get_rank", return_value=0 + ), patch( + "torch.distributed.get_backend", return_value="nccl" + ): + + # Mock init_model_parallel_group to capture the groups being created + created_groups = {} + + def mock_init_model_parallel_group(group_ranks, local_rank, backend, **kwargs): + group_name = kwargs.get("group_name", "unknown") + created_groups[group_name] = group_ranks + + # Create a mock group object + mock_group = Mock() + mock_group.device_group = Mock() + return mock_group + + with patch.object( + parallel_state, + "init_model_parallel_group", + side_effect=mock_init_model_parallel_group, + ), patch.object(parallel_state, "get_world_group") as mock_world_group: + + # Mock world group + mock_world = Mock() + mock_world.device_group = Mock() + mock_world.local_rank = 0 + mock_world_group.return_value = mock_world + + # Call the actual function + parallel_state.initialize_model_parallel( + tensor_model_parallel_size=8, + pipeline_model_parallel_size=1, + attention_context_model_parallel_size=2, + ) + + # Verify TP groups + tp_groups = created_groups.get("tp", []) + assert len(tp_groups) == 1, f"Expected 1 TP group, got {len(tp_groups)}" + assert tp_groups[0] == [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + ], f"Wrong TP group: {tp_groups[0]}" + + # Verify ATTN_CP groups + attn_cp_groups = created_groups.get("attn_cp", []) + assert ( + len(attn_cp_groups) == 4 + ), f"Expected 4 ATTN_CP groups, got {len(attn_cp_groups)}" + expected_attn_cp = [ + [0, 4], + [1, 5], + [2, 6], + [3, 7], + ] + assert ( + attn_cp_groups == expected_attn_cp + ), f"Wrong ATTN_CP groups: {attn_cp_groups}" + + print("TP=8, Attn CP=2 group construction verified") + + # Cleanup + parallel_state.destroy_model_parallel() + + +def test_parallel_group_construction_tp8_moe_ep4_cp2(): + """ + Test parallel group construction for 8 GPU configuration with: + - tensor_model_parallel_size = 8 + - expert_model_parallel_size = 4 + - moe_data_model_parallel_size = 2 + + Expected groups: + 1 tensor model-parallel group: + [g0, g1, g2, g3, g4, g5, g6, g7] + 2 MoE expert-parallel groups: + [g0, g1, g2, g3], [g4, g5, g6, g7] + 4 MoE data-parallel groups: + [g0, g4], [g1, g5], [g2, g6], [g3, g7] + """ + world_size = 8 + + # Mock the distributed backend + with patch.object(parallel_state, "_WORLD", None), patch.object( + parallel_state, "_TP", None + ), patch.object(parallel_state, "_MOE_EP", None), patch.object( + parallel_state, "_MOE_DP", None + ), patch.object( + parallel_state, "_MOE_TP", None + ), patch.object( + parallel_state, "_PP", None + ), patch( + "torch.distributed.is_initialized", return_value=True + ), patch( + "torch.distributed.get_world_size", return_value=world_size + ), patch( + "torch.distributed.get_rank", return_value=0 + ), patch( + "torch.distributed.get_backend", return_value="nccl" + ): + + # Mock init_model_parallel_group to capture the groups being created + created_groups = {} + + def mock_init_model_parallel_group(group_ranks, local_rank, backend, **kwargs): + group_name = kwargs.get("group_name", "unknown") + created_groups[group_name] = group_ranks + + # Create a mock group object + mock_group = Mock() + mock_group.device_group = Mock() + return mock_group + + with patch.object( + parallel_state, + "init_model_parallel_group", + side_effect=mock_init_model_parallel_group, + ), patch.object(parallel_state, "get_world_group") as mock_world_group: + + # Mock world group + mock_world = Mock() + mock_world.device_group = Mock() + mock_world.local_rank = 0 + mock_world_group.return_value = mock_world + + # Call the actual function + parallel_state.initialize_model_parallel( + tensor_model_parallel_size=8, + expert_model_parallel_size=4, + pipeline_model_parallel_size=1, + moe_data_model_parallel_size=2, + ) + + # Verify TP groups + tp_groups = created_groups.get("tp", []) + assert len(tp_groups) == 1, f"Expected 1 TP group, got {len(tp_groups)}" + assert tp_groups[0] == [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + ], f"Wrong TP group: {tp_groups[0]}" + + # Verify MOE_EP groups + moe_ep_groups = created_groups.get("moe_ep", []) + assert ( + len(moe_ep_groups) == 2 + ), f"Expected 2 MOE_EP groups, got {len(moe_ep_groups)}" + expected_moe_ep = [ + [0, 1, 2, 3], + [4, 5, 6, 7], + ] + assert ( + moe_ep_groups == expected_moe_ep + ), f"Wrong MOE_EP groups: {moe_ep_groups}" + + # Verify MOE_DP groups + moe_dp_groups = created_groups.get("moe_dp", []) + assert ( + len(moe_dp_groups) == 4 + ), f"Expected 4 MOE_DP groups, got {len(moe_dp_groups)}" + expected_moe_dp = [ + [0, 4], + [1, 5], + [2, 6], + [3, 7], + ] + assert ( + moe_dp_groups == expected_moe_dp + ), f"Wrong MOE_DP groups: {moe_dp_groups}" + + print("TP=8, MoE EP=4, MoE CP=2 group construction verified") + + # Cleanup + parallel_state.destroy_model_parallel() + + +if __name__ == "__main__": + # Run tests without requiring GPUs + import sys + + try: + test_parallel_group_construction_tp8_attn_cp2() + test_parallel_group_construction_tp8_moe_ep4_cp2() + + sys.exit(0) + except AssertionError as e: + print(f"\n Test failed: {e}") + sys.exit(1) + except Exception as e: + print(f"\n Unexpected error: {e}") + import traceback + + traceback.print_exc() + sys.exit(1)