diff --git a/archon-ui-main/src/components/settings/OllamaConfigurationPanel.tsx b/archon-ui-main/src/components/settings/OllamaConfigurationPanel.tsx index 4da6f9a0de..ecf5a78022 100644 --- a/archon-ui-main/src/components/settings/OllamaConfigurationPanel.tsx +++ b/archon-ui-main/src/components/settings/OllamaConfigurationPanel.tsx @@ -6,6 +6,7 @@ import { Badge } from '../ui/Badge'; import { useToast } from '../../features/shared/hooks/useToast'; import { cn } from '../../lib/utils'; import { credentialsService, OllamaInstance } from '../../services/credentialsService'; +import { ollamaService } from '../../services/ollamaService'; import { OllamaModelDiscoveryModal } from './OllamaModelDiscoveryModal'; import type { OllamaInstance as OllamaInstanceType } from './types/OllamaTypes'; @@ -104,61 +105,23 @@ const OllamaConfigurationPanel: React.FC = ({ } }; - // Test connection to an Ollama instance with retry logic + // Test connection to an Ollama instance using ollamaService with smart retry logic const testConnection = async (baseUrl: string, retryCount = 3): Promise => { - const maxRetries = retryCount; - let lastError: Error | null = null; - - for (let attempt = 1; attempt <= maxRetries; attempt++) { - try { - const response = await fetch('/api/providers/validate', { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - }, - body: JSON.stringify({ - provider: 'ollama', - base_url: baseUrl - }) - }); - - if (!response.ok) { - throw new Error(`HTTP ${response.status}: ${response.statusText}`); - } - - const data = await response.json(); - - const result = { - isHealthy: data.health_status?.is_available || false, - responseTimeMs: data.health_status?.response_time_ms, - modelsAvailable: data.health_status?.models_available, - error: data.health_status?.error_message - }; - - // If successful, return immediately - if (result.isHealthy) { - return result; - } - - // If not healthy but we got a valid response, still return (but might retry) - lastError = new Error(result.error || 'Instance not available'); - - } catch (error) { - lastError = error instanceof Error ? error : new Error('Unknown error'); - } + try { + const result = await ollamaService.testConnection(baseUrl, retryCount); - // If this wasn't the last attempt, wait before retrying - if (attempt < maxRetries) { - const delayMs = Math.pow(2, attempt - 1) * 1000; // Exponential backoff: 1s, 2s, 4s - await new Promise(resolve => setTimeout(resolve, delayMs)); - } + return { + isHealthy: result.isHealthy, + responseTimeMs: result.responseTime, + modelsAvailable: undefined, // Not available from the simple health check + error: result.error + }; + } catch (error) { + return { + isHealthy: false, + error: error instanceof Error ? error.message : 'Connection test failed' + }; } - - // All retries failed, return error result - return { - isHealthy: false, - error: lastError?.message || 'Connection failed after retries' - }; }; // Handle connection test for a specific instance diff --git a/archon-ui-main/src/features/knowledge/services/knowledgeService.ts b/archon-ui-main/src/features/knowledge/services/knowledgeService.ts index cfab3f7f92..5016590385 100644 --- a/archon-ui-main/src/features/knowledge/services/knowledgeService.ts +++ b/archon-ui-main/src/features/knowledge/services/knowledgeService.ts @@ -117,27 +117,15 @@ export const knowledgeService = { formData.append("tags", JSON.stringify(metadata.tags)); } - // Use fetch directly for file upload (FormData doesn't work well with our ETag wrapper) - // In test environment, we need absolute URLs - let uploadUrl = "/api/documents/upload"; - if (typeof process !== "undefined" && process.env?.NODE_ENV === "test") { - const testHost = process.env?.VITE_HOST || "localhost"; - const testPort = process.env?.ARCHON_SERVER_PORT || "8181"; - uploadUrl = `http://${testHost}:${testPort}${uploadUrl}`; - } - - const response = await fetch(uploadUrl, { - method: "POST", - body: formData, - signal: AbortSignal.timeout(30000), // 30 second timeout for file uploads - }); - - if (!response.ok) { - const err = await response.json().catch(() => ({})); - throw new APIServiceError(err.error || `HTTP ${response.status}`, "HTTP_ERROR", response.status); - } - - return response.json(); + // Use API service with proper FormData handling and timeout + return callAPIWithETag<{ success: boolean; progressId: string; message: string; filename: string }>( + "/api/documents/upload", + { + method: "POST", + body: formData, + signal: AbortSignal.timeout(30000), // 30 second timeout for file uploads + }, + ); }, /** diff --git a/archon-ui-main/src/features/shared/api/apiClient.ts b/archon-ui-main/src/features/shared/api/apiClient.ts index 5d7d47137f..8bc6aa7fae 100644 --- a/archon-ui-main/src/features/shared/api/apiClient.ts +++ b/archon-ui-main/src/features/shared/api/apiClient.ts @@ -34,11 +34,14 @@ function buildFullUrl(cleanEndpoint: string): string { } /** - * Simple API call function for JSON APIs + * Simple API call function for JSON APIs and FormData uploads * Browser automatically handles ETags/304s through its HTTP cache * - * NOTE: This wrapper is designed for JSON-only API calls. - * For file uploads or FormData requests, use fetch() directly. + * Features: + * - Automatic FormData detection (avoids setting Content-Type header) + * - JSON API support with proper Content-Type headers + * - Built-in timeout and error handling + * - ETag/304 optimization through browser HTTP cache */ export async function callAPIWithETag(endpoint: string, options: RequestInit = {}): Promise { try { @@ -48,24 +51,31 @@ export async function callAPIWithETag(endpoint: string, options: Re // Construct the full URL const fullUrl = buildFullUrl(cleanEndpoint); - // Build headers - only set Content-Type for requests with a body + // Detect FormData to avoid setting Content-Type (browser sets multipart/form-data with boundary) + // Guard against environments where FormData is undefined (Node.js, Jest, iframes) + const isFormData = typeof FormData !== "undefined" && options.body instanceof FormData; + + // Build headers - normalize and handle Content-Type properly for FormData // NOTE: We do NOT add If-None-Match headers; the browser handles ETag revalidation automatically - // - // Currently assumes headers are passed as plain objects (Record) - // which works for all our current usage. The API doesn't require Accept headers - // since it always returns JSON, and we only set Content-Type when sending data. - const headers: Record = { - ...((options.headers as Record) || {}), - }; - - // Only set Content-Type for requests that have a body (POST, PUT, PATCH, etc.) - // GET and DELETE requests should not have Content-Type header - const method = options.method?.toUpperCase() || 'GET'; - const hasBody = options.body !== undefined && options.body !== null; - if (hasBody && !headers['Content-Type']) { - headers['Content-Type'] = 'application/json'; + // Normalize headers to support Headers instances, [string, string][] tuples, and plain objects + const headersObj = new Headers(options.headers as HeadersInit | undefined); + + // Only set Accept header if not already provided by caller (preserves caller-provided Accept headers) + if (!headersObj.has("Accept")) { + headersObj.set("Accept", "application/json"); + } + + if (isFormData) { + // For FormData, remove any Content-Type header to let browser set multipart/form-data with boundary + headersObj.delete("Content-Type"); + } else if (!headersObj.has("Content-Type") && options.body != null) { + // Only set Content-Type if not already provided and body is present + headersObj.set("Content-Type", "application/json"); } + // Preserve Headers instance instead of converting to Record + const headers = headersObj; + // Make the request with timeout // NOTE: Increased to 20s due to database performance issues with large DELETE operations // Root cause: Sequential scan on crawled_pages table when deleting sources with 7K+ rows @@ -104,15 +114,32 @@ export async function callAPIWithETag(endpoint: string, options: Re return undefined as T; } - // Parse response data - const result = await response.json(); + // Check content type before parsing as JSON + const contentType = response.headers.get("content-type")?.toLowerCase() ?? ""; + if (contentType.includes("application/json") || contentType.includes("+json")) { + // Parse JSON response + const result = await response.json(); + if (result && typeof result === "object" && "error" in result && result.error) { + throw new APIServiceError(result.error as string, "API_ERROR", response.status); + } + return result as T; + } - // Check for API errors - if (result.error) { - throw new APIServiceError(result.error, "API_ERROR", response.status); + // Handle binary responses (PDFs, images, octet-stream) + if ( + contentType.includes("application/octet-stream") || + contentType.includes("application/pdf") || + contentType.startsWith("image/") || + contentType.includes("video/") || + contentType.includes("audio/") + ) { + const blob = await response.blob(); + return blob as unknown as T; } - return result as T; + // Handle non-JSON or empty body responses + const text = await response.text().catch(() => ""); + return text ? (text as unknown as T) : (undefined as T); } catch (error) { if (error instanceof APIServiceError) { throw error; diff --git a/archon-ui-main/src/features/shared/api/tests/apiClient.test.ts b/archon-ui-main/src/features/shared/api/tests/apiClient.test.ts index bfe9137516..0e4ab021f2 100644 --- a/archon-ui-main/src/features/shared/api/tests/apiClient.test.ts +++ b/archon-ui-main/src/features/shared/api/tests/apiClient.test.ts @@ -56,7 +56,8 @@ describe("apiClient (callAPIWithETag)", () => { expect.stringContaining("/test-endpoint"), expect.objectContaining({ headers: expect.objectContaining({ - "Content-Type": "application/json", + "content-type": "application/json", + "accept": "application/json", }), }), ); @@ -168,9 +169,10 @@ describe("apiClient (callAPIWithETag)", () => { expect.any(String), expect.objectContaining({ headers: expect.objectContaining({ - "Content-Type": "application/json", - Authorization: "Bearer token123", - "Custom-Header": "custom-value", + "content-type": "application/json", + "accept": "application/json", + "authorization": "Bearer token123", + "custom-header": "custom-value", }), }), ); @@ -409,4 +411,190 @@ describe("apiClient (callAPIWithETag)", () => { expect(result2.version).toBeGreaterThan(result1.version); }); }); + + describe("FormData Support", () => { + it("should detect FormData and omit Content-Type header", async () => { + const mockData = { success: true, fileId: "123" }; + const mockResponse = { + ok: true, + status: 200, + json: () => Promise.resolve(mockData), + headers: new Headers(), + }; + + global.fetch = vi.fn().mockResolvedValue(mockResponse); + + const formData = new FormData(); + formData.append("file", new File(["test content"], "test.txt", { type: "text/plain" })); + formData.append("metadata", "test metadata"); + + const result = await callAPIWithETag("/api/upload", { + method: "POST", + body: formData, + }); + + expect(result).toEqual(mockData); + expect(global.fetch).toHaveBeenCalledWith( + expect.stringContaining("/upload"), + expect.objectContaining({ + method: "POST", + body: formData, + headers: expect.objectContaining({ + "accept": "application/json", + // Content-Type should NOT be present for FormData + }), + }), + ); + + // Verify Content-Type is NOT set (browser sets multipart/form-data with boundary) + const [, options] = (global.fetch as any).mock.calls[0]; + expect(options.headers).not.toHaveProperty("content-type"); + }); + + it("should still set Content-Type for JSON requests", async () => { + const mockData = { success: true }; + const mockResponse = { + ok: true, + status: 200, + json: () => Promise.resolve(mockData), + headers: new Headers(), + }; + + global.fetch = vi.fn().mockResolvedValue(mockResponse); + + const jsonPayload = { name: "test", type: "document" }; + + const result = await callAPIWithETag("/api/create", { + method: "POST", + body: JSON.stringify(jsonPayload), + }); + + expect(result).toEqual(mockData); + expect(global.fetch).toHaveBeenCalledWith( + expect.stringContaining("/create"), + expect.objectContaining({ + headers: expect.objectContaining({ + "content-type": "application/json", + "accept": "application/json", + }), + }), + ); + }); + + it("should handle FormData upload errors properly", async () => { + const errorResponse = { + ok: false, + status: 413, + text: () => Promise.resolve(JSON.stringify({ detail: "File too large" })), + headers: new Headers(), + }; + + global.fetch = vi.fn().mockResolvedValue(errorResponse); + + const formData = new FormData(); + formData.append("file", new File(["large file content"], "large.txt")); + + await expect(callAPIWithETag("/api/upload", { + method: "POST", + body: formData, + })).rejects.toThrow("File too large"); + }); + + it("should preserve custom headers with FormData", async () => { + const mockData = { uploaded: true }; + const mockResponse = { + ok: true, + status: 200, + json: () => Promise.resolve(mockData), + headers: new Headers(), + }; + + global.fetch = vi.fn().mockResolvedValue(mockResponse); + + const formData = new FormData(); + formData.append("file", new File(["test"], "test.txt")); + + await callAPIWithETag("/api/upload", { + method: "POST", + body: formData, + headers: { + Authorization: "Bearer token123", + "X-Custom-Header": "custom-value", + }, + }); + + expect(global.fetch).toHaveBeenCalledWith( + expect.any(String), + expect.objectContaining({ + headers: expect.objectContaining({ + "accept": "application/json", + "authorization": "Bearer token123", + "x-custom-header": "custom-value", + // Content-Type should NOT be present + }), + }), + ); + + const [, options] = (global.fetch as any).mock.calls[0]; + expect(options.headers).not.toHaveProperty("Content-Type"); + }); + + it("should work with binary files in FormData", async () => { + const mockData = { fileId: "binary-123", type: "image" }; + const mockResponse = { + ok: true, + status: 200, + json: () => Promise.resolve(mockData), + headers: new Headers(), + }; + + global.fetch = vi.fn().mockResolvedValue(mockResponse); + + // Create a mock binary file + const binaryData = new Uint8Array([0x89, 0x50, 0x4E, 0x47]); // PNG header + const binaryFile = new File([binaryData], "image.png", { type: "image/png" }); + + const formData = new FormData(); + formData.append("image", binaryFile); + formData.append("description", "Test image upload"); + + const result = await callAPIWithETag("/api/images/upload", { + method: "POST", + body: formData, + }); + + expect(result).toEqual(mockData); + expect(global.fetch).toHaveBeenCalledWith( + expect.stringContaining("/images/upload"), + expect.objectContaining({ + body: formData, + headers: expect.not.objectContaining({ + "Content-Type": expect.any(String), + }), + }), + ); + }); + + it("should remove user-provided Content-Type for FormData", async () => { + const mockResponse = { + ok: true, + status: 200, + json: () => Promise.resolve({ ok: true }), + headers: new Headers(), + }; + global.fetch = vi.fn().mockResolvedValue(mockResponse); + + const fd = new FormData(); + fd.append("file", new File(["x"], "x.txt")); + + await callAPIWithETag("/api/upload", { + method: "POST", + body: fd, + headers: { "Content-Type": "multipart/form-data" }, + }); + + const [, options] = (global.fetch as any).mock.calls[0]; + expect(options.headers).not.toHaveProperty("content-type"); + }); + }); }); diff --git a/archon-ui-main/src/features/shared/tests/queryPatterns.test.ts b/archon-ui-main/src/features/shared/tests/queryPatterns.test.ts new file mode 100644 index 0000000000..5fce9141ab --- /dev/null +++ b/archon-ui-main/src/features/shared/tests/queryPatterns.test.ts @@ -0,0 +1,182 @@ +import { describe, expect, it } from "vitest"; +import { createRetryLogic } from "../queryPatterns"; + +describe("createRetryLogic", () => { + describe("should retry network and server errors", () => { + it("should retry network errors", () => { + const retryLogic = createRetryLogic(3); + + const networkError = new Error("Network error"); + + expect(retryLogic(0, networkError)).toBe(true); // First retry + expect(retryLogic(1, networkError)).toBe(true); // Second retry + expect(retryLogic(2, networkError)).toBe(true); // Third retry + expect(retryLogic(3, networkError)).toBe(false); // Exhausted retries + }); + + it("should retry 5xx server errors", () => { + const retryLogic = createRetryLogic(3); + + const serverError = { statusCode: 500, message: "Internal Server Error" }; + + expect(retryLogic(0, serverError)).toBe(true); // Should retry 500 + expect(retryLogic(1, serverError)).toBe(true); + expect(retryLogic(2, serverError)).toBe(true); + expect(retryLogic(3, serverError)).toBe(false); // Max retries reached + }); + + it("should retry 502 bad gateway errors", () => { + const retryLogic = createRetryLogic(2); + + const badGatewayError = { status: 502 }; + + expect(retryLogic(0, badGatewayError)).toBe(true); + expect(retryLogic(1, badGatewayError)).toBe(true); + expect(retryLogic(2, badGatewayError)).toBe(false); + }); + }); + + describe("should NOT retry client errors", () => { + it("should NOT retry 400 bad request", () => { + const retryLogic = createRetryLogic(3); + + const badRequestError = { statusCode: 400, message: "Bad Request" }; + + expect(retryLogic(0, badRequestError)).toBe(false); + expect(retryLogic(1, badRequestError)).toBe(false); + }); + + it("should NOT retry 401 unauthorized", () => { + const retryLogic = createRetryLogic(3); + + const unauthorizedError = { status: 401 }; + + expect(retryLogic(0, unauthorizedError)).toBe(false); + }); + + it("should NOT retry 404 not found", () => { + const retryLogic = createRetryLogic(3); + + const notFoundError = { statusCode: 404, message: "Not Found" }; + + expect(retryLogic(0, notFoundError)).toBe(false); + }); + + it("should NOT retry 422 validation errors", () => { + const retryLogic = createRetryLogic(3); + + const validationError = { status: 422 }; + + expect(retryLogic(0, validationError)).toBe(false); + }); + + it("should NOT retry 429 rate limit (4xx range)", () => { + const retryLogic = createRetryLogic(3); + + const rateLimitError = { statusCode: 429 }; + + expect(retryLogic(0, rateLimitError)).toBe(false); + }); + }); + + describe("should NOT retry abort errors", () => { + it("should NOT retry AbortError by name", () => { + const retryLogic = createRetryLogic(3); + + const abortError = { name: "AbortError", message: "Request aborted" }; + + expect(retryLogic(0, abortError)).toBe(false); + expect(retryLogic(1, abortError)).toBe(false); + }); + + it("should NOT retry ERR_CANCELED by code", () => { + const retryLogic = createRetryLogic(3); + + const cancelError = { code: "ERR_CANCELED", message: "Request canceled" }; + + expect(retryLogic(0, cancelError)).toBe(false); + }); + + it("should NOT retry axios cancel token", () => { + const retryLogic = createRetryLogic(3); + + const axiosCancelError = { + name: "AbortError", + code: "ERR_CANCELED", + message: "Request canceled" + }; + + expect(retryLogic(0, axiosCancelError)).toBe(false); + }); + }); + + describe("edge cases", () => { + it("should handle null/undefined errors", () => { + const retryLogic = createRetryLogic(2); + + expect(retryLogic(0, null)).toBe(true); // No status, so retry + expect(retryLogic(0, undefined)).toBe(true); + expect(retryLogic(0, "string error")).toBe(true); + }); + + it("should respect maxRetries parameter", () => { + const retryLogic1 = createRetryLogic(1); + const retryLogic5 = createRetryLogic(5); + + const networkError = new Error("Network timeout"); + + // MaxRetries = 1 + expect(retryLogic1(0, networkError)).toBe(true); + expect(retryLogic1(1, networkError)).toBe(false); + + // MaxRetries = 5 + expect(retryLogic5(0, networkError)).toBe(true); + expect(retryLogic5(4, networkError)).toBe(true); + expect(retryLogic5(5, networkError)).toBe(false); + }); + + it("should default to 2 retries when no maxRetries specified", () => { + const retryLogic = createRetryLogic(); // No parameter + + const error = new Error("Some error"); + + expect(retryLogic(0, error)).toBe(true); + expect(retryLogic(1, error)).toBe(true); + expect(retryLogic(2, error)).toBe(false); // Default max is 2 + }); + + it("should handle errors with multiple status properties", () => { + const retryLogic = createRetryLogic(3); + + // Object with both status and statusCode (statusCode takes precedence) + const mixedError = { + status: 500, + statusCode: 404, + message: "Conflict" + }; + + expect(retryLogic(0, mixedError)).toBe(false); // 404 (statusCode) = no retry + }); + }); + + describe("integration with ollamaService patterns", () => { + it("should handle ollama-specific error scenarios", () => { + const retryLogic = createRetryLogic(3); + + // Model not found (404) - don't retry + const modelNotFoundError = { + statusCode: 404, + message: "Model 'llama2' not found" + }; + expect(retryLogic(0, modelNotFoundError)).toBe(false); + + // Connection timeout - retry + const timeoutError = new Error("Connection timeout"); + expect(retryLogic(0, timeoutError)).toBe(true); + + // Server overloaded (503) - retry + const overloadedError = { status: 503, message: "Service unavailable" }; + expect(retryLogic(0, overloadedError)).toBe(true); + }); + }); +}); \ No newline at end of file diff --git a/archon-ui-main/src/services/ollamaService.ts b/archon-ui-main/src/services/ollamaService.ts index 7a6097eb19..25815b4131 100644 --- a/archon-ui-main/src/services/ollamaService.ts +++ b/archon-ui-main/src/services/ollamaService.ts @@ -6,6 +6,7 @@ */ import { getApiUrl } from "../config/api"; +import { createRetryLogic } from "../features/shared/config/queryPatterns"; // Type definitions for Ollama API responses export interface OllamaModel { @@ -39,7 +40,7 @@ export interface ModelDiscoveryResponse { name: string; instance_url: string; size: number; - parameters?: any; + parameters?: unknown; // Real API data from /api/show context_window?: number; architecture?: string; @@ -54,7 +55,7 @@ export interface ModelDiscoveryResponse { instance_url: string; dimensions?: number; size: number; - parameters?: any; + parameters?: unknown; // Real API data from /api/show architecture?: string; format?: string; @@ -154,8 +155,9 @@ export interface EmbeddingRouteOptions { class OllamaService { private baseUrl = getApiUrl(); - private handleApiError(error: any, context: string): Error { + private handleApiError(error: unknown, context: string): Error { const errorMessage = error instanceof Error ? error.message : String(error); + const errorName = error instanceof Error ? error.name : ''; // Check for network errors if ( @@ -170,7 +172,7 @@ class OllamaService { } // Check for timeout errors - if (errorMessage.includes("timeout") || errorMessage.includes("AbortError")) { + if (errorMessage.includes("timeout") || errorMessage.includes("AbortError") || errorName === "AbortError") { return new Error( `Timeout error while ${context.toLowerCase()}: The Ollama instance may be slow to respond or unavailable.` ); @@ -202,8 +204,9 @@ class OllamaService { const response = await fetch(`${this.baseUrl}/api/ollama/models?${params.toString()}`, { method: 'GET', headers: { - 'Content-Type': 'application/json', + 'Accept': 'application/json', }, + signal: AbortSignal.timeout(30000), // 30 second timeout }); if (!response.ok) { @@ -221,7 +224,7 @@ class OllamaService { /** * Check health status of multiple Ollama instances */ - async checkInstanceHealth(instanceUrls: string[], includeModels: boolean = false): Promise { + async checkInstanceHealth(instanceUrls: string[], includeModels: boolean = false, signal?: AbortSignal): Promise { try { if (!instanceUrls || instanceUrls.length === 0) { throw new Error("At least one instance URL is required for health checking"); @@ -240,8 +243,9 @@ class OllamaService { const response = await fetch(`${this.baseUrl}/api/ollama/instances/health?${params.toString()}`, { method: 'GET', headers: { - 'Content-Type': 'application/json', + 'Accept': 'application/json', }, + signal: signal ?? AbortSignal.timeout(30000), }); if (!response.ok) { @@ -273,6 +277,7 @@ class OllamaService { 'Content-Type': 'application/json', }, body: JSON.stringify(requestBody), + signal: AbortSignal.timeout(30000), // 30 second timeout }); if (!response.ok) { @@ -304,6 +309,7 @@ class OllamaService { 'Content-Type': 'application/json', }, body: JSON.stringify(requestBody), + signal: AbortSignal.timeout(30000), // 30 second timeout }); if (!response.ok) { @@ -340,8 +346,9 @@ class OllamaService { const response = await fetch(`${this.baseUrl}/api/ollama/embedding/routes?${params.toString()}`, { method: 'GET', headers: { - 'Content-Type': 'application/json', + 'Accept': 'application/json', }, + signal: AbortSignal.timeout(30000), // 30 second timeout }); if (!response.ok) { @@ -366,6 +373,7 @@ class OllamaService { headers: { 'Content-Type': 'application/json', }, + signal: AbortSignal.timeout(30000), // 30 second timeout }); if (!response.ok) { @@ -381,24 +389,27 @@ class OllamaService { } /** - * Test connectivity to a single Ollama instance (quick health check) with retry logic + * Test connectivity to a single Ollama instance (quick health check) with smart retry logic */ async testConnection(instanceUrl: string, retryCount = 3): Promise<{ isHealthy: boolean; responseTime?: number; error?: string }> { - const maxRetries = retryCount; + const retryLogic = createRetryLogic(retryCount); let lastError: Error | null = null; - for (let attempt = 1; attempt <= maxRetries; attempt++) { + for (let attempt = 1; attempt <= retryCount + 1; attempt++) { try { const startTime = Date.now(); - - const healthResponse = await this.checkInstanceHealth([instanceUrl], false); + + const healthResponse = await this.checkInstanceHealth([instanceUrl], false, AbortSignal.timeout(5000)); const responseTime = Date.now() - startTime; - - const instanceStatus = healthResponse.instance_status[instanceUrl]; - + + const normalizedUrl = instanceUrl.replace(/\/+$/, ""); + const instanceStatus = + healthResponse.instance_status[normalizedUrl] ?? + healthResponse.instance_status[instanceUrl]; + const result = { isHealthy: instanceStatus?.is_healthy || false, - responseTime: instanceStatus?.response_time_ms || responseTime, + responseTime: instanceStatus?.response_time_ms ?? responseTime, error: instanceStatus?.error_message, }; @@ -407,17 +418,50 @@ class OllamaService { return result; } - // If not healthy but we got a valid response, store error for potential retry + // If not healthy but we got a valid response, this might be a 4xx error + // Create an error object that smart retry logic can evaluate lastError = new Error(result.error || 'Instance not available'); - + + // For health check failures, we can add a statusCode if we know it's a client error + if (result.error?.includes('404') || result.error?.includes('not found')) { + (lastError as unknown as { statusCode: number; status: number }).statusCode = 404; + (lastError as unknown as { statusCode: number; status: number }).status = 404; + } else if (result.error?.includes('401') || result.error?.includes('unauthorized')) { + (lastError as unknown as { statusCode: number; status: number }).statusCode = 401; + (lastError as unknown as { statusCode: number; status: number }).status = 401; + } else if (result.error?.includes('403') || result.error?.includes('forbidden')) { + (lastError as unknown as { statusCode: number; status: number }).statusCode = 403; + (lastError as unknown as { statusCode: number; status: number }).status = 403; + } else if (result.error?.includes('500') || result.error?.includes('internal server')) { + (lastError as unknown as { statusCode: number; status: number }).statusCode = 500; + (lastError as unknown as { statusCode: number; status: number }).status = 500; + } + } catch (error) { lastError = error instanceof Error ? error : new Error('Unknown error'); + + // Add status code annotation for HTTP errors that the smart retry logic can use + if (error && typeof error === 'object' && 'status' in error) { + (lastError as unknown as { statusCode: number; status: number }).statusCode = (error as { status: number }).status; + (lastError as unknown as { statusCode: number; status: number }).status = (error as { status: number }).status; + } else if (lastError.message.includes('HTTP ')) { + const statusMatch = lastError.message.match(/HTTP (\d+)/); + if (statusMatch) { + const statusCode = parseInt(statusMatch[1], 10); + (lastError as unknown as { statusCode: number; status: number }).statusCode = statusCode; + (lastError as unknown as { statusCode: number; status: number }).status = statusCode; + } + } } - // If this wasn't the last attempt, wait before retrying - if (attempt < maxRetries) { - const delayMs = Math.pow(2, attempt - 1) * 1000; // Exponential backoff: 1s, 2s, 4s + // Use smart retry logic to determine if we should retry + if (attempt <= retryCount && retryLogic(attempt - 1, lastError)) { + const baseDelay = Math.pow(2, attempt - 1) * 1000; // Exponential backoff: 1s, 2s, 4s + const jitter = Math.random() * 0.5; // Add 0-50% jitter + const delayMs = baseDelay * (1 + jitter); await new Promise(resolve => setTimeout(resolve, delayMs)); + } else { + break; } } diff --git a/python/src/agents/mcp_client.py b/python/src/agents/mcp_client.py index 932473f082..2697055d73 100644 --- a/python/src/agents/mcp_client.py +++ b/python/src/agents/mcp_client.py @@ -6,8 +6,10 @@ instead of direct database access or service imports. """ +import asyncio import json import logging +import uuid from typing import Any import httpx @@ -15,10 +17,27 @@ logger = logging.getLogger(__name__) +class MCPError(Exception): + """Base MCP client error.""" + + +class MCPTransportError(MCPError): + def __init__(self, message: str, status_code: int | None = None): + super().__init__(message if status_code is None else f"[HTTP {status_code}] {message}") + self.status_code = status_code + + +class MCPToolError(MCPError): + def __init__(self, message: str, code: int | None = None, data: Any | None = None): + super().__init__(message if code is None else f"[{code}] {message}") + self.code = code + self.data = data + + class MCPClient: """Client for calling MCP tools via HTTP.""" - def __init__(self, mcp_url: str = None): + def __init__(self, mcp_url: str | None = None): """ Initialize MCP client. @@ -43,7 +62,7 @@ def __init__(self, mcp_url: str = None): else: self.mcp_url = f"http://localhost:{mcp_port}" - self.client = httpx.AsyncClient(timeout=30.0) + self.client: httpx.AsyncClient = httpx.AsyncClient(timeout=30.0) logger.info(f"MCP Client initialized with URL: {self.mcp_url}") async def __aenter__(self): @@ -58,7 +77,7 @@ async def close(self): """Close the HTTP client.""" await self.client.aclose() - async def call_tool(self, tool_name: str, **kwargs) -> dict[str, Any]: + async def call_tool(self, tool_name: str, **kwargs) -> Any: """ Call an MCP tool via HTTP. @@ -67,38 +86,65 @@ async def call_tool(self, tool_name: str, **kwargs) -> dict[str, Any]: **kwargs: Tool arguments Returns: - Dict with the tool response + JSON-RPC result value (any JSON-serializable type) """ try: - # MCP tools are called via JSON-RPC protocol - request_data = {"jsonrpc": "2.0", "method": tool_name, "params": kwargs, "id": 1} - - # Make HTTP request to MCP server - response = await self.client.post( - f"{self.mcp_url}/rpc", - json=request_data, - headers={"Content-Type": "application/json"}, - ) + # Use unique JSON-RPC IDs for correlation + request_id = str(uuid.uuid4()) + request_data = {"jsonrpc": "2.0", "method": tool_name, "params": kwargs, "id": request_id} + + # Add X-Request-ID header for cross-service correlation + headers = {"X-Request-ID": request_id} + + # Make HTTP request to MCP server (httpx sets Content-Type for json=) + response = await self.client.post(f"{self.mcp_url}/rpc", json=request_data, headers=headers) + + # Treat 3xx redirects as transport errors for JSON-RPC + if 300 <= response.status_code < 400: + raise MCPTransportError(f"JSON-RPC does not support redirects (got {response.status_code})", status_code=response.status_code) response.raise_for_status() - result = response.json() + + # Handle invalid JSON responses explicitly + try: + result = response.json() + except json.JSONDecodeError as e: + raise MCPTransportError(f"Invalid JSON response from MCP server: {str(e)}", status_code=response.status_code) from e if "error" in result: error = result["error"] - raise Exception(f"MCP tool error: {error.get('message', 'Unknown error')}") + error_msg = error.get("error") or error.get("message", "Unknown error") + code = error.get("code") + data = error.get("data") + raise MCPToolError(error_msg, code=code, data=data) + + if "result" not in result: + raise MCPError(f"Malformed JSON-RPC response: missing 'result' field in response: {result}") - return result.get("result", {}) + return result["result"] except httpx.HTTPError as e: - logger.error(f"HTTP error calling MCP tool {tool_name}: {e}") - raise Exception(f"Failed to call MCP tool: {str(e)}") - except Exception as e: - logger.error(f"Error calling MCP tool {tool_name}: {e}") + # Extract response details for comprehensive logging + resp = getattr(e, "response", None) + status_code = resp.status_code if resp is not None else None + body_snippet = resp.text[:500] if resp is not None else None + + logger.exception( + f"HTTP error calling MCP tool {tool_name} | url={self.mcp_url}/rpc | " + f"status={status_code} | request_id={request_id} | body_snippet={body_snippet}" + ) + raise MCPTransportError(f"HTTP error calling MCP tool {tool_name}", status_code=status_code) from e + + except MCPError: + # Preserve MCPError subclasses without re-wrapping raise + except Exception as e: + logger.exception(f"Unexpected error calling MCP tool {tool_name} | request_id={request_id}") + raise MCPError(f"Failed to call MCP tool {tool_name}: {str(e)}") from e # Convenience methods for common MCP tools - async def perform_rag_query(self, query: str, source: str = None, match_count: int = 5) -> str: + async def perform_rag_query(self, query: str, source: str | None = None, match_count: int = 5) -> str: """Perform a RAG query through MCP.""" result = await self.call_tool( "perform_rag_query", query=query, source=source, match_count=match_count @@ -111,7 +157,7 @@ async def get_available_sources(self) -> str: return json.dumps(result) if isinstance(result, dict) else str(result) async def search_code_examples( - self, query: str, source_id: str = None, match_count: int = 5 + self, query: str, source_id: str | None = None, match_count: int = 5 ) -> str: """Search code examples through MCP.""" result = await self.call_tool( @@ -139,18 +185,51 @@ async def manage_task(self, action: str, project_id: str, **kwargs) -> str: # Global MCP client instance (created on first use) _mcp_client: MCPClient | None = None +_mcp_client_lock: asyncio.Lock | None = None async def get_mcp_client() -> MCPClient: """ Get or create the global MCP client instance. + Thread-safe implementation using double-checked locking pattern. + Returns: MCPClient instance """ - global _mcp_client + global _mcp_client, _mcp_client_lock + + # First check without lock for performance + if _mcp_client is not None: + return _mcp_client + + # Initialize lock if needed + if _mcp_client_lock is None: + _mcp_client_lock = asyncio.Lock() + + # Double-checked locking pattern + async with _mcp_client_lock: + # Check again in case another coroutine created the client + if _mcp_client is None: + _mcp_client = MCPClient() + logger.info("Created new global MCP client instance") + + return _mcp_client + + +async def shutdown_mcp_client() -> None: + """ + Shutdown the global MCP client instance. + + This should be called during application shutdown to properly + close HTTP connections and clean up resources. + """ + global _mcp_client, _mcp_client_lock - if _mcp_client is None: - _mcp_client = MCPClient() + if _mcp_client is not None: + await _mcp_client.close() + _mcp_client = None - return _mcp_client + # Reset global lock on shutdown for test safety + _mcp_client_lock = None + logger.info("Global MCP client shutdown completed") diff --git a/python/src/server/api_routes/ollama_api.py b/python/src/server/api_routes/ollama_api.py index d961551e88..a0e99778a9 100644 --- a/python/src/server/api_routes/ollama_api.py +++ b/python/src/server/api_routes/ollama_api.py @@ -8,9 +8,12 @@ - Embedding routing and dimension analysis """ +import ipaddress import json -from datetime import datetime +import socket +from datetime import UTC, datetime from typing import Any +from urllib.parse import urlparse from fastapi import APIRouter, BackgroundTasks, HTTPException, Query from pydantic import BaseModel, Field @@ -25,6 +28,31 @@ router = APIRouter(prefix="/api/ollama", tags=["ollama"]) +def _is_private_host(host: str) -> bool: + """ + Check if a hostname resolves to private, loopback, link-local, reserved, or unspecified IP addresses. + + Returns True if the host is considered unsafe for server-side requests to prevent SSRF attacks. + """ + try: + infos = socket.getaddrinfo(host, None) + for _, _, _, _, sockaddr in infos: + ip = ipaddress.ip_address(sockaddr[0]) + if ( + ip.is_private + or ip.is_loopback + or ip.is_link_local + or ip.is_reserved + or ip.is_multicast + or ip.is_unspecified + ): + return True + except Exception: + # If resolution fails, treat as unsafe or log/deny explicitly + return True + return False + + # Pydantic models for API requests/responses class InstanceValidationRequest(BaseModel): """Request for validating an Ollama instance.""" @@ -85,7 +113,7 @@ async def discover_models_endpoint( instance_urls: list[str] = Query(..., description="Ollama instance URLs"), include_capabilities: bool = Query(True, description="Include capability detection"), fetch_details: bool = Query(False, description="Fetch comprehensive model details via /api/show"), - background_tasks: BackgroundTasks = None + background_tasks: BackgroundTasks = BackgroundTasks() ) -> ModelDiscoveryResponse: """ Discover models from multiple Ollama instances with capability detection. @@ -95,32 +123,40 @@ async def discover_models_endpoint( """ try: logger.info(f"Starting model discovery for {len(instance_urls)} instances with fetch_details={fetch_details}") - - # Validate instance URLs + + # Validate instance URLs and check for SSRF risks valid_urls = [] for url in instance_urls: try: - # Basic URL validation + # Basic URL validation - require http/https schemes if not url.startswith(('http://', 'https://')): logger.warning(f"Invalid URL format: {url}") continue + + # SSRF protection - check if URL targets private/internal addresses + parsed = urlparse(url) + if parsed.scheme not in ('http', 'https') or not parsed.hostname or _is_private_host(parsed.hostname): + logger.warning(f"Blocked private/invalid URL: {url}") + continue + valid_urls.append(url.rstrip('/')) except Exception as e: - logger.warning(f"Error validating URL {url}: {e}") + logger.warning(f"Error validating URL {url}: {e}", exc_info=True) if not valid_urls: raise HTTPException(status_code=400, detail="No valid instance URLs provided") # Perform model discovery with optional detailed fetching discovery_result = await model_discovery_service.discover_models_from_multiple_instances( - valid_urls, + valid_urls, fetch_details=fetch_details ) logger.info(f"Discovery complete: {discovery_result['total_models']} models found") # If background tasks available, schedule cache warming - if background_tasks: + if background_tasks is not None: + # Schedule cache warming as a FastAPI background task (runs after response) background_tasks.add_task(_warm_model_cache, valid_urls) return ModelDiscoveryResponse( @@ -135,7 +171,7 @@ async def discover_models_endpoint( except HTTPException: raise except Exception as e: - logger.error(f"Error in model discovery: {e}") + logger.error(f"Error in model discovery: {e}", exc_info=True) raise HTTPException(status_code=500, detail=f"Model discovery failed: {str(e)}") @@ -155,10 +191,23 @@ async def health_check_endpoint( health_results = {} - # Check health for each instance + # Check health for each instance (with SSRF protection) for instance_url in instance_urls: try: url = instance_url.rstrip('/') + + # SSRF protection - check if URL targets private/internal addresses + parsed = urlparse(url) + if parsed.scheme not in ('http', 'https') or not parsed.hostname or _is_private_host(parsed.hostname): + logger.warning(f"Blocked private/invalid URL in health check: {url}") + health_results[url] = { + "is_healthy": False, + "error_message": "URL blocked for security reasons", + "response_time_ms": None, + "models_available": 0, + "last_checked": datetime.now(UTC).isoformat() + } + continue health_status = await model_discovery_service.check_instance_health(url) health_results[url] = { @@ -170,8 +219,8 @@ async def health_check_endpoint( } except Exception as e: - logger.warning(f"Health check failed for {instance_url}: {e}") - health_results[instance_url] = { + logger.warning(f"Health check failed for {instance_url}: {e}", exc_info=True) + health_results[url] = { "is_healthy": False, "response_time_ms": None, "models_available": None, @@ -196,11 +245,11 @@ async def health_check_endpoint( "average_response_time_ms": avg_response_time }, "instance_status": health_results, - "timestamp": model_discovery_service.check_instance_health.__module__ # Use current timestamp + "timestamp": datetime.now(UTC).isoformat() } except Exception as e: - logger.error(f"Error in health check: {e}") + logger.error(f"Error in health check: {e}", exc_info=True) raise HTTPException(status_code=500, detail=f"Health check failed: {str(e)}") @@ -218,6 +267,12 @@ async def validate_instance_endpoint(request: InstanceValidationRequest) -> Inst # Clean up URL instance_url = request.instance_url.rstrip('/') + # SSRF protection - check if URL targets private/internal addresses + parsed = urlparse(instance_url) + if parsed.scheme not in ('http', 'https') or not parsed.hostname or _is_private_host(parsed.hostname): + logger.warning(f"Blocked private/invalid URL in validate_instance: {instance_url}") + raise HTTPException(status_code=400, detail="URL blocked for security reasons") + # Perform basic validation using the provider service validation_result = await validate_provider_instance("ollama", instance_url) @@ -231,8 +286,8 @@ async def validate_instance_endpoint(request: InstanceValidationRequest) -> Inst "total_models": len(models), "chat_models": [m.name for m in models if "chat" in m.capabilities], "embedding_models": [m.name for m in models if "embedding" in m.capabilities], - "supported_dimensions": list(set(m.embedding_dimensions for m in models - if m.embedding_dimensions)) + "supported_dimensions": list({m.embedding_dimensions for m in models + if m.embedding_dimensions}) } except Exception as e: @@ -250,7 +305,7 @@ async def validate_instance_endpoint(request: InstanceValidationRequest) -> Inst ) except Exception as e: - logger.error(f"Error validating instance {request.instance_url}: {e}") + logger.error(f"Error validating instance {request.instance_url}: {e}", exc_info=True) raise HTTPException(status_code=500, detail=f"Instance validation failed: {str(e)}") @@ -265,6 +320,12 @@ async def analyze_embedding_route_endpoint(request: EmbeddingRouteRequest) -> Em try: logger.info(f"Analyzing embedding route for {request.model_name} on {request.instance_url}") + # SSRF protection - require http(s) and block private/internal targets + parsed = urlparse(request.instance_url) + if parsed.scheme not in ('http', 'https') or not parsed.hostname or _is_private_host(parsed.hostname): + logger.warning(f"Blocked private/invalid URL in embedding route analysis: {request.instance_url}") + raise HTTPException(status_code=400, detail="URL blocked for security reasons") + # Get routing decision from the embedding router routing_decision = await embedding_router.route_embedding( model_name=request.model_name, @@ -273,7 +334,7 @@ async def analyze_embedding_route_endpoint(request: EmbeddingRouteRequest) -> Em ) # Calculate performance score - performance_score = embedding_router._calculate_performance_score(routing_decision.dimensions) + performance_score = embedding_router.calculate_performance_score(routing_decision.dimensions) return EmbeddingRouteResponse( target_column=routing_decision.target_column, @@ -287,7 +348,7 @@ async def analyze_embedding_route_endpoint(request: EmbeddingRouteRequest) -> Em ) except Exception as e: - logger.error(f"Error analyzing embedding route: {e}") + logger.error(f"Error analyzing embedding route: {e}", exc_info=True) raise HTTPException(status_code=500, detail=f"Embedding route analysis failed: {str(e)}") @@ -305,8 +366,27 @@ async def get_available_embedding_routes_endpoint( try: logger.info(f"Getting embedding routes for {len(instance_urls)} instances") - # Get available routes - routes = await embedding_router.get_available_embedding_routes(instance_urls) + # Validate instance URLs and check for SSRF risks + valid_urls: list[str] = [] + for url in instance_urls: + try: + parsed = urlparse(url.rstrip('/')) + if parsed.scheme not in ('http', 'https') or not parsed.hostname or _is_private_host(parsed.hostname): + logger.warning(f"Blocked private/invalid URL in embedding routes: {url}") + continue + valid_urls.append(url.rstrip('/')) + except Exception as e: + logger.warning(f"Error validating URL {url}: {e}", exc_info=True) + continue + if not valid_urls: + raise HTTPException(status_code=400, detail="No valid instance URLs provided") + + # Get available routes for validated URLs only + routes = await embedding_router.get_available_embedding_routes(valid_urls) + + # If not sorting by performance, provide a stable alternative ordering + if not sort_by_performance: + routes.sort(key=lambda r: (r.model_name, r.instance_url)) # Convert to response format route_data = [] @@ -343,7 +423,7 @@ async def get_available_embedding_routes_endpoint( } except Exception as e: - logger.error(f"Error getting embedding routes: {e}") + logger.error(f"Error getting embedding routes: {e}", exc_info=True) raise HTTPException(status_code=500, detail=f"Failed to get embedding routes: {str(e)}") @@ -371,14 +451,13 @@ async def clear_ollama_cache_endpoint() -> dict[str, str]: return {"message": "All Ollama caches cleared successfully"} except Exception as e: - logger.error(f"Error clearing caches: {e}") + logger.error(f"Error clearing caches: {e}", exc_info=True) raise HTTPException(status_code=500, detail=f"Failed to clear caches: {str(e)}") class ModelDiscoveryAndStoreRequest(BaseModel): """Request for discovering and storing models from Ollama instances.""" instance_urls: list[str] = Field(..., description="List of Ollama instance URLs") - force_refresh: bool = Field(False, description="Force refresh even if cached data exists") class StoredModelInfo(BaseModel): @@ -431,6 +510,13 @@ async def discover_and_store_models_endpoint(request: ModelDiscoveryAndStoreRequ for instance_url in request.instance_urls: try: base_url = instance_url.replace('/v1', '').rstrip('/') + + # SSRF protection - check if URL targets private/internal addresses + parsed = urlparse(base_url) + if parsed.scheme not in ('http', 'https') or not parsed.hostname or _is_private_host(parsed.hostname): + logger.warning(f"Blocked private/invalid URL in model discovery: {base_url}") + continue + logger.debug(f"Discovering models from {base_url}") # Get detailed model information @@ -454,20 +540,20 @@ async def discover_and_store_models_endpoint(request: ModelDiscoveryAndStoreRequ limitations=compatibility_info['limitations'], performance_rating=_assess_performance_rating(model), description=_generate_model_description(model), - last_updated=datetime.now().isoformat() + last_updated=datetime.now(UTC).isoformat() ) stored_models.append(stored_model) logger.debug(f"Discovered {len(models)} models from {base_url}") except Exception as e: - logger.warning(f"Failed to discover models from {instance_url}: {e}") + logger.warning(f"Failed to discover models from {instance_url}: {e}", exc_info=True) continue # Store models in archon_settings models_data = { "models": [model.dict() for model in stored_models], - "last_discovery": datetime.now().isoformat(), + "last_discovery": datetime.now(UTC).isoformat(), "instances_checked": instances_checked, "total_count": len(stored_models) } @@ -478,7 +564,7 @@ async def discover_and_store_models_endpoint(request: ModelDiscoveryAndStoreRequ "value": json.dumps(models_data), "category": "ollama", "description": "Discovered Ollama models with compatibility information", - "updated_at": datetime.now().isoformat() + "updated_at": datetime.now(UTC).isoformat() }).execute() logger.info(f"Stored {len(stored_models)} models from {instances_checked} instances") @@ -492,7 +578,7 @@ async def discover_and_store_models_endpoint(request: ModelDiscoveryAndStoreRequ ) except Exception as e: - logger.error(f"Error in model discovery and storage: {e}") + logger.error(f"Error in model discovery and storage: {e}", exc_info=True) raise HTTPException(status_code=500, detail=f"Model discovery failed: {str(e)}") @@ -523,9 +609,19 @@ async def get_stored_models_endpoint() -> ModelListResponse: cache_status="empty" ) - models_data = json.loads(models_setting) if isinstance(models_setting, str) else models_setting - from datetime import datetime - + # Handle both JSON string and native dict from DB driver + if isinstance(models_setting, str): + try: + models_data = json.loads(models_setting) + except json.JSONDecodeError: + logger.error("Corrupted 'ollama_discovered_models' JSON in archon_settings", exc_info=True) + raise HTTPException(status_code=500, detail="Stored models are corrupted") + elif isinstance(models_setting, dict): + models_data = models_setting + else: + logger.error(f"Unexpected type for models_setting: {type(models_setting).__name__}", exc_info=True) + raise HTTPException(status_code=500, detail="Invalid stored models format") + # Handle both old format (direct list) and new format (object with models key) if isinstance(models_data, list): # Old format - direct list of models @@ -539,7 +635,7 @@ async def get_stored_models_endpoint() -> ModelListResponse: total_count = models_data.get("total_count", len(models_list)) instances_checked = models_data.get("instances_checked", 0) last_discovery = models_data.get("last_discovery") - + # Convert to StoredModelInfo objects, handling missing fields stored_models = [] for model in models_list: @@ -559,12 +655,12 @@ async def get_stored_models_endpoint() -> ModelListResponse: limitations=model.get('limitations', []), performance_rating=model.get('performance_rating'), description=model.get('description'), - last_updated=model.get('last_updated', datetime.utcnow().isoformat()), + last_updated=model.get('last_updated', datetime.now(UTC).isoformat()), embedding_dimensions=model.get('embedding_dimensions') ) stored_models.append(stored_model) except Exception as model_error: - logger.warning(f"Failed to parse stored model {model}: {model_error}") + logger.warning(f"Failed to parse stored model {model}: {model_error}", exc_info=True) return ModelListResponse( models=stored_models, @@ -575,7 +671,7 @@ async def get_stored_models_endpoint() -> ModelListResponse: ) except Exception as e: - logger.error(f"Error retrieving stored models: {e}") + logger.error(f"Error retrieving stored models: {e}", exc_info=True) raise HTTPException(status_code=500, detail=f"Failed to retrieve models: {str(e)}") @@ -587,6 +683,12 @@ async def _warm_model_cache(instance_urls: list[str]) -> None: for url in instance_urls: try: + # SSRF protection - check if URL targets private/internal addresses + parsed = urlparse(url) + if parsed.scheme not in ('http', 'https') or not parsed.hostname or _is_private_host(parsed.hostname): + logger.warning(f"Blocked private/invalid URL in cache warming: {url}") + continue + await model_discovery_service.discover_models(url) logger.debug(f"Cache warmed for {url}") except Exception as e: @@ -595,35 +697,34 @@ async def _warm_model_cache(instance_urls: list[str]) -> None: logger.info("Model cache warming completed") except Exception as e: - logger.error(f"Error warming model cache: {e}") + logger.error(f"Error warming model cache: {e}", exc_info=True) # Helper functions for model assessment and analysis async def _assess_archon_compatibility_with_testing(model, instance_url: str) -> dict[str, Any]: """Assess Archon compatibility for a given model using actual capability testing.""" - model_name = model.name.lower() capabilities = getattr(model, 'capabilities', []) - + # Test actual model capabilities function_calling_supported = await _test_function_calling_capability(model.name, instance_url) structured_output_supported = await _test_structured_output_capability(model.name, instance_url) - + # Determine compatibility level based on actual test results compatibility_level = 'limited' features = ['Local Processing'] # All Ollama models support local processing limitations = [] - + # Check for chat capability if 'chat' in capabilities: features.append('Text Generation') features.append('MCP Integration') # All chat models can integrate with MCP features.append('Streaming') # All Ollama models support streaming - + # Add advanced features based on actual testing if function_calling_supported: features.append('Function Calls') compatibility_level = 'full' # Function calling indicates full support - + if structured_output_supported: features.append('Structured Output') if compatibility_level != 'full': @@ -631,18 +732,18 @@ async def _assess_archon_compatibility_with_testing(model, instance_url: str) -> else: if compatibility_level != 'full': # Only add limitation if not already full support limitations.append('Limited structured output support') - + # Add embedding capability if 'embedding' in capabilities: features.append('High-quality embeddings') if compatibility_level == 'limited': compatibility_level = 'full' # Embedding models are considered full support for their purpose - + # If no advanced features detected, remain limited if not function_calling_supported and not structured_output_supported and 'embedding' not in capabilities: compatibility_level = 'limited' limitations.append('Compatibility not fully tested') - + return { 'level': compatibility_level, 'features': features, @@ -853,12 +954,12 @@ async def _test_function_calling_capability(model_name: str, instance_url: str) try: # Import here to avoid circular imports from ..services.llm_provider_service import get_llm_client - + # Use OpenAI-compatible client for function calling test async with get_llm_client(provider="ollama") as client: # Set base_url for this specific instance client.base_url = f"{instance_url.rstrip('/')}/v1" - + # Define a simple test function test_function = { "name": "get_weather", @@ -874,7 +975,7 @@ async def _test_function_calling_capability(model_name: str, instance_url: str) "required": ["location"] } } - + # Try to make a function calling request response = await client.chat.completions.create( model=model_name, @@ -883,16 +984,16 @@ async def _test_function_calling_capability(model_name: str, instance_url: str) max_tokens=50, timeout=10 ) - + # Check if the model attempted to use the function if response.choices and len(response.choices) > 0: choice = response.choices[0] if hasattr(choice.message, 'tool_calls') and choice.message.tool_calls: logger.info(f"Model {model_name} supports function calling") return True - + return False - + except Exception as e: logger.debug(f"Function calling test failed for {model_name}: {e}") return False @@ -912,24 +1013,24 @@ async def _test_structured_output_capability(model_name: str, instance_url: str) try: # Import here to avoid circular imports from ..services.llm_provider_service import get_llm_client - + # Use OpenAI-compatible client for structured output test async with get_llm_client(provider="ollama") as client: # Set base_url for this specific instance client.base_url = f"{instance_url.rstrip('/')}/v1" - + # Test structured output with JSON format response = await client.chat.completions.create( model=model_name, messages=[{ - "role": "user", + "role": "user", "content": "Return a JSON object with the structure: {\"city\": \"Paris\", \"country\": \"France\", \"population\": 2140000}. Only return the JSON, no other text." }], max_tokens=100, timeout=10, temperature=0.1 # Low temperature for more consistent output ) - + if response.choices and len(response.choices) > 0: content = response.choices[0].message.content if content: @@ -942,13 +1043,11 @@ async def _test_structured_output_capability(model_name: str, instance_url: str) logger.info(f"Model {model_name} supports structured output") return True except json.JSONDecodeError: - # Try to find JSON-like patterns in the response - if '{' in content and '}' in content and '"' in content: - logger.info(f"Model {model_name} has partial structured output support") - return True - + # Only accept valid JSON - no partial support heuristics + pass + return False - + except Exception as e: logger.debug(f"Structured output test failed for {model_name}: {e}") return False @@ -957,14 +1056,12 @@ async def _test_structured_output_capability(model_name: str, instance_url: str) @router.post("/models/discover-with-details", response_model=ModelDiscoveryResponse) async def discover_models_with_real_details(request: ModelDiscoveryAndStoreRequest) -> ModelDiscoveryResponse: """ - Discover models from Ollama instances with complete real details from both /api/tags and /api/show. + Discover models from Ollama instances using /api/tags endpoint for fast discovery. Only stores actual data from Ollama API endpoints - no fabricated information. """ try: logger.info(f"Starting detailed model discovery for {len(request.instance_urls)} instances") - from datetime import datetime - import httpx from ..utils import get_supabase_client @@ -976,6 +1073,13 @@ async def discover_models_with_real_details(request: ModelDiscoveryAndStoreReque for instance_url in request.instance_urls: try: base_url = instance_url.replace('/v1', '').rstrip('/') + + # SSRF protection - check if URL targets private/internal addresses + parsed = urlparse(base_url) + if parsed.scheme not in ('http', 'https') or not parsed.hostname or _is_private_host(parsed.hostname): + logger.warning(f"Blocked private/invalid URL in detailed discovery: {base_url}") + continue + logger.debug(f"Fetching real model data from {base_url}") async with httpx.AsyncClient(timeout=httpx.Timeout(5.0)) as client: @@ -1028,16 +1132,7 @@ async def discover_models_with_real_details(request: ModelDiscoveryAndStoreReque # Set default embedding dimensions based on common model patterns embedding_dimensions = None - if model_type == 'embedding': - # Use common defaults based on model name - if "nomic-embed" in model_name.lower(): - embedding_dimensions = 768 - elif "bge" in model_name.lower(): - embedding_dimensions = 768 - elif "e5" in model_name.lower(): - embedding_dimensions = 1024 - else: - embedding_dimensions = 768 # Common default + # Don't fabricate embedding dimensions - leave as None for unknown values # Extract real parameter info parameters = details.get("parameter_size") @@ -1051,22 +1146,12 @@ async def discover_models_with_real_details(request: ModelDiscoveryAndStoreReque param_parts.append(quantization) param_string = " ".join(param_parts) if param_parts else None - # Create model with only real data - # Skip capability testing for fast discovery - assume basic capabilities - if model_type == 'chat': - # Skip testing, assume basic chat capabilities for fast discovery - features = ['Local Processing', 'Text Generation', 'Chat Support'] - limitations = [] - compatibility_level = 'full' # Assume full for now - - compatibility = { - 'level': compatibility_level, - 'features': features, - 'limitations': limitations - } - else: - # Embedding models are all considered full compatibility for embedding tasks - compatibility = {'level': 'full', 'features': ['High-quality embeddings', 'Local processing'], 'limitations': []} + # Create model with only real data - don't fabricate compatibility + compatibility = { + 'level': 'unknown', + 'features': [], + 'limitations': ['Requires capability testing for accurate assessment'] + } stored_model = StoredModelInfo( name=model_name, @@ -1081,7 +1166,7 @@ async def discover_models_with_real_details(request: ModelDiscoveryAndStoreReque limitations=compatibility['limitations'], performance_rating=None, description=None, - last_updated=datetime.now().isoformat(), + last_updated=datetime.now(UTC).isoformat(), embedding_dimensions=embedding_dimensions ) @@ -1094,34 +1179,35 @@ async def discover_models_with_real_details(request: ModelDiscoveryAndStoreReque logger.debug(f"Processed model {model_name} with real data") except Exception as e: - logger.warning(f"Failed to get details for model {model_name}: {e}") + logger.warning(f"Failed to get details for model {model_name}: {e}", exc_info=True) continue instances_checked += 1 logger.debug(f"Completed processing {base_url}") except Exception as e: - logger.warning(f"Failed to process instance {instance_url}: {e}") + logger.warning(f"Failed to process instance {instance_url}: {e}", exc_info=True) continue # Store models with real data only models_data = { "models": stored_models, # Already converted to dicts above - "last_discovery": datetime.now().isoformat(), + "last_discovery": datetime.now(UTC).isoformat(), "instances_checked": instances_checked, "total_count": len(stored_models) } - + # Debug log to check what's in stored_models embedding_models_with_dims = [m for m in stored_models if m.get('model_type') == 'embedding' and m.get('embedding_dimensions')] logger.info(f"Storing {len(embedding_models_with_dims)} embedding models with dimensions: {[(m['name'], m.get('embedding_dimensions')) for m in embedding_models_with_dims]}") - # Update the stored models - result = supabase.table("archon_settings").update({ + # Upsert the stored models + supabase.table("archon_settings").upsert({ + "key": "ollama_discovered_models", "value": json.dumps(models_data), "description": "Real Ollama model data from API endpoints", - "updated_at": datetime.now().isoformat() - }).eq("key", "ollama_discovered_models").execute() + "updated_at": datetime.now(UTC).isoformat() + }).execute() logger.info(f"Stored {len(stored_models)} models with real data from {instances_checked} instances") @@ -1138,10 +1224,10 @@ async def discover_models_with_real_details(request: ModelDiscoveryAndStoreReque embedding_models = [] host_status = {} unique_model_names = set() - + for model in stored_models: unique_model_names.add(model['name']) - + # Build host status host = model['host'].replace('/v1', '').rstrip('/') if host not in host_status: @@ -1151,7 +1237,7 @@ async def discover_models_with_real_details(request: ModelDiscoveryAndStoreReque "instance_url": model['host'] } host_status[host]["models_count"] += 1 - + # Categorize models if model['model_type'] == 'embedding': embedding_models.append({ @@ -1166,7 +1252,7 @@ async def discover_models_with_real_details(request: ModelDiscoveryAndStoreReque "instance_url": model['host'], "size": model.get('size_mb', 0) * 1024 * 1024 if model.get('size_mb') else 0 }) - + return ModelDiscoveryResponse( total_models=len(stored_models), chat_models=chat_models, @@ -1177,7 +1263,7 @@ async def discover_models_with_real_details(request: ModelDiscoveryAndStoreReque ) except Exception as e: - logger.error(f"Error in detailed model discovery: {e}") + logger.error(f"Error in detailed model discovery: {e}", exc_info=True) raise HTTPException(status_code=500, detail=f"Model discovery failed: {str(e)}") @@ -1238,13 +1324,19 @@ async def test_model_capabilities_endpoint(request: ModelCapabilityTestRequest) """ import time start_time = time.time() - + try: logger.info(f"Testing capabilities for model {request.model_name} on {request.instance_url}") - + + # SSRF protection - check if URL targets private/internal addresses + parsed = urlparse(request.instance_url) + if parsed.scheme not in ('http', 'https') or not parsed.hostname or _is_private_host(parsed.hostname): + logger.warning(f"Blocked private/invalid URL in capability testing: {request.instance_url}") + raise HTTPException(status_code=400, detail="URL blocked for security reasons") + test_results = {} errors = [] - + # Test function calling if requested if request.test_function_calling: try: @@ -1260,7 +1352,7 @@ async def test_model_capabilities_endpoint(request: ModelCapabilityTestRequest) error_msg = f"Function calling test failed: {str(e)}" errors.append(error_msg) test_results["function_calling"] = {"supported": False, "error": error_msg} - + # Test structured output if requested if request.test_structured_output: try: @@ -1276,34 +1368,34 @@ async def test_model_capabilities_endpoint(request: ModelCapabilityTestRequest) error_msg = f"Structured output test failed: {str(e)}" errors.append(error_msg) test_results["structured_output"] = {"supported": False, "error": error_msg} - + # Assess compatibility based on test results compatibility_level = 'limited' features = ['Local Processing', 'Text Generation', 'MCP Integration', 'Streaming'] limitations = [] - + # Determine compatibility level based on test results function_calling_works = test_results.get("function_calling", {}).get("supported", False) structured_output_works = test_results.get("structured_output", {}).get("supported", False) - + if function_calling_works: features.append('Function Calls') compatibility_level = 'full' - + if structured_output_works: features.append('Structured Output') if compatibility_level == 'limited': compatibility_level = 'partial' - + # Add limitations based on what doesn't work if not function_calling_works: limitations.append('No function calling support detected') if not structured_output_works: limitations.append('Limited structured output support') - + if compatibility_level == 'limited': limitations.append('Basic text generation only') - + compatibility_assessment = { 'level': compatibility_level, 'features': features, @@ -1311,11 +1403,11 @@ async def test_model_capabilities_endpoint(request: ModelCapabilityTestRequest) 'testing_method': 'Real-time API testing', 'confidence': 'High' if not errors else 'Medium' } - + duration = time.time() - start_time - + logger.info(f"Capability testing complete for {request.model_name}: {compatibility_level} support detected in {duration:.2f}s") - + return ModelCapabilityTestResponse( model_name=request.model_name, instance_url=request.instance_url, @@ -1324,8 +1416,8 @@ async def test_model_capabilities_endpoint(request: ModelCapabilityTestRequest) test_duration_seconds=duration, errors=errors ) - + except Exception as e: duration = time.time() - start_time - logger.error(f"Error testing model capabilities: {e}") + logger.error(f"Error testing model capabilities: {e}", exc_info=True) raise HTTPException(status_code=500, detail=f"Capability testing failed: {str(e)}")