Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 57 additions & 33 deletions python/sglang/srt/layers/attention/nsa/nsa_indexer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import contextlib
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple

Expand All @@ -23,6 +24,7 @@
import torch_npu
from sglang.srt.hardware_backend.npu.utils import get_indexer_weight_stream

from sglang.srt.distributed.parallel_state import get_pp_group
from sglang.srt.layers import deep_gemm_wrapper
from sglang.srt.layers.attention.nsa.utils import (
cp_all_gather_rerange_output,
Expand Down Expand Up @@ -146,6 +148,10 @@ def __init__(
if is_cuda():
self.sm_count = deep_gemm.get_num_sms()
self.half_device_sm_count = ceil_align(self.sm_count // 2, 8)
pp_size = get_global_server_args().pp_size
self.logits_with_pp_recv = pp_size > 1 and not get_pp_group().is_last_rank
else:
self.logits_with_pp_recv = False

self.wq_b = ReplicatedLinear(
self.q_lora_rank,
Expand Down Expand Up @@ -184,6 +190,21 @@ def __init__(
self.scale_fmt = scale_fmt
self.softmax_scale = self.head_dim**-0.5

@contextlib.contextmanager
def _with_real_sm_count(self):
# When pipeline parallelism is enabled, each PP rank initiates a recv operation after the _pp_launch_batch
# request to receive the PP proxy tensor or output from the previous stage, occupying one SM resource.
# Model execution runs in parallel with the recv operation, so the SMs available to the indexer must be reduced
# by 1. Currently, the last rank starts the send result + recv request only after waiting for execution results.
if self.logits_with_pp_recv:
pp_recv_sm_count = 1
with deep_gemm_wrapper.configure_deep_gemm_num_sms(
self.sm_count - pp_recv_sm_count
):
yield
else:
yield

@torch.compile(dynamic=True)
def _get_logits_head_gate(self, x: torch.Tensor, q_scale: torch.Tensor):
weights, _ = self.weights_proj(x.float())
Expand Down Expand Up @@ -333,7 +354,6 @@ def _get_topk_paged(
)
assert len(weights.shape) == 3
weights = weights.squeeze(2)

logits = deep_gemm.fp8_paged_mqa_logits(
q_fp8,
kv_cache_fp8,
Expand Down Expand Up @@ -432,14 +452,15 @@ def _get_topk_ragged(

if not need_chunk:
assert q_fp8[:q_offset].shape[0] != 0
logits = deep_gemm.fp8_mqa_logits(
q_fp8[:q_offset],
kv_fp8,
weights[:q_offset],
ks,
ke,
clean_logits=False,
)
with self._with_real_sm_count():
logits = deep_gemm.fp8_mqa_logits(
q_fp8[:q_offset],
kv_fp8,
weights[:q_offset],
ks,
ke,
clean_logits=False,
)
assert logits.shape[0] == len(seq_lens_expanded)
assert logits.shape[1] == k_offset

Expand Down Expand Up @@ -468,14 +489,15 @@ def _get_topk_ragged(
while start < q_offset:
end = min(start + max_rows, q_offset)

logits_chunk = deep_gemm.fp8_mqa_logits(
q_fp8[start:end],
kv_fp8,
weights[start:end],
ks[start:end],
ke[start:end],
clean_logits=False,
)
with self._with_real_sm_count():
logits_chunk = deep_gemm.fp8_mqa_logits(
q_fp8[start:end],
kv_fp8,
weights[start:end],
ks[start:end],
ke[start:end],
clean_logits=False,
)

lengths_chunk = seq_lens_expanded[start:end]

Expand Down Expand Up @@ -630,14 +652,15 @@ def _get_topk_ragged_with_cp(
ke_offset = torch.cat(ke_offset_list, dim=0)
ke = ks + ke_offset
actual_seq_q = torch.cat(actual_seq_q_list, dim=0)
logits = deep_gemm.fp8_mqa_logits(
q_fp8,
kv_fp8,
weights,
ks,
ke,
clean_logits=False,
)
with self._with_real_sm_count():
logits = deep_gemm.fp8_mqa_logits(
q_fp8,
kv_fp8,
weights,
ks,
ke,
clean_logits=False,
)
topk_result = metadata.topk_transform(
logits,
self.index_topk,
Expand Down Expand Up @@ -675,14 +698,15 @@ def _get_topk_ragged_with_cp(
)
ke = ks + ke_offset

logits = deep_gemm.fp8_mqa_logits(
q_fp8,
kv_fp8,
weights,
ks,
ke,
clean_logits=False,
)
with self._with_real_sm_count():
logits = deep_gemm.fp8_mqa_logits(
q_fp8,
kv_fp8,
weights,
ks,
ke,
clean_logits=False,
)
actual_seq_q = torch.tensor([actual_seq_q], dtype=torch.int32).to(
device="cuda", non_blocking=True
)
Expand Down
18 changes: 15 additions & 3 deletions python/sglang/srt/managers/scheduler_pp_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,10 @@ def event_loop_pp_disagg_decode(self: Scheduler):

def init_pp_loop_state(self: Scheduler):
self.pp_loop_size: int = self.pp_size + self.server_args.pp_async_batch_depth
# In CP mode, attention weights are duplicated, eliminating the need for the attention TP all-gather operation.
self.require_attn_tp_allgather = (
not self.server_args.enable_nsa_prefill_context_parallel
)
self.mbs = [None] * self.pp_loop_size
self.last_mbs = [None] * self.pp_loop_size
self.running_mbs = [
Expand Down Expand Up @@ -906,7 +910,9 @@ def _pp_send_dict_to_next_stage(
p2p_work.extend(
self.pp_group.send_tensor_dict(
tensor_dict=tensor_dict,
all_gather_group=self.attn_tp_group,
all_gather_group=(
self.attn_tp_group if self.require_attn_tp_allgather else None
),
async_send=async_send,
)
)
Expand All @@ -916,15 +922,21 @@ def _pp_recv_proxy_tensors(self: Scheduler) -> Optional[PPProxyTensors]:
pp_proxy_tensors = None
if not self.pp_group.is_first_rank:
pp_proxy_tensors = PPProxyTensors(
self.pp_group.recv_tensor_dict(all_gather_group=self.attn_tp_group)
self.pp_group.recv_tensor_dict(
all_gather_group=(
self.attn_tp_group if self.require_attn_tp_allgather else None
)
)
)
return pp_proxy_tensors

def _pp_recv_dict_from_prev_stage(
self: Scheduler,
) -> Dict[str, torch.Tensor]:
res = self.pp_group.recv_tensor_dict(
all_gather_group=self.attn_tp_group,
all_gather_group=(
self.attn_tp_group if self.require_attn_tp_allgather else None
),
)
return res

Expand Down
Loading