diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 5591f160d79..b5bcdcfbbd2 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -739,6 +739,9 @@ importers: diff-match-patch: specifier: ^1.0.5 version: 1.0.5 + esbuild: + specifier: '>=0.25.0' + version: 0.25.9 exceljs: specifier: ^4.4.0 version: 4.4.0 @@ -974,9 +977,6 @@ importers: '@vscode/vsce': specifier: 3.3.2 version: 3.3.2 - esbuild: - specifier: '>=0.25.0' - version: 0.25.9 execa: specifier: ^9.5.2 version: 9.5.3 diff --git a/src/core/assistant-message/NativeToolCallParser.ts b/src/core/assistant-message/NativeToolCallParser.ts index 250afdc3890..98dfd654660 100644 --- a/src/core/assistant-message/NativeToolCallParser.ts +++ b/src/core/assistant-message/NativeToolCallParser.ts @@ -1,13 +1,16 @@ +import { parseJSON } from "partial-json" + import { type ToolName, toolNames, type FileEntry } from "@roo-code/types" +import { customToolRegistry } from "@roo-code/core" + import { type ToolUse, type McpToolUse, type ToolParamName, - toolParamNames, type NativeToolArgs, + toolParamNames, } from "../../shared/tools" import { resolveToolAlias } from "../prompts/tools/filter-tools-for-mode" -import { parseJSON } from "partial-json" import type { ApiStreamToolCallStartChunk, ApiStreamToolCallDeltaChunk, @@ -558,6 +561,7 @@ export class NativeToolCallParser { }): ToolUse | McpToolUse | null { // Check if this is a dynamic MCP tool (mcp--serverName--toolName) const mcpPrefix = MCP_TOOL_PREFIX + MCP_TOOL_SEPARATOR + if (typeof toolCall.name === "string" && toolCall.name.startsWith(mcpPrefix)) { return this.parseDynamicMcpTool(toolCall) } @@ -565,8 +569,8 @@ export class NativeToolCallParser { // Resolve tool alias to canonical name (e.g., "edit_file" -> "apply_diff", "temp_edit_file" -> "search_and_replace") const resolvedName = resolveToolAlias(toolCall.name as string) as TName - // Validate tool name (after alias resolution) - if (!toolNames.includes(resolvedName as ToolName)) { + // Validate tool name (after alias resolution). + if (!toolNames.includes(resolvedName as ToolName) && !customToolRegistry.has(resolvedName)) { console.error(`Invalid tool name: ${toolCall.name} (resolved: ${resolvedName})`) console.error(`Valid tool names:`, toolNames) return null @@ -574,7 +578,7 @@ export class NativeToolCallParser { try { // Parse the arguments JSON string - const args = JSON.parse(toolCall.arguments) + const args = toolCall.arguments === "" ? {} : JSON.parse(toolCall.arguments) // Build legacy params object for backward compatibility with XML protocol and UI. // Native execution path uses nativeArgs instead, which has proper typing. @@ -589,7 +593,7 @@ export class NativeToolCallParser { } // Validate parameter name - if (!toolParamNames.includes(key as ToolParamName)) { + if (!toolParamNames.includes(key as ToolParamName) && !customToolRegistry.has(resolvedName)) { console.warn(`Unknown parameter '${key}' for tool '${resolvedName}'`) console.warn(`Valid param names:`, toolParamNames) continue @@ -786,6 +790,12 @@ export class NativeToolCallParser { break default: + if (customToolRegistry.has(resolvedName)) { + nativeArgs = args as NativeArgsFor + } else { + console.error(`Unhandled tool: ${resolvedName}`) + } + break } diff --git a/src/core/assistant-message/presentAssistantMessage.ts b/src/core/assistant-message/presentAssistantMessage.ts index 2e8b791b349..d23120ee73d 100644 --- a/src/core/assistant-message/presentAssistantMessage.ts +++ b/src/core/assistant-message/presentAssistantMessage.ts @@ -4,6 +4,7 @@ import { Anthropic } from "@anthropic-ai/sdk" import type { ToolName, ClineAsk, ToolProgressStatus } from "@roo-code/types" import { TelemetryService } from "@roo-code/telemetry" +import { customToolRegistry } from "@roo-code/core" import { t } from "../../i18n" @@ -1045,9 +1046,8 @@ export async function presentAssistantMessage(cline: Task) { }) break default: { - // Handle unknown/invalid tool names + // Handle unknown/invalid tool names OR custom tools // This is critical for native protocol where every tool_use MUST have a tool_result - // Note: This case should rarely be reached since validateToolUse now checks for unknown tools // CRITICAL: Don't process partial blocks for unknown tools - just let them stream in. // If we try to show errors for partial blocks, we'd show the error on every streaming chunk, @@ -1056,6 +1056,45 @@ export async function presentAssistantMessage(cline: Task) { break } + const customTool = customToolRegistry.get(block.name) + + if (customTool) { + try { + console.log(`executing customTool -> ${JSON.stringify(customTool, null, 2)}`) + let customToolArgs + + if (customTool.parameters) { + try { + customToolArgs = customTool.parameters.parse(block.nativeArgs || block.params || {}) + console.log(`customToolArgs -> ${JSON.stringify(customToolArgs, null, 2)}`) + } catch (parseParamsError) { + const message = `Custom tool "${block.name}" argument validation failed: ${parseParamsError.message}` + console.error(message) + cline.consecutiveMistakeCount++ + await cline.say("error", message) + pushToolResult(formatResponse.toolError(message, toolProtocol)) + break + } + } + + console.log(`${customTool.name}.execute() -> ${JSON.stringify(customToolArgs, null, 2)}`) + + const result = await customTool.execute(customToolArgs, { + mode: mode ?? defaultModeSlug, + task: cline, + }) + + pushToolResult(result) + cline.consecutiveMistakeCount = 0 + } catch (executionError: any) { + cline.consecutiveMistakeCount++ + await handleError(`executing custom tool "${block.name}"`, executionError) + } + + break + } + + // Not a custom tool - handle as unknown tool error const errorMessage = `Unknown tool "${block.name}". This tool does not exist. Please use one of the available tools.` cline.consecutiveMistakeCount++ cline.recordToolError(block.name as ToolName, errorMessage) diff --git a/src/core/prompts/system.ts b/src/core/prompts/system.ts index 75b28bbb213..fe2d98504ce 100644 --- a/src/core/prompts/system.ts +++ b/src/core/prompts/system.ts @@ -1,9 +1,15 @@ import * as vscode from "vscode" import * as os from "os" -import type { ModeConfig, PromptComponent, CustomModePrompts, TodoItem } from "@roo-code/types" - -import type { SystemPromptSettings } from "./types" +import { + type ModeConfig, + type PromptComponent, + type CustomModePrompts, + type TodoItem, + getEffectiveProtocol, + isNativeProtocol, +} from "@roo-code/types" +import { customToolRegistry, formatXml } from "@roo-code/core" import { Mode, modes, defaultModeSlug, getModeBySlug, getGroupName, getModeSelection } from "../../shared/modes" import { DiffStrategy } from "../../shared/tools" @@ -15,8 +21,8 @@ import { CodeIndexManager } from "../../services/code-index/manager" import { PromptVariables, loadSystemPromptFile } from "./sections/custom-system-prompt" +import type { SystemPromptSettings } from "./types" import { getToolDescriptionsForMode } from "./tools" -import { getEffectiveProtocol, isNativeProtocol } from "@roo-code/types" import { getRulesSection, getSystemInfoSection, @@ -98,7 +104,7 @@ async function generatePrompt( ]) // Build tools catalog section only for XML protocol - const toolsCatalog = isNativeProtocol(effectiveProtocol) + const builtInToolsCatalog = isNativeProtocol(effectiveProtocol) ? "" : `\n\n${getToolDescriptionsForMode( mode, @@ -116,6 +122,18 @@ async function generatePrompt( modelId, )}` + let customToolsSection = "" + + if (!isNativeProtocol(effectiveProtocol)) { + const customTools = customToolRegistry.getAllSerialized() + + if (customTools.length > 0) { + customToolsSection = `\n\n${formatXml(customTools)}` + } + } + + const toolsCatalog = builtInToolsCatalog + customToolsSection + const basePrompt = `${roleDefinition} ${markdownFormattingSection()} diff --git a/src/core/task/build-tools.ts b/src/core/task/build-tools.ts index 575b31580e6..b7861be3270 100644 --- a/src/core/task/build-tools.ts +++ b/src/core/task/build-tools.ts @@ -1,6 +1,12 @@ +import path from "path" + import type OpenAI from "openai" + import type { ProviderSettings, ModeConfig, ModelInfo } from "@roo-code/types" +import { customToolRegistry, formatNative } from "@roo-code/core" + import type { ClineProvider } from "../webview/ClineProvider" + import { getNativeTools, getMcpServerTools } from "../prompts/tools/native-tools" import { filterNativeToolsForMode, filterMcpToolsForMode } from "../prompts/tools/filter-tools-for-mode" @@ -40,11 +46,11 @@ export async function buildNativeToolsArray(options: BuildToolsOptions): Promise const mcpHub = provider.getMcpHub() - // Get CodeIndexManager for feature checking + // Get CodeIndexManager for feature checking. const { CodeIndexManager } = await import("../../services/code-index/manager") const codeIndexManager = CodeIndexManager.getInstance(provider.context, cwd) - // Build settings object for tool filtering + // Build settings object for tool filtering. const filterSettings = { todoListEnabled: apiConfiguration?.todoListEnabled ?? true, browserToolEnabled: browserToolEnabled ?? true, @@ -52,13 +58,13 @@ export async function buildNativeToolsArray(options: BuildToolsOptions): Promise diffEnabled, } - // Determine if partial reads are enabled based on maxReadFileLine setting + // Determine if partial reads are enabled based on maxReadFileLine setting. const partialReadsEnabled = maxReadFileLine !== -1 - // Build native tools with dynamic read_file tool based on partialReadsEnabled + // Build native tools with dynamic read_file tool based on partialReadsEnabled. const nativeTools = getNativeTools(partialReadsEnabled) - // Filter native tools based on mode restrictions + // Filter native tools based on mode restrictions. const filteredNativeTools = filterNativeToolsForMode( nativeTools, mode, @@ -69,9 +75,18 @@ export async function buildNativeToolsArray(options: BuildToolsOptions): Promise mcpHub, ) - // Filter MCP tools based on mode restrictions + // Filter MCP tools based on mode restrictions. const mcpTools = getMcpServerTools(mcpHub) const filteredMcpTools = filterMcpToolsForMode(mcpTools, mode, customModes, experiments) - return [...filteredNativeTools, ...filteredMcpTools] + // Add custom tools if they are available. + await customToolRegistry.loadFromDirectoryIfStale(path.join(cwd, ".roo", "tools")) + const customTools = customToolRegistry.getAllSerialized() + let nativeCustomTools: OpenAI.Chat.ChatCompletionFunctionTool[] = [] + + if (customTools.length > 0) { + nativeCustomTools = customTools.map(formatNative) + } + + return [...filteredNativeTools, ...filteredMcpTools, ...nativeCustomTools] } diff --git a/src/core/tools/validateToolUse.ts b/src/core/tools/validateToolUse.ts index d0570337414..de814f0b3c9 100644 --- a/src/core/tools/validateToolUse.ts +++ b/src/core/tools/validateToolUse.ts @@ -1,5 +1,6 @@ import type { ToolName, ModeConfig, ExperimentId, GroupOptions, GroupEntry } from "@roo-code/types" import { toolNames as validToolNames } from "@roo-code/types" +import { customToolRegistry } from "@roo-code/core" import { type Mode, FileRestrictionError, getModeBySlug, getGroupName } from "../../shared/modes" import { EXPERIMENT_IDS } from "../../shared/experiments" @@ -16,6 +17,10 @@ export function isValidToolName(toolName: string): toolName is ToolName { return true } + if (customToolRegistry.has(toolName)) { + return true + } + // Check if it's a dynamic MCP tool (mcp_serverName_toolName format). if (toolName.startsWith("mcp_")) { return true @@ -87,6 +92,12 @@ export function isToolAllowedForMode( return true } + // For now, allow all custom tools in any mode. + // As a follow-up we should expand the custom tool definition to include mode restrictions. + if (customToolRegistry.has(tool)) { + return true + } + // Check if this is a dynamic MCP tool (mcp_serverName_toolName) // These should be allowed if the mcp group is allowed for the mode const isDynamicMcpTool = tool.startsWith("mcp_") diff --git a/src/esbuild.mjs b/src/esbuild.mjs index f99b077e9f9..68298eb3de4 100644 --- a/src/esbuild.mjs +++ b/src/esbuild.mjs @@ -15,7 +15,7 @@ async function main() { const production = process.argv.includes("--production") const watch = process.argv.includes("--watch") const minify = production - const sourcemap = true // Always generate source maps for error handling + const sourcemap = true // Always generate source maps for error handling. /** * @type {import('esbuild').BuildOptions} @@ -100,7 +100,7 @@ async function main() { plugins, entryPoints: ["extension.ts"], outfile: "dist/extension.js", - external: ["vscode"], + external: ["vscode", "esbuild"], } /** diff --git a/src/package.json b/src/package.json index eb27df9a977..e5884238d2b 100644 --- a/src/package.json +++ b/src/package.json @@ -446,6 +446,7 @@ "@roo-code/telemetry": "workspace:^", "@roo-code/types": "workspace:^", "@vscode/codicons": "^0.0.36", + "esbuild": "^0.25.0", "async-mutex": "^0.5.0", "axios": "^1.12.0", "cheerio": "^1.0.0", @@ -535,7 +536,6 @@ "@types/vscode": "^1.84.0", "@vscode/test-electron": "^2.5.2", "@vscode/vsce": "3.3.2", - "esbuild": "^0.25.0", "execa": "^9.5.2", "glob": "^11.1.0", "mkdirp": "^3.0.1", diff --git a/src/utils/__tests__/autoImportSettings.spec.ts b/src/utils/__tests__/autoImportSettings.spec.ts index be0d769670f..f3911571d10 100644 --- a/src/utils/__tests__/autoImportSettings.spec.ts +++ b/src/utils/__tests__/autoImportSettings.spec.ts @@ -17,15 +17,33 @@ vi.mock("fs/promises", () => ({ readFile: vi.fn(), })) -vi.mock("path", () => ({ - join: vi.fn((...args: string[]) => args.join("/")), - isAbsolute: vi.fn((p: string) => p.startsWith("/")), - basename: vi.fn((p: string) => p.split("/").pop() || ""), -})) +vi.mock("path", async (importOriginal) => { + const actual = await importOriginal() + return { + ...actual, + default: { + ...actual, + join: vi.fn((...args: string[]) => args.join("/")), + isAbsolute: vi.fn((p: string) => p.startsWith("/")), + basename: vi.fn((p: string) => p.split("/").pop() || ""), + }, + join: vi.fn((...args: string[]) => args.join("/")), + isAbsolute: vi.fn((p: string) => p.startsWith("/")), + basename: vi.fn((p: string) => p.split("/").pop() || ""), + } +}) -vi.mock("os", () => ({ - homedir: vi.fn(() => "/home/user"), -})) +vi.mock("os", async (importOriginal) => { + const actual = await importOriginal() + return { + ...actual, + default: { + ...actual, + homedir: vi.fn(() => "/home/user"), + }, + homedir: vi.fn(() => "/home/user"), + } +}) vi.mock("../fs", () => ({ fileExistsAtPath: vi.fn(),