diff --git a/packages/types/src/providers/fireworks.ts b/packages/types/src/providers/fireworks.ts index 1642424045..c9017c54cd 100644 --- a/packages/types/src/providers/fireworks.ts +++ b/packages/types/src/providers/fireworks.ts @@ -4,6 +4,7 @@ export type FireworksModelId = | "accounts/fireworks/models/kimi-k2-instruct" | "accounts/fireworks/models/kimi-k2-instruct-0905" | "accounts/fireworks/models/kimi-k2-thinking" + | "accounts/fireworks/models/kimi-k2p5" | "accounts/fireworks/models/minimax-m2" | "accounts/fireworks/models/minimax-m2p1" | "accounts/fireworks/models/qwen3-235b-a22b-instruct-2507" @@ -60,6 +61,17 @@ export const fireworksModels = { description: "The kimi-k2-thinking model is a general-purpose agentic reasoning model developed by Moonshot AI. Thanks to its strength in deep reasoning and multi-turn tool use, it can solve even the hardest problems.", }, + "accounts/fireworks/models/kimi-k2p5": { + maxTokens: 16384, + contextWindow: 262144, + supportsImages: true, + supportsPromptCache: true, + inputPrice: 0.6, + outputPrice: 3.0, + cacheReadsPrice: 0.1, + description: + "Kimi K2.5 is Moonshot AI's flagship agentic model and a new SOTA open model. It unifies vision and text, thinking and non-thinking modes, and single-agent and multi-agent execution into one model. Fireworks enables users to control the reasoning behavior and inspect its reasoning history for greater transparency.", + }, "accounts/fireworks/models/minimax-m2": { maxTokens: 4096, contextWindow: 204800, diff --git a/packages/types/src/skills.ts b/packages/types/src/skills.ts index b50b4e6d47..3e856612bc 100644 --- a/packages/types/src/skills.ts +++ b/packages/types/src/skills.ts @@ -7,7 +7,17 @@ export interface SkillMetadata { description: string // Required: when to use this skill path: string // Absolute path to SKILL.md (or "" for built-in skills) source: "global" | "project" | "built-in" // Where the skill was discovered - mode?: string // If set, skill is only available in this mode + /** + * @deprecated Use modeSlugs instead. Kept for backward compatibility. + * If set, skill is only available in this mode. + */ + mode?: string + /** + * Mode slugs where this skill is available. + * - undefined or empty array means the skill is available in all modes ("Any mode"). + * - An array with one or more mode slugs restricts the skill to those modes. + */ + modeSlugs?: string[] } /** diff --git a/packages/types/src/vscode-extension-host.ts b/packages/types/src/vscode-extension-host.ts index 2f334cc896..79954f2127 100644 --- a/packages/types/src/vscode-extension-host.ts +++ b/packages/types/src/vscode-extension-host.ts @@ -685,6 +685,7 @@ export interface WebviewMessage { | "createSkill" | "deleteSkill" | "moveSkill" + | "updateSkillModes" | "openSkillFile" text?: string // costrict-start @@ -727,9 +728,15 @@ export interface WebviewMessage { payload?: WebViewMessagePayload source?: "global" | "project" | "built-in" skillName?: string // For skill operations (createSkill, deleteSkill, moveSkill, openSkillFile) + /** @deprecated Use skillModeSlugs instead */ skillMode?: string // For skill operations (current mode restriction) + /** @deprecated Use newSkillModeSlugs instead */ newSkillMode?: string // For moveSkill (target mode) skillDescription?: string // For createSkill (skill description) + /** Mode slugs for skill operations. undefined/empty = any mode */ + skillModeSlugs?: string[] // For skill operations (mode restrictions) + /** Target mode slugs for updateSkillModes */ + newSkillModeSlugs?: string[] // For updateSkillModes (new mode restrictions) requestId?: string ids?: string[] hasSystemPromptOverride?: boolean diff --git a/src/api/providers/__tests__/huggingface.spec.ts b/src/api/providers/__tests__/huggingface.spec.ts new file mode 100644 index 0000000000..e7682474c1 --- /dev/null +++ b/src/api/providers/__tests__/huggingface.spec.ts @@ -0,0 +1,553 @@ +// npx vitest run src/api/providers/__tests__/huggingface.spec.ts + +// Use vi.hoisted to define mock functions that can be referenced in hoisted vi.mock() calls +const { mockStreamText, mockGenerateText } = vi.hoisted(() => ({ + mockStreamText: vi.fn(), + mockGenerateText: vi.fn(), +})) + +vi.mock("ai", async (importOriginal) => { + const actual = await importOriginal() + return { + ...actual, + streamText: mockStreamText, + generateText: mockGenerateText, + } +}) + +vi.mock("@ai-sdk/openai-compatible", () => ({ + createOpenAICompatible: vi.fn(() => { + // Return a function that returns a mock language model + return vi.fn(() => ({ + modelId: "meta-llama/Llama-3.3-70B-Instruct", + provider: "huggingface", + })) + }), +})) + +// Mock the fetchers +vi.mock("../fetchers/huggingface", () => ({ + getHuggingFaceModels: vi.fn(() => Promise.resolve({})), + getCachedHuggingFaceModels: vi.fn(() => ({})), +})) + +import type { Anthropic } from "@anthropic-ai/sdk" + +import type { ApiHandlerOptions } from "../../../shared/api" + +import { HuggingFaceHandler } from "../huggingface" + +describe("HuggingFaceHandler", () => { + let handler: HuggingFaceHandler + let mockOptions: ApiHandlerOptions + + beforeEach(() => { + mockOptions = { + huggingFaceApiKey: "test-huggingface-api-key", + huggingFaceModelId: "meta-llama/Llama-3.3-70B-Instruct", + } + handler = new HuggingFaceHandler(mockOptions) + vi.clearAllMocks() + }) + + describe("constructor", () => { + it("should initialize with provided options", () => { + expect(handler).toBeInstanceOf(HuggingFaceHandler) + expect(handler.getModel().id).toBe(mockOptions.huggingFaceModelId) + }) + + it("should use default model ID if not provided", () => { + const handlerWithoutModel = new HuggingFaceHandler({ + ...mockOptions, + huggingFaceModelId: undefined, + }) + expect(handlerWithoutModel.getModel().id).toBe("meta-llama/Llama-3.3-70B-Instruct") + }) + + it("should throw error if API key is not provided", () => { + expect(() => { + new HuggingFaceHandler({ + ...mockOptions, + huggingFaceApiKey: undefined, + }) + }).toThrow("Hugging Face API key is required") + }) + }) + + describe("getModel", () => { + it("should return default model when no model is specified", () => { + const handlerWithoutModel = new HuggingFaceHandler({ + huggingFaceApiKey: "test-huggingface-api-key", + }) + const model = handlerWithoutModel.getModel() + expect(model.id).toBe("meta-llama/Llama-3.3-70B-Instruct") + expect(model.info).toBeDefined() + }) + + it("should return specified model when valid model is provided", () => { + const testModelId = "mistralai/Mistral-7B-Instruct-v0.3" + const handlerWithModel = new HuggingFaceHandler({ + huggingFaceModelId: testModelId, + huggingFaceApiKey: "test-huggingface-api-key", + }) + const model = handlerWithModel.getModel() + expect(model.id).toBe(testModelId) + }) + + it("should include model parameters from getModelParams", () => { + const model = handler.getModel() + expect(model).toHaveProperty("temperature") + expect(model).toHaveProperty("maxTokens") + }) + + it("should return fallback info when model not in cache", () => { + const model = handler.getModel() + expect(model.info).toEqual( + expect.objectContaining({ + maxTokens: 8192, + contextWindow: 131072, + supportsImages: false, + supportsPromptCache: false, + }), + ) + }) + }) + + describe("createMessage", () => { + const systemPrompt = "You are a helpful assistant." + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: [ + { + type: "text" as const, + text: "Hello!", + }, + ], + }, + ] + + it("should handle streaming responses", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "Test response from HuggingFace" } + } + + const mockUsage = Promise.resolve({ + inputTokens: 10, + outputTokens: 5, + }) + + const mockProviderMetadata = Promise.resolve({}) + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: mockUsage, + providerMetadata: mockProviderMetadata, + }) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(chunks.length).toBeGreaterThan(0) + const textChunks = chunks.filter((chunk) => chunk.type === "text") + expect(textChunks).toHaveLength(1) + expect(textChunks[0].text).toBe("Test response from HuggingFace") + }) + + it("should include usage information", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "Test response" } + } + + const mockUsage = Promise.resolve({ + inputTokens: 10, + outputTokens: 20, + }) + + const mockProviderMetadata = Promise.resolve({}) + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: mockUsage, + providerMetadata: mockProviderMetadata, + }) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + const usageChunks = chunks.filter((chunk) => chunk.type === "usage") + expect(usageChunks.length).toBeGreaterThan(0) + expect(usageChunks[0].inputTokens).toBe(10) + expect(usageChunks[0].outputTokens).toBe(20) + }) + + it("should handle cached tokens in usage data from providerMetadata", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "Test response" } + } + + const mockUsage = Promise.resolve({ + inputTokens: 100, + outputTokens: 50, + }) + + // HuggingFace provides cache metrics via providerMetadata for supported models + const mockProviderMetadata = Promise.resolve({ + huggingface: { + promptCacheHitTokens: 30, + promptCacheMissTokens: 70, + }, + }) + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: mockUsage, + providerMetadata: mockProviderMetadata, + }) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + const usageChunks = chunks.filter((chunk) => chunk.type === "usage") + expect(usageChunks.length).toBeGreaterThan(0) + expect(usageChunks[0].inputTokens).toBe(100) + expect(usageChunks[0].outputTokens).toBe(50) + expect(usageChunks[0].cacheReadTokens).toBe(30) + expect(usageChunks[0].cacheWriteTokens).toBe(70) + }) + + it("should handle usage with details.cachedInputTokens when providerMetadata is not available", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "Test response" } + } + + const mockUsage = Promise.resolve({ + inputTokens: 100, + outputTokens: 50, + details: { + cachedInputTokens: 25, + }, + }) + + const mockProviderMetadata = Promise.resolve({}) + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: mockUsage, + providerMetadata: mockProviderMetadata, + }) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + const usageChunks = chunks.filter((chunk) => chunk.type === "usage") + expect(usageChunks.length).toBeGreaterThan(0) + expect(usageChunks[0].cacheReadTokens).toBe(25) + expect(usageChunks[0].cacheWriteTokens).toBeUndefined() + }) + + it("should pass correct temperature (0.7 default) to streamText", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "Test" } + } + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }), + providerMetadata: Promise.resolve({}), + }) + + const handlerWithDefaultTemp = new HuggingFaceHandler({ + huggingFaceApiKey: "test-key", + huggingFaceModelId: "meta-llama/Llama-3.3-70B-Instruct", + }) + + const stream = handlerWithDefaultTemp.createMessage(systemPrompt, messages) + for await (const _ of stream) { + // consume stream + } + + expect(mockStreamText).toHaveBeenCalledWith( + expect.objectContaining({ + temperature: 0.7, + }), + ) + }) + + it("should use user-specified temperature over provider defaults", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "Test" } + } + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }), + providerMetadata: Promise.resolve({}), + }) + + const handlerWithCustomTemp = new HuggingFaceHandler({ + huggingFaceApiKey: "test-key", + huggingFaceModelId: "meta-llama/Llama-3.3-70B-Instruct", + modelTemperature: 0.7, + }) + + const stream = handlerWithCustomTemp.createMessage(systemPrompt, messages) + for await (const _ of stream) { + // consume stream + } + + // User-specified temperature should take precedence over everything + expect(mockStreamText).toHaveBeenCalledWith( + expect.objectContaining({ + temperature: 0.7, + }), + ) + }) + + it("should handle stream with multiple chunks", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "Hello" } + yield { type: "text-delta", text: " world" } + } + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 5, outputTokens: 10 }), + providerMetadata: Promise.resolve({}), + }) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + const textChunks = chunks.filter((c) => c.type === "text") + expect(textChunks[0]).toEqual({ type: "text", text: "Hello" }) + expect(textChunks[1]).toEqual({ type: "text", text: " world" }) + + const usageChunks = chunks.filter((c) => c.type === "usage") + expect(usageChunks[0]).toMatchObject({ type: "usage", inputTokens: 5, outputTokens: 10 }) + }) + + it("should handle errors with handleAiSdkError", async () => { + async function* mockFullStream(): AsyncGenerator { + yield { type: "text-delta", text: "" } // Yield something before error to satisfy lint + throw new Error("API Error") + } + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 0, outputTokens: 0 }), + providerMetadata: Promise.resolve({}), + }) + + const stream = handler.createMessage(systemPrompt, messages) + + await expect(async () => { + for await (const _ of stream) { + // consume stream + } + }).rejects.toThrow("HuggingFace: API Error") + }) + }) + + describe("completePrompt", () => { + it("should complete a prompt using generateText", async () => { + mockGenerateText.mockResolvedValue({ + text: "Test completion from HuggingFace", + }) + + const result = await handler.completePrompt("Test prompt") + + expect(result).toBe("Test completion from HuggingFace") + expect(mockGenerateText).toHaveBeenCalledWith( + expect.objectContaining({ + prompt: "Test prompt", + }), + ) + }) + + it("should use default temperature in completePrompt", async () => { + mockGenerateText.mockResolvedValue({ + text: "Test completion", + }) + + await handler.completePrompt("Test prompt") + + expect(mockGenerateText).toHaveBeenCalledWith( + expect.objectContaining({ + temperature: 0.7, + }), + ) + }) + }) + + describe("processUsageMetrics", () => { + it("should correctly process usage metrics including cache information from providerMetadata", () => { + class TestHuggingFaceHandler extends HuggingFaceHandler { + public testProcessUsageMetrics(usage: any, providerMetadata?: any) { + return this.processUsageMetrics(usage, providerMetadata) + } + } + + const testHandler = new TestHuggingFaceHandler(mockOptions) + + const usage = { + inputTokens: 100, + outputTokens: 50, + } + + const providerMetadata = { + huggingface: { + promptCacheHitTokens: 20, + promptCacheMissTokens: 80, + }, + } + + const result = testHandler.testProcessUsageMetrics(usage, providerMetadata) + + expect(result.type).toBe("usage") + expect(result.inputTokens).toBe(100) + expect(result.outputTokens).toBe(50) + expect(result.cacheWriteTokens).toBe(80) + expect(result.cacheReadTokens).toBe(20) + }) + + it("should handle missing cache metrics gracefully", () => { + class TestHuggingFaceHandler extends HuggingFaceHandler { + public testProcessUsageMetrics(usage: any, providerMetadata?: any) { + return this.processUsageMetrics(usage, providerMetadata) + } + } + + const testHandler = new TestHuggingFaceHandler(mockOptions) + + const usage = { + inputTokens: 100, + outputTokens: 50, + } + + const result = testHandler.testProcessUsageMetrics(usage) + + expect(result.type).toBe("usage") + expect(result.inputTokens).toBe(100) + expect(result.outputTokens).toBe(50) + expect(result.cacheWriteTokens).toBeUndefined() + expect(result.cacheReadTokens).toBeUndefined() + }) + + it("should include reasoning tokens when provided", () => { + class TestHuggingFaceHandler extends HuggingFaceHandler { + public testProcessUsageMetrics(usage: any, providerMetadata?: any) { + return this.processUsageMetrics(usage, providerMetadata) + } + } + + const testHandler = new TestHuggingFaceHandler(mockOptions) + + const usage = { + inputTokens: 100, + outputTokens: 50, + details: { + reasoningTokens: 30, + }, + } + + const result = testHandler.testProcessUsageMetrics(usage) + + expect(result.reasoningTokens).toBe(30) + }) + }) + + describe("tool handling", () => { + const systemPrompt = "You are a helpful assistant." + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: [{ type: "text" as const, text: "Hello!" }], + }, + ] + + it("should handle tool calls in streaming", async () => { + async function* mockFullStream() { + yield { + type: "tool-input-start", + id: "tool-call-1", + toolName: "read_file", + } + yield { + type: "tool-input-delta", + id: "tool-call-1", + delta: '{"path":"test.ts"}', + } + yield { + type: "tool-input-end", + id: "tool-call-1", + } + } + + const mockUsage = Promise.resolve({ + inputTokens: 10, + outputTokens: 5, + }) + + const mockProviderMetadata = Promise.resolve({}) + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: mockUsage, + providerMetadata: mockProviderMetadata, + }) + + const stream = handler.createMessage(systemPrompt, messages, { + taskId: "test-task", + tools: [ + { + type: "function", + function: { + name: "read_file", + description: "Read a file", + parameters: { + type: "object", + properties: { path: { type: "string" } }, + required: ["path"], + }, + }, + }, + ], + }) + + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + const toolCallStartChunks = chunks.filter((c) => c.type === "tool_call_start") + const toolCallDeltaChunks = chunks.filter((c) => c.type === "tool_call_delta") + const toolCallEndChunks = chunks.filter((c) => c.type === "tool_call_end") + + expect(toolCallStartChunks.length).toBe(1) + expect(toolCallStartChunks[0].id).toBe("tool-call-1") + expect(toolCallStartChunks[0].name).toBe("read_file") + + expect(toolCallDeltaChunks.length).toBe(1) + expect(toolCallDeltaChunks[0].delta).toBe('{"path":"test.ts"}') + + expect(toolCallEndChunks.length).toBe(1) + expect(toolCallEndChunks[0].id).toBe("tool-call-1") + }) + }) +}) diff --git a/src/api/providers/huggingface.ts b/src/api/providers/huggingface.ts index f6e54ec07e..ead3f432fa 100644 --- a/src/api/providers/huggingface.ts +++ b/src/api/providers/huggingface.ts @@ -1,22 +1,37 @@ -import OpenAI from "openai" import { Anthropic } from "@anthropic-ai/sdk" +import { createOpenAICompatible } from "@ai-sdk/openai-compatible" +import { streamText, generateText, ToolSet } from "ai" -import type { ModelRecord } from "@roo-code/types" +import type { ModelRecord, ModelInfo } from "@roo-code/types" import type { ApiHandlerOptions } from "../../shared/api" -import { ApiStream } from "../transform/stream" -import { convertToOpenAiMessages } from "../transform/openai-format" -import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" + +import { + convertToAiSdkMessages, + convertToolsForAiSdk, + processAiSdkStreamPart, + mapToolChoice, + handleAiSdkError, +} from "../transform/ai-sdk" +import { ApiStream, ApiStreamUsageChunk } from "../transform/stream" +import { getModelParams } from "../transform/model-params" + import { DEFAULT_HEADERS } from "./constants" import { BaseProvider } from "./base-provider" import { getHuggingFaceModels, getCachedHuggingFaceModels } from "./fetchers/huggingface" -import { handleOpenAIError } from "./utils/openai-error-handler" +import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" + +const HUGGINGFACE_DEFAULT_TEMPERATURE = 0.7 +/** + * HuggingFace provider using @ai-sdk/openai-compatible for OpenAI-compatible API. + * Uses HuggingFace's OpenAI-compatible endpoint to enable tool message support. + * @see https://github.com/vercel/ai/issues/10766 - Workaround for tool messages not supported in @ai-sdk/huggingface + */ export class HuggingFaceHandler extends BaseProvider implements SingleCompletionHandler { - private client: OpenAI - private options: ApiHandlerOptions + protected options: ApiHandlerOptions + protected provider: ReturnType private modelCache: ModelRecord | null = null - private readonly providerName = "HuggingFace" constructor(options: ApiHandlerOptions) { super() @@ -26,10 +41,14 @@ export class HuggingFaceHandler extends BaseProvider implements SingleCompletion throw new Error("Hugging Face API key is required") } - this.client = new OpenAI({ + // Create an OpenAI-compatible provider pointing to HuggingFace's /v1 endpoint + // This fixes "tool messages not supported" error - the HuggingFace SDK doesn't + // properly handle function_call_output format, but OpenAI SDK does + this.provider = createOpenAICompatible({ + name: "huggingface", baseURL: "https://router.huggingface.co/v1", apiKey: this.options.huggingFaceApiKey, - defaultHeaders: DEFAULT_HEADERS, + headers: DEFAULT_HEADERS, }) // Try to get cached models first @@ -47,94 +66,147 @@ export class HuggingFaceHandler extends BaseProvider implements SingleCompletion } } - override async *createMessage( - systemPrompt: string, - messages: Anthropic.Messages.MessageParam[], - metadata?: ApiHandlerCreateMessageMetadata, - ): ApiStream { - const modelId = this.options.huggingFaceModelId || "meta-llama/Llama-3.3-70B-Instruct" - const temperature = this.options.modelTemperature ?? 0.7 - - const params: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = { - model: modelId, - temperature, - messages: [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)], - stream: true, - stream_options: { include_usage: true }, - } + override getModel(): { id: string; info: ModelInfo; maxTokens?: number; temperature?: number } { + const id = this.options.huggingFaceModelId || "meta-llama/Llama-3.3-70B-Instruct" - // Add max_tokens if specified - if (this.options.includeMaxTokens && this.options.modelMaxTokens) { - params.max_tokens = this.options.modelMaxTokens - } + // Try to get model info from cache + const cachedInfo = this.modelCache?.[id] - let stream - try { - stream = await this.client.chat.completions.create(params) - } catch (error) { - throw handleOpenAIError(error, this.providerName) + const info: ModelInfo = cachedInfo || { + maxTokens: 8192, + contextWindow: 131072, + supportsImages: false, + supportsPromptCache: false, } - for await (const chunk of stream) { - const delta = chunk.choices[0]?.delta + const params = getModelParams({ + format: "openai", + modelId: id, + model: info, + settings: this.options, + defaultTemperature: HUGGINGFACE_DEFAULT_TEMPERATURE, + }) - if (delta?.content) { - yield { - type: "text", - text: delta.content, - } - } + return { id, info, ...params } + } - if (chunk.usage) { - yield { - type: "usage", - inputTokens: chunk.usage.prompt_tokens || 0, - outputTokens: chunk.usage.completion_tokens || 0, - } + /** + * Get the language model for the configured model ID. + */ + protected getLanguageModel() { + const { id } = this.getModel() + return this.provider(id) + } + + /** + * Process usage metrics from the AI SDK response. + */ + protected processUsageMetrics( + usage: { + inputTokens?: number + outputTokens?: number + details?: { + cachedInputTokens?: number + reasoningTokens?: number + } + }, + providerMetadata?: { + huggingface?: { + promptCacheHitTokens?: number + promptCacheMissTokens?: number } + }, + ): ApiStreamUsageChunk { + // Extract cache metrics from HuggingFace's providerMetadata if available + const cacheReadTokens = providerMetadata?.huggingface?.promptCacheHitTokens ?? usage.details?.cachedInputTokens + const cacheWriteTokens = providerMetadata?.huggingface?.promptCacheMissTokens + + return { + type: "usage", + inputTokens: usage.inputTokens || 0, + outputTokens: usage.outputTokens || 0, + cacheReadTokens, + cacheWriteTokens, + reasoningTokens: usage.details?.reasoningTokens, } } - async completePrompt(prompt: string, systemPrompt?: string, metadata?: any): Promise { - const modelId = this.options.huggingFaceModelId || "meta-llama/Llama-3.3-70B-Instruct" + /** + * Get the max tokens parameter to include in the request. + */ + protected getMaxOutputTokens(): number | undefined { + const { info } = this.getModel() + return this.options.modelMaxTokens || info.maxTokens || undefined + } - try { - const response = await this.client.chat.completions.create( - { - model: modelId, - messages: [{ role: "user", content: prompt }], - }, - { signal: metadata?.signal }, - ) - - return response.choices[0]?.message.content || "" - } catch (error) { - throw handleOpenAIError(error, this.providerName) + /** + * Create a message stream using the AI SDK. + */ + override async *createMessage( + systemPrompt: string, + messages: Anthropic.Messages.MessageParam[], + metadata?: ApiHandlerCreateMessageMetadata, + ): ApiStream { + const { temperature } = this.getModel() + const languageModel = this.getLanguageModel() + + // Convert messages to AI SDK format + const aiSdkMessages = convertToAiSdkMessages(messages) + + // Convert tools to OpenAI format first, then to AI SDK format + const openAiTools = this.convertToolsForOpenAI(metadata?.tools) + const aiSdkTools = convertToolsForAiSdk(openAiTools) as ToolSet | undefined + + // Build the request options + const requestOptions: Parameters[0] = { + model: languageModel, + system: systemPrompt, + messages: aiSdkMessages, + temperature: this.options.modelTemperature ?? temperature ?? HUGGINGFACE_DEFAULT_TEMPERATURE, + maxOutputTokens: this.getMaxOutputTokens(), + tools: aiSdkTools, + toolChoice: mapToolChoice(metadata?.tool_choice), } - } - override getModel() { - const modelId = this.options.huggingFaceModelId || "meta-llama/Llama-3.3-70B-Instruct" + // Use streamText for streaming responses + const result = streamText(requestOptions) - // Try to get model info from cache - const modelInfo = this.modelCache?.[modelId] + try { + // Process the full stream to get all events + for await (const part of result.fullStream) { + // Use the processAiSdkStreamPart utility to convert stream parts + for (const chunk of processAiSdkStreamPart(part)) { + yield chunk + } + } - if (modelInfo) { - return { - id: modelId, - info: modelInfo, + // Yield usage metrics at the end, including cache metrics from providerMetadata + const usage = await result.usage + const providerMetadata = await result.providerMetadata + if (usage) { + yield this.processUsageMetrics(usage, providerMetadata as any) } + } catch (error) { + // Handle AI SDK errors (AI_RetryError, AI_APICallError, etc.) + throw handleAiSdkError(error, "HuggingFace") } + } - // Fallback to default values if model not found in cache - return { - id: modelId, - info: { - maxTokens: 8192, - contextWindow: 131072, - supportsImages: false, - supportsPromptCache: false, - }, - } + /** + * Complete a prompt using the AI SDK generateText. + */ + async completePrompt(prompt: string, systemPrompt?: string, metadata?: any): Promise { + const { temperature } = this.getModel() + const languageModel = this.getLanguageModel() + + const { text } = await generateText({ + model: languageModel, + prompt, + maxOutputTokens: this.getMaxOutputTokens(), + temperature: this.options.modelTemperature ?? temperature ?? HUGGINGFACE_DEFAULT_TEMPERATURE, + abortSignal: metadata?.signal, + }) + + return text } } diff --git a/src/api/providers/minimax.ts b/src/api/providers/minimax.ts index bfcf4e3be4..54c78c1eca 100644 --- a/src/api/providers/minimax.ts +++ b/src/api/providers/minimax.ts @@ -258,7 +258,7 @@ export class MiniMaxHandler extends BaseProvider implements SingleCompletionHand content: typeof message.content === "string" ? [{ type: "text", text: message.content, cache_control: cacheControl }] - : message.content.map((content, contentIndex) => + : (message?.content || []).map((content, contentIndex) => contentIndex === message.content.length - 1 ? { ...content, cache_control: cacheControl } : content, diff --git a/src/core/condense/index.ts b/src/core/condense/index.ts index e4b33d9b83..2ac483dfcc 100644 --- a/src/core/condense/index.ts +++ b/src/core/condense/index.ts @@ -75,7 +75,7 @@ export function convertToolBlocksToText( return content } - return content.map((block) => { + return (content ?? []).map((block) => { if (block.type === "tool_use") { return { type: "text" as const, diff --git a/src/core/prompts/tools/native-tools/read_file.ts b/src/core/prompts/tools/native-tools/read_file.ts index a7cf9749b8..b175c3c749 100644 --- a/src/core/prompts/tools/native-tools/read_file.ts +++ b/src/core/prompts/tools/native-tools/read_file.ts @@ -3,10 +3,10 @@ import type OpenAI from "openai" // ─── Constants ──────────────────────────────────────────────────────────────── /** Default maximum lines to return per file (Codex-inspired predictable limit) */ -export const DEFAULT_LINE_LIMIT = 1500 +export const DEFAULT_LINE_LIMIT = 1600 /** Maximum characters per line before truncation */ -export const MAX_LINE_LENGTH = 1500 +export const MAX_LINE_LENGTH = 1600 /** Default indentation levels to include above anchor (0 = unlimited) */ export const DEFAULT_MAX_LEVELS = 0 diff --git a/src/core/webview/__tests__/skillsMessageHandler.spec.ts b/src/core/webview/__tests__/skillsMessageHandler.spec.ts index f26194ee81..cdc571282f 100644 --- a/src/core/webview/__tests__/skillsMessageHandler.spec.ts +++ b/src/core/webview/__tests__/skillsMessageHandler.spec.ts @@ -52,6 +52,7 @@ describe("skillsMessageHandler", () => { const mockDeleteSkill = vi.fn() const mockMoveSkill = vi.fn() const mockGetSkill = vi.fn() + const mockFindSkillByNameAndSource = vi.fn() const createMockProvider = (hasSkillsManager: boolean = true): ClineProvider => { const skillsManager = hasSkillsManager @@ -61,6 +62,7 @@ describe("skillsMessageHandler", () => { deleteSkill: mockDeleteSkill, moveSkill: mockMoveSkill, getSkill: mockGetSkill, + findSkillByNameAndSource: mockFindSkillByNameAndSource, } : undefined @@ -158,7 +160,7 @@ describe("skillsMessageHandler", () => { } as WebviewMessage) expect(result).toEqual(mockSkills) - expect(mockCreateSkill).toHaveBeenCalledWith("new-skill", "project", "New skill description", "code") + expect(mockCreateSkill).toHaveBeenCalledWith("new-skill", "project", "New skill description", ["code"]) }) it("returns undefined when required fields are missing", async () => { @@ -355,7 +357,7 @@ describe("skillsMessageHandler", () => { describe("handleOpenSkillFile", () => { it("opens a skill file successfully", async () => { const provider = createMockProvider(true) - mockGetSkill.mockReturnValue(mockSkills[0]) + mockFindSkillByNameAndSource.mockReturnValue(mockSkills[0]) await handleOpenSkillFile(provider, { type: "openSkillFile", @@ -363,13 +365,13 @@ describe("skillsMessageHandler", () => { source: "global", } as WebviewMessage) - expect(mockGetSkill).toHaveBeenCalledWith("test-skill", "global", undefined) + expect(mockFindSkillByNameAndSource).toHaveBeenCalledWith("test-skill", "global") expect(openFile).toHaveBeenCalledWith("/path/to/test-skill/SKILL.md") }) it("opens a skill file with mode restriction", async () => { const provider = createMockProvider(true) - mockGetSkill.mockReturnValue(mockSkills[1]) + mockFindSkillByNameAndSource.mockReturnValue(mockSkills[1]) await handleOpenSkillFile(provider, { type: "openSkillFile", @@ -378,7 +380,7 @@ describe("skillsMessageHandler", () => { skillMode: "code", } as WebviewMessage) - expect(mockGetSkill).toHaveBeenCalledWith("project-skill", "project", "code") + expect(mockFindSkillByNameAndSource).toHaveBeenCalledWith("project-skill", "project") expect(openFile).toHaveBeenCalledWith("/project/.roo/skills/project-skill/SKILL.md") }) @@ -416,7 +418,7 @@ describe("skillsMessageHandler", () => { it("shows error when skill is not found", async () => { const provider = createMockProvider(true) - mockGetSkill.mockReturnValue(undefined) + mockFindSkillByNameAndSource.mockReturnValue(undefined) await handleOpenSkillFile(provider, { type: "openSkillFile", diff --git a/src/core/webview/skillsMessageHandler.ts b/src/core/webview/skillsMessageHandler.ts index f09f22f58c..f5db0473fb 100644 --- a/src/core/webview/skillsMessageHandler.ts +++ b/src/core/webview/skillsMessageHandler.ts @@ -38,7 +38,8 @@ export async function handleCreateSkill( const skillName = message.skillName const source = message.source const skillDescription = message.skillDescription - const skillMode = message.skillMode + // Support new modeSlugs array or fall back to legacy skillMode + const modeSlugs = message.skillModeSlugs ?? (message.skillMode ? [message.skillMode] : undefined) if (!skillName || !source || !skillDescription) { throw new Error(t("skills:errors.missing_create_fields")) @@ -54,7 +55,7 @@ export async function handleCreateSkill( throw new Error(t("skills:errors.manager_unavailable")) } - const createdPath = await skillsManager.createSkill(skillName, source, skillDescription, skillMode) + const createdPath = await skillsManager.createSkill(skillName, source, skillDescription, modeSlugs) // Open the created file in the editor openFile(createdPath) @@ -81,7 +82,8 @@ export async function handleDeleteSkill( try { const skillName = message.skillName const source = message.source - const skillMode = message.skillMode + // Support new skillModeSlugs array or fall back to legacy skillMode + const skillMode = message.skillModeSlugs?.[0] ?? message.skillMode if (!skillName || !source) { throw new Error(t("skills:errors.missing_delete_fields")) @@ -152,6 +154,46 @@ export async function handleMoveSkill( } } +/** + * Handles the updateSkillModes message - updates the mode associations for a skill + */ +export async function handleUpdateSkillModes( + provider: ClineProvider, + message: WebviewMessage, +): Promise { + try { + const skillName = message.skillName + const source = message.source + const newModeSlugs = message.newSkillModeSlugs + + if (!skillName || !source) { + throw new Error(t("skills:errors.missing_update_modes_fields")) + } + + // Built-in skills cannot be modified + if (source === "built-in") { + throw new Error(t("skills:errors.cannot_modify_builtin")) + } + + const skillsManager = provider.getSkillsManager() + if (!skillsManager) { + throw new Error(t("skills:errors.manager_unavailable")) + } + + await skillsManager.updateSkillModes(skillName, source, newModeSlugs) + + // Send updated skills list + const skills = skillsManager.getSkillsMetadata() + await provider.postMessageToWebview({ type: "skills", skills }) + return skills + } catch (error) { + const errorMessage = error instanceof Error ? error.message : String(error) + provider.log(`Error updating skill modes: ${errorMessage}`) + vscode.window.showErrorMessage(`Failed to update skill modes: ${errorMessage}`) + return undefined + } +} + /** * Handles the openSkillFile message - opens a skill file in the editor */ @@ -159,7 +201,6 @@ export async function handleOpenSkillFile(provider: ClineProvider, message: Webv try { const skillName = message.skillName const source = message.source - const skillMode = message.skillMode if (!skillName || !source) { throw new Error(t("skills:errors.missing_delete_fields")) @@ -175,7 +216,8 @@ export async function handleOpenSkillFile(provider: ClineProvider, message: Webv throw new Error(t("skills:errors.manager_unavailable")) } - const skill = skillsManager.getSkill(skillName, source, skillMode) + // Find skill by name and source (skills may have modeSlugs arrays now) + const skill = skillsManager.findSkillByNameAndSource(skillName, source) if (!skill) { throw new Error(t("skills:errors.skill_not_found", { name: skillName })) } diff --git a/src/core/webview/webviewMessageHandler.ts b/src/core/webview/webviewMessageHandler.ts index 3c240cf079..17ee3049c4 100644 --- a/src/core/webview/webviewMessageHandler.ts +++ b/src/core/webview/webviewMessageHandler.ts @@ -41,6 +41,7 @@ import { handleCreateSkill, handleDeleteSkill, handleMoveSkill, + handleUpdateSkillModes, handleOpenSkillFile, } from "./skillsMessageHandler" import { changeLanguage, t } from "../../i18n" @@ -3276,6 +3277,10 @@ export const webviewMessageHandler = async ( await handleMoveSkill(provider, message) break } + case "updateSkillModes": { + await handleUpdateSkillModes(provider, message) + break + } case "openSkillFile": { await handleOpenSkillFile(provider, message) break diff --git a/src/i18n/locales/en/skills.json b/src/i18n/locales/en/skills.json index ef4d7e68e3..5b6dde45b9 100644 --- a/src/i18n/locales/en/skills.json +++ b/src/i18n/locales/en/skills.json @@ -8,6 +8,7 @@ "not_found": "Skill \"{{name}}\" not found in {{source}}{{modeInfo}}", "missing_create_fields": "Missing required fields: skillName, source, or skillDescription", "missing_move_fields": "Missing required fields: skillName or source", + "missing_update_modes_fields": "Missing required fields: skillName or source", "manager_unavailable": "Skills manager not available", "missing_delete_fields": "Missing required fields: skillName or source", "skill_not_found": "Skill \"{{name}}\" not found", diff --git a/src/i18n/locales/zh-CN/skills.json b/src/i18n/locales/zh-CN/skills.json index 566f583fee..719bc722a5 100644 --- a/src/i18n/locales/zh-CN/skills.json +++ b/src/i18n/locales/zh-CN/skills.json @@ -8,6 +8,7 @@ "not_found": "在 {{source}}{{modeInfo}} 中未找到技能 \"{{name}}\"", "missing_create_fields": "缺少必填字段:skillName、source 或 skillDescription", "missing_move_fields": "缺少必填字段:skillName 或 source", + "missing_update_modes_fields": "缺少必填字段:skillName 或 source", "manager_unavailable": "技能管理器不可用", "missing_delete_fields": "缺少必填字段:skillName 或 source", "skill_not_found": "未找到技能 \"{{name}}\"", diff --git a/src/i18n/locales/zh-TW/skills.json b/src/i18n/locales/zh-TW/skills.json index 633bb1a6b2..2d9a52be1e 100644 --- a/src/i18n/locales/zh-TW/skills.json +++ b/src/i18n/locales/zh-TW/skills.json @@ -8,6 +8,7 @@ "not_found": "在 {{source}}{{modeInfo}} 中找不到技能「{{name}}」", "missing_create_fields": "缺少必填欄位:skillName、source 或 skillDescription", "missing_move_fields": "缺少必填欄位:skillName 或 source", + "missing_update_modes_fields": "缺少必填欄位:skillName 或 source", "manager_unavailable": "技能管理器無法使用", "missing_delete_fields": "缺少必填欄位:skillName 或 source", "skill_not_found": "找不到技能「{{name}}」", diff --git a/src/integrations/claude-code/streaming-client.ts b/src/integrations/claude-code/streaming-client.ts index 1bd9706124..47b0214987 100644 --- a/src/integrations/claude-code/streaming-client.ts +++ b/src/integrations/claude-code/streaming-client.ts @@ -122,7 +122,7 @@ function addMessageCacheBreakpoints(messages: Anthropic.Messages.MessageParam[]) } // Handle array content - add cache_control to the last text block - const contentWithCache = message.content.map((block, blockIndex) => { + const contentWithCache = (message?.content || []).map((block, blockIndex) => { // Find the last text block index let lastTextIndex = -1 for (let i = message.content.length - 1; i >= 0; i--) { @@ -222,7 +222,7 @@ function prefixToolNamesInMessages(messages: Anthropic.Messages.MessageParam[]): return message } - const processedContent = message.content.map((block) => { + const processedContent = (message?.content || []).map((block) => { // Prefix tool_use block names if ((block as { type: string }).type === "tool_use") { const toolUseBlock = block as { type: "tool_use"; id: string; name: string; input: unknown } diff --git a/src/services/skills/SkillsManager.ts b/src/services/skills/SkillsManager.ts index 51345945cb..3435629ccd 100644 --- a/src/services/skills/SkillsManager.ts +++ b/src/services/skills/SkillsManager.ts @@ -143,15 +143,34 @@ export class SkillsManager { return } - // Create unique key combining name, source, and mode for override resolution - const skillKey = this.getSkillKey(effectiveSkillName, source, mode) + // Parse modeSlugs from frontmatter (new format) or fall back to directory-based mode + // Priority: frontmatter.modeSlugs > frontmatter.mode > directory mode + let modeSlugs: string[] | undefined + if (Array.isArray(frontmatter.modeSlugs)) { + modeSlugs = frontmatter.modeSlugs.filter((s: unknown) => typeof s === "string" && s.length > 0) + if (modeSlugs.length === 0) { + modeSlugs = undefined // Empty array means "any mode" + } + } else if (typeof frontmatter.mode === "string" && frontmatter.mode.length > 0) { + // Legacy single mode in frontmatter + modeSlugs = [frontmatter.mode] + } else if (mode) { + // Fall back to directory-based mode (skills-{mode}/) + modeSlugs = [mode] + } + + // Create unique key combining name, source, and modeSlugs for override resolution + // For backward compatibility, use first mode slug or undefined for the key + const primaryMode = modeSlugs?.[0] + const skillKey = this.getSkillKey(effectiveSkillName, source, primaryMode) this.skills.set(skillKey, { name: effectiveSkillName, description, path: skillMdPath, source, - mode, // undefined for generic skills, string for mode-specific + mode: primaryMode, // Deprecated: kept for backward compatibility + modeSlugs, // New: array of mode slugs, undefined = any mode }) } catch (error) { console.error(`Failed to load skill at ${skillDir}:`, error) @@ -174,8 +193,11 @@ export class SkillsManager { // Then, add discovered skills (will override built-in skills with same name) for (const skill of this.skills.values()) { - // Skip mode-specific skills that don't match current mode - if (skill.mode && skill.mode !== currentMode) continue + // Check if skill is available in current mode: + // - modeSlugs undefined or empty = available in all modes ("Any mode") + // - modeSlugs array with values = available only if currentMode is in the array + const isAvailableInMode = this.isSkillAvailableInMode(skill, currentMode) + if (!isAvailableInMode) continue const existingSkill = resolvedSkills.get(skill.name) @@ -194,6 +216,20 @@ export class SkillsManager { return Array.from(resolvedSkills.values()) } + /** + * Check if a skill is available in the given mode. + * - modeSlugs undefined or empty = available in all modes ("Any mode") + * - modeSlugs with values = available only if mode is in the array + */ + private isSkillAvailableInMode(skill: SkillMetadata, currentMode: string): boolean { + // No mode restrictions = available in all modes + if (!skill.modeSlugs || skill.modeSlugs.length === 0) { + return true + } + // Check if current mode is in the allowed modes + return skill.modeSlugs.includes(currentMode) + } + /** * Determine if newSkill should override existingSkill based on priority rules. * Priority: project > global > built-in, mode-specific > generic @@ -214,8 +250,11 @@ export class SkillsManager { if (newPriority < existingPriority) return false // Same source: mode-specific overrides generic - if (newSkill.mode && !existing.mode) return true - if (!newSkill.mode && existing.mode) return false + // A skill with modeSlugs (restricted) is more specific than one without (any mode) + const existingHasModes = existing.modeSlugs && existing.modeSlugs.length > 0 + const newHasModes = newSkill.modeSlugs && newSkill.modeSlugs.length > 0 + if (newHasModes && !existingHasModes) return true + if (!newHasModes && existingHasModes) return false // Same source and same mode-specificity: keep existing (first wins) return false @@ -276,6 +315,19 @@ export class SkillsManager { return this.skills.get(skillKey) } + /** + * Find a skill by name and source (regardless of mode). + * Useful for opening/editing skills where the exact mode key may vary. + */ + findSkillByNameAndSource(name: string, source: "global" | "project"): SkillMetadata | undefined { + for (const skill of this.skills.values()) { + if (skill.name === name && skill.source === source) { + return skill + } + } + return undefined + } + /** * Validate skill name per agentskills.io spec using shared validation. * Converts error codes to user-friendly error messages. @@ -307,10 +359,15 @@ export class SkillsManager { * @param name - Skill name (must be valid per agentskills.io spec) * @param source - "global" or "project" * @param description - Skill description - * @param mode - Optional mode restriction (creates in skills-{mode}/ directory) + * @param modeSlugs - Optional mode restrictions (undefined/empty = any mode) * @returns Path to created SKILL.md file */ - async createSkill(name: string, source: "global" | "project", description: string, mode?: string): Promise { + async createSkill( + name: string, + source: "global" | "project", + description: string, + modeSlugs?: string[], + ): Promise { // Validate skill name const validation = this.validateSkillName(name) if (!validation.valid) { @@ -335,9 +392,8 @@ export class SkillsManager { baseDir = path.join(provider.cwd, ".roo") } - // Determine skills directory (with optional mode suffix) - const skillsDirName = mode ? `skills-${mode}` : "skills" - const skillsDir = path.join(baseDir, skillsDirName) + // Always use the generic skills directory (mode info stored in frontmatter now) + const skillsDir = path.join(baseDir, "skills") const skillDir = path.join(skillsDir, name) const skillMdPath = path.join(skillDir, "SKILL.md") @@ -355,9 +411,17 @@ export class SkillsManager { .map((word) => word.charAt(0).toUpperCase() + word.slice(1)) .join(" ") + // Build frontmatter with optional modeSlugs + const frontmatterLines = [`name: ${name}`, `description: ${trimmedDescription}`] + if (modeSlugs && modeSlugs.length > 0) { + frontmatterLines.push(`modeSlugs:`) + for (const slug of modeSlugs) { + frontmatterLines.push(` - ${slug}`) + } + } + const skillContent = `--- -name: ${name} -description: ${trimmedDescription} +${frontmatterLines.join("\n")} --- # ${titleName} @@ -471,6 +535,49 @@ Add your skill instructions here. await this.discoverSkills() } + /** + * Update the mode associations for a skill by modifying its SKILL.md frontmatter. + * @param name - Skill name + * @param source - Where the skill is located ("global" or "project") + * @param newModeSlugs - New mode slugs (undefined/empty = any mode) + */ + async updateSkillModes(name: string, source: "global" | "project", newModeSlugs?: string[]): Promise { + // Find any skill with this name and source (regardless of current mode) + let skill: SkillMetadata | undefined + for (const s of this.skills.values()) { + if (s.name === name && s.source === source) { + skill = s + break + } + } + + if (!skill) { + throw new Error(t("skills:errors.not_found", { name, source, modeInfo: "" })) + } + + // Read the current SKILL.md file + const fileContent = await fs.readFile(skill.path, "utf-8") + const { data: frontmatter, content: body } = matter(fileContent) + + // Update the frontmatter with new modeSlugs + if (newModeSlugs && newModeSlugs.length > 0) { + frontmatter.modeSlugs = newModeSlugs + // Remove legacy mode field if present + delete frontmatter.mode + } else { + // Empty/undefined = any mode, remove mode restrictions + delete frontmatter.modeSlugs + delete frontmatter.mode + } + + // Serialize back to SKILL.md format + const newContent = matter.stringify(body, frontmatter) + await fs.writeFile(skill.path, newContent, "utf-8") + + // Refresh skills list + await this.discoverSkills() + } + /** * Get all skills directories to scan, including mode-specific directories. */ diff --git a/src/services/skills/__tests__/SkillsManager.spec.ts b/src/services/skills/__tests__/SkillsManager.spec.ts index 780a295f1c..8e364b68db 100644 --- a/src/services/skills/__tests__/SkillsManager.spec.ts +++ b/src/services/skills/__tests__/SkillsManager.spec.ts @@ -1006,7 +1006,7 @@ Instructions`) expect(writeCall[1]).toContain("description: A new skill description") }) - it("should create a mode-specific skill", async () => { + it("should create a mode-specific skill with modeSlugs array", async () => { mockDirectoryExists.mockResolvedValue(false) mockRealpath.mockImplementation(async (p: string) => p) mockReaddir.mockResolvedValue([]) @@ -1014,9 +1014,15 @@ Instructions`) mockMkdir.mockResolvedValue(undefined) mockWriteFile.mockResolvedValue(undefined) - const createdPath = await skillsManager.createSkill("code-skill", "global", "A code skill", "code") + const createdPath = await skillsManager.createSkill("code-skill", "global", "A code skill", ["code"]) - expect(createdPath).toBe(p(GLOBAL_ROO_DIR, "skills-code", "code-skill", "SKILL.md")) + // Skills are always created in the generic skills directory now; mode info is in frontmatter + expect(createdPath).toBe(p(GLOBAL_ROO_DIR, "skills", "code-skill", "SKILL.md")) + + // Verify frontmatter contains modeSlugs + const writeCall = mockWriteFile.mock.calls[0] + expect(writeCall[1]).toContain("modeSlugs:") + expect(writeCall[1]).toContain("- code") }) it("should create a project skill", async () => { diff --git a/src/shared/skills.ts b/src/shared/skills.ts index ae35b8c387..cbcc71d7b7 100644 --- a/src/shared/skills.ts +++ b/src/shared/skills.ts @@ -7,7 +7,17 @@ export interface SkillMetadata { description: string // Required: when to use this skill path: string // Absolute path to SKILL.md (or "" for built-in skills) source: "global" | "project" | "built-in" // Where the skill was discovered - mode?: string // If set, skill is only available in this mode + /** + * @deprecated Use modeSlugs instead. Kept for backward compatibility. + * If set, skill is only available in this mode. + */ + mode?: string + /** + * Mode slugs where this skill is available. + * - undefined or empty array means the skill is available in all modes ("Any mode"). + * - An array with one or more mode slugs restricts the skill to those modes. + */ + modeSlugs?: string[] } /** diff --git a/webview-ui/src/components/chat/SlashCommandItem.tsx b/webview-ui/src/components/chat/SlashCommandItem.tsx deleted file mode 100644 index 3f375b2d76..0000000000 --- a/webview-ui/src/components/chat/SlashCommandItem.tsx +++ /dev/null @@ -1,86 +0,0 @@ -import React from "react" -import { Edit, Trash2 } from "lucide-react" - -import type { Command } from "@roo-code/types" - -import { useAppTranslation } from "@/i18n/TranslationContext" -import { Button, StandardTooltip } from "@/components/ui" -import { vscode } from "@/utils/vscode" -import { getJumpLine } from "@/utils/path-mentions" - -interface SlashCommandItemProps { - command: Command - onDelete: (command: Command) => void - onClick?: (command: Command) => void -} - -export const SlashCommandItem: React.FC = ({ command, onDelete, onClick }) => { - const { t } = useAppTranslation() - - // Built-in commands cannot be edited or deleted - const isBuiltIn = command.source === "built-in" - - const handleEdit = () => { - if (command.filePath) { - vscode.postMessage({ - type: "openFile", - text: command.filePath, - values: { line: getJumpLine(command)[0] || 0 }, - }) - } else { - // Fallback: request to open command file by name and source - vscode.postMessage({ - type: "openCommandFile", - text: command.name, - values: { source: command.source }, - }) - } - } - - const handleDelete = () => { - onDelete(command) - } - - return ( -
- {/* Command name - clickable */} -
onClick?.(command)}> -
- {command.name} - {command.description && ( -
- {command.description} -
- )} -
-
- - {/* Action buttons - only show for non-built-in commands */} - {!isBuiltIn && ( -
- - - - - - - -
- )} -
- ) -} diff --git a/webview-ui/src/components/settings/CreateSkillDialog.tsx b/webview-ui/src/components/settings/CreateSkillDialog.tsx index a4daa9989c..3a8def14ee 100644 --- a/webview-ui/src/components/settings/CreateSkillDialog.tsx +++ b/webview-ui/src/components/settings/CreateSkillDialog.tsx @@ -7,17 +7,20 @@ import { useAppTranslation } from "@/i18n/TranslationContext" import { useExtensionState } from "@/context/ExtensionStateContext" import { Button, + Checkbox, Dialog, DialogContent, DialogDescription, DialogFooter, DialogHeader, DialogTitle, + Input, Select, SelectContent, SelectItem, SelectTrigger, SelectValue, + Textarea, } from "@/components/ui" import { vscode } from "@/utils/vscode" @@ -65,9 +68,6 @@ const validateDescription = (description: string): string | null => { return null } -// Sentinel value for "Any mode" since Radix Select doesn't allow empty string values -const MODE_ANY = "__any__" - export const CreateSkillDialog: React.FC = ({ open, onOpenChange, @@ -80,11 +80,14 @@ export const CreateSkillDialog: React.FC = ({ const [name, setName] = useState("") const [description, setDescription] = useState("") const [source, setSource] = useState<"global" | "project">(hasWorkspace ? "project" : "global") - const [mode, setMode] = useState(MODE_ANY) const [nameError, setNameError] = useState(null) const [descriptionError, setDescriptionError] = useState(null) - // Get available modes for the dropdown (built-in + custom modes) + // Multi-mode selection state (same pattern as SkillsSettings mode dialog) + const [selectedModes, setSelectedModes] = useState([]) + const [isAnyMode, setIsAnyMode] = useState(true) + + // Get available modes for the checkboxes (built-in + custom modes) const availableModes = useMemo(() => { return getAllModes(customModes).map((m) => ({ slug: m.slug, name: m.name })) }, [customModes]) @@ -93,7 +96,8 @@ export const CreateSkillDialog: React.FC = ({ setName("") setDescription("") setSource(hasWorkspace ? "project" : "global") - setMode(MODE_ANY) + setSelectedModes([]) + setIsAnyMode(true) setNameError(null) setDescriptionError(null) }, [hasWorkspace]) @@ -114,6 +118,33 @@ export const CreateSkillDialog: React.FC = ({ setDescriptionError(null) }, []) + // Handle "Any mode" toggle - mutually exclusive with specific modes + const handleAnyModeToggle = useCallback((checked: boolean) => { + if (checked) { + setIsAnyMode(true) + setSelectedModes([]) // Clear specific modes when "Any mode" is selected + } else { + setIsAnyMode(false) + } + }, []) + + // Handle specific mode toggle - unchecks "Any mode" when a specific mode is selected + const handleModeToggle = useCallback((modeSlug: string, checked: boolean) => { + if (checked) { + setIsAnyMode(false) // Uncheck "Any mode" when selecting a specific mode + setSelectedModes((prev) => [...prev, modeSlug]) + } else { + setSelectedModes((prev) => { + const newModes = prev.filter((m) => m !== modeSlug) + // If no modes selected, default back to "Any mode" + if (newModes.length === 0) { + setIsAnyMode(true) + } + return newModes + }) + } + }, []) + const handleCreate = useCallback(() => { // Validate fields const nameValidationError = validateSkillName(name) @@ -130,73 +161,64 @@ export const CreateSkillDialog: React.FC = ({ } // Send message to create skill - // Convert MODE_ANY sentinel value to undefined for the backend + // Convert to modeSlugs: undefined for "Any mode", or array of selected modes + const modeSlugs = isAnyMode ? undefined : selectedModes.length > 0 ? selectedModes : undefined vscode.postMessage({ type: "createSkill", skillName: name, source, skillDescription: description, - skillMode: mode === MODE_ANY ? undefined : mode, + skillModeSlugs: modeSlugs, }) // Close dialog and notify parent handleClose() onSkillCreated() - }, [name, description, source, mode, handleClose, onSkillCreated]) + }, [name, description, source, isAnyMode, selectedModes, handleClose, onSkillCreated]) return ( {t("settings:skills.createDialog.title")} - {t("settings:skills.createDialog.description")} + -
+
{/* Name Input */} -
+
- - - {t("settings:skills.createDialog.nameHint")} - {nameError && {t(nameError)}}
{/* Description Input */} -
- -