Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 26 additions & 5 deletions python/sglang/srt/layers/attention/nsa/nsa_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,8 @@ def _get_topk_ragged(
k_scale_list = []
ks_list = []
ke_list = []
# Token-to-batch mapping for PAGED chunk alignment
token_to_batch_idx: List[int] = []

q_offset = 0
k_offset = 0
Expand Down Expand Up @@ -401,6 +403,7 @@ def _get_topk_ragged(
ks_list.append(ks)
ke_list.append(ke)

token_to_batch_idx.extend([i] * extend_seq_len)
q_offset += extend_seq_len
k_offset += seq_len

Expand Down Expand Up @@ -473,6 +476,13 @@ def _get_topk_ragged(
(token_nums, self.index_topk), -1, device=device, dtype=torch.int32
)

# Only materialize batch index tensor when PAGED path needs it
token_to_batch_idx_tensor = None
if global_topk_offset is None:
token_to_batch_idx_tensor = torch.tensor(
token_to_batch_idx, dtype=torch.long, device=device
)

start = 0
while start < q_offset:
end = min(start + max_rows, q_offset)
Expand All @@ -488,17 +498,28 @@ def _get_topk_ragged(

lengths_chunk = seq_lens_expanded[start:end]

topk_offset_chunk = (
global_topk_offset[start:end]
if global_topk_offset is not None
else None
)
# RAGGED: use global offset; PAGED: construct local cu_seqlens_q per chunk
if global_topk_offset is not None:
# RAGGED path
topk_offset_chunk = global_topk_offset[start:end]
cu_seqlens_q_chunk = None
batch_idx_chunk = None
else:
# PAGED path: treat each token as a length-1 sequence
topk_offset_chunk = None
B_chunk = logits_chunk.shape[0]
cu_seqlens_q_chunk = torch.ones(
B_chunk, dtype=torch.int32, device=device
)
batch_idx_chunk = token_to_batch_idx_tensor[start:end]

raw_topk_chunk = metadata.topk_transform(
logits_chunk,
self.index_topk,
ks=ks[start:end],
cu_seqlens_q=cu_seqlens_q_chunk,
ke_offset=lengths_chunk,
batch_idx_list=batch_idx_chunk,
topk_indices_offset_override=topk_offset_chunk,
)
topk_result[start:end] = raw_topk_chunk
Expand Down
58 changes: 0 additions & 58 deletions test/nightly/test_deepseek_v32_nsabackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,64 +197,6 @@ def test_a_gsm8k(
self.assertGreater(metrics["accuracy"], 0.935)


@unittest.skip("Temporary skip pure TP test")
class TestDeepseekV32NasBackend_pure_tp(CustomTestCase):
"""Test DeepSeek V3.2 with pure TP mode (no DP attention)."""

@classmethod
def setUpClass(cls):
cls.model = DEEPSEEK_V32_MODEL_PATH
cls.base_url = DEFAULT_URL_FOR_TEST
# Pure TP configuration without --dp and --enable-dp-attention
other_args = [
"--trust-remote-code",
"--attention-backend",
"nsa",
"--nsa-prefill-backend",
"flashmla_sparse",
"--nsa-decode-backend",
"flashmla_kv",
"--tp",
"8",
]
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=other_args,
)

@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)

def test_a_gsm8k(self):
"""Test GSM8K accuracy with pure TP mode."""
args = SimpleNamespace(
num_shots=20,
data_path=None,
num_questions=1400,
parallel=1400,
max_new_tokens=512,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval_few_shot_gsm8k(args)
print(f"{metrics=}")

if is_in_ci():
TEST_RESULTS.append(
{
"variant": "pure_tp",
"prefill_backend": "flashmla_sparse",
"decode_backend": "flashmla_kv",
"kv_cache": "fp16",
"accuracy": metrics["accuracy"],
}
)
self.assertGreater(metrics["accuracy"], 0.935)


def _write_summary_table():
"""Write a markdown table with all test results."""
if not TEST_RESULTS:
Expand Down
169 changes: 169 additions & 0 deletions test/nightly/test_deepseek_v32_tp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import os
import unittest
from types import SimpleNamespace

from sglang.srt.utils import kill_process_tree
from sglang.test.ci.ci_register import register_cuda_ci
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
is_in_ci,
popen_launch_server,
write_github_step_summary,
)

register_cuda_ci(est_time=600, suite="nightly-8-gpu-h200", nightly=True)

DEEPSEEK_V32_MODEL_PATH = "deepseek-ai/DeepSeek-V3.2-Exp"

# Global list to collect results
TEST_RESULTS = []


class TestDeepseekV32_TP(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEEPSEEK_V32_MODEL_PATH
cls.base_url = DEFAULT_URL_FOR_TEST
# Pure TP configuration without --dp and --enable-dp-attention
other_args = [
"--trust-remote-code",
"--attention-backend",
"nsa",
"--nsa-prefill-backend",
"flashmla_sparse",
"--nsa-decode-backend",
"flashmla_kv",
"--tp",
"8",
]
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=other_args,
)

@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)

def test_a_gsm8k(
self,
): # Append an "a" to make this test run first (alphabetically) to warm up the server
args = SimpleNamespace(
num_shots=20,
data_path=None,
num_questions=1400,
parallel=1400,
max_new_tokens=512,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval_few_shot_gsm8k(args)
print(f"{metrics=}")

if is_in_ci():
TEST_RESULTS.append(
{
"variant": "pure_tp",
"prefill_backend": "flashmla_sparse",
"decode_backend": "flashmla_kv",
"kv_cache": "fp16",
"accuracy": metrics["accuracy"],
}
)
self.assertGreater(metrics["accuracy"], 0.935)


class TestDeepseekV32_Partial_TP(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEEPSEEK_V32_MODEL_PATH
cls.base_url = DEFAULT_URL_FOR_TEST
# Partial TP configuration with dp=4 and dp-attention enabled
other_args = [
"--trust-remote-code",
"--attention-backend",
"nsa",
"--nsa-prefill-backend",
"flashmla_sparse",
"--nsa-decode-backend",
"flashmla_kv",
"--tp",
"8",
"--dp",
"4",
"--enable-dp-attention",
]
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=other_args,
)

@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)

def test_a_gsm8k(
self,
): # Append an "a" to make this test run first (alphabetically) to warm up the server
args = SimpleNamespace(
num_shots=20,
data_path=None,
num_questions=1400,
parallel=1400,
max_new_tokens=512,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval_few_shot_gsm8k(args)
print(f"{metrics=}")

if is_in_ci():
TEST_RESULTS.append(
{
"variant": "partial_tp",
"prefill_backend": "flashmla_sparse",
"decode_backend": "flashmla_kv",
"kv_cache": "fp16",
"accuracy": metrics["accuracy"],
}
)

# Write the summary table after all tests complete
_write_summary_table()
self.assertGreater(metrics["accuracy"], 0.935)


def _write_summary_table():
"""Write a markdown table with all test results."""
if not TEST_RESULTS:
return

gpu_config = os.getenv("GPU_CONFIG", "8-gpu-h200")

# Build table header
summary = (
f"### {DEEPSEEK_V32_MODEL_PATH} GSM8K Accuracy (TP Tests) [{gpu_config}]\n\n"
)
summary += "| Variant | Prefill Backend | Decode Backend | KV Cache | Accuracy |\n"
summary += "|---------|-----------------|----------------|----------|----------|\n"

# Add each result as a row
for result in TEST_RESULTS:
summary += (
f"| {result['variant']} | {result['prefill_backend']} | "
f"{result['decode_backend']} | {result['kv_cache']} | "
f"{result['accuracy']:.3f} |\n"
)

write_github_step_summary(summary)


if __name__ == "__main__":
unittest.main()
Loading