Skip to content
Merged
13 changes: 7 additions & 6 deletions src/api/providers/base-openai-compatible-provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import { convertToOpenAiMessages } from "../transform/openai-format"
import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
import { DEFAULT_HEADERS } from "./constants"
import { BaseProvider } from "./base-provider"
import { handleOpenAIError } from "./utils/openai-error-handler"

type BaseOpenAiCompatibleProviderOptions<ModelName extends string> = ApiHandlerOptions & {
providerName: string
Expand Down Expand Up @@ -86,7 +87,11 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>
params.temperature = this.options.modelTemperature
}

return this.client.chat.completions.create(params, requestOptions)
try {
return this.client.chat.completions.create(params, requestOptions)
} catch (error) {
throw handleOpenAIError(error, this.providerName)
}
}

override async *createMessage(
Expand Down Expand Up @@ -127,11 +132,7 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>

return response.choices[0]?.message.content || ""
} catch (error) {
if (error instanceof Error) {
throw new Error(`${this.providerName} completion error: ${error.message}`)
}

throw error
throw handleOpenAIError(error, this.providerName)
}
}

Expand Down
15 changes: 9 additions & 6 deletions src/api/providers/huggingface.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@ import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from ".
import { DEFAULT_HEADERS } from "./constants"
import { BaseProvider } from "./base-provider"
import { getHuggingFaceModels, getCachedHuggingFaceModels } from "./fetchers/huggingface"
import { handleOpenAIError } from "./utils/openai-error-handler"

export class HuggingFaceHandler extends BaseProvider implements SingleCompletionHandler {
private client: OpenAI
private options: ApiHandlerOptions
private modelCache: ModelRecord | null = null
private readonly providerName = "HuggingFace"

constructor(options: ApiHandlerOptions) {
super()
Expand Down Expand Up @@ -64,7 +66,12 @@ export class HuggingFaceHandler extends BaseProvider implements SingleCompletion
params.max_tokens = this.options.modelMaxTokens
}

const stream = await this.client.chat.completions.create(params)
let stream
try {
stream = await this.client.chat.completions.create(params)
} catch (error) {
throw handleOpenAIError(error, this.providerName)
}

for await (const chunk of stream) {
const delta = chunk.choices[0]?.delta
Expand Down Expand Up @@ -97,11 +104,7 @@ export class HuggingFaceHandler extends BaseProvider implements SingleCompletion

return response.choices[0]?.message.content || ""
} catch (error) {
if (error instanceof Error) {
throw new Error(`Hugging Face completion error: ${error.message}`)
}

throw error
throw handleOpenAIError(error, this.providerName)
}
}

Expand Down
21 changes: 18 additions & 3 deletions src/api/providers/lm-studio.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,23 @@ import { BaseProvider } from "./base-provider"
import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
import { getModels, getModelsFromCache } from "./fetchers/modelCache"
import { getApiRequestTimeout } from "./utils/timeout-config"
import { handleOpenAIError } from "./utils/openai-error-handler"

export class LmStudioHandler extends BaseProvider implements SingleCompletionHandler {
protected options: ApiHandlerOptions
private client: OpenAI
private readonly providerName = "LM Studio"

constructor(options: ApiHandlerOptions) {
super()
this.options = options

// LM Studio uses "noop" as a placeholder API key
const apiKey = "noop"
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this intentional? The comment on line 28 suggests we should validate a real key if provided, but we're always validating the hardcoded "noop" string. Should we check for a real API key from options.lmStudioApiKey first before defaulting to "noop"?


this.client = new OpenAI({
baseURL: (this.options.lmStudioBaseUrl || "http://localhost:1234") + "/v1",
apiKey: "noop",
apiKey: apiKey,
timeout: getApiRequestTimeout(),
})
}
Expand Down Expand Up @@ -88,7 +93,12 @@ export class LmStudioHandler extends BaseProvider implements SingleCompletionHan
params.draft_model = this.options.lmStudioDraftModelId
}

const results = await this.client.chat.completions.create(params)
let results
try {
results = await this.client.chat.completions.create(params)
} catch (error) {
throw handleOpenAIError(error, this.providerName)
}

const matcher = new XmlMatcher(
"think",
Expand Down Expand Up @@ -164,7 +174,12 @@ export class LmStudioHandler extends BaseProvider implements SingleCompletionHan
params.draft_model = this.options.lmStudioDraftModelId
}

const response = await this.client.chat.completions.create(params)
let response
try {
response = await this.client.chat.completions.create(params)
} catch (error) {
throw handleOpenAIError(error, this.providerName)
}
return response.choices[0]?.message.content || ""
} catch (error) {
throw new Error(
Expand Down
42 changes: 27 additions & 15 deletions src/api/providers/ollama.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@ import { ApiStream } from "../transform/stream"
import { BaseProvider } from "./base-provider"
import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
import { getApiRequestTimeout } from "./utils/timeout-config"
import { handleOpenAIError } from "./utils/openai-error-handler"

type CompletionUsage = OpenAI.Chat.Completions.ChatCompletionChunk["usage"]

export class OllamaHandler extends BaseProvider implements SingleCompletionHandler {
protected options: ApiHandlerOptions
private client: OpenAI
private readonly providerName = "Ollama"

constructor(options: ApiHandlerOptions) {
super()
Expand Down Expand Up @@ -54,13 +56,18 @@ export class OllamaHandler extends BaseProvider implements SingleCompletionHandl
...(useR1Format ? convertToR1Format(messages) : convertToOpenAiMessages(messages)),
]

const stream = await this.client.chat.completions.create({
model: this.getModel().id,
messages: openAiMessages,
temperature: this.options.modelTemperature ?? 0,
stream: true,
stream_options: { include_usage: true },
})
let stream
try {
stream = await this.client.chat.completions.create({
model: this.getModel().id,
messages: openAiMessages,
temperature: this.options.modelTemperature ?? 0,
stream: true,
stream_options: { include_usage: true },
})
} catch (error) {
throw handleOpenAIError(error, this.providerName)
}
const matcher = new XmlMatcher(
"think",
(chunk) =>
Expand Down Expand Up @@ -106,14 +113,19 @@ export class OllamaHandler extends BaseProvider implements SingleCompletionHandl
try {
const modelId = this.getModel().id
const useR1Format = modelId.toLowerCase().includes("deepseek-r1")
const response = await this.client.chat.completions.create({
model: this.getModel().id,
messages: useR1Format
? convertToR1Format([{ role: "user", content: prompt }])
: [{ role: "user", content: prompt }],
temperature: this.options.modelTemperature ?? (useR1Format ? DEEP_SEEK_DEFAULT_TEMPERATURE : 0),
stream: false,
})
let response
try {
response = await this.client.chat.completions.create({
model: this.getModel().id,
messages: useR1Format
? convertToR1Format([{ role: "user", content: prompt }])
: [{ role: "user", content: prompt }],
temperature: this.options.modelTemperature ?? (useR1Format ? DEEP_SEEK_DEFAULT_TEMPERATURE : 0),
stream: false,
})
} catch (error) {
throw handleOpenAIError(error, this.providerName)
}
return response.choices[0]?.message.content || ""
} catch (error) {
if (error instanceof Error) {
Expand Down
69 changes: 48 additions & 21 deletions src/api/providers/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,15 @@ import { DEFAULT_HEADERS } from "./constants"
import { BaseProvider } from "./base-provider"
import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
import { getApiRequestTimeout } from "./utils/timeout-config"
import { handleOpenAIError } from "./utils/openai-error-handler"

// TODO: Rename this to OpenAICompatibleHandler. Also, I think the
// `OpenAINativeHandler` can subclass from this, since it's obviously
// compatible with the OpenAI API. We can also rename it to `OpenAIHandler`.
export class OpenAiHandler extends BaseProvider implements SingleCompletionHandler {
protected options: ApiHandlerOptions
private client: OpenAI
private readonly providerName = "OpenAI"

constructor(options: ApiHandlerOptions) {
super()
Expand Down Expand Up @@ -174,10 +176,15 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
// Add max_tokens if needed
this.addMaxTokensIfNeeded(requestOptions, modelInfo)

const stream = await this.client.chat.completions.create(
requestOptions,
isAzureAiInference ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } : {},
)
let stream
try {
stream = await this.client.chat.completions.create(
requestOptions,
isAzureAiInference ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } : {},
)
} catch (error) {
throw handleOpenAIError(error, this.providerName)
}

const matcher = new XmlMatcher(
"think",
Expand Down Expand Up @@ -236,10 +243,15 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
// Add max_tokens if needed
this.addMaxTokensIfNeeded(requestOptions, modelInfo)

const response = await this.client.chat.completions.create(
requestOptions,
this._isAzureAiInference(modelUrl) ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } : {},
)
let response
try {
response = await this.client.chat.completions.create(
requestOptions,
this._isAzureAiInference(modelUrl) ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } : {},
)
} catch (error) {
throw handleOpenAIError(error, this.providerName)
}

yield {
type: "text",
Expand Down Expand Up @@ -281,15 +293,20 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
// Add max_tokens if needed
this.addMaxTokensIfNeeded(requestOptions, modelInfo)

const response = await this.client.chat.completions.create(
requestOptions,
isAzureAiInference ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } : {},
)
let response
try {
response = await this.client.chat.completions.create(
requestOptions,
isAzureAiInference ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } : {},
)
} catch (error) {
throw handleOpenAIError(error, this.providerName)
}

return response.choices[0]?.message.content || ""
} catch (error) {
if (error instanceof Error) {
throw new Error(`OpenAI completion error: ${error.message}`)
throw new Error(`${this.providerName} completion error: ${error.message}`)
}

throw error
Expand Down Expand Up @@ -327,10 +344,15 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
// This allows O3 models to limit response length when includeMaxTokens is enabled
this.addMaxTokensIfNeeded(requestOptions, modelInfo)

const stream = await this.client.chat.completions.create(
requestOptions,
methodIsAzureAiInference ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } : {},
)
let stream
try {
stream = await this.client.chat.completions.create(
requestOptions,
methodIsAzureAiInference ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } : {},
)
} catch (error) {
throw handleOpenAIError(error, this.providerName)
}

yield* this.handleStreamResponse(stream)
} else {
Expand All @@ -352,10 +374,15 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
// This allows O3 models to limit response length when includeMaxTokens is enabled
this.addMaxTokensIfNeeded(requestOptions, modelInfo)

const response = await this.client.chat.completions.create(
requestOptions,
methodIsAzureAiInference ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } : {},
)
let response
try {
response = await this.client.chat.completions.create(
requestOptions,
methodIsAzureAiInference ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } : {},
)
} catch (error) {
throw handleOpenAIError(error, this.providerName)
}

yield {
type: "text",
Expand Down
16 changes: 14 additions & 2 deletions src/api/providers/openrouter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import { getModelEndpoints } from "./fetchers/modelEndpointCache"
import { DEFAULT_HEADERS } from "./constants"
import { BaseProvider } from "./base-provider"
import type { SingleCompletionHandler } from "../index"
import { handleOpenAIError } from "./utils/openai-error-handler"

// Image generation types
interface ImageGenerationResponse {
Expand Down Expand Up @@ -85,6 +86,7 @@ export class OpenRouterHandler extends BaseProvider implements SingleCompletionH
private client: OpenAI
protected models: ModelRecord = {}
protected endpoints: ModelRecord = {}
private readonly providerName = "OpenRouter"

constructor(options: ApiHandlerOptions) {
super()
Expand Down Expand Up @@ -161,7 +163,12 @@ export class OpenRouterHandler extends BaseProvider implements SingleCompletionH
...(reasoning && { reasoning }),
}

const stream = await this.client.chat.completions.create(completionParams)
let stream
try {
stream = await this.client.chat.completions.create(completionParams)
} catch (error) {
throw handleOpenAIError(error, this.providerName)
}

let lastUsage: CompletionUsage | undefined = undefined

Expand Down Expand Up @@ -259,7 +266,12 @@ export class OpenRouterHandler extends BaseProvider implements SingleCompletionH
...(reasoning && { reasoning }),
}

const response = await this.client.chat.completions.create(completionParams)
let response
try {
response = await this.client.chat.completions.create(completionParams)
} catch (error) {
throw handleOpenAIError(error, this.providerName)
}

if ("error" in response) {
const error = response.error as { message?: string; code?: number }
Expand Down
Loading
Loading