Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/types/src/providers/ollama.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ export const ollamaDefaultModelInfo: ModelInfo = {
contextWindow: 200_000,
supportsImages: true,
supportsPromptCache: true,
supportsNativeTools: true,
inputPrice: 0,
outputPrice: 0,
cacheWritesPrice: 0,
Expand Down
278 changes: 270 additions & 8 deletions src/api/providers/__tests__/native-ollama.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import { NativeOllamaHandler } from "../native-ollama"
import { ApiHandlerOptions } from "../../../shared/api"
import { getOllamaModels } from "../fetchers/ollama"

// Mock the ollama package
const mockChat = vitest.fn()
Expand All @@ -16,22 +17,27 @@ vitest.mock("ollama", () => {

// Mock the getOllamaModels function
vitest.mock("../fetchers/ollama", () => ({
getOllamaModels: vitest.fn().mockResolvedValue({
llama2: {
contextWindow: 4096,
maxTokens: 4096,
supportsImages: false,
supportsPromptCache: false,
},
}),
getOllamaModels: vitest.fn(),
}))

const mockGetOllamaModels = vitest.mocked(getOllamaModels)

describe("NativeOllamaHandler", () => {
let handler: NativeOllamaHandler

beforeEach(() => {
vitest.clearAllMocks()

// Default mock for getOllamaModels
mockGetOllamaModels.mockResolvedValue({
llama2: {
contextWindow: 4096,
maxTokens: 4096,
supportsImages: false,
supportsPromptCache: false,
},
})

const options: ApiHandlerOptions = {
apiModelId: "llama2",
ollamaModelId: "llama2",
Expand Down Expand Up @@ -257,4 +263,260 @@ describe("NativeOllamaHandler", () => {
expect(model.info).toBeDefined()
})
})

describe("tool calling", () => {
it("should include tools when model supports native tools", async () => {
// Mock model with native tool support
mockGetOllamaModels.mockResolvedValue({
"llama3.2": {
contextWindow: 128000,
maxTokens: 4096,
supportsImages: true,
supportsPromptCache: false,
supportsNativeTools: true,
},
})

const options: ApiHandlerOptions = {
apiModelId: "llama3.2",
ollamaModelId: "llama3.2",
ollamaBaseUrl: "http://localhost:11434",
}

handler = new NativeOllamaHandler(options)

// Mock the chat response
mockChat.mockImplementation(async function* () {
yield { message: { content: "I will use the tool" } }
})

const tools = [
{
type: "function" as const,
function: {
name: "get_weather",
description: "Get the weather for a location",
parameters: {
type: "object",
properties: {
location: { type: "string", description: "The city name" },
},
required: ["location"],
},
},
},
]

const stream = handler.createMessage(
"System",
[{ role: "user" as const, content: "What's the weather?" }],
{ taskId: "test", tools },
)

// Consume the stream
for await (const _ of stream) {
// consume stream
}

// Verify tools were passed to the API
expect(mockChat).toHaveBeenCalledWith(
expect.objectContaining({
tools: [
{
type: "function",
function: {
name: "get_weather",
description: "Get the weather for a location",
parameters: {
type: "object",
properties: {
location: { type: "string", description: "The city name" },
},
required: ["location"],
},
},
},
],
}),
)
})

it("should not include tools when model does not support native tools", async () => {
// Mock model without native tool support
mockGetOllamaModels.mockResolvedValue({
llama2: {
contextWindow: 4096,
maxTokens: 4096,
supportsImages: false,
supportsPromptCache: false,
supportsNativeTools: false,
},
})

// Mock the chat response
mockChat.mockImplementation(async function* () {
yield { message: { content: "Response without tools" } }
})

const tools = [
{
type: "function" as const,
function: {
name: "get_weather",
description: "Get the weather",
parameters: { type: "object", properties: {} },
},
},
]

const stream = handler.createMessage("System", [{ role: "user" as const, content: "Test" }], {
taskId: "test",
tools,
})

// Consume the stream
for await (const _ of stream) {
// consume stream
}

// Verify tools were NOT passed
expect(mockChat).toHaveBeenCalledWith(
expect.not.objectContaining({
tools: expect.anything(),
}),
)
})

it("should not include tools when toolProtocol is xml", async () => {
// Mock model with native tool support
mockGetOllamaModels.mockResolvedValue({
"llama3.2": {
contextWindow: 128000,
maxTokens: 4096,
supportsImages: true,
supportsPromptCache: false,
supportsNativeTools: true,
},
})

const options: ApiHandlerOptions = {
apiModelId: "llama3.2",
ollamaModelId: "llama3.2",
ollamaBaseUrl: "http://localhost:11434",
}

handler = new NativeOllamaHandler(options)

// Mock the chat response
mockChat.mockImplementation(async function* () {
yield { message: { content: "Response" } }
})

const tools = [
{
type: "function" as const,
function: {
name: "get_weather",
description: "Get the weather",
parameters: { type: "object", properties: {} },
},
},
]

const stream = handler.createMessage("System", [{ role: "user" as const, content: "Test" }], {
taskId: "test",
tools,
toolProtocol: "xml",
})

// Consume the stream
for await (const _ of stream) {
// consume stream
}

// Verify tools were NOT passed (XML protocol forces XML format)
expect(mockChat).toHaveBeenCalledWith(
expect.not.objectContaining({
tools: expect.anything(),
}),
)
})

it("should yield tool_call_partial when model returns tool calls", async () => {
// Mock model with native tool support
mockGetOllamaModels.mockResolvedValue({
"llama3.2": {
contextWindow: 128000,
maxTokens: 4096,
supportsImages: true,
supportsPromptCache: false,
supportsNativeTools: true,
},
})

const options: ApiHandlerOptions = {
apiModelId: "llama3.2",
ollamaModelId: "llama3.2",
ollamaBaseUrl: "http://localhost:11434",
}

handler = new NativeOllamaHandler(options)

// Mock the chat response with tool calls
mockChat.mockImplementation(async function* () {
yield {
message: {
content: "",
tool_calls: [
{
function: {
name: "get_weather",
arguments: { location: "San Francisco" },
},
},
],
},
}
})

const tools = [
{
type: "function" as const,
function: {
name: "get_weather",
description: "Get the weather for a location",
parameters: {
type: "object",
properties: {
location: { type: "string" },
},
required: ["location"],
},
},
},
]

const stream = handler.createMessage(
"System",
[{ role: "user" as const, content: "What's the weather in SF?" }],
{ taskId: "test", tools },
)

const results = []
for await (const chunk of stream) {
results.push(chunk)
}

// Should yield a tool_call_partial chunk
const toolCallChunk = results.find((r) => r.type === "tool_call_partial")
expect(toolCallChunk).toBeDefined()
expect(toolCallChunk).toEqual({
type: "tool_call_partial",
index: 0,
id: "ollama-tool-0",
name: "get_weather",
arguments: JSON.stringify({ location: "San Francisco" }),
})
})
})
})
Loading
Loading