diff --git a/packages/types/src/providers/openai.ts b/packages/types/src/providers/openai.ts index 358a4c55b1..0bff8aea3f 100644 --- a/packages/types/src/providers/openai.ts +++ b/packages/types/src/providers/openai.ts @@ -436,6 +436,7 @@ export const openAiModelInfoSaneDefaults: ModelInfo = { supportsPromptCache: false, inputPrice: 0, outputPrice: 0, + supportsNativeTools: true, } // https://learn.microsoft.com/en-us/azure/ai-services/openai/api-version-deprecation diff --git a/src/api/providers/__tests__/openai.spec.ts b/src/api/providers/__tests__/openai.spec.ts index 6ec43a6dfa..452664e7dd 100644 --- a/src/api/providers/__tests__/openai.spec.ts +++ b/src/api/providers/__tests__/openai.spec.ts @@ -157,6 +157,55 @@ describe("OpenAiHandler", () => { expect(usageChunk?.outputTokens).toBe(5) }) + it("should handle tool calls in non-streaming mode", async () => { + mockCreate.mockResolvedValueOnce({ + choices: [ + { + message: { + role: "assistant", + content: null, + tool_calls: [ + { + id: "call_1", + type: "function", + function: { + name: "test_tool", + arguments: '{"arg":"value"}', + }, + }, + ], + }, + finish_reason: "tool_calls", + }, + ], + usage: { + prompt_tokens: 10, + completion_tokens: 5, + total_tokens: 15, + }, + }) + + const handler = new OpenAiHandler({ + ...mockOptions, + openAiStreamingEnabled: false, + }) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + const toolCallChunks = chunks.filter((chunk) => chunk.type === "tool_call") + expect(toolCallChunks).toHaveLength(1) + expect(toolCallChunks[0]).toEqual({ + type: "tool_call", + id: "call_1", + name: "test_tool", + arguments: '{"arg":"value"}', + }) + }) + it("should handle streaming responses", async () => { const stream = handler.createMessage(systemPrompt, messages) const chunks: any[] = [] @@ -170,6 +219,66 @@ describe("OpenAiHandler", () => { expect(textChunks[0].text).toBe("Test response") }) + it("should handle tool calls in streaming responses", async () => { + mockCreate.mockImplementation(async (options) => { + return { + [Symbol.asyncIterator]: async function* () { + yield { + choices: [ + { + delta: { + tool_calls: [ + { + index: 0, + id: "call_1", + function: { name: "test_tool", arguments: "" }, + }, + ], + }, + finish_reason: null, + }, + ], + } + yield { + choices: [ + { + delta: { + tool_calls: [{ index: 0, function: { arguments: '{"arg":' } }], + }, + finish_reason: null, + }, + ], + } + yield { + choices: [ + { + delta: { + tool_calls: [{ index: 0, function: { arguments: '"value"}' } }], + }, + finish_reason: "tool_calls", + }, + ], + } + }, + } + }) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + const toolCallChunks = chunks.filter((chunk) => chunk.type === "tool_call") + expect(toolCallChunks).toHaveLength(1) + expect(toolCallChunks[0]).toEqual({ + type: "tool_call", + id: "call_1", + name: "test_tool", + arguments: '{"arg":"value"}', + }) + }) + it("should include reasoning_effort when reasoning effort is enabled", async () => { const reasoningOptions: ApiHandlerOptions = { ...mockOptions, @@ -618,6 +727,58 @@ describe("OpenAiHandler", () => { ) }) + it("should handle tool calls with O3 model in streaming mode", async () => { + const o3Handler = new OpenAiHandler(o3Options) + + mockCreate.mockImplementation(async (options) => { + return { + [Symbol.asyncIterator]: async function* () { + yield { + choices: [ + { + delta: { + tool_calls: [ + { + index: 0, + id: "call_1", + function: { name: "test_tool", arguments: "" }, + }, + ], + }, + finish_reason: null, + }, + ], + } + yield { + choices: [ + { + delta: { + tool_calls: [{ index: 0, function: { arguments: "{}" } }], + }, + finish_reason: "tool_calls", + }, + ], + } + }, + } + }) + + const stream = o3Handler.createMessage("system", []) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + const toolCallChunks = chunks.filter((chunk) => chunk.type === "tool_call") + expect(toolCallChunks).toHaveLength(1) + expect(toolCallChunks[0]).toEqual({ + type: "tool_call", + id: "call_1", + name: "test_tool", + arguments: "{}", + }) + }) + it("should handle O3 model with streaming and exclude max_tokens when includeMaxTokens is false", async () => { const o3Handler = new OpenAiHandler({ ...o3Options, @@ -705,6 +866,55 @@ describe("OpenAiHandler", () => { expect(callArgs).not.toHaveProperty("stream") }) + it("should handle tool calls with O3 model in non-streaming mode", async () => { + const o3Handler = new OpenAiHandler({ + ...o3Options, + openAiStreamingEnabled: false, + }) + + mockCreate.mockResolvedValueOnce({ + choices: [ + { + message: { + role: "assistant", + content: null, + tool_calls: [ + { + id: "call_1", + type: "function", + function: { + name: "test_tool", + arguments: "{}", + }, + }, + ], + }, + finish_reason: "tool_calls", + }, + ], + usage: { + prompt_tokens: 10, + completion_tokens: 5, + total_tokens: 15, + }, + }) + + const stream = o3Handler.createMessage("system", []) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + const toolCallChunks = chunks.filter((chunk) => chunk.type === "tool_call") + expect(toolCallChunks).toHaveLength(1) + expect(toolCallChunks[0]).toEqual({ + type: "tool_call", + id: "call_1", + name: "test_tool", + arguments: "{}", + }) + }) + it("should use default temperature of 0 when not specified for O3 models", async () => { const o3Handler = new OpenAiHandler({ ...o3Options, diff --git a/src/api/providers/base-openai-compatible-provider.ts b/src/api/providers/base-openai-compatible-provider.ts index 2a240510a2..51db85410e 100644 --- a/src/api/providers/base-openai-compatible-provider.ts +++ b/src/api/providers/base-openai-compatible-provider.ts @@ -90,6 +90,8 @@ export abstract class BaseOpenAiCompatibleProvider messages: [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)], stream: true, stream_options: { include_usage: true }, + ...(metadata?.tools && { tools: this.convertToolsForOpenAI(metadata.tools) }), + ...(metadata?.tool_choice && { tool_choice: metadata.tool_choice }), } try { @@ -115,6 +117,8 @@ export abstract class BaseOpenAiCompatibleProvider }) as const, ) + const toolCallAccumulator = new Map() + for await (const chunk of stream) { // Check for provider-specific error responses (e.g., MiniMax base_resp) const chunkAny = chunk as any @@ -125,6 +129,7 @@ export abstract class BaseOpenAiCompatibleProvider } const delta = chunk.choices?.[0]?.delta + const finishReason = chunk.choices?.[0]?.finish_reason if (delta?.content) { for (const processedChunk of matcher.update(delta.content)) { @@ -139,6 +144,37 @@ export abstract class BaseOpenAiCompatibleProvider } } + if (delta?.tool_calls) { + for (const toolCall of delta.tool_calls) { + const index = toolCall.index + const existing = toolCallAccumulator.get(index) + + if (existing) { + if (toolCall.function?.arguments) { + existing.arguments += toolCall.function.arguments + } + } else { + toolCallAccumulator.set(index, { + id: toolCall.id || "", + name: toolCall.function?.name || "", + arguments: toolCall.function?.arguments || "", + }) + } + } + } + + if (finishReason === "tool_calls") { + for (const toolCall of toolCallAccumulator.values()) { + yield { + type: "tool_call", + id: toolCall.id, + name: toolCall.name, + arguments: toolCall.arguments, + } + } + toolCallAccumulator.clear() + } + if (chunk.usage) { yield { type: "usage", diff --git a/src/api/providers/base-provider.ts b/src/api/providers/base-provider.ts index 1abbf5f558..a0611a7b3f 100644 --- a/src/api/providers/base-provider.ts +++ b/src/api/providers/base-provider.ts @@ -18,6 +18,75 @@ export abstract class BaseProvider implements ApiHandler { abstract getModel(): { id: string; info: ModelInfo } + /** + * Converts an array of tools to be compatible with OpenAI's strict mode. + * Filters for function tools and applies schema conversion to their parameters. + */ + protected convertToolsForOpenAI(tools: any[] | undefined): any[] | undefined { + if (!tools) { + return undefined + } + + return tools.map((tool) => + tool.type === "function" + ? { + ...tool, + function: { + ...tool.function, + parameters: this.convertToolSchemaForOpenAI(tool.function.parameters), + }, + } + : tool, + ) + } + + /** + * Converts tool schemas to be compatible with OpenAI's strict mode by: + * - Ensuring all properties are in the required array (strict mode requirement) + * - Converting nullable types (["type", "null"]) to non-nullable ("type") + * - Recursively processing nested objects and arrays + * + * This matches the behavior of ensureAllRequired in openai-native.ts + */ + protected convertToolSchemaForOpenAI(schema: any): any { + if (!schema || typeof schema !== "object" || schema.type !== "object") { + return schema + } + + const result = { ...schema } + + if (result.properties) { + const allKeys = Object.keys(result.properties) + // OpenAI strict mode requires ALL properties to be in required array + result.required = allKeys + + // Recursively process nested objects and convert nullable types + const newProps = { ...result.properties } + for (const key of allKeys) { + const prop = newProps[key] + + // Handle nullable types by removing null + if (prop && Array.isArray(prop.type) && prop.type.includes("null")) { + const nonNullTypes = prop.type.filter((t: string) => t !== "null") + prop.type = nonNullTypes.length === 1 ? nonNullTypes[0] : nonNullTypes + } + + // Recursively process nested objects + if (prop && prop.type === "object") { + newProps[key] = this.convertToolSchemaForOpenAI(prop) + } else if (prop && prop.type === "array" && prop.items?.type === "object") { + newProps[key] = { + ...prop, + items: this.convertToolSchemaForOpenAI(prop.items), + } + } + } + result.properties = newProps + } + + return result + } + /** * Default token counting implementation using tiktoken. * Providers can override this to use their native token counting endpoints. diff --git a/src/api/providers/openai.ts b/src/api/providers/openai.ts index 6b847be2d0..79d65e82e2 100644 --- a/src/api/providers/openai.ts +++ b/src/api/providers/openai.ts @@ -95,7 +95,7 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl const ark = modelUrl.includes(".volces.com") if (modelId.includes("o1") || modelId.includes("o3") || modelId.includes("o4")) { - yield* this.handleO3FamilyMessage(modelId, systemPrompt, messages) + yield* this.handleO3FamilyMessage(modelId, systemPrompt, messages, metadata) return } @@ -164,6 +164,8 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl stream: true as const, ...(isGrokXAI ? {} : { stream_options: { include_usage: true } }), ...(reasoning && reasoning), + ...(metadata?.tools && { tools: this.convertToolsForOpenAI(metadata.tools) }), + ...(metadata?.tool_choice && { tool_choice: metadata.tool_choice }), } // Add max_tokens if needed @@ -189,9 +191,11 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl ) let lastUsage + const toolCallAccumulator = new Map() for await (const chunk of stream) { const delta = chunk.choices?.[0]?.delta ?? {} + const finishReason = chunk.choices?.[0]?.finish_reason if (delta.content) { for (const chunk of matcher.update(delta.content)) { @@ -205,6 +209,38 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl text: (delta.reasoning_content as string | undefined) || "", } } + + if (delta.tool_calls) { + for (const toolCall of delta.tool_calls) { + const index = toolCall.index + const existing = toolCallAccumulator.get(index) + + if (existing) { + if (toolCall.function?.arguments) { + existing.arguments += toolCall.function.arguments + } + } else { + toolCallAccumulator.set(index, { + id: toolCall.id || "", + name: toolCall.function?.name || "", + arguments: toolCall.function?.arguments || "", + }) + } + } + } + + if (finishReason === "tool_calls") { + for (const toolCall of toolCallAccumulator.values()) { + yield { + type: "tool_call", + id: toolCall.id, + name: toolCall.name, + arguments: toolCall.arguments, + } + } + toolCallAccumulator.clear() + } + if (chunk.usage) { lastUsage = chunk.usage } @@ -225,6 +261,8 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl : enabledLegacyFormat ? [systemMessage, ...convertToSimpleMessages(messages)] : [systemMessage, ...convertToOpenAiMessages(messages)], + ...(metadata?.tools && { tools: this.convertToolsForOpenAI(metadata.tools) }), + ...(metadata?.tool_choice && { tool_choice: metadata.tool_choice }), } // Add max_tokens if needed @@ -240,9 +278,24 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl throw handleOpenAIError(error, this.providerName) } + const message = response.choices?.[0]?.message + + if (message?.tool_calls) { + for (const toolCall of message.tool_calls) { + if (toolCall.type === "function") { + yield { + type: "tool_call", + id: toolCall.id, + name: toolCall.function.name, + arguments: toolCall.function.arguments, + } + } + } + } + yield { type: "text", - text: response.choices?.[0]?.message.content || "", + text: message?.content || "", } yield this.processUsageMetrics(response.usage, modelInfo) @@ -304,6 +357,7 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl modelId: string, systemPrompt: string, messages: Anthropic.Messages.MessageParam[], + metadata?: ApiHandlerCreateMessageMetadata, ): ApiStream { const modelInfo = this.getModel().info const methodIsAzureAiInference = this._isAzureAiInference(this.options.openAiBaseUrl) @@ -324,6 +378,8 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl ...(isGrokXAI ? {} : { stream_options: { include_usage: true } }), reasoning_effort: modelInfo.reasoningEffort as "low" | "medium" | "high" | undefined, temperature: undefined, + ...(metadata?.tools && { tools: this.convertToolsForOpenAI(metadata.tools) }), + ...(metadata?.tool_choice && { tool_choice: metadata.tool_choice }), } // O3 family models do not support the deprecated max_tokens parameter @@ -354,6 +410,8 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl ], reasoning_effort: modelInfo.reasoningEffort as "low" | "medium" | "high" | undefined, temperature: undefined, + ...(metadata?.tools && { tools: this.convertToolsForOpenAI(metadata.tools) }), + ...(metadata?.tool_choice && { tool_choice: metadata.tool_choice }), } // O3 family models do not support the deprecated max_tokens parameter @@ -371,22 +429,73 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl throw handleOpenAIError(error, this.providerName) } + const message = response.choices?.[0]?.message + if (message?.tool_calls) { + for (const toolCall of message.tool_calls) { + if (toolCall.type === "function") { + yield { + type: "tool_call", + id: toolCall.id, + name: toolCall.function.name, + arguments: toolCall.function.arguments, + } + } + } + } + yield { type: "text", - text: response.choices?.[0]?.message.content || "", + text: message?.content || "", } yield this.processUsageMetrics(response.usage) } } private async *handleStreamResponse(stream: AsyncIterable): ApiStream { + const toolCallAccumulator = new Map() + for await (const chunk of stream) { const delta = chunk.choices?.[0]?.delta - if (delta?.content) { - yield { - type: "text", - text: delta.content, + const finishReason = chunk.choices?.[0]?.finish_reason + + if (delta) { + if (delta.content) { + yield { + type: "text", + text: delta.content, + } + } + + if (delta.tool_calls) { + for (const toolCall of delta.tool_calls) { + const index = toolCall.index + const existing = toolCallAccumulator.get(index) + + if (existing) { + if (toolCall.function?.arguments) { + existing.arguments += toolCall.function.arguments + } + } else { + toolCallAccumulator.set(index, { + id: toolCall.id || "", + name: toolCall.function?.name || "", + arguments: toolCall.function?.arguments || "", + }) + } + } + } + } + + if (finishReason === "tool_calls") { + for (const toolCall of toolCallAccumulator.values()) { + yield { + type: "tool_call", + id: toolCall.id, + name: toolCall.name, + arguments: toolCall.arguments, + } } + toolCallAccumulator.clear() } if (chunk.usage) { diff --git a/src/api/transform/model-params.ts b/src/api/transform/model-params.ts index 5e9d9d844e..22b43ba8f5 100644 --- a/src/api/transform/model-params.ts +++ b/src/api/transform/model-params.ts @@ -42,6 +42,7 @@ type BaseModelParams = { reasoningEffort: ReasoningEffortExtended | undefined reasoningBudget: number | undefined verbosity: VerbosityLevel | undefined + tools?: boolean } type AnthropicModelParams = { @@ -160,6 +161,7 @@ export function getModelParams({ format, ...params, reasoning: getOpenAiReasoning({ model, reasoningBudget, reasoningEffort, settings }), + tools: model.supportsNativeTools, } } else if (format === "gemini") { return {