diff --git a/.changeset/dynamic-openai-model-fetching.md b/.changeset/dynamic-openai-model-fetching.md new file mode 100644 index 00000000000..2501d85129a --- /dev/null +++ b/.changeset/dynamic-openai-model-fetching.md @@ -0,0 +1,6 @@ +--- +"kilo-code": patch +"@kilocode/types": patch +--- + +Implement dynamic model fetching for OpenAI-compatible providers diff --git a/packages/types/src/vscode-extension-host.ts b/packages/types/src/vscode-extension-host.ts index bc118a4ed1d..e799a376601 100644 --- a/packages/types/src/vscode-extension-host.ts +++ b/packages/types/src/vscode-extension-host.ts @@ -659,6 +659,7 @@ export type ExtensionState = Pick< debug?: boolean speechToTextStatus?: { available: boolean; reason?: "openaiKeyMissing" | "ffmpegNotInstalled" } // kilocode_change: Speech-to-text availability status with failure reason appendSystemPrompt?: string // kilocode_change: Custom text to append to system prompt (CLI only) + openAiModels?: string[] } export interface Command { diff --git a/webview-ui/src/components/kilocode/hooks/__tests__/dynamic-openai-models.spec.ts b/webview-ui/src/components/kilocode/hooks/__tests__/dynamic-openai-models.spec.ts new file mode 100644 index 00000000000..d9c57157c6f --- /dev/null +++ b/webview-ui/src/components/kilocode/hooks/__tests__/dynamic-openai-models.spec.ts @@ -0,0 +1,80 @@ +import { ModelInfo } from "@roo-code/types" +import { RouterModels } from "@roo/api" +import { getModelsByProvider } from "../useProviderModels" + +describe("PR #5562: Dynamic OpenAI model fetching on front page", () => { + const testModel: ModelInfo = { + maxTokens: 4096, + contextWindow: 8192, + supportsImages: false, + supportsPromptCache: false, + inputPrice: 0.1, + outputPrice: 0.2, + description: "Test model", + } + + const routerModels: RouterModels = { + openrouter: { "test-model": testModel }, + requesty: { "test-model": testModel }, + glama: { "test-model": testModel }, + unbound: { "test-model": testModel }, + litellm: { "test-model": testModel }, + kilocode: { "test-model": testModel }, + "nano-gpt": { "test-model": testModel }, + ollama: { "test-model": testModel }, + lmstudio: { "test-model": testModel }, + "io-intelligence": { "test-model": testModel }, + deepinfra: { "test-model": testModel }, + "vercel-ai-gateway": { "test-model": testModel }, + huggingface: { "test-model": testModel }, + gemini: { "test-model": testModel }, + ovhcloud: { "test-model": testModel }, + chutes: { "test-model": testModel }, + "sap-ai-core": { "test-model": testModel }, + synthetic: { "test-model": testModel }, + inception: { "test-model": testModel }, + roo: { "test-model": testModel }, + } + + const baseArgs = { + routerModels, + kilocodeDefaultModel: "test-model", + options: { isChina: false }, + } + + it("returns dynamically fetched models when openAiModels is provided", () => { + const result = getModelsByProvider({ + ...baseArgs, + provider: "openai", + openAiModels: ["gpt-4o", "gpt-4o-mini", "o1-preview"], + }) + + expect(Object.keys(result.models)).toEqual(["gpt-4o", "gpt-4o-mini", "o1-preview"]) + expect(result.defaultModel).toBe("gpt-4o") + // Each model should have sane defaults (128K context, supports images) + expect(result.models["gpt-4o"].contextWindow).toBe(128_000) + expect(result.models["gpt-4o"].supportsImages).toBe(true) + }) + + it("returns empty models when openAiModels is not provided", () => { + const result = getModelsByProvider({ + ...baseArgs, + provider: "openai", + }) + + expect(Object.keys(result.models)).toHaveLength(0) + expect(result.defaultModel).toBe("") + }) + + it("handles empty openAiModels array gracefully", () => { + const result = getModelsByProvider({ + ...baseArgs, + provider: "openai", + openAiModels: [], + }) + + // Empty array is truthy but has no models + expect(Object.keys(result.models)).toHaveLength(0) + expect(result.defaultModel).toBe("") + }) +}) diff --git a/webview-ui/src/components/kilocode/hooks/useProviderModels.ts b/webview-ui/src/components/kilocode/hooks/useProviderModels.ts index 1674d7791d5..138b7fae200 100644 --- a/webview-ui/src/components/kilocode/hooks/useProviderModels.ts +++ b/webview-ui/src/components/kilocode/hooks/useProviderModels.ts @@ -59,6 +59,7 @@ import { internationalZAiDefaultModelId, mainlandZAiModels, mainlandZAiDefaultModelId, + openAiModelInfoSaneDefaults, } from "@roo-code/types" import type { ModelRecord, RouterModels } from "@roo/api" import { useRouterModels } from "../../ui/hooks/useRouterModels" @@ -73,11 +74,13 @@ export const getModelsByProvider = ({ provider, routerModels, kilocodeDefaultModel, + openAiModels, options = { isChina: false }, }: { provider: ProviderName routerModels: RouterModels kilocodeDefaultModel: string + openAiModels?: string[] options: { isChina?: boolean } }): { models: ModelRecord; defaultModel: string } => { switch (provider) { @@ -181,7 +184,12 @@ export const getModelsByProvider = ({ } } case "openai": { - // TODO(catrielmuller): Support the fetch here + if (openAiModels) { + return { + models: Object.fromEntries(openAiModels.map((model) => [model, openAiModelInfoSaneDefaults])), + defaultModel: openAiModels[0] || "", + } + } return { models: {}, defaultModel: "", @@ -351,7 +359,7 @@ export const getOptionsForProvider = (provider: ProviderName, apiConfiguration?: export const useProviderModels = (apiConfiguration?: ProviderSettings) => { const provider = apiConfiguration?.apiProvider || "anthropic" - const { kilocodeDefaultModel } = useExtensionState() + const { kilocodeDefaultModel, openAiModels } = useExtensionState() const routerModels = useRouterModels({ openRouterBaseUrl: apiConfiguration?.openRouterBaseUrl, @@ -375,6 +383,7 @@ export const useProviderModels = (apiConfiguration?: ProviderSettings) => { provider, routerModels: routerModels.data, kilocodeDefaultModel, + openAiModels, options, }) : FALLBACK_MODELS diff --git a/webview-ui/src/context/ExtensionStateContext.tsx b/webview-ui/src/context/ExtensionStateContext.tsx index c0e62b93a29..dd24b067654 100644 --- a/webview-ui/src/context/ExtensionStateContext.tsx +++ b/webview-ui/src/context/ExtensionStateContext.tsx @@ -1,4 +1,5 @@ import React, { createContext, useCallback, useContext, useEffect, useState } from "react" +import { useDebounce } from "react-use" import { type ProviderSettings, @@ -220,6 +221,7 @@ export interface ExtensionStateContextType extends ExtensionState { setIncludeCurrentTime: (value: boolean) => void includeCurrentCost?: boolean setIncludeCurrentCost: (value: boolean) => void + openAiModels?: string[] } export const ExtensionStateContext = createContext(undefined) @@ -360,6 +362,7 @@ export const ExtensionStateContextProvider: React.FC<{ children: React.ReactNode openRouterImageGenerationSelectedModel: "", includeCurrentTime: true, includeCurrentCost: true, + openAiModels: [], }) const [didHydrateState, setDidHydrateState] = useState(false) @@ -518,6 +521,10 @@ export const ExtensionStateContextProvider: React.FC<{ children: React.ReactNode setExtensionRouterModels(message.routerModels) break } + case "openAiModels": { + setState((prevState) => ({ ...prevState, openAiModels: message.openAiModels ?? [] })) + break + } case "marketplaceData": { if (message.marketplaceItems !== undefined) { setMarketplaceItems(message.marketplaceItems) @@ -554,6 +561,38 @@ export const ExtensionStateContextProvider: React.FC<{ children: React.ReactNode setPrevCloudIsAuthenticated(currentAuth) }, [state.cloudIsAuthenticated, prevCloudIsAuthenticated, state.apiConfiguration?.apiProvider]) + // Fetch OpenAI models on startup or when configuration changes + useDebounce( + () => { + if (!didHydrateState) { + return + } + + const { apiProvider, openAiBaseUrl, openAiApiKey, openAiHeaders } = state.apiConfiguration || {} + + if (apiProvider === "openai" || apiProvider === "openai-responses") { + if (openAiBaseUrl) { + vscode.postMessage({ + type: "requestOpenAiModels", + values: { + baseUrl: openAiBaseUrl, + apiKey: openAiApiKey, + openAiHeaders, + }, + }) + } + } + }, + 500, + [ + didHydrateState, + state.apiConfiguration?.apiProvider, + state.apiConfiguration?.openAiBaseUrl, + state.apiConfiguration?.openAiApiKey, + state.apiConfiguration?.openAiHeaders, + ], + ) + const contextValue: ExtensionStateContextType = { ...state, reasoningBlockCollapsed: state.reasoningBlockCollapsed ?? true,