diff --git a/packages/types/src/history.ts b/packages/types/src/history.ts index d97884d216e..b4d84cb9a51 100644 --- a/packages/types/src/history.ts +++ b/packages/types/src/history.ts @@ -29,6 +29,7 @@ export const historyItemSchema = z.object({ * This ensures task resumption works correctly even when NTC settings change. */ toolProtocol: z.enum(["xml", "native"]).optional(), + apiConfigName: z.string().optional(), // Provider profile name for sticky profile feature status: z.enum(["active", "completed", "delegated"]).optional(), delegatedToId: z.string().optional(), // Last child this parent delegated to childIds: z.array(z.string()).optional(), // All children spawned by this task diff --git a/src/core/task-persistence/taskMetadata.ts b/src/core/task-persistence/taskMetadata.ts index eb872a6f7e9..cf8d9adb529 100644 --- a/src/core/task-persistence/taskMetadata.ts +++ b/src/core/task-persistence/taskMetadata.ts @@ -21,6 +21,8 @@ export type TaskMetadataOptions = { globalStoragePath: string workspace: string mode?: string + /** Provider profile name for the task (sticky profile feature) */ + apiConfigName?: string /** Initial status for the task (e.g., "active" for child tasks) */ initialStatus?: "active" | "delegated" | "completed" /** @@ -39,6 +41,7 @@ export async function taskMetadata({ globalStoragePath, workspace, mode, + apiConfigName, initialStatus, toolProtocol, }: TaskMetadataOptions) { @@ -116,6 +119,7 @@ export async function taskMetadata({ workspace, mode, ...(toolProtocol && { toolProtocol }), + ...(typeof apiConfigName === "string" && apiConfigName.length > 0 ? { apiConfigName } : {}), ...(initialStatus && { status: initialStatus }), } diff --git a/src/core/task/Task.ts b/src/core/task/Task.ts index 33e8245ccc6..91927450913 100644 --- a/src/core/task/Task.ts +++ b/src/core/task/Task.ts @@ -247,6 +247,49 @@ export class Task extends EventEmitter implements TaskLike { */ private taskModeReady: Promise + /** + * The API configuration name (provider profile) associated with this task. + * Persisted across sessions to maintain the provider profile when reopening tasks from history. + * + * ## Lifecycle + * + * ### For new tasks: + * 1. Initially `undefined` during construction + * 2. Asynchronously initialized from provider state via `initializeTaskApiConfigName()` + * 3. Falls back to "default" if provider state is unavailable + * + * ### For history items: + * 1. Immediately set from `historyItem.apiConfigName` during construction + * 2. Falls back to undefined if not stored in history (for backward compatibility) + * + * ## Important + * If you need a non-`undefined` provider profile (e.g., for profile-dependent operations), + * wait for `taskApiConfigReady` first (or use `getTaskApiConfigName()`). + * The sync `taskApiConfigName` getter may return `undefined` for backward compatibility. + * + * @private + * @see {@link getTaskApiConfigName} - For safe async access + * @see {@link taskApiConfigName} - For sync access after initialization + */ + private _taskApiConfigName: string | undefined + + /** + * Promise that resolves when the task API config name has been initialized. + * This ensures async API config name initialization completes before the task is used. + * + * ## Purpose + * - Prevents race conditions when accessing task API config name + * - Ensures provider state is properly loaded before profile-dependent operations + * - Provides a synchronization point for async initialization + * + * ## Resolution timing + * - For history items: Resolves immediately (sync initialization) + * - For new tasks: Resolves after provider state is fetched (async initialization) + * + * @private + */ + private taskApiConfigReady: Promise + providerRef: WeakRef private readonly globalStoragePath: string abort: boolean = false @@ -480,21 +523,25 @@ export class Task extends EventEmitter implements TaskLike { this.taskNumber = taskNumber this.initialStatus = initialStatus - // Store the task's mode when it's created. - // For history items, use the stored mode; for new tasks, we'll set it + // Store the task's mode and API config name when it's created. + // For history items, use the stored values; for new tasks, we'll set them // after getting state. if (historyItem) { this._taskMode = historyItem.mode || defaultModeSlug + this._taskApiConfigName = historyItem.apiConfigName this.taskModeReady = Promise.resolve() + this.taskApiConfigReady = Promise.resolve() TelemetryService.instance.captureTaskRestarted(this.taskId) // For history items, use the persisted tool protocol if available. // If not available (old tasks), it will be detected in resumeTaskFromHistory. this._taskToolProtocol = historyItem.toolProtocol } else { - // For new tasks, don't set the mode yet - wait for async initialization. + // For new tasks, don't set the mode/apiConfigName yet - wait for async initialization. this._taskMode = undefined + this._taskApiConfigName = undefined this.taskModeReady = this.initializeTaskMode(provider) + this.taskApiConfigReady = this.initializeTaskApiConfigName(provider) TelemetryService.instance.captureTaskCreated(this.taskId) // For new tasks, resolve and lock the tool protocol immediately. @@ -617,6 +664,47 @@ export class Task extends EventEmitter implements TaskLike { } } + /** + * Initialize the task API config name from the provider state. + * This method handles async initialization with proper error handling. + * + * ## Flow + * 1. Attempts to fetch the current API config name from provider state + * 2. Sets `_taskApiConfigName` to the fetched name or "default" if unavailable + * 3. Handles errors gracefully by falling back to "default" + * 4. Logs any initialization errors for debugging + * + * ## Error handling + * - Network failures when fetching provider state + * - Provider not yet initialized + * - Invalid state structure + * + * All errors result in fallback to "default" to ensure task can proceed. + * + * @private + * @param provider - The ClineProvider instance to fetch state from + * @returns Promise that resolves when initialization is complete + */ + private async initializeTaskApiConfigName(provider: ClineProvider): Promise { + try { + const state = await provider.getState() + + // Avoid clobbering a newer value that may have been set while awaiting provider state + // (e.g., user switches provider profile immediately after task creation). + if (this._taskApiConfigName === undefined) { + this._taskApiConfigName = state?.currentApiConfigName ?? "default" + } + } catch (error) { + // If there's an error getting state, use the default profile (unless a newer value was set). + if (this._taskApiConfigName === undefined) { + this._taskApiConfigName = "default" + } + // Use the provider's log method for better error visibility + const errorMessage = `Failed to initialize task API config name: ${error instanceof Error ? error.message : String(error)}` + provider.log(errorMessage) + } + } + /** * Sets up a listener for provider profile changes to automatically update the parser state. * This ensures the XML/native protocol parser stays synchronized with the current model. @@ -737,6 +825,73 @@ export class Task extends EventEmitter implements TaskLike { return this._taskMode } + /** + * Wait for the task API config name to be initialized before proceeding. + * This method ensures that any operations depending on the task's provider profile + * will have access to the correct value. + * + * ## When to use + * - Before accessing provider profile-specific configurations + * - When switching between tasks with different provider profiles + * - Before operations that depend on the provider profile + * + * @returns Promise that resolves when the task API config name is initialized + * @public + */ + public async waitForApiConfigInitialization(): Promise { + return this.taskApiConfigReady + } + + /** + * Get the task API config name asynchronously, ensuring it's properly initialized. + * This is the recommended way to access the task's provider profile as it guarantees + * the value is available before returning. + * + * ## Async behavior + * - Internally waits for `taskApiConfigReady` promise to resolve + * - Returns the initialized API config name or undefined as fallback + * - Safe to call multiple times - subsequent calls return immediately if already initialized + * + * @returns Promise resolving to the task API config name string or undefined + * @public + */ + public async getTaskApiConfigName(): Promise { + await this.taskApiConfigReady + return this._taskApiConfigName + } + + /** + * Get the task API config name synchronously. This should only be used when you're certain + * that the value has already been initialized (e.g., after waitForApiConfigInitialization). + * + * ## When to use + * - In synchronous contexts where async/await is not available + * - After explicitly waiting for initialization via `waitForApiConfigInitialization()` + * - In event handlers or callbacks where API config name is guaranteed to be initialized + * + * Note: Unlike taskMode, this getter does not throw if uninitialized since the API config + * name can legitimately be undefined (backward compatibility with tasks created before + * this feature was added). + * + * @returns The task API config name string or undefined + * @public + */ + public get taskApiConfigName(): string | undefined { + return this._taskApiConfigName + } + + /** + * Update the task's API config name. This is called when the user switches + * provider profiles while a task is active, allowing the task to remember + * its new provider profile. + * + * @param apiConfigName - The new API config name to set + * @internal + */ + public setTaskApiConfigName(apiConfigName: string | undefined): void { + this._taskApiConfigName = apiConfigName + } + static create(options: TaskOptions): [Task, Promise] { const instance = new Task({ ...options, startTask: false }) const { images, task, historyItem } = options @@ -1005,6 +1160,10 @@ export class Task extends EventEmitter implements TaskLike { globalStoragePath: this.globalStoragePath, }) + if (this._taskApiConfigName === undefined) { + await this.taskApiConfigReady + } + const { historyItem, tokenUsage } = await taskMetadata({ taskId: this.taskId, rootTaskId: this.rootTaskId, @@ -1014,6 +1173,7 @@ export class Task extends EventEmitter implements TaskLike { globalStoragePath: this.globalStoragePath, workspace: this.cwd, mode: this._taskMode || defaultModeSlug, // Use the task's own mode, not the current provider mode. + apiConfigName: this._taskApiConfigName, // Use the task's own provider profile, not the current provider profile. initialStatus: this.initialStatus, toolProtocol: this._taskToolProtocol, // Persist the locked tool protocol. }) diff --git a/src/core/task/__tests__/Task.sticky-profile-race.spec.ts b/src/core/task/__tests__/Task.sticky-profile-race.spec.ts new file mode 100644 index 00000000000..e78301541df --- /dev/null +++ b/src/core/task/__tests__/Task.sticky-profile-race.spec.ts @@ -0,0 +1,142 @@ +// npx vitest run core/task/__tests__/Task.sticky-profile-race.spec.ts + +import * as vscode from "vscode" + +import type { ProviderSettings } from "@roo-code/types" +import { Task } from "../Task" +import { ClineProvider } from "../../webview/ClineProvider" + +vi.mock("@roo-code/telemetry", () => ({ + TelemetryService: { + hasInstance: vi.fn().mockReturnValue(true), + createInstance: vi.fn(), + get instance() { + return { + captureTaskCreated: vi.fn(), + captureTaskRestarted: vi.fn(), + captureModeSwitch: vi.fn(), + captureConversationMessage: vi.fn(), + captureLlmCompletion: vi.fn(), + captureConsecutiveMistakeError: vi.fn(), + captureCodeActionUsed: vi.fn(), + setProvider: vi.fn(), + } + }, + }, +})) + +vi.mock("vscode", () => { + const mockDisposable = { dispose: vi.fn() } + const mockEventEmitter = { event: vi.fn(), fire: vi.fn() } + const mockTextDocument = { uri: { fsPath: "/mock/workspace/path/file.ts" } } + const mockTextEditor = { document: mockTextDocument } + const mockTab = { input: { uri: { fsPath: "/mock/workspace/path/file.ts" } } } + const mockTabGroup = { tabs: [mockTab] } + + return { + TabInputTextDiff: vi.fn(), + CodeActionKind: { + QuickFix: { value: "quickfix" }, + RefactorRewrite: { value: "refactor.rewrite" }, + }, + window: { + createTextEditorDecorationType: vi.fn().mockReturnValue({ + dispose: vi.fn(), + }), + visibleTextEditors: [mockTextEditor], + tabGroups: { + all: [mockTabGroup], + close: vi.fn(), + onDidChangeTabs: vi.fn(() => ({ dispose: vi.fn() })), + }, + showErrorMessage: vi.fn(), + }, + workspace: { + getConfiguration: vi.fn(() => ({ get: (_k: string, d: any) => d })), + workspaceFolders: [ + { + uri: { fsPath: "/mock/workspace/path" }, + name: "mock-workspace", + index: 0, + }, + ], + createFileSystemWatcher: vi.fn(() => ({ + onDidCreate: vi.fn(() => mockDisposable), + onDidDelete: vi.fn(() => mockDisposable), + onDidChange: vi.fn(() => mockDisposable), + dispose: vi.fn(), + })), + fs: { + stat: vi.fn().mockResolvedValue({ type: 1 }), + }, + onDidSaveTextDocument: vi.fn(() => mockDisposable), + }, + env: { + uriScheme: "vscode", + language: "en", + }, + EventEmitter: vi.fn().mockImplementation(() => mockEventEmitter), + Disposable: { + from: vi.fn(), + }, + TabInputText: vi.fn(), + version: "1.85.0", + } +}) + +vi.mock("../../environment/getEnvironmentDetails", () => ({ + getEnvironmentDetails: vi.fn().mockResolvedValue(""), +})) + +vi.mock("../../ignore/RooIgnoreController") + +vi.mock("p-wait-for", () => ({ + default: vi.fn().mockImplementation(async () => Promise.resolve()), +})) + +vi.mock("delay", () => ({ + __esModule: true, + default: vi.fn().mockResolvedValue(undefined), +})) + +describe("Task - sticky provider profile init race", () => { + it("does not overwrite task apiConfigName if set during async initialization", async () => { + const apiConfig: ProviderSettings = { + apiProvider: "anthropic", + apiModelId: "claude-3-5-sonnet-20241022", + apiKey: "test-api-key", + } as any + + let resolveGetState: ((v: any) => void) | undefined + const getStatePromise = new Promise((resolve) => { + resolveGetState = resolve + }) + + const mockProvider = { + context: { + globalStorageUri: { fsPath: "/test/storage" }, + }, + getState: vi.fn().mockImplementation(() => getStatePromise), + log: vi.fn(), + on: vi.fn(), + off: vi.fn(), + postStateToWebview: vi.fn().mockResolvedValue(undefined), + updateTaskHistory: vi.fn().mockResolvedValue(undefined), + } as unknown as ClineProvider + + const task = new Task({ + provider: mockProvider, + apiConfiguration: apiConfig, + task: "test task", + startTask: false, + }) + + // Simulate a profile switch happening before provider.getState resolves. + task.setTaskApiConfigName("new-profile") + + resolveGetState?.({ currentApiConfigName: "old-profile" }) + await task.waitForApiConfigInitialization() + + expect(task.taskApiConfigName).toBe("new-profile") + }) +}) diff --git a/src/core/webview/ClineProvider.ts b/src/core/webview/ClineProvider.ts index a34fb817ee2..6153af6160c 100644 --- a/src/core/webview/ClineProvider.ts +++ b/src/core/webview/ClineProvider.ts @@ -890,29 +890,64 @@ export class ClineProvider await this.updateGlobalState("mode", historyItem.mode) // Load the saved API config for the restored mode if it exists. - const savedConfigId = await this.providerSettingsManager.getModeConfigId(historyItem.mode) - const listApiConfig = await this.providerSettingsManager.listConfig() + // Skip mode-based profile activation if historyItem.apiConfigName exists, + // since the task's specific provider profile will override it anyway. + if (!historyItem.apiConfigName) { + const savedConfigId = await this.providerSettingsManager.getModeConfigId(historyItem.mode) + const listApiConfig = await this.providerSettingsManager.listConfig() + + // Update listApiConfigMeta first to ensure UI has latest data. + await this.updateGlobalState("listApiConfigMeta", listApiConfig) + + // If this mode has a saved config, use it. + if (savedConfigId) { + const profile = listApiConfig.find(({ id }) => id === savedConfigId) + + if (profile?.name) { + try { + await this.activateProviderProfile({ name: profile.name }) + } catch (error) { + // Log the error but continue with task restoration. + this.log( + `Failed to restore API configuration for mode '${historyItem.mode}': ${ + error instanceof Error ? error.message : String(error) + }. Continuing with default configuration.`, + ) + // The task will continue with the current/default configuration. + } + } + } + } + } - // Update listApiConfigMeta first to ensure UI has latest data. + // If the history item has a saved API config name (provider profile), restore it. + // This overrides any mode-based config restoration above, because the task's + // specific provider profile takes precedence over mode defaults. + if (historyItem.apiConfigName) { + const listApiConfig = await this.providerSettingsManager.listConfig() + // Keep global state/UI in sync with latest profiles for parity with mode restoration above. await this.updateGlobalState("listApiConfigMeta", listApiConfig) + const profile = listApiConfig.find(({ name }) => name === historyItem.apiConfigName) - // If this mode has a saved config, use it. - if (savedConfigId) { - const profile = listApiConfig.find(({ id }) => id === savedConfigId) - - if (profile?.name) { - try { - await this.activateProviderProfile({ name: profile.name }) - } catch (error) { - // Log the error but continue with task restoration. - this.log( - `Failed to restore API configuration for mode '${historyItem.mode}': ${ - error instanceof Error ? error.message : String(error) - }. Continuing with default configuration.`, - ) - // The task will continue with the current/default configuration. - } + if (profile?.name) { + try { + await this.activateProviderProfile( + { name: profile.name }, + { persistModeConfig: false, persistTaskHistory: false }, + ) + } catch (error) { + // Log the error but continue with task restoration. + this.log( + `Failed to restore API configuration '${historyItem.apiConfigName}' for task: ${ + error instanceof Error ? error.message : String(error) + }. Continuing with current configuration.`, + ) } + } else { + // Profile no longer exists, log warning but continue + this.log( + `Provider profile '${historyItem.apiConfigName}' from history no longer exists. Using current configuration.`, + ) } } @@ -1399,6 +1434,9 @@ export class ClineProvider // Change the provider for the current task. // TODO: We should rename `buildApiHandler` for clarity (e.g. `getProviderClient`). this.updateTaskApiHandlerIfNeeded(providerSettings, { forceRebuild: true }) + + // Keep the current task's sticky provider profile in sync with the newly-activated profile. + await this.persistStickyProviderProfileToCurrentTask(name) } else { await this.updateGlobalState("listApiConfigMeta", await this.providerSettingsManager.listConfig()) } @@ -1438,9 +1476,42 @@ export class ClineProvider await this.postStateToWebview() } - async activateProviderProfile(args: { name: string } | { id: string }) { + private async persistStickyProviderProfileToCurrentTask(apiConfigName: string): Promise { + const task = this.getCurrentTask() + if (!task) { + return + } + + try { + // Update in-memory state immediately so sticky behavior works even before the task has + // been persisted into taskHistory (it will be captured on the next save). + task.setTaskApiConfigName(apiConfigName) + + const history = this.getGlobalState("taskHistory") ?? [] + const taskHistoryItem = history.find((item) => item.id === task.taskId) + + if (taskHistoryItem) { + await this.updateTaskHistory({ ...taskHistoryItem, apiConfigName }) + } + } catch (error) { + // If persistence fails, log the error but don't fail the profile switch. + this.log( + `Failed to persist provider profile switch for task ${task.taskId}: ${ + error instanceof Error ? error.message : String(error) + }`, + ) + } + } + + async activateProviderProfile( + args: { name: string } | { id: string }, + options?: { persistModeConfig?: boolean; persistTaskHistory?: boolean }, + ) { const { name, id, ...providerSettings } = await this.providerSettingsManager.activateProfile(args) + const persistModeConfig = options?.persistModeConfig ?? true + const persistTaskHistory = options?.persistTaskHistory ?? true + // See `upsertProviderProfile` for a description of what this is doing. await Promise.all([ this.contextProxy.setValue("listApiConfigMeta", await this.providerSettingsManager.listConfig()), @@ -1450,12 +1521,19 @@ export class ClineProvider const { mode } = await this.getState() - if (id) { + if (id && persistModeConfig) { await this.providerSettingsManager.setModeConfig(mode, id) } + // Change the provider for the current task. this.updateTaskApiHandlerIfNeeded(providerSettings, { forceRebuild: true }) + // Update the current task's sticky provider profile, unless this activation is + // being used purely as a non-persisting restoration (e.g., reopening a task from history). + if (persistTaskHistory) { + await this.persistStickyProviderProfileToCurrentTask(name) + } + await this.postStateToWebview() if (providerSettings.apiProvider) { diff --git a/src/core/webview/__tests__/ClineProvider.sticky-profile.spec.ts b/src/core/webview/__tests__/ClineProvider.sticky-profile.spec.ts new file mode 100644 index 00000000000..3df4408b718 --- /dev/null +++ b/src/core/webview/__tests__/ClineProvider.sticky-profile.spec.ts @@ -0,0 +1,883 @@ +// npx vitest run core/webview/__tests__/ClineProvider.sticky-profile.spec.ts + +import * as vscode from "vscode" +import { TelemetryService } from "@roo-code/telemetry" +import { ClineProvider } from "../ClineProvider" +import { ContextProxy } from "../../config/ContextProxy" +import type { HistoryItem } from "@roo-code/types" + +vi.mock("vscode", () => ({ + ExtensionContext: vi.fn(), + OutputChannel: vi.fn(), + WebviewView: vi.fn(), + Uri: { + joinPath: vi.fn(), + file: vi.fn(), + }, + CodeActionKind: { + QuickFix: { value: "quickfix" }, + RefactorRewrite: { value: "refactor.rewrite" }, + }, + commands: { + executeCommand: vi.fn().mockResolvedValue(undefined), + }, + window: { + showInformationMessage: vi.fn(), + showWarningMessage: vi.fn(), + showErrorMessage: vi.fn(), + onDidChangeActiveTextEditor: vi.fn(() => ({ dispose: vi.fn() })), + }, + workspace: { + getConfiguration: vi.fn().mockReturnValue({ + get: vi.fn().mockReturnValue([]), + update: vi.fn(), + }), + onDidChangeConfiguration: vi.fn().mockImplementation(() => ({ + dispose: vi.fn(), + })), + onDidSaveTextDocument: vi.fn(() => ({ dispose: vi.fn() })), + onDidChangeTextDocument: vi.fn(() => ({ dispose: vi.fn() })), + onDidOpenTextDocument: vi.fn(() => ({ dispose: vi.fn() })), + onDidCloseTextDocument: vi.fn(() => ({ dispose: vi.fn() })), + }, + env: { + uriScheme: "vscode", + language: "en", + appName: "Visual Studio Code", + }, + ExtensionMode: { + Production: 1, + Development: 2, + Test: 3, + }, + version: "1.85.0", +})) + +// Create a counter for unique task IDs. +let taskIdCounter = 0 + +vi.mock("../../task/Task", () => ({ + Task: vi.fn().mockImplementation((options) => ({ + taskId: options.taskId || `test-task-id-${++taskIdCounter}`, + saveClineMessages: vi.fn(), + clineMessages: [], + apiConversationHistory: [], + overwriteClineMessages: vi.fn(), + overwriteApiConversationHistory: vi.fn(), + abortTask: vi.fn(), + handleWebviewAskResponse: vi.fn(), + getTaskNumber: vi.fn().mockReturnValue(0), + setTaskNumber: vi.fn(), + setParentTask: vi.fn(), + setRootTask: vi.fn(), + emit: vi.fn(), + parentTask: options.parentTask, + updateApiConfiguration: vi.fn(), + setTaskApiConfigName: vi.fn(), + _taskApiConfigName: options.historyItem?.apiConfigName, + taskApiConfigName: options.historyItem?.apiConfigName, + })), +})) + +vi.mock("../../prompts/sections/custom-instructions") + +vi.mock("../../../utils/safeWriteJson") + +vi.mock("../../../api", () => ({ + buildApiHandler: vi.fn().mockReturnValue({ + getModel: vi.fn().mockReturnValue({ + id: "claude-3-sonnet", + }), + }), +})) + +vi.mock("../../../integrations/workspace/WorkspaceTracker", () => ({ + default: vi.fn().mockImplementation(() => ({ + initializeFilePaths: vi.fn(), + dispose: vi.fn(), + })), +})) + +vi.mock("../../diff/strategies/multi-search-replace", () => ({ + MultiSearchReplaceDiffStrategy: vi.fn().mockImplementation(() => ({ + getToolDescription: () => "test", + getName: () => "test-strategy", + applyDiff: vi.fn(), + })), +})) + +vi.mock("@roo-code/cloud", () => ({ + CloudService: { + hasInstance: vi.fn().mockReturnValue(true), + get instance() { + return { + isAuthenticated: vi.fn().mockReturnValue(false), + } + }, + }, + BridgeOrchestrator: { + isEnabled: vi.fn().mockReturnValue(false), + }, + getRooCodeApiUrl: vi.fn().mockReturnValue("https://app.roocode.com"), +})) + +vi.mock("../../../shared/modes", () => ({ + modes: [ + { + slug: "code", + name: "Code Mode", + roleDefinition: "You are a code assistant", + groups: ["read", "edit", "browser"], + }, + { + slug: "architect", + name: "Architect Mode", + roleDefinition: "You are an architect", + groups: ["read", "edit"], + }, + ], + getModeBySlug: vi.fn().mockReturnValue({ + slug: "code", + name: "Code Mode", + roleDefinition: "You are a code assistant", + groups: ["read", "edit", "browser"], + }), + defaultModeSlug: "code", +})) + +vi.mock("../../prompts/system", () => ({ + SYSTEM_PROMPT: vi.fn().mockResolvedValue("mocked system prompt"), + codeMode: "code", +})) + +vi.mock("../../../api/providers/fetchers/modelCache", () => ({ + getModels: vi.fn().mockResolvedValue({}), + flushModels: vi.fn(), +})) + +vi.mock("../../../integrations/misc/extract-text", () => ({ + extractTextFromFile: vi.fn().mockResolvedValue("Mock file content"), +})) + +vi.mock("p-wait-for", () => ({ + default: vi.fn().mockImplementation(async () => Promise.resolve()), +})) + +vi.mock("fs/promises", () => ({ + mkdir: vi.fn().mockResolvedValue(undefined), + writeFile: vi.fn().mockResolvedValue(undefined), + readFile: vi.fn().mockResolvedValue(""), + unlink: vi.fn().mockResolvedValue(undefined), + rmdir: vi.fn().mockResolvedValue(undefined), +})) + +vi.mock("@roo-code/telemetry", () => ({ + TelemetryService: { + hasInstance: vi.fn().mockReturnValue(true), + createInstance: vi.fn(), + get instance() { + return { + trackEvent: vi.fn(), + trackError: vi.fn(), + setProvider: vi.fn(), + captureModeSwitch: vi.fn(), + } + }, + }, +})) + +describe("ClineProvider - Sticky Provider Profile", () => { + let provider: ClineProvider + let mockContext: vscode.ExtensionContext + let mockOutputChannel: vscode.OutputChannel + let mockWebviewView: vscode.WebviewView + let mockPostMessage: any + + beforeEach(() => { + vi.clearAllMocks() + taskIdCounter = 0 + + if (!TelemetryService.hasInstance()) { + TelemetryService.createInstance([]) + } + + const globalState: Record = { + mode: "code", + currentApiConfigName: "default-profile", + } + + const secrets: Record = {} + + mockContext = { + extensionPath: "/test/path", + extensionUri: {} as vscode.Uri, + globalState: { + get: vi.fn().mockImplementation((key: string) => globalState[key]), + update: vi.fn().mockImplementation((key: string, value: string | undefined) => { + globalState[key] = value + return Promise.resolve() + }), + keys: vi.fn().mockImplementation(() => Object.keys(globalState)), + }, + secrets: { + get: vi.fn().mockImplementation((key: string) => secrets[key]), + store: vi.fn().mockImplementation((key: string, value: string | undefined) => { + secrets[key] = value + return Promise.resolve() + }), + delete: vi.fn().mockImplementation((key: string) => { + delete secrets[key] + return Promise.resolve() + }), + }, + subscriptions: [], + extension: { + packageJSON: { version: "1.0.0" }, + }, + globalStorageUri: { + fsPath: "/test/storage/path", + }, + } as unknown as vscode.ExtensionContext + + mockOutputChannel = { + appendLine: vi.fn(), + clear: vi.fn(), + dispose: vi.fn(), + } as unknown as vscode.OutputChannel + + mockPostMessage = vi.fn() + + mockWebviewView = { + webview: { + postMessage: mockPostMessage, + html: "", + options: {}, + onDidReceiveMessage: vi.fn(), + asWebviewUri: vi.fn(), + cspSource: "vscode-webview://test-csp-source", + }, + visible: true, + onDidDispose: vi.fn().mockImplementation((callback) => { + callback() + return { dispose: vi.fn() } + }), + onDidChangeVisibility: vi.fn().mockImplementation(() => ({ dispose: vi.fn() })), + } as unknown as vscode.WebviewView + + provider = new ClineProvider(mockContext, mockOutputChannel, "sidebar", new ContextProxy(mockContext)) + + // Mock getMcpHub method + provider.getMcpHub = vi.fn().mockReturnValue({ + listTools: vi.fn().mockResolvedValue([]), + callTool: vi.fn().mockResolvedValue({ content: [] }), + listResources: vi.fn().mockResolvedValue([]), + readResource: vi.fn().mockResolvedValue({ contents: [] }), + getAllServers: vi.fn().mockReturnValue([]), + }) + }) + + describe("activateProviderProfile", () => { + beforeEach(async () => { + await provider.resolveWebviewView(mockWebviewView) + }) + + it("should save provider profile to task metadata when switching profiles", async () => { + // Create a mock task + const mockTask = { + taskId: "test-task-id", + _taskApiConfigName: "default-profile", + setTaskApiConfigName: vi.fn(), + emit: vi.fn(), + saveClineMessages: vi.fn(), + clineMessages: [], + apiConversationHistory: [], + updateApiConfiguration: vi.fn(), + } + + // Add task to provider stack + await provider.addClineToStack(mockTask as any) + + // Mock getGlobalState to return task history + vi.spyOn(provider as any, "getGlobalState").mockReturnValue([ + { + id: mockTask.taskId, + ts: Date.now(), + task: "Test task", + number: 1, + tokensIn: 0, + tokensOut: 0, + cacheWrites: 0, + cacheReads: 0, + totalCost: 0, + }, + ]) + + // Mock updateTaskHistory to track calls + const updateTaskHistorySpy = vi + .spyOn(provider, "updateTaskHistory") + .mockImplementation(() => Promise.resolve([])) + + // Mock providerSettingsManager.activateProfile + vi.spyOn(provider.providerSettingsManager, "activateProfile").mockResolvedValue({ + name: "new-profile", + id: "new-profile-id", + apiProvider: "anthropic", + }) + + // Mock providerSettingsManager.listConfig + vi.spyOn(provider.providerSettingsManager, "listConfig").mockResolvedValue([ + { name: "new-profile", id: "new-profile-id", apiProvider: "anthropic" }, + ]) + + // Switch provider profile + await provider.activateProviderProfile({ name: "new-profile" }) + + // Verify task history was updated with new provider profile + expect(updateTaskHistorySpy).toHaveBeenCalledWith( + expect.objectContaining({ + id: mockTask.taskId, + apiConfigName: "new-profile", + }), + ) + + // Verify task's setTaskApiConfigName was called + expect(mockTask.setTaskApiConfigName).toHaveBeenCalledWith("new-profile") + }) + + it("should update task's taskApiConfigName property when switching profiles", async () => { + // Create a mock task with initial profile + const mockTask = { + taskId: "test-task-id", + _taskApiConfigName: "default-profile", + setTaskApiConfigName: vi.fn().mockImplementation(function (this: any, name: string) { + this._taskApiConfigName = name + }), + emit: vi.fn(), + saveClineMessages: vi.fn(), + clineMessages: [], + apiConversationHistory: [], + updateApiConfiguration: vi.fn(), + } + + // Add task to provider stack + await provider.addClineToStack(mockTask as any) + + // Mock getGlobalState to return task history + vi.spyOn(provider as any, "getGlobalState").mockReturnValue([ + { + id: mockTask.taskId, + ts: Date.now(), + task: "Test task", + number: 1, + tokensIn: 0, + tokensOut: 0, + cacheWrites: 0, + cacheReads: 0, + totalCost: 0, + }, + ]) + + // Mock updateTaskHistory + vi.spyOn(provider, "updateTaskHistory").mockImplementation(() => Promise.resolve([])) + + // Mock providerSettingsManager.activateProfile + vi.spyOn(provider.providerSettingsManager, "activateProfile").mockResolvedValue({ + name: "new-profile", + id: "new-profile-id", + apiProvider: "openrouter", + }) + + // Mock providerSettingsManager.listConfig + vi.spyOn(provider.providerSettingsManager, "listConfig").mockResolvedValue([ + { name: "new-profile", id: "new-profile-id", apiProvider: "openrouter" }, + ]) + + // Switch provider profile + await provider.activateProviderProfile({ name: "new-profile" }) + + // Verify task's _taskApiConfigName property was updated + expect(mockTask._taskApiConfigName).toBe("new-profile") + }) + + it("should update in-memory task profile even if task history item does not exist yet", async () => { + await provider.resolveWebviewView(mockWebviewView) + + const mockTask = { + taskId: "test-task-id", + _taskApiConfigName: "default-profile", + setTaskApiConfigName: vi.fn().mockImplementation(function (this: any, name: string) { + this._taskApiConfigName = name + }), + emit: vi.fn(), + saveClineMessages: vi.fn(), + clineMessages: [], + apiConversationHistory: [], + updateApiConfiguration: vi.fn(), + } + + await provider.addClineToStack(mockTask as any) + + // No history item exists yet + vi.spyOn(provider as any, "getGlobalState").mockReturnValue([]) + + const updateTaskHistorySpy = vi + .spyOn(provider, "updateTaskHistory") + .mockImplementation(() => Promise.resolve([])) + + vi.spyOn(provider.providerSettingsManager, "activateProfile").mockResolvedValue({ + name: "new-profile", + id: "new-profile-id", + apiProvider: "openrouter", + }) + + vi.spyOn(provider.providerSettingsManager, "listConfig").mockResolvedValue([ + { name: "new-profile", id: "new-profile-id", apiProvider: "openrouter" }, + ]) + + await provider.activateProviderProfile({ name: "new-profile" }) + + // In-memory should still update, even without a history item. + expect(mockTask._taskApiConfigName).toBe("new-profile") + // No history item => no updateTaskHistory call. + expect(updateTaskHistorySpy).not.toHaveBeenCalled() + }) + }) + + describe("createTaskWithHistoryItem", () => { + it("should restore provider profile from history item when reopening task", async () => { + await provider.resolveWebviewView(mockWebviewView) + + // Create a history item with saved provider profile + const historyItem: HistoryItem = { + id: "test-task-id", + number: 1, + ts: Date.now(), + task: "Test task", + tokensIn: 100, + tokensOut: 200, + cacheWrites: 0, + cacheReads: 0, + totalCost: 0.001, + mode: "code", + apiConfigName: "saved-profile", // Saved provider profile + } + + // Mock activateProviderProfile to track calls + const activateProviderProfileSpy = vi + .spyOn(provider, "activateProviderProfile") + .mockResolvedValue(undefined) + + // Mock providerSettingsManager.listConfig + vi.spyOn(provider.providerSettingsManager, "listConfig").mockResolvedValue([ + { name: "saved-profile", id: "saved-profile-id", apiProvider: "anthropic" }, + ]) + + // Initialize task with history item + await provider.createTaskWithHistoryItem(historyItem) + + // Verify provider profile was restored via activateProviderProfile (restore-only: don't persist mode config) + expect(activateProviderProfileSpy).toHaveBeenCalledWith( + { name: "saved-profile" }, + { persistModeConfig: false, persistTaskHistory: false }, + ) + }) + + it("should use current profile if history item has no saved apiConfigName", async () => { + await provider.resolveWebviewView(mockWebviewView) + + // Create a history item without saved provider profile + const historyItem: HistoryItem = { + id: "test-task-id", + number: 1, + ts: Date.now(), + task: "Test task", + tokensIn: 100, + tokensOut: 200, + cacheWrites: 0, + cacheReads: 0, + totalCost: 0.001, + // No apiConfigName field + } + + // Mock activateProviderProfile to track calls + const activateProviderProfileSpy = vi + .spyOn(provider, "activateProviderProfile") + .mockResolvedValue(undefined) + + // Initialize task with history item + await provider.createTaskWithHistoryItem(historyItem) + + // Verify activateProviderProfile was NOT called for apiConfigName restoration + // (it might be called for mode-based config, but not for direct apiConfigName) + const callsForApiConfigName = activateProviderProfileSpy.mock.calls.filter( + (call) => call[0] && "name" in call[0] && call[0].name === historyItem.apiConfigName, + ) + expect(callsForApiConfigName.length).toBe(0) + }) + + it("should override mode-based config with task's apiConfigName", async () => { + await provider.resolveWebviewView(mockWebviewView) + + // Create a history item with both mode and apiConfigName + const historyItem: HistoryItem = { + id: "test-task-id", + number: 1, + ts: Date.now(), + task: "Test task", + tokensIn: 100, + tokensOut: 200, + cacheWrites: 0, + cacheReads: 0, + totalCost: 0.001, + mode: "architect", // Mode has a different preferred profile + apiConfigName: "task-specific-profile", // Task's actual profile + } + + // Track all activateProviderProfile calls + const activateCalls: string[] = [] + vi.spyOn(provider, "activateProviderProfile").mockImplementation(async (args) => { + if ("name" in args) { + activateCalls.push(args.name) + } + }) + + // Mock providerSettingsManager methods + vi.spyOn(provider.providerSettingsManager, "getModeConfigId").mockResolvedValue("mode-config-id") + vi.spyOn(provider.providerSettingsManager, "listConfig").mockResolvedValue([ + { name: "mode-preferred-profile", id: "mode-config-id", apiProvider: "anthropic" }, + { name: "task-specific-profile", id: "task-profile-id", apiProvider: "openai" }, + ]) + + // Initialize task with history item + await provider.createTaskWithHistoryItem(historyItem) + + // Verify task's apiConfigName was activated LAST (overriding mode-based config) + expect(activateCalls[activateCalls.length - 1]).toBe("task-specific-profile") + }) + + it("should handle missing provider profile gracefully", async () => { + await provider.resolveWebviewView(mockWebviewView) + + // Create a history item with a provider profile that no longer exists + const historyItem: HistoryItem = { + id: "test-task-id", + number: 1, + ts: Date.now(), + task: "Test task", + tokensIn: 100, + tokensOut: 200, + cacheWrites: 0, + cacheReads: 0, + totalCost: 0.001, + apiConfigName: "deleted-profile", // Profile that doesn't exist + } + + // Mock providerSettingsManager.listConfig to return empty (profile doesn't exist) + vi.spyOn(provider.providerSettingsManager, "listConfig").mockResolvedValue([]) + + // Mock log to verify warning is logged + const logSpy = vi.spyOn(provider, "log") + + // Initialize task with history item - should not throw + await expect(provider.createTaskWithHistoryItem(historyItem)).resolves.not.toThrow() + + // Verify a warning was logged + expect(logSpy).toHaveBeenCalledWith( + expect.stringContaining("Provider profile 'deleted-profile' from history no longer exists"), + ) + }) + }) + + describe("Task metadata persistence", () => { + it("should include apiConfigName in task metadata when saving", async () => { + await provider.resolveWebviewView(mockWebviewView) + + // Create a mock task with provider profile + const mockTask = { + taskId: "test-task-id", + _taskApiConfigName: "test-profile", + setTaskApiConfigName: vi.fn(), + emit: vi.fn(), + saveClineMessages: vi.fn(), + clineMessages: [], + apiConversationHistory: [], + updateApiConfiguration: vi.fn(), + } + + // Mock getGlobalState to return task history with our task + vi.spyOn(provider as any, "getGlobalState").mockReturnValue([ + { + id: mockTask.taskId, + ts: Date.now(), + task: "Test task", + number: 1, + tokensIn: 0, + tokensOut: 0, + cacheWrites: 0, + cacheReads: 0, + totalCost: 0, + }, + ]) + + // Mock updateTaskHistory to capture the updated history item + let updatedHistoryItem: any + vi.spyOn(provider, "updateTaskHistory").mockImplementation((item) => { + updatedHistoryItem = item + return Promise.resolve([item]) + }) + + // Add task to provider stack + await provider.addClineToStack(mockTask as any) + + // Mock providerSettingsManager.activateProfile + vi.spyOn(provider.providerSettingsManager, "activateProfile").mockResolvedValue({ + name: "new-profile", + id: "new-profile-id", + apiProvider: "anthropic", + }) + + // Mock providerSettingsManager.listConfig + vi.spyOn(provider.providerSettingsManager, "listConfig").mockResolvedValue([ + { name: "new-profile", id: "new-profile-id", apiProvider: "anthropic" }, + ]) + + // Trigger a profile switch + await provider.activateProviderProfile({ name: "new-profile" }) + + // Verify apiConfigName was included in the updated history item + expect(updatedHistoryItem).toBeDefined() + expect(updatedHistoryItem.apiConfigName).toBe("new-profile") + }) + }) + + describe("Multiple workspaces isolation", () => { + it("should preserve task profile when switching profiles in another workspace", async () => { + // This test verifies that each task retains its designated provider profile + // so that switching profiles in one workspace doesn't alter other tasks + + await provider.resolveWebviewView(mockWebviewView) + + // Create task 1 with profile A + const task1 = { + taskId: "task-1", + _taskApiConfigName: "profile-a", + setTaskApiConfigName: vi.fn().mockImplementation(function (this: any, name: string) { + this._taskApiConfigName = name + }), + emit: vi.fn(), + saveClineMessages: vi.fn(), + clineMessages: [], + apiConversationHistory: [], + updateApiConfiguration: vi.fn(), + } + + // Create task 2 with profile B + const task2 = { + taskId: "task-2", + _taskApiConfigName: "profile-b", + setTaskApiConfigName: vi.fn().mockImplementation(function (this: any, name: string) { + this._taskApiConfigName = name + }), + emit: vi.fn(), + saveClineMessages: vi.fn(), + clineMessages: [], + apiConversationHistory: [], + updateApiConfiguration: vi.fn(), + } + + // Add task 1 to stack + await provider.addClineToStack(task1 as any) + + // Mock getGlobalState to return task history for both tasks + const taskHistory = [ + { + id: "task-1", + ts: Date.now(), + task: "Task 1", + number: 1, + tokensIn: 0, + tokensOut: 0, + cacheWrites: 0, + cacheReads: 0, + totalCost: 0, + apiConfigName: "profile-a", + }, + { + id: "task-2", + ts: Date.now(), + task: "Task 2", + number: 2, + tokensIn: 0, + tokensOut: 0, + cacheWrites: 0, + cacheReads: 0, + totalCost: 0, + apiConfigName: "profile-b", + }, + ] + + vi.spyOn(provider as any, "getGlobalState").mockReturnValue(taskHistory) + + // Mock updateTaskHistory + vi.spyOn(provider, "updateTaskHistory").mockImplementation((item) => { + const index = taskHistory.findIndex((h) => h.id === item.id) + if (index >= 0) { + taskHistory[index] = { ...taskHistory[index], ...item } + } + return Promise.resolve(taskHistory) + }) + + // Mock providerSettingsManager.activateProfile + vi.spyOn(provider.providerSettingsManager, "activateProfile").mockResolvedValue({ + name: "profile-c", + id: "profile-c-id", + apiProvider: "anthropic", + }) + + // Mock providerSettingsManager.listConfig + vi.spyOn(provider.providerSettingsManager, "listConfig").mockResolvedValue([ + { name: "profile-a", id: "profile-a-id", apiProvider: "anthropic" }, + { name: "profile-b", id: "profile-b-id", apiProvider: "openai" }, + { name: "profile-c", id: "profile-c-id", apiProvider: "anthropic" }, + ]) + + // Switch task 1's profile to profile C + await provider.activateProviderProfile({ name: "profile-c" }) + + // Verify task 1's profile was updated + expect(task1._taskApiConfigName).toBe("profile-c") + expect(taskHistory[0].apiConfigName).toBe("profile-c") + + // Verify task 2's profile remains unchanged + expect(taskHistory[1].apiConfigName).toBe("profile-b") + }) + }) + + describe("Error handling", () => { + it("should handle errors gracefully when saving profile fails", async () => { + await provider.resolveWebviewView(mockWebviewView) + + // Create a mock task + const mockTask = { + taskId: "test-task-id", + _taskApiConfigName: "default-profile", + setTaskApiConfigName: vi.fn(), + emit: vi.fn(), + saveClineMessages: vi.fn(), + clineMessages: [], + apiConversationHistory: [], + updateApiConfiguration: vi.fn(), + } + + // Add task to provider stack + await provider.addClineToStack(mockTask as any) + + // Mock getGlobalState + vi.spyOn(provider as any, "getGlobalState").mockReturnValue([ + { + id: mockTask.taskId, + ts: Date.now(), + task: "Test task", + number: 1, + tokensIn: 0, + tokensOut: 0, + cacheWrites: 0, + cacheReads: 0, + totalCost: 0, + }, + ]) + + // Mock updateTaskHistory to throw error + vi.spyOn(provider, "updateTaskHistory").mockRejectedValue(new Error("Save failed")) + + // Mock providerSettingsManager.activateProfile + vi.spyOn(provider.providerSettingsManager, "activateProfile").mockResolvedValue({ + name: "new-profile", + id: "new-profile-id", + apiProvider: "anthropic", + }) + + // Mock providerSettingsManager.listConfig + vi.spyOn(provider.providerSettingsManager, "listConfig").mockResolvedValue([ + { name: "new-profile", id: "new-profile-id", apiProvider: "anthropic" }, + ]) + + // Mock log to verify error is logged + const logSpy = vi.spyOn(provider, "log") + + // Switch provider profile - should not throw + await expect(provider.activateProviderProfile({ name: "new-profile" })).resolves.not.toThrow() + + // Verify error was logged + expect(logSpy).toHaveBeenCalledWith(expect.stringContaining("Failed to persist provider profile switch")) + }) + + it("should handle null/undefined apiConfigName gracefully", async () => { + await provider.resolveWebviewView(mockWebviewView) + + // Create a history item with null apiConfigName + const historyItem: HistoryItem = { + id: "test-task-id", + number: 1, + ts: Date.now(), + task: "Test task", + tokensIn: 100, + tokensOut: 200, + cacheWrites: 0, + cacheReads: 0, + totalCost: 0.001, + apiConfigName: null as any, // Invalid apiConfigName + } + + // Mock activateProviderProfile to track calls + const activateProviderProfileSpy = vi + .spyOn(provider, "activateProviderProfile") + .mockResolvedValue(undefined) + + // Initialize task with history item - should not throw + await expect(provider.createTaskWithHistoryItem(historyItem)).resolves.not.toThrow() + + // Verify activateProviderProfile was not called with null + expect(activateProviderProfileSpy).not.toHaveBeenCalledWith({ name: null }) + }) + }) + + describe("Profile restoration with activateProfile failure", () => { + it("should continue task restoration even if activateProviderProfile fails", async () => { + await provider.resolveWebviewView(mockWebviewView) + + // Create a history item with saved provider profile + const historyItem: HistoryItem = { + id: "test-task-id", + number: 1, + ts: Date.now(), + task: "Test task", + tokensIn: 100, + tokensOut: 200, + cacheWrites: 0, + cacheReads: 0, + totalCost: 0.001, + apiConfigName: "failing-profile", + } + + // Mock providerSettingsManager.listConfig to return the profile + vi.spyOn(provider.providerSettingsManager, "listConfig").mockResolvedValue([ + { name: "failing-profile", id: "failing-profile-id", apiProvider: "anthropic" }, + ]) + + // Mock activateProviderProfile to throw error + vi.spyOn(provider, "activateProviderProfile").mockRejectedValue(new Error("Activation failed")) + + // Mock log to verify error is logged + const logSpy = vi.spyOn(provider, "log") + + // Initialize task with history item - should not throw even though activation fails + await expect(provider.createTaskWithHistoryItem(historyItem)).resolves.not.toThrow() + + // Verify error was logged + expect(logSpy).toHaveBeenCalledWith( + expect.stringContaining("Failed to restore API configuration 'failing-profile' for task"), + ) + }) + }) +})