Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions src/api/providers/__tests__/moonshot.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -294,4 +294,66 @@ describe("MoonshotHandler", () => {
expect(result.cacheReadTokens).toBeUndefined()
})
})

describe("addMaxTokensIfNeeded", () => {
it("should always add max_tokens regardless of includeMaxTokens option", () => {
// Create a test subclass to access the protected method
class TestMoonshotHandler extends MoonshotHandler {
public testAddMaxTokensIfNeeded(requestOptions: any, modelInfo: any) {
this.addMaxTokensIfNeeded(requestOptions, modelInfo)
}
}

const testHandler = new TestMoonshotHandler(mockOptions)
const requestOptions: any = {}
const modelInfo = {
maxTokens: 32_000,
}

// Test with includeMaxTokens set to false - should still add max tokens
testHandler.testAddMaxTokensIfNeeded(requestOptions, modelInfo)

expect(requestOptions.max_tokens).toBe(32_000)
})

it("should use modelMaxTokens when provided", () => {
class TestMoonshotHandler extends MoonshotHandler {
public testAddMaxTokensIfNeeded(requestOptions: any, modelInfo: any) {
this.addMaxTokensIfNeeded(requestOptions, modelInfo)
}
}

const customMaxTokens = 5000
const testHandler = new TestMoonshotHandler({
...mockOptions,
modelMaxTokens: customMaxTokens,
})
const requestOptions: any = {}
const modelInfo = {
maxTokens: 32_000,
}

testHandler.testAddMaxTokensIfNeeded(requestOptions, modelInfo)

expect(requestOptions.max_tokens).toBe(customMaxTokens)
})

it("should fall back to modelInfo.maxTokens when modelMaxTokens is not provided", () => {
class TestMoonshotHandler extends MoonshotHandler {
public testAddMaxTokensIfNeeded(requestOptions: any, modelInfo: any) {
this.addMaxTokensIfNeeded(requestOptions, modelInfo)
}
}

const testHandler = new TestMoonshotHandler(mockOptions)
const requestOptions: any = {}
const modelInfo = {
maxTokens: 16_000,
}

testHandler.testAddMaxTokensIfNeeded(requestOptions, modelInfo)

expect(requestOptions.max_tokens).toBe(16_000)
})
})
})
14 changes: 13 additions & 1 deletion src/api/providers/moonshot.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { moonshotModels, moonshotDefaultModelId } from "@roo-code/types"
import OpenAI from "openai"
import { moonshotModels, moonshotDefaultModelId, type ModelInfo } from "@roo-code/types"

import type { ApiHandlerOptions } from "../../shared/api"

Expand Down Expand Up @@ -36,4 +37,15 @@ export class MoonshotHandler extends OpenAiHandler {
cacheReadTokens: usage?.cached_tokens,
}
}

// Override to always include max_tokens for Moonshot (not max_completion_tokens)
protected override addMaxTokensIfNeeded(
requestOptions:
| OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming
| OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming,
modelInfo: ModelInfo,
): void {
// Moonshot uses max_tokens instead of max_completion_tokens
requestOptions.max_tokens = this.options.modelMaxTokens || modelInfo.maxTokens
}
}
2 changes: 1 addition & 1 deletion src/api/providers/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
* Note: max_tokens is deprecated in favor of max_completion_tokens as per OpenAI documentation
* O3 family models handle max_tokens separately in handleO3FamilyMessage
*/
private addMaxTokensIfNeeded(
protected addMaxTokensIfNeeded(
requestOptions:
| OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming
| OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming,
Expand Down
Loading