diff --git a/Libraries/MLXLMCommon/Tool/Parsers/PythonToolCallParser.swift b/Libraries/MLXLMCommon/Tool/Parsers/PythonToolCallParser.swift new file mode 100644 index 00000000..e44c22d8 --- /dev/null +++ b/Libraries/MLXLMCommon/Tool/Parsers/PythonToolCallParser.swift @@ -0,0 +1,281 @@ +// Copyright © 2025 Apple Inc. + +import Foundation + +/// Parser for Python-style function calls: `[func(arg="value", arg2=123)]` +/// Reference: https://github.com/ml-explore/mlx-lm/tree/main/mlx_lm/tool_parsers +public struct PythonToolCallParser: ToolCallParser, Sendable { + public let startTag: String? + public let endTag: String? + + public init(startTag: String, endTag: String) { + self.startTag = startTag + self.endTag = endTag + } + + public func parse(content: String, tools: [[String: any Sendable]]?) -> ToolCall? { + var text = content + + // Strip tags if present + if let start = startTag, let range = text.range(of: start) { + text = String(text[range.upperBound...]) + } + if let end = endTag, let range = text.range(of: end) { + text = String(text[.. ToolCall? { + // Skip leading brackets/whitespace + var i = buffer.startIndex + while i < buffer.endIndex { + let ch = buffer[i] + if ch == "[" || ch == "]" || ch == "," || ch.isWhitespace { + i = buffer.index(after: i) + continue + } + break + } + + if i >= buffer.endIndex { + buffer = "" + return nil + } + + // Read function name + guard let nameEnd = readIdentifier(buffer, from: i) else { + return nil + } + let name = String(buffer[i ..< nameEnd]) + + // Find opening paren + var j = nameEnd + skipWhitespace(buffer, &j) + guard j < buffer.endIndex, buffer[j] == "(" else { return nil } + + // Find matching closing paren + guard let closeIdx = findMatchingParen(in: buffer, openIndex: j) else { + return nil + } + + // Parse arguments + let argsBody = String(buffer[buffer.index(after: j) ..< closeIdx]) + let arguments = parseArgs(argsBody) + + // Update buffer (consume parsed content) + var k = buffer.index(after: closeIdx) + skipWhitespace(buffer, &k) + if k < buffer.endIndex, buffer[k] == "," { + k = buffer.index(after: k) + } + while k < buffer.endIndex, buffer[k] == "]" || buffer[k].isWhitespace { + k = buffer.index(after: k) + } + buffer = String(buffer[k...]) + + return ToolCall(function: ToolCall.Function(name: name, arguments: arguments)) + } + + private func readIdentifier(_ s: String, from start: String.Index) -> String.Index? { + var i = start + guard i < s.endIndex, s[i].isLetter || s[i] == "_" else { return nil } + i = s.index(after: i) + while i < s.endIndex, s[i].isLetter || s[i].isNumber || s[i] == "_" { + i = s.index(after: i) + } + return i + } + + private func skipWhitespace(_ s: String, _ i: inout String.Index) { + while i < s.endIndex, s[i].isWhitespace { + i = s.index(after: i) + } + } + + private func findMatchingParen(in s: String, openIndex: String.Index) -> String.Index? { + var i = s.index(after: openIndex) + var depth = 1 + var quote: Character? + var escape = false + + while i < s.endIndex { + let ch = s[i] + if let q = quote { + if escape { + escape = false + } else if ch == "\\" { + escape = true + } else if ch == q { + quote = nil + } + } else { + switch ch { + case "'", "\"": quote = ch + case "(": depth += 1 + case ")": + depth -= 1 + if depth == 0 { return i } + default: break + } + } + i = s.index(after: i) + } + return nil + } + + private func parseArgs(_ body: String) -> [String: any Sendable] { + var result: [String: any Sendable] = [:] + let parts = splitTopLevel(body, on: ",") + + for part in parts { + let trimmed = part.trimmingCharacters(in: .whitespacesAndNewlines) + guard !trimmed.isEmpty, + let eqIdx = indexOfTopLevelEquals(in: trimmed) + else { continue } + + let key = String(trimmed[.. any Sendable { + guard let first = s.first else { return "" } + + // Quoted string + if first == "\"" || first == "'" { + return unquoteString(s) + } + + // Boolean + let lower = s.lowercased() + if lower == "true" { return true } + if lower == "false" { return false } + if lower == "none" || lower == "null" { return NSNull() } + + // Number + if let intVal = Int(s) { return intVal } + if let dblVal = Double(s) { return dblVal } + + return s + } + + private func unquoteString(_ s: String) -> String { + guard let q = s.first, q == "\"" || q == "'", s.last == q else { return s } + let inner = s.dropFirst().dropLast() + var result = "" + var escape = false + for ch in inner { + if escape { + switch ch { + case "n": result.append("\n") + case "t": result.append("\t") + case "r": result.append("\r") + case "\\": result.append("\\") + case "\"": result.append("\"") + case "'": result.append("'") + default: result.append(ch) + } + escape = false + } else if ch == "\\" { + escape = true + } else { + result.append(ch) + } + } + return result + } + + private func splitTopLevel(_ s: String, on sep: Character) -> [String] { + var result: [String] = [] + var current = "" + var depth = 0 + var quote: Character? + var escape = false + + for ch in s { + if let q = quote { + current.append(ch) + if escape { + escape = false + } else if ch == "\\" { + escape = true + } else if ch == q { + quote = nil + } + } else { + switch ch { + case "'", "\"": + quote = ch + current.append(ch) + case "(", "[", "{": + depth += 1 + current.append(ch) + case ")", "]", "}": + depth = max(0, depth - 1) + current.append(ch) + default: + if ch == sep && depth == 0 { + result.append(current.trimmingCharacters(in: .whitespacesAndNewlines)) + current = "" + } else { + current.append(ch) + } + } + } + } + + let final = current.trimmingCharacters(in: .whitespacesAndNewlines) + if !final.isEmpty { + result.append(final) + } + return result + } + + private func indexOfTopLevelEquals(in s: String) -> String.Index? { + var i = s.startIndex + var depthParen = 0 + var depthBrace = 0 + var depthBracket = 0 + var quote: Character? + var escape = false + + while i < s.endIndex { + let ch = s[i] + if let q = quote { + if escape { + escape = false + } else if ch == "\\" { + escape = true + } else if ch == q { + quote = nil + } + } else { + switch ch { + case "'", "\"": quote = ch + case "(": depthParen += 1 + case ")": if depthParen > 0 { depthParen -= 1 } + case "[": depthBracket += 1 + case "]": if depthBracket > 0 { depthBracket -= 1 } + case "{": depthBrace += 1 + case "}": if depthBrace > 0 { depthBrace -= 1 } + case "=": + if depthParen == 0, depthBrace == 0, depthBracket == 0 { + return i + } + default: break + } + } + i = s.index(after: i) + } + return nil + } +} diff --git a/Libraries/MLXLMCommon/Tool/ToolCallFormat.swift b/Libraries/MLXLMCommon/Tool/ToolCallFormat.swift index 3b39bf60..101d19c4 100644 --- a/Libraries/MLXLMCommon/Tool/ToolCallFormat.swift +++ b/Libraries/MLXLMCommon/Tool/ToolCallFormat.swift @@ -42,8 +42,8 @@ public enum ToolCallFormat: String, Sendable, Codable, CaseIterable { /// Example: `{"name": "func", "arguments": {...}}` case json - /// LFM2 JSON format with model-specific tags. - /// Example: `<|tool_call_start|>{"name": "func", "arguments": {...}}<|tool_call_end|>` + /// LFM2 Python-style format with model-specific tags. + /// Example: `<|tool_call_start|>[name(parameter="value")]<|tool_call_end|>` case lfm2 /// XML function format used by Qwen3 Coder. @@ -75,7 +75,8 @@ public enum ToolCallFormat: String, Sendable, Codable, CaseIterable { case .json: return JSONToolCallParser(startTag: "", endTag: "") case .lfm2: - return JSONToolCallParser(startTag: "<|tool_call_start|>", endTag: "<|tool_call_end|>") + return PythonToolCallParser( + startTag: "<|tool_call_start|>", endTag: "<|tool_call_end|>") case .xmlFunction: return XMLFunctionParser() case .glm4: diff --git a/Tests/MLXLMTests/ToolTests.swift b/Tests/MLXLMTests/ToolTests.swift index b2b312b8..3c03b57a 100644 --- a/Tests/MLXLMTests/ToolTests.swift +++ b/Tests/MLXLMTests/ToolTests.swift @@ -100,33 +100,6 @@ struct ToolTests { #expect(toolCall.function.arguments["location"] == .string("Paris")) } - @Test("Test JSON Tool Call Parser - LFM2 Tags") - func testJSONParserLFM2Tags() throws { - let parser = JSONToolCallParser( - startTag: "<|tool_call_start|>", endTag: "<|tool_call_end|>") - let content = - "<|tool_call_start|>{\"name\": \"search\", \"arguments\": {\"query\": \"swift programming\"}}<|tool_call_end|>" - - let toolCall = try #require(parser.parse(content: content, tools: nil)) - - #expect(toolCall.function.name == "search") - #expect(toolCall.function.arguments["query"] == .string("swift programming")) - } - - @Test("Test LFM2 Format via ToolCallProcessor") - func testLFM2FormatProcessor() throws { - let processor = ToolCallProcessor(format: .lfm2) - let content = - "<|tool_call_start|>{\"name\": \"calculator\", \"arguments\": {\"expression\": \"2+2\"}}<|tool_call_end|>" - - _ = processor.processChunk(content) - - #expect(processor.toolCalls.count == 1) - let toolCall = try #require(processor.toolCalls.first) - #expect(toolCall.function.name == "calculator") - #expect(toolCall.function.arguments["expression"] == .string("2+2")) - } - // MARK: - XML Function Format Tests (Qwen3 Coder) @Test("Test XML Function Parser - Qwen3 Coder Format") @@ -168,6 +141,228 @@ struct ToolTests { #expect(toolCall.function.arguments["enabled"] == .bool(true)) } + // MARK: - Python Function Format Tests + + @Test("Test Python Tool Call Parser - Basic") + func testPythonToolCallParserBasic() throws { + let parser = PythonToolCallParser( + startTag: "<|tool_call_start|>", endTag: "<|tool_call_end|>") + let content = + "<|tool_call_start|>[get_weather(location=\"Paris\")]<|tool_call_end|>" + + let toolCall = try #require(parser.parse(content: content, tools: nil)) + + #expect(toolCall.function.name == "get_weather") + #expect(toolCall.function.arguments["location"] == .string("Paris")) + } + + @Test("Test Python Tool Call Parser - Multiple Arguments") + func testPythonToolCallParserMultipleArgs() throws { + let parser = PythonToolCallParser( + startTag: "<|tool_call_start|>", endTag: "<|tool_call_end|>") + let content = + "<|tool_call_start|>[get_candidate_status(candidate_id=\"12345\", include_history=true)]<|tool_call_end|>" + + let toolCall = try #require(parser.parse(content: content, tools: nil)) + + #expect(toolCall.function.name == "get_candidate_status") + #expect(toolCall.function.arguments["candidate_id"] == .string("12345")) + #expect(toolCall.function.arguments["include_history"] == .bool(true)) + } + + @Test("Test Python Tool Call Parser - Numeric Arguments") + func testPythonToolCallParserNumericArgs() throws { + let parser = PythonToolCallParser( + startTag: "<|tool_call_start|>", endTag: "<|tool_call_end|>") + let content = + "<|tool_call_start|>[calculate(x=42, y=3.14, enabled=false)]<|tool_call_end|>" + + let toolCall = try #require(parser.parse(content: content, tools: nil)) + + #expect(toolCall.function.name == "calculate") + #expect(toolCall.function.arguments["x"] == .int(42)) + #expect(toolCall.function.arguments["y"] == .double(3.14)) + #expect(toolCall.function.arguments["enabled"] == .bool(false)) + } + + @Test("Test Python Tool Call Parser - Escaped Strings") + func testPythonToolCallParserEscapedStrings() throws { + let parser = PythonToolCallParser( + startTag: "<|tool_call_start|>", endTag: "<|tool_call_end|>") + let content = + "<|tool_call_start|>[search(query=\"hello\\nworld\")]<|tool_call_end|>" + + let toolCall = try #require(parser.parse(content: content, tools: nil)) + + #expect(toolCall.function.name == "search") + #expect(toolCall.function.arguments["query"] == .string("hello\nworld")) + } + + @Test("Test Python Tool Call Parser - Single Quotes") + func testPythonToolCallParserSingleQuotes() throws { + let parser = PythonToolCallParser( + startTag: "<|tool_call_start|>", endTag: "<|tool_call_end|>") + let content = + "<|tool_call_start|>[search(query='single quoted')]<|tool_call_end|>" + + let toolCall = try #require(parser.parse(content: content, tools: nil)) + + #expect(toolCall.function.name == "search") + #expect(toolCall.function.arguments["query"] == .string("single quoted")) + } + + @Test("Test Python Tool Call Parser - None/Null Value") + func testPythonToolCallParserNullValue() throws { + let parser = PythonToolCallParser( + startTag: "<|tool_call_start|>", endTag: "<|tool_call_end|>") + let content = + "<|tool_call_start|>[set_value(key=\"test\", value=None)]<|tool_call_end|>" + + let toolCall = try #require(parser.parse(content: content, tools: nil)) + + #expect(toolCall.function.name == "set_value") + #expect(toolCall.function.arguments["key"] == .string("test")) + #expect(toolCall.function.arguments["value"] == .null) + } + + @Test("Test Python Tool Call Parser - Single Quotes With Escapes") + func testPythonToolCallParserSingleQuotesWithEscapes() throws { + let parser = PythonToolCallParser( + startTag: "<|tool_call_start|>", endTag: "<|tool_call_end|>") + let content = + #"<|tool_call_start|>[find(name='O\'Hara, "Alex"')]<|tool_call_end|>"# + + let toolCall = try #require(parser.parse(content: content, tools: nil)) + + #expect(toolCall.function.name == "find") + #expect(toolCall.function.arguments["name"] == .string("O'Hara, \"Alex\"")) + } + + @Test("Test Python Tool Call Parser - Without Brackets") + func testPythonToolCallParserWithoutBrackets() throws { + let parser = PythonToolCallParser( + startTag: "<|tool_call_start|>", endTag: "<|tool_call_end|>") + let content = + "<|tool_call_start|>get_latest_news(query=\"Apple\", sortBy=\"popularity\")<|tool_call_end|>" + + let toolCall = try #require(parser.parse(content: content, tools: nil)) + + #expect(toolCall.function.name == "get_latest_news") + #expect(toolCall.function.arguments["query"] == .string("Apple")) + #expect(toolCall.function.arguments["sortBy"] == .string("popularity")) + } + + @Test("Test Python Tool Call Parser - Trailing Comma") + func testPythonToolCallParserTrailingComma() throws { + let parser = PythonToolCallParser( + startTag: "<|tool_call_start|>", endTag: "<|tool_call_end|>") + let content = + "<|tool_call_start|>[get_stock_price(symbol=\"AAPL\"), ]<|tool_call_end|>" + + let toolCall = try #require(parser.parse(content: content, tools: nil)) + + #expect(toolCall.function.name == "get_stock_price") + #expect(toolCall.function.arguments["symbol"] == .string("AAPL")) + } + + @Test("Test Python Tool Call Parser - Commas Inside Quoted String") + func testPythonToolCallParserCommasInString() throws { + let parser = PythonToolCallParser( + startTag: "<|tool_call_start|>", endTag: "<|tool_call_end|>") + let content = + "<|tool_call_start|>[foo(msg='a, b, c')]<|tool_call_end|>" + + let toolCall = try #require(parser.parse(content: content, tools: nil)) + + #expect(toolCall.function.name == "foo") + #expect(toolCall.function.arguments["msg"] == .string("a, b, c")) + } + + @Test("Test Python Tool Call Parser - Nested Parens In String") + func testPythonToolCallParserNestedParensInString() throws { + let parser = PythonToolCallParser( + startTag: "<|tool_call_start|>", endTag: "<|tool_call_end|>") + let content = + #"<|tool_call_start|>[fn(expr="(a, b), c")]<|tool_call_end|>"# + + let toolCall = try #require(parser.parse(content: content, tools: nil)) + + #expect(toolCall.function.name == "fn") + #expect(toolCall.function.arguments["expr"] == .string("(a, b), c")) + } + + @Test("Test Python Tool Call Parser - Whitespace And Noise") + func testPythonToolCallParserWhitespaceAndNoise() throws { + let parser = PythonToolCallParser( + startTag: "<|tool_call_start|>", endTag: "<|tool_call_end|>") + let content = + "<|tool_call_start|> [ , foo(a=1) , ] <|tool_call_end|>" + + let toolCall = try #require(parser.parse(content: content, tools: nil)) + + #expect(toolCall.function.name == "foo") + #expect(toolCall.function.arguments["a"] == .int(1)) + } + + @Test("Test Python Tool Call Parser - Location With Comma") + func testPythonToolCallParserLocationWithComma() throws { + let parser = PythonToolCallParser( + startTag: "<|tool_call_start|>", endTag: "<|tool_call_end|>") + let content = + "<|tool_call_start|>[get_current_weather(location=\"Paris, France\")]<|tool_call_end|>" + + let toolCall = try #require(parser.parse(content: content, tools: nil)) + + #expect(toolCall.function.name == "get_current_weather") + #expect(toolCall.function.arguments["location"] == .string("Paris, France")) + } + + @Test("Test Python Tool Call Parser - All Value Types") + func testPythonToolCallParserAllValueTypes() throws { + let parser = PythonToolCallParser( + startTag: "<|tool_call_start|>", endTag: "<|tool_call_end|>") + let content = + "<|tool_call_start|>[fn(count=3, pi=3.14, ok=true, off=false, none=null)]<|tool_call_end|>" + + let toolCall = try #require(parser.parse(content: content, tools: nil)) + + #expect(toolCall.function.arguments["count"] == .int(3)) + #expect(toolCall.function.arguments["pi"] == .double(3.14)) + #expect(toolCall.function.arguments["ok"] == .bool(true)) + #expect(toolCall.function.arguments["off"] == .bool(false)) + #expect(toolCall.function.arguments["none"] == .null) + } + + // MARK: - LFM2 Format Tests + + @Test("Test LFM2 Format via ToolCallProcessor") + func testLFM2FormatProcessor() throws { + let processor = ToolCallProcessor(format: .lfm2) + let content = + "<|tool_call_start|>[get_status(id=\"abc123\")]<|tool_call_end|>" + + _ = processor.processChunk(content) + + #expect(processor.toolCalls.count == 1) + let toolCall = try #require(processor.toolCalls.first) + #expect(toolCall.function.name == "get_status") + #expect(toolCall.function.arguments["id"] == .string("abc123")) + } + + @Test("Test LFM2 Format via ToolCallProcessor - Calculator") + func testLFM2FormatProcessorCalculator() throws { + let processor = ToolCallProcessor(format: .lfm2) + let content = + "<|tool_call_start|>[calculator(expression=\"2+2\")]<|tool_call_end|>" + + _ = processor.processChunk(content) + + #expect(processor.toolCalls.count == 1) + let toolCall = try #require(processor.toolCalls.first) + #expect(toolCall.function.name == "calculator") + #expect(toolCall.function.arguments["expression"] == .string("2+2")) + } + // MARK: - GLM4 Format Tests @Test("Test GLM4 Tool Call Parser")