Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/ISSUE_TEMPLATE/bug_report.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ body:
- OpenAI Compatible
- OpenRouter
- Requesty
- SambaNova
- Unbound
- VS Code Language Model API
- xAI (Grok)
Expand Down
1 change: 1 addition & 0 deletions packages/types/src/global-settings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ export const SECRET_STATE_KEYS = [
"codebaseIndexGeminiApiKey",
"codebaseIndexMistralApiKey",
"huggingFaceApiKey",
"sambaNovaApiKey",
] as const satisfies readonly (keyof ProviderSettings)[]
export type SecretState = Pick<ProviderSettings, (typeof SECRET_STATE_KEYS)[number]>

Expand Down
7 changes: 7 additions & 0 deletions packages/types/src/provider-settings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ export const providerNames = [
"chutes",
"litellm",
"huggingface",
"sambanova",
] as const

export const providerNamesSchema = z.enum(providerNames)
Expand Down Expand Up @@ -240,6 +241,10 @@ const litellmSchema = baseProviderSettingsSchema.extend({
litellmModelId: z.string().optional(),
})

const sambaNovaSchema = apiModelIdProviderModelSchema.extend({
sambaNovaApiKey: z.string().optional(),
})

const defaultSchema = z.object({
apiProvider: z.undefined(),
})
Expand Down Expand Up @@ -270,6 +275,7 @@ export const providerSettingsSchemaDiscriminated = z.discriminatedUnion("apiProv
huggingFaceSchema.merge(z.object({ apiProvider: z.literal("huggingface") })),
chutesSchema.merge(z.object({ apiProvider: z.literal("chutes") })),
litellmSchema.merge(z.object({ apiProvider: z.literal("litellm") })),
sambaNovaSchema.merge(z.object({ apiProvider: z.literal("sambanova") })),
defaultSchema,
])

Expand Down Expand Up @@ -300,6 +306,7 @@ export const providerSettingsSchema = z.object({
...huggingFaceSchema.shape,
...chutesSchema.shape,
...litellmSchema.shape,
...sambaNovaSchema.shape,
...codebaseIndexProviderSchema.shape,
})

Expand Down
1 change: 1 addition & 0 deletions packages/types/src/providers/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ export * from "./ollama.js"
export * from "./openai.js"
export * from "./openrouter.js"
export * from "./requesty.js"
export * from "./sambanova.js"
export * from "./unbound.js"
export * from "./vertex.js"
export * from "./vscode-llm.js"
Expand Down
90 changes: 90 additions & 0 deletions packages/types/src/providers/sambanova.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import type { ModelInfo } from "../model.js"

// https://docs.sambanova.ai/cloud/docs/get-started/supported-models
export type SambaNovaModelId =
| "Meta-Llama-3.1-8B-Instruct"
| "Meta-Llama-3.3-70B-Instruct"
| "DeepSeek-R1"
| "DeepSeek-V3-0324"
| "DeepSeek-R1-Distill-Llama-70B"
| "Llama-4-Maverick-17B-128E-Instruct"
| "Llama-3.3-Swallow-70B-Instruct-v0.4"
| "Qwen3-32B"

export const sambaNovaDefaultModelId: SambaNovaModelId = "Meta-Llama-3.3-70B-Instruct"

export const sambaNovaModels = {
"Meta-Llama-3.1-8B-Instruct": {
maxTokens: 8192,
contextWindow: 16384,
supportsImages: false,
supportsPromptCache: false,
inputPrice: 0.1,
outputPrice: 0.2,
description: "Meta Llama 3.1 8B Instruct model with 16K context window.",
},
"Meta-Llama-3.3-70B-Instruct": {
maxTokens: 8192,
contextWindow: 131072,
supportsImages: false,
supportsPromptCache: false,
inputPrice: 0.6,
outputPrice: 1.2,
description: "Meta Llama 3.3 70B Instruct model with 128K context window.",
},
"DeepSeek-R1": {
maxTokens: 8192,
contextWindow: 32768,
supportsImages: false,
supportsPromptCache: false,
supportsReasoningBudget: true,
inputPrice: 5.0,
outputPrice: 7.0,
description: "DeepSeek R1 reasoning model with 32K context window.",
},
"DeepSeek-V3-0324": {
maxTokens: 8192,
contextWindow: 32768,
supportsImages: false,
supportsPromptCache: false,
inputPrice: 3.0,
outputPrice: 4.5,
description: "DeepSeek V3 model with 32K context window.",
},
"DeepSeek-R1-Distill-Llama-70B": {
maxTokens: 8192,
contextWindow: 131072,
supportsImages: false,
supportsPromptCache: false,
inputPrice: 0.7,
outputPrice: 1.4,
description: "DeepSeek R1 distilled Llama 70B model with 128K context window.",
},
"Llama-4-Maverick-17B-128E-Instruct": {
maxTokens: 8192,
contextWindow: 131072,
supportsImages: true,
supportsPromptCache: false,
inputPrice: 0.63,
outputPrice: 1.8,
description: "Meta Llama 4 Maverick 17B 128E Instruct model with 128K context window.",
},
"Llama-3.3-Swallow-70B-Instruct-v0.4": {
maxTokens: 8192,
contextWindow: 16384,
supportsImages: false,
supportsPromptCache: false,
inputPrice: 0.6,
outputPrice: 1.2,
description: "Tokyotech Llama 3.3 Swallow 70B Instruct v0.4 model with 16K context window.",
},
"Qwen3-32B": {
maxTokens: 8192,
contextWindow: 8192,
supportsImages: false,
supportsPromptCache: false,
inputPrice: 0.4,
outputPrice: 0.8,
description: "Alibaba Qwen 3 32B model with 8K context window.",
},
} as const satisfies Record<string, ModelInfo>
3 changes: 3 additions & 0 deletions src/api/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import {
ChutesHandler,
LiteLLMHandler,
ClaudeCodeHandler,
SambaNovaHandler,
} from "./providers"

export interface SingleCompletionHandler {
Expand Down Expand Up @@ -115,6 +116,8 @@ export function buildApiHandler(configuration: ProviderSettings): ApiHandler {
return new ChutesHandler(options)
case "litellm":
return new LiteLLMHandler(options)
case "sambanova":
return new SambaNovaHandler(options)
default:
apiProvider satisfies "gemini-cli" | undefined
return new AnthropicHandler(options)
Expand Down
154 changes: 154 additions & 0 deletions src/api/providers/__tests__/sambanova.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
// npx vitest run src/api/providers/__tests__/sambanova.spec.ts

// Mock vscode first to avoid import errors
vitest.mock("vscode", () => ({}))

import OpenAI from "openai"
import { Anthropic } from "@anthropic-ai/sdk"

import { type SambaNovaModelId, sambaNovaDefaultModelId, sambaNovaModels } from "@roo-code/types"

import { SambaNovaHandler } from "../sambanova"

vitest.mock("openai", () => {
const createMock = vitest.fn()
return {
default: vitest.fn(() => ({ chat: { completions: { create: createMock } } })),
}
})

describe("SambaNovaHandler", () => {
let handler: SambaNovaHandler
let mockCreate: any

beforeEach(() => {
vitest.clearAllMocks()
mockCreate = (OpenAI as unknown as any)().chat.completions.create
handler = new SambaNovaHandler({ sambaNovaApiKey: "test-sambanova-api-key" })
})

it("should use the correct SambaNova base URL", () => {
new SambaNovaHandler({ sambaNovaApiKey: "test-sambanova-api-key" })
expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({ baseURL: "https://api.sambanova.ai/v1" }))
})

it("should use the provided API key", () => {
const sambaNovaApiKey = "test-sambanova-api-key"
new SambaNovaHandler({ sambaNovaApiKey })
expect(OpenAI).toHaveBeenCalledWith(expect.objectContaining({ apiKey: sambaNovaApiKey }))
})

it("should return default model when no model is specified", () => {
const model = handler.getModel()
expect(model.id).toBe(sambaNovaDefaultModelId)
expect(model.info).toEqual(sambaNovaModels[sambaNovaDefaultModelId])
})

it("should return specified model when valid model is provided", () => {
const testModelId: SambaNovaModelId = "Meta-Llama-3.3-70B-Instruct"
const handlerWithModel = new SambaNovaHandler({
apiModelId: testModelId,
sambaNovaApiKey: "test-sambanova-api-key",
})
const model = handlerWithModel.getModel()
expect(model.id).toBe(testModelId)
expect(model.info).toEqual(sambaNovaModels[testModelId])
})

it("completePrompt method should return text from SambaNova API", async () => {
const expectedResponse = "This is a test response from SambaNova"
mockCreate.mockResolvedValueOnce({ choices: [{ message: { content: expectedResponse } }] })
const result = await handler.completePrompt("test prompt")
expect(result).toBe(expectedResponse)
})

it("should handle errors in completePrompt", async () => {
const errorMessage = "SambaNova API error"
mockCreate.mockRejectedValueOnce(new Error(errorMessage))
await expect(handler.completePrompt("test prompt")).rejects.toThrow(
`SambaNova completion error: ${errorMessage}`,
)
})

it("createMessage should yield text content from stream", async () => {
const testContent = "This is test content from SambaNova stream"

mockCreate.mockImplementationOnce(() => {
return {
[Symbol.asyncIterator]: () => ({
next: vitest
.fn()
.mockResolvedValueOnce({
done: false,
value: { choices: [{ delta: { content: testContent } }] },
})
.mockResolvedValueOnce({ done: true }),
}),
}
})

const stream = handler.createMessage("system prompt", [])
const firstChunk = await stream.next()

expect(firstChunk.done).toBe(false)
expect(firstChunk.value).toEqual({ type: "text", text: testContent })
})

it("createMessage should yield usage data from stream", async () => {
mockCreate.mockImplementationOnce(() => {
return {
[Symbol.asyncIterator]: () => ({
next: vitest
.fn()
.mockResolvedValueOnce({
done: false,
value: { choices: [{ delta: {} }], usage: { prompt_tokens: 10, completion_tokens: 20 } },
})
.mockResolvedValueOnce({ done: true }),
}),
}
})

const stream = handler.createMessage("system prompt", [])
const firstChunk = await stream.next()

expect(firstChunk.done).toBe(false)
expect(firstChunk.value).toEqual({ type: "usage", inputTokens: 10, outputTokens: 20 })
})

it("createMessage should pass correct parameters to SambaNova client", async () => {
const modelId: SambaNovaModelId = "Meta-Llama-3.3-70B-Instruct"
const modelInfo = sambaNovaModels[modelId]
const handlerWithModel = new SambaNovaHandler({
apiModelId: modelId,
sambaNovaApiKey: "test-sambanova-api-key",
})

mockCreate.mockImplementationOnce(() => {
return {
[Symbol.asyncIterator]: () => ({
async next() {
return { done: true }
},
}),
}
})

const systemPrompt = "Test system prompt for SambaNova"
const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Test message for SambaNova" }]

const messageGenerator = handlerWithModel.createMessage(systemPrompt, messages)
await messageGenerator.next()

expect(mockCreate).toHaveBeenCalledWith(
expect.objectContaining({
model: modelId,
max_tokens: modelInfo.maxTokens,
temperature: 0.7,
messages: expect.arrayContaining([{ role: "system", content: systemPrompt }]),
stream: true,
stream_options: { include_usage: true },
}),
)
})
})
1 change: 1 addition & 0 deletions src/api/providers/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ export { OpenAiNativeHandler } from "./openai-native"
export { OpenAiHandler } from "./openai"
export { OpenRouterHandler } from "./openrouter"
export { RequestyHandler } from "./requesty"
export { SambaNovaHandler } from "./sambanova"
export { UnboundHandler } from "./unbound"
export { VertexHandler } from "./vertex"
export { VsCodeLmHandler } from "./vscode-lm"
Expand Down
19 changes: 19 additions & 0 deletions src/api/providers/sambanova.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import { type SambaNovaModelId, sambaNovaDefaultModelId, sambaNovaModels } from "@roo-code/types"

import type { ApiHandlerOptions } from "../../shared/api"

import { BaseOpenAiCompatibleProvider } from "./base-openai-compatible-provider"

export class SambaNovaHandler extends BaseOpenAiCompatibleProvider<SambaNovaModelId> {
constructor(options: ApiHandlerOptions) {
super({
...options,
providerName: "SambaNova",
baseURL: "https://api.sambanova.ai/v1",
apiKey: options.sambaNovaApiKey,
defaultProviderModelId: sambaNovaDefaultModelId,
providerModels: sambaNovaModels,
defaultTemperature: 0.7,
})
}
}
1 change: 1 addition & 0 deletions src/shared/ProfileValidator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ export class ProfileValidator {
case "deepseek":
case "xai":
case "groq":
case "sambanova":
case "chutes":
return profile.apiModelId
case "litellm":
Expand Down
1 change: 1 addition & 0 deletions src/shared/__tests__/ProfileValidator.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ describe("ProfileValidator", () => {
"xai",
"groq",
"chutes",
"sambanova",
]

apiModelProviders.forEach((provider) => {
Expand Down
Loading