Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion core/internal/testutil/file_base64.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,15 @@ func RunFileBase64ResponsesTest(t *testing.T, client *bifrost.Bifrost, ctx conte
"unable to read", "no file", "corrupted", "unsupported",
}...) // PDF processing failure indicators

responsesRetryConfig := FileInputResponsesRetryConfig()
retryConfig := GetTestRetryConfigForScenario("FileInput", testConfig)
responsesRetryConfig := ResponsesRetryConfig{
MaxAttempts: retryConfig.MaxAttempts,
BaseDelay: retryConfig.BaseDelay,
MaxDelay: retryConfig.MaxDelay,
Conditions: []ResponsesRetryCondition{},
OnRetry: retryConfig.OnRetry,
OnFinalFail: retryConfig.OnFinalFail,
}

response, responsesError := WithResponsesTestRetry(t, responsesRetryConfig, retryContext, expectations, "FileBase64", func() (*schemas.BifrostResponsesResponse, *schemas.BifrostError) {
bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
Expand Down
2 changes: 1 addition & 1 deletion core/providers/anthropic/responses.go
Original file line number Diff line number Diff line change
Expand Up @@ -1502,7 +1502,7 @@ func ToAnthropicResponsesRequest(bifrostReq *schemas.BifrostResponsesRequest) (*
// Set system message if present
if systemContent != nil {
anthropicReq.System = systemContent
} else if bifrostReq.Params != nil && bifrostReq.Params.Instructions != nil {
} else if bifrostReq.Params != nil && bifrostReq.Params.Instructions != nil && *bifrostReq.Params.Instructions != "" {
// if no system content, check if instructions are present
// system messages take precedence over instructions
anthropicReq.System = &AnthropicContent{
Expand Down
2 changes: 1 addition & 1 deletion core/providers/bedrock/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ func convertChatParameters(ctx *schemas.BifrostContext, bifrostReq *schemas.Bifr
}
// Handle request metadata
if reqMetadata, exists := bifrostReq.Params.ExtraParams["requestMetadata"]; exists {
if metadata, ok := reqMetadata.(map[string]string); ok {
if metadata, ok := schemas.SafeExtractStringMap(reqMetadata); ok {
bedrockReq.RequestMetadata = metadata
}
}
Expand Down
4 changes: 2 additions & 2 deletions core/providers/gemini/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func ToGeminiChatCompletionRequest(bifrostReq *schemas.BifrostChatRequest) *Gemi
if bifrostReq.Params.ExtraParams != nil {
// Safety settings
if safetySettings, ok := schemas.SafeExtractFromMap(bifrostReq.Params.ExtraParams, "safety_settings"); ok {
if settings, ok := safetySettings.([]SafetySetting); ok {
if settings, ok := SafeExtractSafetySettings(safetySettings); ok {
geminiReq.SafetySettings = settings
}
}
Expand All @@ -49,7 +49,7 @@ func ToGeminiChatCompletionRequest(bifrostReq *schemas.BifrostChatRequest) *Gemi

// Labels
if labels, ok := schemas.SafeExtractFromMap(bifrostReq.Params.ExtraParams, "labels"); ok {
if labelMap, ok := labels.(map[string]string); ok {
if labelMap, ok := schemas.SafeExtractStringMap(labels); ok {
geminiReq.Labels = labelMap
}
}
Expand Down
6 changes: 4 additions & 2 deletions core/providers/gemini/gemini.go
Original file line number Diff line number Diff line change
Expand Up @@ -905,7 +905,7 @@ func HandleGeminiResponsesStream(
if isLast {
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
finalResponse.ExtraFields.Latency = time.Since(startTime).Milliseconds()
}
}
providerUtils.ProcessAndSendResponse(ctx, postHookRunner, providerUtils.GetBifrostResponseForStreamResponse(nil, nil, finalResponse, nil, nil), responseChan)
}
}()
Expand Down Expand Up @@ -1032,7 +1032,9 @@ func (provider *GeminiProvider) Speech(ctx *schemas.BifrostContext, key schemas.
if bifrostErr != nil {
return nil, bifrostErr
}
ctx.SetValue(BifrostContextKeyResponseFormat, request.Params.ResponseFormat)
if request.Params != nil {
ctx.SetValue(BifrostContextKeyResponseFormat, request.Params.ResponseFormat)
}
response, convErr := geminiResponse.ToBifrostSpeechResponse(ctx)
if convErr != nil {
return nil, providerUtils.NewBifrostOperationError(schemas.ErrProviderResponseDecode, convErr, provider.GetProviderKey())
Expand Down
2 changes: 1 addition & 1 deletion core/providers/gemini/responses.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ func ToGeminiResponsesRequest(bifrostReq *schemas.BifrostResponsesRequest) *Gemi

if bifrostReq.Params.ExtraParams != nil {
if safetySettings, ok := schemas.SafeExtractFromMap(bifrostReq.Params.ExtraParams, "safety_settings"); ok {
if settings, ok := safetySettings.([]SafetySetting); ok {
if settings, ok := SafeExtractSafetySettings(safetySettings); ok {
geminiReq.SafetySettings = settings
}
}
Expand Down
4 changes: 2 additions & 2 deletions core/providers/gemini/transcription.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ func ToGeminiTranscriptionRequest(bifrostReq *schemas.BifrostTranscriptionReques
if bifrostReq.Params.ExtraParams != nil {
// Safety settings
if safetySettings, ok := schemas.SafeExtractFromMap(bifrostReq.Params.ExtraParams, "safety_settings"); ok {
if settings, ok := safetySettings.([]SafetySetting); ok {
if settings, ok := SafeExtractSafetySettings(safetySettings); ok {
geminiReq.SafetySettings = settings
}
}
Expand All @@ -127,7 +127,7 @@ func ToGeminiTranscriptionRequest(bifrostReq *schemas.BifrostTranscriptionReques

// Labels
if labels, ok := schemas.SafeExtractFromMap(bifrostReq.Params.ExtraParams, "labels"); ok {
if labelMap, ok := labels.(map[string]string); ok {
if labelMap, ok := schemas.SafeExtractStringMap(labels); ok {
geminiReq.Labels = labelMap
}
}
Expand Down
34 changes: 34 additions & 0 deletions core/providers/gemini/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,40 @@ type SafetySetting struct {
Threshold string `json:"threshold,omitempty"`
}

// SafeExtractSafetySettings safely extracts []SafetySetting from an interface{} with type checking.
// Handles both direct []SafetySetting and JSON-deserialized []interface{} cases.
func SafeExtractSafetySettings(value interface{}) ([]SafetySetting, bool) {
if value == nil {
return nil, false
}
switch v := value.(type) {
case []SafetySetting:
return v, true
case []interface{}:
settings := make([]SafetySetting, 0, len(v))
for _, item := range v {
if m, ok := item.(map[string]interface{}); ok {
setting := SafetySetting{}
if method, ok := m["method"].(string); ok {
setting.Method = method
}
if category, ok := m["category"].(string); ok {
setting.Category = category
}
if threshold, ok := m["threshold"].(string); ok {
setting.Threshold = threshold
}
settings = append(settings, setting)
} else {
return nil, false
}
}
return settings, true
default:
return nil, false
}
}

// FunctionCallingConfig represents function calling configuration.
type FunctionCallingConfig struct {
// Optional. Function calling mode.
Expand Down
6 changes: 2 additions & 4 deletions core/providers/openai/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"github.com/maximhq/bifrost/core/schemas"
)

// ToBifrostListModelsResponse converts an OpenAI list models response to a Bifrost list models response
func (response *OpenAIListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels []string) *schemas.BifrostListModelsResponse {
if response == nil {
return nil
Expand All @@ -31,16 +32,14 @@ func (response *OpenAIListModelsResponse) ToBifrostListModelsResponse(providerKe
return bifrostResponse
}

// ToOpenAIListModelsResponse converts a Bifrost list models response to an OpenAI list models response
func ToOpenAIListModelsResponse(response *schemas.BifrostListModelsResponse) *OpenAIListModelsResponse {

if response == nil {
return nil
}

openaiResponse := &OpenAIListModelsResponse{
Data: make([]OpenAIModel, 0, len(response.Data)),
}
Comment thread
akshaydeo marked this conversation as resolved.

for _, model := range response.Data {
openaiModel := OpenAIModel{
ID: model.ID,
Expand All @@ -56,6 +55,5 @@ func ToOpenAIListModelsResponse(response *schemas.BifrostListModelsResponse) *Op
openaiResponse.Data = append(openaiResponse.Data, openaiModel)

}

return openaiResponse
}
3 changes: 2 additions & 1 deletion core/providers/openai/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ func (r *OpenAITranscriptionRequest) IsStreamingRequested() bool {
return r.Stream != nil && *r.Stream
}

// MODEL TYPES
// OpenAIModel represents an OpenAI model
type OpenAIModel struct {
ID string `json:"id"`
Object string `json:"object"`
Expand All @@ -537,6 +537,7 @@ type OpenAIModel struct {
ContextWindow *int `json:"context_window,omitempty"`
}

// OpenAIListModelsResponse represents an OpenAI list models response
type OpenAIListModelsResponse struct {
Object string `json:"object"`
Data []OpenAIModel `json:"data"`
Expand Down
2 changes: 1 addition & 1 deletion core/providers/utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -891,7 +891,7 @@ func ProcessAndSendResponse(
}
}

// Run post hooks on the response first so span reflects post-processed data
// Run post hooks on the response (note: accumulated chunks above contain pre-hook data)
processedResponse, processedError := postHookRunner(ctx, response, nil)

if HandleStreamControlSkip(processedError) {
Expand Down
3 changes: 2 additions & 1 deletion core/schemas/plugin_wasm.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
package schemas

// PluginShortCircuit represents a plugin's decision to short-circuit the normal flow.
// It can contain either a response (success short-circuit), a stream (streaming short-circuit), or an error (error short-circuit).
// It can contain either a response (success short-circuit), an error (error short-circuit).
// Streams are not supported in WASM plugins.
Comment thread
akshaydeo marked this conversation as resolved.
type PluginShortCircuit struct {
Response *BifrostResponse // If set, short-circuit with this response (skips provider call)
Error *BifrostError // If set, short-circuit with this error (can set AllowFallbacks field)
Expand Down
24 changes: 24 additions & 0 deletions core/schemas/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,30 @@ func SafeExtractFromMap(m map[string]interface{}, key string) (interface{}, bool
return value, exists
}

// SafeExtractStringMap safely extracts a map[string]string from an interface{} with type checking.
// Handles both direct map[string]string and JSON-deserialized map[string]interface{} cases.
func SafeExtractStringMap(value interface{}) (map[string]string, bool) {
if value == nil {
return nil, false
}
switch v := value.(type) {
case map[string]string:
return v, true
case map[string]interface{}:
result := make(map[string]string, len(v))
for key, val := range v {
if str, ok := SafeExtractString(val); ok {
result[key] = str
} else {
return nil, false
}
}
return result, true
default:
return nil, false
}
}

func SafeExtractOrderedMap(value interface{}) (OrderedMap, bool) {
if value == nil {
return OrderedMap{}, false
Expand Down
1 change: 1 addition & 0 deletions framework/streaming/accumulator.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ func (a *Accumulator) createStreamAccumulator(requestID string) *StreamAccumulat
MaxTranscriptionChunkIndex: -1,
MaxAudioChunkIndex: -1,
IsComplete: false,
mu: sync.Mutex{},
Timestamp: now,
StartTimestamp: now, // Set default StartTimestamp for proper TTFT/latency calculation
}
Expand Down
8 changes: 8 additions & 0 deletions framework/streaming/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ type StreamAccumulator struct {

// getLastChatChunk returns the chunk with the highest ChunkIndex (contains metadata like TokenUsage, Cost)
func (sa *StreamAccumulator) getLastChatChunk() *ChatStreamChunk {
sa.mu.Lock()
defer sa.mu.Unlock()
if sa.MaxChatChunkIndex < 0 {
return nil
}
Expand All @@ -137,6 +139,8 @@ func (sa *StreamAccumulator) getLastChatChunk() *ChatStreamChunk {

// getLastResponsesChunk returns the chunk with the highest ChunkIndex (contains metadata like TokenUsage, Cost)
func (sa *StreamAccumulator) getLastResponsesChunk() *ResponsesStreamChunk {
sa.mu.Lock()
defer sa.mu.Unlock()
if sa.MaxResponsesChunkIndex < 0 {
return nil
}
Expand All @@ -150,6 +154,8 @@ func (sa *StreamAccumulator) getLastResponsesChunk() *ResponsesStreamChunk {

// getLastTranscriptionChunk returns the chunk with the highest ChunkIndex (contains metadata like TokenUsage, Cost)
func (sa *StreamAccumulator) getLastTranscriptionChunk() *TranscriptionStreamChunk {
sa.mu.Lock()
defer sa.mu.Unlock()
if sa.MaxTranscriptionChunkIndex < 0 {
return nil
}
Expand All @@ -163,6 +169,8 @@ func (sa *StreamAccumulator) getLastTranscriptionChunk() *TranscriptionStreamChu

// getLastAudioChunk returns the chunk with the highest ChunkIndex (contains metadata like TokenUsage, Cost)
func (sa *StreamAccumulator) getLastAudioChunk() *AudioStreamChunk {
sa.mu.Lock()
defer sa.mu.Unlock()
if sa.MaxAudioChunkIndex < 0 {
return nil
}
Expand Down
6 changes: 6 additions & 0 deletions plugins/governance/advancedscenarios_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,7 @@ func TestHierarchicalChainBudgetSwitch(t *testing.T) {
// Exhaust Customer1's budget (which is limiting Team1)
consumedBudget := 0.0
requestNum := 1
budgetExhausted := false

for requestNum <= 150 {
resp := MakeRequest(t, APIRequest{
Expand All @@ -471,6 +472,7 @@ func TestHierarchicalChainBudgetSwitch(t *testing.T) {

if resp.StatusCode >= 400 {
if CheckErrorMessage(t, resp, "budget") {
budgetExhausted = true
t.Logf("Customer1 budget exhausted at request %d (consumed: $%.6f)", requestNum, consumedBudget)
break
} else {
Expand All @@ -490,6 +492,10 @@ func TestHierarchicalChainBudgetSwitch(t *testing.T) {
requestNum++
}

if !budgetExhausted {
t.Fatalf("Budget should have been exhausted within 150 requests, but no budget rejection was observed (consumed: $%.6f)", consumedBudget)
}

// Switch VK to Team2 (under Customer2)
updateResp := MakeRequest(t, APIRequest{
Method: "PUT",
Expand Down
2 changes: 2 additions & 0 deletions plugins/semanticcache/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/json"
"fmt"
"sort"
"sync"
"time"
)

Expand All @@ -19,6 +20,7 @@ func (plugin *Plugin) createStreamAccumulator(requestID string, embedding []floa
Embedding: embedding,
Metadata: metadata,
TTL: ttl,
mu: sync.Mutex{},
}

plugin.streamAccumulators.Store(requestID, accumulator)
Expand Down
8 changes: 4 additions & 4 deletions transports/bifrost-http/handlers/inference.go
Original file line number Diff line number Diff line change
Expand Up @@ -506,14 +506,14 @@ func (h *CompletionHandler) listModels(ctx *fasthttp.RequestCtx) {
}
if pricingEntry != nil && modelEntry.Pricing == nil {
pricing := &schemas.Pricing{
Prompt: bifrost.Ptr(fmt.Sprintf("%f", pricingEntry.InputCostPerToken)),
Completion: bifrost.Ptr(fmt.Sprintf("%f", pricingEntry.OutputCostPerToken)),
Prompt: bifrost.Ptr(fmt.Sprintf("%.10f", pricingEntry.InputCostPerToken)),
Completion: bifrost.Ptr(fmt.Sprintf("%.10f", pricingEntry.OutputCostPerToken)),
}
if pricingEntry.InputCostPerImage != nil {
pricing.Image = bifrost.Ptr(fmt.Sprintf("%f", *pricingEntry.InputCostPerImage))
pricing.Image = bifrost.Ptr(fmt.Sprintf("%.10f", *pricingEntry.InputCostPerImage))
}
if pricingEntry.CacheReadInputTokenCost != nil {
pricing.InputCacheRead = bifrost.Ptr(fmt.Sprintf("%f", *pricingEntry.CacheReadInputTokenCost))
pricing.InputCacheRead = bifrost.Ptr(fmt.Sprintf("%.10f", *pricingEntry.CacheReadInputTokenCost))
}
resp.Data[i].Pricing = pricing
}
Expand Down