Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/types/src/providers/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,7 @@ export const openAiModelInfoSaneDefaults: ModelInfo = {
supportsPromptCache: false,
inputPrice: 0,
outputPrice: 0,
supportsNativeTools: true,
}

// https://learn.microsoft.com/en-us/azure/ai-services/openai/api-version-deprecation
Expand Down
210 changes: 210 additions & 0 deletions src/api/providers/__tests__/openai.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,55 @@ describe("OpenAiHandler", () => {
expect(usageChunk?.outputTokens).toBe(5)
})

it("should handle tool calls in non-streaming mode", async () => {
mockCreate.mockResolvedValueOnce({
choices: [
{
message: {
role: "assistant",
content: null,
tool_calls: [
{
id: "call_1",
type: "function",
function: {
name: "test_tool",
arguments: '{"arg":"value"}',
},
},
],
},
finish_reason: "tool_calls",
},
],
usage: {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15,
},
})

const handler = new OpenAiHandler({
...mockOptions,
openAiStreamingEnabled: false,
})

const stream = handler.createMessage(systemPrompt, messages)
const chunks: any[] = []
for await (const chunk of stream) {
chunks.push(chunk)
}

const toolCallChunks = chunks.filter((chunk) => chunk.type === "tool_call")
expect(toolCallChunks).toHaveLength(1)
expect(toolCallChunks[0]).toEqual({
type: "tool_call",
id: "call_1",
name: "test_tool",
arguments: '{"arg":"value"}',
})
})

it("should handle streaming responses", async () => {
const stream = handler.createMessage(systemPrompt, messages)
const chunks: any[] = []
Expand All @@ -170,6 +219,66 @@ describe("OpenAiHandler", () => {
expect(textChunks[0].text).toBe("Test response")
})

it("should handle tool calls in streaming responses", async () => {
mockCreate.mockImplementation(async (options) => {
return {
[Symbol.asyncIterator]: async function* () {
yield {
choices: [
{
delta: {
tool_calls: [
{
index: 0,
id: "call_1",
function: { name: "test_tool", arguments: "" },
},
],
},
finish_reason: null,
},
],
}
yield {
choices: [
{
delta: {
tool_calls: [{ index: 0, function: { arguments: '{"arg":' } }],
},
finish_reason: null,
},
],
}
yield {
choices: [
{
delta: {
tool_calls: [{ index: 0, function: { arguments: '"value"}' } }],
},
finish_reason: "tool_calls",
},
],
}
},
}
})

const stream = handler.createMessage(systemPrompt, messages)
const chunks: any[] = []
for await (const chunk of stream) {
chunks.push(chunk)
}

const toolCallChunks = chunks.filter((chunk) => chunk.type === "tool_call")
expect(toolCallChunks).toHaveLength(1)
expect(toolCallChunks[0]).toEqual({
type: "tool_call",
id: "call_1",
name: "test_tool",
arguments: '{"arg":"value"}',
})
})

it("should include reasoning_effort when reasoning effort is enabled", async () => {
const reasoningOptions: ApiHandlerOptions = {
...mockOptions,
Expand Down Expand Up @@ -618,6 +727,58 @@ describe("OpenAiHandler", () => {
)
})

it("should handle tool calls with O3 model in streaming mode", async () => {
const o3Handler = new OpenAiHandler(o3Options)

mockCreate.mockImplementation(async (options) => {
return {
[Symbol.asyncIterator]: async function* () {
yield {
choices: [
{
delta: {
tool_calls: [
{
index: 0,
id: "call_1",
function: { name: "test_tool", arguments: "" },
},
],
},
finish_reason: null,
},
],
}
yield {
choices: [
{
delta: {
tool_calls: [{ index: 0, function: { arguments: "{}" } }],
},
finish_reason: "tool_calls",
},
],
}
},
}
})

const stream = o3Handler.createMessage("system", [])
const chunks: any[] = []
for await (const chunk of stream) {
chunks.push(chunk)
}

const toolCallChunks = chunks.filter((chunk) => chunk.type === "tool_call")
expect(toolCallChunks).toHaveLength(1)
expect(toolCallChunks[0]).toEqual({
type: "tool_call",
id: "call_1",
name: "test_tool",
arguments: "{}",
})
})

it("should handle O3 model with streaming and exclude max_tokens when includeMaxTokens is false", async () => {
const o3Handler = new OpenAiHandler({
...o3Options,
Expand Down Expand Up @@ -705,6 +866,55 @@ describe("OpenAiHandler", () => {
expect(callArgs).not.toHaveProperty("stream")
})

it("should handle tool calls with O3 model in non-streaming mode", async () => {
const o3Handler = new OpenAiHandler({
...o3Options,
openAiStreamingEnabled: false,
})

mockCreate.mockResolvedValueOnce({
choices: [
{
message: {
role: "assistant",
content: null,
tool_calls: [
{
id: "call_1",
type: "function",
function: {
name: "test_tool",
arguments: "{}",
},
},
],
},
finish_reason: "tool_calls",
},
],
usage: {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15,
},
})

const stream = o3Handler.createMessage("system", [])
const chunks: any[] = []
for await (const chunk of stream) {
chunks.push(chunk)
}

const toolCallChunks = chunks.filter((chunk) => chunk.type === "tool_call")
expect(toolCallChunks).toHaveLength(1)
expect(toolCallChunks[0]).toEqual({
type: "tool_call",
id: "call_1",
name: "test_tool",
arguments: "{}",
})
})

it("should use default temperature of 0 when not specified for O3 models", async () => {
const o3Handler = new OpenAiHandler({
...o3Options,
Expand Down
36 changes: 36 additions & 0 deletions src/api/providers/base-openai-compatible-provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>
messages: [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)],
stream: true,
stream_options: { include_usage: true },
...(metadata?.tools && { tools: metadata.tools }),
...(metadata?.tool_choice && { tool_choice: metadata.tool_choice }),
}

try {
Expand All @@ -115,6 +117,8 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>
}) as const,
)

const toolCallAccumulator = new Map<number, { id: string; name: string; arguments: string }>()

for await (const chunk of stream) {
// Check for provider-specific error responses (e.g., MiniMax base_resp)
const chunkAny = chunk as any
Expand All @@ -125,6 +129,7 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>
}

const delta = chunk.choices?.[0]?.delta
const finishReason = chunk.choices?.[0]?.finish_reason

if (delta?.content) {
for (const processedChunk of matcher.update(delta.content)) {
Expand All @@ -139,6 +144,37 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>
}
}

if (delta?.tool_calls) {
for (const toolCall of delta.tool_calls) {
const index = toolCall.index
const existing = toolCallAccumulator.get(index)

if (existing) {
if (toolCall.function?.arguments) {
existing.arguments += toolCall.function.arguments
}
} else {
toolCallAccumulator.set(index, {
id: toolCall.id || "",
name: toolCall.function?.name || "",
arguments: toolCall.function?.arguments || "",
})
}
}
}

if (finishReason === "tool_calls") {
for (const toolCall of toolCallAccumulator.values()) {
yield {
type: "tool_call",
id: toolCall.id,
name: toolCall.name,
arguments: toolCall.arguments,
}
}
toolCallAccumulator.clear()
}

if (chunk.usage) {
yield {
type: "usage",
Expand Down
Loading
Loading