Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions packages/types/src/experiment.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ export const experimentIds = [
"imageGeneration",
"runSlashCommand",
"multipleNativeToolCalls",
"customTools",
] as const

export const experimentIdsSchema = z.enum(experimentIds)
Expand All @@ -30,6 +31,7 @@ export const experimentsSchema = z.object({
imageGeneration: z.boolean().optional(),
runSlashCommand: z.boolean().optional(),
multipleNativeToolCalls: z.boolean().optional(),
customTools: z.boolean().optional(),
})

export type Experiments = z.infer<typeof experimentsSchema>
Expand Down
10 changes: 5 additions & 5 deletions src/core/assistant-message/presentAssistantMessage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1081,17 +1081,15 @@ export async function presentAssistantMessage(cline: Task) {
break
}

const customTool = customToolRegistry.get(block.name)
const customTool = stateExperiments?.customTools ? customToolRegistry.get(block.name) : undefined

if (customTool) {
try {
console.log(`executing customTool -> ${JSON.stringify(customTool, null, 2)}`)
let customToolArgs

if (customTool.parameters) {
try {
customToolArgs = customTool.parameters.parse(block.nativeArgs || block.params || {})
console.log(`customToolArgs -> ${JSON.stringify(customToolArgs, null, 2)}`)
} catch (parseParamsError) {
const message = `Custom tool "${block.name}" argument validation failed: ${parseParamsError.message}`
console.error(message)
Expand All @@ -1102,13 +1100,15 @@ export async function presentAssistantMessage(cline: Task) {
}
}

console.log(`${customTool.name}.execute() -> ${JSON.stringify(customToolArgs, null, 2)}`)

const result = await customTool.execute(customToolArgs, {
mode: mode ?? defaultModeSlug,
task: cline,
})

console.log(
`${customTool.name}.execute(): ${JSON.stringify(customToolArgs)} -> ${JSON.stringify(result)}`,
)

pushToolResult(result)
cline.consecutiveMistakeCount = 0
} catch (executionError: any) {
Expand Down
2 changes: 1 addition & 1 deletion src/core/prompts/system.ts
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ async function generatePrompt(

let customToolsSection = ""

if (!isNativeProtocol(effectiveProtocol)) {
if (experiments?.customTools && !isNativeProtocol(effectiveProtocol)) {
const customTools = customToolRegistry.getAllSerialized()

if (customTools.length > 0) {
Expand Down
13 changes: 8 additions & 5 deletions src/core/task/build-tools.ts
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,16 @@ export async function buildNativeToolsArray(options: BuildToolsOptions): Promise
const mcpTools = getMcpServerTools(mcpHub)
const filteredMcpTools = filterMcpToolsForMode(mcpTools, mode, customModes, experiments)

// Add custom tools if they are available.
await customToolRegistry.loadFromDirectoryIfStale(path.join(cwd, ".roo", "tools"))
const customTools = customToolRegistry.getAllSerialized()
// Add custom tools if they are available and the experiment is enabled.
let nativeCustomTools: OpenAI.Chat.ChatCompletionFunctionTool[] = []

if (customTools.length > 0) {
nativeCustomTools = customTools.map(formatNative)
if (experiments?.customTools) {
await customToolRegistry.loadFromDirectoryIfStale(path.join(cwd, ".roo", "tools"))
const customTools = customToolRegistry.getAllSerialized()

if (customTools.length > 0) {
nativeCustomTools = customTools.map(formatNative)
}
}

return [...filteredNativeTools, ...filteredMcpTools, ...nativeCustomTools]
Expand Down
8 changes: 4 additions & 4 deletions src/core/tools/validateToolUse.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ import { TOOL_GROUPS, ALWAYS_AVAILABLE_TOOLS } from "../../shared/tools"
* Note: This does NOT check if the tool is allowed for a specific mode,
* only that the tool actually exists.
*/
export function isValidToolName(toolName: string): toolName is ToolName {
export function isValidToolName(toolName: string, experiments?: Record<string, boolean>): toolName is ToolName {
// Check if it's a valid static tool
if ((validToolNames as readonly string[]).includes(toolName)) {
return true
}

if (customToolRegistry.has(toolName)) {
if (experiments?.customTools && customToolRegistry.has(toolName)) {
return true
}

Expand All @@ -40,7 +40,7 @@ export function validateToolUse(
): void {
// First, check if the tool name is actually a valid/known tool
// This catches completely invalid tool names like "edit_file" that don't exist
if (!isValidToolName(toolName)) {
if (!isValidToolName(toolName, experiments)) {
throw new Error(
`Unknown tool "${toolName}". This tool does not exist. Please use one of the available tools: ${validToolNames.join(", ")}.`,
)
Expand Down Expand Up @@ -94,7 +94,7 @@ export function isToolAllowedForMode(

// For now, allow all custom tools in any mode.
// As a follow-up we should expand the custom tool definition to include mode restrictions.
if (customToolRegistry.has(tool)) {
if (experiments?.customTools && customToolRegistry.has(tool)) {
return true
}

Expand Down
19 changes: 19 additions & 0 deletions src/core/webview/webviewMessageHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import {
Experiments,
ExperimentId,
} from "@roo-code/types"
import { customToolRegistry } from "@roo-code/core"
import { CloudService } from "@roo-code/cloud"
import { TelemetryService } from "@roo-code/telemetry"

Expand Down Expand Up @@ -1725,6 +1726,24 @@ export const webviewMessageHandler = async (
}
break
}
case "refreshCustomTools": {
try {
await customToolRegistry.loadFromDirectory(path.join(getCurrentCwd(), ".roo", "tools"))

await provider.postMessageToWebview({
type: "customToolsResult",
tools: customToolRegistry.getAllSerialized(),
})
} catch (error) {
await provider.postMessageToWebview({
type: "customToolsResult",
tools: [],
error: error instanceof Error ? error.message : String(error),
})
}

break
}
case "saveApiConfiguration":
if (message.text && message.apiConfiguration) {
try {
Expand Down
3 changes: 3 additions & 0 deletions src/shared/ExtensionMessage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import type {
OrganizationAllowList,
ShareVisibility,
QueuedMessage,
SerializedCustomToolDefinition,
} from "@roo-code/types"

import { GitCommit } from "../utils/git"
Expand Down Expand Up @@ -133,6 +134,7 @@ export interface ExtensionMessage {
| "browserSessionUpdate"
| "browserSessionNavigate"
| "claudeCodeRateLimits"
| "customToolsResult"
text?: string
payload?: any // Add a generic payload for now, can refine later
// Checkpoint warning message
Expand Down Expand Up @@ -218,6 +220,7 @@ export interface ExtensionMessage {
browserSessionMessages?: ClineMessage[] // For browser session panel updates
isBrowserSessionActive?: boolean // For browser session panel updates
stepIndex?: number // For browserSessionNavigate: the target step index to display
tools?: SerializedCustomToolDefinition[] // For customToolsResult
}

export type ExtensionState = Pick<
Expand Down
1 change: 1 addition & 0 deletions src/shared/WebviewMessage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ export interface WebviewMessage {
| "openDebugUiHistory"
| "downloadErrorDiagnostics"
| "requestClaudeCodeRateLimits"
| "refreshCustomTools"
text?: string
editedMessageContent?: string
tab?: "settings" | "history" | "mcp" | "modes" | "chat" | "marketplace" | "cloud"
Expand Down
3 changes: 3 additions & 0 deletions src/shared/__tests__/experiments.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ describe("experiments", () => {
imageGeneration: false,
runSlashCommand: false,
multipleNativeToolCalls: false,
customTools: false,
}
expect(Experiments.isEnabled(experiments, EXPERIMENT_IDS.POWER_STEERING)).toBe(false)
})
Expand All @@ -44,6 +45,7 @@ describe("experiments", () => {
imageGeneration: false,
runSlashCommand: false,
multipleNativeToolCalls: false,
customTools: false,
}
expect(Experiments.isEnabled(experiments, EXPERIMENT_IDS.POWER_STEERING)).toBe(true)
})
Expand All @@ -56,6 +58,7 @@ describe("experiments", () => {
imageGeneration: false,
runSlashCommand: false,
multipleNativeToolCalls: false,
customTools: false,
}
expect(Experiments.isEnabled(experiments, EXPERIMENT_IDS.POWER_STEERING)).toBe(false)
})
Expand Down
2 changes: 2 additions & 0 deletions src/shared/experiments.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ export const EXPERIMENT_IDS = {
IMAGE_GENERATION: "imageGeneration",
RUN_SLASH_COMMAND: "runSlashCommand",
MULTIPLE_NATIVE_TOOL_CALLS: "multipleNativeToolCalls",
CUSTOM_TOOLS: "customTools",
} as const satisfies Record<string, ExperimentId>

type _AssertExperimentIds = AssertEqual<Equals<ExperimentId, Values<typeof EXPERIMENT_IDS>>>
Expand All @@ -24,6 +25,7 @@ export const experimentConfigsMap: Record<ExperimentKey, ExperimentConfig> = {
IMAGE_GENERATION: { enabled: false },
RUN_SLASH_COMMAND: { enabled: false },
MULTIPLE_NATIVE_TOOL_CALLS: { enabled: false },
CUSTOM_TOOLS: { enabled: false },
}

export const experimentDefault = Object.fromEntries(
Expand Down
173 changes: 173 additions & 0 deletions webview-ui/src/components/settings/CustomToolsSettings.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
import { useState, useEffect, useCallback, useMemo } from "react"
import { useEvent } from "react-use"
import { VSCodeCheckbox } from "@vscode/webview-ui-toolkit/react"
import { RefreshCw, Loader2 } from "lucide-react"

import type { SerializedCustomToolDefinition } from "@roo-code/types"

import { useAppTranslation } from "@/i18n/TranslationContext"

import { vscode } from "@/utils/vscode"

import { Button } from "@/components/ui"

interface ToolParameter {
name: string
type: string
description?: string
required: boolean
}

interface ProcessedTool {
name: string
description: string
parameters: ToolParameter[]
}

interface CustomToolsSettingsProps {
enabled: boolean
onChange: (enabled: boolean) => void
}

export const CustomToolsSettings = ({ enabled, onChange }: CustomToolsSettingsProps) => {
const { t } = useAppTranslation()
const [tools, setTools] = useState<SerializedCustomToolDefinition[]>([])
const [isRefreshing, setIsRefreshing] = useState(false)
const [refreshError, setRefreshError] = useState<string | null>(null)

useEffect(() => {
if (enabled) {
vscode.postMessage({ type: "refreshCustomTools" })
} else {
setTools([])
}
}, [enabled])

useEvent("message", (event: MessageEvent) => {
const message = event.data

if (message.type === "customToolsResult") {
setTools(message.tools || [])
setIsRefreshing(false)
setRefreshError(message.error ?? null)
}
})

const onRefresh = useCallback(() => {
setIsRefreshing(true)
setRefreshError(null)
vscode.postMessage({ type: "refreshCustomTools" })
}, [])

const processedTools = useMemo<ProcessedTool[]>(
() =>
tools.map((tool) => {
const params = tool.parameters
const properties = (params?.properties ?? {}) as Record<string, { type?: string; description?: string }>
const required = (params?.required as string[] | undefined) ?? []

return {
name: tool.name,
description: tool.description,
parameters: Object.entries(properties).map(([name, def]) => ({
name,
type: def.type ?? "any",
description: def.description,
required: required.includes(name),
})),
}
}),
[tools],
)

return (
<div className="space-y-4">
<div>
<div className="flex items-center gap-2">
<VSCodeCheckbox checked={enabled} onChange={(e: any) => onChange(e.target.checked)}>
<span className="font-medium">{t("settings:experimental.CUSTOM_TOOLS.name")}</span>
</VSCodeCheckbox>
</div>
<p className="text-vscode-descriptionForeground text-sm mt-0">
{t("settings:experimental.CUSTOM_TOOLS.description")}
</p>
</div>

{enabled && (
<div className="ml-2 space-y-3">
<div className="flex items-center justify-between gap-4">
<label className="block font-medium">
{t("settings:experimental.CUSTOM_TOOLS.toolsHeader")}
</label>
<Button variant="outline" onClick={onRefresh} disabled={isRefreshing}>
<div className="flex items-center gap-2">
{isRefreshing ? (
<Loader2 className="w-4 h-4 animate-spin" />
) : (
<RefreshCw className="w-4 h-4" />
)}
{isRefreshing
? t("settings:experimental.CUSTOM_TOOLS.refreshing")
: t("settings:experimental.CUSTOM_TOOLS.refreshButton")}
</div>
</Button>
</div>

{refreshError && (
<div className="p-2 bg-vscode-inputValidation-errorBackground text-vscode-errorForeground rounded text-sm border border-vscode-inputValidation-errorBorder">
{t("settings:experimental.CUSTOM_TOOLS.refreshError")}: {refreshError}
</div>
)}

{processedTools.length === 0 ? (
<p className="text-vscode-descriptionForeground text-sm italic">
{t("settings:experimental.CUSTOM_TOOLS.noTools")}
</p>
) : (
<div className="space-y-2">
{processedTools.map((tool) => (
<div
key={tool.name}
className="p-3 bg-vscode-editor-background border border-vscode-panel-border rounded">
<div className="font-medium text-vscode-foreground">{tool.name}</div>
<p className="text-vscode-descriptionForeground text-sm mt-1">{tool.description}</p>
{tool.parameters.length > 0 && (
<div className="mt-3">
<div className="text-xs font-medium text-vscode-foreground mb-2">
{t("settings:experimental.CUSTOM_TOOLS.toolParameters")}:
</div>
<div className="space-y-1">
{tool.parameters.map((param) => (
<div
key={param.name}
className="flex items-start gap-2 text-xs pl-2 py-1 border-l-2 border-vscode-panel-border">
<code className="text-vscode-textLink-foreground font-mono">
{param.name}
</code>
<span className="text-vscode-descriptionForeground">
({param.type})
</span>
{param.required && (
<span className="text-vscode-errorForeground text-[10px] uppercase">
required
</span>
)}
{param.description && (
<span className="text-vscode-descriptionForeground">
— {param.description}
</span>
)}
</div>
))}
</div>
</div>
)}
</div>
))}
</div>
)}
</div>
)}
</div>
)
}
Loading