Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 119 additions & 6 deletions tests/tool_parsers/test_kimi_k2_tool_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,10 +502,11 @@ def test_empty_tool_section(kimi_k2_tool_parser):
assert kimi_k2_tool_parser.in_tool_section is False


def test_malformed_tool_section_recovery(kimi_k2_tool_parser):
def test_large_tool_args_no_forced_exit(kimi_k2_tool_parser):
"""
Test that the parser recovers from a malformed tool section
that never closes properly.
Test that the parser does NOT force-exit a tool section for payloads
within the configurable safety-valve limit (default 512 KB).
This ensures large file outputs via tool calls are supported.
"""
kimi_k2_tool_parser.reset_streaming_state()

Expand All @@ -523,9 +524,51 @@ def test_malformed_tool_section_recovery(kimi_k2_tool_parser):
)
assert kimi_k2_tool_parser.in_tool_section is True

# Simulate a lot of text without proper tool calls or section end
# This should trigger the error recovery mechanism
large_text = "x" * 10000 # Exceeds max_section_chars
# Simulate a 10 KB payload -- well within the 512 KB default limit.
# The parser should NOT force-exit; it stays in tool section.
large_text = "x" * 10000

_result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="<|tool_calls_section_begin|>",
current_text="<|tool_calls_section_begin|>" + large_text,
delta_text=large_text,
previous_token_ids=[section_begin_id],
current_token_ids=[section_begin_id] + list(range(100, 100 + len(large_text))),
delta_token_ids=list(range(100, 100 + len(large_text))),
request=None,
)

# Parser should still be in tool section (no forced exit for 10 KB)
assert kimi_k2_tool_parser.in_tool_section is True


def test_malformed_tool_section_safety_valve(kimi_k2_tool_parser):
"""
Test that the configurable safety valve forces exit when a tool
section exceeds the limit. Uses a small override to avoid
allocating 512 KB in a unit test.
"""
kimi_k2_tool_parser.reset_streaming_state()
# Override the safety valve to a small value for testing
original_max = kimi_k2_tool_parser.max_section_chars
kimi_k2_tool_parser.max_section_chars = 5000

section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>")

# Enter tool section
_result1 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="",
current_text="<|tool_calls_section_begin|>",
delta_text="<|tool_calls_section_begin|>",
previous_token_ids=[],
current_token_ids=[section_begin_id],
delta_token_ids=[section_begin_id],
request=None,
)
assert kimi_k2_tool_parser.in_tool_section is True

# Simulate exceeding the safety valve limit
large_text = "x" * 10000

result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="<|tool_calls_section_begin|>",
Expand All @@ -543,6 +586,9 @@ def test_malformed_tool_section_recovery(kimi_k2_tool_parser):
assert result2 is not None
assert result2.content == large_text

# Restore original max
kimi_k2_tool_parser.max_section_chars = original_max


def test_state_reset(kimi_k2_tool_parser):
"""Test that reset_streaming_state() properly clears all state."""
Expand All @@ -552,6 +598,7 @@ def test_state_reset(kimi_k2_tool_parser):
kimi_k2_tool_parser.current_tool_id = 5
kimi_k2_tool_parser.prev_tool_call_arr = [{"id": "test"}]
kimi_k2_tool_parser.section_char_count = 1000
kimi_k2_tool_parser._current_tool_args = '{"key": "value"}'

# Reset
kimi_k2_tool_parser.reset_streaming_state()
Expand All @@ -564,6 +611,7 @@ def test_state_reset(kimi_k2_tool_parser):
assert kimi_k2_tool_parser.section_char_count == 0
assert kimi_k2_tool_parser.current_tool_name_sent is False
assert kimi_k2_tool_parser.streamed_args_for_tool == []
assert kimi_k2_tool_parser._current_tool_args == ""


def test_section_begin_noise_tool_begin_same_chunk(kimi_k2_tool_parser):
Expand Down Expand Up @@ -923,3 +971,68 @@ def test_streaming_multiple_tool_calls_not_leaked(kimi_k2_tool_parser):

# Legitimate content preserved
assert "compare" in full_content.lower() or len(all_content) > 0


def test_complete_tool_call_single_delta(kimi_k2_tool_parser):
"""
Test that a complete tool call (begin + name + args + end) arriving
in a SINGLE delta still emits both the function name and arguments.

This catches a regression where Phase A was skipped (because
cur_start == cur_end) and Phase B's _handle_call_end returned None
(because current_tool_id was never set up).
"""
kimi_k2_tool_parser.reset_streaming_state()

section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>")
tool_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_call_begin|>")
tool_end_id = kimi_k2_tool_parser.vocab.get("<|tool_call_end|>")

# Step 1: section begin
deltas = [
("<|tool_calls_section_begin|>", [section_begin_id]),
]
run_streaming_sequence(kimi_k2_tool_parser, deltas)

# Step 2: a COMPLETE tool call in one delta (begin + end)
complete_tool = (
"<|tool_call_begin|>functions.get_weather:0 "
'<|tool_call_argument_begin|> {"city": "Paris"} '
"<|tool_call_end|>"
)

previous_text = "<|tool_calls_section_begin|>"
current_text = previous_text + complete_tool
previous_token_ids = [section_begin_id]
current_token_ids = [section_begin_id, tool_begin_id, 10, 11, 12, tool_end_id]
delta_token_ids = [tool_begin_id, 10, 11, 12, tool_end_id]

result = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text=previous_text,
current_text=current_text,
delta_text=complete_tool,
previous_token_ids=previous_token_ids,
current_token_ids=current_token_ids,
delta_token_ids=delta_token_ids,
request=None,
)

# The tool call must NOT be silently dropped
assert result is not None, (
"Complete tool call in single delta was dropped (returned None)"
)
assert result.tool_calls is not None and len(result.tool_calls) > 0, (
"No tool_calls emitted for complete tool call in single delta"
)

# Verify function name was emitted
first_tc = result.tool_calls[0]
assert first_tc.function is not None
# The function field may be a dict (from model_dump) or a pydantic
# model, depending on how DeltaToolCall reconstructs it.
func = first_tc.function
if isinstance(func, dict):
has_name = func.get("name") is not None
else:
has_name = getattr(func, "name", None) is not None
assert has_name, f"Function name not emitted for complete tool call: {first_tc}"
Loading
Loading