Skip to content

Commit 69208b0

Browse files
committed
feat: improve local provider connectivity with CORS bypass
- Add @tauri-apps/plugin-http dependency - Implement dual fetch strategy for local vs remote providers - Auto-detect local providers (localhost, Ollama:11434, LM Studio:1234) - Make API key optional for local providers - Add comprehensive test coverage for provider fetching
1 parent 64a7822 commit 69208b0

File tree

3 files changed

+213
-12
lines changed

3 files changed

+213
-12
lines changed

web-app/package.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
"@tauri-apps/api": "^2.5.0",
3333
"@tauri-apps/plugin-deep-link": "~2",
3434
"@tauri-apps/plugin-dialog": "^2.2.1",
35+
"@tauri-apps/plugin-http": "^2.2.1",
3536
"@tauri-apps/plugin-opener": "^2.2.7",
3637
"@tauri-apps/plugin-os": "^2.2.1",
3738
"@tauri-apps/plugin-updater": "^2.7.1",
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
import { describe, it, expect, vi, beforeEach } from 'vitest'
2+
3+
// Mock the Tauri fetch - use factory function to avoid hoisting issues
4+
vi.mock('@tauri-apps/plugin-http', () => ({
5+
fetch: vi.fn()
6+
}))
7+
8+
// Import after mocking
9+
import { fetchModelsFromProvider } from './providers'
10+
import { fetch as fetchTauri } from '@tauri-apps/plugin-http'
11+
12+
// Mock global fetch
13+
const mockGlobalFetch = vi.fn()
14+
global.fetch = mockGlobalFetch
15+
16+
// Get the mocked fetchTauri
17+
const mockFetchTauri = vi.mocked(fetchTauri)
18+
19+
describe('fetchModelsFromProvider', () => {
20+
beforeEach(() => {
21+
vi.clearAllMocks()
22+
const mockResponse = {
23+
ok: true,
24+
status: 200,
25+
statusText: 'OK',
26+
json: () => Promise.resolve({ data: [{ id: 'test-model' }] })
27+
} as Partial<Response>
28+
mockFetchTauri.mockResolvedValue(mockResponse as Response)
29+
mockGlobalFetch.mockResolvedValue(mockResponse as Response)
30+
})
31+
32+
it('should use fetchTauri for localhost URLs', async () => {
33+
const provider = {
34+
provider: 'test',
35+
base_url: 'http://localhost:8080/v1',
36+
api_key: 'test-key',
37+
models: [],
38+
settings: [],
39+
active: true
40+
}
41+
42+
await fetchModelsFromProvider(provider)
43+
44+
expect(mockFetchTauri).toHaveBeenCalledWith('http://localhost:8080/v1/models', {
45+
method: 'GET',
46+
headers: {
47+
'Content-Type': 'application/json',
48+
'x-api-key': 'test-key',
49+
'Authorization': 'Bearer test-key'
50+
}
51+
})
52+
expect(mockGlobalFetch).not.toHaveBeenCalled()
53+
})
54+
55+
it('should use fetchTauri for 127.0.0.1 URLs', async () => {
56+
const provider = {
57+
provider: 'test',
58+
base_url: 'http://127.0.0.1:8080/v1',
59+
api_key: 'test-key',
60+
models: [],
61+
settings: [],
62+
active: true
63+
}
64+
65+
await fetchModelsFromProvider(provider)
66+
67+
expect(mockFetchTauri).toHaveBeenCalled()
68+
expect(mockGlobalFetch).not.toHaveBeenCalled()
69+
})
70+
71+
it('should use fetchTauri for Ollama port 11434', async () => {
72+
const provider = {
73+
provider: 'ollama',
74+
base_url: 'http://192.168.1.100:11434/v1',
75+
api_key: '',
76+
models: [],
77+
settings: [],
78+
active: true
79+
}
80+
81+
await fetchModelsFromProvider(provider)
82+
83+
expect(mockFetchTauri).toHaveBeenCalled()
84+
expect(mockGlobalFetch).not.toHaveBeenCalled()
85+
})
86+
87+
it('should use fetchTauri for LM Studio port 1234', async () => {
88+
const provider = {
89+
provider: 'lmstudio',
90+
base_url: 'http://192.168.1.100:1234/v1',
91+
api_key: '',
92+
models: [],
93+
settings: [],
94+
active: true
95+
}
96+
97+
await fetchModelsFromProvider(provider)
98+
99+
expect(mockFetchTauri).toHaveBeenCalled()
100+
expect(mockGlobalFetch).not.toHaveBeenCalled()
101+
})
102+
103+
it('should use fetchTauri when skipPreflight is true', async () => {
104+
const provider = {
105+
provider: 'openai',
106+
base_url: 'https://api.openai.com/v1',
107+
api_key: 'test-key',
108+
models: [],
109+
settings: [],
110+
active: true
111+
}
112+
113+
await fetchModelsFromProvider(provider, true)
114+
115+
expect(mockFetchTauri).toHaveBeenCalled()
116+
expect(mockGlobalFetch).not.toHaveBeenCalled()
117+
})
118+
119+
it('should not require API key for local providers', async () => {
120+
const provider = {
121+
provider: 'ollama',
122+
base_url: 'http://127.0.0.1:11434/v1',
123+
api_key: '',
124+
models: [],
125+
settings: [],
126+
active: true
127+
}
128+
129+
await fetchModelsFromProvider(provider)
130+
131+
expect(mockFetchTauri).toHaveBeenCalledWith('http://127.0.0.1:11434/v1/models', {
132+
method: 'GET',
133+
headers: {
134+
'Content-Type': 'application/json'
135+
}
136+
})
137+
})
138+
})

web-app/src/services/providers.ts

Lines changed: 74 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@ import {
1313
import { modelSettings } from '@/lib/predefined'
1414
import { fetchModels } from './models'
1515
import { ExtensionManager } from '@/lib/extension'
16+
import { fetch as fetchTauri } from '@tauri-apps/plugin-http'
17+
18+
// Ensure we have access to the global fetch for non-Tauri requests
19+
const globalFetch = globalThis.fetch || window.fetch
1620

1721
export const getProviders = async (): Promise<ModelProvider[]> => {
1822
const engines = !localStorage.getItem('migration_completed')
@@ -163,32 +167,80 @@ export const getProviders = async (): Promise<ModelProvider[]> => {
163167
return runtimeProviders.concat(builtinProviders as ModelProvider[])
164168
}
165169

170+
/**
171+
* Checks if a URL is a local provider (localhost or 127.0.0.1) or uses specific ports (Ollama/LM Studio)
172+
* @param url The URL to check
173+
* @returns boolean indicating if it's a local provider or uses known local AI ports
174+
*/
175+
const shouldBypassPreflightCheck = (url: string): boolean => {
176+
try {
177+
const urlObj = new URL(url)
178+
const isLocalHost = urlObj.hostname === 'localhost' ||
179+
urlObj.hostname === '127.0.0.1' ||
180+
urlObj.hostname === '0.0.0.0'
181+
182+
// Check for specific ports used by local AI providers
183+
const port = parseInt(urlObj.port)
184+
const isOllamaPort = port === 11434
185+
const isLMStudioPort = port === 1234
186+
187+
return isLocalHost || isOllamaPort || isLMStudioPort
188+
} catch {
189+
return false
190+
}
191+
}
192+
166193
/**
167194
* Fetches models from a provider's API endpoint
195+
* Uses Tauri's HTTP client for local providers or when skipPreflight is true to bypass CORS issues
168196
* @param provider The provider object containing base_url and api_key
197+
* @param skipPreflight Whether to skip CORS preflight by using Tauri fetch (default: false)
169198
* @returns Promise<string[]> Array of model IDs
170199
*/
171200
export const fetchModelsFromProvider = async (
172-
provider: ModelProvider
201+
provider: ModelProvider,
202+
skipPreflight: boolean = false
173203
): Promise<string[]> => {
174-
if (!provider.base_url || !provider.api_key) {
175-
throw new Error('Provider must have base_url and api_key configured')
204+
if (!provider.base_url) {
205+
throw new Error('Provider must have base_url configured')
176206
}
177207

208+
// For local providers, we don't require API key as they often don't use authentication
209+
const isPreflightCheckBypassed = shouldBypassPreflightCheck(provider.base_url)
210+
211+
// Determine whether to use Tauri fetch to bypass CORS preflight
212+
const shouldUseTauriFetch = isPreflightCheckBypassed || skipPreflight
213+
178214
try {
179-
const response = await fetch(`${provider.base_url}/models`, {
215+
const headers: Record<string, string> = {
216+
'Content-Type': 'application/json',
217+
}
218+
219+
// Only add authentication headers if API key is provided
220+
if (provider.api_key) {
221+
headers['x-api-key'] = provider.api_key
222+
headers['Authorization'] = `Bearer ${provider.api_key}`
223+
}
224+
225+
// Use Tauri's fetch for local providers or when skipPreflight is true to avoid CORS issues
226+
// Use regular fetch for remote providers to maintain normal browser behavior
227+
const fetchFunction = shouldUseTauriFetch ? fetchTauri : globalFetch
228+
const response = await fetchFunction(`${provider.base_url}/models`, {
180229
method: 'GET',
181-
headers: {
182-
'x-api-key': provider.api_key,
183-
'Authorization': `Bearer ${provider.api_key}`,
184-
'Content-Type': 'application/json',
185-
},
230+
headers,
186231
})
187232

188233
if (!response.ok) {
189-
throw new Error(
190-
`Failed to fetch models: ${response.status} ${response.statusText}`
191-
)
234+
// Provide more specific error messages for local providers
235+
if (isPreflightCheckBypassed) {
236+
throw new Error(
237+
`Failed to connect to local provider at ${provider.base_url}. Please ensure the service is running and accessible.`
238+
)
239+
} else {
240+
throw new Error(
241+
`Failed to fetch models: ${response.status} ${response.statusText}`
242+
)
243+
}
192244
}
193245

194246
const data = await response.json()
@@ -213,6 +265,16 @@ export const fetchModelsFromProvider = async (
213265
}
214266
} catch (error) {
215267
console.error('Error fetching models from provider:', error)
268+
269+
// Provide helpful error messages for common local provider issues
270+
if (isPreflightCheckBypassed && error instanceof Error) {
271+
if (error.message.includes('fetch')) {
272+
throw new Error(
273+
`Cannot connect to ${provider.provider} at ${provider.base_url}. Please check that the service is running and accessible.`
274+
)
275+
}
276+
}
277+
216278
throw error
217279
}
218280
}

0 commit comments

Comments
 (0)