-
Notifications
You must be signed in to change notification settings - Fork 512
/
Copy pathmodels.ts
92 lines (82 loc) · 2.65 KB
/
models.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import { createAnthropic } from '@ai-sdk/anthropic'
import { createGoogleGenerativeAI } from '@ai-sdk/google'
import { createVertex } from '@ai-sdk/google-vertex'
import { createMistral } from '@ai-sdk/mistral'
import { createOpenAI } from '@ai-sdk/openai'
import { createOllama } from 'ollama-ai-provider'
export type LLMModel = {
id: string
name: string
provider: string
providerId: string
}
export type LLMModelConfig = {
model?: string
apiKey?: string
baseURL?: string
temperature?: number
topP?: number
topK?: number
frequencyPenalty?: number
presencePenalty?: number
maxTokens?: number
}
export function getModelClient(model: LLMModel, config: LLMModelConfig) {
const { id: modelNameString, providerId } = model
const { apiKey, baseURL } = config
const providerConfigs = {
anthropic: () => createAnthropic({ apiKey, baseURL })(modelNameString),
openai: () => createOpenAI({ apiKey, baseURL })(modelNameString),
google: () =>
createGoogleGenerativeAI({ apiKey, baseURL })(modelNameString),
mistral: () => createMistral({ apiKey, baseURL })(modelNameString),
groq: () =>
createOpenAI({
apiKey: apiKey || process.env.GROQ_API_KEY,
baseURL: baseURL || 'https://api.groq.com/openai/v1',
})(modelNameString),
togetherai: () =>
createOpenAI({
apiKey: apiKey || process.env.TOGETHER_API_KEY,
baseURL: baseURL || 'https://api.together.xyz/v1',
})(modelNameString),
ollama: () => createOllama({ baseURL })(modelNameString),
fireworks: () =>
createOpenAI({
apiKey: apiKey || process.env.FIREWORKS_API_KEY,
baseURL: baseURL || 'https://api.fireworks.ai/inference/v1',
})(modelNameString),
vertex: () =>
createVertex({
googleAuthOptions: {
credentials: JSON.parse(
process.env.GOOGLE_VERTEX_CREDENTIALS || '{}',
),
},
})(modelNameString),
xai: () =>
createOpenAI({
apiKey: apiKey || process.env.XAI_API_KEY,
baseURL: baseURL || 'https://api.x.ai/v1',
})(modelNameString),
deepseek: () =>
createOpenAI({
apiKey: apiKey || process.env.DEEPSEEK_API_KEY,
baseURL: baseURL || 'https://api.deepseek.com/v1',
})(modelNameString),
}
const createClient =
providerConfigs[providerId as keyof typeof providerConfigs]
if (!createClient) {
throw new Error(`Unsupported provider: ${providerId}`)
}
return createClient()
}
export function getDefaultMode(model: LLMModel) {
const { id: modelNameString, providerId } = model
// monkey patch fireworks
if (providerId === 'fireworks') {
return 'json'
}
return 'auto'
}