diff --git a/.changeset/public-radios-pull.md b/.changeset/public-radios-pull.md new file mode 100644 index 00000000000..bf2cf98d216 --- /dev/null +++ b/.changeset/public-radios-pull.md @@ -0,0 +1,5 @@ +--- +"@kilocode/cli": patch +--- + +Improve "/model list" command with pagination, filters and sorting diff --git a/cli/src/commands/__tests__/helpers/mockContext.ts b/cli/src/commands/__tests__/helpers/mockContext.ts index 4ca3c836052..254e6f5240c 100644 --- a/cli/src/commands/__tests__/helpers/mockContext.ts +++ b/cli/src/commands/__tests__/helpers/mockContext.ts @@ -71,6 +71,14 @@ export function createMockContext(overrides: Partial = {}): Comm previousTaskHistoryPage: vi.fn().mockResolvedValue(null), sendWebviewMessage: vi.fn().mockResolvedValue(undefined), chatMessages: [], + modelListPageIndex: 0, + modelListFilters: { + sort: "preferred", + capabilities: [], + }, + updateModelListFilters: vi.fn(), + changeModelListPage: vi.fn(), + resetModelListState: vi.fn(), } return { diff --git a/cli/src/commands/__tests__/model.test.ts b/cli/src/commands/__tests__/model.test.ts index 6aa637e3317..190c00a0c79 100644 --- a/cli/src/commands/__tests__/model.test.ts +++ b/cli/src/commands/__tests__/model.test.ts @@ -4,9 +4,11 @@ import { describe, it, expect, vi, beforeEach } from "vitest" import { modelCommand } from "../model.js" +import { createMockContext } from "./helpers/mockContext.js" import type { CommandContext } from "../core/types.js" import type { RouterModels } from "../../types/messages.js" import type { ProviderConfig } from "../../config/types.js" +import type { ModelRecord } from "../../constants/providers/models.js" describe("/model command", () => { let mockContext: CommandContext @@ -51,26 +53,35 @@ describe("/model command", () => { apiKey: "test-key", } + // Create many models for pagination testing + const createManyModels = (count: number): ModelRecord => { + const models: ModelRecord = {} + for (let i = 1; i <= count; i++) { + models[`model-${i}`] = { + contextWindow: 100000 + i * 1000, + supportsPromptCache: i % 2 === 0, + supportsImages: i % 3 === 0, + inputPrice: i * 0.5, + outputPrice: i * 1.0, + displayName: `Model ${i}`, + } + } + return models + } + beforeEach(() => { addMessageMock = vi.fn() updateProviderModelMock = vi.fn().mockResolvedValue(undefined) - mockContext = { + mockContext = createMockContext({ input: "/model", args: [], - options: {}, - sendMessage: vi.fn().mockResolvedValue(undefined), - addMessage: addMessageMock, - clearMessages: vi.fn(), - clearTask: vi.fn().mockResolvedValue(undefined), - setMode: vi.fn(), - exit: vi.fn(), routerModels: mockRouterModels, currentProvider: mockProvider, kilocodeDefaultModel: "", updateProviderModel: updateProviderModelMock, - refreshRouterModels: vi.fn().mockResolvedValue(undefined), - } + addMessage: addMessageMock, + }) }) describe("Command metadata", () => { @@ -107,7 +118,7 @@ describe("/model command", () => { it("should have arguments defined", () => { expect(modelCommand.arguments).toBeDefined() - expect(modelCommand.arguments).toHaveLength(2) + expect(modelCommand.arguments).toHaveLength(3) }) it("should have subcommand argument with values", () => { @@ -313,32 +324,47 @@ describe("/model command", () => { }) it("should filter models when filter is provided", async () => { + // Mock updateModelListFilters to actually update the filters + const updateFiltersMock = vi.fn((filters) => { + mockContext.modelListFilters = { ...mockContext.modelListFilters, ...filters } + }) + mockContext.updateModelListFilters = updateFiltersMock mockContext.args = ["list", "gpt-4"] await modelCommand.handler(mockContext) + // Verify the filter was persisted + expect(updateFiltersMock).toHaveBeenCalledWith({ search: "gpt-4" }) + const message = addMessageMock.mock.calls[0][0] - expect(message.content).toContain("Filtered by") + expect(message.content).toContain('Search: "gpt-4"') expect(message.content).toContain("gpt-4") expect(message.content).not.toContain("gpt-3.5-turbo") }) it("should show message when no models match filter", async () => { + // Mock updateModelListFilters to actually update the filters + const updateFiltersMock = vi.fn((filters) => { + mockContext.modelListFilters = { ...mockContext.modelListFilters, ...filters } + }) + mockContext.updateModelListFilters = updateFiltersMock mockContext.args = ["list", "nonexistent"] await modelCommand.handler(mockContext) + // Verify the filter was persisted + expect(updateFiltersMock).toHaveBeenCalledWith({ search: "nonexistent" }) + const message = addMessageMock.mock.calls[0][0] expect(message.type).toBe("system") expect(message.content).toContain("No models found") }) - it("should display model count", async () => { + it("should display model count with pagination", async () => { await modelCommand.handler(mockContext) const message = addMessageMock.mock.calls[0][0] - expect(message.content).toContain("Total:") - expect(message.content).toContain("2 models") + expect(message.content).toContain("Showing 1-2 of 2") }) it("should show error when no provider configured", async () => { @@ -429,4 +455,204 @@ describe("/model command", () => { } }) }) + + describe("Model list pagination", () => { + beforeEach(() => { + mockContext.routerModels = { + ...mockRouterModels, + openrouter: createManyModels(25), + } + mockContext.args = ["list"] + }) + + it("should paginate results with 10 items per page", async () => { + await modelCommand.handler(mockContext) + + const message = addMessageMock.mock.calls[0][0] + expect(message.content).toContain("Showing 1-10 of 25") + expect(message.content).toContain("Page 1/3") + }) + + it("should navigate to specific page", async () => { + mockContext.args = ["list", "page", "2"] + + await modelCommand.handler(mockContext) + + expect(mockContext.changeModelListPage).toHaveBeenCalledWith(1) + }) + + it("should go to next page", async () => { + mockContext.args = ["list", "next"] + mockContext.modelListPageIndex = 0 + + await modelCommand.handler(mockContext) + + expect(mockContext.changeModelListPage).toHaveBeenCalledWith(1) + }) + + it("should go to previous page", async () => { + mockContext.args = ["list", "prev"] + mockContext.modelListPageIndex = 1 + + await modelCommand.handler(mockContext) + + expect(mockContext.changeModelListPage).toHaveBeenCalledWith(0) + }) + + it("should show error when already on first page", async () => { + mockContext.args = ["list", "prev"] + mockContext.modelListPageIndex = 0 + + await modelCommand.handler(mockContext) + + const message = addMessageMock.mock.calls[0][0] + expect(message.type).toBe("system") + expect(message.content).toContain("Already on the first page") + }) + + it("should show error when already on last page", async () => { + mockContext.args = ["list", "next"] + mockContext.modelListPageIndex = 2 + + await modelCommand.handler(mockContext) + + const message = addMessageMock.mock.calls[0][0] + expect(message.type).toBe("system") + expect(message.content).toContain("Already on the last page") + }) + + it("should validate page number", async () => { + mockContext.args = ["list", "page", "invalid"] + + await modelCommand.handler(mockContext) + + const message = addMessageMock.mock.calls[0][0] + expect(message.type).toBe("error") + expect(message.content).toContain("Invalid page number") + }) + + it("should validate page number is within range", async () => { + mockContext.args = ["list", "page", "10"] + + await modelCommand.handler(mockContext) + + const message = addMessageMock.mock.calls[0][0] + expect(message.type).toBe("error") + expect(message.content).toContain("Must be between 1 and") + }) + }) + + describe("Model list sorting", () => { + beforeEach(() => { + mockContext.args = ["list"] + }) + + it("should sort by name", async () => { + mockContext.args = ["list", "sort", "name"] + + await modelCommand.handler(mockContext) + + expect(mockContext.updateModelListFilters).toHaveBeenCalledWith({ sort: "name" }) + }) + + it("should sort by context window", async () => { + mockContext.args = ["list", "sort", "context"] + + await modelCommand.handler(mockContext) + + expect(mockContext.updateModelListFilters).toHaveBeenCalledWith({ sort: "context" }) + }) + + it("should sort by price", async () => { + mockContext.args = ["list", "sort", "price"] + + await modelCommand.handler(mockContext) + + expect(mockContext.updateModelListFilters).toHaveBeenCalledWith({ sort: "price" }) + }) + + it("should show error for invalid sort option", async () => { + mockContext.args = ["list", "sort", "invalid"] + + await modelCommand.handler(mockContext) + + const message = addMessageMock.mock.calls[0][0] + expect(message.type).toBe("error") + expect(message.content).toContain("Invalid sort option") + }) + + it("should show error when sort option is missing", async () => { + mockContext.args = ["list", "sort"] + + await modelCommand.handler(mockContext) + + const message = addMessageMock.mock.calls[0][0] + expect(message.type).toBe("error") + expect(message.content).toContain("Usage: /model list sort") + }) + }) + + describe("Model list filtering", () => { + beforeEach(() => { + mockContext.args = ["list"] + }) + + it("should filter by images capability", async () => { + mockContext.args = ["list", "filter", "images"] + + await modelCommand.handler(mockContext) + + expect(mockContext.updateModelListFilters).toHaveBeenCalledWith({ + capabilities: ["images"], + }) + }) + + it("should filter by cache capability", async () => { + mockContext.args = ["list", "filter", "cache"] + + await modelCommand.handler(mockContext) + + expect(mockContext.updateModelListFilters).toHaveBeenCalledWith({ + capabilities: ["cache"], + }) + }) + + it("should toggle filter off when already active", async () => { + mockContext.args = ["list", "filter", "images"] + mockContext.modelListFilters = { + sort: "preferred", + capabilities: ["images"], + } + + await modelCommand.handler(mockContext) + + expect(mockContext.updateModelListFilters).toHaveBeenCalledWith({ + capabilities: [], + }) + }) + + it("should clear all filters", async () => { + mockContext.args = ["list", "filter", "all"] + mockContext.modelListFilters = { + sort: "preferred", + capabilities: ["images", "cache"], + } + + await modelCommand.handler(mockContext) + + expect(mockContext.updateModelListFilters).toHaveBeenCalledWith({ + capabilities: [], + }) + }) + + it("should show error for invalid filter option", async () => { + mockContext.args = ["list", "filter", "invalid"] + + await modelCommand.handler(mockContext) + + const message = addMessageMock.mock.calls[0][0] + expect(message.type).toBe("error") + expect(message.content).toContain("Invalid filter option") + }) + }) }) diff --git a/cli/src/commands/core/types.ts b/cli/src/commands/core/types.ts index 67d217cce5b..7c4ebf5f3cc 100644 --- a/cli/src/commands/core/types.ts +++ b/cli/src/commands/core/types.ts @@ -7,6 +7,7 @@ import type { CliMessage } from "../../types/cli.js" import type { CLIConfig, ProviderConfig } from "../../config/types.js" import type { ProfileData, BalanceData } from "../../state/atoms/profile.js" import type { TaskHistoryData, TaskHistoryFilters } from "../../state/atoms/taskHistory.js" +import type { ModelListFilters } from "../../state/atoms/modelList.js" export interface Command { name: string @@ -76,6 +77,12 @@ export interface CommandContext { sendWebviewMessage: (message: WebviewMessage) => Promise refreshTerminal: () => Promise chatMessages: ExtensionMessage[] + // Model list context + modelListPageIndex: number + modelListFilters: ModelListFilters + updateModelListFilters: (filters: Partial) => void + changeModelListPage: (pageIndex: number) => void + resetModelListState: () => void } export type CommandHandler = (context: CommandContext) => Promise | void diff --git a/cli/src/commands/model.ts b/cli/src/commands/model.ts index 42cf2a38309..13478f631ac 100644 --- a/cli/src/commands/model.ts +++ b/cli/src/commands/model.ts @@ -3,7 +3,9 @@ */ import type { Command, ArgumentProviderContext, CommandContext } from "./core/types.js" -import type { ModelInfo } from "../constants/providers/models.js" +import type { ModelRecord } from "../constants/providers/models.js" +import type { ProviderConfig } from "../config/types.js" +import type { RouterModels } from "../types/messages.js" import { getModelsByProvider, getCurrentModelId, @@ -13,6 +15,22 @@ import { formatPrice, prettyModelName, } from "../constants/providers/models.js" +import { MODEL_LIST_PAGE_SIZE, type ModelListFilters } from "../state/atoms/modelList.js" + +/** + * Sort options for model list + */ +const MODEL_SORT_OPTIONS: Record = { + name: "name", + context: "context", + price: "price", + preferred: "preferred", +} + +/** + * Filter options for model list + */ +const MODEL_FILTER_OPTIONS = ["images", "cache", "reasoning", "free", "all"] /** * Ensure router models are loaded for the current provider @@ -73,6 +91,137 @@ async function ensureRouterModels(context: CommandContext): Promise { return true } +/** + * Sort models by different criteria + */ +function sortModels(models: ModelRecord, sortBy: string): string[] { + const modelIds = Object.keys(models) + + switch (sortBy) { + case "name": + return modelIds.sort((a, b) => a.localeCompare(b)) + case "context": + return modelIds.sort((a, b) => (models[b]?.contextWindow || 0) - (models[a]?.contextWindow || 0)) + case "price": + return modelIds.sort((a, b) => { + const priceA = models[a]?.inputPrice ?? Infinity + const priceB = models[b]?.inputPrice ?? Infinity + return priceA - priceB + }) + case "preferred": + default: + return sortModelsByPreference(models) + } +} + +/** + * Filter models by capabilities + */ +function filterModelsByCapabilities(models: ModelRecord, capabilities: string[]): string[] { + if (capabilities.length === 0) { + return Object.keys(models) + } + + return Object.keys(models).filter((id) => { + const model = models[id] + if (!model) return false + + return capabilities.every((cap) => { + switch (cap) { + case "images": + return model.supportsImages === true + case "cache": + return model.supportsPromptCache === true + case "reasoning": + return ( + model.supportsReasoningEffort === true || + model.supportsReasoningBudget === true || + (Array.isArray(model.supportsReasoningEffort) && model.supportsReasoningEffort.length > 0) + ) + case "free": + // Consider a model free if it has isFree flag OR if both input and output prices are 0 + return ( + model.isFree === true || + (model.inputPrice !== undefined && + model.outputPrice !== undefined && + model.inputPrice === 0 && + model.outputPrice === 0) + ) + default: + return true + } + }) + }) +} + +/** + * Paginate model list + */ +function paginateModels( + modelIds: string[], + pageIndex: number, + pageSize: number = MODEL_LIST_PAGE_SIZE, +): { + pageIds: string[] + pageCount: number + totalCount: number +} { + const totalCount = modelIds.length + if (totalCount === 0) { + return { pageIds: [], pageCount: 0, totalCount: 0 } + } + const pageCount = Math.ceil(totalCount / pageSize) + const start = pageIndex * pageSize + const pageIds = modelIds.slice(start, start + pageSize) + + return { pageIds, pageCount, totalCount } +} +/** + * Get filtered model IDs and page count based on current filters + * This helper avoids duplication in pagination functions + */ +function getFilteredModelsWithPageCount(params: { + currentProvider: ProviderConfig + routerModels: RouterModels | null + kilocodeDefaultModel: string + filters: { + search?: string | undefined + capabilities: ("images" | "cache" | "reasoning" | "free")[] + } +}): { + models: ModelRecord + modelIds: string[] + pageCount: number + totalCount: number +} { + const { currentProvider, routerModels, kilocodeDefaultModel, filters } = params + + const { models } = getModelsByProvider({ + provider: currentProvider.provider, + routerModels, + kilocodeDefaultModel, + }) + + // Apply search filter + let modelIds = filters.search ? fuzzyFilterModels(models, filters.search) : Object.keys(models) + + // Apply capability filters + modelIds = filterModelsByCapabilities( + modelIds.reduce((acc, id) => { + const model = models[id] + if (model) { + acc[id] = model + } + return acc + }, {} as ModelRecord), + filters.capabilities, + ) + + const { pageCount, totalCount } = paginateModels(modelIds, 0) + + return { models, modelIds, pageCount, totalCount } +} + /** * Show current model information */ @@ -316,10 +465,19 @@ async function selectModel(context: CommandContext, modelId: string): Promise { - const { currentProvider, routerModels, kilocodeDefaultModel, addMessage } = context +async function listModels( + context: CommandContext, + pageIndexOverride?: number, + filtersOverride?: { sort?: string; capabilities?: string[]; search?: string | undefined }, +): Promise { + const { currentProvider, routerModels, kilocodeDefaultModel, addMessage, modelListPageIndex, modelListFilters } = + context + + // Use overrides if provided, otherwise use context values + const effectivePageIndex = pageIndexOverride !== undefined ? pageIndexOverride : modelListPageIndex + const effectiveFilters = filtersOverride ? { ...modelListFilters, ...filtersOverride } : modelListFilters if (!currentProvider) { addMessage({ @@ -349,42 +507,73 @@ async function listModels(context: CommandContext, filter?: string): Promise { - const model = models[id] - if (model) { - acc[id] = model - } - return acc - }, - {} as Record, - ), + // Apply search filter from stored filters + const effectiveSearch = effectiveFilters.search + let modelIds = effectiveSearch ? fuzzyFilterModels(models, effectiveSearch) : Object.keys(models) + + // Apply capability filters + modelIds = filterModelsByCapabilities( + modelIds.reduce((acc, id) => { + const model = models[id] + if (model) { + acc[id] = model + } + return acc + }, {} as ModelRecord), + effectiveFilters.capabilities, + ) + + // Apply sorting + modelIds = sortModels( + modelIds.reduce((acc, id) => { + const model = models[id] + if (model) { + acc[id] = model + } + return acc + }, {} as ModelRecord), + effectiveFilters.sort, ) if (modelIds.length === 0) { addMessage({ id: Date.now().toString(), type: "system", - content: filter ? `No models found matching "${filter}".` : "No models available for this provider.", + content: effectiveSearch + ? `No models found matching "${effectiveSearch}".` + : "No models available for this provider.", ts: Date.now(), }) return } + // Paginate results + const { pageIds, pageCount, totalCount } = paginateModels(modelIds, effectivePageIndex) + const providerName = currentProvider.provider .split("-") .map((word: string) => word.charAt(0).toUpperCase() + word.slice(1)) .join(" ") - let content = `**Available Models (${providerName})**` - if (filter) { - content += ` - Filtered by "${filter}"` + let content = `**Available Models (${providerName})** - Page ${effectivePageIndex + 1}/${pageCount}\n` + + // Show active filters + const filterParts: string[] = [] + if (effectiveFilters.sort !== "preferred") { + filterParts.push(`Sort: ${effectiveFilters.sort}`) + } + if (effectiveFilters.capabilities.length > 0) { + filterParts.push(`Filter: ${effectiveFilters.capabilities.join(", ")}`) } - content += `:\n\n` + if (effectiveSearch) { + filterParts.push(`Search: "${effectiveSearch}"`) + } + if (filterParts.length > 0) { + content += filterParts.join(" | ") + "\n" + } + content += `\n` - for (const modelId of modelIds) { + for (const modelId of pageIds) { const model = models[modelId] if (!model) continue @@ -407,9 +596,22 @@ async function listModels(context: CommandContext, filter?: string): Promise\` to switch models\n` - content += `Use \`/model info \` for detailed information\n` + const start = effectivePageIndex * MODEL_LIST_PAGE_SIZE + 1 + const end = Math.min((effectivePageIndex + 1) * MODEL_LIST_PAGE_SIZE, totalCount) + content += `**Showing ${start}-${end} of ${totalCount} model${totalCount !== 1 ? "s" : ""}**\n\n` + + // Show navigation hints + if (pageCount > 1) { + if (effectivePageIndex < pageCount - 1) { + content += `Use \`/model list next\` for next page\n` + } + if (effectivePageIndex > 0) { + content += `Use \`/model list prev\` for previous page\n` + } + content += `Use \`/model list page \` to go to a specific page\n` + } + content += `Use \`/model list sort