diff --git a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py index 7d19d0f90ba4..d4ce2271c7ca 100644 --- a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py +++ b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py @@ -1,5 +1,6 @@ from __future__ import annotations +import contextlib from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple @@ -23,6 +24,7 @@ import torch_npu from sglang.srt.hardware_backend.npu.utils import get_indexer_weight_stream +from sglang.srt.distributed.parallel_state import get_pp_group from sglang.srt.layers import deep_gemm_wrapper from sglang.srt.layers.attention.nsa.utils import ( cp_all_gather_rerange_output, @@ -146,6 +148,10 @@ def __init__( if is_cuda(): self.sm_count = deep_gemm.get_num_sms() self.half_device_sm_count = ceil_align(self.sm_count // 2, 8) + pp_size = get_global_server_args().pp_size + self.logits_with_pp_recv = pp_size > 1 and not get_pp_group().is_last_rank + else: + self.logits_with_pp_recv = False self.wq_b = ReplicatedLinear( self.q_lora_rank, @@ -184,6 +190,21 @@ def __init__( self.scale_fmt = scale_fmt self.softmax_scale = self.head_dim**-0.5 + @contextlib.contextmanager + def _with_real_sm_count(self): + # When pipeline parallelism is enabled, each PP rank initiates a recv operation after the _pp_launch_batch + # request to receive the PP proxy tensor or output from the previous stage, occupying one SM resource. + # Model execution runs in parallel with the recv operation, so the SMs available to the indexer must be reduced + # by 1. Currently, the last rank starts the send result + recv request only after waiting for execution results. + if self.logits_with_pp_recv: + pp_recv_sm_count = 1 + with deep_gemm_wrapper.configure_deep_gemm_num_sms( + self.sm_count - pp_recv_sm_count + ): + yield + else: + yield + @torch.compile(dynamic=True) def _get_logits_head_gate(self, x: torch.Tensor, q_scale: torch.Tensor): weights, _ = self.weights_proj(x.float()) @@ -333,7 +354,6 @@ def _get_topk_paged( ) assert len(weights.shape) == 3 weights = weights.squeeze(2) - logits = deep_gemm.fp8_paged_mqa_logits( q_fp8, kv_cache_fp8, @@ -432,14 +452,15 @@ def _get_topk_ragged( if not need_chunk: assert q_fp8[:q_offset].shape[0] != 0 - logits = deep_gemm.fp8_mqa_logits( - q_fp8[:q_offset], - kv_fp8, - weights[:q_offset], - ks, - ke, - clean_logits=False, - ) + with self._with_real_sm_count(): + logits = deep_gemm.fp8_mqa_logits( + q_fp8[:q_offset], + kv_fp8, + weights[:q_offset], + ks, + ke, + clean_logits=False, + ) assert logits.shape[0] == len(seq_lens_expanded) assert logits.shape[1] == k_offset @@ -468,14 +489,15 @@ def _get_topk_ragged( while start < q_offset: end = min(start + max_rows, q_offset) - logits_chunk = deep_gemm.fp8_mqa_logits( - q_fp8[start:end], - kv_fp8, - weights[start:end], - ks[start:end], - ke[start:end], - clean_logits=False, - ) + with self._with_real_sm_count(): + logits_chunk = deep_gemm.fp8_mqa_logits( + q_fp8[start:end], + kv_fp8, + weights[start:end], + ks[start:end], + ke[start:end], + clean_logits=False, + ) lengths_chunk = seq_lens_expanded[start:end] @@ -630,14 +652,15 @@ def _get_topk_ragged_with_cp( ke_offset = torch.cat(ke_offset_list, dim=0) ke = ks + ke_offset actual_seq_q = torch.cat(actual_seq_q_list, dim=0) - logits = deep_gemm.fp8_mqa_logits( - q_fp8, - kv_fp8, - weights, - ks, - ke, - clean_logits=False, - ) + with self._with_real_sm_count(): + logits = deep_gemm.fp8_mqa_logits( + q_fp8, + kv_fp8, + weights, + ks, + ke, + clean_logits=False, + ) topk_result = metadata.topk_transform( logits, self.index_topk, @@ -675,14 +698,15 @@ def _get_topk_ragged_with_cp( ) ke = ks + ke_offset - logits = deep_gemm.fp8_mqa_logits( - q_fp8, - kv_fp8, - weights, - ks, - ke, - clean_logits=False, - ) + with self._with_real_sm_count(): + logits = deep_gemm.fp8_mqa_logits( + q_fp8, + kv_fp8, + weights, + ks, + ke, + clean_logits=False, + ) actual_seq_q = torch.tensor([actual_seq_q], dtype=torch.int32).to( device="cuda", non_blocking=True ) diff --git a/python/sglang/srt/managers/scheduler_pp_mixin.py b/python/sglang/srt/managers/scheduler_pp_mixin.py index a1a58fe38435..6a3069868abe 100644 --- a/python/sglang/srt/managers/scheduler_pp_mixin.py +++ b/python/sglang/srt/managers/scheduler_pp_mixin.py @@ -504,6 +504,10 @@ def event_loop_pp_disagg_decode(self: Scheduler): def init_pp_loop_state(self: Scheduler): self.pp_loop_size: int = self.pp_size + self.server_args.pp_async_batch_depth + # In CP mode, attention weights are duplicated, eliminating the need for the attention TP all-gather operation. + self.require_attn_tp_allgather = ( + not self.server_args.enable_nsa_prefill_context_parallel + ) self.mbs = [None] * self.pp_loop_size self.last_mbs = [None] * self.pp_loop_size self.running_mbs = [ @@ -906,7 +910,9 @@ def _pp_send_dict_to_next_stage( p2p_work.extend( self.pp_group.send_tensor_dict( tensor_dict=tensor_dict, - all_gather_group=self.attn_tp_group, + all_gather_group=( + self.attn_tp_group if self.require_attn_tp_allgather else None + ), async_send=async_send, ) ) @@ -916,7 +922,11 @@ def _pp_recv_proxy_tensors(self: Scheduler) -> Optional[PPProxyTensors]: pp_proxy_tensors = None if not self.pp_group.is_first_rank: pp_proxy_tensors = PPProxyTensors( - self.pp_group.recv_tensor_dict(all_gather_group=self.attn_tp_group) + self.pp_group.recv_tensor_dict( + all_gather_group=( + self.attn_tp_group if self.require_attn_tp_allgather else None + ) + ) ) return pp_proxy_tensors @@ -924,7 +934,9 @@ def _pp_recv_dict_from_prev_stage( self: Scheduler, ) -> Dict[str, torch.Tensor]: res = self.pp_group.recv_tensor_dict( - all_gather_group=self.attn_tp_group, + all_gather_group=( + self.attn_tp_group if self.require_attn_tp_allgather else None + ), ) return res