Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/types/src/global-settings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ export const SECRET_STATE_KEYS = [
"geminiApiKey",
"openAiNativeApiKey",
"deepSeekApiKey",
"moonshotApiKey",
"mistralApiKey",
"unboundApiKey",
"requestyApiKey",
Expand Down
10 changes: 10 additions & 0 deletions packages/types/src/provider-settings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ export const providerNames = [
"gemini-cli",
"openai-native",
"mistral",
"moonshot",
"deepseek",
"unbound",
"requesty",
Expand Down Expand Up @@ -186,6 +187,13 @@ const deepSeekSchema = apiModelIdProviderModelSchema.extend({
deepSeekApiKey: z.string().optional(),
})

const moonshotSchema = apiModelIdProviderModelSchema.extend({
moonshotBaseUrl: z
.union([z.literal("https://api.moonshot.ai/v1"), z.literal("https://api.moonshot.cn/v1")])
.optional(),
moonshotApiKey: z.string().optional(),
})

const unboundSchema = baseProviderSettingsSchema.extend({
unboundApiKey: z.string().optional(),
unboundModelId: z.string().optional(),
Expand Down Expand Up @@ -240,6 +248,7 @@ export const providerSettingsSchemaDiscriminated = z.discriminatedUnion("apiProv
openAiNativeSchema.merge(z.object({ apiProvider: z.literal("openai-native") })),
mistralSchema.merge(z.object({ apiProvider: z.literal("mistral") })),
deepSeekSchema.merge(z.object({ apiProvider: z.literal("deepseek") })),
moonshotSchema.merge(z.object({ apiProvider: z.literal("moonshot") })),
unboundSchema.merge(z.object({ apiProvider: z.literal("unbound") })),
requestySchema.merge(z.object({ apiProvider: z.literal("requesty") })),
humanRelaySchema.merge(z.object({ apiProvider: z.literal("human-relay") })),
Expand Down Expand Up @@ -268,6 +277,7 @@ export const providerSettingsSchema = z.object({
...openAiNativeSchema.shape,
...mistralSchema.shape,
...deepSeekSchema.shape,
...moonshotSchema.shape,
...unboundSchema.shape,
...requestySchema.shape,
...humanRelaySchema.shape,
Expand Down
1 change: 1 addition & 0 deletions packages/types/src/providers/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ export * from "./groq.js"
export * from "./lite-llm.js"
export * from "./lm-studio.js"
export * from "./mistral.js"
export * from "./moonshot.js"
export * from "./ollama.js"
export * from "./openai.js"
export * from "./openrouter.js"
Expand Down
22 changes: 22 additions & 0 deletions packages/types/src/providers/moonshot.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import type { ModelInfo } from "../model.js"

// https://platform.moonshot.ai/
export type MoonshotModelId = keyof typeof moonshotModels

export const moonshotDefaultModelId: MoonshotModelId = "kimi-k2-0711-preview"

export const moonshotModels = {
"kimi-k2-0711-preview": {
maxTokens: 32_000,
contextWindow: 131_072,
supportsImages: false,
supportsPromptCache: true,
inputPrice: 0.6, // $0.60 per million tokens (cache miss)
outputPrice: 2.5, // $2.50 per million tokens
cacheWritesPrice: 0, // $0 per million tokens (cache miss)
cacheReadsPrice: 0.15, // $0.15 per million tokens (cache hit)
description: `Kimi K2 is a state-of-the-art mixture-of-experts (MoE) language model with 32 billion activated parameters and 1 trillion total parameters.`,
},
} as const satisfies Record<string, ModelInfo>

export const MOONSHOT_DEFAULT_TEMPERATURE = 0.6
4 changes: 4 additions & 0 deletions src/api/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import {
GeminiHandler,
OpenAiNativeHandler,
DeepSeekHandler,
MoonshotHandler,
MistralHandler,
VsCodeLmHandler,
UnboundHandler,
Expand Down Expand Up @@ -89,6 +90,8 @@ export function buildApiHandler(configuration: ProviderSettings): ApiHandler {
return new OpenAiNativeHandler(options)
case "deepseek":
return new DeepSeekHandler(options)
case "moonshot":
return new MoonshotHandler(options)
case "vscode-lm":
return new VsCodeLmHandler(options)
case "mistral":
Expand All @@ -110,6 +113,7 @@ export function buildApiHandler(configuration: ProviderSettings): ApiHandler {
case "litellm":
return new LiteLLMHandler(options)
default:
apiProvider satisfies "gemini-cli" | undefined
return new AnthropicHandler(options)
}
}
297 changes: 297 additions & 0 deletions src/api/providers/__tests__/moonshot.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,297 @@
// Mocks must come first, before imports
const mockCreate = vi.fn()
vi.mock("openai", () => {
return {
__esModule: true,
default: vi.fn().mockImplementation(() => ({
chat: {
completions: {
create: mockCreate.mockImplementation(async (options) => {
if (!options.stream) {
return {
id: "test-completion",
choices: [
{
message: { role: "assistant", content: "Test response", refusal: null },
finish_reason: "stop",
index: 0,
},
],
usage: {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15,
cached_tokens: 2,
},
}
}

// Return async iterator for streaming
return {
[Symbol.asyncIterator]: async function* () {
yield {
choices: [
{
delta: { content: "Test response" },
index: 0,
},
],
usage: null,
}
yield {
choices: [
{
delta: {},
index: 0,
},
],
usage: {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15,
cached_tokens: 2,
},
}
},
}
}),
},
},
})),
}
})

import OpenAI from "openai"
import type { Anthropic } from "@anthropic-ai/sdk"

import { moonshotDefaultModelId } from "@roo-code/types"

import type { ApiHandlerOptions } from "../../../shared/api"

import { MoonshotHandler } from "../moonshot"

describe("MoonshotHandler", () => {
let handler: MoonshotHandler
let mockOptions: ApiHandlerOptions

beforeEach(() => {
mockOptions = {
moonshotApiKey: "test-api-key",
apiModelId: "moonshot-chat",
moonshotBaseUrl: "https://api.moonshot.ai/v1",
}
handler = new MoonshotHandler(mockOptions)
vi.clearAllMocks()
})

describe("constructor", () => {
it("should initialize with provided options", () => {
expect(handler).toBeInstanceOf(MoonshotHandler)
expect(handler.getModel().id).toBe(mockOptions.apiModelId)
})

it.skip("should throw error if API key is missing", () => {
expect(() => {
new MoonshotHandler({
...mockOptions,
moonshotApiKey: undefined,
})
}).toThrow("Moonshot API key is required")
})

it("should use default model ID if not provided", () => {
const handlerWithoutModel = new MoonshotHandler({
...mockOptions,
apiModelId: undefined,
})
expect(handlerWithoutModel.getModel().id).toBe(moonshotDefaultModelId)
})

it("should use default base URL if not provided", () => {
const handlerWithoutBaseUrl = new MoonshotHandler({
...mockOptions,
moonshotBaseUrl: undefined,
})
expect(handlerWithoutBaseUrl).toBeInstanceOf(MoonshotHandler)
// The base URL is passed to OpenAI client internally
expect(OpenAI).toHaveBeenCalledWith(
expect.objectContaining({
baseURL: "https://api.moonshot.ai/v1",
}),
)
})

it("should use chinese base URL if provided", () => {
const customBaseUrl = "https://api.moonshot.cn/v1"
const handlerWithCustomUrl = new MoonshotHandler({
...mockOptions,
moonshotBaseUrl: customBaseUrl,
})
expect(handlerWithCustomUrl).toBeInstanceOf(MoonshotHandler)
// The custom base URL is passed to OpenAI client
expect(OpenAI).toHaveBeenCalledWith(
expect.objectContaining({
baseURL: customBaseUrl,
}),
)
})

it("should set includeMaxTokens to true", () => {
// Create a new handler and verify OpenAI client was called with includeMaxTokens
const _handler = new MoonshotHandler(mockOptions)
expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({ apiKey: mockOptions.moonshotApiKey }))
})
})

describe("getModel", () => {
it("should return model info for valid model ID", () => {
const model = handler.getModel()
expect(model.id).toBe(mockOptions.apiModelId)
expect(model.info).toBeDefined()
expect(model.info.maxTokens).toBe(32_000)
expect(model.info.contextWindow).toBe(131_072)
expect(model.info.supportsImages).toBe(false)
expect(model.info.supportsPromptCache).toBe(true) // Should be true now
})

it("should return provided model ID with default model info if model does not exist", () => {
const handlerWithInvalidModel = new MoonshotHandler({
...mockOptions,
apiModelId: "invalid-model",
})
const model = handlerWithInvalidModel.getModel()
expect(model.id).toBe("invalid-model") // Returns provided ID
expect(model.info).toBeDefined()
// With the current implementation, it's the same object reference when using default model info
expect(model.info).toBe(handler.getModel().info)
// Should have the same base properties
expect(model.info.contextWindow).toBe(handler.getModel().info.contextWindow)
// And should have supportsPromptCache set to true
expect(model.info.supportsPromptCache).toBe(true)
})

it("should return default model if no model ID is provided", () => {
const handlerWithoutModel = new MoonshotHandler({
...mockOptions,
apiModelId: undefined,
})
const model = handlerWithoutModel.getModel()
expect(model.id).toBe(moonshotDefaultModelId)
expect(model.info).toBeDefined()
expect(model.info.supportsPromptCache).toBe(true)
})

it("should include model parameters from getModelParams", () => {
const model = handler.getModel()
expect(model).toHaveProperty("temperature")
expect(model).toHaveProperty("maxTokens")
})
})

describe("createMessage", () => {
const systemPrompt = "You are a helpful assistant."
const messages: Anthropic.Messages.MessageParam[] = [
{
role: "user",
content: [
{
type: "text" as const,
text: "Hello!",
},
],
},
]

it("should handle streaming responses", async () => {
const stream = handler.createMessage(systemPrompt, messages)
const chunks: any[] = []
for await (const chunk of stream) {
chunks.push(chunk)
}

expect(chunks.length).toBeGreaterThan(0)
const textChunks = chunks.filter((chunk) => chunk.type === "text")
expect(textChunks).toHaveLength(1)
expect(textChunks[0].text).toBe("Test response")
})

it("should include usage information", async () => {
const stream = handler.createMessage(systemPrompt, messages)
const chunks: any[] = []
for await (const chunk of stream) {
chunks.push(chunk)
}

const usageChunks = chunks.filter((chunk) => chunk.type === "usage")
expect(usageChunks.length).toBeGreaterThan(0)
expect(usageChunks[0].inputTokens).toBe(10)
expect(usageChunks[0].outputTokens).toBe(5)
})

it("should include cache metrics in usage information", async () => {
const stream = handler.createMessage(systemPrompt, messages)
const chunks: any[] = []
for await (const chunk of stream) {
chunks.push(chunk)
}

const usageChunks = chunks.filter((chunk) => chunk.type === "usage")
expect(usageChunks.length).toBeGreaterThan(0)
expect(usageChunks[0].cacheWriteTokens).toBe(0)
expect(usageChunks[0].cacheReadTokens).toBe(2)
})
})

describe("processUsageMetrics", () => {
it("should correctly process usage metrics including cache information", () => {
// We need to access the protected method, so we'll create a test subclass
class TestMoonshotHandler extends MoonshotHandler {
public testProcessUsageMetrics(usage: any) {
return this.processUsageMetrics(usage)
}
}

const testHandler = new TestMoonshotHandler(mockOptions)

const usage = {
prompt_tokens: 100,
completion_tokens: 50,
total_tokens: 150,
cached_tokens: 20,
}

const result = testHandler.testProcessUsageMetrics(usage)

expect(result.type).toBe("usage")
expect(result.inputTokens).toBe(100)
expect(result.outputTokens).toBe(50)
expect(result.cacheWriteTokens).toBe(0)
expect(result.cacheReadTokens).toBe(20)
})

it("should handle missing cache metrics gracefully", () => {
class TestMoonshotHandler extends MoonshotHandler {
public testProcessUsageMetrics(usage: any) {
return this.processUsageMetrics(usage)
}
}

const testHandler = new TestMoonshotHandler(mockOptions)

const usage = {
prompt_tokens: 100,
completion_tokens: 50,
total_tokens: 150,
// No cached_tokens
}

const result = testHandler.testProcessUsageMetrics(usage)

expect(result.type).toBe("usage")
expect(result.inputTokens).toBe(100)
expect(result.outputTokens).toBe(50)
expect(result.cacheWriteTokens).toBe(0)
expect(result.cacheReadTokens).toBeUndefined()
})
})
})
Loading