diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_mxint4_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_mxint4_moe.py index 8419efcb8b72..865f3de43849 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_mxint4_moe.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_mxint4_moe.py @@ -287,7 +287,7 @@ def create_moe_runner( ): self.moe_runner_config = moe_runner_config - def apply( + def apply_weights( self, layer: torch.nn.Module, dispatch_output: StandardDispatchOutput, diff --git a/python/sglang/srt/parser/reasoning_parser.py b/python/sglang/srt/parser/reasoning_parser.py index 00c26196b5cb..a85bbb440f55 100644 --- a/python/sglang/srt/parser/reasoning_parser.py +++ b/python/sglang/srt/parser/reasoning_parser.py @@ -268,6 +268,34 @@ def __init__( ) +class KimiK2Detector(BaseReasoningFormatDetector): + """ + Detector for Kimi K2 models. + Assumes reasoning format: + ()*(.*) + + Kimi K2 can switch from reasoning to tool-call section with + `<|tool_calls_section_begin|>` before emitting ``. + """ + + def __init__( + self, + stream_reasoning: bool = True, + force_reasoning: bool = False, + continue_final_message: bool = False, + previous_content: str = "", + ): + super().__init__( + "", + "", + force_reasoning=force_reasoning, + stream_reasoning=stream_reasoning, + tool_start_token="<|tool_calls_section_begin|>", + continue_final_message=continue_final_message, + previous_content=previous_content, + ) + + class Glm45Detector(BaseReasoningFormatDetector): """ Detector for GLM-4.5 models. @@ -431,7 +459,7 @@ class ReasoningParser: "glm45": Glm45Detector, "gpt-oss": GptOssDetector, "kimi": KimiDetector, - "kimi_k2": Qwen3Detector, + "kimi_k2": KimiK2Detector, "qwen3": Qwen3Detector, "qwen3-thinking": Qwen3Detector, "minimax": Qwen3Detector, diff --git a/test/registered/parser/test_reasoning_parser.py b/test/registered/parser/test_reasoning_parser.py index b76f853423ce..c6ae17aee123 100644 --- a/test/registered/parser/test_reasoning_parser.py +++ b/test/registered/parser/test_reasoning_parser.py @@ -5,6 +5,7 @@ DeepSeekR1Detector, Glm45Detector, KimiDetector, + KimiK2Detector, Qwen3Detector, ReasoningParser, StreamingParseResult, @@ -314,6 +315,51 @@ def test_streaming_kimi_format(self): self.assertEqual(result.normal_text, "answer") +class TestKimiK2Detector(CustomTestCase): + """Test cases for KimiK2 detector with tool interruption support.""" + + def setUp(self): + self.detector = KimiK2Detector() + + def test_init(self): + """Test KimiK2Detector initialization.""" + self.assertEqual(self.detector.think_start_token, "") + self.assertEqual(self.detector.think_end_token, "") + self.assertEqual(self.detector.tool_start_token, "<|tool_calls_section_begin|>") + self.assertFalse(self.detector._in_reasoning) + self.assertTrue(self.detector.stream_reasoning) + + def test_detect_and_parse_tool_interrupt(self): + """Test parsing with Kimi-K2 tool-section interruption.""" + text = "thinking<|tool_calls_section_begin|><|tool_call_begin|>" + result = self.detector.detect_and_parse(text) + self.assertEqual(result.reasoning_text, "thinking") + self.assertEqual( + result.normal_text, "<|tool_calls_section_begin|><|tool_call_begin|>" + ) + + def test_streaming_tool_interrupt(self): + """Test streaming parse interrupted by tool section.""" + self.detector.parse_streaming_increment("") + result1 = self.detector.parse_streaming_increment("reasoning") + self.assertEqual(result1.reasoning_text, "reasoning") + self.assertEqual(result1.normal_text, "") + + result2 = self.detector.parse_streaming_increment( + "<|tool_calls_section_begin|>" + ) + self.assertEqual(result2.reasoning_text, "") + self.assertEqual(result2.normal_text, "<|tool_calls_section_begin|>") + + def test_streaming_after_interrupt_is_normal(self): + """After interruption, subsequent chunks should be normal text.""" + self.detector.parse_streaming_increment("") + self.detector.parse_streaming_increment("reasoning<|tool_calls_section_begin|>") + result = self.detector.parse_streaming_increment("<|tool_call_begin|>") + self.assertEqual(result.reasoning_text, "") + self.assertEqual(result.normal_text, "<|tool_call_begin|>") + + class TestGlm45Detector(CustomTestCase): """Test cases for GLM45 detector with tool interruption support.""" @@ -478,6 +524,9 @@ def test_init_valid_model(self): parser = ReasoningParser("kimi") self.assertIsInstance(parser.detector, KimiDetector) + parser = ReasoningParser("kimi_k2") + self.assertIsInstance(parser.detector, KimiK2Detector) + parser = ReasoningParser("glm45") self.assertIsInstance(parser.detector, Glm45Detector) @@ -565,6 +614,37 @@ def test_glm45_tool_interruption(self): self.assertEqual(all_reasoning, "reasoning") self.assertEqual(all_normal, "tool args") + def test_kimik2_tool_interruption(self): + """Test Kimi-K2 tool interruption through ReasoningParser API.""" + parser = ReasoningParser("kimi_k2") + + # Non-streaming: tool interrupt + reasoning, normal = parser.parse_non_stream( + "thinking<|tool_calls_section_begin|><|tool_call_begin|>" + ) + self.assertEqual(reasoning, "thinking") + self.assertEqual(normal, "<|tool_calls_section_begin|><|tool_call_begin|>") + + # Streaming: tool interrupt + parser = ReasoningParser("kimi_k2") + chunks = [ + "", + "reasoning", + "<|tool_calls_section_begin|>", + "<|tool_call_begin|>", + ] + all_reasoning = "" + all_normal = "" + for chunk in chunks: + reasoning, normal = parser.parse_stream_chunk(chunk) + if reasoning: + all_reasoning += reasoning + if normal: + all_normal += normal + + self.assertEqual(all_reasoning, "reasoning") + self.assertEqual(all_normal, "<|tool_calls_section_begin|><|tool_call_begin|>") + class TestIntegrationScenarios(CustomTestCase): """Integration tests for realistic usage scenarios."""