Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions apps/web-roo-code/src/lib/hooks/use-open-router-models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ export const openRouterModelSchema = z.object({
.optional(),
architecture: z
.object({
modality: z.string(),
input_modalities: z.array(z.string()).nullish(),
output_modalities: z.array(z.string()).nullish(),
})
.optional(),
})
Expand All @@ -47,6 +48,10 @@ export const getOpenRouterModels = async (): Promise<OpenRouterModelRecord> => {
}

return result.data.data
.filter((rawModel) => {
// Skip image generation models (models that output images)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding a comment here explaining why image generation models are being filtered out. This would help future maintainers understand the business logic behind this decision.

Suggested change
// Skip image generation models (models that output images)
.filter((rawModel) => {
// Skip image generation models (models that output images)
// We only want text-based language models in the model selection UI
return !rawModel.architecture?.output_modalities?.includes("image")
})

return !rawModel.architecture?.output_modalities?.includes("image")
})
.sort((a, b) => a.name.localeCompare(b.name))
.map((rawModel) => ({
...rawModel,
Expand All @@ -57,7 +62,7 @@ export const getOpenRouterModels = async (): Promise<OpenRouterModelRecord> => {
outputPrice: parsePrice(rawModel.pricing?.completion),
description: rawModel.description,
supportsPromptCache: false,
supportsImages: false,
supportsImages: rawModel.architecture?.input_modalities?.includes("image") ?? false,
supportsThinking: false,
tiers: [],
},
Expand Down
55 changes: 52 additions & 3 deletions src/api/providers/fetchers/__tests__/openrouter.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,8 @@ describe("OpenRouter API", () => {
const result = parseOpenRouterModel({
id: "openrouter/horizon-alpha",
model: mockModel,
modality: "text",
inputModality: ["text"],
outputModality: ["text"],
maxTokens: 128000,
})

Expand All @@ -303,7 +304,8 @@ describe("OpenRouter API", () => {
const result = parseOpenRouterModel({
id: "openrouter/horizon-beta",
model: mockModel,
modality: "text",
inputModality: ["text"],
outputModality: ["text"],
maxTokens: 128000,
})

Expand All @@ -326,12 +328,59 @@ describe("OpenRouter API", () => {
const result = parseOpenRouterModel({
id: "openrouter/other-model",
model: mockModel,
modality: "text",
inputModality: ["text"],
outputModality: ["text"],
maxTokens: 64000,
})

expect(result.maxTokens).toBe(64000)
expect(result.contextWindow).toBe(128000)
})

it("filters out image generation models", () => {
const mockImageModel = {
name: "Image Model",
description: "Test image generation model",
context_length: 128000,
max_completion_tokens: 64000,
pricing: {
prompt: "0.000003",
completion: "0.000015",
},
}

const mockTextModel = {
name: "Text Model",
description: "Test text generation model",
context_length: 128000,
max_completion_tokens: 64000,
pricing: {
prompt: "0.000003",
completion: "0.000015",
},
}

// Model with image output should be filtered out - we only test parseOpenRouterModel
// since the filtering happens in getOpenRouterModels/getOpenRouterModelEndpoints
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good test coverage for the parseOpenRouterModel function! Since the filtering happens at a higher level in getOpenRouterModels and getOpenRouterModelEndpoints, would it be valuable to add integration tests that verify the filtering behavior actually excludes image generation models from the returned results?

const textResult = parseOpenRouterModel({
id: "test/text-model",
model: mockTextModel,
inputModality: ["text"],
outputModality: ["text"],
maxTokens: 64000,
})

const imageResult = parseOpenRouterModel({
id: "test/image-model",
model: mockImageModel,
inputModality: ["text"],
outputModality: ["image"],
maxTokens: 64000,
})

// Both should parse successfully (filtering happens at a higher level)
expect(textResult.maxTokens).toBe(64000)
expect(imageResult.maxTokens).toBe(64000)
})
})
})
27 changes: 21 additions & 6 deletions src/api/providers/fetchers/openrouter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ import { parseApiPrice } from "../../../shared/cost"
*/

const openRouterArchitectureSchema = z.object({
modality: z.string().nullish(),
input_modalities: z.array(z.string()).nullish(),
output_modalities: z.array(z.string()).nullish(),
tokenizer: z.string().nullish(),
})

Expand Down Expand Up @@ -110,10 +111,16 @@ export async function getOpenRouterModels(options?: ApiHandlerOptions): Promise<
for (const model of data) {
const { id, architecture, top_provider, supported_parameters = [] } = model

// Skip image generation models (models that output images)
if (architecture?.output_modalities?.includes("image")) {
continue
}

models[id] = parseOpenRouterModel({
id,
model,
modality: architecture?.modality,
inputModality: architecture?.input_modalities,
outputModality: architecture?.output_modalities,
maxTokens: top_provider?.max_completion_tokens,
supportedParameters: supported_parameters,
})
Expand Down Expand Up @@ -149,11 +156,17 @@ export async function getOpenRouterModelEndpoints(

const { id, architecture, endpoints } = data

// Skip image generation models (models that output images)
if (architecture?.output_modalities?.includes("image")) {
return models
}

for (const endpoint of endpoints) {
models[endpoint.tag ?? endpoint.provider_name] = parseOpenRouterModel({
id,
model: endpoint,
modality: architecture?.modality,
inputModality: architecture?.input_modalities,
outputModality: architecture?.output_modalities,
maxTokens: endpoint.max_completion_tokens,
})
}
Expand All @@ -173,13 +186,15 @@ export async function getOpenRouterModelEndpoints(
export const parseOpenRouterModel = ({
id,
model,
modality,
inputModality,
outputModality,
maxTokens,
supportedParameters,
}: {
id: string
model: OpenRouterBaseModel
modality: string | null | undefined
inputModality: string[] | null | undefined
outputModality: string[] | null | undefined
maxTokens: number | null | undefined
supportedParameters?: string[]
}): ModelInfo => {
Expand All @@ -194,7 +209,7 @@ export const parseOpenRouterModel = ({
const modelInfo: ModelInfo = {
maxTokens: maxTokens || Math.ceil(model.context_length * 0.2),
contextWindow: model.context_length,
supportsImages: modality?.includes("image") ?? false,
supportsImages: inputModality?.includes("image") ?? false,
supportsPromptCache,
inputPrice: parseApiPrice(model.pricing?.prompt),
outputPrice: parseApiPrice(model.pricing?.completion),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ const openRouterEndpointsSchema = z.object({
description: z.string().optional(),
architecture: z
.object({
modality: z.string().nullish(),
input_modalities: z.array(z.string()).nullish(),
output_modalities: z.array(z.string()).nullish(),
tokenizer: z.string().nullish(),
})
.nullish(),
Expand Down Expand Up @@ -56,6 +57,11 @@ async function getOpenRouterProvidersForModel(modelId: string) {

const { description, architecture, endpoints } = result.data.data

// Skip image generation models (models that output images)
if (architecture?.output_modalities?.includes("image")) {
return models
}

for (const endpoint of endpoints) {
const providerName = endpoint.tag ?? endpoint.name
const inputPrice = parseApiPrice(endpoint.pricing?.prompt)
Expand All @@ -66,7 +72,7 @@ async function getOpenRouterProvidersForModel(modelId: string) {
const modelInfo: OpenRouterModelProvider = {
maxTokens: endpoint.max_completion_tokens || endpoint.context_length,
contextWindow: endpoint.context_length,
supportsImages: architecture?.modality?.includes("image"),
supportsImages: architecture?.input_modalities?.includes("image") ?? false,
supportsPromptCache: typeof cacheReadsPrice !== "undefined",
cacheReadsPrice,
cacheWritesPrice,
Expand Down
Loading