Skip to content

Commit 2de36f5

Browse files
committed
Fix grammar, hide tool_call from output
1 parent 5c7c5dd commit 2de36f5

File tree

2 files changed

+84
-72
lines changed

2 files changed

+84
-72
lines changed

common/chat-parser.cpp

Lines changed: 52 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,7 @@ namespace {
401401
static constexpr size_t MAX_PARAMETER_COUNT = 100; // Maximum parameters per function
402402
static constexpr size_t MAX_TAG_NAME_LENGTH = 256; // Maximum tag name length
403403
static constexpr size_t MAX_ATTRIBUTE_LENGTH = 1024; // Maximum attribute length
404-
404+
405405
// Helper function to set error details
406406
void set_error(common_chat_msg_parser::XmlParseError & error,
407407
common_chat_msg_parser::XmlParseErrorType type,
@@ -413,16 +413,16 @@ namespace {
413413
error.context = context;
414414
error.message = message;
415415
}
416-
416+
417417
// Simple XML tag parser - safer than regex, using string_view for performance
418418
struct XmlTag {
419419
std::string name;
420420
std::string attribute;
421421
std::string content;
422-
size_t start_pos;
423-
size_t end_pos;
422+
size_t start_pos = 0;
423+
size_t end_pos = 0;
424424
};
425-
425+
426426
// Find XML tag with optional attribute - ITERATIVE implementation to avoid stack overflow
427427
std::optional<XmlTag> find_xml_tag(std::string_view text, std::string_view tag_name, size_t start_pos = 0,
428428
common_chat_msg_parser::XmlParseError * error = nullptr) {
@@ -436,7 +436,7 @@ namespace {
436436
}
437437
return std::nullopt;
438438
}
439-
439+
440440
if (tag_name.size() > MAX_TAG_NAME_LENGTH) {
441441
LOG_DBG("Tag name too long: %zu chars (max: %zu)\n", tag_name.size(), MAX_TAG_NAME_LENGTH);
442442
if (error) {
@@ -446,16 +446,16 @@ namespace {
446446
}
447447
return std::nullopt;
448448
}
449-
449+
450450
if (start_pos >= text.size()) {
451451
return std::nullopt;
452452
}
453-
453+
454454
// PERFORMANCE OPTIMIZATION: Use string_view to avoid allocations
455455
// Pre-compute tag patterns
456456
const std::string open_tag_start = std::string("<") + std::string(tag_name);
457457
const std::string close_tag = std::string("</") + std::string(tag_name) + ">";
458-
458+
459459
// ITERATIVE search to avoid recursion and potential stack overflow
460460
size_t search_pos = start_pos;
461461
while (search_pos < text.size()) {
@@ -464,7 +464,7 @@ namespace {
464464
if (open_pos == std::string::npos) {
465465
return std::nullopt;
466466
}
467-
467+
468468
// Validate that this is actually the start of our tag (not a substring)
469469
// Check that the character after tag name is either '>' or '=' or whitespace
470470
size_t check_pos = open_pos + open_tag_start.length();
@@ -477,16 +477,16 @@ namespace {
477477
continue;
478478
}
479479
}
480-
480+
481481
// Find the end of the opening tag
482482
size_t open_end = text.find('>', open_pos);
483483
if (open_end == std::string::npos) {
484484
return std::nullopt;
485485
}
486-
486+
487487
XmlTag tag;
488488
tag.start_pos = open_pos;
489-
489+
490490
// Extract attribute if present (for tags like <function=name> or <function = "name">)
491491
// PERFORMANCE: Use string_view for substring operations
492492
size_t tag_content_start = open_pos + 1 + tag_name.length();
@@ -499,15 +499,15 @@ namespace {
499499
while (attr_start < open_end && std::isspace(text[attr_start])) {
500500
attr_start++;
501501
}
502-
502+
503503
if (attr_start < open_end) {
504504
size_t attr_end = open_end;
505-
505+
506506
// Handle quoted attribute values
507507
if (text[attr_start] == '"' || text[attr_start] == '\'') {
508508
char quote_char = text[attr_start];
509509
attr_start++; // Skip opening quote
510-
510+
511511
// Find closing quote
512512
size_t quote_end = text.find(quote_char, attr_start);
513513
if (quote_end != std::string::npos && quote_end < open_end) {
@@ -522,7 +522,7 @@ namespace {
522522
attr_end--;
523523
}
524524
}
525-
525+
526526
if (attr_start < attr_end) {
527527
std::string_view attr_view = text.substr(attr_start, attr_end - attr_start);
528528
// Validate attribute length
@@ -541,37 +541,37 @@ namespace {
541541
}
542542
}
543543
}
544-
544+
545545
// Look for closing tag - PERFORMANCE: Search from after opening tag
546546
size_t close_pos = text.find(close_tag, open_end + 1);
547547
if (close_pos == std::string::npos) {
548-
return std::nullopt;
548+
return tag;
549549
}
550-
550+
551551
tag.end_pos = close_pos + close_tag.length();
552552
tag.name = std::string(tag_name);
553-
553+
554554
// PERFORMANCE: Use string_view for content extraction
555555
size_t content_start = open_end + 1;
556556
size_t content_length = close_pos - content_start;
557557
if (content_length > 0) {
558558
std::string_view content_view = text.substr(content_start, content_length);
559559
tag.content = std::string(content_view);
560560
}
561-
561+
562562
return tag;
563563
}
564-
564+
565565
return std::nullopt;
566566
}
567-
567+
568568
// Find all XML tags with a specific name and attribute pattern - with limits, using string_view
569569
std::vector<XmlTag> find_all_xml_tags(std::string_view text, std::string_view tag_name,
570570
common_chat_msg_parser::XmlParseError * error = nullptr) {
571571
std::vector<XmlTag> tags;
572572
size_t pos = 0;
573573
size_t tag_count = 0;
574-
574+
575575
while (pos < text.length() && tag_count < MAX_PARAMETER_COUNT) {
576576
auto tag = find_xml_tag(text, tag_name, pos, error);
577577
if (!tag) {
@@ -581,7 +581,7 @@ namespace {
581581
pos = tag->end_pos;
582582
++tag_count;
583583
}
584-
584+
585585
if (tag_count >= MAX_PARAMETER_COUNT) {
586586
LOG_DBG("Too many tags found: %zu (max: %zu)\n", tag_count, MAX_PARAMETER_COUNT);
587587
if (error) {
@@ -590,10 +590,10 @@ namespace {
590590
"Too many " + std::string(tag_name) + " tags found (max: " + std::to_string(MAX_PARAMETER_COUNT) + ")");
591591
}
592592
}
593-
593+
594594
return tags;
595595
}
596-
596+
597597
// Trim whitespace from string using string_view for performance
598598
std::string trim_whitespace(std::string_view str) {
599599
size_t start = str.find_first_not_of(" \t\n\r");
@@ -603,7 +603,7 @@ namespace {
603603
size_t end = str.find_last_not_of(" \t\n\r");
604604
return std::string(str.substr(start, end - start + 1));
605605
}
606-
606+
607607
// Safe integer parsing with overflow protection using string_view
608608
bool safe_parse_int(std::string_view str, int & result) {
609609
try {
@@ -619,7 +619,7 @@ namespace {
619619
return false;
620620
}
621621
}
622-
622+
623623
// Safe float parsing with overflow protection using string_view
624624
bool safe_parse_float(std::string_view str, float & result) {
625625
try {
@@ -634,14 +634,14 @@ namespace {
634634
return false;
635635
}
636636
}
637-
637+
638638
// Convert parameter value based on tool schema type - FIXED JSON injection vulnerability, using string_view
639639
std::string convert_qwen3_param_value(std::string_view param_value,
640640
std::string_view param_name,
641641
const nlohmann::json & param_config,
642642
std::string_view /* func_name */) {
643643
std::string trimmed_value = trim_whitespace(param_value);
644-
644+
645645
// Handle null value
646646
if (trimmed_value == "null") {
647647
return "null";
@@ -689,7 +689,7 @@ namespace {
689689
}
690690
}
691691
}
692-
692+
693693
// Without schema, try to infer type from value
694694
// First check if it's valid JSON (object or array)
695695
try {
@@ -698,23 +698,23 @@ namespace {
698698
} catch (...) {
699699
// Not valid JSON, continue with other type checks
700700
}
701-
701+
702702
// Check if it's a number
703703
int int_val;
704704
if (safe_parse_int(trimmed_value, int_val)) {
705705
return std::to_string(int_val); // It's an integer
706706
}
707-
707+
708708
float float_val;
709709
if (safe_parse_float(trimmed_value, float_val)) {
710710
return std::to_string(float_val); // It's a float
711711
}
712-
712+
713713
// Check if it's a boolean
714714
if (trimmed_value == "true" || trimmed_value == "false") {
715715
return trimmed_value;
716716
}
717-
717+
718718
// Default to string - SECURITY FIX: Use proper JSON escaping
719719
return json(trimmed_value).dump();
720720
}
@@ -752,7 +752,7 @@ bool common_chat_msg_parser::parse_qwen3_xml_tool_call(const std::string & conte
752752
XmlParseError & error) {
753753
// Clear any previous error
754754
error.clear();
755-
755+
756756
// Input validation for DoS protection
757757
if (content.size() > MAX_INPUT_SIZE) {
758758
LOG_DBG("XML content too large: %zu bytes (max: %zu)\n", content.size(), MAX_INPUT_SIZE);
@@ -761,7 +761,7 @@ bool common_chat_msg_parser::parse_qwen3_xml_tool_call(const std::string & conte
761761
"XML content exceeds maximum size limit of " + std::to_string(MAX_INPUT_SIZE) + " bytes");
762762
return false;
763763
}
764-
764+
765765
// Validate tools vector size
766766
if (tools.size() > MAX_PARAMETER_COUNT) {
767767
LOG_DBG("Too many tools provided: %zu (max: %zu)\n", tools.size(), MAX_PARAMETER_COUNT);
@@ -801,6 +801,10 @@ bool common_chat_msg_parser::parse_qwen3_xml_tool_call(const std::string & conte
801801
}
802802
}
803803

804+
if (!tool_call_tag->end_pos) {
805+
return true;
806+
}
807+
804808
// Find function tag within tool_call - use string_view for performance
805809
std::string_view tool_call_content_view(tool_call_tag->content);
806810
auto function_tag = find_xml_tag(tool_call_content_view, "function", 0, &error);
@@ -815,7 +819,7 @@ bool common_chat_msg_parser::parse_qwen3_xml_tool_call(const std::string & conte
815819
}
816820

817821
std::string function_name = trim_whitespace(function_tag->attribute);
818-
822+
819823
// Validate function name
820824
if (function_name.empty() || function_name.size() > MAX_TAG_NAME_LENGTH) {
821825
LOG_DBG("Invalid function name: '%s' (length: %zu, max: %zu)\n",
@@ -826,7 +830,7 @@ bool common_chat_msg_parser::parse_qwen3_xml_tool_call(const std::string & conte
826830
"Invalid function name: '" + function_name + "' (length: " + std::to_string(function_name.size()) + ", max: " + std::to_string(MAX_TAG_NAME_LENGTH) + ")");
827831
return false;
828832
}
829-
833+
830834
// PERFORMANCE OPTIMIZATION: Use hash set for O(1) function lookup instead of O(n) loop
831835
if (!tools.empty() && valid_functions.find(function_name) == valid_functions.end()) {
832836
LOG_DBG("Function '%s' not found in available tools\n", function_name.c_str());
@@ -836,20 +840,20 @@ bool common_chat_msg_parser::parse_qwen3_xml_tool_call(const std::string & conte
836840
"Function '" + function_name + "' not found in available tools");
837841
return false;
838842
}
839-
843+
840844
// Get parameter configuration for this function - use string_view
841845
auto param_config = get_param_config(std::string_view(function_name), tools);
842846

843847
// Parse parameters within function tag - use string_view for performance
844848
json arguments = json::object();
845849
std::string_view function_content_view(function_tag->content);
846850
auto parameter_tags = find_all_xml_tags(function_content_view, "parameter", &error);
847-
851+
848852
// Check if error occurred during parameter parsing
849853
if (error.has_error()) {
850854
return false;
851855
}
852-
856+
853857
// Limit parameter count for DoS protection
854858
size_t param_count = 0;
855859
for (const auto & param_tag : parameter_tags) {
@@ -862,22 +866,22 @@ bool common_chat_msg_parser::parse_qwen3_xml_tool_call(const std::string & conte
862866
"Too many parameters for function '" + function_name + "': " + std::to_string(param_count) + " (max: " + std::to_string(MAX_PARAMETER_COUNT) + ")");
863867
break;
864868
}
865-
869+
866870
if (param_tag.attribute.empty()) {
867871
LOG_DBG("Skipping parameter with empty attribute\n");
868872
continue; // Skip malformed parameter tags
869873
}
870-
874+
871875
std::string param_name = trim_whitespace(param_tag.attribute);
872876
std::string param_value = param_tag.content;
873-
877+
874878
// Validate parameter name
875879
if (param_name.empty() || param_name.size() > MAX_TAG_NAME_LENGTH) {
876880
LOG_DBG("Invalid parameter name: '%s' (length: %zu, max: %zu)\n",
877881
param_name.c_str(), param_name.size(), MAX_TAG_NAME_LENGTH);
878882
continue;
879883
}
880-
884+
881885
// Convert value based on schema type - use string_view for performance
882886
try {
883887
std::string converted_value = convert_qwen3_param_value(

0 commit comments

Comments
 (0)