Skip to content
This repository was archived by the owner on May 15, 2026. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/api/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,15 @@ export interface ApiHandlerCreateMessageMetadata {
* Only applies when toolProtocol is "native".
*/
parallelToolCalls?: boolean
/**
* Optional array of tool names that the model is allowed to call.
* When provided, all tool definitions are passed to the model (so it can reference
* historical tool calls), but only the specified tools can actually be invoked.
* This is used when switching modes to prevent model errors from missing tool
* definitions while still restricting callable tools to the current mode's permissions.
* Only applies to providers that support function calling restrictions (e.g., Gemini).
*/
allowedFunctionNames?: string[]
}

export interface ApiHandler {
Expand Down
149 changes: 149 additions & 0 deletions src/api/providers/__tests__/gemini-handler.spec.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { t } from "i18next"
import { FunctionCallingConfigMode } from "@google/genai"

import { GeminiHandler } from "../gemini"
import type { ApiHandlerOptions } from "../../../shared/api"
Expand Down Expand Up @@ -141,4 +142,152 @@ describe("GeminiHandler backend support", () => {
}).rejects.toThrow(t("common:errors.gemini.generate_stream", { error: "API rate limit exceeded" }))
})
})

describe("allowedFunctionNames support", () => {
const testTools = [
{
type: "function" as const,
function: {
name: "read_file",
description: "Read a file",
parameters: { type: "object", properties: {} },
},
},
{
type: "function" as const,
function: {
name: "write_to_file",
description: "Write to a file",
parameters: { type: "object", properties: {} },
},
},
{
type: "function" as const,
function: {
name: "execute_command",
description: "Execute a command",
parameters: { type: "object", properties: {} },
},
},
]

it("should pass allowedFunctionNames to toolConfig when provided", async () => {
const options = {
apiProvider: "gemini",
} as ApiHandlerOptions
const handler = new GeminiHandler(options)
const stub = vi.fn().mockReturnValue((async function* () {})())
// @ts-ignore access private client
handler["client"].models.generateContentStream = stub

await handler
.createMessage("test", [] as any, {
taskId: "test-task",
tools: testTools,
allowedFunctionNames: ["read_file", "write_to_file"],
})
.next()

const config = stub.mock.calls[0][0].config
expect(config.toolConfig).toEqual({
functionCallingConfig: {
mode: FunctionCallingConfigMode.ANY,
allowedFunctionNames: ["read_file", "write_to_file"],
},
})
})

it("should include all tools but restrict callable functions via allowedFunctionNames", async () => {
const options = {
apiProvider: "gemini",
} as ApiHandlerOptions
const handler = new GeminiHandler(options)
const stub = vi.fn().mockReturnValue((async function* () {})())
// @ts-ignore access private client
handler["client"].models.generateContentStream = stub

await handler
.createMessage("test", [] as any, {
taskId: "test-task",
tools: testTools,
allowedFunctionNames: ["read_file"],
})
.next()

const config = stub.mock.calls[0][0].config
// All tools should be passed to the model
expect(config.tools[0].functionDeclarations).toHaveLength(3)
// But only read_file should be allowed to be called
expect(config.toolConfig.functionCallingConfig.allowedFunctionNames).toEqual(["read_file"])
})

it("should take precedence over tool_choice when allowedFunctionNames is provided", async () => {
const options = {
apiProvider: "gemini",
} as ApiHandlerOptions
const handler = new GeminiHandler(options)
const stub = vi.fn().mockReturnValue((async function* () {})())
// @ts-ignore access private client
handler["client"].models.generateContentStream = stub

await handler
.createMessage("test", [] as any, {
taskId: "test-task",
tools: testTools,
tool_choice: "auto",
allowedFunctionNames: ["read_file"],
})
.next()

const config = stub.mock.calls[0][0].config
// allowedFunctionNames should take precedence - mode should be ANY, not AUTO
expect(config.toolConfig.functionCallingConfig.mode).toBe(FunctionCallingConfigMode.ANY)
expect(config.toolConfig.functionCallingConfig.allowedFunctionNames).toEqual(["read_file"])
})

it("should fall back to tool_choice when allowedFunctionNames is empty", async () => {
const options = {
apiProvider: "gemini",
} as ApiHandlerOptions
const handler = new GeminiHandler(options)
const stub = vi.fn().mockReturnValue((async function* () {})())
// @ts-ignore access private client
handler["client"].models.generateContentStream = stub

await handler
.createMessage("test", [] as any, {
taskId: "test-task",
tools: testTools,
tool_choice: "auto",
allowedFunctionNames: [],
})
.next()

const config = stub.mock.calls[0][0].config
// Empty allowedFunctionNames should fall back to tool_choice behavior
expect(config.toolConfig.functionCallingConfig.mode).toBe(FunctionCallingConfigMode.AUTO)
expect(config.toolConfig.functionCallingConfig.allowedFunctionNames).toBeUndefined()
})

it("should not set toolConfig when allowedFunctionNames is undefined and no tool_choice", async () => {
const options = {
apiProvider: "gemini",
} as ApiHandlerOptions
const handler = new GeminiHandler(options)
const stub = vi.fn().mockReturnValue((async function* () {})())
// @ts-ignore access private client
handler["client"].models.generateContentStream = stub

await handler
.createMessage("test", [] as any, {
taskId: "test-task",
tools: testTools,
})
.next()

const config = stub.mock.calls[0][0].config
// No toolConfig should be set when neither allowedFunctionNames nor tool_choice is provided
expect(config.toolConfig).toBeUndefined()
})
})
})
14 changes: 13 additions & 1 deletion src/api/providers/gemini.ts
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,19 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
...(tools.length > 0 ? { tools } : {}),
}

if (metadata?.tool_choice) {
// Handle allowedFunctionNames for mode-restricted tool access.
// When provided, all tool definitions are passed to the model (so it can reference
// historical tool calls in conversation), but only the specified tools can be invoked.
// This takes precedence over tool_choice to ensure mode restrictions are honored.
if (metadata?.allowedFunctionNames && metadata.allowedFunctionNames.length > 0) {
config.toolConfig = {
functionCallingConfig: {
// Use ANY mode to allow calling any of the allowed functions
mode: FunctionCallingConfigMode.ANY,
allowedFunctionNames: metadata.allowedFunctionNames,
},
}
} else if (metadata?.tool_choice) {
const choice = metadata.tool_choice
let mode: FunctionCallingConfigMode
let allowedFunctionNames: string[] | undefined
Expand Down
24 changes: 21 additions & 3 deletions src/core/task/Task.ts
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ import { sanitizeToolUseId } from "../../utils/tool-id"
// prompts
import { formatResponse } from "../prompts/responses"
import { SYSTEM_PROMPT } from "../prompts/system"
import { buildNativeToolsArray } from "./build-tools"
import { buildNativeToolsArrayWithRestrictions } from "./build-tools"

// core modules
import { ToolRepetitionDetector } from "../tools/ToolRepetitionDetector"
Expand Down Expand Up @@ -4091,15 +4091,27 @@ export class Task extends EventEmitter<TaskEvents> implements TaskLike {
const taskProtocol = this._taskToolProtocol ?? "xml"
const shouldIncludeTools = taskProtocol === TOOL_PROTOCOL.NATIVE && (modelInfo.supportsNativeTools ?? false)

// Build complete tools array: native tools + dynamic MCP tools, filtered by mode restrictions
// Build complete tools array: native tools + dynamic MCP tools
// When includeAllToolsWithRestrictions is true, returns all tools but provides
// allowedFunctionNames for providers (like Gemini) that need to see all tool
// definitions in history while restricting callable tools for the current mode.
// Only Gemini currently supports this - other providers filter tools normally.
let allTools: OpenAI.Chat.ChatCompletionTool[] = []
let allowedFunctionNames: string[] | undefined

// Gemini requires all tool definitions to be present for history compatibility,
// but uses allowedFunctionNames to restrict which tools can be called.
// Other providers (Anthropic, OpenAI, etc.) don't support this feature yet,
// so they continue to receive only the filtered tools for the current mode.
const supportsAllowedFunctionNames = apiConfiguration?.apiProvider === "gemini"

if (shouldIncludeTools) {
const provider = this.providerRef.deref()
if (!provider) {
throw new Error("Provider reference lost during tool building")
}

allTools = await buildNativeToolsArray({
const toolsResult = await buildNativeToolsArrayWithRestrictions({
provider,
cwd: this.cwd,
mode,
Expand All @@ -4111,7 +4123,10 @@ export class Task extends EventEmitter<TaskEvents> implements TaskLike {
browserToolEnabled: state?.browserToolEnabled ?? true,
modelInfo,
diffEnabled: this.diffEnabled,
includeAllToolsWithRestrictions: supportsAllowedFunctionNames,
})
allTools = toolsResult.tools
allowedFunctionNames = toolsResult.allowedFunctionNames
}

// Parallel tool calls are disabled - feature is on hold
Expand All @@ -4129,6 +4144,9 @@ export class Task extends EventEmitter<TaskEvents> implements TaskLike {
tool_choice: "auto",
toolProtocol: taskProtocol,
parallelToolCalls: parallelToolCallsEnabled,
// When mode restricts tools, provide allowedFunctionNames so providers
// like Gemini can see all tools in history but only call allowed ones
...(allowedFunctionNames ? { allowedFunctionNames } : {}),
}
: {}),
}
Expand Down
70 changes: 69 additions & 1 deletion src/core/task/build-tools.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,35 @@ interface BuildToolsOptions {
browserToolEnabled: boolean
modelInfo?: ModelInfo
diffEnabled: boolean
/**
* If true, returns all tools without mode filtering, but also includes
* the list of allowed tool names for use with allowedFunctionNames.
* This enables providers that support function call restrictions (e.g., Gemini)
* to pass all tool definitions while restricting callable tools.
*/
includeAllToolsWithRestrictions?: boolean
}

interface BuildToolsResult {
/**
* The tools to pass to the model.
* If includeAllToolsWithRestrictions is true, this includes ALL tools.
* Otherwise, it includes only mode-filtered tools.
*/
tools: OpenAI.Chat.ChatCompletionTool[]
/**
* The names of tools that are allowed to be called based on mode restrictions.
* Only populated when includeAllToolsWithRestrictions is true.
* Use this with allowedFunctionNames in providers that support it.
*/
allowedFunctionNames?: string[]
}

/**
* Extracts the function name from a tool definition.
*/
function getToolName(tool: OpenAI.Chat.ChatCompletionTool): string {
return (tool as OpenAI.Chat.ChatCompletionFunctionTool).function.name
}

/**
Expand All @@ -33,6 +62,23 @@ interface BuildToolsOptions {
* @returns Array of filtered native and MCP tools
*/
export async function buildNativeToolsArray(options: BuildToolsOptions): Promise<OpenAI.Chat.ChatCompletionTool[]> {
const result = await buildNativeToolsArrayWithRestrictions(options)
return result.tools
}

/**
* Builds the complete tools array for native protocol requests with optional mode restrictions.
* When includeAllToolsWithRestrictions is true, returns ALL tools but also provides
* the list of allowed tool names for use with allowedFunctionNames.
*
* This enables providers like Gemini to pass all tool definitions to the model
* (so it can reference historical tool calls) while restricting which tools
* can actually be invoked via allowedFunctionNames in toolConfig.
*
* @param options - Configuration options for building the tools
* @returns BuildToolsResult with tools array and optional allowedFunctionNames
*/
export async function buildNativeToolsArrayWithRestrictions(options: BuildToolsOptions): Promise<BuildToolsResult> {
const {
provider,
cwd,
Expand All @@ -45,6 +91,7 @@ export async function buildNativeToolsArray(options: BuildToolsOptions): Promise
browserToolEnabled,
modelInfo,
diffEnabled,
includeAllToolsWithRestrictions,
} = options

const mcpHub = provider.getMcpHub()
Expand Down Expand Up @@ -102,5 +149,26 @@ export async function buildNativeToolsArray(options: BuildToolsOptions): Promise
}
}

return [...filteredNativeTools, ...filteredMcpTools, ...nativeCustomTools]
// Combine filtered tools (for backward compatibility and for allowedFunctionNames)
const filteredTools = [...filteredNativeTools, ...filteredMcpTools, ...nativeCustomTools]

// If includeAllToolsWithRestrictions is true, return ALL tools but provide
// allowed names based on mode filtering
if (includeAllToolsWithRestrictions) {
// Combine ALL tools (unfiltered native + all MCP + custom)
const allTools = [...nativeTools, ...mcpTools, ...nativeCustomTools]

// Extract names of tools that are allowed based on mode filtering
const allowedFunctionNames = filteredTools.map(getToolName)
Comment thread
hannesrudolph marked this conversation as resolved.
Outdated

return {
tools: allTools,
allowedFunctionNames,
}
}

// Default behavior: return only filtered tools
return {
tools: filteredTools,
}
}
Loading