Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
257 changes: 245 additions & 12 deletions src/api/providers/__tests__/lite-llm.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,9 @@ import { litellmDefaultModelId, litellmDefaultModelInfo } from "@roo-code/types"
vi.mock("vscode", () => ({}))

// Mock OpenAI
vi.mock("openai", () => {
const mockStream = {
[Symbol.asyncIterator]: vi.fn(),
}

const mockCreate = vi.fn().mockReturnValue({
withResponse: vi.fn().mockResolvedValue({ data: mockStream }),
})
const mockCreate = vi.fn()

vi.mock("openai", () => {
return {
default: vi.fn().mockImplementation(() => ({
chat: {
Expand All @@ -35,14 +29,25 @@ vi.mock("../fetchers/modelCache", () => ({
getModels: vi.fn().mockImplementation(() => {
return Promise.resolve({
[litellmDefaultModelId]: litellmDefaultModelInfo,
"gpt-5": { ...litellmDefaultModelInfo, maxTokens: 8192 },
gpt5: { ...litellmDefaultModelInfo, maxTokens: 8192 },
"GPT-5": { ...litellmDefaultModelInfo, maxTokens: 8192 },
"gpt-5-turbo": { ...litellmDefaultModelInfo, maxTokens: 8192 },
"gpt5-preview": { ...litellmDefaultModelInfo, maxTokens: 8192 },
"gpt-5o": { ...litellmDefaultModelInfo, maxTokens: 8192 },
"gpt-5.1": { ...litellmDefaultModelInfo, maxTokens: 8192 },
"gpt-5-mini": { ...litellmDefaultModelInfo, maxTokens: 8192 },
"gpt-4": { ...litellmDefaultModelInfo, maxTokens: 8192 },
"claude-3-opus": { ...litellmDefaultModelInfo, maxTokens: 8192 },
"llama-3": { ...litellmDefaultModelInfo, maxTokens: 8192 },
"gpt-4-turbo": { ...litellmDefaultModelInfo, maxTokens: 8192 },
})
}),
}))

describe("LiteLLMHandler", () => {
let handler: LiteLLMHandler
let mockOptions: ApiHandlerOptions
let mockOpenAIClient: any

beforeEach(() => {
vi.clearAllMocks()
Expand All @@ -52,7 +57,6 @@ describe("LiteLLMHandler", () => {
litellmModelId: litellmDefaultModelId,
}
handler = new LiteLLMHandler(mockOptions)
mockOpenAIClient = new OpenAI()
})

describe("prompt caching", () => {
Expand Down Expand Up @@ -85,7 +89,7 @@ describe("LiteLLMHandler", () => {
},
}

mockOpenAIClient.chat.completions.create.mockReturnValue({
mockCreate.mockReturnValue({
withResponse: vi.fn().mockResolvedValue({ data: mockStream }),
})

Expand All @@ -96,7 +100,7 @@ describe("LiteLLMHandler", () => {
}

// Verify that create was called with cache control headers
const createCall = mockOpenAIClient.chat.completions.create.mock.calls[0][0]
const createCall = mockCreate.mock.calls[0][0]

// Check system message has cache control in the proper format
expect(createCall.messages[0]).toMatchObject({
Expand Down Expand Up @@ -155,4 +159,233 @@ describe("LiteLLMHandler", () => {
})
})
})

describe("GPT-5 model handling", () => {
it("should use max_completion_tokens instead of max_tokens for GPT-5 models", async () => {
const optionsWithGPT5: ApiHandlerOptions = {
...mockOptions,
litellmModelId: "gpt-5",
}
handler = new LiteLLMHandler(optionsWithGPT5)

const systemPrompt = "You are a helpful assistant"
const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Hello" }]

// Mock the stream response
const mockStream = {
async *[Symbol.asyncIterator]() {
yield {
choices: [{ delta: { content: "Hello!" } }],
usage: {
prompt_tokens: 10,
completion_tokens: 5,
},
}
},
}

mockCreate.mockReturnValue({
withResponse: vi.fn().mockResolvedValue({ data: mockStream }),
})

const generator = handler.createMessage(systemPrompt, messages)
const results = []
for await (const chunk of generator) {
results.push(chunk)
}

// Verify that create was called with max_completion_tokens instead of max_tokens
const createCall = mockCreate.mock.calls[0][0]

// Should have max_completion_tokens, not max_tokens
expect(createCall.max_completion_tokens).toBeDefined()
expect(createCall.max_tokens).toBeUndefined()
})

it("should use max_completion_tokens for various GPT-5 model variations", async () => {
const gpt5Variations = [
"gpt-5",
"gpt5",
"GPT-5",
"gpt-5-turbo",
"gpt5-preview",
"gpt-5o",
"gpt-5.1",
"gpt-5-mini",
]

for (const modelId of gpt5Variations) {
vi.clearAllMocks()

const optionsWithGPT5: ApiHandlerOptions = {
...mockOptions,
litellmModelId: modelId,
}
handler = new LiteLLMHandler(optionsWithGPT5)

const systemPrompt = "You are a helpful assistant"
const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Test" }]

// Mock the stream response
const mockStream = {
async *[Symbol.asyncIterator]() {
yield {
choices: [{ delta: { content: "Response" } }],
usage: {
prompt_tokens: 10,
completion_tokens: 5,
},
}
},
}

mockCreate.mockReturnValue({
withResponse: vi.fn().mockResolvedValue({ data: mockStream }),
})

const generator = handler.createMessage(systemPrompt, messages)
for await (const chunk of generator) {
// Consume the generator
}

// Verify that create was called with max_completion_tokens for this model variation
const createCall = mockCreate.mock.calls[0][0]

expect(createCall.max_completion_tokens).toBeDefined()
expect(createCall.max_tokens).toBeUndefined()
}
})

it("should still use max_tokens for non-GPT-5 models", async () => {
const nonGPT5Models = ["gpt-4", "claude-3-opus", "llama-3", "gpt-4-turbo"]

for (const modelId of nonGPT5Models) {
vi.clearAllMocks()

const options: ApiHandlerOptions = {
...mockOptions,
litellmModelId: modelId,
}
handler = new LiteLLMHandler(options)

const systemPrompt = "You are a helpful assistant"
const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Test" }]

// Mock the stream response
const mockStream = {
async *[Symbol.asyncIterator]() {
yield {
choices: [{ delta: { content: "Response" } }],
usage: {
prompt_tokens: 10,
completion_tokens: 5,
},
}
},
}

mockCreate.mockReturnValue({
withResponse: vi.fn().mockResolvedValue({ data: mockStream }),
})

const generator = handler.createMessage(systemPrompt, messages)
for await (const chunk of generator) {
// Consume the generator
}

// Verify that create was called with max_tokens for non-GPT-5 models
const createCall = mockCreate.mock.calls[0][0]

expect(createCall.max_tokens).toBeDefined()
expect(createCall.max_completion_tokens).toBeUndefined()
}
})

it("should use max_completion_tokens in completePrompt for GPT-5 models", async () => {
const optionsWithGPT5: ApiHandlerOptions = {
...mockOptions,
litellmModelId: "gpt-5",
}
handler = new LiteLLMHandler(optionsWithGPT5)

mockCreate.mockResolvedValue({
choices: [{ message: { content: "Test response" } }],
})

await handler.completePrompt("Test prompt")

// Verify that create was called with max_completion_tokens
const createCall = mockCreate.mock.calls[0][0]

expect(createCall.max_completion_tokens).toBeDefined()
expect(createCall.max_tokens).toBeUndefined()
})

it("should not set any max token fields when maxTokens is undefined (GPT-5 streaming)", async () => {
const optionsWithGPT5: ApiHandlerOptions = {
...mockOptions,
litellmModelId: "gpt-5",
}
handler = new LiteLLMHandler(optionsWithGPT5)

// Force fetchModel to return undefined maxTokens
vi.spyOn(handler as any, "fetchModel").mockResolvedValue({
id: "gpt-5",
info: { ...litellmDefaultModelInfo, maxTokens: undefined },
})

// Mock the stream response
const mockStream = {
async *[Symbol.asyncIterator]() {
yield {
choices: [{ delta: { content: "Hello!" } }],
usage: {
prompt_tokens: 10,
completion_tokens: 5,
},
}
},
}

mockCreate.mockReturnValue({
withResponse: vi.fn().mockResolvedValue({ data: mockStream }),
})

const generator = handler.createMessage("You are a helpful assistant", [
{ role: "user", content: "Hello" } as unknown as Anthropic.Messages.MessageParam,
])
for await (const _chunk of generator) {
// consume
}

// Should not include either token field
const createCall = mockCreate.mock.calls[0][0]
expect(createCall.max_tokens).toBeUndefined()
expect(createCall.max_completion_tokens).toBeUndefined()
})

it("should not set any max token fields when maxTokens is undefined (GPT-5 completePrompt)", async () => {
const optionsWithGPT5: ApiHandlerOptions = {
...mockOptions,
litellmModelId: "gpt-5",
}
handler = new LiteLLMHandler(optionsWithGPT5)

// Force fetchModel to return undefined maxTokens
vi.spyOn(handler as any, "fetchModel").mockResolvedValue({
id: "gpt-5",
info: { ...litellmDefaultModelInfo, maxTokens: undefined },
})

mockCreate.mockResolvedValue({
choices: [{ message: { content: "Ok" } }],
})

await handler.completePrompt("Test prompt")

const createCall = mockCreate.mock.calls[0][0]
expect(createCall.max_tokens).toBeUndefined()
expect(createCall.max_completion_tokens).toBeUndefined()
})
})
})
27 changes: 25 additions & 2 deletions src/api/providers/lite-llm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ export class LiteLLMHandler extends RouterProvider implements SingleCompletionHa
})
}

private isGpt5(modelId: string): boolean {
// Match gpt-5, gpt5, and variants like gpt-5o, gpt-5-turbo, gpt5-preview, gpt-5.1
// Avoid matching gpt-50, gpt-500, etc.
return /\bgpt-?5(?!\d)/i.test(modelId)
}

override async *createMessage(
systemPrompt: string,
messages: Anthropic.Messages.MessageParam[],
Expand Down Expand Up @@ -107,16 +113,25 @@ export class LiteLLMHandler extends RouterProvider implements SingleCompletionHa
// Required by some providers; others default to max tokens allowed
let maxTokens: number | undefined = info.maxTokens ?? undefined

// Check if this is a GPT-5 model that requires max_completion_tokens instead of max_tokens
const isGPT5Model = this.isGpt5(modelId)

const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
model: modelId,
max_tokens: maxTokens,
messages: [systemMessage, ...enhancedMessages],
stream: true,
stream_options: {
include_usage: true,
},
}

// GPT-5 models require max_completion_tokens instead of the deprecated max_tokens parameter
if (isGPT5Model && maxTokens) {
requestOptions.max_completion_tokens = maxTokens
} else if (maxTokens) {
requestOptions.max_tokens = maxTokens
}

if (this.supportsTemperature(modelId)) {
requestOptions.temperature = this.options.modelTemperature ?? 0
}
Expand Down Expand Up @@ -179,6 +194,9 @@ export class LiteLLMHandler extends RouterProvider implements SingleCompletionHa
async completePrompt(prompt: string): Promise<string> {
const { id: modelId, info } = await this.fetchModel()

// Check if this is a GPT-5 model that requires max_completion_tokens instead of max_tokens
const isGPT5Model = this.isGpt5(modelId)

try {
const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
model: modelId,
Expand All @@ -189,7 +207,12 @@ export class LiteLLMHandler extends RouterProvider implements SingleCompletionHa
requestOptions.temperature = this.options.modelTemperature ?? 0
}

requestOptions.max_tokens = info.maxTokens
// GPT-5 models require max_completion_tokens instead of the deprecated max_tokens parameter
if (isGPT5Model && info.maxTokens) {
requestOptions.max_completion_tokens = info.maxTokens
} else if (info.maxTokens) {
requestOptions.max_tokens = info.maxTokens
}

const response = await this.client.chat.completions.create(requestOptions)
return response.choices[0]?.message.content || ""
Expand Down
Loading