Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions packages/types/src/experiment.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ export const experimentIds = [
"imageGeneration",
"runSlashCommand",
"multipleNativeToolCalls",
"customTools",
] as const

export const experimentIdsSchema = z.enum(experimentIds)
Expand All @@ -30,6 +31,7 @@ export const experimentsSchema = z.object({
imageGeneration: z.boolean().optional(),
runSlashCommand: z.boolean().optional(),
multipleNativeToolCalls: z.boolean().optional(),
customTools: z.boolean().optional(),
})

export type Experiments = z.infer<typeof experimentsSchema>
Expand Down
10 changes: 5 additions & 5 deletions src/core/assistant-message/presentAssistantMessage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1081,17 +1081,15 @@ export async function presentAssistantMessage(cline: Task) {
break
}

const customTool = customToolRegistry.get(block.name)
const customTool = stateExperiments?.customTools ? customToolRegistry.get(block.name) : undefined

if (customTool) {
try {
console.log(`executing customTool -> ${JSON.stringify(customTool, null, 2)}`)
let customToolArgs

if (customTool.parameters) {
try {
customToolArgs = customTool.parameters.parse(block.nativeArgs || block.params || {})
console.log(`customToolArgs -> ${JSON.stringify(customToolArgs, null, 2)}`)
} catch (parseParamsError) {
const message = `Custom tool "${block.name}" argument validation failed: ${parseParamsError.message}`
console.error(message)
Expand All @@ -1102,13 +1100,15 @@ export async function presentAssistantMessage(cline: Task) {
}
}

console.log(`${customTool.name}.execute() -> ${JSON.stringify(customToolArgs, null, 2)}`)

const result = await customTool.execute(customToolArgs, {
mode: mode ?? defaultModeSlug,
task: cline,
})

console.log(
`${customTool.name}.execute(): ${JSON.stringify(customToolArgs)} -> ${JSON.stringify(result)}`,
)

pushToolResult(result)
cline.consecutiveMistakeCount = 0
} catch (executionError: any) {
Expand Down
2 changes: 1 addition & 1 deletion src/core/prompts/system.ts
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ async function generatePrompt(

let customToolsSection = ""

if (!isNativeProtocol(effectiveProtocol)) {
if (experiments?.customTools && !isNativeProtocol(effectiveProtocol)) {
const customTools = customToolRegistry.getAllSerialized()

if (customTools.length > 0) {
Expand Down
13 changes: 8 additions & 5 deletions src/core/task/build-tools.ts
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,16 @@ export async function buildNativeToolsArray(options: BuildToolsOptions): Promise
const mcpTools = getMcpServerTools(mcpHub)
const filteredMcpTools = filterMcpToolsForMode(mcpTools, mode, customModes, experiments)

// Add custom tools if they are available.
await customToolRegistry.loadFromDirectoryIfStale(path.join(cwd, ".roo", "tools"))
const customTools = customToolRegistry.getAllSerialized()
// Add custom tools if they are available and the experiment is enabled.
let nativeCustomTools: OpenAI.Chat.ChatCompletionFunctionTool[] = []

if (customTools.length > 0) {
nativeCustomTools = customTools.map(formatNative)
if (experiments?.customTools) {
await customToolRegistry.loadFromDirectoryIfStale(path.join(cwd, ".roo", "tools"))
const customTools = customToolRegistry.getAllSerialized()

if (customTools.length > 0) {
nativeCustomTools = customTools.map(formatNative)
}
}

return [...filteredNativeTools, ...filteredMcpTools, ...nativeCustomTools]
Expand Down
8 changes: 4 additions & 4 deletions src/core/tools/validateToolUse.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ import { TOOL_GROUPS, ALWAYS_AVAILABLE_TOOLS } from "../../shared/tools"
* Note: This does NOT check if the tool is allowed for a specific mode,
* only that the tool actually exists.
*/
export function isValidToolName(toolName: string): toolName is ToolName {
export function isValidToolName(toolName: string, experiments?: Record<string, boolean>): toolName is ToolName {
// Check if it's a valid static tool
if ((validToolNames as readonly string[]).includes(toolName)) {
return true
}

if (customToolRegistry.has(toolName)) {
if (experiments?.customTools && customToolRegistry.has(toolName)) {
return true
}

Expand All @@ -40,7 +40,7 @@ export function validateToolUse(
): void {
// First, check if the tool name is actually a valid/known tool
// This catches completely invalid tool names like "edit_file" that don't exist
if (!isValidToolName(toolName)) {
if (!isValidToolName(toolName, experiments)) {
throw new Error(
`Unknown tool "${toolName}". This tool does not exist. Please use one of the available tools: ${validToolNames.join(", ")}.`,
)
Expand Down Expand Up @@ -94,7 +94,7 @@ export function isToolAllowedForMode(

// For now, allow all custom tools in any mode.
// As a follow-up we should expand the custom tool definition to include mode restrictions.
if (customToolRegistry.has(tool)) {
if (experiments?.customTools && customToolRegistry.has(tool)) {
return true
}

Expand Down
19 changes: 19 additions & 0 deletions src/core/webview/webviewMessageHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import {
Experiments,
ExperimentId,
} from "@roo-code/types"
import { customToolRegistry } from "@roo-code/core"
import { CloudService } from "@roo-code/cloud"
import { TelemetryService } from "@roo-code/telemetry"

Expand Down Expand Up @@ -1725,6 +1726,24 @@ export const webviewMessageHandler = async (
}
break
}
case "refreshCustomTools": {
try {
await customToolRegistry.loadFromDirectory(path.join(getCurrentCwd(), ".roo", "tools"))

await provider.postMessageToWebview({
type: "customToolsResult",
tools: customToolRegistry.getAllSerialized(),
})
} catch (error) {
await provider.postMessageToWebview({
type: "customToolsResult",
tools: [],
error: error instanceof Error ? error.message : String(error),
})
}

break
}
case "saveApiConfiguration":
if (message.text && message.apiConfiguration) {
try {
Expand Down
3 changes: 3 additions & 0 deletions src/shared/ExtensionMessage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import type {
OrganizationAllowList,
ShareVisibility,
QueuedMessage,
SerializedCustomToolDefinition,
} from "@roo-code/types"

import { GitCommit } from "../utils/git"
Expand Down Expand Up @@ -133,6 +134,7 @@ export interface ExtensionMessage {
| "browserSessionUpdate"
| "browserSessionNavigate"
| "claudeCodeRateLimits"
| "customToolsResult"
text?: string
payload?: any // Add a generic payload for now, can refine later
// Checkpoint warning message
Expand Down Expand Up @@ -218,6 +220,7 @@ export interface ExtensionMessage {
browserSessionMessages?: ClineMessage[] // For browser session panel updates
isBrowserSessionActive?: boolean // For browser session panel updates
stepIndex?: number // For browserSessionNavigate: the target step index to display
tools?: SerializedCustomToolDefinition[] // For customToolsResult
}

export type ExtensionState = Pick<
Expand Down
1 change: 1 addition & 0 deletions src/shared/WebviewMessage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ export interface WebviewMessage {
| "openDebugUiHistory"
| "downloadErrorDiagnostics"
| "requestClaudeCodeRateLimits"
| "refreshCustomTools"
text?: string
editedMessageContent?: string
tab?: "settings" | "history" | "mcp" | "modes" | "chat" | "marketplace" | "cloud"
Expand Down
3 changes: 3 additions & 0 deletions src/shared/__tests__/experiments.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ describe("experiments", () => {
imageGeneration: false,
runSlashCommand: false,
multipleNativeToolCalls: false,
customTools: false,
}
expect(Experiments.isEnabled(experiments, EXPERIMENT_IDS.POWER_STEERING)).toBe(false)
})
Expand All @@ -44,6 +45,7 @@ describe("experiments", () => {
imageGeneration: false,
runSlashCommand: false,
multipleNativeToolCalls: false,
customTools: false,
}
expect(Experiments.isEnabled(experiments, EXPERIMENT_IDS.POWER_STEERING)).toBe(true)
})
Expand All @@ -56,6 +58,7 @@ describe("experiments", () => {
imageGeneration: false,
runSlashCommand: false,
multipleNativeToolCalls: false,
customTools: false,
}
expect(Experiments.isEnabled(experiments, EXPERIMENT_IDS.POWER_STEERING)).toBe(false)
})
Expand Down
2 changes: 2 additions & 0 deletions src/shared/experiments.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ export const EXPERIMENT_IDS = {
IMAGE_GENERATION: "imageGeneration",
RUN_SLASH_COMMAND: "runSlashCommand",
MULTIPLE_NATIVE_TOOL_CALLS: "multipleNativeToolCalls",
CUSTOM_TOOLS: "customTools",
} as const satisfies Record<string, ExperimentId>

type _AssertExperimentIds = AssertEqual<Equals<ExperimentId, Values<typeof EXPERIMENT_IDS>>>
Expand All @@ -24,6 +25,7 @@ export const experimentConfigsMap: Record<ExperimentKey, ExperimentConfig> = {
IMAGE_GENERATION: { enabled: false },
RUN_SLASH_COMMAND: { enabled: false },
MULTIPLE_NATIVE_TOOL_CALLS: { enabled: false },
CUSTOM_TOOLS: { enabled: false },
}

export const experimentDefault = Object.fromEntries(
Expand Down
173 changes: 173 additions & 0 deletions webview-ui/src/components/settings/CustomToolsSettings.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
import { useState, useEffect, useCallback, useMemo } from "react"
import { useEvent } from "react-use"
import { VSCodeCheckbox } from "@vscode/webview-ui-toolkit/react"
import { RefreshCw, Loader2 } from "lucide-react"

import type { SerializedCustomToolDefinition } from "@roo-code/types"

import { useAppTranslation } from "@/i18n/TranslationContext"

import { vscode } from "@/utils/vscode"

import { Button } from "@/components/ui"

interface ToolParameter {
name: string
type: string
description?: string
required: boolean
}

interface ProcessedTool {
name: string
description: string
parameters: ToolParameter[]
}

interface CustomToolsSettingsProps {
enabled: boolean
onChange: (enabled: boolean) => void
}

export const CustomToolsSettings = ({ enabled, onChange }: CustomToolsSettingsProps) => {
const { t } = useAppTranslation()
const [tools, setTools] = useState<SerializedCustomToolDefinition[]>([])
const [isRefreshing, setIsRefreshing] = useState(false)
const [refreshError, setRefreshError] = useState<string | null>(null)

useEffect(() => {
if (enabled) {
vscode.postMessage({ type: "refreshCustomTools" })
} else {
setTools([])
}
}, [enabled])

useEvent("message", (event: MessageEvent) => {
const message = event.data

if (message.type === "customToolsResult") {
setTools(message.tools || [])
setIsRefreshing(false)
setRefreshError(message.error ?? null)
}
})

const onRefresh = useCallback(() => {
setIsRefreshing(true)
setRefreshError(null)
vscode.postMessage({ type: "refreshCustomTools" })
}, [])

const processedTools = useMemo<ProcessedTool[]>(
() =>
tools.map((tool) => {
const params = tool.parameters
const properties = (params?.properties ?? {}) as Record<string, { type?: string; description?: string }>
const required = (params?.required as string[] | undefined) ?? []

return {
name: tool.name,
description: tool.description,
parameters: Object.entries(properties).map(([name, def]) => ({
name,
type: def.type ?? "any",
description: def.description,
required: required.includes(name),
})),
}
}),
[tools],
)

return (
<div className="space-y-4">
<div>
<div className="flex items-center gap-2">
<VSCodeCheckbox checked={enabled} onChange={(e: any) => onChange(e.target.checked)}>
<span className="font-medium">{t("settings:experimental.CUSTOM_TOOLS.name")}</span>
</VSCodeCheckbox>
</div>
<p className="text-vscode-descriptionForeground text-sm mt-0">
{t("settings:experimental.CUSTOM_TOOLS.description")}
</p>
</div>

{enabled && (
<div className="ml-2 space-y-3">
<div className="flex items-center justify-between gap-4">
<label className="block font-medium">
{t("settings:experimental.CUSTOM_TOOLS.toolsHeader")}
</label>
<Button variant="outline" onClick={onRefresh} disabled={isRefreshing}>
<div className="flex items-center gap-2">
{isRefreshing ? (
<Loader2 className="w-4 h-4 animate-spin" />
) : (
<RefreshCw className="w-4 h-4" />
)}
{isRefreshing
? t("settings:experimental.CUSTOM_TOOLS.refreshing")
: t("settings:experimental.CUSTOM_TOOLS.refreshButton")}
</div>
</Button>
</div>

{refreshError && (
<div className="p-2 bg-vscode-inputValidation-errorBackground text-vscode-errorForeground rounded text-sm border border-vscode-inputValidation-errorBorder">
{t("settings:experimental.CUSTOM_TOOLS.refreshError")}: {refreshError}
</div>
)}

{processedTools.length === 0 ? (
<p className="text-vscode-descriptionForeground text-sm italic">
{t("settings:experimental.CUSTOM_TOOLS.noTools")}
</p>
) : (
<div className="space-y-2">
{processedTools.map((tool) => (
<div
key={tool.name}
className="p-3 bg-vscode-editor-background border border-vscode-panel-border rounded">
<div className="font-medium text-vscode-foreground">{tool.name}</div>
<p className="text-vscode-descriptionForeground text-sm mt-1">{tool.description}</p>
{tool.parameters.length > 0 && (
<div className="mt-3">
<div className="text-xs font-medium text-vscode-foreground mb-2">
{t("settings:experimental.CUSTOM_TOOLS.toolParameters")}:
</div>
<div className="space-y-1">
{tool.parameters.map((param) => (
<div
key={param.name}
className="flex items-start gap-2 text-xs pl-2 py-1 border-l-2 border-vscode-panel-border">
<code className="text-vscode-textLink-foreground font-mono">
{param.name}
</code>
<span className="text-vscode-descriptionForeground">
({param.type})
</span>
{param.required && (
<span className="text-vscode-errorForeground text-[10px] uppercase">
required
</span>
)}
{param.description && (
<span className="text-vscode-descriptionForeground">
— {param.description}
</span>
)}
</div>
))}
</div>
</div>
)}
</div>
))}
</div>
)}
</div>
)}
</div>
)
}
Loading