Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions src/api/providers/__tests__/vscode-lm.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,66 @@ describe("VsCodeLmHandler", () => {
})
})

describe("countTokens", () => {
beforeEach(() => {
handler["client"] = mockLanguageModelChat
})

it("should count tokens when called outside of an active request", async () => {
// Ensure no active request cancellation token exists
handler["currentRequestCancellation"] = null

mockLanguageModelChat.countTokens.mockResolvedValueOnce(42)

const content: Anthropic.Messages.ContentBlockParam[] = [{ type: "text", text: "Hello world" }]
const result = await handler.countTokens(content)

expect(result).toBe(42)
expect(mockLanguageModelChat.countTokens).toHaveBeenCalledWith("Hello world", expect.any(Object))
})

it("should count tokens when called during an active request", async () => {
// Simulate an active request with a cancellation token
const mockCancellation = {
token: { isCancellationRequested: false, onCancellationRequested: vi.fn() },
cancel: vi.fn(),
dispose: vi.fn(),
}
handler["currentRequestCancellation"] = mockCancellation as any

mockLanguageModelChat.countTokens.mockResolvedValueOnce(50)

const content: Anthropic.Messages.ContentBlockParam[] = [{ type: "text", text: "Test content" }]
const result = await handler.countTokens(content)

expect(result).toBe(50)
expect(mockLanguageModelChat.countTokens).toHaveBeenCalledWith("Test content", mockCancellation.token)
})

it("should return 0 when no client is available", async () => {
handler["client"] = null
handler["currentRequestCancellation"] = null

const content: Anthropic.Messages.ContentBlockParam[] = [{ type: "text", text: "Hello" }]
const result = await handler.countTokens(content)

expect(result).toBe(0)
})

it("should handle image blocks with placeholder", async () => {
handler["currentRequestCancellation"] = null
mockLanguageModelChat.countTokens.mockResolvedValueOnce(5)

const content: Anthropic.Messages.ContentBlockParam[] = [
{ type: "image", source: { type: "base64", media_type: "image/png", data: "abc" } },
]
const result = await handler.countTokens(content)

expect(result).toBe(5)
expect(mockLanguageModelChat.countTokens).toHaveBeenCalledWith("[IMAGE]", expect.any(Object))
})
})

describe("completePrompt", () => {
it("should complete single prompt", async () => {
const mockModel = { ...mockLanguageModelChat }
Expand Down
25 changes: 18 additions & 7 deletions src/api/providers/vscode-lm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -229,31 +229,37 @@ export class VsCodeLmHandler extends BaseProvider implements SingleCompletionHan
return 0
}

if (!this.currentRequestCancellation) {
console.warn("Roo Code <Language Model API>: No cancellation token available for token counting")
return 0
}

// Validate input
if (!text) {
console.debug("Roo Code <Language Model API>: Empty text provided for token counting")
return 0
}

// Create a temporary cancellation token if we don't have one (e.g., when called outside a request)
let cancellationToken: vscode.CancellationToken
let tempCancellation: vscode.CancellationTokenSource | null = null

if (this.currentRequestCancellation) {
cancellationToken = this.currentRequestCancellation.token
} else {
tempCancellation = new vscode.CancellationTokenSource()
cancellationToken = tempCancellation.token
}

try {
// Handle different input types
let tokenCount: number

if (typeof text === "string") {
tokenCount = await this.client.countTokens(text, this.currentRequestCancellation.token)
tokenCount = await this.client.countTokens(text, cancellationToken)
} else if (text instanceof vscode.LanguageModelChatMessage) {
// For chat messages, ensure we have content
if (!text.content || (Array.isArray(text.content) && text.content.length === 0)) {
console.debug("Roo Code <Language Model API>: Empty chat message content")
return 0
}
const countMessage = extractTextCountFromMessage(text)
tokenCount = await this.client.countTokens(countMessage, this.currentRequestCancellation.token)
tokenCount = await this.client.countTokens(countMessage, cancellationToken)
} else {
console.warn("Roo Code <Language Model API>: Invalid input type for token counting")
return 0
Expand Down Expand Up @@ -287,6 +293,11 @@ export class VsCodeLmHandler extends BaseProvider implements SingleCompletionHan
}

return 0 // Fallback to prevent stream interruption
} finally {
// Clean up temporary cancellation token
if (tempCancellation) {
tempCancellation.dispose()
}
}
}

Expand Down
Loading