Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def create_moe_runner(
):
self.moe_runner_config = moe_runner_config

def apply(
def apply_weights(
self,
layer: torch.nn.Module,
dispatch_output: StandardDispatchOutput,
Expand Down
30 changes: 29 additions & 1 deletion python/sglang/srt/parser/reasoning_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,34 @@ def __init__(
)


class KimiK2Detector(BaseReasoningFormatDetector):
"""
Detector for Kimi K2 models.
Assumes reasoning format:
(<think>)*(.*)</think>

Kimi K2 can switch from reasoning to tool-call section with
`<|tool_calls_section_begin|>` before emitting `</think>`.
"""

def __init__(
self,
stream_reasoning: bool = True,
force_reasoning: bool = False,
continue_final_message: bool = False,
previous_content: str = "",
):
super().__init__(
"<think>",
"</think>",
force_reasoning=force_reasoning,
stream_reasoning=stream_reasoning,
tool_start_token="<|tool_calls_section_begin|>",
continue_final_message=continue_final_message,
previous_content=previous_content,
)


class Glm45Detector(BaseReasoningFormatDetector):
"""
Detector for GLM-4.5 models.
Expand Down Expand Up @@ -431,7 +459,7 @@ class ReasoningParser:
"glm45": Glm45Detector,
"gpt-oss": GptOssDetector,
"kimi": KimiDetector,
"kimi_k2": Qwen3Detector,
"kimi_k2": KimiK2Detector,
"qwen3": Qwen3Detector,
"qwen3-thinking": Qwen3Detector,
"minimax": Qwen3Detector,
Expand Down
80 changes: 80 additions & 0 deletions test/registered/parser/test_reasoning_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
DeepSeekR1Detector,
Glm45Detector,
KimiDetector,
KimiK2Detector,
Qwen3Detector,
ReasoningParser,
StreamingParseResult,
Expand Down Expand Up @@ -314,6 +315,51 @@ def test_streaming_kimi_format(self):
self.assertEqual(result.normal_text, "answer")


class TestKimiK2Detector(CustomTestCase):
"""Test cases for KimiK2 detector with tool interruption support."""

def setUp(self):
self.detector = KimiK2Detector()

def test_init(self):
"""Test KimiK2Detector initialization."""
self.assertEqual(self.detector.think_start_token, "<think>")
self.assertEqual(self.detector.think_end_token, "</think>")
self.assertEqual(self.detector.tool_start_token, "<|tool_calls_section_begin|>")
self.assertFalse(self.detector._in_reasoning)
self.assertTrue(self.detector.stream_reasoning)

def test_detect_and_parse_tool_interrupt(self):
"""Test parsing with Kimi-K2 tool-section interruption."""
text = "<think>thinking<|tool_calls_section_begin|><|tool_call_begin|>"
result = self.detector.detect_and_parse(text)
self.assertEqual(result.reasoning_text, "thinking")
self.assertEqual(
result.normal_text, "<|tool_calls_section_begin|><|tool_call_begin|>"
)

def test_streaming_tool_interrupt(self):
"""Test streaming parse interrupted by tool section."""
self.detector.parse_streaming_increment("<think>")
result1 = self.detector.parse_streaming_increment("reasoning")
self.assertEqual(result1.reasoning_text, "reasoning")
self.assertEqual(result1.normal_text, "")

result2 = self.detector.parse_streaming_increment(
"<|tool_calls_section_begin|>"
)
self.assertEqual(result2.reasoning_text, "")
self.assertEqual(result2.normal_text, "<|tool_calls_section_begin|>")

def test_streaming_after_interrupt_is_normal(self):
"""After interruption, subsequent chunks should be normal text."""
self.detector.parse_streaming_increment("<think>")
self.detector.parse_streaming_increment("reasoning<|tool_calls_section_begin|>")
result = self.detector.parse_streaming_increment("<|tool_call_begin|>")
self.assertEqual(result.reasoning_text, "")
self.assertEqual(result.normal_text, "<|tool_call_begin|>")


class TestGlm45Detector(CustomTestCase):
"""Test cases for GLM45 detector with tool interruption support."""

Expand Down Expand Up @@ -478,6 +524,9 @@ def test_init_valid_model(self):
parser = ReasoningParser("kimi")
self.assertIsInstance(parser.detector, KimiDetector)

parser = ReasoningParser("kimi_k2")
self.assertIsInstance(parser.detector, KimiK2Detector)

parser = ReasoningParser("glm45")
self.assertIsInstance(parser.detector, Glm45Detector)

Expand Down Expand Up @@ -565,6 +614,37 @@ def test_glm45_tool_interruption(self):
self.assertEqual(all_reasoning, "reasoning")
self.assertEqual(all_normal, "<tool_call>tool args")

def test_kimik2_tool_interruption(self):
"""Test Kimi-K2 tool interruption through ReasoningParser API."""
parser = ReasoningParser("kimi_k2")

# Non-streaming: tool interrupt
reasoning, normal = parser.parse_non_stream(
"<think>thinking<|tool_calls_section_begin|><|tool_call_begin|>"
)
self.assertEqual(reasoning, "thinking")
self.assertEqual(normal, "<|tool_calls_section_begin|><|tool_call_begin|>")

# Streaming: tool interrupt
parser = ReasoningParser("kimi_k2")
chunks = [
"<think>",
"reasoning",
"<|tool_calls_section_begin|>",
"<|tool_call_begin|>",
]
all_reasoning = ""
all_normal = ""
for chunk in chunks:
reasoning, normal = parser.parse_stream_chunk(chunk)
if reasoning:
all_reasoning += reasoning
if normal:
all_normal += normal

self.assertEqual(all_reasoning, "reasoning")
self.assertEqual(all_normal, "<|tool_calls_section_begin|><|tool_call_begin|>")

Comment thread
JustinTong0323 marked this conversation as resolved.

class TestIntegrationScenarios(CustomTestCase):
"""Integration tests for realistic usage scenarios."""
Expand Down
Loading