Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions packages/types/src/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,13 @@ export const modelInfoSchema = z.object({
supportsNativeTools: z.boolean().optional(),
// Default tool protocol preferred by this model (if not specified, falls back to capability/provider defaults)
defaultToolProtocol: z.enum(["xml", "native"]).optional(),
// Exclude specific native tools from being available (only applies to native protocol)
// These tools will be removed from the set of tools available to the model
excludedTools: z.array(z.string()).optional(),
// Include specific native tools (only applies to native protocol)
// These tools will be added if they belong to an allowed group in the current mode
// Cannot force-add tools from groups the mode doesn't allow
includedTools: z.array(z.string()).optional(),
/**
* Service tiers with pricing information.
* Each tier can have a name (for OpenAI service tiers) and pricing overrides.
Expand Down
265 changes: 263 additions & 2 deletions src/core/prompts/tools/__tests__/filter-tools-for-mode.spec.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { describe, it, expect } from "vitest"
import type OpenAI from "openai"
import type { ModeConfig } from "@roo-code/types"
import { filterNativeToolsForMode, filterMcpToolsForMode } from "../filter-tools-for-mode"
import type { ModeConfig, ModelInfo } from "@roo-code/types"
import { filterNativeToolsForMode, filterMcpToolsForMode, applyModelToolCustomization } from "../filter-tools-for-mode"

describe("filterNativeToolsForMode", () => {
const mockNativeTools: OpenAI.Chat.ChatCompletionTool[] = [
Expand Down Expand Up @@ -467,4 +467,265 @@ describe("filterMcpToolsForMode", () => {
// Should include MCP tools since default mode has mcp group
expect(filtered.length).toBeGreaterThan(0)
})

describe("applyModelToolCustomization", () => {
const codeMode: ModeConfig = {
slug: "code",
name: "Code",
roleDefinition: "Test",
groups: ["read", "edit", "browser", "command", "mcp"] as const,
}

const architectMode: ModeConfig = {
slug: "architect",
name: "Architect",
roleDefinition: "Test",
groups: ["read", "browser", "mcp"] as const,
}

it("should return original tools when modelInfo is undefined", () => {
const tools = new Set(["read_file", "write_to_file", "apply_diff"])
const result = applyModelToolCustomization(tools, codeMode, undefined)
expect(result).toEqual(tools)
})

it("should exclude tools specified in excludedTools", () => {
const tools = new Set(["read_file", "write_to_file", "apply_diff"])
const modelInfo: ModelInfo = {
contextWindow: 100000,
supportsPromptCache: false,
excludedTools: ["apply_diff"],
}
const result = applyModelToolCustomization(tools, codeMode, modelInfo)
expect(result.has("read_file")).toBe(true)
expect(result.has("write_to_file")).toBe(true)
expect(result.has("apply_diff")).toBe(false)
})

it("should exclude multiple tools", () => {
const tools = new Set(["read_file", "write_to_file", "apply_diff", "execute_command"])
const modelInfo: ModelInfo = {
contextWindow: 100000,
supportsPromptCache: false,
excludedTools: ["apply_diff", "write_to_file"],
}
const result = applyModelToolCustomization(tools, codeMode, modelInfo)
expect(result.has("read_file")).toBe(true)
expect(result.has("execute_command")).toBe(true)
expect(result.has("write_to_file")).toBe(false)
expect(result.has("apply_diff")).toBe(false)
})

it("should include tools only if they belong to allowed groups", () => {
const tools = new Set(["read_file"])
const modelInfo: ModelInfo = {
contextWindow: 100000,
supportsPromptCache: false,
includedTools: ["write_to_file", "apply_diff"], // Both in edit group
}
const result = applyModelToolCustomization(tools, codeMode, modelInfo)
expect(result.has("read_file")).toBe(true)
expect(result.has("write_to_file")).toBe(true)
expect(result.has("apply_diff")).toBe(true)
})

it("should NOT include tools from groups not allowed by mode", () => {
const tools = new Set(["read_file"])
const modelInfo: ModelInfo = {
contextWindow: 100000,
supportsPromptCache: false,
includedTools: ["write_to_file", "apply_diff"], // Edit group tools
}
// Architect mode doesn't have edit group
const result = applyModelToolCustomization(tools, architectMode, modelInfo)
expect(result.has("read_file")).toBe(true)
expect(result.has("write_to_file")).toBe(false) // Not in allowed groups
expect(result.has("apply_diff")).toBe(false) // Not in allowed groups
})

it("should apply both exclude and include operations", () => {
const tools = new Set(["read_file", "write_to_file", "apply_diff"])
const modelInfo: ModelInfo = {
contextWindow: 100000,
supportsPromptCache: false,
excludedTools: ["apply_diff"],
includedTools: ["insert_content"], // Another edit tool
}
const result = applyModelToolCustomization(tools, codeMode, modelInfo)
expect(result.has("read_file")).toBe(true)
expect(result.has("write_to_file")).toBe(true)
expect(result.has("apply_diff")).toBe(false) // Excluded
expect(result.has("insert_content")).toBe(true) // Included
})

it("should handle empty excludedTools and includedTools arrays", () => {
const tools = new Set(["read_file", "write_to_file"])
const modelInfo: ModelInfo = {
contextWindow: 100000,
supportsPromptCache: false,
excludedTools: [],
includedTools: [],
}
const result = applyModelToolCustomization(tools, codeMode, modelInfo)
expect(result).toEqual(tools)
})

it("should ignore excluded tools that are not in the original set", () => {
const tools = new Set(["read_file", "write_to_file"])
const modelInfo: ModelInfo = {
contextWindow: 100000,
supportsPromptCache: false,
excludedTools: ["apply_diff", "nonexistent_tool"],
}
const result = applyModelToolCustomization(tools, codeMode, modelInfo)
expect(result.has("read_file")).toBe(true)
expect(result.has("write_to_file")).toBe(true)
expect(result.size).toBe(2)
})
})

describe("filterNativeToolsForMode with model customization", () => {
const mockNativeTools: OpenAI.Chat.ChatCompletionTool[] = [
{
type: "function",
function: {
name: "read_file",
description: "Read files",
parameters: {},
},
},
{
type: "function",
function: {
name: "write_to_file",
description: "Write files",
parameters: {},
},
},
{
type: "function",
function: {
name: "apply_diff",
description: "Apply diff",
parameters: {},
},
},
{
type: "function",
function: {
name: "insert_content",
description: "Insert content",
parameters: {},
},
},
{
type: "function",
function: {
name: "execute_command",
description: "Execute command",
parameters: {},
},
},
]

it("should exclude tools when model specifies excludedTools", () => {
const codeMode: ModeConfig = {
slug: "code",
name: "Code",
roleDefinition: "Test",
groups: ["read", "edit", "browser", "command", "mcp"] as const,
}

const modelInfo: ModelInfo = {
contextWindow: 100000,
supportsPromptCache: false,
excludedTools: ["apply_diff"],
}

const filtered = filterNativeToolsForMode(mockNativeTools, "code", [codeMode], {}, undefined, {
modelInfo,
})

const toolNames = filtered.map((t) => ("function" in t ? t.function.name : ""))

expect(toolNames).toContain("read_file")
expect(toolNames).toContain("write_to_file")
expect(toolNames).toContain("insert_content")
expect(toolNames).not.toContain("apply_diff") // Excluded by model
})

it("should include tools when model specifies includedTools from allowed groups", () => {
const modeWithOnlyRead: ModeConfig = {
slug: "limited",
name: "Limited",
roleDefinition: "Test",
groups: ["read", "edit"] as const,
}

const modelInfo: ModelInfo = {
contextWindow: 100000,
supportsPromptCache: false,
includedTools: ["insert_content"], // Edit group tool
}

const filtered = filterNativeToolsForMode(mockNativeTools, "limited", [modeWithOnlyRead], {}, undefined, {
modelInfo,
})

const toolNames = filtered.map((t) => ("function" in t ? t.function.name : ""))

expect(toolNames).toContain("insert_content") // Included by model
})

it("should NOT include tools from groups not allowed by mode", () => {
const architectMode: ModeConfig = {
slug: "architect",
name: "Architect",
roleDefinition: "Test",
groups: ["read", "browser"] as const, // No edit group
}

const modelInfo: ModelInfo = {
contextWindow: 100000,
supportsPromptCache: false,
includedTools: ["write_to_file", "apply_diff"], // Edit group tools
}

const filtered = filterNativeToolsForMode(mockNativeTools, "architect", [architectMode], {}, undefined, {
modelInfo,
})

const toolNames = filtered.map((t) => ("function" in t ? t.function.name : ""))

expect(toolNames).toContain("read_file")
expect(toolNames).not.toContain("write_to_file") // Not in mode's allowed groups
expect(toolNames).not.toContain("apply_diff") // Not in mode's allowed groups
})

it("should combine excludedTools and includedTools", () => {
const codeMode: ModeConfig = {
slug: "code",
name: "Code",
roleDefinition: "Test",
groups: ["read", "edit", "browser", "command", "mcp"] as const,
}

const modelInfo: ModelInfo = {
contextWindow: 100000,
supportsPromptCache: false,
excludedTools: ["apply_diff"],
includedTools: ["insert_content"],
}

const filtered = filterNativeToolsForMode(mockNativeTools, "code", [codeMode], {}, undefined, {
modelInfo,
})

const toolNames = filtered.map((t) => ("function" in t ? t.function.name : ""))

expect(toolNames).toContain("write_to_file")
expect(toolNames).toContain("insert_content") // Included
expect(toolNames).not.toContain("apply_diff") // Excluded
})
})
})
69 changes: 65 additions & 4 deletions src/core/prompts/tools/filter-tools-for-mode.ts
Original file line number Diff line number Diff line change
@@ -1,21 +1,78 @@
import type OpenAI from "openai"
import type { ModeConfig, ToolName, ToolGroup } from "@roo-code/types"
import type { ModeConfig, ToolName, ToolGroup, ModelInfo } from "@roo-code/types"
import { getModeBySlug, getToolsForMode, isToolAllowedForMode } from "../../../shared/modes"
import { TOOL_GROUPS, ALWAYS_AVAILABLE_TOOLS } from "../../../shared/tools"
import { defaultModeSlug } from "../../../shared/modes"
import type { CodeIndexManager } from "../../../services/code-index/manager"
import type { McpHub } from "../../../services/mcp/McpHub"

/**
* Filters native tools based on mode restrictions.
* Apply model-specific tool customization to a set of allowed tools.
*
* This function filters tools based on model configuration:
* 1. Removes tools specified in modelInfo.excludedTools
* 2. Adds tools from modelInfo.includedTools (only if they belong to allowed groups)
*
* @param allowedTools - Set of tools already allowed by mode configuration
* @param modeConfig - Current mode configuration to check tool groups
* @param modelInfo - Model configuration with tool customization
* @returns Modified set of tools after applying model customization
*/
export function applyModelToolCustomization(
allowedTools: Set<string>,
modeConfig: ModeConfig,
modelInfo?: ModelInfo,
): Set<string> {
if (!modelInfo) {
return allowedTools
}

const result = new Set(allowedTools)

// Apply excluded tools (remove from allowed set)
if (modelInfo.excludedTools && modelInfo.excludedTools.length > 0) {
modelInfo.excludedTools.forEach((tool) => {
result.delete(tool)
})
}

// Apply included tools (add to allowed set, but only if they belong to an allowed group)
if (modelInfo.includedTools && modelInfo.includedTools.length > 0) {
// Build a map of tool -> group for all tools in TOOL_GROUPS
const toolToGroup = new Map<string, ToolGroup>()
for (const [groupName, groupConfig] of Object.entries(TOOL_GROUPS)) {
groupConfig.tools.forEach((tool) => {
toolToGroup.set(tool, groupName as ToolGroup)
})
}

// Get the list of allowed groups for this mode
const allowedGroups = new Set(
modeConfig.groups.map((groupEntry) => (Array.isArray(groupEntry) ? groupEntry[0] : groupEntry)),
)

// Add included tools only if they belong to an allowed group
modelInfo.includedTools.forEach((tool) => {
const toolGroup = toolToGroup.get(tool)
if (toolGroup && allowedGroups.has(toolGroup)) {
result.add(tool)
}
})
}

return result
}

/**
* Filters native tools based on mode restrictions and model customization.
* This ensures native tools are filtered the same way XML tools are filtered in the system prompt.
*
* @param nativeTools - Array of all available native tools
* @param mode - Current mode slug
* @param customModes - Custom mode configurations
* @param experiments - Experiment flags
* @param codeIndexManager - Code index manager for codebase_search feature check
* @param settings - Additional settings for tool filtering
* @param settings - Additional settings for tool filtering (includes modelInfo for model-specific customization)
* @param mcpHub - MCP hub for checking available resources
* @returns Filtered array of tools allowed for the mode
*/
Expand Down Expand Up @@ -43,7 +100,7 @@ export function filterNativeToolsForMode(
const allToolsForMode = getToolsForMode(modeConfig.groups)

// Filter to only tools that pass permission checks
const allowedToolNames = new Set(
let allowedToolNames = new Set(
allToolsForMode.filter((tool) =>
isToolAllowedForMode(
tool as ToolName,
Expand All @@ -56,6 +113,10 @@ export function filterNativeToolsForMode(
),
)

// Apply model-specific tool customization
const modelInfo = settings?.modelInfo as ModelInfo | undefined
allowedToolNames = applyModelToolCustomization(allowedToolNames, modeConfig, modelInfo)

// Conditionally exclude codebase_search if feature is disabled or not configured
if (
!codeIndexManager ||
Expand Down
1 change: 1 addition & 0 deletions src/core/task/Task.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3479,6 +3479,7 @@ export class Task extends EventEmitter<TaskEvents> implements TaskLike {
apiConfiguration,
maxReadFileLine: state?.maxReadFileLine ?? -1,
browserToolEnabled: state?.browserToolEnabled ?? true,
modelInfo,
})
}

Expand Down
Loading
Loading