diff --git a/Libraries/MLXLMCommon/Tool/Parsers/PythonToolCallParser.swift b/Libraries/MLXLMCommon/Tool/Parsers/PythonToolCallParser.swift
new file mode 100644
index 00000000..e44c22d8
--- /dev/null
+++ b/Libraries/MLXLMCommon/Tool/Parsers/PythonToolCallParser.swift
@@ -0,0 +1,281 @@
+// Copyright © 2025 Apple Inc.
+
+import Foundation
+
+/// Parser for Python-style function calls: `[func(arg="value", arg2=123)]`
+/// Reference: https://github.com/ml-explore/mlx-lm/tree/main/mlx_lm/tool_parsers
+public struct PythonToolCallParser: ToolCallParser, Sendable {
+ public let startTag: String?
+ public let endTag: String?
+
+ public init(startTag: String, endTag: String) {
+ self.startTag = startTag
+ self.endTag = endTag
+ }
+
+ public func parse(content: String, tools: [[String: any Sendable]]?) -> ToolCall? {
+ var text = content
+
+ // Strip tags if present
+ if let start = startTag, let range = text.range(of: start) {
+ text = String(text[range.upperBound...])
+ }
+ if let end = endTag, let range = text.range(of: end) {
+ text = String(text[.. ToolCall? {
+ // Skip leading brackets/whitespace
+ var i = buffer.startIndex
+ while i < buffer.endIndex {
+ let ch = buffer[i]
+ if ch == "[" || ch == "]" || ch == "," || ch.isWhitespace {
+ i = buffer.index(after: i)
+ continue
+ }
+ break
+ }
+
+ if i >= buffer.endIndex {
+ buffer = ""
+ return nil
+ }
+
+ // Read function name
+ guard let nameEnd = readIdentifier(buffer, from: i) else {
+ return nil
+ }
+ let name = String(buffer[i ..< nameEnd])
+
+ // Find opening paren
+ var j = nameEnd
+ skipWhitespace(buffer, &j)
+ guard j < buffer.endIndex, buffer[j] == "(" else { return nil }
+
+ // Find matching closing paren
+ guard let closeIdx = findMatchingParen(in: buffer, openIndex: j) else {
+ return nil
+ }
+
+ // Parse arguments
+ let argsBody = String(buffer[buffer.index(after: j) ..< closeIdx])
+ let arguments = parseArgs(argsBody)
+
+ // Update buffer (consume parsed content)
+ var k = buffer.index(after: closeIdx)
+ skipWhitespace(buffer, &k)
+ if k < buffer.endIndex, buffer[k] == "," {
+ k = buffer.index(after: k)
+ }
+ while k < buffer.endIndex, buffer[k] == "]" || buffer[k].isWhitespace {
+ k = buffer.index(after: k)
+ }
+ buffer = String(buffer[k...])
+
+ return ToolCall(function: ToolCall.Function(name: name, arguments: arguments))
+ }
+
+ private func readIdentifier(_ s: String, from start: String.Index) -> String.Index? {
+ var i = start
+ guard i < s.endIndex, s[i].isLetter || s[i] == "_" else { return nil }
+ i = s.index(after: i)
+ while i < s.endIndex, s[i].isLetter || s[i].isNumber || s[i] == "_" {
+ i = s.index(after: i)
+ }
+ return i
+ }
+
+ private func skipWhitespace(_ s: String, _ i: inout String.Index) {
+ while i < s.endIndex, s[i].isWhitespace {
+ i = s.index(after: i)
+ }
+ }
+
+ private func findMatchingParen(in s: String, openIndex: String.Index) -> String.Index? {
+ var i = s.index(after: openIndex)
+ var depth = 1
+ var quote: Character?
+ var escape = false
+
+ while i < s.endIndex {
+ let ch = s[i]
+ if let q = quote {
+ if escape {
+ escape = false
+ } else if ch == "\\" {
+ escape = true
+ } else if ch == q {
+ quote = nil
+ }
+ } else {
+ switch ch {
+ case "'", "\"": quote = ch
+ case "(": depth += 1
+ case ")":
+ depth -= 1
+ if depth == 0 { return i }
+ default: break
+ }
+ }
+ i = s.index(after: i)
+ }
+ return nil
+ }
+
+ private func parseArgs(_ body: String) -> [String: any Sendable] {
+ var result: [String: any Sendable] = [:]
+ let parts = splitTopLevel(body, on: ",")
+
+ for part in parts {
+ let trimmed = part.trimmingCharacters(in: .whitespacesAndNewlines)
+ guard !trimmed.isEmpty,
+ let eqIdx = indexOfTopLevelEquals(in: trimmed)
+ else { continue }
+
+ let key = String(trimmed[.. any Sendable {
+ guard let first = s.first else { return "" }
+
+ // Quoted string
+ if first == "\"" || first == "'" {
+ return unquoteString(s)
+ }
+
+ // Boolean
+ let lower = s.lowercased()
+ if lower == "true" { return true }
+ if lower == "false" { return false }
+ if lower == "none" || lower == "null" { return NSNull() }
+
+ // Number
+ if let intVal = Int(s) { return intVal }
+ if let dblVal = Double(s) { return dblVal }
+
+ return s
+ }
+
+ private func unquoteString(_ s: String) -> String {
+ guard let q = s.first, q == "\"" || q == "'", s.last == q else { return s }
+ let inner = s.dropFirst().dropLast()
+ var result = ""
+ var escape = false
+ for ch in inner {
+ if escape {
+ switch ch {
+ case "n": result.append("\n")
+ case "t": result.append("\t")
+ case "r": result.append("\r")
+ case "\\": result.append("\\")
+ case "\"": result.append("\"")
+ case "'": result.append("'")
+ default: result.append(ch)
+ }
+ escape = false
+ } else if ch == "\\" {
+ escape = true
+ } else {
+ result.append(ch)
+ }
+ }
+ return result
+ }
+
+ private func splitTopLevel(_ s: String, on sep: Character) -> [String] {
+ var result: [String] = []
+ var current = ""
+ var depth = 0
+ var quote: Character?
+ var escape = false
+
+ for ch in s {
+ if let q = quote {
+ current.append(ch)
+ if escape {
+ escape = false
+ } else if ch == "\\" {
+ escape = true
+ } else if ch == q {
+ quote = nil
+ }
+ } else {
+ switch ch {
+ case "'", "\"":
+ quote = ch
+ current.append(ch)
+ case "(", "[", "{":
+ depth += 1
+ current.append(ch)
+ case ")", "]", "}":
+ depth = max(0, depth - 1)
+ current.append(ch)
+ default:
+ if ch == sep && depth == 0 {
+ result.append(current.trimmingCharacters(in: .whitespacesAndNewlines))
+ current = ""
+ } else {
+ current.append(ch)
+ }
+ }
+ }
+ }
+
+ let final = current.trimmingCharacters(in: .whitespacesAndNewlines)
+ if !final.isEmpty {
+ result.append(final)
+ }
+ return result
+ }
+
+ private func indexOfTopLevelEquals(in s: String) -> String.Index? {
+ var i = s.startIndex
+ var depthParen = 0
+ var depthBrace = 0
+ var depthBracket = 0
+ var quote: Character?
+ var escape = false
+
+ while i < s.endIndex {
+ let ch = s[i]
+ if let q = quote {
+ if escape {
+ escape = false
+ } else if ch == "\\" {
+ escape = true
+ } else if ch == q {
+ quote = nil
+ }
+ } else {
+ switch ch {
+ case "'", "\"": quote = ch
+ case "(": depthParen += 1
+ case ")": if depthParen > 0 { depthParen -= 1 }
+ case "[": depthBracket += 1
+ case "]": if depthBracket > 0 { depthBracket -= 1 }
+ case "{": depthBrace += 1
+ case "}": if depthBrace > 0 { depthBrace -= 1 }
+ case "=":
+ if depthParen == 0, depthBrace == 0, depthBracket == 0 {
+ return i
+ }
+ default: break
+ }
+ }
+ i = s.index(after: i)
+ }
+ return nil
+ }
+}
diff --git a/Libraries/MLXLMCommon/Tool/ToolCallFormat.swift b/Libraries/MLXLMCommon/Tool/ToolCallFormat.swift
index 3b39bf60..101d19c4 100644
--- a/Libraries/MLXLMCommon/Tool/ToolCallFormat.swift
+++ b/Libraries/MLXLMCommon/Tool/ToolCallFormat.swift
@@ -42,8 +42,8 @@ public enum ToolCallFormat: String, Sendable, Codable, CaseIterable {
/// Example: `{"name": "func", "arguments": {...}}`
case json
- /// LFM2 JSON format with model-specific tags.
- /// Example: `<|tool_call_start|>{"name": "func", "arguments": {...}}<|tool_call_end|>`
+ /// LFM2 Python-style format with model-specific tags.
+ /// Example: `<|tool_call_start|>[name(parameter="value")]<|tool_call_end|>`
case lfm2
/// XML function format used by Qwen3 Coder.
@@ -75,7 +75,8 @@ public enum ToolCallFormat: String, Sendable, Codable, CaseIterable {
case .json:
return JSONToolCallParser(startTag: "", endTag: "")
case .lfm2:
- return JSONToolCallParser(startTag: "<|tool_call_start|>", endTag: "<|tool_call_end|>")
+ return PythonToolCallParser(
+ startTag: "<|tool_call_start|>", endTag: "<|tool_call_end|>")
case .xmlFunction:
return XMLFunctionParser()
case .glm4:
diff --git a/Tests/MLXLMTests/ToolTests.swift b/Tests/MLXLMTests/ToolTests.swift
index b2b312b8..3c03b57a 100644
--- a/Tests/MLXLMTests/ToolTests.swift
+++ b/Tests/MLXLMTests/ToolTests.swift
@@ -100,33 +100,6 @@ struct ToolTests {
#expect(toolCall.function.arguments["location"] == .string("Paris"))
}
- @Test("Test JSON Tool Call Parser - LFM2 Tags")
- func testJSONParserLFM2Tags() throws {
- let parser = JSONToolCallParser(
- startTag: "<|tool_call_start|>", endTag: "<|tool_call_end|>")
- let content =
- "<|tool_call_start|>{\"name\": \"search\", \"arguments\": {\"query\": \"swift programming\"}}<|tool_call_end|>"
-
- let toolCall = try #require(parser.parse(content: content, tools: nil))
-
- #expect(toolCall.function.name == "search")
- #expect(toolCall.function.arguments["query"] == .string("swift programming"))
- }
-
- @Test("Test LFM2 Format via ToolCallProcessor")
- func testLFM2FormatProcessor() throws {
- let processor = ToolCallProcessor(format: .lfm2)
- let content =
- "<|tool_call_start|>{\"name\": \"calculator\", \"arguments\": {\"expression\": \"2+2\"}}<|tool_call_end|>"
-
- _ = processor.processChunk(content)
-
- #expect(processor.toolCalls.count == 1)
- let toolCall = try #require(processor.toolCalls.first)
- #expect(toolCall.function.name == "calculator")
- #expect(toolCall.function.arguments["expression"] == .string("2+2"))
- }
-
// MARK: - XML Function Format Tests (Qwen3 Coder)
@Test("Test XML Function Parser - Qwen3 Coder Format")
@@ -168,6 +141,228 @@ struct ToolTests {
#expect(toolCall.function.arguments["enabled"] == .bool(true))
}
+ // MARK: - Python Function Format Tests
+
+ @Test("Test Python Tool Call Parser - Basic")
+ func testPythonToolCallParserBasic() throws {
+ let parser = PythonToolCallParser(
+ startTag: "<|tool_call_start|>", endTag: "<|tool_call_end|>")
+ let content =
+ "<|tool_call_start|>[get_weather(location=\"Paris\")]<|tool_call_end|>"
+
+ let toolCall = try #require(parser.parse(content: content, tools: nil))
+
+ #expect(toolCall.function.name == "get_weather")
+ #expect(toolCall.function.arguments["location"] == .string("Paris"))
+ }
+
+ @Test("Test Python Tool Call Parser - Multiple Arguments")
+ func testPythonToolCallParserMultipleArgs() throws {
+ let parser = PythonToolCallParser(
+ startTag: "<|tool_call_start|>", endTag: "<|tool_call_end|>")
+ let content =
+ "<|tool_call_start|>[get_candidate_status(candidate_id=\"12345\", include_history=true)]<|tool_call_end|>"
+
+ let toolCall = try #require(parser.parse(content: content, tools: nil))
+
+ #expect(toolCall.function.name == "get_candidate_status")
+ #expect(toolCall.function.arguments["candidate_id"] == .string("12345"))
+ #expect(toolCall.function.arguments["include_history"] == .bool(true))
+ }
+
+ @Test("Test Python Tool Call Parser - Numeric Arguments")
+ func testPythonToolCallParserNumericArgs() throws {
+ let parser = PythonToolCallParser(
+ startTag: "<|tool_call_start|>", endTag: "<|tool_call_end|>")
+ let content =
+ "<|tool_call_start|>[calculate(x=42, y=3.14, enabled=false)]<|tool_call_end|>"
+
+ let toolCall = try #require(parser.parse(content: content, tools: nil))
+
+ #expect(toolCall.function.name == "calculate")
+ #expect(toolCall.function.arguments["x"] == .int(42))
+ #expect(toolCall.function.arguments["y"] == .double(3.14))
+ #expect(toolCall.function.arguments["enabled"] == .bool(false))
+ }
+
+ @Test("Test Python Tool Call Parser - Escaped Strings")
+ func testPythonToolCallParserEscapedStrings() throws {
+ let parser = PythonToolCallParser(
+ startTag: "<|tool_call_start|>", endTag: "<|tool_call_end|>")
+ let content =
+ "<|tool_call_start|>[search(query=\"hello\\nworld\")]<|tool_call_end|>"
+
+ let toolCall = try #require(parser.parse(content: content, tools: nil))
+
+ #expect(toolCall.function.name == "search")
+ #expect(toolCall.function.arguments["query"] == .string("hello\nworld"))
+ }
+
+ @Test("Test Python Tool Call Parser - Single Quotes")
+ func testPythonToolCallParserSingleQuotes() throws {
+ let parser = PythonToolCallParser(
+ startTag: "<|tool_call_start|>", endTag: "<|tool_call_end|>")
+ let content =
+ "<|tool_call_start|>[search(query='single quoted')]<|tool_call_end|>"
+
+ let toolCall = try #require(parser.parse(content: content, tools: nil))
+
+ #expect(toolCall.function.name == "search")
+ #expect(toolCall.function.arguments["query"] == .string("single quoted"))
+ }
+
+ @Test("Test Python Tool Call Parser - None/Null Value")
+ func testPythonToolCallParserNullValue() throws {
+ let parser = PythonToolCallParser(
+ startTag: "<|tool_call_start|>", endTag: "<|tool_call_end|>")
+ let content =
+ "<|tool_call_start|>[set_value(key=\"test\", value=None)]<|tool_call_end|>"
+
+ let toolCall = try #require(parser.parse(content: content, tools: nil))
+
+ #expect(toolCall.function.name == "set_value")
+ #expect(toolCall.function.arguments["key"] == .string("test"))
+ #expect(toolCall.function.arguments["value"] == .null)
+ }
+
+ @Test("Test Python Tool Call Parser - Single Quotes With Escapes")
+ func testPythonToolCallParserSingleQuotesWithEscapes() throws {
+ let parser = PythonToolCallParser(
+ startTag: "<|tool_call_start|>", endTag: "<|tool_call_end|>")
+ let content =
+ #"<|tool_call_start|>[find(name='O\'Hara, "Alex"')]<|tool_call_end|>"#
+
+ let toolCall = try #require(parser.parse(content: content, tools: nil))
+
+ #expect(toolCall.function.name == "find")
+ #expect(toolCall.function.arguments["name"] == .string("O'Hara, \"Alex\""))
+ }
+
+ @Test("Test Python Tool Call Parser - Without Brackets")
+ func testPythonToolCallParserWithoutBrackets() throws {
+ let parser = PythonToolCallParser(
+ startTag: "<|tool_call_start|>", endTag: "<|tool_call_end|>")
+ let content =
+ "<|tool_call_start|>get_latest_news(query=\"Apple\", sortBy=\"popularity\")<|tool_call_end|>"
+
+ let toolCall = try #require(parser.parse(content: content, tools: nil))
+
+ #expect(toolCall.function.name == "get_latest_news")
+ #expect(toolCall.function.arguments["query"] == .string("Apple"))
+ #expect(toolCall.function.arguments["sortBy"] == .string("popularity"))
+ }
+
+ @Test("Test Python Tool Call Parser - Trailing Comma")
+ func testPythonToolCallParserTrailingComma() throws {
+ let parser = PythonToolCallParser(
+ startTag: "<|tool_call_start|>", endTag: "<|tool_call_end|>")
+ let content =
+ "<|tool_call_start|>[get_stock_price(symbol=\"AAPL\"), ]<|tool_call_end|>"
+
+ let toolCall = try #require(parser.parse(content: content, tools: nil))
+
+ #expect(toolCall.function.name == "get_stock_price")
+ #expect(toolCall.function.arguments["symbol"] == .string("AAPL"))
+ }
+
+ @Test("Test Python Tool Call Parser - Commas Inside Quoted String")
+ func testPythonToolCallParserCommasInString() throws {
+ let parser = PythonToolCallParser(
+ startTag: "<|tool_call_start|>", endTag: "<|tool_call_end|>")
+ let content =
+ "<|tool_call_start|>[foo(msg='a, b, c')]<|tool_call_end|>"
+
+ let toolCall = try #require(parser.parse(content: content, tools: nil))
+
+ #expect(toolCall.function.name == "foo")
+ #expect(toolCall.function.arguments["msg"] == .string("a, b, c"))
+ }
+
+ @Test("Test Python Tool Call Parser - Nested Parens In String")
+ func testPythonToolCallParserNestedParensInString() throws {
+ let parser = PythonToolCallParser(
+ startTag: "<|tool_call_start|>", endTag: "<|tool_call_end|>")
+ let content =
+ #"<|tool_call_start|>[fn(expr="(a, b), c")]<|tool_call_end|>"#
+
+ let toolCall = try #require(parser.parse(content: content, tools: nil))
+
+ #expect(toolCall.function.name == "fn")
+ #expect(toolCall.function.arguments["expr"] == .string("(a, b), c"))
+ }
+
+ @Test("Test Python Tool Call Parser - Whitespace And Noise")
+ func testPythonToolCallParserWhitespaceAndNoise() throws {
+ let parser = PythonToolCallParser(
+ startTag: "<|tool_call_start|>", endTag: "<|tool_call_end|>")
+ let content =
+ "<|tool_call_start|> [ , foo(a=1) , ] <|tool_call_end|>"
+
+ let toolCall = try #require(parser.parse(content: content, tools: nil))
+
+ #expect(toolCall.function.name == "foo")
+ #expect(toolCall.function.arguments["a"] == .int(1))
+ }
+
+ @Test("Test Python Tool Call Parser - Location With Comma")
+ func testPythonToolCallParserLocationWithComma() throws {
+ let parser = PythonToolCallParser(
+ startTag: "<|tool_call_start|>", endTag: "<|tool_call_end|>")
+ let content =
+ "<|tool_call_start|>[get_current_weather(location=\"Paris, France\")]<|tool_call_end|>"
+
+ let toolCall = try #require(parser.parse(content: content, tools: nil))
+
+ #expect(toolCall.function.name == "get_current_weather")
+ #expect(toolCall.function.arguments["location"] == .string("Paris, France"))
+ }
+
+ @Test("Test Python Tool Call Parser - All Value Types")
+ func testPythonToolCallParserAllValueTypes() throws {
+ let parser = PythonToolCallParser(
+ startTag: "<|tool_call_start|>", endTag: "<|tool_call_end|>")
+ let content =
+ "<|tool_call_start|>[fn(count=3, pi=3.14, ok=true, off=false, none=null)]<|tool_call_end|>"
+
+ let toolCall = try #require(parser.parse(content: content, tools: nil))
+
+ #expect(toolCall.function.arguments["count"] == .int(3))
+ #expect(toolCall.function.arguments["pi"] == .double(3.14))
+ #expect(toolCall.function.arguments["ok"] == .bool(true))
+ #expect(toolCall.function.arguments["off"] == .bool(false))
+ #expect(toolCall.function.arguments["none"] == .null)
+ }
+
+ // MARK: - LFM2 Format Tests
+
+ @Test("Test LFM2 Format via ToolCallProcessor")
+ func testLFM2FormatProcessor() throws {
+ let processor = ToolCallProcessor(format: .lfm2)
+ let content =
+ "<|tool_call_start|>[get_status(id=\"abc123\")]<|tool_call_end|>"
+
+ _ = processor.processChunk(content)
+
+ #expect(processor.toolCalls.count == 1)
+ let toolCall = try #require(processor.toolCalls.first)
+ #expect(toolCall.function.name == "get_status")
+ #expect(toolCall.function.arguments["id"] == .string("abc123"))
+ }
+
+ @Test("Test LFM2 Format via ToolCallProcessor - Calculator")
+ func testLFM2FormatProcessorCalculator() throws {
+ let processor = ToolCallProcessor(format: .lfm2)
+ let content =
+ "<|tool_call_start|>[calculator(expression=\"2+2\")]<|tool_call_end|>"
+
+ _ = processor.processChunk(content)
+
+ #expect(processor.toolCalls.count == 1)
+ let toolCall = try #require(processor.toolCalls.first)
+ #expect(toolCall.function.name == "calculator")
+ #expect(toolCall.function.arguments["expression"] == .string("2+2"))
+ }
+
// MARK: - GLM4 Format Tests
@Test("Test GLM4 Tool Call Parser")