diff --git a/.roo/skills/evals-context/SKILL.md b/.roo/skills/evals-context/SKILL.md new file mode 100644 index 0000000000..985b788b94 --- /dev/null +++ b/.roo/skills/evals-context/SKILL.md @@ -0,0 +1,188 @@ +--- +name: evals-context +description: Provides context about the Roo Code evals system structure in this monorepo. Use when tasks mention "evals", "evaluation", "eval runs", "eval exercises", or working with the evals infrastructure. Helps distinguish between the evals execution system (packages/evals, apps/web-evals) and the public website evals display page (apps/web-roo-code/src/app/evals). +--- + +# Evals Codebase Context + +## When to Use This Skill + +Use this skill when the task involves: + +- Modifying or debugging the evals execution infrastructure +- Adding new eval exercises or languages +- Working with the evals web interface (apps/web-evals) +- Modifying the public evals display page on roocode.com +- Understanding where evals code lives in this monorepo + +## When NOT to Use This Skill + +Do NOT use this skill when: + +- Working on unrelated parts of the codebase (extension, webview-ui, etc.) +- The task is purely about the VS Code extension's core functionality +- Working on the main website pages that don't involve evals + +## Key Disambiguation: Two "Evals" Locations + +This monorepo has **two distinct evals-related locations** that can cause confusion: + +| Component | Path | Purpose | +| --------------------------- | -------------------------------------------------------------- | -------------------------------------------------------------- | +| **Evals Execution System** | `packages/evals/` | Core eval infrastructure: CLI, DB schema, Docker configs | +| **Evals Management UI** | `apps/web-evals/` | Next.js app for creating/monitoring eval runs (localhost:3446) | +| **Website Evals Page** | `apps/web-roo-code/src/app/evals/` | Public roocode.com page displaying eval results | +| **External Exercises Repo** | [Roo-Code-Evals](https://github.com/RooCodeInc/Roo-Code-Evals) | Actual coding exercises (NOT in this monorepo) | + +## Directory Structure Reference + +### `packages/evals/` - Core Evals Package + +``` +packages/evals/ +├── ARCHITECTURE.md # Detailed architecture documentation +├── ADDING-EVALS.md # Guide for adding new exercises/languages +├── README.md # Setup and running instructions +├── docker-compose.yml # Container orchestration +├── Dockerfile.runner # Runner container definition +├── Dockerfile.web # Web app container +├── drizzle.config.ts # Database ORM config +├── src/ +│ ├── index.ts # Package exports +│ ├── cli/ # CLI commands for running evals +│ │ ├── runEvals.ts # Orchestrates complete eval runs +│ │ ├── runTask.ts # Executes individual tasks in containers +│ │ ├── runUnitTest.ts # Validates task completion via tests +│ │ └── redis.ts # Redis pub/sub integration +│ ├── db/ +│ │ ├── schema.ts # Database schema (runs, tasks) +│ │ ├── queries/ # Database query functions +│ │ └── migrations/ # SQL migrations +│ └── exercises/ +│ └── index.ts # Exercise loading utilities +└── scripts/ + └── setup.sh # Local macOS setup script +``` + +### `apps/web-evals/` - Evals Management Web App + +``` +apps/web-evals/ +├── src/ +│ ├── app/ +│ │ ├── page.tsx # Home page (runs list) +│ │ ├── runs/ +│ │ │ ├── new/ # Create new eval run +│ │ │ └── [id]/ # View specific run status +│ │ └── api/runs/ # SSE streaming endpoint +│ ├── actions/ # Server actions +│ │ ├── runs.ts # Run CRUD operations +│ │ ├── tasks.ts # Task queries +│ │ ├── exercises.ts # Exercise listing +│ │ └── heartbeat.ts # Controller health checks +│ ├── hooks/ # React hooks (SSE, models, etc.) +│ └── lib/ # Utilities and schemas +``` + +### `apps/web-roo-code/src/app/evals/` - Public Website Evals Page + +``` +apps/web-roo-code/src/app/evals/ +├── page.tsx # Fetches and displays public eval results +├── evals.tsx # Main evals display component +├── plot.tsx # Visualization component +└── types.ts # EvalRun type (extends packages/evals types) +``` + +This page **displays** eval results on the public roocode.com website. It imports types from `@roo-code/evals` but does NOT run evals. + +## Architecture Overview + +The evals system is a distributed evaluation platform that runs AI coding tasks in isolated VS Code environments: + +``` +┌─────────────────────────────────────────────────────────────┐ +│ Web App (apps/web-evals) ──────────────────────────────── │ +│ │ │ +│ ▼ │ +│ PostgreSQL ◄────► Controller Container │ +│ │ │ │ +│ ▼ ▼ │ +│ Redis ◄───► Runner Containers (1-25 parallel) │ +└─────────────────────────────────────────────────────────────┘ +``` + +**Key components:** + +- **Controller**: Orchestrates eval runs, spawns runners, manages task queue (p-queue) +- **Runner**: Isolated Docker container with VS Code + Roo Code extension + language runtimes +- **Redis**: Pub/sub for real-time events (NOT task queuing) +- **PostgreSQL**: Stores runs, tasks, metrics + +## Common Tasks Quick Reference + +### Adding a New Eval Exercise + +1. Add exercise to [Roo-Code-Evals](https://github.com/RooCodeInc/Roo-Code-Evals) repo (external) +2. See [`packages/evals/ADDING-EVALS.md`](packages/evals/ADDING-EVALS.md) for structure + +### Modifying Eval CLI Behavior + +Edit files in [`packages/evals/src/cli/`](packages/evals/src/cli/): + +- [`runEvals.ts`](packages/evals/src/cli/runEvals.ts) - Run orchestration +- [`runTask.ts`](packages/evals/src/cli/runTask.ts) - Task execution +- [`runUnitTest.ts`](packages/evals/src/cli/runUnitTest.ts) - Test validation + +### Modifying the Evals Web Interface + +Edit files in [`apps/web-evals/src/`](apps/web-evals/src/): + +- [`app/runs/new/new-run.tsx`](apps/web-evals/src/app/runs/new/new-run.tsx) - New run form +- [`actions/runs.ts`](apps/web-evals/src/actions/runs.ts) - Run server actions + +### Modifying the Public Evals Display Page + +Edit files in [`apps/web-roo-code/src/app/evals/`](apps/web-roo-code/src/app/evals/): + +- [`evals.tsx`](apps/web-roo-code/src/app/evals/evals.tsx) - Display component +- [`plot.tsx`](apps/web-roo-code/src/app/evals/plot.tsx) - Charts + +### Database Schema Changes + +1. Edit [`packages/evals/src/db/schema.ts`](packages/evals/src/db/schema.ts) +2. Generate migration: `cd packages/evals && pnpm drizzle-kit generate` +3. Apply migration: `pnpm drizzle-kit migrate` + +## Running Evals Locally + +```bash +# From repo root +pnpm evals + +# Opens web UI at http://localhost:3446 +``` + +**Ports (defaults):** + +- PostgreSQL: 5433 +- Redis: 6380 +- Web: 3446 + +## Testing + +```bash +# packages/evals tests +cd packages/evals && npx vitest run + +# apps/web-evals tests +cd apps/web-evals && npx vitest run +``` + +## Key Types/Exports from `@roo-code/evals` + +The package exports are defined in [`packages/evals/src/index.ts`](packages/evals/src/index.ts): + +- Database queries: `getRuns`, `getTasks`, `getTaskMetrics`, etc. +- Schema types: `Run`, `Task`, `TaskMetrics` +- Used by both `apps/web-evals` and `apps/web-roo-code` diff --git a/apps/vscode-nightly/esbuild.mjs b/apps/vscode-nightly/esbuild.mjs index b6ce4830ef..fc72c27a9b 100644 --- a/apps/vscode-nightly/esbuild.mjs +++ b/apps/vscode-nightly/esbuild.mjs @@ -35,8 +35,8 @@ async function main() { platform: "node", define: { "process.env.NODE_ENV": production ? '"production"' : '"development"', - "process.env.ZGSM_BASE_URL": JSON.stringify(process.env.ZGSM_BASE_URL || ""), - "process.env.ZGSM_PUBLIC_KEY": JSON.stringify(process.env.ZGSM_PUBLIC_KEY || ""), + "process.env.COSTRICT_BASE_URL": JSON.stringify(process.env.COSTRICT_BASE_URL || ""), + "process.env.COSTRICT_PUBLIC_KEY": JSON.stringify(process.env.COSTRICT_PUBLIC_KEY || process.env.ZGSM_PUBLIC_KEY || ""), "process.env.COSTRICT_PKG_NAME": '"roo-code-nightly"', "process.env.COSTRICT_PKG_VERSION": `"${overrideJson.version}"`, "process.env.COSTRICT_PKG_OUTPUT_CHANNEL": '"Roo-Code-Nightly"', diff --git a/apps/web-evals/src/app/runs/new/new-run.tsx b/apps/web-evals/src/app/runs/new/new-run.tsx index 28fb4abfd5..cea15c6ddd 100644 --- a/apps/web-evals/src/app/runs/new/new-run.tsx +++ b/apps/web-evals/src/app/runs/new/new-run.tsx @@ -1,6 +1,6 @@ "use client" -import { useCallback, useEffect, useMemo, useState } from "react" +import { useCallback, useEffect, useMemo, useRef, useState } from "react" import { useRouter } from "next/navigation" import { z } from "zod" import { useQuery } from "@tanstack/react-query" @@ -48,6 +48,9 @@ import { } from "@/lib/schemas" import { cn } from "@/lib/utils" +import { loadRooLastModelSelection, saveRooLastModelSelection } from "@/lib/roo-last-model-selection" +import { normalizeCreateRunForSubmit } from "@/lib/normalize-create-run" + import { useOpenRouterModels } from "@/hooks/use-open-router-models" import { useRooCodeCloudModels } from "@/hooks/use-roo-code-cloud-models" @@ -103,6 +106,8 @@ type ConfigSelection = { export function NewRun() { const router = useRouter() + const modelSelectionsByProviderRef = useRef>({}) + const modelValueByProviderRef = useRef>({}) const [provider, setModelSource] = useState<"roo" | "openrouter" | "other">("other") const [executionMethod, setExecutionMethod] = useState("vscode") @@ -147,14 +152,43 @@ export function NewRun() { }) const { + register, setValue, clearErrors, watch, + getValues, formState: { isSubmitting }, } = form const [suite, settings] = watch(["suite", "settings", "concurrency"]) + const selectedModelIds = useMemo( + () => modelSelections.map((s) => s.model).filter((m) => m.length > 0), + [modelSelections], + ) + + const applyModelIds = useCallback( + (modelIds: string[]) => { + const unique = Array.from(new Set(modelIds.map((m) => m.trim()).filter((m) => m.length > 0))) + + if (unique.length === 0) { + setModelSelections([{ id: crypto.randomUUID(), model: "", popoverOpen: false }]) + setValue("model", "") + return + } + + setModelSelections(unique.map((model) => ({ id: crypto.randomUUID(), model, popoverOpen: false }))) + setValue("model", unique[0] ?? "") + }, + [setValue], + ) + + // Ensure the `exercises` field is registered so RHF always includes it in submit values. + useEffect(() => { + register("exercises") + }, [register]) + + // Load settings from localStorage on mount useEffect(() => { const savedConcurrency = localStorage.getItem("evals-concurrency") @@ -215,6 +249,51 @@ export function NewRun() { } }, [setValue]) + // Track previous provider to detect switches + const [prevProvider, setPrevProvider] = useState(provider) + + // Preserve selections per provider; avoids cross-contamination while keeping UX stable. + useEffect(() => { + if (provider === prevProvider) return + + modelSelectionsByProviderRef.current[prevProvider] = modelSelections + modelValueByProviderRef.current[prevProvider] = getValues("model") + + const nextModelSelections = + modelSelectionsByProviderRef.current[provider] ?? + ([{ id: crypto.randomUUID(), model: "", popoverOpen: false }] satisfies ModelSelection[]) + + setModelSelections(nextModelSelections) + + const nextModelValue = + modelValueByProviderRef.current[provider] ?? + nextModelSelections.find((s) => s.model.trim().length > 0)?.model ?? + (provider === "other" && importedSettings && configSelections[0]?.configName + ? (getModelId(importedSettings.apiConfigs[configSelections[0].configName] ?? {}) ?? "") + : "") + + setValue("model", nextModelValue) + setPrevProvider(provider) + }, [provider, prevProvider, modelSelections, setValue, getValues, importedSettings, configSelections]) + + // When switching to Roo provider, restore last-used selection if current selection is empty + useEffect(() => { + if (provider !== "roo") return + if (selectedModelIds.length > 0) return + + const last = loadRooLastModelSelection() + if (last.length > 0) { + applyModelIds(last) + } + }, [applyModelIds, provider, selectedModelIds.length]) + + // Persist last-used Roo provider model selection + useEffect(() => { + if (provider !== "roo") return + saveRooLastModelSelection(selectedModelIds) + }, [provider, selectedModelIds]) + + // Extract unique languages from exercises const languages = useMemo(() => { if (!exercises.data) { return [] @@ -337,7 +416,10 @@ export function NewRun() { const onSubmit = useCallback( async (values: CreateRun) => { try { - if (provider === "roo" && !values.jobToken?.trim()) { + const baseValues = normalizeCreateRunForSubmit(values, selectedExercises, suite) + + // Validate jobToken for Roo Code Cloud provider + if (provider === "roo" && !baseValues.jobToken?.trim()) { toast.error("Roo Code Cloud Token is required") return } @@ -374,8 +456,7 @@ export function NewRun() { await new Promise((resolve) => setTimeout(resolve, 20_000)) } - const runValues = { ...values } - runValues.executionMethod = executionMethod + const runValues = { ...baseValues } if (provider === "openrouter") { runValues.model = selection.model @@ -424,8 +505,9 @@ export function NewRun() { } }, [ + suite, + selectedExercises, provider, - executionMethod, modelSelections, configSelections, importedSettings, diff --git a/apps/web-evals/src/lib/__tests__/normalize-create-run.spec.ts b/apps/web-evals/src/lib/__tests__/normalize-create-run.spec.ts new file mode 100644 index 0000000000..947df31354 --- /dev/null +++ b/apps/web-evals/src/lib/__tests__/normalize-create-run.spec.ts @@ -0,0 +1,65 @@ +import { normalizeCreateRunForSubmit } from "../normalize-create-run" + +describe("normalizeCreateRunForSubmit", () => { + it("uses selectedExercises for partial suite", () => { + const result = normalizeCreateRunForSubmit( + { + model: "roo/model-a", + description: "", + suite: "partial", + exercises: [], + settings: undefined, + concurrency: 1, + timeout: 5, + iterations: 1, + jobToken: "", + executionMethod: "vscode", + }, + ["js/foo", "py/bar"], + ) + + expect(result.suite).toBe("partial") + expect(result.exercises).toEqual(["js/foo", "py/bar"]) + }) + + it("dedupes selectedExercises for partial suite", () => { + const result = normalizeCreateRunForSubmit( + { + model: "roo/model-a", + description: "", + suite: "partial", + exercises: [], + settings: undefined, + concurrency: 1, + timeout: 5, + iterations: 1, + jobToken: "", + executionMethod: "vscode", + }, + ["js/foo", "js/foo", "py/bar"], + ) + + expect(result.exercises).toEqual(["js/foo", "py/bar"]) + }) + + it("clears exercises for full suite", () => { + const result = normalizeCreateRunForSubmit( + { + model: "roo/model-a", + description: "", + suite: "full", + exercises: ["js/foo"], + settings: undefined, + concurrency: 1, + timeout: 5, + iterations: 1, + jobToken: "", + executionMethod: "vscode", + }, + ["js/foo"], + ) + + expect(result.suite).toBe("full") + expect(result.exercises).toEqual([]) + }) +}) diff --git a/apps/web-evals/src/lib/__tests__/roo-last-model-selection.spec.ts b/apps/web-evals/src/lib/__tests__/roo-last-model-selection.spec.ts new file mode 100644 index 0000000000..45879b4be5 --- /dev/null +++ b/apps/web-evals/src/lib/__tests__/roo-last-model-selection.spec.ts @@ -0,0 +1,78 @@ +import { + loadRooLastModelSelection, + ROO_LAST_MODEL_SELECTION_KEY, + saveRooLastModelSelection, +} from "../roo-last-model-selection" + +class LocalStorageMock implements Storage { + private store = new Map() + + get length(): number { + return this.store.size + } + + clear(): void { + this.store.clear() + } + + getItem(key: string): string | null { + return this.store.get(key) ?? null + } + + key(index: number): string | null { + return Array.from(this.store.keys())[index] ?? null + } + + removeItem(key: string): void { + this.store.delete(key) + } + + setItem(key: string, value: string): void { + this.store.set(key, value) + } +} + +beforeEach(() => { + Object.defineProperty(globalThis, "localStorage", { + value: new LocalStorageMock(), + configurable: true, + }) +}) + +describe("roo-last-model-selection", () => { + it("saves and loads (deduped + trimmed)", () => { + saveRooLastModelSelection([" roo/model-a ", "roo/model-a", "roo/model-b"]) + expect(loadRooLastModelSelection()).toEqual(["roo/model-a", "roo/model-b"]) + }) + + it("ignores invalid JSON", () => { + localStorage.setItem(ROO_LAST_MODEL_SELECTION_KEY, "{this is not json") + expect(loadRooLastModelSelection()).toEqual([]) + }) + + it("clears when empty", () => { + localStorage.setItem(ROO_LAST_MODEL_SELECTION_KEY, JSON.stringify(["roo/model-a"])) + saveRooLastModelSelection([]) + expect(localStorage.getItem(ROO_LAST_MODEL_SELECTION_KEY)).toBeNull() + }) + + it("does not throw if localStorage access fails", () => { + Object.defineProperty(globalThis, "localStorage", { + value: { + getItem: () => { + throw new Error("blocked") + }, + setItem: () => { + throw new Error("blocked") + }, + removeItem: () => { + throw new Error("blocked") + }, + }, + configurable: true, + }) + + expect(() => loadRooLastModelSelection()).not.toThrow() + expect(() => saveRooLastModelSelection(["roo/model-a"])).not.toThrow() + }) +}) diff --git a/apps/web-evals/src/lib/normalize-create-run.ts b/apps/web-evals/src/lib/normalize-create-run.ts new file mode 100644 index 0000000000..a5f21ba5ad --- /dev/null +++ b/apps/web-evals/src/lib/normalize-create-run.ts @@ -0,0 +1,20 @@ +import type { CreateRun } from "./schemas" + +/** + * The New Run UI keeps exercise selection in component state. + * This normalizer ensures we submit the *visible/selected* exercises when suite is partial. + */ +export function normalizeCreateRunForSubmit( + values: CreateRun, + selectedExercises: string[], + suiteOverride?: CreateRun["suite"], +): CreateRun { + const suite = suiteOverride ?? values.suite + const normalizedSelectedExercises = Array.from(new Set(selectedExercises)) + + return { + ...values, + suite, + exercises: suite === "partial" ? normalizedSelectedExercises : [], + } +} diff --git a/apps/web-evals/src/lib/roo-last-model-selection.ts b/apps/web-evals/src/lib/roo-last-model-selection.ts new file mode 100644 index 0000000000..b66d493172 --- /dev/null +++ b/apps/web-evals/src/lib/roo-last-model-selection.ts @@ -0,0 +1,76 @@ +import { z } from "zod" + +export const ROO_LAST_MODEL_SELECTION_KEY = "evals-roo-last-model-selection" + +const modelIdListSchema = z.array(z.string()) + +function hasLocalStorage(): boolean { + try { + return typeof localStorage !== "undefined" + } catch { + return false + } +} + +function safeGetItem(key: string): string | null { + try { + return localStorage.getItem(key) + } catch { + return null + } +} + +function safeSetItem(key: string, value: string): void { + try { + localStorage.setItem(key, value) + } catch { + // ignore + } +} + +function safeRemoveItem(key: string): void { + try { + localStorage.removeItem(key) + } catch { + // ignore + } +} + +function tryParseJson(raw: string | null): unknown { + if (raw === null) return undefined + try { + return JSON.parse(raw) + } catch { + return undefined + } +} + +function normalizeModelIds(modelIds: string[]): string[] { + const unique = new Set() + for (const id of modelIds) { + const trimmed = id.trim() + if (trimmed) unique.add(trimmed) + } + return Array.from(unique) +} + +export function loadRooLastModelSelection(): string[] { + if (!hasLocalStorage()) return [] + + const parsed = modelIdListSchema.safeParse(tryParseJson(safeGetItem(ROO_LAST_MODEL_SELECTION_KEY))) + if (!parsed.success) return [] + + return normalizeModelIds(parsed.data) +} + +export function saveRooLastModelSelection(modelIds: string[]): void { + if (!hasLocalStorage()) return + + const normalized = normalizeModelIds(modelIds) + if (normalized.length === 0) { + safeRemoveItem(ROO_LAST_MODEL_SELECTION_KEY) + return + } + + safeSetItem(ROO_LAST_MODEL_SELECTION_KEY, JSON.stringify(normalized)) +} diff --git a/apps/web-roo-code/src/components/homepage/features.tsx b/apps/web-roo-code/src/components/homepage/features.tsx index f2c89f7d90..67024563ea 100644 --- a/apps/web-roo-code/src/components/homepage/features.tsx +++ b/apps/web-roo-code/src/components/homepage/features.tsx @@ -39,7 +39,7 @@ export const features: Feature[] = [ icon: CheckCheck, title: "Granular auto-approval", description: - "Control each action and make CoStrict as autonomous as you want as you build confidence. Or go YOLO and let it rip.", + "Control each action and make CoStrict as autonomous as you want as you build confidence. Or go BRRR and let it rip.", }, { icon: Boxes, diff --git a/packages/types/src/history.ts b/packages/types/src/history.ts index d97884d216..b4d84cb9a5 100644 --- a/packages/types/src/history.ts +++ b/packages/types/src/history.ts @@ -29,6 +29,7 @@ export const historyItemSchema = z.object({ * This ensures task resumption works correctly even when NTC settings change. */ toolProtocol: z.enum(["xml", "native"]).optional(), + apiConfigName: z.string().optional(), // Provider profile name for sticky profile feature status: z.enum(["active", "completed", "delegated"]).optional(), delegatedToId: z.string().optional(), // Last child this parent delegated to childIds: z.array(z.string()).optional(), // All children spawned by this task diff --git a/packages/types/src/providers/bedrock.ts b/packages/types/src/providers/bedrock.ts index da40e98f43..19dfbf0b30 100644 --- a/packages/types/src/providers/bedrock.ts +++ b/packages/types/src/providers/bedrock.ts @@ -264,39 +264,6 @@ export const bedrockModels = { inputPrice: 0.25, outputPrice: 1.25, }, - "anthropic.claude-2-1-v1:0": { - maxTokens: 4096, - contextWindow: 100_000, - supportsImages: false, - supportsPromptCache: false, - supportsNativeTools: true, - defaultToolProtocol: "native", - inputPrice: 8.0, - outputPrice: 24.0, - description: "Claude 2.1", - }, - "anthropic.claude-2-0-v1:0": { - maxTokens: 4096, - contextWindow: 100_000, - supportsImages: false, - supportsPromptCache: false, - supportsNativeTools: true, - defaultToolProtocol: "native", - inputPrice: 8.0, - outputPrice: 24.0, - description: "Claude 2.0", - }, - "anthropic.claude-instant-v1:0": { - maxTokens: 4096, - contextWindow: 100_000, - supportsImages: false, - supportsPromptCache: false, - supportsNativeTools: true, - defaultToolProtocol: "native", - inputPrice: 0.8, - outputPrice: 2.4, - description: "Claude Instant", - }, "deepseek.r1-v1:0": { maxTokens: 32_768, contextWindow: 128_000, diff --git a/packages/types/src/providers/cerebras.ts b/packages/types/src/providers/cerebras.ts index 8e0c2f9413..37c063e83b 100644 --- a/packages/types/src/providers/cerebras.ts +++ b/packages/types/src/providers/cerebras.ts @@ -15,7 +15,19 @@ export const cerebrasModels = { defaultToolProtocol: "native", inputPrice: 0, outputPrice: 0, - description: "Highly intelligent general purpose model with up to 1,000 tokens/s", + description: "Fast general-purpose model on Cerebras (up to 1,000 tokens/s). To be deprecated soon.", + }, + "zai-glm-4.7": { + maxTokens: 16384, // Conservative default to avoid premature rate limiting (Cerebras reserves quota upfront) + contextWindow: 131072, + supportsImages: false, + supportsPromptCache: false, + supportsNativeTools: true, + defaultToolProtocol: "native", + inputPrice: 0, + outputPrice: 0, + description: + "Highly capable general-purpose model on Cerebras (up to 1,000 tokens/s), competitive with leading proprietary models on coding tasks.", }, "qwen-3-235b-a22b-instruct-2507": { maxTokens: 16384, // Conservative default to avoid premature rate limiting diff --git a/src/api/providers/__tests__/bedrock-invokedModelId.spec.ts b/src/api/providers/__tests__/bedrock-invokedModelId.spec.ts index 7fe7255f5b..fe16ea89eb 100644 --- a/src/api/providers/__tests__/bedrock-invokedModelId.spec.ts +++ b/src/api/providers/__tests__/bedrock-invokedModelId.spec.ts @@ -122,7 +122,7 @@ describe("AwsBedrockHandler with invokedModelId", () => { trace: { promptRouter: { invokedModelId: - "arn:aws:bedrock:us-west-2:699475926481:inference-profile/us.anthropic.claude-2-1-v1:0", + "arn:aws:bedrock:us-west-2:699475926481:inference-profile/us.anthropic.claude-3-opus-20240229-v1:0", usage: { inputTokens: 150, outputTokens: 250, @@ -162,12 +162,12 @@ describe("AwsBedrockHandler with invokedModelId", () => { } // Verify that getModelById was called with the id, not the full arn - expect(getModelByIdSpy).toHaveBeenCalledWith("anthropic.claude-2-1-v1:0", "inference-profile") + expect(getModelByIdSpy).toHaveBeenCalledWith("anthropic.claude-3-opus-20240229-v1:0", "inference-profile") // Verify that getModel returns the updated model info const costModel = handler.getModel() //expect(costModel.id).toBe("anthropic.claude-3-5-sonnet-20240620-v1:0") - expect(costModel.info.inputPrice).toBe(8) + expect(costModel.info.inputPrice).toBe(15) // Verify that a usage event was emitted after updating the costModelConfig const usageEvents = events.filter((event) => event.type === "usage") diff --git a/src/api/providers/fetchers/zgsm.ts b/src/api/providers/fetchers/zgsm.ts index 1f5ce359a3..c421558407 100644 --- a/src/api/providers/fetchers/zgsm.ts +++ b/src/api/providers/fetchers/zgsm.ts @@ -6,6 +6,8 @@ import { readModels } from "./modelCache" import { ZgsmAuthService } from "../../../core/costrict/auth" export async function getZgsmModels(baseUrl?: string, apiKey?: string, openAiHeaders?: Record) { + const requestId = uuidv7() + try { if (!baseUrl) { return [] @@ -23,7 +25,7 @@ export async function getZgsmModels(baseUrl?: string, apiKey?: string, openAiHea const headers: Record = { ...COSTRICT_DEFAULT_HEADERS, ...(openAiHeaders || {}), - "X-Request-ID": uuidv7(), + "X-Request-ID": requestId, "x-user-id": id || "", } @@ -39,7 +41,10 @@ export async function getZgsmModels(baseUrl?: string, apiKey?: string, openAiHea const fullResponseData = response.data?.data || [] return fullResponseData as Array } catch (error) { - console.warn(`Error fetching zgsmModels from [${baseUrl}/ai-gateway/api/v1/models]:`, error.message) + console.warn( + `Error fetching zgsmModels from [${requestId}|${baseUrl}/ai-gateway/api/v1/models]:`, + error.message, + ) const modelCache = (await readModels("zgsm")) || {} return Object.keys(modelCache).map((key) => modelCache[key]) @@ -47,6 +52,7 @@ export async function getZgsmModels(baseUrl?: string, apiKey?: string, openAiHea } export async function fetchZgsmQuotaInfo(baseUrl?: string, apiKey?: string): Promise { + const requestId = uuidv7() try { if (!baseUrl || !apiKey) { return null @@ -58,11 +64,10 @@ export async function fetchZgsmQuotaInfo(baseUrl?: string, apiKey?: string): Pro if (!URL.canParse(trimmedBaseUrl)) { return null } - const config: Record = {} const headers: Record = { ...COSTRICT_DEFAULT_HEADERS, - "X-Request-ID": uuidv7(), + "X-Request-ID": requestId, Authorization: `Bearer ${apiKey}`, "Content-Type": "application/json", } @@ -74,12 +79,17 @@ export async function fetchZgsmQuotaInfo(baseUrl?: string, apiKey?: string): Pro return response?.data?.data as QuotaInfo } catch (error) { - console.warn(`Error fetching ZgsmQuotaInfo from [${baseUrl}/quota-manager/api/v1/quota]:`, error.message) + console.warn( + `Error fetching ZgsmQuotaInfo from [${requestId}|${baseUrl}/quota-manager/api/v1/quota]:`, + error.message, + ) return null } } export async function fetchZgsmInviteCode(baseUrl?: string, apiKey?: string): Promise { + const requestId = uuidv7() + try { if (!baseUrl || !apiKey) { return null @@ -95,7 +105,7 @@ export async function fetchZgsmInviteCode(baseUrl?: string, apiKey?: string): Pr const config: Record = {} const headers: Record = { ...COSTRICT_DEFAULT_HEADERS, - "X-Request-ID": uuidv7(), + "X-Request-ID": requestId, Authorization: `Bearer ${apiKey}`, "Content-Type": "application/json", } @@ -107,7 +117,10 @@ export async function fetchZgsmInviteCode(baseUrl?: string, apiKey?: string): Pr return response?.data?.data as InviteCodeInfo } catch (error) { - console.warn(`Error fetching ZgsmInviteCode from [${baseUrl}/quota-manager/api/v1/quota]:`, error.message) + console.warn( + `Error fetching ZgsmInviteCode from [${requestId}|${baseUrl}/quota-manager/api/v1/quota]:`, + error.message, + ) return null } } diff --git a/src/core/costrict/auth/authConfig.ts b/src/core/costrict/auth/authConfig.ts index 9a9c66c2b6..0cdb27bc6b 100644 --- a/src/core/costrict/auth/authConfig.ts +++ b/src/core/costrict/auth/authConfig.ts @@ -26,7 +26,7 @@ export class ZgsmAuthConfig { * Get default API base URL */ public getDefaultApiBaseUrl(): string { - return "https://zgsm.sangfor.com" + return process.env.COSTRICT_BASE_URL || "https://zgsm.sangfor.com" } /** diff --git a/src/core/costrict/base/common/constant.ts b/src/core/costrict/base/common/constant.ts index bebbc0cafd..1b31d37860 100644 --- a/src/core/costrict/base/common/constant.ts +++ b/src/core/costrict/base/common/constant.ts @@ -207,6 +207,6 @@ export const OPENAI_REQUEST_ABORTED = "Request was aborted" export const NOT_PROVIDERED = "not-provided" -export const ZGSM_API_KEY = "zgsmRefreshToken" // zgsmRefreshToken -export const ZGSM_BASE_URL = "zgsmBaseUrl" -export const ZGSM_COMPLETION_URL = "zgsmCompletionUrl" +export const COSTRICT_API_KEY = "zgsmRefreshToken" // zgsmRefreshToken +export const COSTRICT_BASE_URL = "zgsmBaseUrl" +export const COSTRICT_COMPLETION_URL = "zgsmCompletionUrl" diff --git a/src/core/costrict/codebase-index/client.ts b/src/core/costrict/codebase-index/client.ts index 070e23f0f3..5d27028483 100644 --- a/src/core/costrict/codebase-index/client.ts +++ b/src/core/costrict/codebase-index/client.ts @@ -68,7 +68,7 @@ export class CodebaseIndexClient { this.logger = createLogger(Package.outputChannel) this.config = { downloadTimeout: config.downloadTimeout || 30_000, - publicKey: config.publicKey || process.env.ZGSM_PUBLIC_KEY!, + publicKey: config.publicKey || process.env.COSTRICT_PUBLIC_KEY! || process.env.ZGSM_PUBLIC_KEY!, getLocalVersion: config.getLocalVersion, } diff --git a/src/core/mentions/__tests__/resolveImageMentions.spec.ts b/src/core/mentions/__tests__/resolveImageMentions.spec.ts new file mode 100644 index 0000000000..747c778819 --- /dev/null +++ b/src/core/mentions/__tests__/resolveImageMentions.spec.ts @@ -0,0 +1,193 @@ +import * as path from "path" + +import { resolveImageMentions } from "../resolveImageMentions" + +vi.mock("../../tools/helpers/imageHelpers", () => ({ + isSupportedImageFormat: vi.fn((ext: string) => + [".png", ".jpg", ".jpeg", ".gif", ".webp", ".svg", ".bmp", ".ico", ".tiff", ".tif", ".avif"].includes( + ext.toLowerCase(), + ), + ), + readImageAsDataUrlWithBuffer: vi.fn(), + validateImageForProcessing: vi.fn(), + ImageMemoryTracker: vi.fn().mockImplementation(() => ({ + getTotalMemoryUsed: vi.fn().mockReturnValue(0), + addMemoryUsage: vi.fn(), + })), + DEFAULT_MAX_IMAGE_FILE_SIZE_MB: 5, + DEFAULT_MAX_TOTAL_IMAGE_SIZE_MB: 20, +})) + +import { validateImageForProcessing, readImageAsDataUrlWithBuffer } from "../../tools/helpers/imageHelpers" + +const mockReadImageAsDataUrl = vi.mocked(readImageAsDataUrlWithBuffer) +const mockValidateImage = vi.mocked(validateImageForProcessing) + +describe("resolveImageMentions", () => { + beforeEach(() => { + vi.clearAllMocks() + // Default: validation passes + mockValidateImage.mockResolvedValue({ isValid: true, sizeInMB: 0.1 }) + }) + + it("should append a data URL when a local png mention is present", async () => { + const dataUrl = `data:image/png;base64,${Buffer.from("png-bytes").toString("base64")}` + mockReadImageAsDataUrl.mockResolvedValue({ dataUrl, buffer: Buffer.from("png-bytes") }) + + const result = await resolveImageMentions({ + text: "Please look at @/assets/cat.png", + images: [], + cwd: "/workspace", + }) + + expect(mockValidateImage).toHaveBeenCalled() + expect(mockReadImageAsDataUrl).toHaveBeenCalledWith(path.resolve("/workspace", "assets/cat.png")) + expect(result.text).toBe("Please look at @/assets/cat.png") + expect(result.images).toEqual([dataUrl]) + }) + + it("should support gif images (matching read_file)", async () => { + const dataUrl = `data:image/gif;base64,${Buffer.from("gif-bytes").toString("base64")}` + mockReadImageAsDataUrl.mockResolvedValue({ dataUrl, buffer: Buffer.from("gif-bytes") }) + + const result = await resolveImageMentions({ + text: "See @/animation.gif", + images: [], + cwd: "/workspace", + }) + + expect(result.images).toEqual([dataUrl]) + }) + + it("should support svg images (matching read_file)", async () => { + const dataUrl = `data:image/svg+xml;base64,${Buffer.from("svg-bytes").toString("base64")}` + mockReadImageAsDataUrl.mockResolvedValue({ dataUrl, buffer: Buffer.from("svg-bytes") }) + + const result = await resolveImageMentions({ + text: "See @/icon.svg", + images: [], + cwd: "/workspace", + }) + + expect(result.images).toEqual([dataUrl]) + }) + + it("should ignore non-image mentions", async () => { + const result = await resolveImageMentions({ + text: "See @/src/index.ts", + images: [], + cwd: "/workspace", + }) + + expect(mockReadImageAsDataUrl).not.toHaveBeenCalled() + expect(result.images).toEqual([]) + }) + + it("should skip unreadable files (fail-soft)", async () => { + mockReadImageAsDataUrl.mockRejectedValue(new Error("ENOENT")) + + const result = await resolveImageMentions({ + text: "See @/missing.webp", + images: [], + cwd: "/workspace", + }) + + expect(result.images).toEqual([]) + }) + + it("should respect rooIgnoreController", async () => { + const dataUrl = `data:image/jpeg;base64,${Buffer.from("jpg-bytes").toString("base64")}` + mockReadImageAsDataUrl.mockResolvedValue({ dataUrl, buffer: Buffer.from("jpg-bytes") }) + const rooIgnoreController = { + validateAccess: vi.fn().mockReturnValue(false), + } + + const result = await resolveImageMentions({ + text: "See @/secret.jpg", + images: [], + cwd: "/workspace", + rooIgnoreController, + }) + + expect(rooIgnoreController.validateAccess).toHaveBeenCalledWith("secret.jpg") + expect(mockReadImageAsDataUrl).not.toHaveBeenCalled() + expect(result.images).toEqual([]) + }) + + it("should dedupe when mention repeats", async () => { + const dataUrl = `data:image/png;base64,${Buffer.from("png-bytes").toString("base64")}` + mockReadImageAsDataUrl.mockResolvedValue({ dataUrl, buffer: Buffer.from("png-bytes") }) + + const result = await resolveImageMentions({ + text: "@/a.png and again @/a.png", + images: [], + cwd: "/workspace", + }) + + expect(result.images).toHaveLength(1) + }) + + it("should skip images when supportsImages is false", async () => { + const dataUrl = `data:image/png;base64,${Buffer.from("png-bytes").toString("base64")}` + mockReadImageAsDataUrl.mockResolvedValue({ dataUrl, buffer: Buffer.from("png-bytes") }) + + const result = await resolveImageMentions({ + text: "See @/cat.png", + images: [], + cwd: "/workspace", + supportsImages: false, + }) + + expect(mockReadImageAsDataUrl).not.toHaveBeenCalled() + expect(result.images).toEqual([]) + }) + + it("should skip images that exceed size limits", async () => { + mockValidateImage.mockResolvedValue({ + isValid: false, + reason: "size_limit", + notice: "Image too large", + }) + + const result = await resolveImageMentions({ + text: "See @/huge.png", + images: [], + cwd: "/workspace", + }) + + expect(mockValidateImage).toHaveBeenCalled() + expect(mockReadImageAsDataUrl).not.toHaveBeenCalled() + expect(result.images).toEqual([]) + }) + + it("should skip images that would exceed memory limit", async () => { + mockValidateImage.mockResolvedValue({ + isValid: false, + reason: "memory_limit", + notice: "Would exceed memory limit", + }) + + const result = await resolveImageMentions({ + text: "See @/large.png", + images: [], + cwd: "/workspace", + }) + + expect(result.images).toEqual([]) + }) + + it("should pass custom size limits to validation", async () => { + const dataUrl = `data:image/png;base64,${Buffer.from("png-bytes").toString("base64")}` + mockReadImageAsDataUrl.mockResolvedValue({ dataUrl, buffer: Buffer.from("png-bytes") }) + + await resolveImageMentions({ + text: "See @/cat.png", + images: [], + cwd: "/workspace", + maxImageFileSize: 10, + maxTotalImageSize: 50, + }) + + expect(mockValidateImage).toHaveBeenCalledWith(expect.any(String), true, 10, 50, 0) + }) +}) diff --git a/src/core/mentions/index.ts b/src/core/mentions/index.ts index ac7e3f54ec..457cf48a63 100644 --- a/src/core/mentions/index.ts +++ b/src/core/mentions/index.ts @@ -288,7 +288,13 @@ async function getFileOrFolderContent( const stats = await fs.stat(absPath) if (stats.isFile()) { - if (rooIgnoreController && !rooIgnoreController.validateAccess(absPath)) { + // Avoid trying to include image binary content as text context. + // Image mentions are handled separately via image attachment flow. + const isBinary = await isBinaryFileWithEncodingDetection(absPath).catch(() => false) + if (isBinary) { + return `(Binary file ${mentionPath} omitted)` + } + if (rooIgnoreController && !rooIgnoreController.validateAccess(unescapedPath)) { return `(File ${mentionPath} is ignored by .rooignore)` } try { diff --git a/src/core/mentions/resolveImageMentions.ts b/src/core/mentions/resolveImageMentions.ts new file mode 100644 index 0000000000..0a0344348f --- /dev/null +++ b/src/core/mentions/resolveImageMentions.ts @@ -0,0 +1,145 @@ +import * as path from "path" + +import { mentionRegexGlobal, unescapeSpaces } from "../../shared/context-mentions" +import { + isSupportedImageFormat, + readImageAsDataUrlWithBuffer, + validateImageForProcessing, + ImageMemoryTracker, + DEFAULT_MAX_IMAGE_FILE_SIZE_MB, + DEFAULT_MAX_TOTAL_IMAGE_SIZE_MB, +} from "../tools/helpers/imageHelpers" + +const MAX_IMAGES_PER_MESSAGE = 20 + +export interface ResolveImageMentionsOptions { + text: string + images?: string[] + cwd: string + rooIgnoreController?: { validateAccess: (filePath: string) => boolean } + /** Whether the current model supports images. Defaults to true. */ + supportsImages?: boolean + /** Maximum size per image file in MB. Defaults to 5MB. */ + maxImageFileSize?: number + /** Maximum total size of all images in MB. Defaults to 20MB. */ + maxTotalImageSize?: number +} + +export interface ResolveImageMentionsResult { + text: string + images: string[] +} + +function isPathWithinCwd(absPath: string, cwd: string): boolean { + const rel = path.relative(cwd, absPath) + return rel !== "" && !rel.startsWith("..") && !path.isAbsolute(rel) +} + +function dedupePreserveOrder(values: string[]): string[] { + const seen = new Set() + const result: string[] = [] + for (const v of values) { + if (seen.has(v)) continue + seen.add(v) + result.push(v) + } + return result +} + +/** + * Resolves local image file mentions like `@/path/to/image.png` found in `text` into `data:image/...;base64,...` + * and appends them to the outgoing `images` array. + * + * Behavior matches the read_file tool: + * - Supports the same image formats: png, jpg, jpeg, gif, webp, svg, bmp, ico, tiff, avif + * - Respects per-file size limits (default 5MB) + * - Respects total memory limits (default 20MB) + * - Skips images if model doesn't support them + * - Respects `.rooignore` via `rooIgnoreController.validateAccess` when provided + */ +export async function resolveImageMentions({ + text, + images, + cwd, + rooIgnoreController, + supportsImages = true, + maxImageFileSize = DEFAULT_MAX_IMAGE_FILE_SIZE_MB, + maxTotalImageSize = DEFAULT_MAX_TOTAL_IMAGE_SIZE_MB, +}: ResolveImageMentionsOptions): Promise { + const existingImages = Array.isArray(images) ? images : [] + if (existingImages.length >= MAX_IMAGES_PER_MESSAGE) { + return { text, images: existingImages.slice(0, MAX_IMAGES_PER_MESSAGE) } + } + + // If model doesn't support images, skip image processing entirely + if (!supportsImages) { + return { text, images: existingImages } + } + + const mentions = Array.from(text.matchAll(mentionRegexGlobal)) + .map((m) => m[1]) + .filter(Boolean) + if (mentions.length === 0) { + return { text, images: existingImages } + } + + const imageMentions = mentions.filter((mention) => { + if (!mention.startsWith("/")) return false + const relPath = unescapeSpaces(mention.slice(1)) + const ext = path.extname(relPath).toLowerCase() + return isSupportedImageFormat(ext) + }) + + if (imageMentions.length === 0) { + return { text, images: existingImages } + } + + const imageMemoryTracker = new ImageMemoryTracker() + const newImages: string[] = [] + + for (const mention of imageMentions) { + if (existingImages.length + newImages.length >= MAX_IMAGES_PER_MESSAGE) { + break + } + + const relPath = unescapeSpaces(mention.slice(1)) + const absPath = path.resolve(cwd, relPath) + if (!isPathWithinCwd(absPath, cwd)) { + continue + } + + if (rooIgnoreController && !rooIgnoreController.validateAccess(relPath)) { + continue + } + + // Validate image size limits (matches read_file behavior) + try { + const validationResult = await validateImageForProcessing( + absPath, + supportsImages, + maxImageFileSize, + maxTotalImageSize, + imageMemoryTracker.getTotalMemoryUsed(), + ) + + if (!validationResult.isValid) { + // Skip this image due to size/memory limits, but continue processing others + continue + } + + const { dataUrl } = await readImageAsDataUrlWithBuffer(absPath) + newImages.push(dataUrl) + + // Track memory usage + if (validationResult.sizeInMB) { + imageMemoryTracker.addMemoryUsage(validationResult.sizeInMB) + } + } catch { + // Fail-soft: skip unreadable/missing files. + continue + } + } + + const merged = dedupePreserveOrder([...existingImages, ...newImages]).slice(0, MAX_IMAGES_PER_MESSAGE) + return { text, images: merged } +} diff --git a/src/core/prompts/__tests__/sections.spec.ts b/src/core/prompts/__tests__/sections.spec.ts index d8a002d8f5..011b279698 100644 --- a/src/core/prompts/__tests__/sections.spec.ts +++ b/src/core/prompts/__tests__/sections.spec.ts @@ -1,7 +1,8 @@ import { addCustomInstructions } from "../sections/custom-instructions" import { getCapabilitiesSection } from "../sections/capabilities" -import { getRulesSection } from "../sections/rules" +import { getRulesSection, getCommandChainOperator } from "../sections/rules" import { McpHub } from "../../../services/mcp/McpHub" +import * as shellUtils from "../../../utils/shell" describe("addCustomInstructions", () => { it("adds vscode language to custom instructions", async () => { @@ -114,3 +115,117 @@ describe("getRulesSection", () => { expect(result).not.toContain("Never reveal the vendor or company") }) }) + +describe("getCommandChainOperator", () => { + it("returns && for bash shell", () => { + vi.spyOn(shellUtils, "getShell").mockReturnValue("/bin/bash") + expect(getCommandChainOperator()).toBe("&&") + }) + + it("returns && for zsh shell", () => { + vi.spyOn(shellUtils, "getShell").mockReturnValue("/bin/zsh") + expect(getCommandChainOperator()).toBe("&&") + }) + + it("returns ; for PowerShell", () => { + vi.spyOn(shellUtils, "getShell").mockReturnValue( + "C:\\Windows\\System32\\WindowsPowerShell\\v1.0\\powershell.exe", + ) + expect(getCommandChainOperator()).toBe(";") + }) + + it("returns ; for PowerShell Core (pwsh)", () => { + vi.spyOn(shellUtils, "getShell").mockReturnValue("C:\\Program Files\\PowerShell\\7\\pwsh.exe") + expect(getCommandChainOperator()).toBe(";") + }) + + it("returns && for cmd.exe", () => { + vi.spyOn(shellUtils, "getShell").mockReturnValue("C:\\Windows\\System32\\cmd.exe") + expect(getCommandChainOperator()).toBe("&&") + }) + + it("returns && for Git Bash on Windows", () => { + vi.spyOn(shellUtils, "getShell").mockReturnValue("C:\\Program Files\\Git\\bin\\bash.exe") + expect(getCommandChainOperator()).toBe("&&") + }) + + it("returns && for WSL bash", () => { + vi.spyOn(shellUtils, "getShell").mockReturnValue("/bin/bash") + expect(getCommandChainOperator()).toBe("&&") + }) +}) + +describe("getRulesSection shell-aware command chaining", () => { + const cwd = "/test/path" + + afterEach(() => { + vi.restoreAllMocks() + }) + + it("uses && for Unix shells in command chaining example", () => { + vi.spyOn(shellUtils, "getShell").mockReturnValue("/bin/bash") + const result = getRulesSection(cwd) + + expect(result).toContain("cd (path to project) && (command") + expect(result).not.toContain("cd (path to project) ; (command") + expect(result).not.toContain("cd (path to project) & (command") + }) + + it("uses ; for PowerShell in command chaining example", () => { + vi.spyOn(shellUtils, "getShell").mockReturnValue( + "C:\\Windows\\System32\\WindowsPowerShell\\v1.0\\powershell.exe", + ) + const result = getRulesSection(cwd) + + expect(result).toContain("cd (path to project) ; (command") + expect(result).toContain("Note: Using `;` for PowerShell command chaining") + }) + + it("uses && for cmd.exe in command chaining example", () => { + vi.spyOn(shellUtils, "getShell").mockReturnValue("C:\\Windows\\System32\\cmd.exe") + const result = getRulesSection(cwd) + + expect(result).toContain("cd (path to project) && (command") + expect(result).toContain("Note: Using `&&` for cmd.exe command chaining") + }) + + it("includes Unix utility guidance for PowerShell", () => { + vi.spyOn(shellUtils, "getShell").mockReturnValue( + "C:\\Windows\\System32\\WindowsPowerShell\\v1.0\\powershell.exe", + ) + const result = getRulesSection(cwd) + + expect(result).toContain("IMPORTANT: When using PowerShell, avoid Unix-specific utilities") + expect(result).toContain("`sed`, `grep`, `awk`, `cat`, `rm`, `cp`, `mv`") + expect(result).toContain("`Select-String` for grep") + expect(result).toContain("`Get-Content` for cat") + expect(result).toContain("PowerShell's `-replace` operator") + }) + + it("includes Unix utility guidance for cmd.exe", () => { + vi.spyOn(shellUtils, "getShell").mockReturnValue("C:\\Windows\\System32\\cmd.exe") + const result = getRulesSection(cwd) + + expect(result).toContain("IMPORTANT: When using cmd.exe, avoid Unix-specific utilities") + expect(result).toContain("`sed`, `grep`, `awk`, `cat`, `rm`, `cp`, `mv`") + expect(result).toContain("`type` for cat") + expect(result).toContain("`del` for rm") + expect(result).toContain("`find`/`findstr` for grep") + }) + + it("does not include Unix utility guidance for Unix shells", () => { + vi.spyOn(shellUtils, "getShell").mockReturnValue("/bin/bash") + const result = getRulesSection(cwd) + + expect(result).not.toContain("IMPORTANT: When using PowerShell") + expect(result).not.toContain("IMPORTANT: When using cmd.exe") + expect(result).not.toContain("`Select-String` for grep") + }) + + it("does not include note for Unix shells", () => { + vi.spyOn(shellUtils, "getShell").mockReturnValue("/bin/zsh") + const result = getRulesSection(cwd) + + expect(result).not.toContain("Note: Using") + }) +}) diff --git a/src/core/prompts/sections/rules.ts b/src/core/prompts/sections/rules.ts index 20f0897022..800fb430ef 100644 --- a/src/core/prompts/sections/rules.ts +++ b/src/core/prompts/sections/rules.ts @@ -1,6 +1,53 @@ import type { SystemPromptSettings } from "../types" import { getEffectiveProtocol, isNativeProtocol } from "@roo-code/types" +import { getShell } from "../../../utils/shell" + +/** + * Returns the appropriate command chaining operator based on the user's shell. + * - Unix shells (bash, zsh, etc.): `&&` (run next command only if previous succeeds) + * - PowerShell: `;` (semicolon for command separation) + * - cmd.exe: `&&` (conditional execution, same as Unix) + * @internal Exported for testing purposes + */ +export function getCommandChainOperator(): string { + const shell = getShell().toLowerCase() + + // Check for PowerShell (both Windows PowerShell and PowerShell Core) + if (shell.includes("powershell") || shell.includes("pwsh")) { + return ";" + } + + // Check for cmd.exe + if (shell.includes("cmd.exe")) { + return "&&" + } + + // Default to Unix-style && for bash, zsh, sh, and other shells + // This also covers Git Bash, WSL, and other Unix-like environments on Windows + return "&&" +} + +/** + * Returns a shell-specific note about command chaining syntax and platform-specific utilities. + */ +function getCommandChainNote(): string { + const shell = getShell().toLowerCase() + + // Check for PowerShell + if (shell.includes("powershell") || shell.includes("pwsh")) { + return "Note: Using `;` for PowerShell command chaining. For bash/zsh use `&&`, for cmd.exe use `&&`. IMPORTANT: When using PowerShell, avoid Unix-specific utilities like `sed`, `grep`, `awk`, `cat`, `rm`, `cp`, `mv`. Instead use PowerShell equivalents: `Select-String` for grep, `Get-Content` for cat, `Remove-Item` for rm, `Copy-Item` for cp, `Move-Item` for mv, and PowerShell's `-replace` operator or `[regex]` for sed." + } + + // Check for cmd.exe + if (shell.includes("cmd.exe")) { + return "Note: Using `&&` for cmd.exe command chaining (conditional execution). For bash/zsh use `&&`, for PowerShell use `;`. IMPORTANT: When using cmd.exe, avoid Unix-specific utilities like `sed`, `grep`, `awk`, `cat`, `rm`, `cp`, `mv`. Use built-in commands like `type` for cat, `del` for rm, `copy` for cp, `move` for mv, `find`/`findstr` for grep, or consider using PowerShell commands instead." + } + + // Unix shells + return "" +} + function getVendorConfidentialitySection(): string { return ` @@ -20,6 +67,10 @@ export function getRulesSection(cwd: string, settings?: SystemPromptSettings): s // Determine whether to use XML tool references based on protocol const effectiveProtocol = getEffectiveProtocol(settings?.toolProtocol) + // Get shell-appropriate command chaining operator + const chainOp = getCommandChainOperator() + const chainNote = getCommandChainNote() + return `==== RULES @@ -28,7 +79,7 @@ RULES - All file paths must be relative to this directory. However, commands may change directories in terminals, so respect working directory specified by the response to ${isNativeProtocol(effectiveProtocol) ? "execute_command" : ""}. - You cannot \`cd\` into a different directory to complete a task. You are stuck operating from '${cwd.toPosix()}', so be sure to pass in the correct 'path' parameter when using tools that require a path. - Do not use the ~ character or $HOME to refer to the home directory. -- Before using the execute_command tool, you must first think about the SYSTEM INFORMATION context provided to understand the user's environment and tailor your commands to ensure they are compatible with their system. You must also consider if the command you need to run should be executed in a specific directory outside of the current working directory '${cwd.toPosix()}', and if so prepend with \`cd\`'ing into that directory && then executing the command (as one command since you are stuck operating from '${cwd.toPosix()}'). For example, if you needed to run \`npm install\` in a project outside of '${cwd.toPosix()}', you would need to prepend with a \`cd\` i.e. pseudocode for this would be \`cd (path to project) && (command, in this case npm install)\`. +- Before using the execute_command tool, you must first think about the SYSTEM INFORMATION context provided to understand the user's environment and tailor your commands to ensure they are compatible with their system. You must also consider if the command you need to run should be executed in a specific directory outside of the current working directory '${cwd.toPosix()}', and if so prepend with \`cd\`'ing into that directory ${chainOp} then executing the command (as one command since you are stuck operating from '${cwd.toPosix()}'). For example, if you needed to run \`npm install\` in a project outside of '${cwd.toPosix()}', you would need to prepend with a \`cd\` i.e. pseudocode for this would be \`cd (path to project) ${chainOp} (command, in this case npm install)\`.${chainNote ? ` ${chainNote}` : ""} - Some modes have restrictions on which files they can edit. If you attempt to edit a restricted file, the operation will be rejected with a FileRestrictionError that will specify which file patterns are allowed for the current mode. - Be sure to consider the type of project (e.g. Python, JavaScript, web application) when determining the appropriate structure and files to include. Also consider what files may be most relevant to accomplishing the task, for example looking at a project's manifest file would help you understand the project's dependencies, which you could incorporate into any code you write. * For example, in architect mode trying to edit app.js would be rejected because architect mode can only edit files matching "\\.md$" diff --git a/src/core/task-persistence/taskMetadata.ts b/src/core/task-persistence/taskMetadata.ts index eb872a6f7e..cf8d9adb52 100644 --- a/src/core/task-persistence/taskMetadata.ts +++ b/src/core/task-persistence/taskMetadata.ts @@ -21,6 +21,8 @@ export type TaskMetadataOptions = { globalStoragePath: string workspace: string mode?: string + /** Provider profile name for the task (sticky profile feature) */ + apiConfigName?: string /** Initial status for the task (e.g., "active" for child tasks) */ initialStatus?: "active" | "delegated" | "completed" /** @@ -39,6 +41,7 @@ export async function taskMetadata({ globalStoragePath, workspace, mode, + apiConfigName, initialStatus, toolProtocol, }: TaskMetadataOptions) { @@ -116,6 +119,7 @@ export async function taskMetadata({ workspace, mode, ...(toolProtocol && { toolProtocol }), + ...(typeof apiConfigName === "string" && apiConfigName.length > 0 ? { apiConfigName } : {}), ...(initialStatus && { status: initialStatus }), } diff --git a/src/core/task/Task.ts b/src/core/task/Task.ts index 25965a4da4..78f7f31117 100644 --- a/src/core/task/Task.ts +++ b/src/core/task/Task.ts @@ -259,6 +259,49 @@ export class Task extends EventEmitter implements TaskLike { */ private taskModeReady: Promise + /** + * The API configuration name (provider profile) associated with this task. + * Persisted across sessions to maintain the provider profile when reopening tasks from history. + * + * ## Lifecycle + * + * ### For new tasks: + * 1. Initially `undefined` during construction + * 2. Asynchronously initialized from provider state via `initializeTaskApiConfigName()` + * 3. Falls back to "default" if provider state is unavailable + * + * ### For history items: + * 1. Immediately set from `historyItem.apiConfigName` during construction + * 2. Falls back to undefined if not stored in history (for backward compatibility) + * + * ## Important + * If you need a non-`undefined` provider profile (e.g., for profile-dependent operations), + * wait for `taskApiConfigReady` first (or use `getTaskApiConfigName()`). + * The sync `taskApiConfigName` getter may return `undefined` for backward compatibility. + * + * @private + * @see {@link getTaskApiConfigName} - For safe async access + * @see {@link taskApiConfigName} - For sync access after initialization + */ + private _taskApiConfigName: string | undefined + + /** + * Promise that resolves when the task API config name has been initialized. + * This ensures async API config name initialization completes before the task is used. + * + * ## Purpose + * - Prevents race conditions when accessing task API config name + * - Ensures provider state is properly loaded before profile-dependent operations + * - Provides a synchronization point for async initialization + * + * ## Resolution timing + * - For history items: Resolves immediately (sync initialization) + * - For new tasks: Resolves after provider state is fetched (async initialization) + * + * @private + */ + private taskApiConfigReady: Promise + providerRef: WeakRef private readonly globalStoragePath: string abort: boolean = false @@ -537,21 +580,25 @@ export class Task extends EventEmitter implements TaskLike { this.taskNumber = taskNumber this.initialStatus = initialStatus - // Store the task's mode when it's created. - // For history items, use the stored mode; for new tasks, we'll set it + // Store the task's mode and API config name when it's created. + // For history items, use the stored values; for new tasks, we'll set them // after getting state. if (historyItem) { this._taskMode = historyItem.mode || defaultModeSlug + this._taskApiConfigName = historyItem.apiConfigName this.taskModeReady = Promise.resolve() + this.taskApiConfigReady = Promise.resolve() TelemetryService.instance.captureTaskRestarted(this.taskId) // For history items, use the persisted tool protocol if available. // If not available (old tasks), it will be detected in resumeTaskFromHistory. this._taskToolProtocol = historyItem.toolProtocol } else { - // For new tasks, don't set the mode yet - wait for async initialization. + // For new tasks, don't set the mode/apiConfigName yet - wait for async initialization. this._taskMode = undefined + this._taskApiConfigName = undefined this.taskModeReady = this.initializeTaskMode(provider) + this.taskApiConfigReady = this.initializeTaskApiConfigName(provider) TelemetryService.instance.captureTaskCreated(this.taskId) // For new tasks, resolve and lock the tool protocol immediately. @@ -676,6 +723,47 @@ export class Task extends EventEmitter implements TaskLike { } } + /** + * Initialize the task API config name from the provider state. + * This method handles async initialization with proper error handling. + * + * ## Flow + * 1. Attempts to fetch the current API config name from provider state + * 2. Sets `_taskApiConfigName` to the fetched name or "default" if unavailable + * 3. Handles errors gracefully by falling back to "default" + * 4. Logs any initialization errors for debugging + * + * ## Error handling + * - Network failures when fetching provider state + * - Provider not yet initialized + * - Invalid state structure + * + * All errors result in fallback to "default" to ensure task can proceed. + * + * @private + * @param provider - The ClineProvider instance to fetch state from + * @returns Promise that resolves when initialization is complete + */ + private async initializeTaskApiConfigName(provider: ClineProvider): Promise { + try { + const state = await provider.getState() + + // Avoid clobbering a newer value that may have been set while awaiting provider state + // (e.g., user switches provider profile immediately after task creation). + if (this._taskApiConfigName === undefined) { + this._taskApiConfigName = state?.currentApiConfigName ?? "default" + } + } catch (error) { + // If there's an error getting state, use the default profile (unless a newer value was set). + if (this._taskApiConfigName === undefined) { + this._taskApiConfigName = "default" + } + // Use the provider's log method for better error visibility + const errorMessage = `Failed to initialize task API config name: ${error instanceof Error ? error.message : String(error)}` + provider.log(errorMessage) + } + } + /** * Sets up a listener for provider profile changes to automatically update the parser state. * This ensures the XML/native protocol parser stays synchronized with the current model. @@ -796,6 +884,73 @@ export class Task extends EventEmitter implements TaskLike { return this._taskMode } + /** + * Wait for the task API config name to be initialized before proceeding. + * This method ensures that any operations depending on the task's provider profile + * will have access to the correct value. + * + * ## When to use + * - Before accessing provider profile-specific configurations + * - When switching between tasks with different provider profiles + * - Before operations that depend on the provider profile + * + * @returns Promise that resolves when the task API config name is initialized + * @public + */ + public async waitForApiConfigInitialization(): Promise { + return this.taskApiConfigReady + } + + /** + * Get the task API config name asynchronously, ensuring it's properly initialized. + * This is the recommended way to access the task's provider profile as it guarantees + * the value is available before returning. + * + * ## Async behavior + * - Internally waits for `taskApiConfigReady` promise to resolve + * - Returns the initialized API config name or undefined as fallback + * - Safe to call multiple times - subsequent calls return immediately if already initialized + * + * @returns Promise resolving to the task API config name string or undefined + * @public + */ + public async getTaskApiConfigName(): Promise { + await this.taskApiConfigReady + return this._taskApiConfigName + } + + /** + * Get the task API config name synchronously. This should only be used when you're certain + * that the value has already been initialized (e.g., after waitForApiConfigInitialization). + * + * ## When to use + * - In synchronous contexts where async/await is not available + * - After explicitly waiting for initialization via `waitForApiConfigInitialization()` + * - In event handlers or callbacks where API config name is guaranteed to be initialized + * + * Note: Unlike taskMode, this getter does not throw if uninitialized since the API config + * name can legitimately be undefined (backward compatibility with tasks created before + * this feature was added). + * + * @returns The task API config name string or undefined + * @public + */ + public get taskApiConfigName(): string | undefined { + return this._taskApiConfigName + } + + /** + * Update the task's API config name. This is called when the user switches + * provider profiles while a task is active, allowing the task to remember + * its new provider profile. + * + * @param apiConfigName - The new API config name to set + * @internal + */ + public setTaskApiConfigName(apiConfigName: string | undefined): void { + this._taskApiConfigName = apiConfigName + } + static create(options: TaskOptions): [Task, Promise] { const instance = new Task({ ...options, startTask: false }) const { images, task, historyItem } = options @@ -1065,6 +1220,10 @@ export class Task extends EventEmitter implements TaskLike { globalStoragePath: this.globalStoragePath, }) + if (this._taskApiConfigName === undefined) { + await this.taskApiConfigReady + } + const { historyItem, tokenUsage } = await taskMetadata({ taskId: this.taskId, rootTaskId: this.rootTaskId, @@ -1074,6 +1233,7 @@ export class Task extends EventEmitter implements TaskLike { globalStoragePath: this.globalStoragePath, workspace: this.cwd, mode: this._taskMode || defaultModeSlug, // Use the task's own mode, not the current provider mode. + apiConfigName: this._taskApiConfigName, // Use the task's own provider profile, not the current provider profile. initialStatus: this.initialStatus, toolProtocol: this._taskToolProtocol, // Persist the locked tool protocol. }) @@ -4412,7 +4572,7 @@ export class Task extends EventEmitter implements TaskLike { // Respect provider rate limit window let rateLimitDelay = 0 - const rateLimit = state?.apiConfiguration?.rateLimitSeconds ?? 1 + const rateLimit = (state?.apiConfiguration ?? this.apiConfiguration)?.rateLimitSeconds || 0 if (Task.lastGlobalApiRequestTime && rateLimit > 0) { const elapsed = performance.now() - Task.lastGlobalApiRequestTime rateLimitDelay = Math.ceil(Math.min(rateLimit, Math.max(0, rateLimit * 1000 - elapsed) / 1000)) diff --git a/src/core/task/__tests__/Task.sticky-profile-race.spec.ts b/src/core/task/__tests__/Task.sticky-profile-race.spec.ts new file mode 100644 index 0000000000..e78301541d --- /dev/null +++ b/src/core/task/__tests__/Task.sticky-profile-race.spec.ts @@ -0,0 +1,142 @@ +// npx vitest run core/task/__tests__/Task.sticky-profile-race.spec.ts + +import * as vscode from "vscode" + +import type { ProviderSettings } from "@roo-code/types" +import { Task } from "../Task" +import { ClineProvider } from "../../webview/ClineProvider" + +vi.mock("@roo-code/telemetry", () => ({ + TelemetryService: { + hasInstance: vi.fn().mockReturnValue(true), + createInstance: vi.fn(), + get instance() { + return { + captureTaskCreated: vi.fn(), + captureTaskRestarted: vi.fn(), + captureModeSwitch: vi.fn(), + captureConversationMessage: vi.fn(), + captureLlmCompletion: vi.fn(), + captureConsecutiveMistakeError: vi.fn(), + captureCodeActionUsed: vi.fn(), + setProvider: vi.fn(), + } + }, + }, +})) + +vi.mock("vscode", () => { + const mockDisposable = { dispose: vi.fn() } + const mockEventEmitter = { event: vi.fn(), fire: vi.fn() } + const mockTextDocument = { uri: { fsPath: "/mock/workspace/path/file.ts" } } + const mockTextEditor = { document: mockTextDocument } + const mockTab = { input: { uri: { fsPath: "/mock/workspace/path/file.ts" } } } + const mockTabGroup = { tabs: [mockTab] } + + return { + TabInputTextDiff: vi.fn(), + CodeActionKind: { + QuickFix: { value: "quickfix" }, + RefactorRewrite: { value: "refactor.rewrite" }, + }, + window: { + createTextEditorDecorationType: vi.fn().mockReturnValue({ + dispose: vi.fn(), + }), + visibleTextEditors: [mockTextEditor], + tabGroups: { + all: [mockTabGroup], + close: vi.fn(), + onDidChangeTabs: vi.fn(() => ({ dispose: vi.fn() })), + }, + showErrorMessage: vi.fn(), + }, + workspace: { + getConfiguration: vi.fn(() => ({ get: (_k: string, d: any) => d })), + workspaceFolders: [ + { + uri: { fsPath: "/mock/workspace/path" }, + name: "mock-workspace", + index: 0, + }, + ], + createFileSystemWatcher: vi.fn(() => ({ + onDidCreate: vi.fn(() => mockDisposable), + onDidDelete: vi.fn(() => mockDisposable), + onDidChange: vi.fn(() => mockDisposable), + dispose: vi.fn(), + })), + fs: { + stat: vi.fn().mockResolvedValue({ type: 1 }), + }, + onDidSaveTextDocument: vi.fn(() => mockDisposable), + }, + env: { + uriScheme: "vscode", + language: "en", + }, + EventEmitter: vi.fn().mockImplementation(() => mockEventEmitter), + Disposable: { + from: vi.fn(), + }, + TabInputText: vi.fn(), + version: "1.85.0", + } +}) + +vi.mock("../../environment/getEnvironmentDetails", () => ({ + getEnvironmentDetails: vi.fn().mockResolvedValue(""), +})) + +vi.mock("../../ignore/RooIgnoreController") + +vi.mock("p-wait-for", () => ({ + default: vi.fn().mockImplementation(async () => Promise.resolve()), +})) + +vi.mock("delay", () => ({ + __esModule: true, + default: vi.fn().mockResolvedValue(undefined), +})) + +describe("Task - sticky provider profile init race", () => { + it("does not overwrite task apiConfigName if set during async initialization", async () => { + const apiConfig: ProviderSettings = { + apiProvider: "anthropic", + apiModelId: "claude-3-5-sonnet-20241022", + apiKey: "test-api-key", + } as any + + let resolveGetState: ((v: any) => void) | undefined + const getStatePromise = new Promise((resolve) => { + resolveGetState = resolve + }) + + const mockProvider = { + context: { + globalStorageUri: { fsPath: "/test/storage" }, + }, + getState: vi.fn().mockImplementation(() => getStatePromise), + log: vi.fn(), + on: vi.fn(), + off: vi.fn(), + postStateToWebview: vi.fn().mockResolvedValue(undefined), + updateTaskHistory: vi.fn().mockResolvedValue(undefined), + } as unknown as ClineProvider + + const task = new Task({ + provider: mockProvider, + apiConfiguration: apiConfig, + task: "test task", + startTask: false, + }) + + // Simulate a profile switch happening before provider.getState resolves. + task.setTaskApiConfigName("new-profile") + + resolveGetState?.({ currentApiConfigName: "old-profile" }) + await task.waitForApiConfigInitialization() + + expect(task.taskApiConfigName).toBe("new-profile") + }) +}) diff --git a/src/core/webview/ClineProvider.ts b/src/core/webview/ClineProvider.ts index 0bd0ccc2cf..89733782e0 100644 --- a/src/core/webview/ClineProvider.ts +++ b/src/core/webview/ClineProvider.ts @@ -1035,29 +1035,64 @@ export class ClineProvider await this.updateGlobalState("mode", historyItem.mode) // Load the saved API config for the restored mode if it exists. - const savedConfigId = await this.providerSettingsManager.getModeConfigId(historyItem.mode) - const listApiConfig = await this.providerSettingsManager.listConfig() + // Skip mode-based profile activation if historyItem.apiConfigName exists, + // since the task's specific provider profile will override it anyway. + if (!historyItem.apiConfigName) { + const savedConfigId = await this.providerSettingsManager.getModeConfigId(historyItem.mode) + const listApiConfig = await this.providerSettingsManager.listConfig() + + // Update listApiConfigMeta first to ensure UI has latest data. + await this.updateGlobalState("listApiConfigMeta", listApiConfig) + + // If this mode has a saved config, use it. + if (savedConfigId) { + const profile = listApiConfig.find(({ id }) => id === savedConfigId) + + if (profile?.name) { + try { + await this.activateProviderProfile({ name: profile.name }) + } catch (error) { + // Log the error but continue with task restoration. + this.log( + `Failed to restore API configuration for mode '${historyItem.mode}': ${ + error instanceof Error ? error.message : String(error) + }. Continuing with default configuration.`, + ) + // The task will continue with the current/default configuration. + } + } + } + } + } - // Update listApiConfigMeta first to ensure UI has latest data. + // If the history item has a saved API config name (provider profile), restore it. + // This overrides any mode-based config restoration above, because the task's + // specific provider profile takes precedence over mode defaults. + if (historyItem.apiConfigName) { + const listApiConfig = await this.providerSettingsManager.listConfig() + // Keep global state/UI in sync with latest profiles for parity with mode restoration above. await this.updateGlobalState("listApiConfigMeta", listApiConfig) + const profile = listApiConfig.find(({ name }) => name === historyItem.apiConfigName) - // If this mode has a saved config, use it. - if (savedConfigId) { - const profile = listApiConfig.find(({ id }) => id === savedConfigId) - - if (profile?.name) { - try { - await this.activateProviderProfile({ name: profile.name }) - } catch (error) { - // Log the error but continue with task restoration. - this.log( - `Failed to restore API configuration for mode '${historyItem.mode}': ${ - error instanceof Error ? error.message : String(error) - }. Continuing with default configuration.`, - ) - // The task will continue with the current/default configuration. - } + if (profile?.name) { + try { + await this.activateProviderProfile( + { name: profile.name }, + { persistModeConfig: false, persistTaskHistory: false }, + ) + } catch (error) { + // Log the error but continue with task restoration. + this.log( + `Failed to restore API configuration '${historyItem.apiConfigName}' for task: ${ + error instanceof Error ? error.message : String(error) + }. Continuing with current configuration.`, + ) } + } else { + // Profile no longer exists, log warning but continue + this.log( + `Provider profile '${historyItem.apiConfigName}' from history no longer exists. Using current configuration.`, + ) } } @@ -1560,6 +1595,9 @@ export class ClineProvider // Change the provider for the current task. // TODO: We should rename `buildApiHandler` for clarity (e.g. `getProviderClient`). this.updateTaskApiHandlerIfNeeded(providerSettings, { forceRebuild: true }) + + // Keep the current task's sticky provider profile in sync with the newly-activated profile. + await this.persistStickyProviderProfileToCurrentTask(name) } else { await this.updateGlobalState("listApiConfigMeta", await this.providerSettingsManager.listConfig()) } @@ -1599,9 +1637,42 @@ export class ClineProvider await this.postStateToWebview() } - async activateProviderProfile(args: { name: string } | { id: string }) { + private async persistStickyProviderProfileToCurrentTask(apiConfigName: string): Promise { + const task = this.getCurrentTask() + if (!task) { + return + } + + try { + // Update in-memory state immediately so sticky behavior works even before the task has + // been persisted into taskHistory (it will be captured on the next save). + task.setTaskApiConfigName(apiConfigName) + + const history = this.getGlobalState("taskHistory") ?? [] + const taskHistoryItem = history.find((item) => item.id === task.taskId) + + if (taskHistoryItem) { + await this.updateTaskHistory({ ...taskHistoryItem, apiConfigName }) + } + } catch (error) { + // If persistence fails, log the error but don't fail the profile switch. + this.log( + `Failed to persist provider profile switch for task ${task.taskId}: ${ + error instanceof Error ? error.message : String(error) + }`, + ) + } + } + + async activateProviderProfile( + args: { name: string } | { id: string }, + options?: { persistModeConfig?: boolean; persistTaskHistory?: boolean }, + ) { const { name, id, ...providerSettings } = await this.providerSettingsManager.activateProfile(args) + const persistModeConfig = options?.persistModeConfig ?? true + const persistTaskHistory = options?.persistTaskHistory ?? true + // See `upsertProviderProfile` for a description of what this is doing. await Promise.all([ this.contextProxy.setValue("listApiConfigMeta", await this.providerSettingsManager.listConfig()), @@ -1611,12 +1682,19 @@ export class ClineProvider const { mode } = await this.getState() - if (id) { + if (id && persistModeConfig) { await this.providerSettingsManager.setModeConfig(mode, id) } + // Change the provider for the current task. this.updateTaskApiHandlerIfNeeded(providerSettings, { forceRebuild: true }) + // Update the current task's sticky provider profile, unless this activation is + // being used purely as a non-persisting restoration (e.g., reopening a task from history). + if (persistTaskHistory) { + await this.persistStickyProviderProfileToCurrentTask(name) + } + await this.postStateToWebview() if (providerSettings.apiProvider) { diff --git a/src/core/webview/__tests__/ClineProvider.spec.ts b/src/core/webview/__tests__/ClineProvider.spec.ts index c465837a82..0db71efc3a 100644 --- a/src/core/webview/__tests__/ClineProvider.spec.ts +++ b/src/core/webview/__tests__/ClineProvider.spec.ts @@ -3120,7 +3120,7 @@ describe("ClineProvider - Comprehensive Edit/Delete Edge Cases", () => { expect(mockCline.overwriteClineMessages).toHaveBeenCalledWith([mockMessages[0]]) expect(mockCline.overwriteApiConversationHistory).toHaveBeenCalledWith([{ ts: 1000 }]) // Verify submitUserMessage was called with the edited content - expect(mockCline.submitUserMessage).toHaveBeenCalledWith("Edited message with preserved images", undefined) + expect(mockCline.submitUserMessage).toHaveBeenCalledWith("Edited message with preserved images", []) }) test("handles editing messages with file attachments", async () => { @@ -3173,7 +3173,7 @@ describe("ClineProvider - Comprehensive Edit/Delete Edge Cases", () => { }) expect(mockCline.overwriteClineMessages).toHaveBeenCalled() - expect(mockCline.submitUserMessage).toHaveBeenCalledWith("Edited message with file attachment", undefined) + expect(mockCline.submitUserMessage).toHaveBeenCalledWith("Edited message with file attachment", []) }) }) @@ -3704,7 +3704,7 @@ describe("ClineProvider - Comprehensive Edit/Delete Edge Cases", () => { await messageHandler({ type: "editMessageConfirm", messageTs: 2000, text: largeEditedContent }) expect(mockCline.overwriteClineMessages).toHaveBeenCalled() - expect(mockCline.submitUserMessage).toHaveBeenCalledWith(largeEditedContent, undefined) + expect(mockCline.submitUserMessage).toHaveBeenCalledWith(largeEditedContent, []) }) test("handles deleting messages with large payloads", async () => { diff --git a/src/core/webview/__tests__/ClineProvider.sticky-profile.spec.ts b/src/core/webview/__tests__/ClineProvider.sticky-profile.spec.ts new file mode 100644 index 0000000000..3df4408b71 --- /dev/null +++ b/src/core/webview/__tests__/ClineProvider.sticky-profile.spec.ts @@ -0,0 +1,883 @@ +// npx vitest run core/webview/__tests__/ClineProvider.sticky-profile.spec.ts + +import * as vscode from "vscode" +import { TelemetryService } from "@roo-code/telemetry" +import { ClineProvider } from "../ClineProvider" +import { ContextProxy } from "../../config/ContextProxy" +import type { HistoryItem } from "@roo-code/types" + +vi.mock("vscode", () => ({ + ExtensionContext: vi.fn(), + OutputChannel: vi.fn(), + WebviewView: vi.fn(), + Uri: { + joinPath: vi.fn(), + file: vi.fn(), + }, + CodeActionKind: { + QuickFix: { value: "quickfix" }, + RefactorRewrite: { value: "refactor.rewrite" }, + }, + commands: { + executeCommand: vi.fn().mockResolvedValue(undefined), + }, + window: { + showInformationMessage: vi.fn(), + showWarningMessage: vi.fn(), + showErrorMessage: vi.fn(), + onDidChangeActiveTextEditor: vi.fn(() => ({ dispose: vi.fn() })), + }, + workspace: { + getConfiguration: vi.fn().mockReturnValue({ + get: vi.fn().mockReturnValue([]), + update: vi.fn(), + }), + onDidChangeConfiguration: vi.fn().mockImplementation(() => ({ + dispose: vi.fn(), + })), + onDidSaveTextDocument: vi.fn(() => ({ dispose: vi.fn() })), + onDidChangeTextDocument: vi.fn(() => ({ dispose: vi.fn() })), + onDidOpenTextDocument: vi.fn(() => ({ dispose: vi.fn() })), + onDidCloseTextDocument: vi.fn(() => ({ dispose: vi.fn() })), + }, + env: { + uriScheme: "vscode", + language: "en", + appName: "Visual Studio Code", + }, + ExtensionMode: { + Production: 1, + Development: 2, + Test: 3, + }, + version: "1.85.0", +})) + +// Create a counter for unique task IDs. +let taskIdCounter = 0 + +vi.mock("../../task/Task", () => ({ + Task: vi.fn().mockImplementation((options) => ({ + taskId: options.taskId || `test-task-id-${++taskIdCounter}`, + saveClineMessages: vi.fn(), + clineMessages: [], + apiConversationHistory: [], + overwriteClineMessages: vi.fn(), + overwriteApiConversationHistory: vi.fn(), + abortTask: vi.fn(), + handleWebviewAskResponse: vi.fn(), + getTaskNumber: vi.fn().mockReturnValue(0), + setTaskNumber: vi.fn(), + setParentTask: vi.fn(), + setRootTask: vi.fn(), + emit: vi.fn(), + parentTask: options.parentTask, + updateApiConfiguration: vi.fn(), + setTaskApiConfigName: vi.fn(), + _taskApiConfigName: options.historyItem?.apiConfigName, + taskApiConfigName: options.historyItem?.apiConfigName, + })), +})) + +vi.mock("../../prompts/sections/custom-instructions") + +vi.mock("../../../utils/safeWriteJson") + +vi.mock("../../../api", () => ({ + buildApiHandler: vi.fn().mockReturnValue({ + getModel: vi.fn().mockReturnValue({ + id: "claude-3-sonnet", + }), + }), +})) + +vi.mock("../../../integrations/workspace/WorkspaceTracker", () => ({ + default: vi.fn().mockImplementation(() => ({ + initializeFilePaths: vi.fn(), + dispose: vi.fn(), + })), +})) + +vi.mock("../../diff/strategies/multi-search-replace", () => ({ + MultiSearchReplaceDiffStrategy: vi.fn().mockImplementation(() => ({ + getToolDescription: () => "test", + getName: () => "test-strategy", + applyDiff: vi.fn(), + })), +})) + +vi.mock("@roo-code/cloud", () => ({ + CloudService: { + hasInstance: vi.fn().mockReturnValue(true), + get instance() { + return { + isAuthenticated: vi.fn().mockReturnValue(false), + } + }, + }, + BridgeOrchestrator: { + isEnabled: vi.fn().mockReturnValue(false), + }, + getRooCodeApiUrl: vi.fn().mockReturnValue("https://app.roocode.com"), +})) + +vi.mock("../../../shared/modes", () => ({ + modes: [ + { + slug: "code", + name: "Code Mode", + roleDefinition: "You are a code assistant", + groups: ["read", "edit", "browser"], + }, + { + slug: "architect", + name: "Architect Mode", + roleDefinition: "You are an architect", + groups: ["read", "edit"], + }, + ], + getModeBySlug: vi.fn().mockReturnValue({ + slug: "code", + name: "Code Mode", + roleDefinition: "You are a code assistant", + groups: ["read", "edit", "browser"], + }), + defaultModeSlug: "code", +})) + +vi.mock("../../prompts/system", () => ({ + SYSTEM_PROMPT: vi.fn().mockResolvedValue("mocked system prompt"), + codeMode: "code", +})) + +vi.mock("../../../api/providers/fetchers/modelCache", () => ({ + getModels: vi.fn().mockResolvedValue({}), + flushModels: vi.fn(), +})) + +vi.mock("../../../integrations/misc/extract-text", () => ({ + extractTextFromFile: vi.fn().mockResolvedValue("Mock file content"), +})) + +vi.mock("p-wait-for", () => ({ + default: vi.fn().mockImplementation(async () => Promise.resolve()), +})) + +vi.mock("fs/promises", () => ({ + mkdir: vi.fn().mockResolvedValue(undefined), + writeFile: vi.fn().mockResolvedValue(undefined), + readFile: vi.fn().mockResolvedValue(""), + unlink: vi.fn().mockResolvedValue(undefined), + rmdir: vi.fn().mockResolvedValue(undefined), +})) + +vi.mock("@roo-code/telemetry", () => ({ + TelemetryService: { + hasInstance: vi.fn().mockReturnValue(true), + createInstance: vi.fn(), + get instance() { + return { + trackEvent: vi.fn(), + trackError: vi.fn(), + setProvider: vi.fn(), + captureModeSwitch: vi.fn(), + } + }, + }, +})) + +describe("ClineProvider - Sticky Provider Profile", () => { + let provider: ClineProvider + let mockContext: vscode.ExtensionContext + let mockOutputChannel: vscode.OutputChannel + let mockWebviewView: vscode.WebviewView + let mockPostMessage: any + + beforeEach(() => { + vi.clearAllMocks() + taskIdCounter = 0 + + if (!TelemetryService.hasInstance()) { + TelemetryService.createInstance([]) + } + + const globalState: Record = { + mode: "code", + currentApiConfigName: "default-profile", + } + + const secrets: Record = {} + + mockContext = { + extensionPath: "/test/path", + extensionUri: {} as vscode.Uri, + globalState: { + get: vi.fn().mockImplementation((key: string) => globalState[key]), + update: vi.fn().mockImplementation((key: string, value: string | undefined) => { + globalState[key] = value + return Promise.resolve() + }), + keys: vi.fn().mockImplementation(() => Object.keys(globalState)), + }, + secrets: { + get: vi.fn().mockImplementation((key: string) => secrets[key]), + store: vi.fn().mockImplementation((key: string, value: string | undefined) => { + secrets[key] = value + return Promise.resolve() + }), + delete: vi.fn().mockImplementation((key: string) => { + delete secrets[key] + return Promise.resolve() + }), + }, + subscriptions: [], + extension: { + packageJSON: { version: "1.0.0" }, + }, + globalStorageUri: { + fsPath: "/test/storage/path", + }, + } as unknown as vscode.ExtensionContext + + mockOutputChannel = { + appendLine: vi.fn(), + clear: vi.fn(), + dispose: vi.fn(), + } as unknown as vscode.OutputChannel + + mockPostMessage = vi.fn() + + mockWebviewView = { + webview: { + postMessage: mockPostMessage, + html: "", + options: {}, + onDidReceiveMessage: vi.fn(), + asWebviewUri: vi.fn(), + cspSource: "vscode-webview://test-csp-source", + }, + visible: true, + onDidDispose: vi.fn().mockImplementation((callback) => { + callback() + return { dispose: vi.fn() } + }), + onDidChangeVisibility: vi.fn().mockImplementation(() => ({ dispose: vi.fn() })), + } as unknown as vscode.WebviewView + + provider = new ClineProvider(mockContext, mockOutputChannel, "sidebar", new ContextProxy(mockContext)) + + // Mock getMcpHub method + provider.getMcpHub = vi.fn().mockReturnValue({ + listTools: vi.fn().mockResolvedValue([]), + callTool: vi.fn().mockResolvedValue({ content: [] }), + listResources: vi.fn().mockResolvedValue([]), + readResource: vi.fn().mockResolvedValue({ contents: [] }), + getAllServers: vi.fn().mockReturnValue([]), + }) + }) + + describe("activateProviderProfile", () => { + beforeEach(async () => { + await provider.resolveWebviewView(mockWebviewView) + }) + + it("should save provider profile to task metadata when switching profiles", async () => { + // Create a mock task + const mockTask = { + taskId: "test-task-id", + _taskApiConfigName: "default-profile", + setTaskApiConfigName: vi.fn(), + emit: vi.fn(), + saveClineMessages: vi.fn(), + clineMessages: [], + apiConversationHistory: [], + updateApiConfiguration: vi.fn(), + } + + // Add task to provider stack + await provider.addClineToStack(mockTask as any) + + // Mock getGlobalState to return task history + vi.spyOn(provider as any, "getGlobalState").mockReturnValue([ + { + id: mockTask.taskId, + ts: Date.now(), + task: "Test task", + number: 1, + tokensIn: 0, + tokensOut: 0, + cacheWrites: 0, + cacheReads: 0, + totalCost: 0, + }, + ]) + + // Mock updateTaskHistory to track calls + const updateTaskHistorySpy = vi + .spyOn(provider, "updateTaskHistory") + .mockImplementation(() => Promise.resolve([])) + + // Mock providerSettingsManager.activateProfile + vi.spyOn(provider.providerSettingsManager, "activateProfile").mockResolvedValue({ + name: "new-profile", + id: "new-profile-id", + apiProvider: "anthropic", + }) + + // Mock providerSettingsManager.listConfig + vi.spyOn(provider.providerSettingsManager, "listConfig").mockResolvedValue([ + { name: "new-profile", id: "new-profile-id", apiProvider: "anthropic" }, + ]) + + // Switch provider profile + await provider.activateProviderProfile({ name: "new-profile" }) + + // Verify task history was updated with new provider profile + expect(updateTaskHistorySpy).toHaveBeenCalledWith( + expect.objectContaining({ + id: mockTask.taskId, + apiConfigName: "new-profile", + }), + ) + + // Verify task's setTaskApiConfigName was called + expect(mockTask.setTaskApiConfigName).toHaveBeenCalledWith("new-profile") + }) + + it("should update task's taskApiConfigName property when switching profiles", async () => { + // Create a mock task with initial profile + const mockTask = { + taskId: "test-task-id", + _taskApiConfigName: "default-profile", + setTaskApiConfigName: vi.fn().mockImplementation(function (this: any, name: string) { + this._taskApiConfigName = name + }), + emit: vi.fn(), + saveClineMessages: vi.fn(), + clineMessages: [], + apiConversationHistory: [], + updateApiConfiguration: vi.fn(), + } + + // Add task to provider stack + await provider.addClineToStack(mockTask as any) + + // Mock getGlobalState to return task history + vi.spyOn(provider as any, "getGlobalState").mockReturnValue([ + { + id: mockTask.taskId, + ts: Date.now(), + task: "Test task", + number: 1, + tokensIn: 0, + tokensOut: 0, + cacheWrites: 0, + cacheReads: 0, + totalCost: 0, + }, + ]) + + // Mock updateTaskHistory + vi.spyOn(provider, "updateTaskHistory").mockImplementation(() => Promise.resolve([])) + + // Mock providerSettingsManager.activateProfile + vi.spyOn(provider.providerSettingsManager, "activateProfile").mockResolvedValue({ + name: "new-profile", + id: "new-profile-id", + apiProvider: "openrouter", + }) + + // Mock providerSettingsManager.listConfig + vi.spyOn(provider.providerSettingsManager, "listConfig").mockResolvedValue([ + { name: "new-profile", id: "new-profile-id", apiProvider: "openrouter" }, + ]) + + // Switch provider profile + await provider.activateProviderProfile({ name: "new-profile" }) + + // Verify task's _taskApiConfigName property was updated + expect(mockTask._taskApiConfigName).toBe("new-profile") + }) + + it("should update in-memory task profile even if task history item does not exist yet", async () => { + await provider.resolveWebviewView(mockWebviewView) + + const mockTask = { + taskId: "test-task-id", + _taskApiConfigName: "default-profile", + setTaskApiConfigName: vi.fn().mockImplementation(function (this: any, name: string) { + this._taskApiConfigName = name + }), + emit: vi.fn(), + saveClineMessages: vi.fn(), + clineMessages: [], + apiConversationHistory: [], + updateApiConfiguration: vi.fn(), + } + + await provider.addClineToStack(mockTask as any) + + // No history item exists yet + vi.spyOn(provider as any, "getGlobalState").mockReturnValue([]) + + const updateTaskHistorySpy = vi + .spyOn(provider, "updateTaskHistory") + .mockImplementation(() => Promise.resolve([])) + + vi.spyOn(provider.providerSettingsManager, "activateProfile").mockResolvedValue({ + name: "new-profile", + id: "new-profile-id", + apiProvider: "openrouter", + }) + + vi.spyOn(provider.providerSettingsManager, "listConfig").mockResolvedValue([ + { name: "new-profile", id: "new-profile-id", apiProvider: "openrouter" }, + ]) + + await provider.activateProviderProfile({ name: "new-profile" }) + + // In-memory should still update, even without a history item. + expect(mockTask._taskApiConfigName).toBe("new-profile") + // No history item => no updateTaskHistory call. + expect(updateTaskHistorySpy).not.toHaveBeenCalled() + }) + }) + + describe("createTaskWithHistoryItem", () => { + it("should restore provider profile from history item when reopening task", async () => { + await provider.resolveWebviewView(mockWebviewView) + + // Create a history item with saved provider profile + const historyItem: HistoryItem = { + id: "test-task-id", + number: 1, + ts: Date.now(), + task: "Test task", + tokensIn: 100, + tokensOut: 200, + cacheWrites: 0, + cacheReads: 0, + totalCost: 0.001, + mode: "code", + apiConfigName: "saved-profile", // Saved provider profile + } + + // Mock activateProviderProfile to track calls + const activateProviderProfileSpy = vi + .spyOn(provider, "activateProviderProfile") + .mockResolvedValue(undefined) + + // Mock providerSettingsManager.listConfig + vi.spyOn(provider.providerSettingsManager, "listConfig").mockResolvedValue([ + { name: "saved-profile", id: "saved-profile-id", apiProvider: "anthropic" }, + ]) + + // Initialize task with history item + await provider.createTaskWithHistoryItem(historyItem) + + // Verify provider profile was restored via activateProviderProfile (restore-only: don't persist mode config) + expect(activateProviderProfileSpy).toHaveBeenCalledWith( + { name: "saved-profile" }, + { persistModeConfig: false, persistTaskHistory: false }, + ) + }) + + it("should use current profile if history item has no saved apiConfigName", async () => { + await provider.resolveWebviewView(mockWebviewView) + + // Create a history item without saved provider profile + const historyItem: HistoryItem = { + id: "test-task-id", + number: 1, + ts: Date.now(), + task: "Test task", + tokensIn: 100, + tokensOut: 200, + cacheWrites: 0, + cacheReads: 0, + totalCost: 0.001, + // No apiConfigName field + } + + // Mock activateProviderProfile to track calls + const activateProviderProfileSpy = vi + .spyOn(provider, "activateProviderProfile") + .mockResolvedValue(undefined) + + // Initialize task with history item + await provider.createTaskWithHistoryItem(historyItem) + + // Verify activateProviderProfile was NOT called for apiConfigName restoration + // (it might be called for mode-based config, but not for direct apiConfigName) + const callsForApiConfigName = activateProviderProfileSpy.mock.calls.filter( + (call) => call[0] && "name" in call[0] && call[0].name === historyItem.apiConfigName, + ) + expect(callsForApiConfigName.length).toBe(0) + }) + + it("should override mode-based config with task's apiConfigName", async () => { + await provider.resolveWebviewView(mockWebviewView) + + // Create a history item with both mode and apiConfigName + const historyItem: HistoryItem = { + id: "test-task-id", + number: 1, + ts: Date.now(), + task: "Test task", + tokensIn: 100, + tokensOut: 200, + cacheWrites: 0, + cacheReads: 0, + totalCost: 0.001, + mode: "architect", // Mode has a different preferred profile + apiConfigName: "task-specific-profile", // Task's actual profile + } + + // Track all activateProviderProfile calls + const activateCalls: string[] = [] + vi.spyOn(provider, "activateProviderProfile").mockImplementation(async (args) => { + if ("name" in args) { + activateCalls.push(args.name) + } + }) + + // Mock providerSettingsManager methods + vi.spyOn(provider.providerSettingsManager, "getModeConfigId").mockResolvedValue("mode-config-id") + vi.spyOn(provider.providerSettingsManager, "listConfig").mockResolvedValue([ + { name: "mode-preferred-profile", id: "mode-config-id", apiProvider: "anthropic" }, + { name: "task-specific-profile", id: "task-profile-id", apiProvider: "openai" }, + ]) + + // Initialize task with history item + await provider.createTaskWithHistoryItem(historyItem) + + // Verify task's apiConfigName was activated LAST (overriding mode-based config) + expect(activateCalls[activateCalls.length - 1]).toBe("task-specific-profile") + }) + + it("should handle missing provider profile gracefully", async () => { + await provider.resolveWebviewView(mockWebviewView) + + // Create a history item with a provider profile that no longer exists + const historyItem: HistoryItem = { + id: "test-task-id", + number: 1, + ts: Date.now(), + task: "Test task", + tokensIn: 100, + tokensOut: 200, + cacheWrites: 0, + cacheReads: 0, + totalCost: 0.001, + apiConfigName: "deleted-profile", // Profile that doesn't exist + } + + // Mock providerSettingsManager.listConfig to return empty (profile doesn't exist) + vi.spyOn(provider.providerSettingsManager, "listConfig").mockResolvedValue([]) + + // Mock log to verify warning is logged + const logSpy = vi.spyOn(provider, "log") + + // Initialize task with history item - should not throw + await expect(provider.createTaskWithHistoryItem(historyItem)).resolves.not.toThrow() + + // Verify a warning was logged + expect(logSpy).toHaveBeenCalledWith( + expect.stringContaining("Provider profile 'deleted-profile' from history no longer exists"), + ) + }) + }) + + describe("Task metadata persistence", () => { + it("should include apiConfigName in task metadata when saving", async () => { + await provider.resolveWebviewView(mockWebviewView) + + // Create a mock task with provider profile + const mockTask = { + taskId: "test-task-id", + _taskApiConfigName: "test-profile", + setTaskApiConfigName: vi.fn(), + emit: vi.fn(), + saveClineMessages: vi.fn(), + clineMessages: [], + apiConversationHistory: [], + updateApiConfiguration: vi.fn(), + } + + // Mock getGlobalState to return task history with our task + vi.spyOn(provider as any, "getGlobalState").mockReturnValue([ + { + id: mockTask.taskId, + ts: Date.now(), + task: "Test task", + number: 1, + tokensIn: 0, + tokensOut: 0, + cacheWrites: 0, + cacheReads: 0, + totalCost: 0, + }, + ]) + + // Mock updateTaskHistory to capture the updated history item + let updatedHistoryItem: any + vi.spyOn(provider, "updateTaskHistory").mockImplementation((item) => { + updatedHistoryItem = item + return Promise.resolve([item]) + }) + + // Add task to provider stack + await provider.addClineToStack(mockTask as any) + + // Mock providerSettingsManager.activateProfile + vi.spyOn(provider.providerSettingsManager, "activateProfile").mockResolvedValue({ + name: "new-profile", + id: "new-profile-id", + apiProvider: "anthropic", + }) + + // Mock providerSettingsManager.listConfig + vi.spyOn(provider.providerSettingsManager, "listConfig").mockResolvedValue([ + { name: "new-profile", id: "new-profile-id", apiProvider: "anthropic" }, + ]) + + // Trigger a profile switch + await provider.activateProviderProfile({ name: "new-profile" }) + + // Verify apiConfigName was included in the updated history item + expect(updatedHistoryItem).toBeDefined() + expect(updatedHistoryItem.apiConfigName).toBe("new-profile") + }) + }) + + describe("Multiple workspaces isolation", () => { + it("should preserve task profile when switching profiles in another workspace", async () => { + // This test verifies that each task retains its designated provider profile + // so that switching profiles in one workspace doesn't alter other tasks + + await provider.resolveWebviewView(mockWebviewView) + + // Create task 1 with profile A + const task1 = { + taskId: "task-1", + _taskApiConfigName: "profile-a", + setTaskApiConfigName: vi.fn().mockImplementation(function (this: any, name: string) { + this._taskApiConfigName = name + }), + emit: vi.fn(), + saveClineMessages: vi.fn(), + clineMessages: [], + apiConversationHistory: [], + updateApiConfiguration: vi.fn(), + } + + // Create task 2 with profile B + const task2 = { + taskId: "task-2", + _taskApiConfigName: "profile-b", + setTaskApiConfigName: vi.fn().mockImplementation(function (this: any, name: string) { + this._taskApiConfigName = name + }), + emit: vi.fn(), + saveClineMessages: vi.fn(), + clineMessages: [], + apiConversationHistory: [], + updateApiConfiguration: vi.fn(), + } + + // Add task 1 to stack + await provider.addClineToStack(task1 as any) + + // Mock getGlobalState to return task history for both tasks + const taskHistory = [ + { + id: "task-1", + ts: Date.now(), + task: "Task 1", + number: 1, + tokensIn: 0, + tokensOut: 0, + cacheWrites: 0, + cacheReads: 0, + totalCost: 0, + apiConfigName: "profile-a", + }, + { + id: "task-2", + ts: Date.now(), + task: "Task 2", + number: 2, + tokensIn: 0, + tokensOut: 0, + cacheWrites: 0, + cacheReads: 0, + totalCost: 0, + apiConfigName: "profile-b", + }, + ] + + vi.spyOn(provider as any, "getGlobalState").mockReturnValue(taskHistory) + + // Mock updateTaskHistory + vi.spyOn(provider, "updateTaskHistory").mockImplementation((item) => { + const index = taskHistory.findIndex((h) => h.id === item.id) + if (index >= 0) { + taskHistory[index] = { ...taskHistory[index], ...item } + } + return Promise.resolve(taskHistory) + }) + + // Mock providerSettingsManager.activateProfile + vi.spyOn(provider.providerSettingsManager, "activateProfile").mockResolvedValue({ + name: "profile-c", + id: "profile-c-id", + apiProvider: "anthropic", + }) + + // Mock providerSettingsManager.listConfig + vi.spyOn(provider.providerSettingsManager, "listConfig").mockResolvedValue([ + { name: "profile-a", id: "profile-a-id", apiProvider: "anthropic" }, + { name: "profile-b", id: "profile-b-id", apiProvider: "openai" }, + { name: "profile-c", id: "profile-c-id", apiProvider: "anthropic" }, + ]) + + // Switch task 1's profile to profile C + await provider.activateProviderProfile({ name: "profile-c" }) + + // Verify task 1's profile was updated + expect(task1._taskApiConfigName).toBe("profile-c") + expect(taskHistory[0].apiConfigName).toBe("profile-c") + + // Verify task 2's profile remains unchanged + expect(taskHistory[1].apiConfigName).toBe("profile-b") + }) + }) + + describe("Error handling", () => { + it("should handle errors gracefully when saving profile fails", async () => { + await provider.resolveWebviewView(mockWebviewView) + + // Create a mock task + const mockTask = { + taskId: "test-task-id", + _taskApiConfigName: "default-profile", + setTaskApiConfigName: vi.fn(), + emit: vi.fn(), + saveClineMessages: vi.fn(), + clineMessages: [], + apiConversationHistory: [], + updateApiConfiguration: vi.fn(), + } + + // Add task to provider stack + await provider.addClineToStack(mockTask as any) + + // Mock getGlobalState + vi.spyOn(provider as any, "getGlobalState").mockReturnValue([ + { + id: mockTask.taskId, + ts: Date.now(), + task: "Test task", + number: 1, + tokensIn: 0, + tokensOut: 0, + cacheWrites: 0, + cacheReads: 0, + totalCost: 0, + }, + ]) + + // Mock updateTaskHistory to throw error + vi.spyOn(provider, "updateTaskHistory").mockRejectedValue(new Error("Save failed")) + + // Mock providerSettingsManager.activateProfile + vi.spyOn(provider.providerSettingsManager, "activateProfile").mockResolvedValue({ + name: "new-profile", + id: "new-profile-id", + apiProvider: "anthropic", + }) + + // Mock providerSettingsManager.listConfig + vi.spyOn(provider.providerSettingsManager, "listConfig").mockResolvedValue([ + { name: "new-profile", id: "new-profile-id", apiProvider: "anthropic" }, + ]) + + // Mock log to verify error is logged + const logSpy = vi.spyOn(provider, "log") + + // Switch provider profile - should not throw + await expect(provider.activateProviderProfile({ name: "new-profile" })).resolves.not.toThrow() + + // Verify error was logged + expect(logSpy).toHaveBeenCalledWith(expect.stringContaining("Failed to persist provider profile switch")) + }) + + it("should handle null/undefined apiConfigName gracefully", async () => { + await provider.resolveWebviewView(mockWebviewView) + + // Create a history item with null apiConfigName + const historyItem: HistoryItem = { + id: "test-task-id", + number: 1, + ts: Date.now(), + task: "Test task", + tokensIn: 100, + tokensOut: 200, + cacheWrites: 0, + cacheReads: 0, + totalCost: 0.001, + apiConfigName: null as any, // Invalid apiConfigName + } + + // Mock activateProviderProfile to track calls + const activateProviderProfileSpy = vi + .spyOn(provider, "activateProviderProfile") + .mockResolvedValue(undefined) + + // Initialize task with history item - should not throw + await expect(provider.createTaskWithHistoryItem(historyItem)).resolves.not.toThrow() + + // Verify activateProviderProfile was not called with null + expect(activateProviderProfileSpy).not.toHaveBeenCalledWith({ name: null }) + }) + }) + + describe("Profile restoration with activateProfile failure", () => { + it("should continue task restoration even if activateProviderProfile fails", async () => { + await provider.resolveWebviewView(mockWebviewView) + + // Create a history item with saved provider profile + const historyItem: HistoryItem = { + id: "test-task-id", + number: 1, + ts: Date.now(), + task: "Test task", + tokensIn: 100, + tokensOut: 200, + cacheWrites: 0, + cacheReads: 0, + totalCost: 0.001, + apiConfigName: "failing-profile", + } + + // Mock providerSettingsManager.listConfig to return the profile + vi.spyOn(provider.providerSettingsManager, "listConfig").mockResolvedValue([ + { name: "failing-profile", id: "failing-profile-id", apiProvider: "anthropic" }, + ]) + + // Mock activateProviderProfile to throw error + vi.spyOn(provider, "activateProviderProfile").mockRejectedValue(new Error("Activation failed")) + + // Mock log to verify error is logged + const logSpy = vi.spyOn(provider, "log") + + // Initialize task with history item - should not throw even though activation fails + await expect(provider.createTaskWithHistoryItem(historyItem)).resolves.not.toThrow() + + // Verify error was logged + expect(logSpy).toHaveBeenCalledWith( + expect.stringContaining("Failed to restore API configuration 'failing-profile' for task"), + ) + }) + }) +}) diff --git a/src/core/webview/__tests__/webviewMessageHandler.checkpoint.spec.ts b/src/core/webview/__tests__/webviewMessageHandler.checkpoint.spec.ts index bc40ff2e5c..6b87812268 100644 --- a/src/core/webview/__tests__/webviewMessageHandler.checkpoint.spec.ts +++ b/src/core/webview/__tests__/webviewMessageHandler.checkpoint.spec.ts @@ -83,6 +83,10 @@ describe("webviewMessageHandler - checkpoint operations", () => { contextProxy: { globalStorageUri: { fsPath: "/test/storage" }, }, + getState: vi.fn().mockResolvedValue({ + maxImageFileSize: 5, + maxTotalImageSize: 20, + }), } }) @@ -152,7 +156,7 @@ describe("webviewMessageHandler - checkpoint operations", () => { operation: "edit", editData: { editedContent: "Edited checkpoint message", - images: undefined, + images: [], apiConversationHistoryIndex: 0, }, }) diff --git a/src/core/webview/__tests__/webviewMessageHandler.edit.spec.ts b/src/core/webview/__tests__/webviewMessageHandler.edit.spec.ts index fef95d2542..d4a2e73f75 100644 --- a/src/core/webview/__tests__/webviewMessageHandler.edit.spec.ts +++ b/src/core/webview/__tests__/webviewMessageHandler.edit.spec.ts @@ -96,6 +96,10 @@ describe("webviewMessageHandler - Edit Message with Timestamp Fallback", () => { globalStorageUri: { fsPath: "/mock/storage" }, }, log: vi.fn(), + getState: vi.fn().mockResolvedValue({ + maxImageFileSize: 5, + maxTotalImageSize: 20, + }), } as unknown as ClineProvider }) diff --git a/src/core/webview/__tests__/webviewMessageHandler.imageMentions.integration.spec.ts b/src/core/webview/__tests__/webviewMessageHandler.imageMentions.integration.spec.ts new file mode 100644 index 0000000000..277e56626a --- /dev/null +++ b/src/core/webview/__tests__/webviewMessageHandler.imageMentions.integration.spec.ts @@ -0,0 +1,130 @@ +import * as fs from "fs/promises" +import * as path from "path" +import * as os from "os" + +// Must mock dependencies before importing the handler module. +vi.mock("../../../api/providers/fetchers/modelCache") + +import { webviewMessageHandler } from "../webviewMessageHandler" +import type { ClineProvider } from "../ClineProvider" + +vi.mock("vscode", () => ({ + window: { + showInformationMessage: vi.fn(), + showErrorMessage: vi.fn(), + }, + workspace: { + workspaceFolders: [{ uri: { fsPath: "/mock/workspace" } }], + }, +})) + +// Mock imageHelpers - use actual implementations for functions that need real file access +vi.mock("../../tools/helpers/imageHelpers", async (importOriginal) => { + const actual = await importOriginal() + return { + ...actual, + validateImageForProcessing: vi.fn().mockResolvedValue({ isValid: true, sizeInMB: 0.001 }), + ImageMemoryTracker: vi.fn().mockImplementation(() => ({ + getTotalMemoryUsed: vi.fn().mockReturnValue(0), + addMemoryUsage: vi.fn(), + })), + } +}) + +describe("webviewMessageHandler - image mentions (integration)", () => { + it("resolves image mentions for newTask and passes images to createTask", async () => { + const tmpRoot = await fs.mkdtemp(path.join(os.tmpdir(), "roo-image-mentions-")) + try { + const imgBytes = Buffer.from("png-bytes") + await fs.writeFile(path.join(tmpRoot, "cat.png"), imgBytes) + + const mockProvider = { + cwd: tmpRoot, + getCurrentTask: vi.fn().mockReturnValue(undefined), + createTask: vi.fn().mockResolvedValue(undefined), + postMessageToWebview: vi.fn().mockResolvedValue(undefined), + getState: vi.fn().mockResolvedValue({ + maxImageFileSize: 5, + maxTotalImageSize: 20, + }), + } as unknown as ClineProvider + + await webviewMessageHandler(mockProvider, { + type: "newTask", + text: "Please look at @/cat.png", + images: [], + } as any) + + expect(mockProvider.createTask).toHaveBeenCalledWith("Please look at @/cat.png", [ + `data:image/png;base64,${imgBytes.toString("base64")}`, + ]) + } finally { + await fs.rm(tmpRoot, { recursive: true, force: true }) + } + }) + + it("resolves image mentions for askResponse and passes images to handleWebviewAskResponse", async () => { + const tmpRoot = await fs.mkdtemp(path.join(os.tmpdir(), "roo-image-mentions-")) + try { + const imgBytes = Buffer.from("jpg-bytes") + await fs.writeFile(path.join(tmpRoot, "cat.jpg"), imgBytes) + + const handleWebviewAskResponse = vi.fn() + const mockProvider = { + cwd: tmpRoot, + getCurrentTask: vi.fn().mockReturnValue({ + cwd: tmpRoot, + handleWebviewAskResponse, + }), + getState: vi.fn().mockResolvedValue({ + maxImageFileSize: 5, + maxTotalImageSize: 20, + }), + } as unknown as ClineProvider + + await webviewMessageHandler(mockProvider, { + type: "askResponse", + askResponse: "messageResponse", + text: "Please look at @/cat.jpg", + images: [], + } as any) + + expect(handleWebviewAskResponse).toHaveBeenCalledWith("messageResponse", "Please look at @/cat.jpg", [ + `data:image/jpeg;base64,${imgBytes.toString("base64")}`, + ]) + } finally { + await fs.rm(tmpRoot, { recursive: true, force: true }) + } + }) + + it("resolves gif image mentions (matching read_file behavior)", async () => { + const tmpRoot = await fs.mkdtemp(path.join(os.tmpdir(), "roo-image-mentions-")) + try { + const imgBytes = Buffer.from("gif-bytes") + await fs.writeFile(path.join(tmpRoot, "animation.gif"), imgBytes) + + const mockProvider = { + cwd: tmpRoot, + getCurrentTask: vi.fn().mockReturnValue(undefined), + createTask: vi.fn().mockResolvedValue(undefined), + postMessageToWebview: vi.fn().mockResolvedValue(undefined), + getState: vi.fn().mockResolvedValue({ + maxImageFileSize: 5, + maxTotalImageSize: 20, + }), + } as unknown as ClineProvider + + await webviewMessageHandler(mockProvider, { + type: "newTask", + text: "See @/animation.gif", + images: [], + } as any) + + expect(mockProvider.createTask).toHaveBeenCalledWith("See @/animation.gif", [ + `data:image/gif;base64,${imgBytes.toString("base64")}`, + ]) + } finally { + await fs.rm(tmpRoot, { recursive: true, force: true }) + } + }) +}) diff --git a/src/core/webview/__tests__/webviewMessageHandler.searchFiles.spec.ts b/src/core/webview/__tests__/webviewMessageHandler.searchFiles.spec.ts new file mode 100644 index 0000000000..82f4d765ab --- /dev/null +++ b/src/core/webview/__tests__/webviewMessageHandler.searchFiles.spec.ts @@ -0,0 +1,297 @@ +// npx vitest core/webview/__tests__/webviewMessageHandler.searchFiles.spec.ts + +import type { Mock } from "vitest" + +// Mock dependencies - must come before imports +vi.mock("../../../services/search/file-search") +vi.mock("../../ignore/RooIgnoreController") + +import { webviewMessageHandler } from "../webviewMessageHandler" +import type { ClineProvider } from "../ClineProvider" +import { searchWorkspaceFiles } from "../../../services/search/file-search" +import { RooIgnoreController } from "../../ignore/RooIgnoreController" + +const mockSearchWorkspaceFiles = searchWorkspaceFiles as Mock + +vi.mock("vscode", () => ({ + window: { + showInformationMessage: vi.fn(), + showErrorMessage: vi.fn(), + }, + workspace: { + workspaceFolders: [{ uri: { fsPath: "/mock/workspace" } }], + }, +})) + +describe("webviewMessageHandler - searchFiles with RooIgnore filtering", () => { + let mockClineProvider: ClineProvider + let mockFilterPaths: Mock + let mockDispose: Mock + + beforeEach(() => { + vi.clearAllMocks() + + // Spy on the mock RooIgnoreController prototype methods + mockFilterPaths = vi.fn() + mockDispose = vi.fn() + + // Override the filterPaths method on the prototype + ;(RooIgnoreController.prototype as any).filterPaths = mockFilterPaths + ;(RooIgnoreController.prototype as any).initialize = vi.fn().mockResolvedValue(undefined) + ;(RooIgnoreController.prototype as any).dispose = mockDispose + + // Create mock ClineProvider + mockClineProvider = { + getState: vi.fn(), + postMessageToWebview: vi.fn(), + getCurrentTask: vi.fn(), + cwd: "/mock/workspace", + } as unknown as ClineProvider + }) + + it("should filter results using RooIgnoreController when showRooIgnoredFiles is false", async () => { + // Setup mock results from file search + const mockResults = [ + { path: "src/index.ts", type: "file" as const, label: "index.ts" }, + { path: "secrets/config.json", type: "file" as const, label: "config.json" }, + { path: "src/utils.ts", type: "file" as const, label: "utils.ts" }, + ] + mockSearchWorkspaceFiles.mockResolvedValue(mockResults) + + // Setup state with showRooIgnoredFiles = false + ;(mockClineProvider.getState as Mock).mockResolvedValue({ + showRooIgnoredFiles: false, + }) + + // Setup filter to exclude secrets folder + mockFilterPaths.mockReturnValue(["src/index.ts", "src/utils.ts"]) + + // No current task, so temporary controller will be created + ;(mockClineProvider.getCurrentTask as Mock).mockReturnValue(null) + + await webviewMessageHandler(mockClineProvider, { + type: "searchFiles", + query: "index", + requestId: "test-request-123", + }) + + // Verify filterPaths was called with all result paths + expect(mockFilterPaths).toHaveBeenCalledWith(["src/index.ts", "secrets/config.json", "src/utils.ts"]) + + // Verify filtered results were sent to webview + expect(mockClineProvider.postMessageToWebview).toHaveBeenCalledWith({ + type: "fileSearchResults", + results: [ + { path: "src/index.ts", type: "file", label: "index.ts" }, + { path: "src/utils.ts", type: "file", label: "utils.ts" }, + ], + requestId: "test-request-123", + }) + }) + + it("should not filter results when showRooIgnoredFiles is true", async () => { + // Setup mock results from file search + const mockResults = [ + { path: "src/index.ts", type: "file" as const, label: "index.ts" }, + { path: "secrets/config.json", type: "file" as const, label: "config.json" }, + ] + mockSearchWorkspaceFiles.mockResolvedValue(mockResults) + + // Setup state with showRooIgnoredFiles = true + ;(mockClineProvider.getState as Mock).mockResolvedValue({ + showRooIgnoredFiles: true, + }) + + // No current task + ;(mockClineProvider.getCurrentTask as Mock).mockReturnValue(null) + + await webviewMessageHandler(mockClineProvider, { + type: "searchFiles", + query: "index", + requestId: "test-request-456", + }) + + // Verify filterPaths was NOT called + expect(mockFilterPaths).not.toHaveBeenCalled() + + // Verify all results were sent to webview (unfiltered) + expect(mockClineProvider.postMessageToWebview).toHaveBeenCalledWith({ + type: "fileSearchResults", + results: mockResults, + requestId: "test-request-456", + }) + }) + + it("should use existing RooIgnoreController from current task", async () => { + // Setup mock results from file search + const mockResults = [ + { path: "src/index.ts", type: "file" as const, label: "index.ts" }, + { path: "private/secret.ts", type: "file" as const, label: "secret.ts" }, + ] + mockSearchWorkspaceFiles.mockResolvedValue(mockResults) + + // Setup state with showRooIgnoredFiles = false + ;(mockClineProvider.getState as Mock).mockResolvedValue({ + showRooIgnoredFiles: false, + }) + + // Create a mock task with its own RooIgnoreController + const taskFilterPaths = vi.fn().mockReturnValue(["src/index.ts"]) + const taskRooIgnoreController = { + filterPaths: taskFilterPaths, + initialize: vi.fn(), + } + ;(mockClineProvider.getCurrentTask as Mock).mockReturnValue({ + taskId: "test-task-id", + rooIgnoreController: taskRooIgnoreController, + }) + + await webviewMessageHandler(mockClineProvider, { + type: "searchFiles", + query: "index", + requestId: "test-request-789", + }) + + // Verify the task's controller was used (not the prototype) + expect(taskFilterPaths).toHaveBeenCalledWith(["src/index.ts", "private/secret.ts"]) + + // Verify filtered results were sent to webview + expect(mockClineProvider.postMessageToWebview).toHaveBeenCalledWith({ + type: "fileSearchResults", + results: [{ path: "src/index.ts", type: "file", label: "index.ts" }], + requestId: "test-request-789", + }) + }) + + it("should handle error when no workspace path is available", async () => { + // Create provider without cwd + mockClineProvider = { + ...mockClineProvider, + cwd: undefined, + getCurrentTask: vi.fn().mockReturnValue(null), + } as unknown as ClineProvider + + await webviewMessageHandler(mockClineProvider, { + type: "searchFiles", + query: "test", + requestId: "test-request-error", + }) + + // Verify error response was sent + expect(mockClineProvider.postMessageToWebview).toHaveBeenCalledWith({ + type: "fileSearchResults", + results: [], + requestId: "test-request-error", + error: "No workspace path available", + }) + }) + + it("should handle errors from searchWorkspaceFiles", async () => { + mockSearchWorkspaceFiles.mockRejectedValue(new Error("File search failed")) + + // Setup state + ;(mockClineProvider.getState as Mock).mockResolvedValue({ + showRooIgnoredFiles: false, + }) + ;(mockClineProvider.getCurrentTask as Mock).mockReturnValue(null) + + await webviewMessageHandler(mockClineProvider, { + type: "searchFiles", + query: "test", + requestId: "test-request-fail", + }) + + // Verify error response was sent + expect(mockClineProvider.postMessageToWebview).toHaveBeenCalledWith({ + type: "fileSearchResults", + results: [], + error: "File search failed", + requestId: "test-request-fail", + }) + }) + + it("should default showRooIgnoredFiles to false when state is null", async () => { + // Setup mock results from file search + const mockResults = [{ path: "src/index.ts", type: "file" as const, label: "index.ts" }] + mockSearchWorkspaceFiles.mockResolvedValue(mockResults) + + // Setup state to return null + ;(mockClineProvider.getState as Mock).mockResolvedValue(null) + + // Setup filter to return all paths (no filtering) + mockFilterPaths.mockReturnValue(["src/index.ts"]) + + // No current task + ;(mockClineProvider.getCurrentTask as Mock).mockReturnValue(null) + + await webviewMessageHandler(mockClineProvider, { + type: "searchFiles", + query: "index", + requestId: "test-request-default", + }) + + // Verify filterPaths was called (showRooIgnoredFiles defaults to false) + expect(mockFilterPaths).toHaveBeenCalled() + }) + + it("should dispose temporary RooIgnoreController after use", async () => { + // Setup mock results from file search + const mockResults = [{ path: "src/index.ts", type: "file" as const, label: "index.ts" }] + mockSearchWorkspaceFiles.mockResolvedValue(mockResults) + + // Setup state + ;(mockClineProvider.getState as Mock).mockResolvedValue({ + showRooIgnoredFiles: false, + }) + + // Setup filter + mockFilterPaths.mockReturnValue(["src/index.ts"]) + + // No current task, so temporary controller will be created and should be disposed + ;(mockClineProvider.getCurrentTask as Mock).mockReturnValue(null) + + await webviewMessageHandler(mockClineProvider, { + type: "searchFiles", + query: "index", + requestId: "test-request-dispose", + }) + + // Verify dispose was called on the temporary controller + expect(mockDispose).toHaveBeenCalled() + }) + + it("should not dispose controller from current task", async () => { + // Setup mock results from file search + const mockResults = [{ path: "src/index.ts", type: "file" as const, label: "index.ts" }] + mockSearchWorkspaceFiles.mockResolvedValue(mockResults) + + // Setup state + ;(mockClineProvider.getState as Mock).mockResolvedValue({ + showRooIgnoredFiles: false, + }) + + // Create a mock task with its own RooIgnoreController + const taskFilterPaths = vi.fn().mockReturnValue(["src/index.ts"]) + const taskDispose = vi.fn() + const taskRooIgnoreController = { + filterPaths: taskFilterPaths, + initialize: vi.fn(), + dispose: taskDispose, + } + ;(mockClineProvider.getCurrentTask as Mock).mockReturnValue({ + taskId: "test-task-id", + rooIgnoreController: taskRooIgnoreController, + }) + + await webviewMessageHandler(mockClineProvider, { + type: "searchFiles", + query: "index", + requestId: "test-request-no-dispose", + }) + + // Verify dispose was NOT called on the task's controller + expect(taskDispose).not.toHaveBeenCalled() + // Verify the prototype dispose was also not called + expect(mockDispose).not.toHaveBeenCalled() + }) +}) diff --git a/src/core/webview/__tests__/webviewMessageHandler.spec.ts b/src/core/webview/__tests__/webviewMessageHandler.spec.ts index a04d3403ed..2072bf1af5 100644 --- a/src/core/webview/__tests__/webviewMessageHandler.spec.ts +++ b/src/core/webview/__tests__/webviewMessageHandler.spec.ts @@ -134,6 +134,15 @@ vi.mock("../../../utils/fs") vi.mock("../../../utils/path") vi.mock("../../../utils/globalContext") +vi.mock("../../mentions/resolveImageMentions", () => ({ + resolveImageMentions: vi.fn(async ({ text, images }: { text: string; images?: string[] }) => ({ + text, + images: [...(images ?? []), "data:image/png;base64,from-mention"], + })), +})) + +import { resolveImageMentions } from "../../mentions/resolveImageMentions" + describe("webviewMessageHandler - requestLmStudioModels", () => { beforeEach(() => { vi.clearAllMocks() @@ -176,6 +185,41 @@ describe("webviewMessageHandler - requestLmStudioModels", () => { }) }) +describe("webviewMessageHandler - image mentions", () => { + beforeEach(() => { + vi.clearAllMocks() + mockClineProvider.getState = vi.fn().mockResolvedValue({ + maxImageFileSize: 5, + maxTotalImageSize: 20, + }) + }) + + it("should resolve image mentions for askResponse payloads", async () => { + const mockHandleWebviewAskResponse = vi.fn() + vi.mocked(mockClineProvider.getCurrentTask).mockReturnValue({ + cwd: "/mock/workspace", + rooIgnoreController: undefined, + handleWebviewAskResponse: mockHandleWebviewAskResponse, + } as any) + + await webviewMessageHandler(mockClineProvider, { + type: "askResponse", + askResponse: "messageResponse", + text: "See @/img.png", + images: [], + }) + + expect(vi.mocked(resolveImageMentions)).toHaveBeenCalled() + expect(mockHandleWebviewAskResponse).toHaveBeenCalledWith( + "messageResponse", + "See @/img.png", + ["data:image/png;base64,from-mention"], + "system", + false, + ) + }) +}) + describe("webviewMessageHandler - requestOllamaModels", () => { beforeEach(() => { vi.clearAllMocks() diff --git a/src/core/webview/webviewMessageHandler.ts b/src/core/webview/webviewMessageHandler.ts index aaa0f56f4b..a63b12b11b 100644 --- a/src/core/webview/webviewMessageHandler.ts +++ b/src/core/webview/webviewMessageHandler.ts @@ -56,6 +56,8 @@ import { exportSettings, importSettingsWithFeedback } from "../config/importExpo import { getOpenAiModels } from "../../api/providers/openai" import { getVsCodeLmModels } from "../../api/providers/vscode-lm" import { openMention } from "../mentions" +import { resolveImageMentions } from "../mentions/resolveImageMentions" +import { RooIgnoreController } from "../ignore/RooIgnoreController" import { getWorkspacePath } from "../../utils/path" import { Mode, defaultModeSlug, ZgsmCodeMode } from "../../shared/modes" import { getModels, flushModels } from "../../api/providers/fetchers/modelCache" @@ -95,6 +97,26 @@ export const webviewMessageHandler = async ( const getCurrentCwd = () => { return provider.getCurrentTask()?.cwd || provider.cwd } + + /** + * Resolves image file mentions in incoming messages. + * Matches read_file behavior: respects size limits and model capabilities. + */ + const resolveIncomingImages = async (payload: { text?: string; images?: string[] }) => { + const text = payload.text ?? "" + const images = payload.images + const currentTask = provider.getCurrentTask() + const state = await provider.getState() + const resolved = await resolveImageMentions({ + text, + images, + cwd: getCurrentCwd(), + rooIgnoreController: currentTask?.rooIgnoreController, + maxImageFileSize: state.maxImageFileSize, + maxTotalImageSize: state.maxTotalImageSize, + }) + return resolved + } /** * Shared utility to find message indices based on timestamp. * When multiple messages share the same timestamp (e.g., after condense), @@ -584,10 +606,8 @@ export const webviewMessageHandler = async ( // agentically running promises in old instance don't affect our new // task. This essentially creates a fresh slate for the new task. try { - // if (message.values?.checkProjectWiki) { - // await ensureProjectWikiSubtasksExists() - // } - await provider.createTask(message.text, message.images) + const resolved = await resolveIncomingImages({ text: message.text, images: message.images }) + await provider.createTask(resolved.text, resolved.images) // Task created successfully - notify the UI to reset await provider.postMessageToWebview({ type: "invoke", invoke: "newChat" }) } catch (error) { @@ -604,15 +624,18 @@ export const webviewMessageHandler = async ( break case "askResponse": - provider - .getCurrentTask() - ?.handleWebviewAskResponse( - message.askResponse!, - message.text, - message.images, - message?.values?.chatType || "system", - message?.values?.isCommandInput ?? false, - ) + { + const resolved = await resolveIncomingImages({ text: message.text, images: message.images }) + provider + .getCurrentTask() + ?.handleWebviewAskResponse( + message.askResponse!, + resolved.text, + resolved.images, + message?.values?.chatType || "system", + message?.values?.isCommandInput ?? false, + ) + } break case "updateSettings": @@ -1847,12 +1870,39 @@ export const webviewMessageHandler = async ( 20, // Use default limit, as filtering is now done in the backend ) - // Send results back to webview - await provider.postMessageToWebview({ - type: "fileSearchResults", - results, - requestId: message.requestId, - }) + // Get the RooIgnoreController from the current task, or create a new one + const currentTask = provider.getCurrentTask() + let rooIgnoreController = currentTask?.rooIgnoreController + let tempController: RooIgnoreController | undefined + + // If no current task or no controller, create a temporary one + if (!rooIgnoreController) { + tempController = new RooIgnoreController(workspacePath) + await tempController.initialize() + rooIgnoreController = tempController + } + + try { + // Get showRooIgnoredFiles setting from state + const { showRooIgnoredFiles = false } = (await provider.getState()) ?? {} + + // Filter results using RooIgnoreController if showRooIgnoredFiles is false + let filteredResults = results + if (!showRooIgnoredFiles && rooIgnoreController) { + const allowedPaths = rooIgnoreController.filterPaths(results.map((r) => r.path)) + filteredResults = results.filter((r) => allowedPaths.includes(r.path)) + } + + // Send results back to webview + await provider.postMessageToWebview({ + type: "fileSearchResults", + results: filteredResults, + requestId: message.requestId, + }) + } finally { + // Dispose temporary controller to prevent resource leak + tempController?.dispose() + } } catch (error) { const errorMessage = error instanceof Error ? error.message : String(error) @@ -2029,11 +2079,12 @@ export const webviewMessageHandler = async ( break case "editMessageConfirm": if (message.messageTs && message.text) { + const resolved = await resolveIncomingImages({ text: message.text, images: message.images }) await handleEditMessageConfirm( message.messageTs, - message.text, + resolved.text, message.restoreCheckpoint, - message.images, + resolved.images, ) } break @@ -3466,7 +3517,8 @@ export const webviewMessageHandler = async ( */ case "queueMessage": { - provider.getCurrentTask()?.messageQueueService.addMessage(message.text ?? "", message.images) + const resolved = await resolveIncomingImages({ text: message.text, images: message.images }) + provider.getCurrentTask()?.messageQueueService.addMessage(resolved.text, resolved.images) break } case "removeQueuedMessage": { diff --git a/src/esbuild.mjs b/src/esbuild.mjs index 0e248088ff..f7d1fcda0a 100644 --- a/src/esbuild.mjs +++ b/src/esbuild.mjs @@ -31,8 +31,8 @@ async function main() { platform: "node", define: { "process.env.NODE_ENV": production ? '"production"' : '"development"', - "process.env.ZGSM_BASE_URL": JSON.stringify(process.env.ZGSM_BASE_URL || ""), - "process.env.ZGSM_PUBLIC_KEY": JSON.stringify(process.env.ZGSM_PUBLIC_KEY || ""), + "process.env.COSTRICT_BASE_URL": JSON.stringify(process.env.COSTRICT_BASE_URL || ""), + "process.env.COSTRICT_PUBLIC_KEY": JSON.stringify(process.env.COSTRICT_PUBLIC_KEY || process.env.ZGSM_PUBLIC_KEY || ""), }, banner: { js: networkInterfacesCompatible, diff --git a/src/utils/encoding.ts b/src/utils/encoding.ts index 8922ce1e54..29678f50b9 100644 --- a/src/utils/encoding.ts +++ b/src/utils/encoding.ts @@ -242,7 +242,6 @@ export async function isBinaryFileWithEncodingDetection(filePath: string, size?: } catch (error) { // File read error, assume it's binary return false - return true } } diff --git a/turbo.json b/turbo.json index 2df72b43ee..c24ae1e0a9 100644 --- a/turbo.json +++ b/turbo.json @@ -1,6 +1,6 @@ { "$schema": "https://turbo.build/schema.json", - "globalEnv": ["ZGSM_PUBLIC_KEY", "ZGSM_BASE_URL", "NODE_ENV"], + "globalEnv": ["ZGSM_PUBLIC_KEY", "COSTRICT_PUBLIC_KEY", "COSTRICT_BASE_URL", "NODE_ENV"], "tasks": { "lint": {}, "check-types": {}, diff --git a/webview-ui/src/components/settings/ExperimentalSettings.tsx b/webview-ui/src/components/settings/ExperimentalSettings.tsx index 87654c6263..d5de439499 100644 --- a/webview-ui/src/components/settings/ExperimentalSettings.tsx +++ b/webview-ui/src/components/settings/ExperimentalSettings.tsx @@ -58,6 +58,8 @@ export const ExperimentalSettings = ({ .filter(([key]) => key in EXPERIMENT_IDS) // Hide MULTIPLE_NATIVE_TOOL_CALLS - feature is on hold .filter(([key]) => key !== "MULTIPLE_NATIVE_TOOL_CALLS") + // Hide CHAT_SEARCH - moved to UI settings + .filter(([key]) => key !== "CHAT_SEARCH") .map((config) => { if (config[0] === "MULTI_FILE_APPLY_DIFF") { return ( diff --git a/webview-ui/src/components/settings/SettingsView.tsx b/webview-ui/src/components/settings/SettingsView.tsx index e65fdf11ea..75b9f5c783 100644 --- a/webview-ui/src/components/settings/SettingsView.tsx +++ b/webview-ui/src/components/settings/SettingsView.tsx @@ -897,8 +897,10 @@ const SettingsView = forwardRef(({ onDone, t showSpeedInfo={showSpeedInfo ?? false} automaticallyFocus={automaticallyFocus ?? false} enterBehavior={enterBehavior ?? "send"} + experiments={experiments} apiConfiguration={apiConfiguration} setCachedStateField={setCachedStateField} + setExperimentEnabled={setExperimentEnabled} /> )} diff --git a/webview-ui/src/components/settings/UISettings.tsx b/webview-ui/src/components/settings/UISettings.tsx index 98873ee6b1..2b09eaa43f 100644 --- a/webview-ui/src/components/settings/UISettings.tsx +++ b/webview-ui/src/components/settings/UISettings.tsx @@ -3,8 +3,10 @@ import { useAppTranslation } from "@/i18n/TranslationContext" import { VSCodeCheckbox } from "@vscode/webview-ui-toolkit/react" import { Glasses } from "lucide-react" import { telemetryClient } from "@/utils/TelemetryClient" +import type { Experiments } from "@roo-code/types" +import { EXPERIMENT_IDS } from "@roo/experiments" -import { SetCachedStateField } from "./types" +import { SetCachedStateField, SetExperimentEnabled } from "./types" import { SectionHeader } from "./SectionHeader" import { Section } from "./Section" import { ExtensionStateContextType } from "@/context/ExtensionStateContext" @@ -14,8 +16,10 @@ interface UISettingsProps extends HTMLAttributes { showSpeedInfo: boolean automaticallyFocus: boolean enterBehavior: "send" | "newline" + experiments: Experiments apiConfiguration?: any setCachedStateField: SetCachedStateField + setExperimentEnabled: SetExperimentEnabled } export const UISettings = ({ @@ -23,8 +27,10 @@ export const UISettings = ({ showSpeedInfo, automaticallyFocus, enterBehavior, + experiments, apiConfiguration, setCachedStateField, + setExperimentEnabled, ...props }: UISettingsProps) => { const { t } = useAppTranslation() @@ -73,6 +79,15 @@ export const UISettings = ({ }) } + const handleChatSearchChange = (enabled: boolean) => { + setExperimentEnabled(EXPERIMENT_IDS.CHAT_SEARCH, enabled) + + // Track telemetry event + telemetryClient.capture("ui_settings_chat_search_changed", { + enabled, + }) + } + return (
@@ -137,6 +152,18 @@ export const UISettings = ({ {t("settings:ui.requireCtrlEnterToSend.description", { primaryMod })}
+ {/* Chat Search Setting */} +
+ handleChatSearchChange(e.target.checked)} + data-testid="chat-search-checkbox"> + {t("settings:experimental.CHAT_SEARCH.name")} + +
+ {t("settings:experimental.CHAT_SEARCH.description")} +
+
diff --git a/webview-ui/src/components/settings/__tests__/UISettings.spec.tsx b/webview-ui/src/components/settings/__tests__/UISettings.spec.tsx index 7a98566f31..0b744a2435 100644 --- a/webview-ui/src/components/settings/__tests__/UISettings.spec.tsx +++ b/webview-ui/src/components/settings/__tests__/UISettings.spec.tsx @@ -11,7 +11,9 @@ describe("UISettings", () => { apiProvider: "zgsm", }, enterBehavior: "send" as const, + experiments: {}, setCachedStateField: vi.fn(), + setExperimentEnabled: vi.fn(), } it("renders the collapse thinking checkbox", () => { diff --git a/webview-ui/src/components/settings/providers/OpenAICompatible.tsx b/webview-ui/src/components/settings/providers/OpenAICompatible.tsx index ad338d342a..2cf77b8366 100644 --- a/webview-ui/src/components/settings/providers/OpenAICompatible.tsx +++ b/webview-ui/src/components/settings/providers/OpenAICompatible.tsx @@ -280,7 +280,7 @@ export const OpenAICompatible = ({ }} modelInfo={{ ...(apiConfiguration.openAiCustomModelInfo || openAiModelInfoSaneDefaults), - supportsReasoningEffort: true, + supportsReasoningEffort: ["low", "medium", "high", "xhigh"], }} /> )} diff --git a/webview-ui/src/components/settings/providers/ZgsmAI.tsx b/webview-ui/src/components/settings/providers/ZgsmAI.tsx index 879301b032..76040eb7b9 100644 --- a/webview-ui/src/components/settings/providers/ZgsmAI.tsx +++ b/webview-ui/src/components/settings/providers/ZgsmAI.tsx @@ -365,7 +365,7 @@ export const ZgsmAI = ({ }} modelInfo={{ ...(apiConfiguration.zgsmAiCustomModelInfo || zgsmModels.default), - supportsReasoningEffort: true, + supportsReasoningEffort: ["low", "medium", "high", "xhigh"], }} /> )} diff --git a/webview-ui/src/i18n/locales/en/chat.json b/webview-ui/src/i18n/locales/en/chat.json index f1827756b4..c42dd824f4 100644 --- a/webview-ui/src/i18n/locales/en/chat.json +++ b/webview-ui/src/i18n/locales/en/chat.json @@ -334,7 +334,7 @@ "triggerLabel_zero": "0 auto-approve", "triggerLabel_one": "1 auto-approved", "triggerLabel_other": "{{count}} auto-approved", - "triggerLabelAll": "YOLO" + "triggerLabelAll": "BRRR" }, "announcement": { "title": "Roo Code {{version}} Released", diff --git a/webview-ui/src/i18n/locales/zh-CN/chat.json b/webview-ui/src/i18n/locales/zh-CN/chat.json index a4142775b5..7dd1b2060f 100644 --- a/webview-ui/src/i18n/locales/zh-CN/chat.json +++ b/webview-ui/src/i18n/locales/zh-CN/chat.json @@ -305,7 +305,7 @@ "triggerLabel_zero": "0 个自动批准", "triggerLabel_one": "1 个自动批准", "triggerLabel_other": "{{count}} 个自动批准", - "triggerLabelAll": "YOLO" + "triggerLabelAll": "BRRR" }, "reasoning": { "thinking": "思考", diff --git a/webview-ui/src/i18n/locales/zh-TW/chat.json b/webview-ui/src/i18n/locales/zh-TW/chat.json index 88a07b670d..99e0d35898 100644 --- a/webview-ui/src/i18n/locales/zh-TW/chat.json +++ b/webview-ui/src/i18n/locales/zh-TW/chat.json @@ -338,7 +338,7 @@ "triggerLabel_zero": "0 個自動核准", "triggerLabel_one": "1 個自動核准", "triggerLabel_other": "{{count}} 個自動核准", - "triggerLabelAll": "YOLO" + "triggerLabelAll": "BRRR" }, "announcement": { "title": "Roo Code {{version}} 已發布", diff --git a/webview-ui/src/index.css b/webview-ui/src/index.css index 0ddddadf31..a61993d314 100644 --- a/webview-ui/src/index.css +++ b/webview-ui/src/index.css @@ -214,6 +214,29 @@ .history-item-highlight { @apply underline; } + + /* Custom smooth bounce animation for Roo hero */ + @keyframes smooth-bounce { + 0% { + transform: translateY(0); + } + 25% { + transform: translateY(-25%); + } + 50% { + transform: translateY(0); + } + 75% { + transform: translateY(-12.5%); + } + 100% { + transform: translateY(0); + } + } + + .animate-smooth-bounce { + animation: smooth-bounce 1s ease-in-out infinite; + } } /* Form Element Focus States */ diff --git a/webview-ui/vite.config.ts b/webview-ui/vite.config.ts index cb543cd627..35d6bbc090 100644 --- a/webview-ui/vite.config.ts +++ b/webview-ui/vite.config.ts @@ -64,8 +64,10 @@ export default defineConfig(({ mode }) => { "process.env.COSTRICT_PKG_NAME": JSON.stringify(pkg.name), "process.env.COSTRICT_PKG_VERSION": JSON.stringify(pkg.version), "process.env.COSTRICT_PKG_OUTPUT_CHANNEL": JSON.stringify("CoStrict"), - "process.env.ZGSM_BASE_URL": JSON.stringify(process.env.ZGSM_BASE_URL || ""), - "process.env.ZGSM_PUBLIC_KEY": JSON.stringify(process.env.ZGSM_PUBLIC_KEY || ""), + "process.env.COSTRICT_BASE_URL": JSON.stringify(process.env.COSTRICT_BASE_URL || ""), + "process.env.COSTRICT_PUBLIC_KEY": JSON.stringify( + process.env.COSTRICT_PUBLIC_KEY || process.env.ZGSM_PUBLIC_KEY || "", + ), ...(gitSha ? { "process.env.COSTRICT_PKG_SHA": JSON.stringify(gitSha) } : {}), }