Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 12 additions & 169 deletions tools/agent/agent-loop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,172 +228,6 @@ void agent_loop::clear() {
permission_mgr_.clear_session();
}

// Parse a single function block: <function=name>...<parameter=key>value</parameter>...</function>
static bool parse_function_block(const std::string & block, common_chat_tool_call & tc) {
// Parse function name: <function=name>
size_t func_start = block.find("<function=");
if (func_start == std::string::npos) return false;

size_t func_name_start = func_start + 10;
size_t func_name_end = block.find(">", func_name_start);
if (func_name_end == std::string::npos) return false;

tc.name = block.substr(func_name_start, func_name_end - func_name_start);

// Find function block end
size_t func_end = block.find("</function>", func_name_end);
if (func_end == std::string::npos) func_end = block.size();

std::string func_body = block.substr(func_name_end + 1, func_end - func_name_end - 1);

// Parse parameters
json args = json::object();
size_t param_pos = 0;
while ((param_pos = func_body.find("<parameter=", param_pos)) != std::string::npos) {
size_t param_name_start = param_pos + 11;
size_t param_name_end = func_body.find(">", param_name_start);
if (param_name_end == std::string::npos) break;

std::string param_name = func_body.substr(param_name_start, param_name_end - param_name_start);

// Find parameter value (between > and </parameter> or next <parameter=)
size_t value_start = param_name_end + 1;
// Skip leading newline if present
if (value_start < func_body.size() && func_body[value_start] == '\n') {
value_start++;
}

size_t param_end = func_body.find("</parameter>", value_start);
size_t next_param = func_body.find("<parameter=", value_start);

size_t value_end;
if (param_end != std::string::npos && (next_param == std::string::npos || param_end < next_param)) {
value_end = param_end;
} else if (next_param != std::string::npos) {
value_end = next_param;
} else {
value_end = func_body.size();
}

std::string param_value = func_body.substr(value_start, value_end - value_start);
// Trim trailing newline/whitespace
while (!param_value.empty() && (param_value.back() == '\n' || param_value.back() == '\r')) {
param_value.pop_back();
}
// Trim leading/trailing whitespace for type inference
std::string trimmed = param_value;
while (!trimmed.empty() && std::isspace(trimmed.front())) trimmed.erase(0, 1);
while (!trimmed.empty() && std::isspace(trimmed.back())) trimmed.pop_back();

// Convert to appropriate JSON type
std::string lower_trimmed = trimmed;
for (auto & c : lower_trimmed) c = std::tolower(c);

if (lower_trimmed == "true") {
args[param_name] = true;
} else if (lower_trimmed == "false") {
args[param_name] = false;
} else {
// Try to parse as number
bool is_number = !trimmed.empty();
bool has_dot = false;
for (size_t i = 0; i < trimmed.size(); i++) {
char c = trimmed[i];
if (c == '-' && i == 0) continue;
if (c == '.' && !has_dot) { has_dot = true; continue; }
if (!std::isdigit(c)) { is_number = false; break; }
}
if (is_number && !trimmed.empty() && trimmed != "-" && trimmed != ".") {
if (has_dot) {
args[param_name] = std::stod(trimmed);
} else {
args[param_name] = std::stoll(trimmed);
}
} else {
args[param_name] = param_value;
}
}
param_pos = value_end;
}

tc.arguments = args.dump();
return true;
}

// Parse tool calls from qwen3-coder/nemotron XML format
// Supports both:
// <tool_call><function=name>...</function></tool_call>
// <function=name>...</function> (without wrapper)
static common_chat_msg parse_tool_calls_xml(const std::string & content) {
common_chat_msg msg;
msg.role = "assistant";

std::string remaining = content;

// First, try to find <tool_call> wrapped format
size_t tool_call_start = remaining.find("<tool_call>");
// If no <tool_call>, look for bare <function= tags
size_t func_start = remaining.find("<function=");

// Determine the earliest tool/function occurrence
size_t first_tool = std::string::npos;
bool has_wrapper = false;
if (tool_call_start != std::string::npos && (func_start == std::string::npos || tool_call_start < func_start)) {
first_tool = tool_call_start;
has_wrapper = true;
} else if (func_start != std::string::npos) {
first_tool = func_start;
has_wrapper = false;
}

// Extract content before any tool calls
if (first_tool != std::string::npos) {
msg.content = remaining.substr(0, first_tool);
// Trim trailing whitespace from content
while (!msg.content.empty() && std::isspace(msg.content.back())) {
msg.content.pop_back();
}
} else {
msg.content = content;
return msg; // No tool calls
}

// Parse tool calls
if (has_wrapper) {
// Parse <tool_call>...<function=...>...</function>...</tool_call> format
while ((tool_call_start = remaining.find("<tool_call>")) != std::string::npos) {
size_t tool_call_end = remaining.find("</tool_call>", tool_call_start);
if (tool_call_end == std::string::npos) break;

std::string tool_block = remaining.substr(tool_call_start + 11, tool_call_end - tool_call_start - 11);
remaining = remaining.substr(tool_call_end + 12);

common_chat_tool_call tc;
tc.id = "call_" + std::to_string(msg.tool_calls.size());
if (parse_function_block(tool_block, tc)) {
msg.tool_calls.push_back(tc);
}
}
} else {
// Parse bare <function=...>...</function> format
while ((func_start = remaining.find("<function=")) != std::string::npos) {
size_t func_end = remaining.find("</function>", func_start);
if (func_end == std::string::npos) break;

std::string func_block = remaining.substr(func_start, func_end - func_start + 11);
remaining = remaining.substr(func_end + 11);

common_chat_tool_call tc;
tc.id = "call_" + std::to_string(msg.tool_calls.size());
if (parse_function_block(func_block, tc)) {
msg.tool_calls.push_back(tc);
}
}
}

return msg;
}

common_chat_msg agent_loop::generate_completion(result_timings & out_timings) {
server_response_reader rd = server_ctx_.get_response_reader();
{
Expand Down Expand Up @@ -479,7 +313,12 @@ common_chat_msg agent_loop::generate_completion(result_timings & out_timings) {
auto res_final = dynamic_cast<server_task_result_cmpl_final *>(result.get());
if (res_final) {
out_timings = std::move(res_final->timings);
// Use the raw content for our own parsing
// Use the server-parsed message which handles all chat template formats
// (Hermes 2 Pro, Qwen3-Coder, Llama 3.x, DeepSeek, etc.)
if (!res_final->oaicompat_msg.empty()) {
return res_final->oaicompat_msg;
}
// Fallback to raw content if no parsed message
if (!res_final->content.empty()) {
full_content = res_final->content;
}
Expand All @@ -501,8 +340,12 @@ common_chat_msg agent_loop::generate_completion(result_timings & out_timings) {
return msg;
}

// Parse tool calls ourselves using the qwen3-coder/nemotron XML format
return parse_tool_calls_xml(full_content);
// Fallback: return content without tool calls
// (Server should have parsed if parse_tool_calls=true, but handle edge cases)
common_chat_msg msg;
msg.role = "assistant";
msg.content = full_content;
return msg;
}

tool_result agent_loop::execute_tool_call(const common_chat_tool_call & call) {
Expand Down
14 changes: 8 additions & 6 deletions tools/server/server-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1428,6 +1428,7 @@ struct server_context_impl {
res->res_type = slot.task->params.res_type;
res->oaicompat_model = slot.task->params.oaicompat_model;
res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id;
res->oaicompat_chat_syntax = slot.task->params.oaicompat_chat_syntax;

// populate res.probs_output
if (slot.task->params.sampling.n_probs > 0) {
Expand Down Expand Up @@ -1470,12 +1471,13 @@ struct server_context_impl {
res->stop = slot.stop;
res->post_sampling_probs = slot.task->params.post_sampling_probs;

res->verbose = slot.task->params.verbose;
res->stream = slot.task->params.stream;
res->include_usage = slot.task->params.include_usage;
res->res_type = slot.task->params.res_type;
res->oaicompat_model = slot.task->params.oaicompat_model;
res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id;
res->verbose = slot.task->params.verbose;
res->stream = slot.task->params.stream;
res->include_usage = slot.task->params.include_usage;
res->res_type = slot.task->params.res_type;
res->oaicompat_model = slot.task->params.oaicompat_model;
res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id;
res->oaicompat_chat_syntax = slot.task->params.oaicompat_chat_syntax;

// populate res.probs_output
if (slot.task->params.sampling.n_probs > 0) {
Expand Down
10 changes: 10 additions & 0 deletions tools/server/server-task.h
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,9 @@ struct server_task_result_cmpl_final : server_task_result {
std::string oaicompat_cmpl_id;
common_chat_msg oaicompat_msg; // to be populated by update()

// Chat syntax for tool call parsing (synced from task params)
common_chat_syntax oaicompat_chat_syntax;

std::vector<common_chat_msg_diff> oaicompat_msg_diffs; // to be populated by update()
bool is_updated = false;

Expand All @@ -304,6 +307,8 @@ struct server_task_result_cmpl_final : server_task_result {

virtual void update(task_result_state & state) override {
is_updated = true;
// Sync chat syntax from server (may have been updated by tokenize_cli_input)
state.oaicompat_chat_syntax = oaicompat_chat_syntax;
oaicompat_msg = state.update_chat_msg(content, false, oaicompat_msg_diffs);
}

Expand Down Expand Up @@ -341,6 +346,9 @@ struct server_task_result_cmpl_partial : server_task_result {
std::vector<common_chat_msg_diff> oaicompat_msg_diffs; // to be populated by update()
bool is_updated = false;

// Chat syntax for tool call parsing (synced from task params)
common_chat_syntax oaicompat_chat_syntax;

// for Anthropic API: track if any reasoning content has been generated
bool anthropic_has_reasoning = false;
// Streaming state copied from task_result_state for this chunk
Expand All @@ -355,6 +363,8 @@ struct server_task_result_cmpl_partial : server_task_result {

virtual void update(task_result_state & state) override {
is_updated = true;
// Sync chat syntax from server (may have been updated by tokenize_cli_input)
state.oaicompat_chat_syntax = oaicompat_chat_syntax;
state.update_chat_msg(content, true, oaicompat_msg_diffs);
// track if the accumulated message has any reasoning content
anthropic_has_reasoning = !state.chat_msg.reasoning_content.empty();
Expand Down