Skip to content

Commit bc6fad1

Browse files
mrubensdaniel-lxs
andauthored
Add native tool calling support to OpenAI-compatible (#9369)
* Add native tool calling support to OpenAI-compatible * Fix OpenAI strict mode schema validation by adding converter methods to BaseProvider - Add convertToolsForOpenAI() and convertToolSchemaForOpenAI() methods to BaseProvider - These methods ensure all properties are in required array and convert nullable types - Remove line_ranges from required array in read_file tool (converter handles it) - Update OpenAiHandler and BaseOpenAiCompatibleProvider to use helper methods - Eliminates code duplication across multiple tool usage sites - Fixes: OpenAI completion error: 400 Invalid schema for function 'read_file' --------- Co-authored-by: daniel-lxs <[email protected]>
1 parent b0c254c commit bc6fad1

File tree

6 files changed

+434
-7
lines changed

6 files changed

+434
-7
lines changed

packages/types/src/providers/openai.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,7 @@ export const openAiModelInfoSaneDefaults: ModelInfo = {
436436
supportsPromptCache: false,
437437
inputPrice: 0,
438438
outputPrice: 0,
439+
supportsNativeTools: true,
439440
}
440441

441442
// https://learn.microsoft.com/en-us/azure/ai-services/openai/api-version-deprecation

src/api/providers/__tests__/openai.spec.ts

Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,55 @@ describe("OpenAiHandler", () => {
157157
expect(usageChunk?.outputTokens).toBe(5)
158158
})
159159

160+
it("should handle tool calls in non-streaming mode", async () => {
161+
mockCreate.mockResolvedValueOnce({
162+
choices: [
163+
{
164+
message: {
165+
role: "assistant",
166+
content: null,
167+
tool_calls: [
168+
{
169+
id: "call_1",
170+
type: "function",
171+
function: {
172+
name: "test_tool",
173+
arguments: '{"arg":"value"}',
174+
},
175+
},
176+
],
177+
},
178+
finish_reason: "tool_calls",
179+
},
180+
],
181+
usage: {
182+
prompt_tokens: 10,
183+
completion_tokens: 5,
184+
total_tokens: 15,
185+
},
186+
})
187+
188+
const handler = new OpenAiHandler({
189+
...mockOptions,
190+
openAiStreamingEnabled: false,
191+
})
192+
193+
const stream = handler.createMessage(systemPrompt, messages)
194+
const chunks: any[] = []
195+
for await (const chunk of stream) {
196+
chunks.push(chunk)
197+
}
198+
199+
const toolCallChunks = chunks.filter((chunk) => chunk.type === "tool_call")
200+
expect(toolCallChunks).toHaveLength(1)
201+
expect(toolCallChunks[0]).toEqual({
202+
type: "tool_call",
203+
id: "call_1",
204+
name: "test_tool",
205+
arguments: '{"arg":"value"}',
206+
})
207+
})
208+
160209
it("should handle streaming responses", async () => {
161210
const stream = handler.createMessage(systemPrompt, messages)
162211
const chunks: any[] = []
@@ -170,6 +219,66 @@ describe("OpenAiHandler", () => {
170219
expect(textChunks[0].text).toBe("Test response")
171220
})
172221

222+
it("should handle tool calls in streaming responses", async () => {
223+
mockCreate.mockImplementation(async (options) => {
224+
return {
225+
[Symbol.asyncIterator]: async function* () {
226+
yield {
227+
choices: [
228+
{
229+
delta: {
230+
tool_calls: [
231+
{
232+
index: 0,
233+
id: "call_1",
234+
function: { name: "test_tool", arguments: "" },
235+
},
236+
],
237+
},
238+
finish_reason: null,
239+
},
240+
],
241+
}
242+
yield {
243+
choices: [
244+
{
245+
delta: {
246+
tool_calls: [{ index: 0, function: { arguments: '{"arg":' } }],
247+
},
248+
finish_reason: null,
249+
},
250+
],
251+
}
252+
yield {
253+
choices: [
254+
{
255+
delta: {
256+
tool_calls: [{ index: 0, function: { arguments: '"value"}' } }],
257+
},
258+
finish_reason: "tool_calls",
259+
},
260+
],
261+
}
262+
},
263+
}
264+
})
265+
266+
const stream = handler.createMessage(systemPrompt, messages)
267+
const chunks: any[] = []
268+
for await (const chunk of stream) {
269+
chunks.push(chunk)
270+
}
271+
272+
const toolCallChunks = chunks.filter((chunk) => chunk.type === "tool_call")
273+
expect(toolCallChunks).toHaveLength(1)
274+
expect(toolCallChunks[0]).toEqual({
275+
type: "tool_call",
276+
id: "call_1",
277+
name: "test_tool",
278+
arguments: '{"arg":"value"}',
279+
})
280+
})
281+
173282
it("should include reasoning_effort when reasoning effort is enabled", async () => {
174283
const reasoningOptions: ApiHandlerOptions = {
175284
...mockOptions,
@@ -618,6 +727,58 @@ describe("OpenAiHandler", () => {
618727
)
619728
})
620729

730+
it("should handle tool calls with O3 model in streaming mode", async () => {
731+
const o3Handler = new OpenAiHandler(o3Options)
732+
733+
mockCreate.mockImplementation(async (options) => {
734+
return {
735+
[Symbol.asyncIterator]: async function* () {
736+
yield {
737+
choices: [
738+
{
739+
delta: {
740+
tool_calls: [
741+
{
742+
index: 0,
743+
id: "call_1",
744+
function: { name: "test_tool", arguments: "" },
745+
},
746+
],
747+
},
748+
finish_reason: null,
749+
},
750+
],
751+
}
752+
yield {
753+
choices: [
754+
{
755+
delta: {
756+
tool_calls: [{ index: 0, function: { arguments: "{}" } }],
757+
},
758+
finish_reason: "tool_calls",
759+
},
760+
],
761+
}
762+
},
763+
}
764+
})
765+
766+
const stream = o3Handler.createMessage("system", [])
767+
const chunks: any[] = []
768+
for await (const chunk of stream) {
769+
chunks.push(chunk)
770+
}
771+
772+
const toolCallChunks = chunks.filter((chunk) => chunk.type === "tool_call")
773+
expect(toolCallChunks).toHaveLength(1)
774+
expect(toolCallChunks[0]).toEqual({
775+
type: "tool_call",
776+
id: "call_1",
777+
name: "test_tool",
778+
arguments: "{}",
779+
})
780+
})
781+
621782
it("should handle O3 model with streaming and exclude max_tokens when includeMaxTokens is false", async () => {
622783
const o3Handler = new OpenAiHandler({
623784
...o3Options,
@@ -705,6 +866,55 @@ describe("OpenAiHandler", () => {
705866
expect(callArgs).not.toHaveProperty("stream")
706867
})
707868

869+
it("should handle tool calls with O3 model in non-streaming mode", async () => {
870+
const o3Handler = new OpenAiHandler({
871+
...o3Options,
872+
openAiStreamingEnabled: false,
873+
})
874+
875+
mockCreate.mockResolvedValueOnce({
876+
choices: [
877+
{
878+
message: {
879+
role: "assistant",
880+
content: null,
881+
tool_calls: [
882+
{
883+
id: "call_1",
884+
type: "function",
885+
function: {
886+
name: "test_tool",
887+
arguments: "{}",
888+
},
889+
},
890+
],
891+
},
892+
finish_reason: "tool_calls",
893+
},
894+
],
895+
usage: {
896+
prompt_tokens: 10,
897+
completion_tokens: 5,
898+
total_tokens: 15,
899+
},
900+
})
901+
902+
const stream = o3Handler.createMessage("system", [])
903+
const chunks: any[] = []
904+
for await (const chunk of stream) {
905+
chunks.push(chunk)
906+
}
907+
908+
const toolCallChunks = chunks.filter((chunk) => chunk.type === "tool_call")
909+
expect(toolCallChunks).toHaveLength(1)
910+
expect(toolCallChunks[0]).toEqual({
911+
type: "tool_call",
912+
id: "call_1",
913+
name: "test_tool",
914+
arguments: "{}",
915+
})
916+
})
917+
708918
it("should use default temperature of 0 when not specified for O3 models", async () => {
709919
const o3Handler = new OpenAiHandler({
710920
...o3Options,

src/api/providers/base-openai-compatible-provider.ts

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>
9090
messages: [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)],
9191
stream: true,
9292
stream_options: { include_usage: true },
93+
...(metadata?.tools && { tools: this.convertToolsForOpenAI(metadata.tools) }),
94+
...(metadata?.tool_choice && { tool_choice: metadata.tool_choice }),
9395
}
9496

9597
try {
@@ -115,6 +117,8 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>
115117
}) as const,
116118
)
117119

120+
const toolCallAccumulator = new Map<number, { id: string; name: string; arguments: string }>()
121+
118122
for await (const chunk of stream) {
119123
// Check for provider-specific error responses (e.g., MiniMax base_resp)
120124
const chunkAny = chunk as any
@@ -125,6 +129,7 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>
125129
}
126130

127131
const delta = chunk.choices?.[0]?.delta
132+
const finishReason = chunk.choices?.[0]?.finish_reason
128133

129134
if (delta?.content) {
130135
for (const processedChunk of matcher.update(delta.content)) {
@@ -139,6 +144,37 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>
139144
}
140145
}
141146

147+
if (delta?.tool_calls) {
148+
for (const toolCall of delta.tool_calls) {
149+
const index = toolCall.index
150+
const existing = toolCallAccumulator.get(index)
151+
152+
if (existing) {
153+
if (toolCall.function?.arguments) {
154+
existing.arguments += toolCall.function.arguments
155+
}
156+
} else {
157+
toolCallAccumulator.set(index, {
158+
id: toolCall.id || "",
159+
name: toolCall.function?.name || "",
160+
arguments: toolCall.function?.arguments || "",
161+
})
162+
}
163+
}
164+
}
165+
166+
if (finishReason === "tool_calls") {
167+
for (const toolCall of toolCallAccumulator.values()) {
168+
yield {
169+
type: "tool_call",
170+
id: toolCall.id,
171+
name: toolCall.name,
172+
arguments: toolCall.arguments,
173+
}
174+
}
175+
toolCallAccumulator.clear()
176+
}
177+
142178
if (chunk.usage) {
143179
yield {
144180
type: "usage",

src/api/providers/base-provider.ts

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,75 @@ export abstract class BaseProvider implements ApiHandler {
1818

1919
abstract getModel(): { id: string; info: ModelInfo }
2020

21+
/**
22+
* Converts an array of tools to be compatible with OpenAI's strict mode.
23+
* Filters for function tools and applies schema conversion to their parameters.
24+
*/
25+
protected convertToolsForOpenAI(tools: any[] | undefined): any[] | undefined {
26+
if (!tools) {
27+
return undefined
28+
}
29+
30+
return tools.map((tool) =>
31+
tool.type === "function"
32+
? {
33+
...tool,
34+
function: {
35+
...tool.function,
36+
parameters: this.convertToolSchemaForOpenAI(tool.function.parameters),
37+
},
38+
}
39+
: tool,
40+
)
41+
}
42+
43+
/**
44+
* Converts tool schemas to be compatible with OpenAI's strict mode by:
45+
* - Ensuring all properties are in the required array (strict mode requirement)
46+
* - Converting nullable types (["type", "null"]) to non-nullable ("type")
47+
* - Recursively processing nested objects and arrays
48+
*
49+
* This matches the behavior of ensureAllRequired in openai-native.ts
50+
*/
51+
protected convertToolSchemaForOpenAI(schema: any): any {
52+
if (!schema || typeof schema !== "object" || schema.type !== "object") {
53+
return schema
54+
}
55+
56+
const result = { ...schema }
57+
58+
if (result.properties) {
59+
const allKeys = Object.keys(result.properties)
60+
// OpenAI strict mode requires ALL properties to be in required array
61+
result.required = allKeys
62+
63+
// Recursively process nested objects and convert nullable types
64+
const newProps = { ...result.properties }
65+
for (const key of allKeys) {
66+
const prop = newProps[key]
67+
68+
// Handle nullable types by removing null
69+
if (prop && Array.isArray(prop.type) && prop.type.includes("null")) {
70+
const nonNullTypes = prop.type.filter((t: string) => t !== "null")
71+
prop.type = nonNullTypes.length === 1 ? nonNullTypes[0] : nonNullTypes
72+
}
73+
74+
// Recursively process nested objects
75+
if (prop && prop.type === "object") {
76+
newProps[key] = this.convertToolSchemaForOpenAI(prop)
77+
} else if (prop && prop.type === "array" && prop.items?.type === "object") {
78+
newProps[key] = {
79+
...prop,
80+
items: this.convertToolSchemaForOpenAI(prop.items),
81+
}
82+
}
83+
}
84+
result.properties = newProps
85+
}
86+
87+
return result
88+
}
89+
2190
/**
2291
* Default token counting implementation using tiktoken.
2392
* Providers can override this to use their native token counting endpoints.

0 commit comments

Comments
 (0)