diff --git a/src/api/index.ts b/src/api/index.ts index 7bbac561cc1..e36a4cb3782 100644 --- a/src/api/index.ts +++ b/src/api/index.ts @@ -55,6 +55,8 @@ import { // kilocode_change start import { KilocodeOpenrouterHandler } from "./providers/kilocode-openrouter" import { InceptionLabsHandler } from "./providers/inception" +import type { FimHandler } from "./providers/kilocode/FimHandler" // kilocode_change +export type { FimHandler } from "./providers/kilocode/FimHandler" // kilocode_change end import { NativeOllamaHandler } from "./providers/native-ollama" @@ -137,6 +139,14 @@ export interface ApiHandler { */ countTokens(content: Array): Promise + // kilocode_change start + /** + * Returns a FimHandler if the provider supports FIM (Fill-In-the-Middle) completions, + * or undefined if FIM is not supported. + */ + fimSupport?: () => FimHandler | undefined + // kilocode_change end + contextWindow?: number // kilocode_change: Add contextWindow property for virtual quota fallback } diff --git a/src/api/providers/__tests__/kilocode-openrouter.spec.ts b/src/api/providers/__tests__/kilocode-openrouter.spec.ts index c1138019118..5418b0b7d1e 100644 --- a/src/api/providers/__tests__/kilocode-openrouter.spec.ts +++ b/src/api/providers/__tests__/kilocode-openrouter.spec.ts @@ -241,86 +241,101 @@ describe("KilocodeOpenrouterHandler", () => { }) describe("FIM support", () => { - it("supportsFim returns true for codestral models", () => { - const handler = new KilocodeOpenrouterHandler({ - ...mockOptions, - kilocodeModel: "mistral/codestral-latest", + describe("fimSupport", () => { + it("returns FimHandler for codestral models", () => { + const handler = new KilocodeOpenrouterHandler({ + ...mockOptions, + kilocodeModel: "mistral/codestral-latest", + }) + + const fimHandler = handler.fimSupport() + expect(fimHandler).toBeDefined() + expect(typeof fimHandler?.streamFim).toBe("function") + expect(typeof fimHandler?.getModel).toBe("function") + expect(typeof fimHandler?.getTotalCost).toBe("function") }) - expect(handler.supportsFim()).toBe(true) - }) + it("returns undefined for non-codestral models", () => { + const handler = new KilocodeOpenrouterHandler({ + ...mockOptions, + kilocodeModel: "anthropic/claude-sonnet-4", + }) - it("supportsFim returns false for non-codestral models", () => { - const handler = new KilocodeOpenrouterHandler({ - ...mockOptions, - kilocodeModel: "anthropic/claude-sonnet-4", + expect(handler.fimSupport()).toBeUndefined() }) - - expect(handler.supportsFim()).toBe(false) }) - it("streamFim yields chunks correctly", async () => { - const handler = new KilocodeOpenrouterHandler({ - ...mockOptions, - kilocodeModel: "mistral/codestral-latest", - }) - - // Mock streamSse to return the expected data - ;(streamSse as any).mockImplementation(async function* () { - yield { choices: [{ delta: { content: "chunk1" } }] } - yield { choices: [{ delta: { content: "chunk2" } }] } - yield { - usage: { - prompt_tokens: 10, - completion_tokens: 5, - total_tokens: 15, - }, - } - }) + describe("streamFim via fimSupport()", () => { + it("yields chunks correctly", async () => { + const handler = new KilocodeOpenrouterHandler({ + ...mockOptions, + kilocodeModel: "mistral/codestral-latest", + }) + + // Mock streamSse to return the expected data + ;(streamSse as any).mockImplementation(async function* () { + yield { choices: [{ delta: { content: "chunk1" } }] } + yield { choices: [{ delta: { content: "chunk2" } }] } + yield { + usage: { + prompt_tokens: 10, + completion_tokens: 5, + total_tokens: 15, + }, + } + }) - const mockResponse = { - ok: true, - status: 200, - statusText: "OK", - } as Response + const mockResponse = { + ok: true, + status: 200, + statusText: "OK", + } as Response - global.fetch = vitest.fn().mockResolvedValue(mockResponse) + global.fetch = vitest.fn().mockResolvedValue(mockResponse) - const chunks: string[] = [] - let receivedUsage: any = null + const chunks: string[] = [] + let receivedUsage: any = null - for await (const chunk of handler.streamFim("prefix", "suffix", undefined, (usage) => { - receivedUsage = usage - })) { - chunks.push(chunk) - } + const fimHandler = handler.fimSupport() + expect(fimHandler).toBeDefined() - expect(chunks).toEqual(["chunk1", "chunk2"]) - expect(receivedUsage).toEqual({ - prompt_tokens: 10, - completion_tokens: 5, - total_tokens: 15, - }) - expect(streamSse).toHaveBeenCalledWith(mockResponse) - }) + for await (const chunk of fimHandler!.streamFim("prefix", "suffix", undefined, (usage) => { + receivedUsage = usage + })) { + chunks.push(chunk) + } - it("streamFim handles errors correctly", async () => { - const handler = new KilocodeOpenrouterHandler({ - ...mockOptions, - kilocodeModel: "mistral/codestral-latest", + expect(chunks).toEqual(["chunk1", "chunk2"]) + expect(receivedUsage).toEqual({ + prompt_tokens: 10, + completion_tokens: 5, + total_tokens: 15, + }) + expect(streamSse).toHaveBeenCalledWith(mockResponse) }) - const mockResponse = { - ok: false, - status: 400, - statusText: "Bad Request", - text: vitest.fn().mockResolvedValue("Invalid request"), - } + it("handles errors correctly", async () => { + const handler = new KilocodeOpenrouterHandler({ + ...mockOptions, + kilocodeModel: "mistral/codestral-latest", + }) + + const mockResponse = { + ok: false, + status: 400, + statusText: "Bad Request", + text: vitest.fn().mockResolvedValue("Invalid request"), + } - global.fetch = vitest.fn().mockResolvedValue(mockResponse) + global.fetch = vitest.fn().mockResolvedValue(mockResponse) - const generator = handler.streamFim("prefix", "suffix") - await expect(generator.next()).rejects.toThrow("FIM streaming failed: 400 Bad Request - Invalid request") + const fimHandler = handler.fimSupport() + expect(fimHandler).toBeDefined() + const generator = fimHandler!.streamFim("prefix", "suffix") + await expect(generator.next()).rejects.toThrow( + "FIM streaming failed: 400 Bad Request - Invalid request", + ) + }) }) }) }) diff --git a/src/api/providers/__tests__/mistral-fim.spec.ts b/src/api/providers/__tests__/mistral-fim.spec.ts index b9ebade8fac..d8c3d1bc684 100644 --- a/src/api/providers/__tests__/mistral-fim.spec.ts +++ b/src/api/providers/__tests__/mistral-fim.spec.ts @@ -24,45 +24,49 @@ describe("MistralHandler FIM support", () => { beforeEach(() => vitest.clearAllMocks()) - describe("supportsFim", () => { - it("returns true for codestral models", () => { + describe("fimSupport", () => { + it("returns FimHandler for codestral models", () => { const handler = new MistralHandler({ ...mockOptions, apiModelId: "codestral-latest", }) - expect(handler.supportsFim()).toBe(true) + const fimHandler = handler.fimSupport() + expect(fimHandler).toBeDefined() + expect(typeof fimHandler?.streamFim).toBe("function") + expect(typeof fimHandler?.getModel).toBe("function") + expect(typeof fimHandler?.getTotalCost).toBe("function") }) - it("returns true for codestral-2405", () => { + it("returns FimHandler for codestral-2405", () => { const handler = new MistralHandler({ ...mockOptions, apiModelId: "codestral-2405", }) - expect(handler.supportsFim()).toBe(true) + expect(handler.fimSupport()).toBeDefined() }) - it("returns false for non-codestral models", () => { + it("returns undefined for non-codestral models", () => { const handler = new MistralHandler({ ...mockOptions, apiModelId: "mistral-large-latest", }) - expect(handler.supportsFim()).toBe(false) + expect(handler.fimSupport()).toBeUndefined() }) - it("returns true when no model is specified (defaults to codestral-latest)", () => { + it("returns FimHandler when no model is specified (defaults to codestral-latest)", () => { const handler = new MistralHandler({ mistralApiKey: "test-api-key", }) // Default model is codestral-latest, which supports FIM - expect(handler.supportsFim()).toBe(true) + expect(handler.fimSupport()).toBeDefined() }) }) - describe("streamFim", () => { + describe("streamFim via fimSupport()", () => { it("yields chunks correctly", async () => { const handler = new MistralHandler({ ...mockOptions, @@ -85,8 +89,10 @@ describe("MistralHandler FIM support", () => { global.fetch = vitest.fn().mockResolvedValue(mockResponse) const chunks: string[] = [] + const fimHandler = handler.fimSupport() + expect(fimHandler).toBeDefined() - for await (const chunk of handler.streamFim("prefix", "suffix")) { + for await (const chunk of fimHandler!.streamFim("prefix", "suffix")) { chunks.push(chunk) } @@ -109,7 +115,9 @@ describe("MistralHandler FIM support", () => { global.fetch = vitest.fn().mockResolvedValue(mockResponse) - const generator = handler.streamFim("prefix", "suffix") + const fimHandler = handler.fimSupport() + expect(fimHandler).toBeDefined() + const generator = fimHandler!.streamFim("prefix", "suffix") await expect(generator.next()).rejects.toThrow("FIM streaming failed: 400 Bad Request - Invalid request") }) @@ -131,7 +139,9 @@ describe("MistralHandler FIM support", () => { global.fetch = vitest.fn().mockResolvedValue(mockResponse) - const generator = handler.streamFim("prefix", "suffix") + const fimHandler = handler.fimSupport() + expect(fimHandler).toBeDefined() + const generator = fimHandler!.streamFim("prefix", "suffix") await generator.next() expect(global.fetch).toHaveBeenCalledWith( @@ -166,7 +176,9 @@ describe("MistralHandler FIM support", () => { global.fetch = vitest.fn().mockResolvedValue(mockResponse) - const generator = handler.streamFim("prefix", "suffix") + const fimHandler = handler.fimSupport() + expect(fimHandler).toBeDefined() + const generator = fimHandler!.streamFim("prefix", "suffix") await generator.next() expect(global.fetch).toHaveBeenCalledWith( diff --git a/src/api/providers/kilocode-openrouter.ts b/src/api/providers/kilocode-openrouter.ts index a09e00e03ba..eaa895059ef 100644 --- a/src/api/providers/kilocode-openrouter.ts +++ b/src/api/providers/kilocode-openrouter.ts @@ -20,6 +20,7 @@ import { KILOCODE_TOKEN_REQUIRED_ERROR } from "../../shared/kilocode/errorUtils" import { DEFAULT_HEADERS } from "./constants" import { streamSse } from "../../services/continuedev/core/fetch/stream" import { getEditorNameHeader } from "../../core/kilocode/wrapper" +import type { FimHandler } from "./kilocode/FimHandler" /** * A custom OpenRouter handler that overrides the getModel function @@ -143,9 +144,13 @@ export class KilocodeOpenrouterHandler extends OpenRouterHandler { return this.getModel() } - supportsFim(): boolean { + fimSupport(): FimHandler | undefined { const modelId = this.options.kilocodeModel ?? this.defaultModel - return modelId.includes("codestral") + if (!modelId.includes("codestral")) { + return undefined + } + + return this } async *streamFim( diff --git a/src/api/providers/kilocode/FimHandler.ts b/src/api/providers/kilocode/FimHandler.ts new file mode 100644 index 00000000000..0d115332092 --- /dev/null +++ b/src/api/providers/kilocode/FimHandler.ts @@ -0,0 +1,36 @@ +// kilocode_change - new file +import type { ModelInfo } from "@roo-code/types" +import type { CompletionUsage } from "../openrouter" + +/** + * Interface for FIM (Fill-In-the-Middle) completion handlers. + */ +export interface FimHandler { + /** + * Stream code completion between a prefix and suffix + * @param prefix - The code before the cursor/insertion point + * @param suffix - The code after the cursor/insertion point + * @param taskId - Optional task ID for tracking + * @param onUsage - Optional callback invoked with usage information when available + * @returns An async generator yielding code chunks as strings + */ + streamFim( + prefix: string, + suffix: string, + taskId?: string, + onUsage?: (usage: CompletionUsage) => void, + ): AsyncGenerator + + /** + * Get the model information for the FIM handler + * @returns Object containing model id, info, and optional maxTokens + */ + getModel(): { id: string; info: ModelInfo; maxTokens?: number } + + /** + * Calculate the total cost for a completion based on usage + * @param usage - The completion usage information + * @returns The total cost in dollars + */ + getTotalCost(usage: CompletionUsage): number +} diff --git a/src/api/providers/mistral.ts b/src/api/providers/mistral.ts index 6ef99acd93f..492ffe0f910 100644 --- a/src/api/providers/mistral.ts +++ b/src/api/providers/mistral.ts @@ -14,6 +14,7 @@ import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from ". import { DEFAULT_HEADERS } from "./constants" // kilocode_change import { streamSse } from "../../services/continuedev/core/fetch/stream" // kilocode_change import type { CompletionUsage } from "./openrouter" // kilocode_change +import type { FimHandler } from "./kilocode/FimHandler" // kilocode_change // Type helper to handle thinking chunks from Mistral API // The SDK includes ThinkChunk but TypeScript has trouble with the discriminated union @@ -214,12 +215,26 @@ export class MistralHandler extends BaseProvider implements SingleCompletionHand } // kilocode_change start - supportsFim(): boolean { + fimSupport(): FimHandler | undefined { const modelId = this.options.apiModelId ?? mistralDefaultModelId - return modelId.startsWith("codestral-") + if (!modelId.startsWith("codestral-")) { + return undefined + } + + return { + streamFim: this.streamFim.bind(this), + getModel: () => this.getModel(), + getTotalCost: (usage: CompletionUsage) => { + // Calculate cost based on model pricing + const { info } = this.getModel() + const inputCost = ((usage.prompt_tokens ?? 0) / 1_000_000) * (info.inputPrice ?? 0) + const outputCost = ((usage.completion_tokens ?? 0) / 1_000_000) * (info.outputPrice ?? 0) + return inputCost + outputCost + }, + } } - async *streamFim( + private async *streamFim( prefix: string, suffix: string, _taskId?: string, diff --git a/src/services/ghost/GhostModel.ts b/src/services/ghost/GhostModel.ts index 3e9a74ffea4..1ce81230d9b 100644 --- a/src/services/ghost/GhostModel.ts +++ b/src/services/ghost/GhostModel.ts @@ -1,6 +1,6 @@ // kilocode_change new file import { modelIdKeysByProvider, ProviderName } from "@roo-code/types" -import { ApiHandler, buildApiHandler } from "../../api" +import { ApiHandler, buildApiHandler, FimHandler } from "../../api" import { ProviderSettingsManager } from "../../core/config/ProviderSettingsManager" import { OpenRouterHandler } from "../../api/providers" import { CompletionUsage } from "../../api/providers/openrouter" @@ -10,31 +10,11 @@ import { KilocodeOpenrouterHandler } from "../../api/providers/kilocode-openrout import { PROVIDERS } from "../../../webview-ui/src/components/settings/constants" import { ResponseMetaData } from "./types" -/** - * Interface for handlers that support FIM (Fill-in-the-Middle) completions. - * Uses duck typing - any handler implementing these methods can be used for FIM. - */ -interface FimCapableHandler { - supportsFim(): boolean - streamFim( - prefix: string, - suffix: string, - taskId?: string, - onUsage?: (usage: CompletionUsage) => void, - ): AsyncGenerator - getModel(): { id: string; info: any; maxTokens?: number } - getTotalCost?(usage: CompletionUsage): number -} - -/** - * Type guard to check if a handler supports FIM operations using duck typing. - */ -function isFimCapable(handler: ApiHandler): handler is ApiHandler & FimCapableHandler { - return ( - typeof (handler as any).supportsFim === "function" && - typeof (handler as any).streamFim === "function" && - (handler as any).supportsFim() === true - ) +function getFimHandler(handler: ApiHandler): FimHandler | undefined { + if (typeof handler.fimSupport === "function") { + return handler.fimSupport() + } + return undefined } // Convert PROVIDERS array to a lookup map for display names @@ -120,13 +100,11 @@ export class GhostModel { return false } - // Use duck typing to check if the handler supports FIM - return isFimCapable(this.apiHandler) + return getFimHandler(this.apiHandler) !== undefined } /** * Generate FIM completion using the FIM API endpoint. - * Uses duck typing to support any handler that implements supportsFim() and streamFim(). */ public async generateFimResponse( prefix: string, @@ -139,23 +117,23 @@ export class GhostModel { throw new Error("API handler is not initialized. Please check your configuration.") } - if (!isFimCapable(this.apiHandler)) { + const fimHandler = getFimHandler(this.apiHandler) + if (!fimHandler) { throw new Error("Current provider/model does not support FIM completions") } - console.log("USED MODEL (FIM)", this.apiHandler.getModel()) + console.log("USED MODEL (FIM)", fimHandler.getModel()) let usage: CompletionUsage | undefined - for await (const chunk of this.apiHandler.streamFim(prefix, suffix, taskId, (u) => { + for await (const chunk of fimHandler.streamFim(prefix, suffix, taskId, (u: CompletionUsage) => { usage = u })) { onChunk(chunk) } - // Calculate cost if the handler supports it (duck typing) - const cost = - usage && typeof this.apiHandler.getTotalCost === "function" ? this.apiHandler.getTotalCost(usage) : 0 + // Calculate cost using the FimHandler's getTotalCost method + const cost = usage ? fimHandler.getTotalCost(usage) : 0 const inputTokens = usage?.prompt_tokens ?? 0 const outputTokens = usage?.completion_tokens ?? 0 const cacheReadTokens = usage?.prompt_tokens_details?.cached_tokens ?? 0