Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/api/providers/__tests__/chutes.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,6 @@ describe("ChutesHandler", () => {
expect.objectContaining({
model: modelId,
max_tokens: modelInfo.maxTokens,
temperature: 0.5,
messages: expect.arrayContaining([{ role: "system", content: systemPrompt }]),
stream: true,
stream_options: { include_usage: true },
Expand Down
1 change: 0 additions & 1 deletion src/api/providers/__tests__/fireworks.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,6 @@ describe("FireworksHandler", () => {
expect.objectContaining({
model: modelId,
max_tokens: modelInfo.maxTokens,
temperature: 0.5,
messages: expect.arrayContaining([{ role: "system", content: systemPrompt }]),
stream: true,
stream_options: { include_usage: true },
Expand Down
77 changes: 76 additions & 1 deletion src/api/providers/__tests__/groq.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,11 @@ describe("GroqHandler", () => {
it("createMessage should pass correct parameters to Groq client", async () => {
const modelId: GroqModelId = "llama-3.1-8b-instant"
const modelInfo = groqModels[modelId]
const handlerWithModel = new GroqHandler({ apiModelId: modelId, groqApiKey: "test-groq-api-key" })
const handlerWithModel = new GroqHandler({
apiModelId: modelId,
groqApiKey: "test-groq-api-key",
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this intentional? The temperature is set inline here (line 119) while in other tests it's part of the options object. Could we standardize this for better consistency?

modelTemperature: 0.5, // Explicitly set temperature for this test
})

mockCreate.mockImplementationOnce(() => {
return {
Expand Down Expand Up @@ -143,4 +147,75 @@ describe("GroqHandler", () => {
}),
)
})

it("should omit temperature when modelTemperature is undefined", async () => {
const modelId: GroqModelId = "llama-3.1-8b-instant"
const handlerWithoutTemp = new GroqHandler({
apiModelId: modelId,
groqApiKey: "test-groq-api-key",
// modelTemperature is not set
})

mockCreate.mockImplementationOnce(() => {
return {
[Symbol.asyncIterator]: () => ({
async next() {
return { done: true }
},
}),
}
})

const systemPrompt = "Test system prompt"
const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Test message" }]

const messageGenerator = handlerWithoutTemp.createMessage(systemPrompt, messages)
await messageGenerator.next()

expect(mockCreate).toHaveBeenCalledWith(
expect.objectContaining({
model: modelId,
messages: expect.arrayContaining([{ role: "system", content: systemPrompt }]),
stream: true,
}),
)

// Verify temperature is NOT included
const callArgs = mockCreate.mock.calls[0][0]
expect(callArgs).not.toHaveProperty("temperature")
})

it("should include temperature when modelTemperature is explicitly set", async () => {
const modelId: GroqModelId = "llama-3.1-8b-instant"
const handlerWithTemp = new GroqHandler({
apiModelId: modelId,
groqApiKey: "test-groq-api-key",
modelTemperature: 0.7,
})

mockCreate.mockImplementationOnce(() => {
return {
[Symbol.asyncIterator]: () => ({
async next() {
return { done: true }
},
}),
}
})

const systemPrompt = "Test system prompt"
const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Test message" }]

const messageGenerator = handlerWithTemp.createMessage(systemPrompt, messages)
await messageGenerator.next()

expect(mockCreate).toHaveBeenCalledWith(
expect.objectContaining({
model: modelId,
temperature: 0.7,
messages: expect.arrayContaining([{ role: "system", content: systemPrompt }]),
stream: true,
}),
)
})
})
67 changes: 66 additions & 1 deletion src/api/providers/__tests__/openai.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,71 @@ describe("OpenAiHandler", () => {
const callArgs = mockCreate.mock.calls[0][0]
expect(callArgs.max_completion_tokens).toBe(4096)
})

it("should omit temperature when modelTemperature is undefined", async () => {
const optionsWithoutTemperature: ApiHandlerOptions = {
...mockOptions,
// modelTemperature is not set, should not include temperature
}
const handlerWithoutTemperature = new OpenAiHandler(optionsWithoutTemperature)
const stream = handlerWithoutTemperature.createMessage(systemPrompt, messages)
// Consume the stream to trigger the API call
for await (const _chunk of stream) {
}
// Assert the mockCreate was called without temperature
expect(mockCreate).toHaveBeenCalled()
const callArgs = mockCreate.mock.calls[0][0]
expect(callArgs).not.toHaveProperty("temperature")
})

it("should include temperature when modelTemperature is explicitly set to 0", async () => {
const optionsWithZeroTemperature: ApiHandlerOptions = {
...mockOptions,
modelTemperature: 0,
}
const handlerWithZeroTemperature = new OpenAiHandler(optionsWithZeroTemperature)
const stream = handlerWithZeroTemperature.createMessage(systemPrompt, messages)
// Consume the stream to trigger the API call
for await (const _chunk of stream) {
}
// Assert the mockCreate was called with temperature: 0
expect(mockCreate).toHaveBeenCalled()
const callArgs = mockCreate.mock.calls[0][0]
expect(callArgs.temperature).toBe(0)
})

it("should include temperature when modelTemperature is set to a non-zero value", async () => {
const optionsWithCustomTemperature: ApiHandlerOptions = {
...mockOptions,
modelTemperature: 0.7,
}
const handlerWithCustomTemperature = new OpenAiHandler(optionsWithCustomTemperature)
const stream = handlerWithCustomTemperature.createMessage(systemPrompt, messages)
// Consume the stream to trigger the API call
for await (const _chunk of stream) {
}
// Assert the mockCreate was called with temperature: 0.7
expect(mockCreate).toHaveBeenCalled()
const callArgs = mockCreate.mock.calls[0][0]
expect(callArgs.temperature).toBe(0.7)
})

it("should include DEEP_SEEK_DEFAULT_TEMPERATURE for deepseek-reasoner models when temperature is not set", async () => {
const deepseekOptions: ApiHandlerOptions = {
...mockOptions,
openAiModelId: "deepseek-reasoner",
// modelTemperature is not set
}
const deepseekHandler = new OpenAiHandler(deepseekOptions)
const stream = deepseekHandler.createMessage(systemPrompt, messages)
// Consume the stream to trigger the API call
for await (const _chunk of stream) {
}
// Assert the mockCreate was called with DEEP_SEEK_DEFAULT_TEMPERATURE (0.6)
expect(mockCreate).toHaveBeenCalled()
const callArgs = mockCreate.mock.calls[0][0]
expect(callArgs.temperature).toBe(0.6)
})
})

describe("error handling", () => {
Expand Down Expand Up @@ -450,7 +515,7 @@ describe("OpenAiHandler", () => {
],
stream: true,
stream_options: { include_usage: true },
temperature: 0,
// temperature should be omitted when not set
},
{ path: "/models/chat/completions" },
)
Expand Down
6 changes: 3 additions & 3 deletions src/api/providers/__tests__/roo.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -350,16 +350,16 @@ describe("RooHandler", () => {
})

describe("temperature and model configuration", () => {
it("should use default temperature of 0.7", async () => {
it("should omit temperature when not explicitly set", async () => {
handler = new RooHandler(mockOptions)
const stream = handler.createMessage(systemPrompt, messages)
for await (const _chunk of stream) {
// Consume stream
}

expect(mockCreate).toHaveBeenCalledWith(
expect.objectContaining({
temperature: 0.7,
expect.not.objectContaining({
temperature: expect.anything(),
}),
)
})
Expand Down
1 change: 0 additions & 1 deletion src/api/providers/__tests__/sambanova.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ describe("SambaNovaHandler", () => {
expect.objectContaining({
model: modelId,
max_tokens: modelInfo.maxTokens,
temperature: 0.7,
messages: expect.arrayContaining([{ role: "system", content: systemPrompt }]),
stream: true,
stream_options: { include_usage: true },
Expand Down
1 change: 0 additions & 1 deletion src/api/providers/__tests__/zai.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,6 @@ describe("ZAiHandler", () => {
expect.objectContaining({
model: modelId,
max_tokens: modelInfo.maxTokens,
temperature: ZAI_DEFAULT_TEMPERATURE,
messages: expect.arrayContaining([{ role: "system", content: systemPrompt }]),
stream: true,
stream_options: { include_usage: true },
Expand Down
8 changes: 5 additions & 3 deletions src/api/providers/base-openai-compatible-provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -72,17 +72,19 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>
info: { maxTokens: max_tokens },
} = this.getModel()

const temperature = this.options.modelTemperature ?? this.defaultTemperature

const params: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
model,
max_tokens,
temperature,
messages: [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)],
stream: true,
stream_options: { include_usage: true },
}

// Only include temperature if explicitly set
if (this.options.modelTemperature !== undefined) {
params.temperature = this.options.modelTemperature
}

return this.client.chat.completions.create(params)
}

Expand Down
9 changes: 8 additions & 1 deletion src/api/providers/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -157,13 +157,20 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl

const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
model: modelId,
temperature: this.options.modelTemperature ?? (deepseekReasoner ? DEEP_SEEK_DEFAULT_TEMPERATURE : 0),
messages: convertedMessages,
stream: true as const,
...(isGrokXAI ? {} : { stream_options: { include_usage: true } }),
...(reasoning && reasoning),
}

// Only include temperature if explicitly set
if (this.options.modelTemperature !== undefined) {
requestOptions.temperature = this.options.modelTemperature
} else if (deepseekReasoner) {
// DeepSeek Reasoner has a specific default temperature
requestOptions.temperature = DEEP_SEEK_DEFAULT_TEMPERATURE
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The special handling for DeepSeek Reasoner models is well-implemented. Could we add a comment explaining why this model requires a specific default temperature? This would help future maintainers understand the reasoning behind this special case.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we don't need to send this anymore since we are letting the provider set it

}

// Add max_tokens if needed
this.addMaxTokensIfNeeded(requestOptions, modelInfo)

Expand Down
Loading