Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 202 additions & 0 deletions src/api/providers/__tests__/lite-llm.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -718,4 +718,206 @@ describe("LiteLLMHandler", () => {
})
})
})

describe("tool ID normalization", () => {
it("should truncate tool IDs longer than 64 characters", async () => {
const optionsWithBedrock: ApiHandlerOptions = {
...mockOptions,
litellmModelId: "bedrock/anthropic.claude-3-sonnet",
}
handler = new LiteLLMHandler(optionsWithBedrock)

vi.spyOn(handler as any, "fetchModel").mockResolvedValue({
id: "bedrock/anthropic.claude-3-sonnet",
info: { ...litellmDefaultModelInfo, maxTokens: 8192 },
})

// Create a tool ID longer than 64 characters
const longToolId = "toolu_" + "a".repeat(70) // 76 characters total

const systemPrompt = "You are a helpful assistant"
const messages: Anthropic.Messages.MessageParam[] = [
{ role: "user", content: "Hello" },
{
role: "assistant",
content: [
{ type: "text", text: "I'll help you with that." },
{ type: "tool_use", id: longToolId, name: "read_file", input: { path: "test.txt" } },
],
},
{
role: "user",
content: [{ type: "tool_result", tool_use_id: longToolId, content: "file contents" }],
},
]

const mockStream = {
async *[Symbol.asyncIterator]() {
yield {
choices: [{ delta: { content: "Response" } }],
usage: { prompt_tokens: 100, completion_tokens: 20 },
}
},
}

mockCreate.mockReturnValue({
withResponse: vi.fn().mockResolvedValue({ data: mockStream }),
})

const generator = handler.createMessage(systemPrompt, messages)
for await (const _chunk of generator) {
// Consume
}

// Verify that tool IDs are truncated to 64 characters or less
const createCall = mockCreate.mock.calls[0][0]
const assistantMessage = createCall.messages.find(
(msg: any) => msg.role === "assistant" && msg.tool_calls && msg.tool_calls.length > 0,
)
const toolMessage = createCall.messages.find((msg: any) => msg.role === "tool")

expect(assistantMessage).toBeDefined()
expect(assistantMessage.tool_calls[0].id.length).toBeLessThanOrEqual(64)

expect(toolMessage).toBeDefined()
expect(toolMessage.tool_call_id.length).toBeLessThanOrEqual(64)
})

it("should not modify tool IDs that are already within 64 characters", async () => {
const optionsWithBedrock: ApiHandlerOptions = {
...mockOptions,
litellmModelId: "bedrock/anthropic.claude-3-sonnet",
}
handler = new LiteLLMHandler(optionsWithBedrock)

vi.spyOn(handler as any, "fetchModel").mockResolvedValue({
id: "bedrock/anthropic.claude-3-sonnet",
info: { ...litellmDefaultModelInfo, maxTokens: 8192 },
})

// Create a tool ID within 64 characters
const shortToolId = "toolu_01ABC123" // Well under 64 characters

const systemPrompt = "You are a helpful assistant"
const messages: Anthropic.Messages.MessageParam[] = [
{ role: "user", content: "Hello" },
{
role: "assistant",
content: [
{ type: "text", text: "I'll help you with that." },
{ type: "tool_use", id: shortToolId, name: "read_file", input: { path: "test.txt" } },
],
},
{
role: "user",
content: [{ type: "tool_result", tool_use_id: shortToolId, content: "file contents" }],
},
]

const mockStream = {
async *[Symbol.asyncIterator]() {
yield {
choices: [{ delta: { content: "Response" } }],
usage: { prompt_tokens: 100, completion_tokens: 20 },
}
},
}

mockCreate.mockReturnValue({
withResponse: vi.fn().mockResolvedValue({ data: mockStream }),
})

const generator = handler.createMessage(systemPrompt, messages)
for await (const _chunk of generator) {
// Consume
}

// Verify that tool IDs are unchanged
const createCall = mockCreate.mock.calls[0][0]
const assistantMessage = createCall.messages.find(
(msg: any) => msg.role === "assistant" && msg.tool_calls && msg.tool_calls.length > 0,
)
const toolMessage = createCall.messages.find((msg: any) => msg.role === "tool")

expect(assistantMessage).toBeDefined()
expect(assistantMessage.tool_calls[0].id).toBe(shortToolId)

expect(toolMessage).toBeDefined()
expect(toolMessage.tool_call_id).toBe(shortToolId)
})

it("should maintain uniqueness with hash suffix when truncating", async () => {
const optionsWithBedrock: ApiHandlerOptions = {
...mockOptions,
litellmModelId: "bedrock/anthropic.claude-3-sonnet",
}
handler = new LiteLLMHandler(optionsWithBedrock)

vi.spyOn(handler as any, "fetchModel").mockResolvedValue({
id: "bedrock/anthropic.claude-3-sonnet",
info: { ...litellmDefaultModelInfo, maxTokens: 8192 },
})

// Create two tool IDs that differ only near the end
const longToolId1 = "toolu_" + "a".repeat(60) + "_suffix1"
const longToolId2 = "toolu_" + "a".repeat(60) + "_suffix2"

const systemPrompt = "You are a helpful assistant"
const messages: Anthropic.Messages.MessageParam[] = [
{ role: "user", content: "Hello" },
{
role: "assistant",
content: [
{ type: "text", text: "I'll help." },
{ type: "tool_use", id: longToolId1, name: "read_file", input: { path: "test1.txt" } },
{ type: "tool_use", id: longToolId2, name: "read_file", input: { path: "test2.txt" } },
],
},
{
role: "user",
content: [
{ type: "tool_result", tool_use_id: longToolId1, content: "file1 contents" },
{ type: "tool_result", tool_use_id: longToolId2, content: "file2 contents" },
],
},
]

const mockStream = {
async *[Symbol.asyncIterator]() {
yield {
choices: [{ delta: { content: "Response" } }],
usage: { prompt_tokens: 100, completion_tokens: 20 },
}
},
}

mockCreate.mockReturnValue({
withResponse: vi.fn().mockResolvedValue({ data: mockStream }),
})

const generator = handler.createMessage(systemPrompt, messages)
for await (const _chunk of generator) {
// Consume
}

// Verify that truncated tool IDs are unique (hash suffix ensures this)
const createCall = mockCreate.mock.calls[0][0]
const assistantMessage = createCall.messages.find(
(msg: any) => msg.role === "assistant" && msg.tool_calls && msg.tool_calls.length > 0,
)

expect(assistantMessage).toBeDefined()
expect(assistantMessage.tool_calls).toHaveLength(2)

const id1 = assistantMessage.tool_calls[0].id
const id2 = assistantMessage.tool_calls[1].id

// Both should be truncated to 64 characters
expect(id1.length).toBeLessThanOrEqual(64)
expect(id2.length).toBeLessThanOrEqual(64)

// They should be different (hash suffix ensures uniqueness)
expect(id1).not.toBe(id2)
})
})
})
60 changes: 60 additions & 0 deletions src/api/providers/__tests__/vscode-lm.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,66 @@ describe("VsCodeLmHandler", () => {
})
})

describe("countTokens", () => {
beforeEach(() => {
handler["client"] = mockLanguageModelChat
})

it("should count tokens when called outside of an active request", async () => {
// Ensure no active request cancellation token exists
handler["currentRequestCancellation"] = null

mockLanguageModelChat.countTokens.mockResolvedValueOnce(42)

const content: Anthropic.Messages.ContentBlockParam[] = [{ type: "text", text: "Hello world" }]
const result = await handler.countTokens(content)

expect(result).toBe(42)
expect(mockLanguageModelChat.countTokens).toHaveBeenCalledWith("Hello world", expect.any(Object))
})

it("should count tokens when called during an active request", async () => {
// Simulate an active request with a cancellation token
const mockCancellation = {
token: { isCancellationRequested: false, onCancellationRequested: vi.fn() },
cancel: vi.fn(),
dispose: vi.fn(),
}
handler["currentRequestCancellation"] = mockCancellation as any

mockLanguageModelChat.countTokens.mockResolvedValueOnce(50)

const content: Anthropic.Messages.ContentBlockParam[] = [{ type: "text", text: "Test content" }]
const result = await handler.countTokens(content)

expect(result).toBe(50)
expect(mockLanguageModelChat.countTokens).toHaveBeenCalledWith("Test content", mockCancellation.token)
})

it("should return 0 when no client is available", async () => {
handler["client"] = null
handler["currentRequestCancellation"] = null

const content: Anthropic.Messages.ContentBlockParam[] = [{ type: "text", text: "Hello" }]
const result = await handler.countTokens(content)

expect(result).toBe(0)
})

it("should handle image blocks with placeholder", async () => {
handler["currentRequestCancellation"] = null
mockLanguageModelChat.countTokens.mockResolvedValueOnce(5)

const content: Anthropic.Messages.ContentBlockParam[] = [
{ type: "image", source: { type: "base64", media_type: "image/png", data: "abc" } },
]
const result = await handler.countTokens(content)

expect(result).toBe(5)
expect(mockLanguageModelChat.countTokens).toHaveBeenCalledWith("[IMAGE]", expect.any(Object))
})
})

describe("completePrompt", () => {
it("should complete single prompt", async () => {
const mockModel = { ...mockLanguageModelChat }
Expand Down
5 changes: 4 additions & 1 deletion src/api/providers/lite-llm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import { ApiHandlerOptions } from "../../shared/api"

import { ApiStream, ApiStreamUsageChunk } from "../transform/stream"
import { convertToOpenAiMessages } from "../transform/openai-format"
import { sanitizeOpenAiCallId } from "../../utils/tool-id"

import type { SingleCompletionHandler, ApiHandlerCreateMessageMetadata } from "../index"
import { RouterProvider } from "./router-provider"
Expand Down Expand Up @@ -115,7 +116,9 @@ export class LiteLLMHandler extends RouterProvider implements SingleCompletionHa
): ApiStream {
const { id: modelId, info } = await this.fetchModel()

const openAiMessages = convertToOpenAiMessages(messages)
const openAiMessages = convertToOpenAiMessages(messages, {
normalizeToolCallId: sanitizeOpenAiCallId,
})

// Prepare messages with cache control if enabled and supported
let systemMessage: OpenAI.Chat.ChatCompletionMessageParam
Expand Down
25 changes: 18 additions & 7 deletions src/api/providers/vscode-lm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -229,31 +229,37 @@ export class VsCodeLmHandler extends BaseProvider implements SingleCompletionHan
return 0
}

if (!this.currentRequestCancellation) {
console.warn("Roo Code <Language Model API>: No cancellation token available for token counting")
return 0
}

// Validate input
if (!text) {
console.debug("Roo Code <Language Model API>: Empty text provided for token counting")
return 0
}

// Create a temporary cancellation token if we don't have one (e.g., when called outside a request)
let cancellationToken: vscode.CancellationToken
let tempCancellation: vscode.CancellationTokenSource | null = null

if (this.currentRequestCancellation) {
cancellationToken = this.currentRequestCancellation.token
} else {
tempCancellation = new vscode.CancellationTokenSource()
cancellationToken = tempCancellation.token
}

try {
// Handle different input types
let tokenCount: number

if (typeof text === "string") {
tokenCount = await this.client.countTokens(text, this.currentRequestCancellation.token)
tokenCount = await this.client.countTokens(text, cancellationToken)
} else if (text instanceof vscode.LanguageModelChatMessage) {
// For chat messages, ensure we have content
if (!text.content || (Array.isArray(text.content) && text.content.length === 0)) {
console.debug("Roo Code <Language Model API>: Empty chat message content")
return 0
}
const countMessage = extractTextCountFromMessage(text)
tokenCount = await this.client.countTokens(countMessage, this.currentRequestCancellation.token)
tokenCount = await this.client.countTokens(countMessage, cancellationToken)
} else {
console.warn("Roo Code <Language Model API>: Invalid input type for token counting")
return 0
Expand Down Expand Up @@ -287,6 +293,11 @@ export class VsCodeLmHandler extends BaseProvider implements SingleCompletionHan
}

return 0 // Fallback to prevent stream interruption
} finally {
// Clean up temporary cancellation token
if (tempCancellation) {
tempCancellation.dispose()
}
}
}

Expand Down
Loading