Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/types/src/providers/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,7 @@ export const openAiModelInfoSaneDefaults: ModelInfo = {
supportsPromptCache: false,
inputPrice: 0,
outputPrice: 0,
supportsNativeTools: true,
}

// https://learn.microsoft.com/en-us/azure/ai-services/openai/api-version-deprecation
Expand Down
210 changes: 210 additions & 0 deletions src/api/providers/__tests__/openai.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,55 @@ describe("OpenAiHandler", () => {
expect(usageChunk?.outputTokens).toBe(5)
})

it("should handle tool calls in non-streaming mode", async () => {
mockCreate.mockResolvedValueOnce({
choices: [
{
message: {
role: "assistant",
content: null,
tool_calls: [
{
id: "call_1",
type: "function",
function: {
name: "test_tool",
arguments: '{"arg":"value"}',
},
},
],
},
finish_reason: "tool_calls",
},
],
usage: {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15,
},
})

const handler = new OpenAiHandler({
...mockOptions,
openAiStreamingEnabled: false,
})

const stream = handler.createMessage(systemPrompt, messages)
const chunks: any[] = []
for await (const chunk of stream) {
chunks.push(chunk)
}

const toolCallChunks = chunks.filter((chunk) => chunk.type === "tool_call")
expect(toolCallChunks).toHaveLength(1)
expect(toolCallChunks[0]).toEqual({
type: "tool_call",
id: "call_1",
name: "test_tool",
arguments: '{"arg":"value"}',
})
})

it("should handle streaming responses", async () => {
const stream = handler.createMessage(systemPrompt, messages)
const chunks: any[] = []
Expand All @@ -170,6 +219,66 @@ describe("OpenAiHandler", () => {
expect(textChunks[0].text).toBe("Test response")
})

it("should handle tool calls in streaming responses", async () => {
mockCreate.mockImplementation(async (options) => {
return {
[Symbol.asyncIterator]: async function* () {
yield {
choices: [
{
delta: {
tool_calls: [
{
index: 0,
id: "call_1",
function: { name: "test_tool", arguments: "" },
},
],
},
finish_reason: null,
},
],
}
yield {
choices: [
{
delta: {
tool_calls: [{ index: 0, function: { arguments: '{"arg":' } }],
},
finish_reason: null,
},
],
}
yield {
choices: [
{
delta: {
tool_calls: [{ index: 0, function: { arguments: '"value"}' } }],
},
finish_reason: "tool_calls",
},
],
}
},
}
})

const stream = handler.createMessage(systemPrompt, messages)
const chunks: any[] = []
for await (const chunk of stream) {
chunks.push(chunk)
}

const toolCallChunks = chunks.filter((chunk) => chunk.type === "tool_call")
expect(toolCallChunks).toHaveLength(1)
expect(toolCallChunks[0]).toEqual({
type: "tool_call",
id: "call_1",
name: "test_tool",
arguments: '{"arg":"value"}',
})
})

it("should include reasoning_effort when reasoning effort is enabled", async () => {
const reasoningOptions: ApiHandlerOptions = {
...mockOptions,
Expand Down Expand Up @@ -618,6 +727,58 @@ describe("OpenAiHandler", () => {
)
})

it("should handle tool calls with O3 model in streaming mode", async () => {
const o3Handler = new OpenAiHandler(o3Options)

mockCreate.mockImplementation(async (options) => {
return {
[Symbol.asyncIterator]: async function* () {
yield {
choices: [
{
delta: {
tool_calls: [
{
index: 0,
id: "call_1",
function: { name: "test_tool", arguments: "" },
},
],
},
finish_reason: null,
},
],
}
yield {
choices: [
{
delta: {
tool_calls: [{ index: 0, function: { arguments: "{}" } }],
},
finish_reason: "tool_calls",
},
],
}
},
}
})

const stream = o3Handler.createMessage("system", [])
const chunks: any[] = []
for await (const chunk of stream) {
chunks.push(chunk)
}

const toolCallChunks = chunks.filter((chunk) => chunk.type === "tool_call")
expect(toolCallChunks).toHaveLength(1)
expect(toolCallChunks[0]).toEqual({
type: "tool_call",
id: "call_1",
name: "test_tool",
arguments: "{}",
})
})

it("should handle O3 model with streaming and exclude max_tokens when includeMaxTokens is false", async () => {
const o3Handler = new OpenAiHandler({
...o3Options,
Expand Down Expand Up @@ -705,6 +866,55 @@ describe("OpenAiHandler", () => {
expect(callArgs).not.toHaveProperty("stream")
})

it("should handle tool calls with O3 model in non-streaming mode", async () => {
const o3Handler = new OpenAiHandler({
...o3Options,
openAiStreamingEnabled: false,
})

mockCreate.mockResolvedValueOnce({
choices: [
{
message: {
role: "assistant",
content: null,
tool_calls: [
{
id: "call_1",
type: "function",
function: {
name: "test_tool",
arguments: "{}",
},
},
],
},
finish_reason: "tool_calls",
},
],
usage: {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15,
},
})

const stream = o3Handler.createMessage("system", [])
const chunks: any[] = []
for await (const chunk of stream) {
chunks.push(chunk)
}

const toolCallChunks = chunks.filter((chunk) => chunk.type === "tool_call")
expect(toolCallChunks).toHaveLength(1)
expect(toolCallChunks[0]).toEqual({
type: "tool_call",
id: "call_1",
name: "test_tool",
arguments: "{}",
})
})

it("should use default temperature of 0 when not specified for O3 models", async () => {
const o3Handler = new OpenAiHandler({
...o3Options,
Expand Down
36 changes: 36 additions & 0 deletions src/api/providers/base-openai-compatible-provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>
messages: [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)],
stream: true,
stream_options: { include_usage: true },
...(metadata?.tools && { tools: this.convertToolsForOpenAI(metadata.tools) }),
...(metadata?.tool_choice && { tool_choice: metadata.tool_choice }),
}

try {
Expand All @@ -115,6 +117,8 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>
}) as const,
)

const toolCallAccumulator = new Map<number, { id: string; name: string; arguments: string }>()

for await (const chunk of stream) {
// Check for provider-specific error responses (e.g., MiniMax base_resp)
const chunkAny = chunk as any
Expand All @@ -125,6 +129,7 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>
}

const delta = chunk.choices?.[0]?.delta
const finishReason = chunk.choices?.[0]?.finish_reason

if (delta?.content) {
for (const processedChunk of matcher.update(delta.content)) {
Expand All @@ -139,6 +144,37 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>
}
}

if (delta?.tool_calls) {
for (const toolCall of delta.tool_calls) {
const index = toolCall.index
const existing = toolCallAccumulator.get(index)

if (existing) {
if (toolCall.function?.arguments) {
existing.arguments += toolCall.function.arguments
}
} else {
toolCallAccumulator.set(index, {
id: toolCall.id || "",
name: toolCall.function?.name || "",
arguments: toolCall.function?.arguments || "",
})
}
}
}

if (finishReason === "tool_calls") {
for (const toolCall of toolCallAccumulator.values()) {
yield {
type: "tool_call",
id: toolCall.id,
name: toolCall.name,
arguments: toolCall.arguments,
}
}
toolCallAccumulator.clear()
}

if (chunk.usage) {
yield {
type: "usage",
Expand Down
69 changes: 69 additions & 0 deletions src/api/providers/base-provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,75 @@ export abstract class BaseProvider implements ApiHandler {

abstract getModel(): { id: string; info: ModelInfo }

/**
* Converts an array of tools to be compatible with OpenAI's strict mode.
* Filters for function tools and applies schema conversion to their parameters.
*/
protected convertToolsForOpenAI(tools: any[] | undefined): any[] | undefined {
if (!tools) {
return undefined
}

return tools.map((tool) =>
tool.type === "function"
? {
...tool,
function: {
...tool.function,
parameters: this.convertToolSchemaForOpenAI(tool.function.parameters),
},
}
: tool,
)
}

/**
* Converts tool schemas to be compatible with OpenAI's strict mode by:
* - Ensuring all properties are in the required array (strict mode requirement)
* - Converting nullable types (["type", "null"]) to non-nullable ("type")
* - Recursively processing nested objects and arrays
*
* This matches the behavior of ensureAllRequired in openai-native.ts
*/
protected convertToolSchemaForOpenAI(schema: any): any {
if (!schema || typeof schema !== "object" || schema.type !== "object") {
return schema
}

const result = { ...schema }

if (result.properties) {
const allKeys = Object.keys(result.properties)
// OpenAI strict mode requires ALL properties to be in required array
result.required = allKeys

// Recursively process nested objects and convert nullable types
const newProps = { ...result.properties }
for (const key of allKeys) {
const prop = newProps[key]

// Handle nullable types by removing null
if (prop && Array.isArray(prop.type) && prop.type.includes("null")) {
const nonNullTypes = prop.type.filter((t: string) => t !== "null")
prop.type = nonNullTypes.length === 1 ? nonNullTypes[0] : nonNullTypes
}

// Recursively process nested objects
if (prop && prop.type === "object") {
newProps[key] = this.convertToolSchemaForOpenAI(prop)
} else if (prop && prop.type === "array" && prop.items?.type === "object") {
newProps[key] = {
...prop,
items: this.convertToolSchemaForOpenAI(prop.items),
}
}
}
result.properties = newProps
}

return result
}

/**
* Default token counting implementation using tiktoken.
* Providers can override this to use their native token counting endpoints.
Expand Down
Loading
Loading