Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 4 additions & 2 deletions src/api/providers/__tests__/minimax.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -372,11 +372,13 @@ describe("MiniMaxHandler", () => {
const firstChunk = await stream.next()

expect(firstChunk.done).toBe(false)
// Provider now yields tool_call_partial chunks, NativeToolCallParser handles reassembly
expect(firstChunk.value).toEqual({
type: "tool_call",
type: "tool_call_partial",
index: 0,
id: "tool-123",
name: "get_weather",
arguments: JSON.stringify({ city: "London" }),
arguments: undefined,
})
})
})
Expand Down
67 changes: 48 additions & 19 deletions src/api/providers/__tests__/openai.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -269,13 +269,31 @@ describe("OpenAiHandler", () => {
chunks.push(chunk)
}

const toolCallChunks = chunks.filter((chunk) => chunk.type === "tool_call")
expect(toolCallChunks).toHaveLength(1)
expect(toolCallChunks[0]).toEqual({
type: "tool_call",
// Provider now yields tool_call_partial chunks, NativeToolCallParser handles reassembly
const toolCallPartialChunks = chunks.filter((chunk) => chunk.type === "tool_call_partial")
expect(toolCallPartialChunks).toHaveLength(3)
// First chunk has id and name
expect(toolCallPartialChunks[0]).toEqual({
type: "tool_call_partial",
index: 0,
id: "call_1",
name: "test_tool",
arguments: '{"arg":"value"}',
arguments: "",
})
// Subsequent chunks have arguments
expect(toolCallPartialChunks[1]).toEqual({
type: "tool_call_partial",
index: 0,
id: undefined,
name: undefined,
arguments: '{"arg":',
})
expect(toolCallPartialChunks[2]).toEqual({
type: "tool_call_partial",
index: 0,
id: undefined,
name: undefined,
arguments: '"value"}',
})
})

Expand Down Expand Up @@ -318,11 +336,12 @@ describe("OpenAiHandler", () => {
chunks.push(chunk)
}

// Tool calls should still be yielded via the fallback mechanism
const toolCallChunks = chunks.filter((chunk) => chunk.type === "tool_call")
expect(toolCallChunks).toHaveLength(1)
expect(toolCallChunks[0]).toEqual({
type: "tool_call",
// Provider now yields tool_call_partial chunks, NativeToolCallParser handles reassembly
const toolCallPartialChunks = chunks.filter((chunk) => chunk.type === "tool_call_partial")
expect(toolCallPartialChunks).toHaveLength(1)
expect(toolCallPartialChunks[0]).toEqual({
type: "tool_call_partial",
index: 0,
id: "call_fallback",
name: "fallback_tool",
arguments: '{"test":"fallback"}',
Expand Down Expand Up @@ -819,12 +838,21 @@ describe("OpenAiHandler", () => {
chunks.push(chunk)
}

const toolCallChunks = chunks.filter((chunk) => chunk.type === "tool_call")
expect(toolCallChunks).toHaveLength(1)
expect(toolCallChunks[0]).toEqual({
type: "tool_call",
// Provider now yields tool_call_partial chunks, NativeToolCallParser handles reassembly
const toolCallPartialChunks = chunks.filter((chunk) => chunk.type === "tool_call_partial")
expect(toolCallPartialChunks).toHaveLength(2)
expect(toolCallPartialChunks[0]).toEqual({
type: "tool_call_partial",
index: 0,
id: "call_1",
name: "test_tool",
arguments: "",
})
expect(toolCallPartialChunks[1]).toEqual({
type: "tool_call_partial",
index: 0,
id: undefined,
name: undefined,
arguments: "{}",
})
})
Expand Down Expand Up @@ -870,11 +898,12 @@ describe("OpenAiHandler", () => {
chunks.push(chunk)
}

// Tool calls should still be yielded via the fallback mechanism
const toolCallChunks = chunks.filter((chunk) => chunk.type === "tool_call")
expect(toolCallChunks).toHaveLength(1)
expect(toolCallChunks[0]).toEqual({
type: "tool_call",
// Provider now yields tool_call_partial chunks, NativeToolCallParser handles reassembly
const toolCallPartialChunks = chunks.filter((chunk) => chunk.type === "tool_call_partial")
expect(toolCallPartialChunks).toHaveLength(1)
expect(toolCallPartialChunks[0]).toEqual({
type: "tool_call_partial",
index: 0,
id: "call_o3_fallback",
name: "o3_fallback_tool",
arguments: '{"o3":"test"}',
Expand Down
80 changes: 49 additions & 31 deletions src/api/providers/__tests__/roo.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,7 @@ describe("RooHandler", () => {
handler = new RooHandler(mockOptions)
})

it("should yield tool calls when finish_reason is tool_calls", async () => {
it("should yield raw tool call chunks when tool_calls present", async () => {
mockCreate.mockResolvedValueOnce({
[Symbol.asyncIterator]: async function* () {
yield {
Expand Down Expand Up @@ -689,14 +689,27 @@ describe("RooHandler", () => {
chunks.push(chunk)
}

const toolCallChunks = chunks.filter((chunk) => chunk.type === "tool_call")
expect(toolCallChunks).toHaveLength(1)
expect(toolCallChunks[0].id).toBe("call_123")
expect(toolCallChunks[0].name).toBe("read_file")
expect(toolCallChunks[0].arguments).toBe('{"path":"test.ts"}')
// Verify we get raw tool call chunks
const rawChunks = chunks.filter((chunk) => chunk.type === "tool_call_partial")

expect(rawChunks).toHaveLength(2)
expect(rawChunks[0]).toEqual({
type: "tool_call_partial",
index: 0,
id: "call_123",
name: "read_file",
arguments: '{"path":"',
})
expect(rawChunks[1]).toEqual({
type: "tool_call_partial",
index: 0,
id: undefined,
name: undefined,
arguments: 'test.ts"}',
})
})

it("should yield tool calls even when finish_reason is not set (fallback behavior)", async () => {
it("should yield raw tool call chunks even when finish_reason is not tool_calls", async () => {
mockCreate.mockResolvedValueOnce({
[Symbol.asyncIterator]: async function* () {
yield {
Expand All @@ -718,12 +731,11 @@ describe("RooHandler", () => {
},
],
}
// Stream ends without finish_reason being set to "tool_calls"
yield {
choices: [
{
delta: {},
finish_reason: "stop", // Different finish reason
finish_reason: "stop",
index: 0,
},
],
Expand All @@ -738,15 +750,19 @@ describe("RooHandler", () => {
chunks.push(chunk)
}

// Tool calls should still be yielded via the fallback mechanism
const toolCallChunks = chunks.filter((chunk) => chunk.type === "tool_call")
expect(toolCallChunks).toHaveLength(1)
expect(toolCallChunks[0].id).toBe("call_456")
expect(toolCallChunks[0].name).toBe("write_to_file")
expect(toolCallChunks[0].arguments).toBe('{"path":"test.ts","content":"hello"}')
const rawChunks = chunks.filter((chunk) => chunk.type === "tool_call_partial")

expect(rawChunks).toHaveLength(1)
expect(rawChunks[0]).toEqual({
type: "tool_call_partial",
index: 0,
id: "call_456",
name: "write_to_file",
arguments: '{"path":"test.ts","content":"hello"}',
})
})

it("should handle multiple tool calls", async () => {
it("should handle multiple tool calls with different indices", async () => {
mockCreate.mockResolvedValueOnce({
[Symbol.asyncIterator]: async function* () {
yield {
Expand Down Expand Up @@ -800,15 +816,16 @@ describe("RooHandler", () => {
chunks.push(chunk)
}

const toolCallChunks = chunks.filter((chunk) => chunk.type === "tool_call")
expect(toolCallChunks).toHaveLength(2)
expect(toolCallChunks[0].id).toBe("call_1")
expect(toolCallChunks[0].name).toBe("read_file")
expect(toolCallChunks[1].id).toBe("call_2")
expect(toolCallChunks[1].name).toBe("read_file")
const rawChunks = chunks.filter((chunk) => chunk.type === "tool_call_partial")

expect(rawChunks).toHaveLength(2)
expect(rawChunks[0].index).toBe(0)
expect(rawChunks[0].id).toBe("call_1")
expect(rawChunks[1].index).toBe(1)
expect(rawChunks[1].id).toBe("call_2")
})

it("should accumulate tool call arguments across multiple chunks", async () => {
it("should emit raw chunks for streaming arguments", async () => {
mockCreate.mockResolvedValueOnce({
[Symbol.asyncIterator]: async function* () {
yield {
Expand Down Expand Up @@ -876,14 +893,15 @@ describe("RooHandler", () => {
chunks.push(chunk)
}

const toolCallChunks = chunks.filter((chunk) => chunk.type === "tool_call")
expect(toolCallChunks).toHaveLength(1)
expect(toolCallChunks[0].id).toBe("call_789")
expect(toolCallChunks[0].name).toBe("execute_command")
expect(toolCallChunks[0].arguments).toBe('{"command":"npm install"}')
const rawChunks = chunks.filter((chunk) => chunk.type === "tool_call_partial")

expect(rawChunks).toHaveLength(3)
expect(rawChunks[0].arguments).toBe('{"command":"')
expect(rawChunks[1].arguments).toBe("npm install")
expect(rawChunks[2].arguments).toBe('"}')
})

it("should not yield empty tool calls when no tool calls present", async () => {
it("should not yield tool call chunks when no tool calls present", async () => {
mockCreate.mockResolvedValueOnce({
[Symbol.asyncIterator]: async function* () {
yield {
Expand All @@ -902,8 +920,8 @@ describe("RooHandler", () => {
chunks.push(chunk)
}

const toolCallChunks = chunks.filter((chunk) => chunk.type === "tool_call")
expect(toolCallChunks).toHaveLength(0)
const rawChunks = chunks.filter((chunk) => chunk.type === "tool_call_partial")
expect(rawChunks).toHaveLength(0)
})
})
})
45 changes: 5 additions & 40 deletions src/api/providers/base-openai-compatible-provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,6 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>
}) as const,
)

const toolCallAccumulator = new Map<number, { id: string; name: string; arguments: string }>()

let lastUsage: OpenAI.CompletionUsage | undefined

for await (const chunk of stream) {
Expand All @@ -137,7 +135,6 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>
}

const delta = chunk.choices?.[0]?.delta
const finishReason = chunk.choices?.[0]?.finish_reason

if (delta?.content) {
for (const processedChunk of matcher.update(delta.content)) {
Expand All @@ -157,56 +154,24 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>
}
}

// Emit raw tool call chunks - NativeToolCallParser handles state management
if (delta?.tool_calls) {
for (const toolCall of delta.tool_calls) {
const index = toolCall.index
const existing = toolCallAccumulator.get(index)

if (existing) {
if (toolCall.function?.arguments) {
existing.arguments += toolCall.function.arguments
}
} else {
toolCallAccumulator.set(index, {
id: toolCall.id || "",
name: toolCall.function?.name || "",
arguments: toolCall.function?.arguments || "",
})
}
}
}

if (finishReason === "tool_calls") {
for (const toolCall of toolCallAccumulator.values()) {
yield {
type: "tool_call",
type: "tool_call_partial",
index: toolCall.index,
id: toolCall.id,
name: toolCall.name,
arguments: toolCall.arguments,
name: toolCall.function?.name,
arguments: toolCall.function?.arguments,
}
}
toolCallAccumulator.clear()
}

if (chunk.usage) {
lastUsage = chunk.usage
}
}

// Fallback: If stream ends with accumulated tool calls that weren't yielded
// (e.g., finish_reason was 'stop' or 'length' instead of 'tool_calls')
if (toolCallAccumulator.size > 0) {
for (const toolCall of toolCallAccumulator.values()) {
yield {
type: "tool_call",
id: toolCall.id,
name: toolCall.name,
arguments: toolCall.arguments,
}
}
toolCallAccumulator.clear()
}

if (lastUsage) {
yield this.processUsageMetrics(lastUsage, this.getModel().info)
}
Expand Down
Loading
Loading