diff --git a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py index 181b29c8b32e..307f1979aba3 100644 --- a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py +++ b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py @@ -370,6 +370,8 @@ def _get_topk_ragged( k_scale_list = [] ks_list = [] ke_list = [] + # Token-to-batch mapping for PAGED chunk alignment + token_to_batch_idx: List[int] = [] q_offset = 0 k_offset = 0 @@ -401,6 +403,7 @@ def _get_topk_ragged( ks_list.append(ks) ke_list.append(ke) + token_to_batch_idx.extend([i] * extend_seq_len) q_offset += extend_seq_len k_offset += seq_len @@ -473,6 +476,13 @@ def _get_topk_ragged( (token_nums, self.index_topk), -1, device=device, dtype=torch.int32 ) + # Only materialize batch index tensor when PAGED path needs it + token_to_batch_idx_tensor = None + if global_topk_offset is None: + token_to_batch_idx_tensor = torch.tensor( + token_to_batch_idx, dtype=torch.long, device=device + ) + start = 0 while start < q_offset: end = min(start + max_rows, q_offset) @@ -488,17 +498,28 @@ def _get_topk_ragged( lengths_chunk = seq_lens_expanded[start:end] - topk_offset_chunk = ( - global_topk_offset[start:end] - if global_topk_offset is not None - else None - ) + # RAGGED: use global offset; PAGED: construct local cu_seqlens_q per chunk + if global_topk_offset is not None: + # RAGGED path + topk_offset_chunk = global_topk_offset[start:end] + cu_seqlens_q_chunk = None + batch_idx_chunk = None + else: + # PAGED path: treat each token as a length-1 sequence + topk_offset_chunk = None + B_chunk = logits_chunk.shape[0] + cu_seqlens_q_chunk = torch.ones( + B_chunk, dtype=torch.int32, device=device + ) + batch_idx_chunk = token_to_batch_idx_tensor[start:end] raw_topk_chunk = metadata.topk_transform( logits_chunk, self.index_topk, ks=ks[start:end], + cu_seqlens_q=cu_seqlens_q_chunk, ke_offset=lengths_chunk, + batch_idx_list=batch_idx_chunk, topk_indices_offset_override=topk_offset_chunk, ) topk_result[start:end] = raw_topk_chunk diff --git a/test/nightly/test_deepseek_v32_nsabackend.py b/test/nightly/test_deepseek_v32_nsabackend.py index 22916dac9ce9..45e2d665aadf 100644 --- a/test/nightly/test_deepseek_v32_nsabackend.py +++ b/test/nightly/test_deepseek_v32_nsabackend.py @@ -197,64 +197,6 @@ def test_a_gsm8k( self.assertGreater(metrics["accuracy"], 0.935) -@unittest.skip("Temporary skip pure TP test") -class TestDeepseekV32NasBackend_pure_tp(CustomTestCase): - """Test DeepSeek V3.2 with pure TP mode (no DP attention).""" - - @classmethod - def setUpClass(cls): - cls.model = DEEPSEEK_V32_MODEL_PATH - cls.base_url = DEFAULT_URL_FOR_TEST - # Pure TP configuration without --dp and --enable-dp-attention - other_args = [ - "--trust-remote-code", - "--attention-backend", - "nsa", - "--nsa-prefill-backend", - "flashmla_sparse", - "--nsa-decode-backend", - "flashmla_kv", - "--tp", - "8", - ] - cls.process = popen_launch_server( - cls.model, - cls.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=other_args, - ) - - @classmethod - def tearDownClass(cls): - kill_process_tree(cls.process.pid) - - def test_a_gsm8k(self): - """Test GSM8K accuracy with pure TP mode.""" - args = SimpleNamespace( - num_shots=20, - data_path=None, - num_questions=1400, - parallel=1400, - max_new_tokens=512, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), - ) - metrics = run_eval_few_shot_gsm8k(args) - print(f"{metrics=}") - - if is_in_ci(): - TEST_RESULTS.append( - { - "variant": "pure_tp", - "prefill_backend": "flashmla_sparse", - "decode_backend": "flashmla_kv", - "kv_cache": "fp16", - "accuracy": metrics["accuracy"], - } - ) - self.assertGreater(metrics["accuracy"], 0.935) - - def _write_summary_table(): """Write a markdown table with all test results.""" if not TEST_RESULTS: diff --git a/test/nightly/test_deepseek_v32_tp.py b/test/nightly/test_deepseek_v32_tp.py new file mode 100644 index 000000000000..3193046fd385 --- /dev/null +++ b/test/nightly/test_deepseek_v32_tp.py @@ -0,0 +1,169 @@ +import os +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import kill_process_tree +from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + is_in_ci, + popen_launch_server, + write_github_step_summary, +) + +register_cuda_ci(est_time=600, suite="nightly-8-gpu-h200", nightly=True) + +DEEPSEEK_V32_MODEL_PATH = "deepseek-ai/DeepSeek-V3.2-Exp" + +# Global list to collect results +TEST_RESULTS = [] + + +class TestDeepseekV32_TP(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = DEEPSEEK_V32_MODEL_PATH + cls.base_url = DEFAULT_URL_FOR_TEST + # Pure TP configuration without --dp and --enable-dp-attention + other_args = [ + "--trust-remote-code", + "--attention-backend", + "nsa", + "--nsa-prefill-backend", + "flashmla_sparse", + "--nsa-decode-backend", + "flashmla_kv", + "--tp", + "8", + ] + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=other_args, + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_a_gsm8k( + self, + ): # Append an "a" to make this test run first (alphabetically) to warm up the server + args = SimpleNamespace( + num_shots=20, + data_path=None, + num_questions=1400, + parallel=1400, + max_new_tokens=512, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval_few_shot_gsm8k(args) + print(f"{metrics=}") + + if is_in_ci(): + TEST_RESULTS.append( + { + "variant": "pure_tp", + "prefill_backend": "flashmla_sparse", + "decode_backend": "flashmla_kv", + "kv_cache": "fp16", + "accuracy": metrics["accuracy"], + } + ) + self.assertGreater(metrics["accuracy"], 0.935) + + +class TestDeepseekV32_Partial_TP(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = DEEPSEEK_V32_MODEL_PATH + cls.base_url = DEFAULT_URL_FOR_TEST + # Partial TP configuration with dp=4 and dp-attention enabled + other_args = [ + "--trust-remote-code", + "--attention-backend", + "nsa", + "--nsa-prefill-backend", + "flashmla_sparse", + "--nsa-decode-backend", + "flashmla_kv", + "--tp", + "8", + "--dp", + "4", + "--enable-dp-attention", + ] + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=other_args, + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_a_gsm8k( + self, + ): # Append an "a" to make this test run first (alphabetically) to warm up the server + args = SimpleNamespace( + num_shots=20, + data_path=None, + num_questions=1400, + parallel=1400, + max_new_tokens=512, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval_few_shot_gsm8k(args) + print(f"{metrics=}") + + if is_in_ci(): + TEST_RESULTS.append( + { + "variant": "partial_tp", + "prefill_backend": "flashmla_sparse", + "decode_backend": "flashmla_kv", + "kv_cache": "fp16", + "accuracy": metrics["accuracy"], + } + ) + + # Write the summary table after all tests complete + _write_summary_table() + self.assertGreater(metrics["accuracy"], 0.935) + + +def _write_summary_table(): + """Write a markdown table with all test results.""" + if not TEST_RESULTS: + return + + gpu_config = os.getenv("GPU_CONFIG", "8-gpu-h200") + + # Build table header + summary = ( + f"### {DEEPSEEK_V32_MODEL_PATH} GSM8K Accuracy (TP Tests) [{gpu_config}]\n\n" + ) + summary += "| Variant | Prefill Backend | Decode Backend | KV Cache | Accuracy |\n" + summary += "|---------|-----------------|----------------|----------|----------|\n" + + # Add each result as a row + for result in TEST_RESULTS: + summary += ( + f"| {result['variant']} | {result['prefill_backend']} | " + f"{result['decode_backend']} | {result['kv_cache']} | " + f"{result['accuracy']:.3f} |\n" + ) + + write_github_step_summary(summary) + + +if __name__ == "__main__": + unittest.main()