Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/api/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ import {
// kilocode_change start
import { KilocodeOpenrouterHandler } from "./providers/kilocode-openrouter"
import { InceptionLabsHandler } from "./providers/inception"
import type { FimHandler } from "./providers/kilocode/FimHandler" // kilocode_change
export type { FimHandler } from "./providers/kilocode/FimHandler"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any reason to export this here other than everyone/Opus does it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i assumed this to be best practice al this is the index file, and also has a barrel function. That's not the case?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see the point of barrel files, and IIRC we have had eng. discussions about it, and IIRC the general tone was that it's not worth it in general.

// kilocode_change end
import { NativeOllamaHandler } from "./providers/native-ollama"

Expand Down Expand Up @@ -137,6 +139,14 @@ export interface ApiHandler {
*/
countTokens(content: Array<Anthropic.Messages.ContentBlockParam>): Promise<number>

// kilocode_change start
/**
* Returns a FimHandler if the provider supports FIM (Fill-In-the-Middle) completions,
* or undefined if FIM is not supported.
*/
fimSupport?: () => FimHandler | undefined
// kilocode_change end

contextWindow?: number // kilocode_change: Add contextWindow property for virtual quota fallback
}

Expand Down
143 changes: 79 additions & 64 deletions src/api/providers/__tests__/kilocode-openrouter.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -241,86 +241,101 @@ describe("KilocodeOpenrouterHandler", () => {
})

describe("FIM support", () => {
it("supportsFim returns true for codestral models", () => {
const handler = new KilocodeOpenrouterHandler({
...mockOptions,
kilocodeModel: "mistral/codestral-latest",
describe("fimSupport", () => {
it("returns FimHandler for codestral models", () => {
const handler = new KilocodeOpenrouterHandler({
...mockOptions,
kilocodeModel: "mistral/codestral-latest",
})

const fimHandler = handler.fimSupport()
expect(fimHandler).toBeDefined()
expect(typeof fimHandler?.streamFim).toBe("function")
expect(typeof fimHandler?.getModel).toBe("function")
expect(typeof fimHandler?.getTotalCost).toBe("function")
})

expect(handler.supportsFim()).toBe(true)
})
it("returns undefined for non-codestral models", () => {
const handler = new KilocodeOpenrouterHandler({
...mockOptions,
kilocodeModel: "anthropic/claude-sonnet-4",
})

it("supportsFim returns false for non-codestral models", () => {
const handler = new KilocodeOpenrouterHandler({
...mockOptions,
kilocodeModel: "anthropic/claude-sonnet-4",
expect(handler.fimSupport()).toBeUndefined()
})

expect(handler.supportsFim()).toBe(false)
})

it("streamFim yields chunks correctly", async () => {
const handler = new KilocodeOpenrouterHandler({
...mockOptions,
kilocodeModel: "mistral/codestral-latest",
})

// Mock streamSse to return the expected data
;(streamSse as any).mockImplementation(async function* () {
yield { choices: [{ delta: { content: "chunk1" } }] }
yield { choices: [{ delta: { content: "chunk2" } }] }
yield {
usage: {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15,
},
}
})
describe("streamFim via fimSupport()", () => {
it("yields chunks correctly", async () => {
const handler = new KilocodeOpenrouterHandler({
...mockOptions,
kilocodeModel: "mistral/codestral-latest",
})

// Mock streamSse to return the expected data
;(streamSse as any).mockImplementation(async function* () {
yield { choices: [{ delta: { content: "chunk1" } }] }
yield { choices: [{ delta: { content: "chunk2" } }] }
yield {
usage: {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15,
},
}
})

const mockResponse = {
ok: true,
status: 200,
statusText: "OK",
} as Response
const mockResponse = {
ok: true,
status: 200,
statusText: "OK",
} as Response

global.fetch = vitest.fn().mockResolvedValue(mockResponse)
global.fetch = vitest.fn().mockResolvedValue(mockResponse)

const chunks: string[] = []
let receivedUsage: any = null
const chunks: string[] = []
let receivedUsage: any = null

for await (const chunk of handler.streamFim("prefix", "suffix", undefined, (usage) => {
receivedUsage = usage
})) {
chunks.push(chunk)
}
const fimHandler = handler.fimSupport()
expect(fimHandler).toBeDefined()

expect(chunks).toEqual(["chunk1", "chunk2"])
expect(receivedUsage).toEqual({
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15,
})
expect(streamSse).toHaveBeenCalledWith(mockResponse)
})
for await (const chunk of fimHandler!.streamFim("prefix", "suffix", undefined, (usage) => {
receivedUsage = usage
})) {
chunks.push(chunk)
}

it("streamFim handles errors correctly", async () => {
const handler = new KilocodeOpenrouterHandler({
...mockOptions,
kilocodeModel: "mistral/codestral-latest",
expect(chunks).toEqual(["chunk1", "chunk2"])
expect(receivedUsage).toEqual({
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15,
})
expect(streamSse).toHaveBeenCalledWith(mockResponse)
})

const mockResponse = {
ok: false,
status: 400,
statusText: "Bad Request",
text: vitest.fn().mockResolvedValue("Invalid request"),
}
it("handles errors correctly", async () => {
const handler = new KilocodeOpenrouterHandler({
...mockOptions,
kilocodeModel: "mistral/codestral-latest",
})

const mockResponse = {
ok: false,
status: 400,
statusText: "Bad Request",
text: vitest.fn().mockResolvedValue("Invalid request"),
}

global.fetch = vitest.fn().mockResolvedValue(mockResponse)
global.fetch = vitest.fn().mockResolvedValue(mockResponse)

const generator = handler.streamFim("prefix", "suffix")
await expect(generator.next()).rejects.toThrow("FIM streaming failed: 400 Bad Request - Invalid request")
const fimHandler = handler.fimSupport()
expect(fimHandler).toBeDefined()
const generator = fimHandler!.streamFim("prefix", "suffix")
await expect(generator.next()).rejects.toThrow(
"FIM streaming failed: 400 Bad Request - Invalid request",
)
})
})
})
})
40 changes: 26 additions & 14 deletions src/api/providers/__tests__/mistral-fim.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,45 +24,49 @@ describe("MistralHandler FIM support", () => {

beforeEach(() => vitest.clearAllMocks())

describe("supportsFim", () => {
it("returns true for codestral models", () => {
describe("fimSupport", () => {
it("returns FimHandler for codestral models", () => {
const handler = new MistralHandler({
...mockOptions,
apiModelId: "codestral-latest",
})

expect(handler.supportsFim()).toBe(true)
const fimHandler = handler.fimSupport()
expect(fimHandler).toBeDefined()
expect(typeof fimHandler?.streamFim).toBe("function")
expect(typeof fimHandler?.getModel).toBe("function")
expect(typeof fimHandler?.getTotalCost).toBe("function")
})

it("returns true for codestral-2405", () => {
it("returns FimHandler for codestral-2405", () => {
const handler = new MistralHandler({
...mockOptions,
apiModelId: "codestral-2405",
})

expect(handler.supportsFim()).toBe(true)
expect(handler.fimSupport()).toBeDefined()
})

it("returns false for non-codestral models", () => {
it("returns undefined for non-codestral models", () => {
const handler = new MistralHandler({
...mockOptions,
apiModelId: "mistral-large-latest",
})

expect(handler.supportsFim()).toBe(false)
expect(handler.fimSupport()).toBeUndefined()
})

it("returns true when no model is specified (defaults to codestral-latest)", () => {
it("returns FimHandler when no model is specified (defaults to codestral-latest)", () => {
const handler = new MistralHandler({
mistralApiKey: "test-api-key",
})

// Default model is codestral-latest, which supports FIM
expect(handler.supportsFim()).toBe(true)
expect(handler.fimSupport()).toBeDefined()
})
})

describe("streamFim", () => {
describe("streamFim via fimSupport()", () => {
it("yields chunks correctly", async () => {
const handler = new MistralHandler({
...mockOptions,
Expand All @@ -85,8 +89,10 @@ describe("MistralHandler FIM support", () => {
global.fetch = vitest.fn().mockResolvedValue(mockResponse)

const chunks: string[] = []
const fimHandler = handler.fimSupport()
expect(fimHandler).toBeDefined()

for await (const chunk of handler.streamFim("prefix", "suffix")) {
for await (const chunk of fimHandler!.streamFim("prefix", "suffix")) {
chunks.push(chunk)
}

Expand All @@ -109,7 +115,9 @@ describe("MistralHandler FIM support", () => {

global.fetch = vitest.fn().mockResolvedValue(mockResponse)

const generator = handler.streamFim("prefix", "suffix")
const fimHandler = handler.fimSupport()
expect(fimHandler).toBeDefined()
const generator = fimHandler!.streamFim("prefix", "suffix")
await expect(generator.next()).rejects.toThrow("FIM streaming failed: 400 Bad Request - Invalid request")
})

Expand All @@ -131,7 +139,9 @@ describe("MistralHandler FIM support", () => {

global.fetch = vitest.fn().mockResolvedValue(mockResponse)

const generator = handler.streamFim("prefix", "suffix")
const fimHandler = handler.fimSupport()
expect(fimHandler).toBeDefined()
const generator = fimHandler!.streamFim("prefix", "suffix")
await generator.next()

expect(global.fetch).toHaveBeenCalledWith(
Expand Down Expand Up @@ -166,7 +176,9 @@ describe("MistralHandler FIM support", () => {

global.fetch = vitest.fn().mockResolvedValue(mockResponse)

const generator = handler.streamFim("prefix", "suffix")
const fimHandler = handler.fimSupport()
expect(fimHandler).toBeDefined()
const generator = fimHandler!.streamFim("prefix", "suffix")
await generator.next()

expect(global.fetch).toHaveBeenCalledWith(
Expand Down
9 changes: 7 additions & 2 deletions src/api/providers/kilocode-openrouter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import { KILOCODE_TOKEN_REQUIRED_ERROR } from "../../shared/kilocode/errorUtils"
import { DEFAULT_HEADERS } from "./constants"
import { streamSse } from "../../services/continuedev/core/fetch/stream"
import { getEditorNameHeader } from "../../core/kilocode/wrapper"
import type { FimHandler } from "./kilocode/FimHandler"

/**
* A custom OpenRouter handler that overrides the getModel function
Expand Down Expand Up @@ -143,9 +144,13 @@ export class KilocodeOpenrouterHandler extends OpenRouterHandler {
return this.getModel()
}

supportsFim(): boolean {
fimSupport(): FimHandler | undefined {
const modelId = this.options.kilocodeModel ?? this.defaultModel
return modelId.includes("codestral")
if (!modelId.includes("codestral")) {
return undefined
}

return this
}

async *streamFim(
Expand Down
36 changes: 36 additions & 0 deletions src/api/providers/kilocode/FimHandler.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// kilocode_change - new file
import type { ModelInfo } from "@roo-code/types"
import type { CompletionUsage } from "../openrouter"

/**
* Interface for FIM (Fill-In-the-Middle) completion handlers.
*/
export interface FimHandler {
/**
* Stream code completion between a prefix and suffix
* @param prefix - The code before the cursor/insertion point
* @param suffix - The code after the cursor/insertion point
* @param taskId - Optional task ID for tracking
* @param onUsage - Optional callback invoked with usage information when available
* @returns An async generator yielding code chunks as strings
*/
streamFim(
prefix: string,
suffix: string,
taskId?: string,
onUsage?: (usage: CompletionUsage) => void,
): AsyncGenerator<string>

/**
* Get the model information for the FIM handler
* @returns Object containing model id, info, and optional maxTokens
*/
getModel(): { id: string; info: ModelInfo; maxTokens?: number }

/**
* Calculate the total cost for a completion based on usage
* @param usage - The completion usage information
* @returns The total cost in dollars
*/
getTotalCost(usage: CompletionUsage): number
}
21 changes: 18 additions & 3 deletions src/api/providers/mistral.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from ".
import { DEFAULT_HEADERS } from "./constants" // kilocode_change
import { streamSse } from "../../services/continuedev/core/fetch/stream" // kilocode_change
import type { CompletionUsage } from "./openrouter" // kilocode_change
import type { FimHandler } from "./kilocode/FimHandler" // kilocode_change

// Type helper to handle thinking chunks from Mistral API
// The SDK includes ThinkChunk but TypeScript has trouble with the discriminated union
Expand Down Expand Up @@ -214,12 +215,26 @@ export class MistralHandler extends BaseProvider implements SingleCompletionHand
}

// kilocode_change start
supportsFim(): boolean {
fimSupport(): FimHandler | undefined {
const modelId = this.options.apiModelId ?? mistralDefaultModelId
return modelId.startsWith("codestral-")
if (!modelId.startsWith("codestral-")) {
return undefined
}

return {
streamFim: this.streamFim.bind(this),
getModel: () => this.getModel(),
getTotalCost: (usage: CompletionUsage) => {
// Calculate cost based on model pricing
const { info } = this.getModel()
const inputCost = ((usage.prompt_tokens ?? 0) / 1_000_000) * (info.inputPrice ?? 0)
const outputCost = ((usage.completion_tokens ?? 0) / 1_000_000) * (info.outputPrice ?? 0)
return inputCost + outputCost
},
}
}

async *streamFim(
private async *streamFim(
prefix: string,
suffix: string,
_taskId?: string,
Expand Down
Loading