Skip to content

Commit 4e9c78c

Browse files
authored
Enable LLM function calls (#643)
1 parent 1b05210 commit 4e9c78c

File tree

2 files changed

+43
-3
lines changed

2 files changed

+43
-3
lines changed

examples/server/server.cpp

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1683,6 +1683,7 @@ struct server_context {
16831683
res.stop = true;
16841684
res.data = json {
16851685
{"content", !slot.params.stream ? slot.generated_text : ""},
1686+
{"generated_text", slot.generated_text}, // Always include full text for finish_reason logic
16861687
{"id_slot", slot.id},
16871688
{"stop", true},
16881689
{"model", params.model_alias},
@@ -2822,11 +2823,22 @@ static std::vector<json> format_partial_response_oaicompat(server_task_result ta
28222823
std::string content = json_value(result, "content", std::string(""));
28232824

28242825
std::string finish_reason;
2825-
if (stopped_word || stopped_eos) {
2826-
finish_reason = "stop";
2827-
}
28282826
if (stopped_limit) {
28292827
finish_reason = "length";
2828+
} else if (stopped_word || stopped_eos) {
2829+
// Following original llama.cpp pattern: finish_reason = oaicompat_msg.tool_calls.empty() ? "stop" : "tool_calls"
2830+
// Use generated_text (complete content) for finish_reason logic, not content (empty in streaming)
2831+
std::string generated_text = json_value(result, "generated_text", std::string(""));
2832+
ik_chat_msg final_msg = parse_chat_message_incremental(generated_text, false, modelname);
2833+
2834+
// Debug logging
2835+
LOG_INFO("DEBUG: Streaming finish_reason check", {
2836+
{"generated_text", generated_text},
2837+
{"model_name", modelname},
2838+
{"tool_calls_count", final_msg.tool_calls.size()}
2839+
});
2840+
2841+
finish_reason = final_msg.tool_calls.empty() ? "stop" : "tool_calls";
28302842
}
28312843

28322844
std::time_t t = std::time(0);

tests/test-function-calls.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2992,6 +2992,34 @@ int main() {
29922992
assert(extracted.find("<|tool▁calls▁begin|>") == std::string::npos);
29932993
std::cout << "✅ PASS: DeepSeek R1 content extraction works" << std::endl;
29942994

2995+
// Test streaming finish_reason logic (core of the fix)
2996+
std::cout << "\n🎯 Testing Streaming finish_reason Logic:" << std::endl;
2997+
2998+
// Test Case 1: Content with tool calls should lead to finish_reason="tool_calls"
2999+
std::string tool_call_content = "functions.get_weather:0{\"location\": \"Tokyo\"}";
3000+
ik_chat_msg msg_with_tools = parse_chat_message_incremental(tool_call_content, false, "kimi-k2");
3001+
bool should_be_tool_calls = !msg_with_tools.tool_calls.empty();
3002+
std::string finish_reason_with_tools = should_be_tool_calls ? "tool_calls" : "stop";
3003+
assert(finish_reason_with_tools == "tool_calls");
3004+
std::cout << "✅ PASS: Content with tool calls -> finish_reason='tool_calls'" << std::endl;
3005+
3006+
// Test Case 2: Content without tool calls should lead to finish_reason="stop"
3007+
std::string regular_content = "This is just regular text without any tool calls.";
3008+
ik_chat_msg msg_without_tools = parse_chat_message_incremental(regular_content, false, "kimi-k2");
3009+
bool should_be_stop = msg_without_tools.tool_calls.empty();
3010+
std::string finish_reason_without_tools = should_be_stop ? "stop" : "tool_calls";
3011+
assert(finish_reason_without_tools == "stop");
3012+
std::cout << "✅ PASS: Content without tool calls -> finish_reason='stop'" << std::endl;
3013+
3014+
// Test Case 3: Qwen3 XML format tool calls
3015+
std::string qwen3_content = "<tool_call>\n{\"name\": \"get_weather\", \"arguments\": {\"location\": \"Tokyo\"}}\n</tool_call>";
3016+
ik_chat_msg qwen3_msg = parse_chat_message_incremental(qwen3_content, false, "qwen3-7b");
3017+
bool qwen3_should_be_tool_calls = !qwen3_msg.tool_calls.empty();
3018+
std::string qwen3_finish_reason = qwen3_should_be_tool_calls ? "tool_calls" : "stop";
3019+
assert(qwen3_finish_reason == "tool_calls");
3020+
std::cout << "✅ PASS: Qwen3 XML tool calls -> finish_reason='tool_calls'" << std::endl;
3021+
3022+
std::cout << "🎯 All streaming finish_reason tests passed!" << std::endl;
29953023
} catch (const std::exception& e) {
29963024
std::cout << std::endl;
29973025
std::cout << "❌ Test failed with exception: " << e.what() << std::endl;

0 commit comments

Comments
 (0)