Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 100 additions & 6 deletions src/__tests__/extension.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,22 @@ vi.mock("@dotenvx/dotenvx", () => ({

const mockBridgeOrchestratorDisconnect = vi.fn().mockResolvedValue(undefined)

const mockCloudServiceInstance = {
off: vi.fn(),
on: vi.fn(),
getUserInfo: vi.fn().mockReturnValue(null),
isTaskSyncEnabled: vi.fn().mockReturnValue(false),
authService: {
getSessionToken: vi.fn().mockReturnValue("test-session-token"),
},
}

vi.mock("@roo-code/cloud", () => ({
CloudService: {
createInstance: vi.fn(),
hasInstance: vi.fn().mockReturnValue(true),
get instance() {
return {
off: vi.fn(),
on: vi.fn(),
getUserInfo: vi.fn().mockReturnValue(null),
isTaskSyncEnabled: vi.fn().mockReturnValue(false),
}
return mockCloudServiceInstance
},
},
BridgeOrchestrator: {
Expand Down Expand Up @@ -203,10 +208,12 @@ vi.mock("../core/webview/ClineProvider", async () => {
})

// Mock modelCache to prevent network requests during module loading
const mockRefreshModels = vi.fn().mockResolvedValue({})
vi.mock("../api/providers/fetchers/modelCache", () => ({
flushModels: vi.fn(),
getModels: vi.fn().mockResolvedValue([]),
initializeModelCacheRefresh: vi.fn(),
refreshModels: mockRefreshModels,
}))

describe("extension.ts", () => {
Expand Down Expand Up @@ -244,6 +251,8 @@ describe("extension.ts", () => {
off: vi.fn(),
on: vi.fn(),
telemetryClient: null,
hasActiveSession: vi.fn().mockReturnValue(false),
authService: null,
} as any
})

Expand Down Expand Up @@ -277,6 +286,8 @@ describe("extension.ts", () => {
off: vi.fn(),
on: vi.fn(),
telemetryClient: null,
hasActiveSession: vi.fn().mockReturnValue(false),
authService: null,
} as any
})

Expand All @@ -293,4 +304,87 @@ describe("extension.ts", () => {
// Verify BridgeOrchestrator.disconnect was NOT called.
expect(mockBridgeOrchestratorDisconnect).not.toHaveBeenCalled()
})

describe("Roo model cache refresh on auth state change (ROO-202)", () => {
beforeEach(() => {
vi.resetModules()
mockRefreshModels.mockClear()
})

test("refreshModels is called with session token when auth state changes to active-session", async () => {
const mockAuthService = {
getSessionToken: vi.fn().mockReturnValue("test-session-token"),
}

const { CloudService } = await import("@roo-code/cloud")

vi.mocked(CloudService.createInstance).mockImplementation(async (_context, _logger, handlers) => {
if (handlers?.["auth-state-changed"]) {
authStateChangedHandler = handlers["auth-state-changed"]
}
return {
off: vi.fn(),
on: vi.fn(),
telemetryClient: null,
authService: mockAuthService,
hasActiveSession: vi.fn().mockReturnValue(false),
} as any
})

vi.mocked(CloudService.hasInstance).mockReturnValue(true)

// Activate the extension
const { activate } = await import("../extension")
await activate(mockContext)

// Clear any calls during activation
mockRefreshModels.mockClear()

// Trigger active-session state
await authStateChangedHandler!({
state: "active-session" as AuthState,
previousState: "logged-out" as AuthState,
})

// Verify refreshModels was called with correct parameters including session token
expect(mockRefreshModels).toHaveBeenCalledWith({
provider: "roo",
baseUrl: expect.any(String),
apiKey: "test-session-token",
})
})

test("flushModels is called when auth state changes to logged-out", async () => {
const { flushModels } = await import("../api/providers/fetchers/modelCache")
const { CloudService } = await import("@roo-code/cloud")

vi.mocked(CloudService.createInstance).mockImplementation(async (_context, _logger, handlers) => {
if (handlers?.["auth-state-changed"]) {
authStateChangedHandler = handlers["auth-state-changed"]
}
return {
off: vi.fn(),
on: vi.fn(),
telemetryClient: null,
authService: null,
hasActiveSession: vi.fn().mockReturnValue(false),
} as any
})

vi.mocked(CloudService.hasInstance).mockReturnValue(true)

// Activate the extension
const { activate } = await import("../extension")
await activate(mockContext)

// Trigger logged-out state
await authStateChangedHandler!({
state: "logged-out" as AuthState,
previousState: "active-session" as AuthState,
})

// Verify flushModels was called to clear the cache on logout
expect(flushModels).toHaveBeenCalledWith("roo", false)
})
})
})
20 changes: 13 additions & 7 deletions src/extension.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ import {
CodeActionProvider,
} from "./activate"
import { initializeI18n } from "./i18n"
import { flushModels, getModels, initializeModelCacheRefresh } from "./api/providers/fetchers/modelCache"
import { flushModels, initializeModelCacheRefresh, refreshModels } from "./api/providers/fetchers/modelCache"

/**
* Built using https://github.com/microsoft/vscode-webview-ui-toolkit
Expand Down Expand Up @@ -142,16 +142,22 @@ export async function activate(context: vscode.ExtensionContext) {
}
}

// Handle Roo models cache based on auth state
// Handle Roo models cache based on auth state (ROO-202)
const handleRooModelsCache = async () => {
try {
// Flush and refresh cache on auth state changes
await flushModels("roo", true)

if (data.state === "active-session") {
cloudLogger(`[authStateChangedHandler] Refreshed Roo models cache for active session`)
// Refresh with auth token to get authenticated models
const sessionToken = CloudService.hasInstance()
? CloudService.instance.authService?.getSessionToken()
: undefined
await refreshModels({
provider: "roo",
baseUrl: process.env.ROO_CODE_PROVIDER_URL ?? "https://api.roocode.com/proxy",
apiKey: sessionToken,
})
} else {
cloudLogger(`[authStateChangedHandler] Flushed Roo models cache on logout`)
// Flush without refresh on logout
await flushModels("roo", false)
}
} catch (error) {
cloudLogger(
Expand Down
Loading