Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/types/src/providers/ollama.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ export const ollamaDefaultModelInfo: ModelInfo = {
contextWindow: 200_000,
supportsImages: true,
supportsPromptCache: true,
supportsNativeTools: true,
inputPrice: 0,
outputPrice: 0,
cacheWritesPrice: 0,
Expand Down
278 changes: 270 additions & 8 deletions src/api/providers/__tests__/native-ollama.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import { NativeOllamaHandler } from "../native-ollama"
import { ApiHandlerOptions } from "../../../shared/api"
import { getOllamaModels } from "../fetchers/ollama"

// Mock the ollama package
const mockChat = vitest.fn()
Expand All @@ -16,22 +17,27 @@ vitest.mock("ollama", () => {

// Mock the getOllamaModels function
vitest.mock("../fetchers/ollama", () => ({
getOllamaModels: vitest.fn().mockResolvedValue({
llama2: {
contextWindow: 4096,
maxTokens: 4096,
supportsImages: false,
supportsPromptCache: false,
},
}),
getOllamaModels: vitest.fn(),
}))

const mockGetOllamaModels = vitest.mocked(getOllamaModels)

describe("NativeOllamaHandler", () => {
let handler: NativeOllamaHandler

beforeEach(() => {
vitest.clearAllMocks()

// Default mock for getOllamaModels
mockGetOllamaModels.mockResolvedValue({
llama2: {
contextWindow: 4096,
maxTokens: 4096,
supportsImages: false,
supportsPromptCache: false,
},
})

const options: ApiHandlerOptions = {
apiModelId: "llama2",
ollamaModelId: "llama2",
Expand Down Expand Up @@ -257,4 +263,260 @@ describe("NativeOllamaHandler", () => {
expect(model.info).toBeDefined()
})
})

describe("tool calling", () => {
it("should include tools when model supports native tools", async () => {
// Mock model with native tool support
mockGetOllamaModels.mockResolvedValue({
"llama3.2": {
contextWindow: 128000,
maxTokens: 4096,
supportsImages: true,
supportsPromptCache: false,
supportsNativeTools: true,
},
})

const options: ApiHandlerOptions = {
apiModelId: "llama3.2",
ollamaModelId: "llama3.2",
ollamaBaseUrl: "http://localhost:11434",
}

handler = new NativeOllamaHandler(options)

// Mock the chat response
mockChat.mockImplementation(async function* () {
yield { message: { content: "I will use the tool" } }
})

const tools = [
{
type: "function" as const,
function: {
name: "get_weather",
description: "Get the weather for a location",
parameters: {
type: "object",
properties: {
location: { type: "string", description: "The city name" },
},
required: ["location"],
},
},
},
]

const stream = handler.createMessage(
"System",
[{ role: "user" as const, content: "What's the weather?" }],
{ taskId: "test", tools },
)

// Consume the stream
for await (const _ of stream) {
// consume stream
}

// Verify tools were passed to the API
expect(mockChat).toHaveBeenCalledWith(
expect.objectContaining({
tools: [
{
type: "function",
function: {
name: "get_weather",
description: "Get the weather for a location",
parameters: {
type: "object",
properties: {
location: { type: "string", description: "The city name" },
},
required: ["location"],
},
},
},
],
}),
)
})

it("should not include tools when model does not support native tools", async () => {
// Mock model without native tool support
mockGetOllamaModels.mockResolvedValue({
llama2: {
contextWindow: 4096,
maxTokens: 4096,
supportsImages: false,
supportsPromptCache: false,
supportsNativeTools: false,
},
})

// Mock the chat response
mockChat.mockImplementation(async function* () {
yield { message: { content: "Response without tools" } }
})

const tools = [
{
type: "function" as const,
function: {
name: "get_weather",
description: "Get the weather",
parameters: { type: "object", properties: {} },
},
},
]

const stream = handler.createMessage("System", [{ role: "user" as const, content: "Test" }], {
taskId: "test",
tools,
})

// Consume the stream
for await (const _ of stream) {
// consume stream
}

// Verify tools were NOT passed
expect(mockChat).toHaveBeenCalledWith(
expect.not.objectContaining({
tools: expect.anything(),
}),
)
})

it("should not include tools when toolProtocol is xml", async () => {
// Mock model with native tool support
mockGetOllamaModels.mockResolvedValue({
"llama3.2": {
contextWindow: 128000,
maxTokens: 4096,
supportsImages: true,
supportsPromptCache: false,
supportsNativeTools: true,
},
})

const options: ApiHandlerOptions = {
apiModelId: "llama3.2",
ollamaModelId: "llama3.2",
ollamaBaseUrl: "http://localhost:11434",
}

handler = new NativeOllamaHandler(options)

// Mock the chat response
mockChat.mockImplementation(async function* () {
yield { message: { content: "Response" } }
})

const tools = [
{
type: "function" as const,
function: {
name: "get_weather",
description: "Get the weather",
parameters: { type: "object", properties: {} },
},
},
]

const stream = handler.createMessage("System", [{ role: "user" as const, content: "Test" }], {
taskId: "test",
tools,
toolProtocol: "xml",
})

// Consume the stream
for await (const _ of stream) {
// consume stream
}

// Verify tools were NOT passed (XML protocol forces XML format)
expect(mockChat).toHaveBeenCalledWith(
expect.not.objectContaining({
tools: expect.anything(),
}),
)
})

it("should yield tool_call_partial when model returns tool calls", async () => {
// Mock model with native tool support
mockGetOllamaModels.mockResolvedValue({
"llama3.2": {
contextWindow: 128000,
maxTokens: 4096,
supportsImages: true,
supportsPromptCache: false,
supportsNativeTools: true,
},
})

const options: ApiHandlerOptions = {
apiModelId: "llama3.2",
ollamaModelId: "llama3.2",
ollamaBaseUrl: "http://localhost:11434",
}

handler = new NativeOllamaHandler(options)

// Mock the chat response with tool calls
mockChat.mockImplementation(async function* () {
yield {
message: {
content: "",
tool_calls: [
{
function: {
name: "get_weather",
arguments: { location: "San Francisco" },
},
},
],
},
}
})

const tools = [
{
type: "function" as const,
function: {
name: "get_weather",
description: "Get the weather for a location",
parameters: {
type: "object",
properties: {
location: { type: "string" },
},
required: ["location"],
},
},
},
]

const stream = handler.createMessage(
"System",
[{ role: "user" as const, content: "What's the weather in SF?" }],
{ taskId: "test", tools },
)

const results = []
for await (const chunk of stream) {
results.push(chunk)
}

// Should yield a tool_call_partial chunk
const toolCallChunk = results.find((r) => r.type === "tool_call_partial")
expect(toolCallChunk).toBeDefined()
expect(toolCallChunk).toEqual({
type: "tool_call_partial",
index: 0,
id: "ollama-tool-0",
name: "get_weather",
arguments: JSON.stringify({ location: "San Francisco" }),
})
})
})
})
Loading
Loading