Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 6 additions & 10 deletions src/api/providers/vscode-lm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import type { ApiHandlerOptions } from "../../shared/api"
import { SELECTOR_SEPARATOR, stringifyVsCodeLmModelSelector } from "../../shared/vsCodeSelectorUtils"

import { ApiStream } from "../transform/stream"
import { convertToVsCodeLmMessages } from "../transform/vscode-lm-format"
import { convertToVsCodeLmMessages, extractTextCountFromMessage } from "../transform/vscode-lm-format"

import { BaseProvider } from "./base-provider"
import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
Expand Down Expand Up @@ -231,7 +231,8 @@ export class VsCodeLmHandler extends BaseProvider implements SingleCompletionHan
console.debug("Roo Code <Language Model API>: Empty chat message content")
return 0
}
tokenCount = await this.client.countTokens(text, this.currentRequestCancellation.token)
const countMessage = extractTextCountFromMessage(text)
tokenCount = await this.client.countTokens(countMessage, this.currentRequestCancellation.token)
} else {
console.warn("Roo Code <Language Model API>: Invalid input type for token counting")
return 0
Expand Down Expand Up @@ -268,15 +269,10 @@ export class VsCodeLmHandler extends BaseProvider implements SingleCompletionHan
}
}

private async calculateTotalInputTokens(
systemPrompt: string,
vsCodeLmMessages: vscode.LanguageModelChatMessage[],
): Promise<number> {
const systemTokens: number = await this.internalCountTokens(systemPrompt)

private async calculateTotalInputTokens(vsCodeLmMessages: vscode.LanguageModelChatMessage[]): Promise<number> {
const messageTokens: number[] = await Promise.all(vsCodeLmMessages.map((msg) => this.internalCountTokens(msg)))

return systemTokens + messageTokens.reduce((sum: number, tokens: number): number => sum + tokens, 0)
return messageTokens.reduce((sum: number, tokens: number): number => sum + tokens, 0)
}

private ensureCleanState(): void {
Expand Down Expand Up @@ -359,7 +355,7 @@ export class VsCodeLmHandler extends BaseProvider implements SingleCompletionHan
this.currentRequestCancellation = new vscode.CancellationTokenSource()

// Calculate input tokens before starting the stream
const totalInputTokens: number = await this.calculateTotalInputTokens(systemPrompt, vsCodeLmMessages)
const totalInputTokens: number = await this.calculateTotalInputTokens(vsCodeLmMessages)

// Accumulate the text and count at the end of the stream to reduce token counting overhead.
let accumulatedText: string = ""
Expand Down
165 changes: 160 additions & 5 deletions src/api/transform/__tests__/vscode-lm-format.spec.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
// npx vitest run src/api/transform/__tests__/vscode-lm-format.spec.ts

import { Anthropic } from "@anthropic-ai/sdk"
import * as vscode from "vscode"

import { convertToVsCodeLmMessages, convertToAnthropicRole } from "../vscode-lm-format"
import { convertToVsCodeLmMessages, convertToAnthropicRole, extractTextCountFromMessage } from "../vscode-lm-format"

// Mock crypto using Vitest
vitest.stubGlobal("crypto", {
Expand All @@ -24,8 +25,8 @@ interface MockLanguageModelToolCallPart {

interface MockLanguageModelToolResultPart {
type: "tool_result"
toolUseId: string
parts: MockLanguageModelTextPart[]
callId: string
content: MockLanguageModelTextPart[]
}

// Mock vscode namespace
Expand All @@ -52,8 +53,8 @@ vitest.mock("vscode", () => {
class MockLanguageModelToolResultPart {
type = "tool_result"
constructor(
public toolUseId: string,
public parts: MockLanguageModelTextPart[],
public callId: string,
public content: MockLanguageModelTextPart[],
) {}
}

Expand Down Expand Up @@ -189,3 +190,157 @@ describe("convertToAnthropicRole", () => {
expect(result).toBeNull()
})
})

describe("extractTextCountFromMessage", () => {
it("should extract text from simple string content", () => {
const message = {
role: "user",
content: "Hello world",
} as any

const result = extractTextCountFromMessage(message)
expect(result).toBe("Hello world")
})

it("should extract text from LanguageModelTextPart", () => {
const mockTextPart = new (vitest.mocked(vscode).LanguageModelTextPart)("Text content")
const message = {
role: "user",
content: [mockTextPart],
} as any

const result = extractTextCountFromMessage(message)
expect(result).toBe("Text content")
})

it("should extract text from multiple LanguageModelTextParts", () => {
const mockTextPart1 = new (vitest.mocked(vscode).LanguageModelTextPart)("First part")
const mockTextPart2 = new (vitest.mocked(vscode).LanguageModelTextPart)("Second part")
const message = {
role: "user",
content: [mockTextPart1, mockTextPart2],
} as any

const result = extractTextCountFromMessage(message)
expect(result).toBe("First partSecond part")
})

it("should extract text from LanguageModelToolResultPart", () => {
const mockTextPart = new (vitest.mocked(vscode).LanguageModelTextPart)("Tool result content")
const mockToolResultPart = new (vitest.mocked(vscode).LanguageModelToolResultPart)("tool-result-id", [
mockTextPart,
])
const message = {
role: "user",
content: [mockToolResultPart],
} as any

const result = extractTextCountFromMessage(message)
expect(result).toBe("tool-result-idTool result content")
})

it("should extract text from LanguageModelToolCallPart without input", () => {
const mockToolCallPart = new (vitest.mocked(vscode).LanguageModelToolCallPart)("call-id", "tool-name", {})
const message = {
role: "assistant",
content: [mockToolCallPart],
} as any

const result = extractTextCountFromMessage(message)
expect(result).toBe("tool-namecall-id")
})

it("should extract text from LanguageModelToolCallPart with input", () => {
const mockInput = { operation: "add", numbers: [1, 2, 3] }
const mockToolCallPart = new (vitest.mocked(vscode).LanguageModelToolCallPart)(
"call-id",
"calculator",
mockInput,
)
const message = {
role: "assistant",
content: [mockToolCallPart],
} as any

const result = extractTextCountFromMessage(message)
expect(result).toBe(`calculatorcall-id${JSON.stringify(mockInput)}`)
})

it("should extract text from LanguageModelToolCallPart with empty input", () => {
const mockToolCallPart = new (vitest.mocked(vscode).LanguageModelToolCallPart)("call-id", "tool-name", {})
const message = {
role: "assistant",
content: [mockToolCallPart],
} as any

const result = extractTextCountFromMessage(message)
expect(result).toBe("tool-namecall-id")
})

it("should extract text from mixed content types", () => {
const mockTextPart = new (vitest.mocked(vscode).LanguageModelTextPart)("Text content")
const mockToolResultTextPart = new (vitest.mocked(vscode).LanguageModelTextPart)("Tool result")
const mockToolResultPart = new (vitest.mocked(vscode).LanguageModelToolResultPart)("result-id", [
mockToolResultTextPart,
])
const mockInput = { param: "value" }
const mockToolCallPart = new (vitest.mocked(vscode).LanguageModelToolCallPart)("call-id", "tool", mockInput)

const message = {
role: "assistant",
content: [mockTextPart, mockToolResultPart, mockToolCallPart],
} as any

const result = extractTextCountFromMessage(message)
expect(result).toBe(`Text contentresult-idTool resulttoolcall-id${JSON.stringify(mockInput)}`)
})

it("should handle empty array content", () => {
const message = {
role: "user",
content: [],
} as any

const result = extractTextCountFromMessage(message)
expect(result).toBe("")
})

it("should handle undefined content", () => {
const message = {
role: "user",
content: undefined,
} as any

const result = extractTextCountFromMessage(message)
expect(result).toBe("")
})

it("should handle ToolResultPart with multiple text parts", () => {
const mockTextPart1 = new (vitest.mocked(vscode).LanguageModelTextPart)("Part 1")
const mockTextPart2 = new (vitest.mocked(vscode).LanguageModelTextPart)("Part 2")
const mockToolResultPart = new (vitest.mocked(vscode).LanguageModelToolResultPart)("result-id", [
mockTextPart1,
mockTextPart2,
])

const message = {
role: "user",
content: [mockToolResultPart],
} as any

const result = extractTextCountFromMessage(message)
expect(result).toBe("result-idPart 1Part 2")
})

it("should handle ToolResultPart with empty parts array", () => {
const mockToolResultPart = new (vitest.mocked(vscode).LanguageModelToolResultPart)("result-id", [])

const message = {
role: "user",
content: [mockToolResultPart],
} as any

const result = extractTextCountFromMessage(message)
expect(result).toBe("result-id")
})
})
38 changes: 38 additions & 0 deletions src/api/transform/vscode-lm-format.ts
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,41 @@ export function convertToAnthropicRole(vsCodeLmMessageRole: vscode.LanguageModel
return null
}
}

/**
* Extracts the text content from a VS Code Language Model chat message.
* @param message A VS Code Language Model chat message.
* @returns The extracted text content.
*/
export function extractTextCountFromMessage(message: vscode.LanguageModelChatMessage): string {
let text = ""
if (Array.isArray(message.content)) {
for (const item of message.content) {
if (item instanceof vscode.LanguageModelTextPart) {
text += item.value
}
if (item instanceof vscode.LanguageModelToolResultPart) {
text += item.callId
for (const part of item.content) {
if (part instanceof vscode.LanguageModelTextPart) {
text += part.value
}
}
}
if (item instanceof vscode.LanguageModelToolCallPart) {
text += item.name
text += item.callId
if (item.input && Object.keys(item.input).length > 0) {
try {
text += JSON.stringify(item.input)
} catch (error) {
console.error("Roo Code <Language Model API>: Failed to stringify tool call input:", error)
}
}
}
}
} else if (typeof message.content === "string") {
text += message.content
}
return text
}