diff --git a/src/api/providers/vscode-lm.ts b/src/api/providers/vscode-lm.ts index 6474371bee..d8a492f772 100644 --- a/src/api/providers/vscode-lm.ts +++ b/src/api/providers/vscode-lm.ts @@ -7,7 +7,7 @@ import type { ApiHandlerOptions } from "../../shared/api" import { SELECTOR_SEPARATOR, stringifyVsCodeLmModelSelector } from "../../shared/vsCodeSelectorUtils" import { ApiStream } from "../transform/stream" -import { convertToVsCodeLmMessages } from "../transform/vscode-lm-format" +import { convertToVsCodeLmMessages, extractTextCountFromMessage } from "../transform/vscode-lm-format" import { BaseProvider } from "./base-provider" import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index" @@ -231,7 +231,8 @@ export class VsCodeLmHandler extends BaseProvider implements SingleCompletionHan console.debug("Roo Code : Empty chat message content") return 0 } - tokenCount = await this.client.countTokens(text, this.currentRequestCancellation.token) + const countMessage = extractTextCountFromMessage(text) + tokenCount = await this.client.countTokens(countMessage, this.currentRequestCancellation.token) } else { console.warn("Roo Code : Invalid input type for token counting") return 0 @@ -268,15 +269,10 @@ export class VsCodeLmHandler extends BaseProvider implements SingleCompletionHan } } - private async calculateTotalInputTokens( - systemPrompt: string, - vsCodeLmMessages: vscode.LanguageModelChatMessage[], - ): Promise { - const systemTokens: number = await this.internalCountTokens(systemPrompt) - + private async calculateTotalInputTokens(vsCodeLmMessages: vscode.LanguageModelChatMessage[]): Promise { const messageTokens: number[] = await Promise.all(vsCodeLmMessages.map((msg) => this.internalCountTokens(msg))) - return systemTokens + messageTokens.reduce((sum: number, tokens: number): number => sum + tokens, 0) + return messageTokens.reduce((sum: number, tokens: number): number => sum + tokens, 0) } private ensureCleanState(): void { @@ -359,7 +355,7 @@ export class VsCodeLmHandler extends BaseProvider implements SingleCompletionHan this.currentRequestCancellation = new vscode.CancellationTokenSource() // Calculate input tokens before starting the stream - const totalInputTokens: number = await this.calculateTotalInputTokens(systemPrompt, vsCodeLmMessages) + const totalInputTokens: number = await this.calculateTotalInputTokens(vsCodeLmMessages) // Accumulate the text and count at the end of the stream to reduce token counting overhead. let accumulatedText: string = "" diff --git a/src/api/transform/__tests__/vscode-lm-format.spec.ts b/src/api/transform/__tests__/vscode-lm-format.spec.ts index 73878033c2..1f53cc5751 100644 --- a/src/api/transform/__tests__/vscode-lm-format.spec.ts +++ b/src/api/transform/__tests__/vscode-lm-format.spec.ts @@ -1,8 +1,9 @@ // npx vitest run src/api/transform/__tests__/vscode-lm-format.spec.ts import { Anthropic } from "@anthropic-ai/sdk" +import * as vscode from "vscode" -import { convertToVsCodeLmMessages, convertToAnthropicRole } from "../vscode-lm-format" +import { convertToVsCodeLmMessages, convertToAnthropicRole, extractTextCountFromMessage } from "../vscode-lm-format" // Mock crypto using Vitest vitest.stubGlobal("crypto", { @@ -24,8 +25,8 @@ interface MockLanguageModelToolCallPart { interface MockLanguageModelToolResultPart { type: "tool_result" - toolUseId: string - parts: MockLanguageModelTextPart[] + callId: string + content: MockLanguageModelTextPart[] } // Mock vscode namespace @@ -52,8 +53,8 @@ vitest.mock("vscode", () => { class MockLanguageModelToolResultPart { type = "tool_result" constructor( - public toolUseId: string, - public parts: MockLanguageModelTextPart[], + public callId: string, + public content: MockLanguageModelTextPart[], ) {} } @@ -189,3 +190,157 @@ describe("convertToAnthropicRole", () => { expect(result).toBeNull() }) }) + +describe("extractTextCountFromMessage", () => { + it("should extract text from simple string content", () => { + const message = { + role: "user", + content: "Hello world", + } as any + + const result = extractTextCountFromMessage(message) + expect(result).toBe("Hello world") + }) + + it("should extract text from LanguageModelTextPart", () => { + const mockTextPart = new (vitest.mocked(vscode).LanguageModelTextPart)("Text content") + const message = { + role: "user", + content: [mockTextPart], + } as any + + const result = extractTextCountFromMessage(message) + expect(result).toBe("Text content") + }) + + it("should extract text from multiple LanguageModelTextParts", () => { + const mockTextPart1 = new (vitest.mocked(vscode).LanguageModelTextPart)("First part") + const mockTextPart2 = new (vitest.mocked(vscode).LanguageModelTextPart)("Second part") + const message = { + role: "user", + content: [mockTextPart1, mockTextPart2], + } as any + + const result = extractTextCountFromMessage(message) + expect(result).toBe("First partSecond part") + }) + + it("should extract text from LanguageModelToolResultPart", () => { + const mockTextPart = new (vitest.mocked(vscode).LanguageModelTextPart)("Tool result content") + const mockToolResultPart = new (vitest.mocked(vscode).LanguageModelToolResultPart)("tool-result-id", [ + mockTextPart, + ]) + const message = { + role: "user", + content: [mockToolResultPart], + } as any + + const result = extractTextCountFromMessage(message) + expect(result).toBe("tool-result-idTool result content") + }) + + it("should extract text from LanguageModelToolCallPart without input", () => { + const mockToolCallPart = new (vitest.mocked(vscode).LanguageModelToolCallPart)("call-id", "tool-name", {}) + const message = { + role: "assistant", + content: [mockToolCallPart], + } as any + + const result = extractTextCountFromMessage(message) + expect(result).toBe("tool-namecall-id") + }) + + it("should extract text from LanguageModelToolCallPart with input", () => { + const mockInput = { operation: "add", numbers: [1, 2, 3] } + const mockToolCallPart = new (vitest.mocked(vscode).LanguageModelToolCallPart)( + "call-id", + "calculator", + mockInput, + ) + const message = { + role: "assistant", + content: [mockToolCallPart], + } as any + + const result = extractTextCountFromMessage(message) + expect(result).toBe(`calculatorcall-id${JSON.stringify(mockInput)}`) + }) + + it("should extract text from LanguageModelToolCallPart with empty input", () => { + const mockToolCallPart = new (vitest.mocked(vscode).LanguageModelToolCallPart)("call-id", "tool-name", {}) + const message = { + role: "assistant", + content: [mockToolCallPart], + } as any + + const result = extractTextCountFromMessage(message) + expect(result).toBe("tool-namecall-id") + }) + + it("should extract text from mixed content types", () => { + const mockTextPart = new (vitest.mocked(vscode).LanguageModelTextPart)("Text content") + const mockToolResultTextPart = new (vitest.mocked(vscode).LanguageModelTextPart)("Tool result") + const mockToolResultPart = new (vitest.mocked(vscode).LanguageModelToolResultPart)("result-id", [ + mockToolResultTextPart, + ]) + const mockInput = { param: "value" } + const mockToolCallPart = new (vitest.mocked(vscode).LanguageModelToolCallPart)("call-id", "tool", mockInput) + + const message = { + role: "assistant", + content: [mockTextPart, mockToolResultPart, mockToolCallPart], + } as any + + const result = extractTextCountFromMessage(message) + expect(result).toBe(`Text contentresult-idTool resulttoolcall-id${JSON.stringify(mockInput)}`) + }) + + it("should handle empty array content", () => { + const message = { + role: "user", + content: [], + } as any + + const result = extractTextCountFromMessage(message) + expect(result).toBe("") + }) + + it("should handle undefined content", () => { + const message = { + role: "user", + content: undefined, + } as any + + const result = extractTextCountFromMessage(message) + expect(result).toBe("") + }) + + it("should handle ToolResultPart with multiple text parts", () => { + const mockTextPart1 = new (vitest.mocked(vscode).LanguageModelTextPart)("Part 1") + const mockTextPart2 = new (vitest.mocked(vscode).LanguageModelTextPart)("Part 2") + const mockToolResultPart = new (vitest.mocked(vscode).LanguageModelToolResultPart)("result-id", [ + mockTextPart1, + mockTextPart2, + ]) + + const message = { + role: "user", + content: [mockToolResultPart], + } as any + + const result = extractTextCountFromMessage(message) + expect(result).toBe("result-idPart 1Part 2") + }) + + it("should handle ToolResultPart with empty parts array", () => { + const mockToolResultPart = new (vitest.mocked(vscode).LanguageModelToolResultPart)("result-id", []) + + const message = { + role: "user", + content: [mockToolResultPart], + } as any + + const result = extractTextCountFromMessage(message) + expect(result).toBe("result-id") + }) +}) diff --git a/src/api/transform/vscode-lm-format.ts b/src/api/transform/vscode-lm-format.ts index 080267b221..58b85f19a9 100644 --- a/src/api/transform/vscode-lm-format.ts +++ b/src/api/transform/vscode-lm-format.ts @@ -155,3 +155,41 @@ export function convertToAnthropicRole(vsCodeLmMessageRole: vscode.LanguageModel return null } } + +/** + * Extracts the text content from a VS Code Language Model chat message. + * @param message A VS Code Language Model chat message. + * @returns The extracted text content. + */ +export function extractTextCountFromMessage(message: vscode.LanguageModelChatMessage): string { + let text = "" + if (Array.isArray(message.content)) { + for (const item of message.content) { + if (item instanceof vscode.LanguageModelTextPart) { + text += item.value + } + if (item instanceof vscode.LanguageModelToolResultPart) { + text += item.callId + for (const part of item.content) { + if (part instanceof vscode.LanguageModelTextPart) { + text += part.value + } + } + } + if (item instanceof vscode.LanguageModelToolCallPart) { + text += item.name + text += item.callId + if (item.input && Object.keys(item.input).length > 0) { + try { + text += JSON.stringify(item.input) + } catch (error) { + console.error("Roo Code : Failed to stringify tool call input:", error) + } + } + } + } + } else if (typeof message.content === "string") { + text += message.content + } + return text +}