diff --git a/src/__tests__/extension.spec.ts b/src/__tests__/extension.spec.ts index 89729fbbf3a..3144a717d26 100644 --- a/src/__tests__/extension.spec.ts +++ b/src/__tests__/extension.spec.ts @@ -48,17 +48,22 @@ vi.mock("@dotenvx/dotenvx", () => ({ const mockBridgeOrchestratorDisconnect = vi.fn().mockResolvedValue(undefined) +const mockCloudServiceInstance = { + off: vi.fn(), + on: vi.fn(), + getUserInfo: vi.fn().mockReturnValue(null), + isTaskSyncEnabled: vi.fn().mockReturnValue(false), + authService: { + getSessionToken: vi.fn().mockReturnValue("test-session-token"), + }, +} + vi.mock("@roo-code/cloud", () => ({ CloudService: { createInstance: vi.fn(), hasInstance: vi.fn().mockReturnValue(true), get instance() { - return { - off: vi.fn(), - on: vi.fn(), - getUserInfo: vi.fn().mockReturnValue(null), - isTaskSyncEnabled: vi.fn().mockReturnValue(false), - } + return mockCloudServiceInstance }, }, BridgeOrchestrator: { @@ -203,10 +208,12 @@ vi.mock("../core/webview/ClineProvider", async () => { }) // Mock modelCache to prevent network requests during module loading +const mockRefreshModels = vi.fn().mockResolvedValue({}) vi.mock("../api/providers/fetchers/modelCache", () => ({ flushModels: vi.fn(), getModels: vi.fn().mockResolvedValue([]), initializeModelCacheRefresh: vi.fn(), + refreshModels: mockRefreshModels, })) describe("extension.ts", () => { @@ -244,6 +251,8 @@ describe("extension.ts", () => { off: vi.fn(), on: vi.fn(), telemetryClient: null, + hasActiveSession: vi.fn().mockReturnValue(false), + authService: null, } as any }) @@ -277,6 +286,8 @@ describe("extension.ts", () => { off: vi.fn(), on: vi.fn(), telemetryClient: null, + hasActiveSession: vi.fn().mockReturnValue(false), + authService: null, } as any }) @@ -293,4 +304,87 @@ describe("extension.ts", () => { // Verify BridgeOrchestrator.disconnect was NOT called. expect(mockBridgeOrchestratorDisconnect).not.toHaveBeenCalled() }) + + describe("Roo model cache refresh on auth state change (ROO-202)", () => { + beforeEach(() => { + vi.resetModules() + mockRefreshModels.mockClear() + }) + + test("refreshModels is called with session token when auth state changes to active-session", async () => { + const mockAuthService = { + getSessionToken: vi.fn().mockReturnValue("test-session-token"), + } + + const { CloudService } = await import("@roo-code/cloud") + + vi.mocked(CloudService.createInstance).mockImplementation(async (_context, _logger, handlers) => { + if (handlers?.["auth-state-changed"]) { + authStateChangedHandler = handlers["auth-state-changed"] + } + return { + off: vi.fn(), + on: vi.fn(), + telemetryClient: null, + authService: mockAuthService, + hasActiveSession: vi.fn().mockReturnValue(false), + } as any + }) + + vi.mocked(CloudService.hasInstance).mockReturnValue(true) + + // Activate the extension + const { activate } = await import("../extension") + await activate(mockContext) + + // Clear any calls during activation + mockRefreshModels.mockClear() + + // Trigger active-session state + await authStateChangedHandler!({ + state: "active-session" as AuthState, + previousState: "logged-out" as AuthState, + }) + + // Verify refreshModels was called with correct parameters including session token + expect(mockRefreshModels).toHaveBeenCalledWith({ + provider: "roo", + baseUrl: expect.any(String), + apiKey: "test-session-token", + }) + }) + + test("flushModels is called when auth state changes to logged-out", async () => { + const { flushModels } = await import("../api/providers/fetchers/modelCache") + const { CloudService } = await import("@roo-code/cloud") + + vi.mocked(CloudService.createInstance).mockImplementation(async (_context, _logger, handlers) => { + if (handlers?.["auth-state-changed"]) { + authStateChangedHandler = handlers["auth-state-changed"] + } + return { + off: vi.fn(), + on: vi.fn(), + telemetryClient: null, + authService: null, + hasActiveSession: vi.fn().mockReturnValue(false), + } as any + }) + + vi.mocked(CloudService.hasInstance).mockReturnValue(true) + + // Activate the extension + const { activate } = await import("../extension") + await activate(mockContext) + + // Trigger logged-out state + await authStateChangedHandler!({ + state: "logged-out" as AuthState, + previousState: "active-session" as AuthState, + }) + + // Verify flushModels was called to clear the cache on logout + expect(flushModels).toHaveBeenCalledWith("roo", false) + }) + }) }) diff --git a/src/extension.ts b/src/extension.ts index e286891cdc3..82561fda5f8 100644 --- a/src/extension.ts +++ b/src/extension.ts @@ -40,7 +40,7 @@ import { CodeActionProvider, } from "./activate" import { initializeI18n } from "./i18n" -import { flushModels, getModels, initializeModelCacheRefresh } from "./api/providers/fetchers/modelCache" +import { flushModels, initializeModelCacheRefresh, refreshModels } from "./api/providers/fetchers/modelCache" /** * Built using https://github.com/microsoft/vscode-webview-ui-toolkit @@ -142,16 +142,22 @@ export async function activate(context: vscode.ExtensionContext) { } } - // Handle Roo models cache based on auth state + // Handle Roo models cache based on auth state (ROO-202) const handleRooModelsCache = async () => { try { - // Flush and refresh cache on auth state changes - await flushModels("roo", true) - if (data.state === "active-session") { - cloudLogger(`[authStateChangedHandler] Refreshed Roo models cache for active session`) + // Refresh with auth token to get authenticated models + const sessionToken = CloudService.hasInstance() + ? CloudService.instance.authService?.getSessionToken() + : undefined + await refreshModels({ + provider: "roo", + baseUrl: process.env.ROO_CODE_PROVIDER_URL ?? "https://api.roocode.com/proxy", + apiKey: sessionToken, + }) } else { - cloudLogger(`[authStateChangedHandler] Flushed Roo models cache on logout`) + // Flush without refresh on logout + await flushModels("roo", false) } } catch (error) { cloudLogger(