Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import { describe, it, expect, vi, beforeEach } from "vitest"
import { webviewMessageHandler } from "../webviewMessageHandler"
import type { ClineProvider } from "../ClineProvider"

// Mock vscode (minimal)
vi.mock("vscode", () => ({
window: {
showErrorMessage: vi.fn(),
showWarningMessage: vi.fn(),
showInformationMessage: vi.fn(),
},
workspace: {
workspaceFolders: undefined,
getConfiguration: vi.fn(() => ({
get: vi.fn(),
update: vi.fn(),
})),
},
env: {
clipboard: { writeText: vi.fn() },
openExternal: vi.fn(),
},
commands: {
executeCommand: vi.fn(),
},
Uri: {
parse: vi.fn((s: string) => ({ toString: () => s })),
file: vi.fn((p: string) => ({ fsPath: p })),
},
ConfigurationTarget: {
Global: 1,
Workspace: 2,
WorkspaceFolder: 3,
},
}))

// Mock modelCache getModels/flushModels used by the handler
const getModelsMock = vi.fn()
vi.mock("../../../api/providers/fetchers/modelCache", () => ({
getModels: (...args: any[]) => getModelsMock(...args),
flushModels: vi.fn(),
}))

describe("webviewMessageHandler - requestRouterModels provider filter", () => {
let mockProvider: ClineProvider & {
postMessageToWebview: ReturnType<typeof vi.fn>
getState: ReturnType<typeof vi.fn>
contextProxy: any
log: ReturnType<typeof vi.fn>
}

beforeEach(() => {
vi.clearAllMocks()

mockProvider = {
// Only methods used by this code path
postMessageToWebview: vi.fn(),
getState: vi.fn().mockResolvedValue({ apiConfiguration: {} }),
contextProxy: {
getValue: vi.fn(),
setValue: vi.fn(),
globalStorageUri: { fsPath: "/mock/storage" },
},
log: vi.fn(),
} as any

// Default mock: return distinct model maps per provider so we can verify keys
getModelsMock.mockImplementation(async (options: any) => {
switch (options?.provider) {
case "roo":
return { "roo/sonnet": { contextWindow: 8192, supportsPromptCache: false } }
case "openrouter":
return { "openrouter/qwen2.5": { contextWindow: 32768, supportsPromptCache: false } }
case "requesty":
return { "requesty/model": { contextWindow: 8192, supportsPromptCache: false } }
case "deepinfra":
return { "deepinfra/model": { contextWindow: 8192, supportsPromptCache: false } }
case "glama":
return { "glama/model": { contextWindow: 8192, supportsPromptCache: false } }
case "unbound":
return { "unbound/model": { contextWindow: 8192, supportsPromptCache: false } }
case "vercel-ai-gateway":
return { "vercel/model": { contextWindow: 8192, supportsPromptCache: false } }
case "io-intelligence":
return { "io/model": { contextWindow: 8192, supportsPromptCache: false } }
case "litellm":
return { "litellm/model": { contextWindow: 8192, supportsPromptCache: false } }
default:
return {}
}
})
})

it("fetches only requested provider when values.provider is present ('roo')", async () => {
await webviewMessageHandler(
mockProvider as any,
{
type: "requestRouterModels",
values: { provider: "roo" },
} as any,
)

// Should post a single routerModels message
expect(mockProvider.postMessageToWebview).toHaveBeenCalledWith(
expect.objectContaining({ type: "routerModels", routerModels: expect.any(Object) }),
)

const call = (mockProvider.postMessageToWebview as any).mock.calls.find(
(c: any[]) => c[0]?.type === "routerModels",
)
expect(call).toBeTruthy()
const payload = call[0]
const routerModels = payload.routerModels as Record<string, Record<string, any>>

// Only "roo" key should be present
const keys = Object.keys(routerModels)
expect(keys).toEqual(["roo"])
expect(Object.keys(routerModels.roo || {})).toContain("roo/sonnet")

// getModels should have been called exactly once for roo
const providersCalled = getModelsMock.mock.calls.map((c: any[]) => c[0]?.provider)
expect(providersCalled).toEqual(["roo"])
})

it("defaults to aggregate fetching when no provider filter is sent", async () => {
await webviewMessageHandler(
mockProvider as any,
{
type: "requestRouterModels",
} as any,
)

const call = (mockProvider.postMessageToWebview as any).mock.calls.find(
(c: any[]) => c[0]?.type === "routerModels",
)
expect(call).toBeTruthy()
const routerModels = call[0].routerModels as Record<string, Record<string, any>>

// Aggregate handler initializes many known routers - ensure a few expected keys exist
expect(routerModels).toHaveProperty("openrouter")
expect(routerModels).toHaveProperty("roo")
expect(routerModels).toHaveProperty("requesty")
})

it("supports filtering another single provider ('openrouter')", async () => {
await webviewMessageHandler(
mockProvider as any,
{
type: "requestRouterModels",
values: { provider: "openrouter" },
} as any,
)

const call = (mockProvider.postMessageToWebview as any).mock.calls.find(
(c: any[]) => c[0]?.type === "routerModels",
)
expect(call).toBeTruthy()
const routerModels = call[0].routerModels as Record<string, Record<string, any>>
const keys = Object.keys(routerModels)

expect(keys).toEqual(["openrouter"])
expect(Object.keys(routerModels.openrouter || {})).toContain("openrouter/qwen2.5")

const providersCalled = getModelsMock.mock.calls.map((c: any[]) => c[0]?.provider)
expect(providersCalled).toEqual(["openrouter"])
})
})
77 changes: 39 additions & 38 deletions src/core/webview/webviewMessageHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -757,20 +757,26 @@ export const webviewMessageHandler = async (
case "requestRouterModels":
const { apiConfiguration } = await provider.getState()

const routerModels: Record<RouterName, ModelRecord> = {
openrouter: {},
"vercel-ai-gateway": {},
huggingface: {},
litellm: {},
deepinfra: {},
"io-intelligence": {},
requesty: {},
unbound: {},
glama: {},
ollama: {},
lmstudio: {},
roo: {},
}
// Optional single provider filter from webview
const requestedProvider = message?.values?.provider
const providerFilter = requestedProvider ? toRouterName(requestedProvider) : undefined

const routerModels: Record<RouterName, ModelRecord> = providerFilter
? ({} as Record<RouterName, ModelRecord>)
: {
openrouter: {},
"vercel-ai-gateway": {},
huggingface: {},
litellm: {},
deepinfra: {},
"io-intelligence": {},
requesty: {},
unbound: {},
glama: {},
ollama: {},
lmstudio: {},
roo: {},
}

const safeGetModels = async (options: GetModelsOptions): Promise<ModelRecord> => {
try {
Expand All @@ -785,7 +791,8 @@ export const webviewMessageHandler = async (
}
}

const modelFetchPromises: { key: RouterName; options: GetModelsOptions }[] = [
// Base candidates (only those handled by this aggregate fetcher)
const candidates: { key: RouterName; options: GetModelsOptions }[] = [
{ key: "openrouter", options: { provider: "openrouter" } },
{
key: "requesty",
Expand Down Expand Up @@ -818,29 +825,30 @@ export const webviewMessageHandler = async (
},
]

// Add IO Intelligence if API key is provided.
const ioIntelligenceApiKey = apiConfiguration.ioIntelligenceApiKey

if (ioIntelligenceApiKey) {
modelFetchPromises.push({
// IO Intelligence is conditional on api key
if (apiConfiguration.ioIntelligenceApiKey) {
candidates.push({
key: "io-intelligence",
options: { provider: "io-intelligence", apiKey: ioIntelligenceApiKey },
options: { provider: "io-intelligence", apiKey: apiConfiguration.ioIntelligenceApiKey },
})
}

// Don't fetch Ollama and LM Studio models by default anymore.
// They have their own specific handlers: requestOllamaModels and requestLmStudioModels.

// LiteLLM is conditional on baseUrl+apiKey
const litellmApiKey = apiConfiguration.litellmApiKey || message?.values?.litellmApiKey
const litellmBaseUrl = apiConfiguration.litellmBaseUrl || message?.values?.litellmBaseUrl

if (litellmApiKey && litellmBaseUrl) {
modelFetchPromises.push({
candidates.push({
key: "litellm",
options: { provider: "litellm", apiKey: litellmApiKey, baseUrl: litellmBaseUrl },
})
}

// Apply single provider filter if specified
const modelFetchPromises = providerFilter
? candidates.filter(({ key }) => key === providerFilter)
: candidates

const results = await Promise.allSettled(
modelFetchPromises.map(async ({ key, options }) => {
const models = await safeGetModels(options)
Expand All @@ -854,18 +862,7 @@ export const webviewMessageHandler = async (
if (result.status === "fulfilled") {
routerModels[routerName] = result.value.models

// Ollama and LM Studio settings pages still need these events.
if (routerName === "ollama" && Object.keys(result.value.models).length > 0) {
provider.postMessageToWebview({
type: "ollamaModels",
ollamaModels: result.value.models,
})
} else if (routerName === "lmstudio" && Object.keys(result.value.models).length > 0) {
provider.postMessageToWebview({
type: "lmStudioModels",
lmStudioModels: result.value.models,
})
}
// Ollama and LM Studio settings pages still need these events. They are not fetched here.
} else {
// Handle rejection: Post a specific error message for this provider.
const errorMessage = result.reason instanceof Error ? result.reason.message : String(result.reason)
Expand All @@ -882,7 +879,11 @@ export const webviewMessageHandler = async (
}
})

provider.postMessageToWebview({ type: "routerModels", routerModels })
provider.postMessageToWebview({
type: "routerModels",
routerModels,
values: providerFilter ? { provider: requestedProvider } : undefined,
})
break
case "requestOllamaModels": {
// Specific handler for Ollama models only.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ describe("useSelectedModel", () => {
})

describe("loading and error states", () => {
it("should return loading state when router models are loading", () => {
it("should NOT set loading when router models are loading but provider is static (anthropic)", () => {
mockUseRouterModels.mockReturnValue({
data: undefined,
isLoading: true,
Expand All @@ -307,10 +307,11 @@ describe("useSelectedModel", () => {
const wrapper = createWrapper()
const { result } = renderHook(() => useSelectedModel(), { wrapper })

expect(result.current.isLoading).toBe(true)
// With static provider default (anthropic), useSelectedModel gates router fetches, so loading should be false
expect(result.current.isLoading).toBe(false)
})

it("should return loading state when open router model providers are loading", () => {
it("should NOT set loading when openrouter provider metadata is loading but provider is static (anthropic)", () => {
mockUseRouterModels.mockReturnValue({
data: { openrouter: {}, requesty: {}, glama: {}, unbound: {}, litellm: {}, "io-intelligence": {} },
isLoading: false,
Expand All @@ -326,10 +327,11 @@ describe("useSelectedModel", () => {
const wrapper = createWrapper()
const { result } = renderHook(() => useSelectedModel(), { wrapper })

expect(result.current.isLoading).toBe(true)
// With static provider default (anthropic), openrouter providers are irrelevant, so loading should be false
expect(result.current.isLoading).toBe(false)
})

it("should return error state when either hook has an error", () => {
it("should NOT set error when hooks error but provider is static (anthropic)", () => {
mockUseRouterModels.mockReturnValue({
data: undefined,
isLoading: false,
Expand All @@ -345,7 +347,8 @@ describe("useSelectedModel", () => {
const wrapper = createWrapper()
const { result } = renderHook(() => useSelectedModel(), { wrapper })

expect(result.current.isError).toBe(true)
// Error from gated routerModels should not bubble for static provider default
expect(result.current.isError).toBe(false)
})
})

Expand Down
30 changes: 27 additions & 3 deletions webview-ui/src/components/ui/hooks/useRouterModels.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@ import { ExtensionMessage } from "@roo/ExtensionMessage"

import { vscode } from "@src/utils/vscode"

const getRouterModels = async () =>
type UseRouterModelsOptions = {
provider?: string // single provider filter (e.g. "roo")
enabled?: boolean // gate fetching entirely
}

const getRouterModels = async (provider?: string) =>
new Promise<RouterModels>((resolve, reject) => {
const cleanup = () => {
window.removeEventListener("message", handler)
Expand All @@ -20,6 +25,14 @@ const getRouterModels = async () =>
const message: ExtensionMessage = event.data

if (message.type === "routerModels") {
const msgProvider = message?.values?.provider as string | undefined

// Verify response matches request
if (provider !== msgProvider) {
// Not our response; ignore and wait for the matching one
return
}
Comment on lines +28 to +34
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Critical bug: The backend never includes the provider in the response values, so msgProvider will always be undefined. When a provider filter is used (e.g., provider="roo"), the check provider !== msgProvider evaluates to "roo" !== undefined, which is always true, causing the response to be ignored indefinitely. The promise will timeout after 10 seconds, breaking router model fetching for all dynamic providers. The backend needs to include the provider in the response: provider.postMessageToWebview({ type: "routerModels", routerModels, values: { provider: requestedProvider } })


clearTimeout(timeout)
cleanup()

Expand All @@ -32,7 +45,18 @@ const getRouterModels = async () =>
}

window.addEventListener("message", handler)
vscode.postMessage({ type: "requestRouterModels" })
if (provider) {
vscode.postMessage({ type: "requestRouterModels", values: { provider } })
} else {
vscode.postMessage({ type: "requestRouterModels" })
}
})

export const useRouterModels = () => useQuery({ queryKey: ["routerModels"], queryFn: getRouterModels })
export const useRouterModels = (opts: UseRouterModelsOptions = {}) => {
const provider = opts.provider || undefined
return useQuery({
queryKey: ["routerModels", provider || "all"],
queryFn: () => getRouterModels(provider),
enabled: opts.enabled !== false,
})
}
Loading