diff --git a/src/api/providers/fetchers/__tests__/ollama.test.ts b/src/api/providers/fetchers/__tests__/ollama.test.ts index 5eb7c7686656..b6ddbf91a40e 100644 --- a/src/api/providers/fetchers/__tests__/ollama.test.ts +++ b/src/api/providers/fetchers/__tests__/ollama.test.ts @@ -108,10 +108,10 @@ describe("Ollama Fetcher", () => { const result = await getOllamaModels(baseUrl) expect(mockedAxios.get).toHaveBeenCalledTimes(1) - expect(mockedAxios.get).toHaveBeenCalledWith(`${baseUrl}/api/tags`) + expect(mockedAxios.get).toHaveBeenCalledWith(`${baseUrl}/api/tags`, { headers: {} }) expect(mockedAxios.post).toHaveBeenCalledTimes(1) - expect(mockedAxios.post).toHaveBeenCalledWith(`${baseUrl}/api/show`, { model: modelName }) + expect(mockedAxios.post).toHaveBeenCalledWith(`${baseUrl}/api/show`, { model: modelName }, { headers: {} }) expect(typeof result).toBe("object") expect(result).not.toBeInstanceOf(Array) @@ -130,7 +130,7 @@ describe("Ollama Fetcher", () => { const result = await getOllamaModels(baseUrl) expect(mockedAxios.get).toHaveBeenCalledTimes(1) - expect(mockedAxios.get).toHaveBeenCalledWith(`${baseUrl}/api/tags`) + expect(mockedAxios.get).toHaveBeenCalledWith(`${baseUrl}/api/tags`, { headers: {} }) expect(mockedAxios.post).not.toHaveBeenCalled() expect(result).toEqual({}) }) @@ -146,7 +146,7 @@ describe("Ollama Fetcher", () => { const result = await getOllamaModels(baseUrl) expect(mockedAxios.get).toHaveBeenCalledTimes(1) - expect(mockedAxios.get).toHaveBeenCalledWith(`${baseUrl}/api/tags`) + expect(mockedAxios.get).toHaveBeenCalledWith(`${baseUrl}/api/tags`, { headers: {} }) expect(mockedAxios.post).not.toHaveBeenCalled() expect(consoleInfoSpy).toHaveBeenCalledWith(`Failed connecting to Ollama at ${baseUrl}`) expect(result).toEqual({}) @@ -204,10 +204,10 @@ describe("Ollama Fetcher", () => { const result = await getOllamaModels(baseUrl) expect(mockedAxios.get).toHaveBeenCalledTimes(1) - expect(mockedAxios.get).toHaveBeenCalledWith(`${baseUrl}/api/tags`) + expect(mockedAxios.get).toHaveBeenCalledWith(`${baseUrl}/api/tags`, { headers: {} }) expect(mockedAxios.post).toHaveBeenCalledTimes(1) - expect(mockedAxios.post).toHaveBeenCalledWith(`${baseUrl}/api/show`, { model: modelName }) + expect(mockedAxios.post).toHaveBeenCalledWith(`${baseUrl}/api/show`, { model: modelName }, { headers: {} }) expect(typeof result).toBe("object") expect(result).not.toBeInstanceOf(Array) @@ -217,5 +217,73 @@ describe("Ollama Fetcher", () => { // Verify the model was parsed correctly despite null families expect(result[modelName].description).toBe("Family: llama, Context: 4096, Size: 23.6B") }) + + it("should include Authorization header when API key is provided", async () => { + const baseUrl = "http://localhost:11434" + const apiKey = "test-api-key-123" + const modelName = "test-model:latest" + + const mockApiTagsResponse = { + models: [ + { + name: modelName, + model: modelName, + modified_at: "2025-06-03T09:23:22.610222878-04:00", + size: 14333928010, + digest: "6a5f0c01d2c96c687d79e32fdd25b87087feb376bf9838f854d10be8cf3c10a5", + details: { + family: "llama", + families: ["llama"], + format: "gguf", + parameter_size: "23.6B", + parent_model: "", + quantization_level: "Q4_K_M", + }, + }, + ], + } + const mockApiShowResponse = { + license: "Mock License", + modelfile: "FROM /path/to/blob\nTEMPLATE {{ .Prompt }}", + parameters: "num_ctx 4096\nstop_token ", + template: "{{ .System }}USER: {{ .Prompt }}ASSISTANT:", + modified_at: "2025-06-03T09:23:22.610222878-04:00", + details: { + parent_model: "", + format: "gguf", + family: "llama", + families: ["llama"], + parameter_size: "23.6B", + quantization_level: "Q4_K_M", + }, + model_info: { + "ollama.context_length": 4096, + "some.other.info": "value", + }, + capabilities: ["completion"], + } + + mockedAxios.get.mockResolvedValueOnce({ data: mockApiTagsResponse }) + mockedAxios.post.mockResolvedValueOnce({ data: mockApiShowResponse }) + + const result = await getOllamaModels(baseUrl, apiKey) + + const expectedHeaders = { Authorization: `Bearer ${apiKey}` } + + expect(mockedAxios.get).toHaveBeenCalledTimes(1) + expect(mockedAxios.get).toHaveBeenCalledWith(`${baseUrl}/api/tags`, { headers: expectedHeaders }) + + expect(mockedAxios.post).toHaveBeenCalledTimes(1) + expect(mockedAxios.post).toHaveBeenCalledWith( + `${baseUrl}/api/show`, + { model: modelName }, + { headers: expectedHeaders }, + ) + + expect(typeof result).toBe("object") + expect(result).not.toBeInstanceOf(Array) + expect(Object.keys(result).length).toBe(1) + expect(result[modelName]).toBeDefined() + }) }) }) diff --git a/src/api/providers/fetchers/modelCache.ts b/src/api/providers/fetchers/modelCache.ts index a91cdaf99422..ce7582fabca9 100644 --- a/src/api/providers/fetchers/modelCache.ts +++ b/src/api/providers/fetchers/modelCache.ts @@ -75,7 +75,7 @@ export const getModels = async (options: GetModelsOptions): Promise models = await getLiteLLMModels(options.apiKey, options.baseUrl) break case "ollama": - models = await getOllamaModels(options.baseUrl) + models = await getOllamaModels(options.baseUrl, options.apiKey) break case "lmstudio": models = await getLMStudioModels(options.baseUrl) diff --git a/src/api/providers/fetchers/ollama.ts b/src/api/providers/fetchers/ollama.ts index 8e1e3f7f072e..a679a9027d44 100644 --- a/src/api/providers/fetchers/ollama.ts +++ b/src/api/providers/fetchers/ollama.ts @@ -54,7 +54,10 @@ export const parseOllamaModel = (rawModel: OllamaModelInfoResponse): ModelInfo = return modelInfo } -export async function getOllamaModels(baseUrl = "http://localhost:11434"): Promise> { +export async function getOllamaModels( + baseUrl = "http://localhost:11434", + apiKey?: string, +): Promise> { const models: Record = {} // clearing the input can leave an empty string; use the default in that case @@ -65,7 +68,13 @@ export async function getOllamaModels(baseUrl = "http://localhost:11434"): Promi return models } - const response = await axios.get(`${baseUrl}/api/tags`) + // Prepare headers with optional API key + const headers: Record = {} + if (apiKey) { + headers["Authorization"] = `Bearer ${apiKey}` + } + + const response = await axios.get(`${baseUrl}/api/tags`, { headers }) const parsedResponse = OllamaModelsResponseSchema.safeParse(response.data) let modelInfoPromises = [] @@ -73,9 +82,13 @@ export async function getOllamaModels(baseUrl = "http://localhost:11434"): Promi for (const ollamaModel of parsedResponse.data.models) { modelInfoPromises.push( axios - .post(`${baseUrl}/api/show`, { - model: ollamaModel.model, - }) + .post( + `${baseUrl}/api/show`, + { + model: ollamaModel.model, + }, + { headers }, + ) .then((ollamaModelInfo) => { models[ollamaModel.name] = parseOllamaModel(ollamaModelInfo.data) }), diff --git a/src/api/providers/native-ollama.ts b/src/api/providers/native-ollama.ts index 06c1c33d2368..80231540e8e8 100644 --- a/src/api/providers/native-ollama.ts +++ b/src/api/providers/native-ollama.ts @@ -256,7 +256,7 @@ export class NativeOllamaHandler extends BaseProvider implements SingleCompletio } async fetchModel() { - this.models = await getOllamaModels(this.options.ollamaBaseUrl) + this.models = await getOllamaModels(this.options.ollamaBaseUrl, this.options.ollamaApiKey) return this.getModel() } diff --git a/src/core/webview/webviewMessageHandler.ts b/src/core/webview/webviewMessageHandler.ts index a6e8e73a6aef..a3d40b0d821f 100644 --- a/src/core/webview/webviewMessageHandler.ts +++ b/src/core/webview/webviewMessageHandler.ts @@ -887,6 +887,7 @@ export const webviewMessageHandler = async ( const ollamaModels = await getModels({ provider: "ollama", baseUrl: ollamaApiConfig.ollamaBaseUrl, + apiKey: ollamaApiConfig.ollamaApiKey, }) if (Object.keys(ollamaModels).length > 0) { diff --git a/src/shared/api.ts b/src/shared/api.ts index eb3ae124a825..2a3275957498 100644 --- a/src/shared/api.ts +++ b/src/shared/api.ts @@ -150,7 +150,7 @@ export type GetModelsOptions = | { provider: "requesty"; apiKey?: string; baseUrl?: string } | { provider: "unbound"; apiKey?: string } | { provider: "litellm"; apiKey: string; baseUrl: string } - | { provider: "ollama"; baseUrl?: string } + | { provider: "ollama"; baseUrl?: string; apiKey?: string } | { provider: "lmstudio"; baseUrl?: string } | { provider: "deepinfra"; apiKey?: string; baseUrl?: string } | { provider: "io-intelligence"; apiKey: string }