diff --git a/packages/core/src/custom-tools/custom-tool-registry.ts b/packages/core/src/custom-tools/custom-tool-registry.ts index cfcd356335..bd5d648f17 100644 --- a/packages/core/src/custom-tools/custom-tool-registry.ts +++ b/packages/core/src/custom-tools/custom-tool-registry.ts @@ -366,10 +366,18 @@ export class CustomToolRegistry { } // Add alias for @roo-code/types if we found it. - // Note: @roo-code/types is built with zod bundled in, so we don't need a separate zod alias. + // Use the packaged version to ensure parametersSchema instance consistency if (this.typesPackagePath) { - esbuildOptions.alias = { - "@roo-code/types": this.typesPackagePath, + // Look for the packaged ES modules version to maintain instance consistency + const packagedTypesPath = path.join(path.dirname(this.typesPackagePath), "dist", "index.js") + if (fs.existsSync(packagedTypesPath)) { + esbuildOptions.alias = { + "@roo-code/types": packagedTypesPath, + } + } else { + esbuildOptions.alias = { + "@roo-code/types": this.typesPackagePath, + } } } diff --git a/src/api/providers/__tests__/bedrock-native-tools.spec.ts b/src/api/providers/__tests__/bedrock-native-tools.spec.ts index 8325f94bfb..0396a81744 100644 --- a/src/api/providers/__tests__/bedrock-native-tools.spec.ts +++ b/src/api/providers/__tests__/bedrock-native-tools.spec.ts @@ -168,12 +168,13 @@ describe("AwsBedrockHandler Native Tool Calling", () => { expect(executeCommandSchema.properties.cwd.description).toBe("Working directory (optional)") // Second tool: line_ranges should be transformed from type: ["array", "null"] to anyOf + // with items moved inside the array variant (required by GPT-5-mini strict schema validation) const readFileSchema = bedrockTools[1].toolSpec.inputSchema.json as any const lineRanges = readFileSchema.properties.files.items.properties.line_ranges - expect(lineRanges.anyOf).toEqual([{ type: "array" }, { type: "null" }]) + expect(lineRanges.anyOf).toEqual([{ type: "array", items: { type: "integer" } }, { type: "null" }]) expect(lineRanges.type).toBeUndefined() - // items also gets additionalProperties: false from normalization - expect(lineRanges.items.type).toBe("integer") + // items should now be inside the array variant, not at root + expect(lineRanges.items).toBeUndefined() expect(lineRanges.description).toBe("Optional line ranges") }) diff --git a/src/api/providers/__tests__/openai.spec.ts b/src/api/providers/__tests__/openai.spec.ts index 32449bf69a..4b147c2f09 100644 --- a/src/api/providers/__tests__/openai.spec.ts +++ b/src/api/providers/__tests__/openai.spec.ts @@ -295,6 +295,10 @@ describe("OpenAiHandler", () => { name: undefined, arguments: '"value"}', }) + + // Verify tool_call_end event is emitted when finish_reason is "tool_calls" + const toolCallEndChunks = chunks.filter((chunk) => chunk.type === "tool_call_end") + expect(toolCallEndChunks).toHaveLength(1) }) it("should yield tool calls even when finish_reason is not set (fallback behavior)", async () => { @@ -855,6 +859,10 @@ describe("OpenAiHandler", () => { name: undefined, arguments: "{}", }) + + // Verify tool_call_end event is emitted when finish_reason is "tool_calls" + const toolCallEndChunks = chunks.filter((chunk) => chunk.type === "tool_call_end") + expect(toolCallEndChunks).toHaveLength(1) }) it("should yield tool calls for O3 model even when finish_reason is not set (fallback behavior)", async () => { diff --git a/src/api/providers/base-provider.ts b/src/api/providers/base-provider.ts index 84c8cf6fe9..64d99b3f0c 100644 --- a/src/api/providers/base-provider.ts +++ b/src/api/providers/base-provider.ts @@ -5,6 +5,7 @@ import type { ModelInfo } from "@roo-code/types" import type { ApiHandler, ApiHandlerCreateMessageMetadata } from "../index" import { ApiStream } from "../transform/stream" import { countTokens } from "../../utils/countTokens" +import { isMcpTool } from "../../utils/mcp-name" /** * Base class for API providers that implements common functionality. @@ -28,18 +29,26 @@ export abstract class BaseProvider implements ApiHandler { return undefined } - return tools.map((tool) => - tool.type === "function" - ? { - ...tool, - function: { - ...tool.function, - strict: true, - parameters: this.convertToolSchemaForOpenAI(tool.function.parameters), - }, - } - : tool, - ) + return tools.map((tool) => { + if (tool.type !== "function") { + return tool + } + + // MCP tools use the 'mcp--' prefix - disable strict mode for them + // to preserve optional parameters from the MCP server schema + const isMcp = isMcpTool(tool.function.name) + + return { + ...tool, + function: { + ...tool.function, + strict: !isMcp, + parameters: isMcp + ? tool.function.parameters + : this.convertToolSchemaForOpenAI(tool.function.parameters), + }, + } + }) } /** diff --git a/src/api/providers/fetchers/__tests__/chutes.spec.ts b/src/api/providers/fetchers/__tests__/chutes.spec.ts index b68903fde5..79ed027383 100644 --- a/src/api/providers/fetchers/__tests__/chutes.spec.ts +++ b/src/api/providers/fetchers/__tests__/chutes.spec.ts @@ -212,4 +212,132 @@ describe("getChutesModels", () => { expect(models["test/no-tools-model"].supportsNativeTools).toBe(false) expect(models["test/no-tools-model"].defaultToolProtocol).toBeUndefined() }) + + it("should skip empty objects in API response and still process valid models", async () => { + const mockResponse = { + data: { + data: [ + { + id: "test/valid-model", + object: "model", + owned_by: "test", + created: 1234567890, + context_length: 128000, + max_model_len: 8192, + input_modalities: ["text"], + }, + {}, // Empty object - should be skipped + { + id: "test/another-valid-model", + object: "model", + context_length: 64000, + max_model_len: 4096, + }, + ], + }, + } + + mockedAxios.get.mockResolvedValue(mockResponse) + + const models = await getChutesModels("test-api-key") + + // Valid models should be processed + expect(models["test/valid-model"]).toBeDefined() + expect(models["test/valid-model"].contextWindow).toBe(128000) + expect(models["test/another-valid-model"]).toBeDefined() + expect(models["test/another-valid-model"].contextWindow).toBe(64000) + }) + + it("should skip models without id field", async () => { + const mockResponse = { + data: { + data: [ + { + // Missing id field + object: "model", + context_length: 128000, + max_model_len: 8192, + }, + { + id: "test/valid-model", + context_length: 64000, + max_model_len: 4096, + }, + ], + }, + } + + mockedAxios.get.mockResolvedValue(mockResponse) + + const models = await getChutesModels("test-api-key") + + // Only the valid model should be added + expect(models["test/valid-model"]).toBeDefined() + // Hardcoded models should still exist + expect(Object.keys(models).length).toBeGreaterThan(1) + }) + + it("should calculate maxTokens fallback when max_model_len is missing", async () => { + const mockResponse = { + data: { + data: [ + { + id: "test/no-max-len-model", + object: "model", + context_length: 100000, + // max_model_len is missing + input_modalities: ["text"], + }, + ], + }, + } + + mockedAxios.get.mockResolvedValue(mockResponse) + + const models = await getChutesModels("test-api-key") + + // Should calculate maxTokens as 20% of contextWindow + expect(models["test/no-max-len-model"]).toBeDefined() + expect(models["test/no-max-len-model"].maxTokens).toBe(20000) // 100000 * 0.2 + expect(models["test/no-max-len-model"].contextWindow).toBe(100000) + }) + + it("should gracefully handle response with mixed valid and invalid items", async () => { + const consoleErrorSpy = vi.spyOn(console, "error").mockImplementation(() => {}) + + const mockResponse = { + data: { + data: [ + { + id: "test/valid-1", + context_length: 128000, + max_model_len: 8192, + }, + {}, // Empty - will be skipped + null, // Null - will be skipped + { + id: "", // Empty string id - will be skipped + context_length: 64000, + }, + { + id: "test/valid-2", + context_length: 256000, + max_model_len: 16384, + supported_features: ["tools"], + }, + ], + }, + } + + mockedAxios.get.mockResolvedValue(mockResponse) + + const models = await getChutesModels("test-api-key") + + // Both valid models should be processed + expect(models["test/valid-1"]).toBeDefined() + expect(models["test/valid-2"]).toBeDefined() + expect(models["test/valid-2"].supportsNativeTools).toBe(true) + + consoleErrorSpy.mockRestore() + }) }) diff --git a/src/api/providers/fetchers/chutes.ts b/src/api/providers/fetchers/chutes.ts index d5334aa59b..7a237a07bb 100644 --- a/src/api/providers/fetchers/chutes.ts +++ b/src/api/providers/fetchers/chutes.ts @@ -6,19 +6,22 @@ import { type ModelInfo, chutesModels } from "@roo-code/types" import { DEFAULT_HEADERS } from "../constants" // Chutes models endpoint follows OpenAI /models shape with additional fields. +// All fields are optional to allow graceful handling of incomplete API responses. const ChutesModelSchema = z.object({ - id: z.string(), + id: z.string().optional(), object: z.literal("model").optional(), owned_by: z.string().optional(), created: z.number().optional(), context_length: z.number().optional(), - max_model_len: z.number(), + max_model_len: z.number().optional(), input_modalities: z.array(z.string()).optional(), supported_features: z.array(z.string()).optional(), }) const ChutesModelsResponseSchema = z.object({ data: z.array(ChutesModelSchema) }) +type ChutesModelsResponse = z.infer + export async function getChutesModels(apiKey?: string): Promise> { const headers: Record = { ...DEFAULT_HEADERS } @@ -32,33 +35,51 @@ export async function getChutesModels(apiKey?: string): Promise = { ...chutesModels } try { - const response = await axios.get(url, { headers }) - const parsed = ChutesModelsResponseSchema.safeParse(response.data) - - if (parsed.success) { - for (const m of parsed.data.data) { - const contextWindow = m.context_length - - if (!contextWindow) { - continue - } - - const info: ModelInfo = { - maxTokens: m.max_model_len, - contextWindow, - supportsImages: (m.input_modalities || []).includes("image"), - supportsPromptCache: false, - supportsNativeTools: (m.supported_features || []).includes("tools"), - inputPrice: 0, - outputPrice: 0, - description: `Chutes AI model: ${m.id}`, - } - - // Union: dynamic models override hardcoded ones if they have the same ID. - models[m.id] = info + const response = await axios.get(url, { headers }) + const result = ChutesModelsResponseSchema.safeParse(response.data) + + // Graceful fallback: use parsed data if valid, otherwise fall back to raw response data. + // This mirrors the OpenRouter pattern for handling API responses with some invalid items. + const data = result.success ? result.data.data : response.data?.data + + if (!result.success) { + console.error(`Error parsing Chutes models response: ${JSON.stringify(result.error.format(), null, 2)}`) + } + + if (!data || !Array.isArray(data)) { + console.error("Chutes models response missing data array") + return models + } + + for (const m of data) { + // Skip items missing required fields (e.g., empty objects from API) + if (!m || typeof m.id !== "string" || !m.id) { + continue } - } else { - console.error(`Error parsing Chutes models: ${JSON.stringify(parsed.error.format(), null, 2)}`) + + const contextWindow = + typeof m.context_length === "number" && Number.isFinite(m.context_length) ? m.context_length : undefined + const maxModelLen = + typeof m.max_model_len === "number" && Number.isFinite(m.max_model_len) ? m.max_model_len : undefined + + // Skip models without valid context window information + if (!contextWindow) { + continue + } + + const info: ModelInfo = { + maxTokens: maxModelLen ?? Math.ceil(contextWindow * 0.2), + contextWindow, + supportsImages: (m.input_modalities || []).includes("image"), + supportsPromptCache: false, + supportsNativeTools: (m.supported_features || []).includes("tools"), + inputPrice: 0, + outputPrice: 0, + description: `Chutes AI model: ${m.id}`, + } + + // Union: dynamic models override hardcoded ones if they have the same ID. + models[m.id] = info } } catch (error) { console.error(`Error fetching Chutes models: ${error instanceof Error ? error.message : String(error)}`) diff --git a/src/api/providers/openai-native.ts b/src/api/providers/openai-native.ts index e0f5b51786..8ba8c390d3 100644 --- a/src/api/providers/openai-native.ts +++ b/src/api/providers/openai-native.ts @@ -24,6 +24,7 @@ import { getModelParams } from "../transform/model-params" import { BaseProvider } from "./base-provider" import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" +import { isMcpTool } from "../../utils/mcp-name" export type OpenAiNativeModel = ReturnType @@ -291,13 +292,18 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio ...(metadata?.tools && { tools: metadata.tools .filter((tool) => tool.type === "function") - .map((tool) => ({ - type: "function", - name: tool.function.name, - description: tool.function.description, - parameters: ensureAllRequired(tool.function.parameters), - strict: true, - })), + .map((tool) => { + // MCP tools use the 'mcp--' prefix - disable strict mode for them + // to preserve optional parameters from the MCP server schema + const isMcp = isMcpTool(tool.function.name) + return { + type: "function", + name: tool.function.name, + description: tool.function.description, + parameters: isMcp ? tool.function.parameters : ensureAllRequired(tool.function.parameters), + strict: !isMcp, + } + }), }), ...(metadata?.tool_choice && { tool_choice: metadata.tool_choice }), } diff --git a/src/api/providers/openai.ts b/src/api/providers/openai.ts index 08fcfed6ab..f44f120804 100644 --- a/src/api/providers/openai.ts +++ b/src/api/providers/openai.ts @@ -194,9 +194,11 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl ) let lastUsage + const activeToolCallIds = new Set() for await (const chunk of stream) { const delta = chunk.choices?.[0]?.delta ?? {} + const finishReason = chunk.choices?.[0]?.finish_reason if (delta.content) { for (const chunk of matcher.update(delta.content)) { @@ -211,17 +213,7 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl } } - if (delta.tool_calls) { - for (const toolCall of delta.tool_calls) { - yield { - type: "tool_call_partial", - index: toolCall.index, - id: toolCall.id, - name: toolCall.function?.name, - arguments: toolCall.function?.arguments, - } - } - } + yield* this.processToolCalls(delta, finishReason, activeToolCallIds) if (chunk.usage) { lastUsage = chunk.usage @@ -447,8 +439,11 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl } private async *handleStreamResponse(stream: AsyncIterable): ApiStream { + const activeToolCallIds = new Set() + for await (const chunk of stream) { const delta = chunk.choices?.[0]?.delta + const finishReason = chunk.choices?.[0]?.finish_reason if (delta) { if (delta.content) { @@ -458,18 +453,7 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl } } - // Emit raw tool call chunks - NativeToolCallParser handles state management - if (delta.tool_calls) { - for (const toolCall of delta.tool_calls) { - yield { - type: "tool_call_partial", - index: toolCall.index, - id: toolCall.id, - name: toolCall.function?.name, - arguments: toolCall.function?.arguments, - } - } - } + yield* this.processToolCalls(delta, finishReason, activeToolCallIds) } if (chunk.usage) { @@ -482,6 +466,46 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl } } + /** + * Helper generator to process tool calls from a stream chunk. + * Tracks active tool call IDs and yields tool_call_partial and tool_call_end events. + * @param delta - The delta object from the stream chunk + * @param finishReason - The finish_reason from the stream chunk + * @param activeToolCallIds - Set to track active tool call IDs (mutated in place) + */ + private *processToolCalls( + delta: OpenAI.Chat.Completions.ChatCompletionChunk.Choice.Delta | undefined, + finishReason: string | null | undefined, + activeToolCallIds: Set, + ): Generator< + | { type: "tool_call_partial"; index: number; id?: string; name?: string; arguments?: string } + | { type: "tool_call_end"; id: string } + > { + if (delta?.tool_calls) { + for (const toolCall of delta.tool_calls) { + if (toolCall.id) { + activeToolCallIds.add(toolCall.id) + } + yield { + type: "tool_call_partial", + index: toolCall.index, + id: toolCall.id, + name: toolCall.function?.name, + arguments: toolCall.function?.arguments, + } + } + } + + // Emit tool_call_end events when finish_reason is "tool_calls" + // This ensures tool calls are finalized even if the stream doesn't properly close + if (finishReason === "tool_calls" && activeToolCallIds.size > 0) { + for (const id of activeToolCallIds) { + yield { type: "tool_call_end", id } + } + activeToolCallIds.clear() + } + } + protected _getUrlHost(baseUrl?: string): string { try { return new URL(baseUrl ?? "").host diff --git a/src/api/providers/zgsm.ts b/src/api/providers/zgsm.ts index 2c5343b500..0929441178 100644 --- a/src/api/providers/zgsm.ts +++ b/src/api/providers/zgsm.ts @@ -462,6 +462,7 @@ export class ZgsmAiHandler extends BaseProvider implements SingleCompletionHandl ) let lastUsage + const activeToolCallIds = new Set() // Use content buffer to reduce matcher.update() calls const contentBuffer: string[] = [] @@ -481,12 +482,8 @@ export class ZgsmAiHandler extends BaseProvider implements SingleCompletionHandl // chunk for await (const chunk of stream) { - if (this.abortController?.signal.aborted) { - break - } - const delta = chunk.choices?.[0]?.delta ?? {} - + const finishReason = chunk.choices?.[0]?.finish_reason // Cache content for batch processing if (delta.content) { contentBuffer.push(delta.content) @@ -517,17 +514,7 @@ export class ZgsmAiHandler extends BaseProvider implements SingleCompletionHandl } } - if (delta.tool_calls) { - for (const toolCall of delta.tool_calls) { - yield { - type: "tool_call_partial", - index: toolCall.index, - id: toolCall.id, - name: toolCall.function?.name, - arguments: toolCall.function?.arguments, - } - } - } + yield* this.processToolCalls(delta, finishReason, activeToolCallIds) // Cache usage information if (chunk.usage) { @@ -554,6 +541,46 @@ export class ZgsmAiHandler extends BaseProvider implements SingleCompletionHandl } } + /** + * Helper generator to process tool calls from a stream chunk. + * Tracks active tool call IDs and yields tool_call_partial and tool_call_end events. + * @param delta - The delta object from the stream chunk + * @param finishReason - The finish_reason from the stream chunk + * @param activeToolCallIds - Set to track active tool call IDs (mutated in place) + */ + private *processToolCalls( + delta: OpenAI.Chat.Completions.ChatCompletionChunk.Choice.Delta | undefined, + finishReason: string | null | undefined, + activeToolCallIds: Set, + ): Generator< + | { type: "tool_call_partial"; index: number; id?: string; name?: string; arguments?: string } + | { type: "tool_call_end"; id: string } + > { + if (delta?.tool_calls) { + for (const toolCall of delta.tool_calls) { + if (toolCall.id) { + activeToolCallIds.add(toolCall.id) + } + yield { + type: "tool_call_partial", + index: toolCall.index, + id: toolCall.id, + name: toolCall.function?.name, + arguments: toolCall.function?.arguments, + } + } + } + + // Emit tool_call_end events when finish_reason is "tool_calls" + // This ensures tool calls are finalized even if the stream doesn't properly close + if (finishReason === "tool_calls" && activeToolCallIds.size > 0) { + for (const id of activeToolCallIds) { + yield { type: "tool_call_end", id } + } + activeToolCallIds.clear() + } + } + async updateModelInfo() { const id = this.options.zgsmModelId ?? zgsmDefaultModelId const info = @@ -740,11 +767,11 @@ export class ZgsmAiHandler extends BaseProvider implements SingleCompletionHandl } private async *handleStreamResponse(stream: AsyncIterable): ApiStream { + const activeToolCallIds = new Set() + for await (const chunk of stream) { - if (this.abortController?.signal.aborted) { - break - } const delta = chunk.choices?.[0]?.delta + const finishReason = chunk.choices?.[0]?.finish_reason if (delta) { if (delta.content) { @@ -754,18 +781,7 @@ export class ZgsmAiHandler extends BaseProvider implements SingleCompletionHandl } } - // Emit raw tool call chunks - NativeToolCallParser handles state management - if (delta.tool_calls) { - for (const toolCall of delta.tool_calls) { - yield { - type: "tool_call_partial", - index: toolCall.index, - id: toolCall.id, - name: toolCall.function?.name, - arguments: toolCall.function?.arguments, - } - } - } + yield* this.processToolCalls(delta, finishReason, activeToolCallIds) } if (chunk.usage) { diff --git a/src/core/assistant-message/AssistantMessageParser.ts b/src/core/assistant-message/AssistantMessageParser.ts index d7af914170..9194751996 100644 --- a/src/core/assistant-message/AssistantMessageParser.ts +++ b/src/core/assistant-message/AssistantMessageParser.ts @@ -1,6 +1,7 @@ import { type ToolName, toolNames } from "@roo-code/types" import { TextContent, ToolUse, ToolParamName, toolParamNames } from "../../shared/tools" import { AssistantMessageContent } from "./parseAssistantMessage" +import { parseXml } from "../../utils/xml" /** * Parser for assistant messages. Maintains state between chunks @@ -12,16 +13,22 @@ export class AssistantMessageParser { private currentTextContentStartIndex = 0 private currentToolUse: ToolUse | undefined = undefined private currentToolUseStartIndex = 0 - private currentParamName: ToolParamName | undefined = undefined + private currentParamName: string | undefined = undefined private currentParamValueStartIndex = 0 private readonly MAX_ACCUMULATOR_SIZE = 1024 * 1024 // 1MB limit private readonly MAX_PARAM_LENGTH = 1024 * 100 // 100KB per parameter limit private accumulator = "" + private allToolNames: readonly string[] + private customToolNames: Set /** * Initialize a new AssistantMessageParser instance. + * @param customToolNames - Optional array of custom tool names to recognize in addition to built-in tools */ - constructor() { + constructor(customToolNames: string[] = []) { + // Combine built-in tool names with custom tool names + this.allToolNames = [...toolNames, ...customToolNames] + this.customToolNames = new Set(customToolNames) this.reset() } @@ -78,7 +85,8 @@ export class AssistantMessageParser { // End of param value. // Do not trim content parameters to preserve newlines, but strip first and last newline only const paramValue = currentParamValue.slice(0, -paramClosingTag.length) - this.currentToolUse.params[this.currentParamName] = + // Use type assertion to support custom tool parameters + ;(this.currentToolUse.params as Record)[this.currentParamName] = this.currentParamName === "content" ? paramValue.replace(/^\n/, "").replace(/\n$/, "") : paramValue.trim() @@ -87,7 +95,8 @@ export class AssistantMessageParser { } else { // Partial param value is accumulating. // Write the currently accumulated param content in real time - this.currentToolUse.params[this.currentParamName] = currentParamValue + ;(this.currentToolUse.params as Record)[this.currentParamName] = + currentParamValue continue } } @@ -98,15 +107,30 @@ export class AssistantMessageParser { const currentToolValue = this.accumulator.slice(this.currentToolUseStartIndex) const toolUseClosingTag = `` if (currentToolValue.endsWith(toolUseClosingTag)) { - // End of a tool use. - this.currentToolUse.partial = false if ( this.currentToolUse.name === "attempt_completion" && !this.currentToolUse?.params?.result && currentToolValue.trim() ) { this.currentToolUse.params.result = currentToolValueExtract(currentToolValue.trim()) + } else if (this.customToolNames.has(this.currentToolUse.name)) { + // Custom tool use, extract the content + Object.assign( + this.currentToolUse.params, + ((text: string = "") => { + try { + return parseXml(text, []) + } catch (error) { + console.log( + `[${this.currentToolUse.name}] Invalid XML format: ${error instanceof Error ? error.message : String(error)}`, + ) + return {} + } + })(currentToolValue.trim()), + ) } + // End of a tool use. + this.currentToolUse.partial = false this.currentToolUse = undefined continue } else { @@ -161,7 +185,7 @@ export class AssistantMessageParser { // No currentToolUse. let didStartToolUse = false - const possibleToolUseOpeningTags = toolNames.map((name) => `<${name}>`) + const possibleToolUseOpeningTags = this.allToolNames.map((name) => `<${name}>`) for (const toolUseOpeningTag of possibleToolUseOpeningTags) { if (this.accumulator.endsWith(toolUseOpeningTag)) { @@ -169,7 +193,7 @@ export class AssistantMessageParser { const extractedToolName = toolUseOpeningTag.slice(1, -1) // Check if the extracted tool name is valid - if (!toolNames.includes(extractedToolName as ToolName)) { + if (!this.allToolNames.includes(extractedToolName)) { // Invalid tool name, treat as plain text and continue continue } diff --git a/src/core/assistant-message/__tests__/custom-tool-parsing.test.ts b/src/core/assistant-message/__tests__/custom-tool-parsing.test.ts new file mode 100644 index 0000000000..2f4abfaef2 --- /dev/null +++ b/src/core/assistant-message/__tests__/custom-tool-parsing.test.ts @@ -0,0 +1,81 @@ +import { describe, it, expect } from "vitest" +import { AssistantMessageParser } from "../AssistantMessageParser" + +describe("AssistantMessageParser with custom tools", () => { + it("should parse built-in tools without custom tools", () => { + const parser = new AssistantMessageParser() + const message = "\ntest.ts\n" + + const blocks = parser.processChunk(message) + parser.finalizeContentBlocks() + + const toolBlocks = blocks.filter((b) => b.type === "tool_use") + expect(toolBlocks).toHaveLength(1) + expect(toolBlocks[0]).toMatchObject({ + type: "tool_use", + name: "read_file", + partial: false, + }) + }) + + it("should recognize custom tools when custom tool names are provided", () => { + const parser = new AssistantMessageParser(["add_numbers"]) + const message = "\n5\n10\n" + + const blocks = parser.processChunk(message) + parser.finalizeContentBlocks() + + const toolBlocks = blocks.filter((b) => b.type === "tool_use") + // The key fix: custom tool is now recognized as tool_use, not text + // This prevents the "noToolsUsed" error + expect(toolBlocks).toHaveLength(1) + expect(toolBlocks[0]).toMatchObject({ + type: "tool_use", + name: "add_numbers", + partial: false, + }) + // Note: Custom tool parameter parsing is a separate concern + // and would require extending toolParamNames dynamically + }) + + it("should treat unknown tool names as text when not in custom tools list", () => { + const parser = new AssistantMessageParser() + const message = "\n5\n10\n" + + const blocks = parser.processChunk(message) + parser.finalizeContentBlocks() + + // Without custom tool names, add_numbers should be treated as text + const textBlocks = blocks.filter((b) => b.type === "text") + expect(textBlocks.length).toBeGreaterThan(0) + }) + + it("should recognize mix of built-in and custom tools", () => { + const parser = new AssistantMessageParser(["add_numbers", "custom_tool"]) + const message = `I'll read the file first. + +test.ts + + +Then I'll use the custom tool. + +5 +10 +` + + const blocks = parser.processChunk(message) + parser.finalizeContentBlocks() + + expect(blocks.length).toBeGreaterThanOrEqual(4) // text, tool, text, tool + + const toolBlocks = blocks.filter((b) => b.type === "tool_use") + // Both built-in and custom tools are recognized + expect(toolBlocks).toHaveLength(2) + expect(toolBlocks[0]).toMatchObject({ + name: "read_file", + }) + expect(toolBlocks[1]).toMatchObject({ + name: "add_numbers", + }) + }) +}) diff --git a/src/core/task/Task.ts b/src/core/task/Task.ts index d8fcc24bfe..5480cecab7 100644 --- a/src/core/task/Task.ts +++ b/src/core/task/Task.ts @@ -54,6 +54,7 @@ import { } from "@roo-code/types" // import { CloudService, BridgeOrchestrator } from "@roo-code/cloud" import { TelemetryService } from "@roo-code/telemetry" +import { customToolRegistry } from "@roo-code/core" import { resolveToolProtocol, detectToolProtocolFromHistory } from "../../utils/resolveToolProtocol" // api @@ -95,6 +96,7 @@ import { getWorkspacePath } from "../../utils/path" import { formatResponse } from "../prompts/responses" import { SYSTEM_PROMPT } from "../prompts/system" import { buildNativeToolsArray } from "./build-tools" +import { getRooDirectoriesForCwd } from "../../services/roo-config/index.js" // core modules import { ToolRepetitionDetector } from "../tools/ToolRepetitionDetector" @@ -526,7 +528,8 @@ export class Task extends EventEmitter implements TaskLike { // For history items without a persisted protocol, we default to XML parser // and will update it in resumeTaskFromHistory after detection. const effectiveProtocol = this._taskToolProtocol || "xml" - this.assistantMessageParser = effectiveProtocol !== "native" ? new AssistantMessageParser() : undefined + this.assistantMessageParser = + effectiveProtocol !== "native" ? new AssistantMessageParser(this.getCustomToolNames()) : undefined this.messageQueueService = new MessageQueueService() @@ -1825,7 +1828,7 @@ export class Task extends EventEmitter implements TaskLike { // Update parser state to match the detected/resolved protocol const shouldUseXmlParser = this._taskToolProtocol === "xml" if (shouldUseXmlParser && !this.assistantMessageParser) { - this.assistantMessageParser = new AssistantMessageParser() + this.assistantMessageParser = new AssistantMessageParser(this.getCustomToolNames()) } else if (!shouldUseXmlParser && this.assistantMessageParser) { this.assistantMessageParser.reset() this.assistantMessageParser = undefined @@ -3528,6 +3531,15 @@ export class Task extends EventEmitter implements TaskLike { return false } + /** + * Get the names of all loaded custom tools. + * This is a synchronous method that returns the currently loaded custom tool names. + * If custom tools experiment is not enabled, the registry will be empty. + */ + private getCustomToolNames(): string[] { + return customToolRegistry.list() + } + private async getSystemPrompt(): Promise { const { mcpEnabled } = (await this.providerRef.deref()?.getState()) ?? {} let mcpHub: McpHub | undefined diff --git a/src/core/task/build-tools.ts b/src/core/task/build-tools.ts index 8eea4ace82..d8f381377d 100644 --- a/src/core/task/build-tools.ts +++ b/src/core/task/build-tools.ts @@ -84,7 +84,7 @@ export async function buildNativeToolsArray(options: BuildToolsOptions): Promise let nativeCustomTools: OpenAI.Chat.ChatCompletionFunctionTool[] = [] if (experiments?.customTools) { - const toolDirs = getRooDirectoriesForCwd(cwd).map((dir) => path.join(dir, "tools")) + const toolDirs = getRooDirectoriesForCwd(cwd, true).map((dir) => path.join(dir, "tools")) await customToolRegistry.loadFromDirectoriesIfStale(toolDirs) const customTools = customToolRegistry.getAllSerialized() diff --git a/src/core/webview/webviewMessageHandler.ts b/src/core/webview/webviewMessageHandler.ts index 64bdac2e29..e2087bceb9 100644 --- a/src/core/webview/webviewMessageHandler.ts +++ b/src/core/webview/webviewMessageHandler.ts @@ -914,6 +914,9 @@ export const webviewMessageHandler = async ( const requestedProvider = message?.values?.provider const providerFilter = requestedProvider ? toRouterName(requestedProvider) : undefined + // Optional refresh flag to flush cache before fetching (useful for providers requiring credentials) + const shouldRefresh = message?.values?.refresh === true + const routerModels: Record = providerFilter ? ({} as Record) : { @@ -1021,6 +1024,12 @@ export const webviewMessageHandler = async ( ? candidates.filter(({ key }) => key === providerFilter) : candidates + // If refresh flag is set and we have a specific provider, flush its cache first + if (shouldRefresh && providerFilter && modelFetchPromises.length > 0) { + const targetCandidate = modelFetchPromises[0] + await flushModels(targetCandidate.options, true) + } + const results = await Promise.allSettled( modelFetchPromises.map(async ({ key, options }) => { const models = await safeGetModels(options) @@ -1859,7 +1868,7 @@ export const webviewMessageHandler = async ( } case "refreshCustomTools": { try { - const toolDirs = getRooDirectoriesForCwd(getCurrentCwd()).map((dir) => path.join(dir, "tools")) + const toolDirs = getRooDirectoriesForCwd(getCurrentCwd(), true).map((dir) => path.join(dir, "tools")) await customToolRegistry.loadFromDirectories(toolDirs) await provider.postMessageToWebview({ diff --git a/src/services/roo-config/index.ts b/src/services/roo-config/index.ts index 30bb1a9ed0..9805f56e81 100644 --- a/src/services/roo-config/index.ts +++ b/src/services/roo-config/index.ts @@ -148,7 +148,7 @@ export async function readFileIfExists(filePath: string): Promise * └── index.ts * ``` */ -export function getRooDirectoriesForCwd(cwd: string): string[] { +export function getRooDirectoriesForCwd(cwd: string, ignoreOpenspec = false): string[] { const directories: string[] = [] // Add global directory first @@ -156,7 +156,10 @@ export function getRooDirectoriesForCwd(cwd: string): string[] { // Add project-local directory second directories.push(getProjectRooDirectoryForCwd(cwd)) - directories.push(path.join(getProjectCostrictSpecDirectoryForCwd(cwd), "openspec")) + + if (!ignoreOpenspec) { + directories.push(path.join(getProjectCostrictSpecDirectoryForCwd(cwd), "openspec")) + } return directories } diff --git a/src/services/tree-sitter/__tests__/helpers.ts b/src/services/tree-sitter/__tests__/helpers.ts index 2697639fb7..dbb07c07cf 100644 --- a/src/services/tree-sitter/__tests__/helpers.ts +++ b/src/services/tree-sitter/__tests__/helpers.ts @@ -39,9 +39,13 @@ export async function initializeTreeSitter() { Language.load = async (wasmPath: string) => { const filename = path.basename(wasmPath) - const correctPath = path.join(process.cwd(), "dist", filename) - // console.log(`Redirecting WASM load from ${wasmPath} to ${correctPath}`) - return originalLoad(correctPath) + // Check if we're in the src directory and adjust path accordingly + const cwd = process.cwd() + const distPath = cwd.endsWith("/src") + ? path.join(cwd, "dist", filename) + : path.join(cwd, "src", "dist", filename) + // console.log(`Redirecting WASM load from ${wasmPath} to ${distPath}`) + return originalLoad(distPath) } initializedTreeSitter = { Parser, Language } @@ -85,7 +89,8 @@ export async function testParseSourceCodeDefinitions( const parser = new Parser() // Load language and configure parser - const wasmPath = path.join(process.cwd(), `dist/${wasmFile}`) + const cwd = process.cwd() + const wasmPath = cwd.endsWith("/src") ? path.join(cwd, `dist/${wasmFile}`) : path.join(cwd, `src/dist/${wasmFile}`) const lang = await Language.load(wasmPath) parser.setLanguage(lang) @@ -114,7 +119,10 @@ export async function testParseSourceCodeDefinitions( export async function inspectTreeStructure(content: string, language: string = "typescript"): Promise { const { Parser, Language } = await initializeTreeSitter() const parser = new Parser() - const wasmPath = path.join(process.cwd(), `dist/tree-sitter-${language}.wasm`) + const cwd = process.cwd() + const wasmPath = cwd.endsWith("/src") + ? path.join(cwd, `dist/tree-sitter-${language}.wasm`) + : path.join(cwd, `src/dist/tree-sitter-${language}.wasm`) const lang = await Language.load(wasmPath) parser.setLanguage(lang) diff --git a/src/utils/__tests__/json-schema.spec.ts b/src/utils/__tests__/json-schema.spec.ts index c53e0d7b86..5a1510be43 100644 --- a/src/utils/__tests__/json-schema.spec.ts +++ b/src/utils/__tests__/json-schema.spec.ts @@ -26,10 +26,10 @@ describe("normalizeToolSchema", () => { const result = normalizeToolSchema(input) - // additionalProperties should NOT be added to array or primitive types + // Array-specific properties (items) should be moved inside the array variant + // This is required by strict schema validators like GPT-5-mini expect(result).toEqual({ - anyOf: [{ type: "array" }, { type: "null" }], - items: { type: "string" }, + anyOf: [{ type: "array", items: { type: "string" } }, { type: "null" }], description: "Optional array", }) }) @@ -97,6 +97,7 @@ describe("normalizeToolSchema", () => { const result = normalizeToolSchema(input) // additionalProperties: false should ONLY be on object types + // Array-specific properties (items) should be moved inside the array variant expect(result).toEqual({ type: "array", items: { @@ -104,8 +105,7 @@ describe("normalizeToolSchema", () => { properties: { path: { type: "string" }, line_ranges: { - anyOf: [{ type: "array" }, { type: "null" }], - items: { type: "integer" }, + anyOf: [{ type: "array", items: { type: "integer" } }, { type: "null" }], }, }, additionalProperties: false, @@ -143,7 +143,11 @@ describe("normalizeToolSchema", () => { const properties = result.properties as Record> const filesItems = properties.files.items as Record const filesItemsProps = filesItems.properties as Record> - expect(filesItemsProps.line_ranges.anyOf).toEqual([{ type: "array" }, { type: "null" }]) + // Array-specific properties (items) should be moved inside the array variant + expect(filesItemsProps.line_ranges.anyOf).toEqual([ + { type: "array", items: { type: "array", items: { type: "integer" } } }, + { type: "null" }, + ]) }) it("should recursively transform anyOf arrays", () => { @@ -255,13 +259,26 @@ describe("normalizeToolSchema", () => { const result = normalizeToolSchema(input) - // Verify the line_ranges was transformed + // Verify the line_ranges was transformed with items inside the array variant const files = (result.properties as Record).files as Record const items = files.items as Record const props = items.properties as Record> - expect(props.line_ranges.anyOf).toEqual([{ type: "array" }, { type: "null" }]) - // Verify other properties are preserved - expect(props.line_ranges.items).toBeDefined() + // Array-specific properties (items, minItems, maxItems) should be moved inside the array variant + expect(props.line_ranges.anyOf).toEqual([ + { + type: "array", + items: { + type: "array", + items: { type: "integer" }, + minItems: 2, + maxItems: 2, + }, + }, + { type: "null" }, + ]) + // items should NOT be at root level anymore + expect(props.line_ranges.items).toBeUndefined() + // Other properties are preserved at root level expect(props.line_ranges.description).toBe("Optional line ranges") }) diff --git a/src/utils/__tests__/mcp-name.spec.ts b/src/utils/__tests__/mcp-name.spec.ts index b28c2e504c..5511893f79 100644 --- a/src/utils/__tests__/mcp-name.spec.ts +++ b/src/utils/__tests__/mcp-name.spec.ts @@ -1,4 +1,11 @@ -import { sanitizeMcpName, buildMcpToolName, parseMcpToolName, MCP_TOOL_SEPARATOR, MCP_TOOL_PREFIX } from "../mcp-name" +import { + sanitizeMcpName, + buildMcpToolName, + parseMcpToolName, + isMcpTool, + MCP_TOOL_SEPARATOR, + MCP_TOOL_PREFIX, +} from "../mcp-name" describe("mcp-name utilities", () => { describe("constants", () => { @@ -8,6 +15,29 @@ describe("mcp-name utilities", () => { }) }) + describe("isMcpTool", () => { + it("should return true for valid MCP tool names", () => { + expect(isMcpTool("mcp--server--tool")).toBe(true) + expect(isMcpTool("mcp--my_server--get_forecast")).toBe(true) + }) + + it("should return false for non-MCP tool names", () => { + expect(isMcpTool("server--tool")).toBe(false) + expect(isMcpTool("tool")).toBe(false) + expect(isMcpTool("read_file")).toBe(false) + expect(isMcpTool("")).toBe(false) + }) + + it("should return false for old underscore format", () => { + expect(isMcpTool("mcp_server_tool")).toBe(false) + }) + + it("should return false for partial prefix", () => { + expect(isMcpTool("mcp-server")).toBe(false) + expect(isMcpTool("mcp")).toBe(false) + }) + }) + describe("sanitizeMcpName", () => { it("should return underscore placeholder for empty input", () => { expect(sanitizeMcpName("")).toBe("_") diff --git a/src/utils/json-schema.ts b/src/utils/json-schema.ts index 180a51848b..8059c2ee0d 100644 --- a/src/utils/json-schema.ts +++ b/src/utils/json-schema.ts @@ -23,6 +23,28 @@ const OPENAI_SUPPORTED_FORMATS = new Set([ "uuid", ]) +/** + * Array-specific JSON Schema properties that must be nested inside array type variants + * when converting to anyOf format (JSON Schema draft 2020-12). + */ +const ARRAY_SPECIFIC_PROPERTIES = ["items", "minItems", "maxItems", "uniqueItems"] as const + +/** + * Applies array-specific properties from source to target object. + * Only copies properties that are defined in the source. + */ +function applyArrayProperties( + target: Record, + source: Record, +): Record { + for (const prop of ARRAY_SPECIFIC_PROPERTIES) { + if (source[prop] !== undefined) { + target[prop] = source[prop] + } + } + return target +} + /** * Zod schema for JSON Schema primitive types */ @@ -133,18 +155,42 @@ const NormalizedToolSchemaInternal: z.ZodType, z.ZodType }) .passthrough() .transform((schema) => { - const { type, required, properties, additionalProperties, format, ...rest } = schema + const { + type, + required, + properties, + additionalProperties, + format, + items, + minItems, + maxItems, + uniqueItems, + ...rest + } = schema const result: Record = { ...rest } // Determine if this schema represents an object type const isObjectType = type === "object" || (Array.isArray(type) && type.includes("object")) || properties !== undefined + // Collect array-specific properties for potential use in type handling + const arrayProps = { items, minItems, maxItems, uniqueItems } + // If type is an array, convert to anyOf format (JSON Schema 2020-12) + // Array-specific properties must be moved inside the array variant if (Array.isArray(type)) { - result.anyOf = type.map((t) => ({ type: t })) + result.anyOf = type.map((t) => { + if (t === "array") { + return applyArrayProperties({ type: t }, arrayProps) + } + return { type: t } + }) } else if (type !== undefined) { result.type = type + // For single "array" type, preserve array-specific properties at root + if (type === "array") { + applyArrayProperties(result, arrayProps) + } } // Strip unsupported format values for OpenAI compatibility diff --git a/src/utils/mcp-name.ts b/src/utils/mcp-name.ts index 55845d67ed..c81d5e770f 100644 --- a/src/utils/mcp-name.ts +++ b/src/utils/mcp-name.ts @@ -17,6 +17,16 @@ export const MCP_TOOL_SEPARATOR = "--" */ export const MCP_TOOL_PREFIX = "mcp" +/** + * Check if a tool name is an MCP tool (starts with the MCP prefix and separator). + * + * @param toolName - The tool name to check + * @returns true if the tool name starts with "mcp--", false otherwise + */ +export function isMcpTool(toolName: string): boolean { + return toolName.startsWith(`${MCP_TOOL_PREFIX}${MCP_TOOL_SEPARATOR}`) +} + /** * Sanitize a name to be safe for use in API function names. * This removes special characters and ensures the name starts correctly. diff --git a/webview-ui/src/App.tsx b/webview-ui/src/App.tsx index 119627e450..84818d68eb 100644 --- a/webview-ui/src/App.tsx +++ b/webview-ui/src/App.tsx @@ -30,6 +30,7 @@ import { cn } from "./lib/utils" import { ReauthConfirmationDialog } from "./components/chat/ReauthConfirmationDialog" import { ZgsmCodebaseDisableConfirmDialog } from "./components/settings/ZgsmCodebaseDisableConfirmDialog" import { useTranslation } from "react-i18next" +import { EXPERIMENT_IDS } from "@roo/experiments" // type Tab = "settings" | "history" | "mcp" | "modes" | "chat" | "marketplace" | "cloud" | "zgsm-account" | "codeReview" type Tab = "settings" | "history" | "chat" | "marketplace" | "cloud" | "zgsm-account" | "codeReview" @@ -89,6 +90,7 @@ const App = () => { telemetrySetting, telemetryKey, machineId, + experiments, // cloudUserInfo, // cloudIsAuthenticated, // cloudApiUrl, @@ -250,7 +252,11 @@ const App = () => { // Tell the extension that we are ready to receive messages. useEffect(() => vscode.postMessage({ type: "webviewDidLaunch" }), []) - + useEffect(() => { + if (experiments[EXPERIMENT_IDS.CUSTOM_TOOLS] ?? false) { + vscode.postMessage({ type: "refreshCustomTools" }) + } + }, [experiments]) // Initialize source map support for better error reporting useEffect(() => { // Initialize source maps for better error reporting in production diff --git a/webview-ui/src/components/settings/ApiOptions.tsx b/webview-ui/src/components/settings/ApiOptions.tsx index ef8eb404f7..ee8dac8587 100644 --- a/webview-ui/src/components/settings/ApiOptions.tsx +++ b/webview-ui/src/components/settings/ApiOptions.tsx @@ -478,7 +478,7 @@ const ApiOptions = ({ // 1. User preference (toolProtocol) - handled by the select value binding // 2. Model default - use if available // 3. Native fallback - const defaultProtocol = selectedModelInfo?.defaultToolProtocol || TOOL_PROTOCOL.NATIVE + const defaultProtocol = selectedModelInfo?.defaultToolProtocol || TOOL_PROTOCOL.XML // Show the tool protocol selector when model supports native tools. // For OpenAI Compatible providers we always show it so users can force XML/native explicitly. diff --git a/webview-ui/src/components/settings/providers/Requesty.tsx b/webview-ui/src/components/settings/providers/Requesty.tsx index 0285dbaed0..859d82d03e 100644 --- a/webview-ui/src/components/settings/providers/Requesty.tsx +++ b/webview-ui/src/components/settings/providers/Requesty.tsx @@ -30,7 +30,6 @@ export const Requesty = ({ apiConfiguration, setApiConfigurationField, routerModels, - refetchRouterModels, organizationAllowList, modelValidationError, uriScheme, @@ -127,8 +126,7 @@ export const Requesty = ({