From 87b86650d60d54c65e70f850b7de6d9fd387b102 Mon Sep 17 00:00:00 2001 From: Pratham-Mishra04 Date: Thu, 11 Sep 2025 14:11:26 +0530 Subject: [PATCH] feat: responses api option added in openai provider --- core/bifrost.go | 345 +++- core/mcp.go | 147 +- core/providers/anthropic.go | 165 +- core/providers/azure.go | 166 +- core/providers/bedrock.go | 347 ++-- core/providers/cerebras.go | 221 +-- core/providers/cohere.go | 131 +- core/providers/gemini.go | 323 ++-- core/providers/groq.go | 139 +- core/providers/mistral.go | 213 +-- core/providers/ollama.go | 168 +- core/providers/openai.go | 482 ++++-- core/providers/openrouter.go | 223 +-- core/providers/parasail.go | 139 +- core/providers/sgl.go | 160 +- core/providers/utils.go | 194 +-- core/providers/vertex.go | 128 +- core/schemas/account.go | 7 + core/schemas/bifrost.go | 632 ++----- core/schemas/chatcompletions.go | 323 ++++ core/schemas/embedding.go | 91 + core/schemas/mux.go | 865 ++++++++++ core/schemas/provider.go | 53 +- core/schemas/providers/anthropic/chat.go | 387 ++--- core/schemas/providers/anthropic/responses.go | 1051 ++++++++++++ core/schemas/providers/anthropic/text.go | 39 +- core/schemas/providers/anthropic/types.go | 73 +- core/schemas/providers/anthropic/utils.go | 104 +- core/schemas/providers/bedrock/chat.go | 51 +- core/schemas/providers/bedrock/embedding.go | 96 ++ core/schemas/providers/bedrock/responses.go | 532 ++++++ core/schemas/providers/bedrock/text.go | 85 +- core/schemas/providers/bedrock/types.go | 48 +- core/schemas/providers/bedrock/utils.go | 212 +-- core/schemas/providers/cohere/chat.go | 178 +- core/schemas/providers/cohere/embedding.go | 42 +- core/schemas/providers/cohere/responses.go | 424 +++++ core/schemas/providers/cohere/types.go | 56 +- core/schemas/providers/gemini/chat.go | 255 +-- core/schemas/providers/gemini/embedding.go | 16 +- core/schemas/providers/gemini/responses.go | 721 ++++++++ core/schemas/providers/gemini/speech.go | 48 + .../schemas/providers/gemini/transcription.go | 73 + core/schemas/providers/gemini/types.go | 59 +- core/schemas/providers/gemini/utils.go | 265 ++- core/schemas/providers/mistral/embedding.go | 36 +- core/schemas/providers/openai/chat.go | 48 +- core/schemas/providers/openai/embedding.go | 72 +- core/schemas/providers/openai/error.go | 50 - core/schemas/providers/openai/responses.go | 36 + core/schemas/providers/openai/speech.go | 62 +- core/schemas/providers/openai/stream.go | 85 - core/schemas/providers/openai/text.go | 85 +- .../schemas/providers/openai/transcription.go | 74 +- core/schemas/providers/openai/types.go | 191 +-- core/schemas/providers/openai/utils.go | 130 -- core/schemas/providers/vertex/embedding.go | 34 +- core/schemas/responses.go | 1488 +++++++++++++++++ core/schemas/speech.go | 84 + core/schemas/textcompletions.go | 69 + core/schemas/transcriptions.go | 54 + core/schemas/utils.go | 655 ++++---- core/utils.go | 18 +- docs/apis/openapi.json | 40 +- docs/features/fallbacks.mdx | 6 +- docs/features/mcp.mdx | 16 +- docs/features/plugins/jsonparser.mdx | 8 +- docs/features/plugins/mocker.mdx | 8 +- docs/quickstart/go-sdk/multimodal.mdx | 24 +- .../go-sdk/provider-configuration.mdx | 10 +- docs/quickstart/go-sdk/setting-up.mdx | 8 +- docs/quickstart/go-sdk/tool-calling.mdx | 20 +- framework/configstore/sqlite.go | 19 + framework/configstore/tables.go | 17 + framework/logstore/tables.go | 32 +- framework/pricing/main.go | 26 +- framework/pricing/utils.go | 6 +- plugins/governance/main.go | 36 +- plugins/jsonparser/main.go | 22 +- plugins/jsonparser/plugin_test.go | 47 +- plugins/logging/main.go | 200 ++- plugins/logging/operations.go | 32 +- plugins/logging/streaming.go | 66 +- plugins/maxim/main.go | 110 +- plugins/maxim/plugin_test.go | 14 +- plugins/mocker/benchmark_test.go | 120 +- plugins/mocker/main.go | 127 +- plugins/mocker/plugin_test.go | 102 +- plugins/semanticcache/main.go | 21 +- .../plugin_conversation_config_test.go | 56 +- .../semanticcache/plugin_edge_cases_test.go | 330 ++-- .../semanticcache/plugin_integration_test.go | 248 +-- .../plugin_normalization_test.go | 62 +- .../semanticcache/plugin_responses_test.go | 415 +++++ plugins/semanticcache/search.go | 14 +- plugins/semanticcache/test_utils.go | 144 +- plugins/semanticcache/utils.go | 607 +++++-- plugins/telemetry/main.go | 75 +- plugins/telemetry/setup.go | 2 +- tests/core-chatbot/main.go | 30 +- tests/core-providers/custom_test.go | 15 +- tests/core-providers/openai_test.go | 3 +- .../scenarios/automatic_function_calling.go | 2 +- .../scenarios/chat_completion_stream.go | 4 +- .../scenarios/complete_end_to_end.go | 4 +- .../scenarios/end_to_end_tool_calling.go | 4 +- .../core-providers/scenarios/image_base64.go | 2 +- tests/core-providers/scenarios/image_url.go | 2 +- .../scenarios/multi_turn_conversation.go | 4 +- .../scenarios/multiple_images.go | 6 +- .../scenarios/multiple_tool_calls.go | 2 +- .../scenarios/provider_specific.go | 2 +- tests/core-providers/scenarios/simple_chat.go | 2 +- tests/core-providers/scenarios/tool_calls.go | 2 +- tests/core-providers/scenarios/utils.go | 22 +- .../bifrost-http/handlers/completions.go | 962 +++++++---- transports/bifrost-http/handlers/mcp.go | 2 +- .../bifrost-http/integrations/anthropic.go | 4 +- transports/bifrost-http/integrations/genai.go | 4 +- .../bifrost-http/integrations/openai.go | 110 +- transports/bifrost-http/integrations/utils.go | 337 ++-- transports/bifrost-http/lib/config.go | 32 +- transports/bifrost-http/main.go | 6 +- transports/go.mod | 19 +- transports/go.sum | 25 - ui/app/logs/page.tsx | 6 +- ui/app/logs/views/logMessageView.tsx | 4 +- .../fragments/apiKeysFormFragment.tsx | 26 + ui/lib/schemas/providerForm.ts | 5 + ui/lib/types/config.ts | 9 + ui/lib/types/logs.ts | 10 +- ui/lib/types/schemas.ts | 6 + 132 files changed, 12055 insertions(+), 6456 deletions(-) create mode 100644 core/schemas/chatcompletions.go create mode 100644 core/schemas/embedding.go create mode 100644 core/schemas/mux.go create mode 100644 core/schemas/providers/anthropic/responses.go create mode 100644 core/schemas/providers/bedrock/embedding.go create mode 100644 core/schemas/providers/bedrock/responses.go create mode 100644 core/schemas/providers/cohere/responses.go create mode 100644 core/schemas/providers/gemini/responses.go create mode 100644 core/schemas/providers/gemini/speech.go create mode 100644 core/schemas/providers/gemini/transcription.go delete mode 100644 core/schemas/providers/openai/error.go create mode 100644 core/schemas/providers/openai/responses.go delete mode 100644 core/schemas/providers/openai/stream.go delete mode 100644 core/schemas/providers/openai/utils.go create mode 100644 core/schemas/responses.go create mode 100644 core/schemas/speech.go create mode 100644 core/schemas/textcompletions.go create mode 100644 core/schemas/transcriptions.go create mode 100644 plugins/semanticcache/plugin_responses_test.go diff --git a/core/bifrost.go b/core/bifrost.go index c071a565c4..de061db9dd 100644 --- a/core/bifrost.go +++ b/core/bifrost.go @@ -25,7 +25,6 @@ type ChannelMessage struct { Response chan *schemas.BifrostResponse ResponseStream chan chan *schemas.BifrostStream Err chan schemas.BifrostError - Type schemas.RequestType } // Bifrost manages providers and maintains specified open channels for concurrent processing. @@ -42,6 +41,7 @@ type Bifrost struct { errorChannelPool sync.Pool // Pool for error channels, initial pool size is set in Init responseStreamPool sync.Pool // Pool for response stream channels, initial pool size is set in Init pluginPipelinePool sync.Pool // Pool for PluginPipeline objects + bifrostRequestPool sync.Pool // Pool for BifrostRequest objects logger schemas.Logger // logger instance, default logger is used if not provided mcpManager *MCPManager // MCP integration manager (nil if MCP not configured) dropExcessRequests atomic.Bool // If true, in cases where the queue is full, requests will not wait for the queue to be empty and will be dropped instead. @@ -117,6 +117,11 @@ func Init(ctx context.Context, config schemas.BifrostConfig) (*Bifrost, error) { } }, } + bifrost.bifrostRequestPool = sync.Pool{ + New: func() interface{} { + return &schemas.BifrostRequest{} + }, + } // Prewarm pools with multiple objects for range config.InitialPoolSize { @@ -129,6 +134,7 @@ func Init(ctx context.Context, config schemas.BifrostConfig) (*Bifrost, error) { preHookErrors: make([]error, 0), postHookErrors: make([]error, 0), }) + bifrost.bifrostRequestPool.Put(&schemas.BifrostRequest{}) } providerKeys, err := bifrost.account.GetConfiguredProviders() @@ -182,8 +188,8 @@ func Init(ctx context.Context, config schemas.BifrostConfig) (*Bifrost, error) { // PUBLIC API METHODS // TextCompletionRequest sends a text completion request to the specified provider. -func (bifrost *Bifrost) TextCompletionRequest(ctx context.Context, req *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { - if req.Input.TextCompletionInput == nil { +func (bifrost *Bifrost) TextCompletionRequest(ctx context.Context, req *schemas.BifrostTextCompletionRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { + if req.Input.PromptStr == nil && req.Input.PromptArray == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: schemas.ErrorField{ @@ -192,12 +198,19 @@ func (bifrost *Bifrost) TextCompletionRequest(ctx context.Context, req *schemas. } } - return bifrost.handleRequest(ctx, req, schemas.TextCompletionRequest) + bifrostReq := bifrost.getBifrostRequest() + bifrostReq.Provider = req.Provider + bifrostReq.Model = req.Model + bifrostReq.Fallbacks = req.Fallbacks + bifrostReq.RequestType = schemas.TextCompletionRequest + bifrostReq.TextCompletionRequest = req + + return bifrost.handleRequest(ctx, bifrostReq) } // ChatCompletionRequest sends a chat completion request to the specified provider. -func (bifrost *Bifrost) ChatCompletionRequest(ctx context.Context, req *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { - if req.Input.ChatCompletionInput == nil { +func (bifrost *Bifrost) ChatCompletionRequest(ctx context.Context, req *schemas.BifrostChatRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { + if req.Input == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: schemas.ErrorField{ @@ -206,12 +219,19 @@ func (bifrost *Bifrost) ChatCompletionRequest(ctx context.Context, req *schemas. } } - return bifrost.handleRequest(ctx, req, schemas.ChatCompletionRequest) + bifrostReq := bifrost.getBifrostRequest() + bifrostReq.Provider = req.Provider + bifrostReq.Model = req.Model + bifrostReq.Fallbacks = req.Fallbacks + bifrostReq.RequestType = schemas.ChatCompletionRequest + bifrostReq.ChatRequest = req + + return bifrost.handleRequest(ctx, bifrostReq) } // ChatCompletionStreamRequest sends a chat completion stream request to the specified provider. -func (bifrost *Bifrost) ChatCompletionStreamRequest(ctx context.Context, req *schemas.BifrostRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { - if req.Input.ChatCompletionInput == nil { +func (bifrost *Bifrost) ChatCompletionStreamRequest(ctx context.Context, req *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { + if req.Input == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: schemas.ErrorField{ @@ -220,12 +240,61 @@ func (bifrost *Bifrost) ChatCompletionStreamRequest(ctx context.Context, req *sc } } - return bifrost.handleStreamRequest(ctx, req, schemas.ChatCompletionStreamRequest) + bifrostReq := bifrost.getBifrostRequest() + bifrostReq.Provider = req.Provider + bifrostReq.Model = req.Model + bifrostReq.Fallbacks = req.Fallbacks + bifrostReq.RequestType = schemas.ChatCompletionStreamRequest + bifrostReq.ChatRequest = req + + return bifrost.handleStreamRequest(ctx, bifrostReq) +} + +// ResponsesRequest sends a responses request to the specified provider. +func (bifrost *Bifrost) ResponsesRequest(ctx context.Context, req *schemas.BifrostResponsesRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { + if req.Input == nil { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Message: "responses not provided for responses request", + }, + } + } + + bifrostReq := bifrost.getBifrostRequest() + bifrostReq.Provider = req.Provider + bifrostReq.Model = req.Model + bifrostReq.Fallbacks = req.Fallbacks + bifrostReq.RequestType = schemas.ResponsesRequest + bifrostReq.ResponsesRequest = req + + return bifrost.handleRequest(ctx, bifrostReq) +} + +// ResponsesStreamRequest sends a responses stream request to the specified provider. +func (bifrost *Bifrost) ResponsesStreamRequest(ctx context.Context, req *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { + if req.Input == nil { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Message: "responses not provided for responses stream request", + }, + } + } + + bifrostReq := bifrost.getBifrostRequest() + bifrostReq.Provider = req.Provider + bifrostReq.Model = req.Model + bifrostReq.Fallbacks = req.Fallbacks + bifrostReq.RequestType = schemas.ResponsesStreamRequest + bifrostReq.ResponsesRequest = req + + return bifrost.handleStreamRequest(ctx, bifrostReq) } // EmbeddingRequest sends an embedding request to the specified provider. -func (bifrost *Bifrost) EmbeddingRequest(ctx context.Context, req *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { - if req.Input.EmbeddingInput == nil { +func (bifrost *Bifrost) EmbeddingRequest(ctx context.Context, req *schemas.BifrostEmbeddingRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { + if req.Input.Text == nil && req.Input.Texts == nil && req.Input.Embedding == nil && req.Input.Embeddings == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: schemas.ErrorField{ @@ -234,12 +303,19 @@ func (bifrost *Bifrost) EmbeddingRequest(ctx context.Context, req *schemas.Bifro } } - return bifrost.handleRequest(ctx, req, schemas.EmbeddingRequest) + bifrostReq := bifrost.getBifrostRequest() + bifrostReq.Provider = req.Provider + bifrostReq.Model = req.Model + bifrostReq.Fallbacks = req.Fallbacks + bifrostReq.RequestType = schemas.EmbeddingRequest + bifrostReq.EmbeddingRequest = req + + return bifrost.handleRequest(ctx, bifrostReq) } // SpeechRequest sends a speech request to the specified provider. -func (bifrost *Bifrost) SpeechRequest(ctx context.Context, req *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { - if req.Input.SpeechInput == nil { +func (bifrost *Bifrost) SpeechRequest(ctx context.Context, req *schemas.BifrostSpeechRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { + if req.Input.Input == "" { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: schemas.ErrorField{ @@ -248,12 +324,19 @@ func (bifrost *Bifrost) SpeechRequest(ctx context.Context, req *schemas.BifrostR } } - return bifrost.handleRequest(ctx, req, schemas.SpeechRequest) + bifrostReq := bifrost.getBifrostRequest() + bifrostReq.Provider = req.Provider + bifrostReq.Model = req.Model + bifrostReq.Fallbacks = req.Fallbacks + bifrostReq.RequestType = schemas.SpeechRequest + bifrostReq.SpeechRequest = req + + return bifrost.handleRequest(ctx, bifrostReq) } // SpeechStreamRequest sends a speech stream request to the specified provider. -func (bifrost *Bifrost) SpeechStreamRequest(ctx context.Context, req *schemas.BifrostRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { - if req.Input.SpeechInput == nil { +func (bifrost *Bifrost) SpeechStreamRequest(ctx context.Context, req *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { + if req.Input.Input == "" { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: schemas.ErrorField{ @@ -262,12 +345,19 @@ func (bifrost *Bifrost) SpeechStreamRequest(ctx context.Context, req *schemas.Bi } } - return bifrost.handleStreamRequest(ctx, req, schemas.SpeechStreamRequest) + bifrostReq := bifrost.getBifrostRequest() + bifrostReq.Provider = req.Provider + bifrostReq.Model = req.Model + bifrostReq.Fallbacks = req.Fallbacks + bifrostReq.RequestType = schemas.SpeechStreamRequest + bifrostReq.SpeechRequest = req + + return bifrost.handleStreamRequest(ctx, bifrostReq) } // TranscriptionRequest sends a transcription request to the specified provider. -func (bifrost *Bifrost) TranscriptionRequest(ctx context.Context, req *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { - if req.Input.TranscriptionInput == nil { +func (bifrost *Bifrost) TranscriptionRequest(ctx context.Context, req *schemas.BifrostTranscriptionRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { + if req.Input.File == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: schemas.ErrorField{ @@ -276,12 +366,19 @@ func (bifrost *Bifrost) TranscriptionRequest(ctx context.Context, req *schemas.B } } - return bifrost.handleRequest(ctx, req, schemas.TranscriptionRequest) + bifrostReq := bifrost.getBifrostRequest() + bifrostReq.Provider = req.Provider + bifrostReq.Model = req.Model + bifrostReq.Fallbacks = req.Fallbacks + bifrostReq.RequestType = schemas.TranscriptionRequest + bifrostReq.TranscriptionRequest = req + + return bifrost.handleRequest(ctx, bifrostReq) } // TranscriptionStreamRequest sends a transcription stream request to the specified provider. -func (bifrost *Bifrost) TranscriptionStreamRequest(ctx context.Context, req *schemas.BifrostRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { - if req.Input.TranscriptionInput == nil { +func (bifrost *Bifrost) TranscriptionStreamRequest(ctx context.Context, req *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { + if req.Input.File == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: schemas.ErrorField{ @@ -290,7 +387,14 @@ func (bifrost *Bifrost) TranscriptionStreamRequest(ctx context.Context, req *sch } } - return bifrost.handleStreamRequest(ctx, req, schemas.TranscriptionStreamRequest) + bifrostReq := bifrost.getBifrostRequest() + bifrostReq.Provider = req.Provider + bifrostReq.Model = req.Model + bifrostReq.Fallbacks = req.Fallbacks + bifrostReq.RequestType = schemas.TranscriptionStreamRequest + bifrostReq.TranscriptionRequest = req + + return bifrost.handleStreamRequest(ctx, bifrostReq) } // UpdateProviderConcurrency dynamically updates the queue size and concurrency for an existing provider. @@ -467,7 +571,7 @@ func (bifrost *Bifrost) getProviderMutex(providerKey schemas.ModelProvider) *syn // func(args EchoArgs) (string, error) { // return args.Message, nil // }, toolSchema) -func (bifrost *Bifrost) RegisterMCPTool(name, description string, handler func(args any) (string, error), toolSchema schemas.Tool) error { +func (bifrost *Bifrost) RegisterMCPTool(name, description string, handler func(args any) (string, error), toolSchema schemas.ChatTool) error { if bifrost.mcpManager == nil { return fmt.Errorf("MCP is not configured in this Bifrost instance") } @@ -483,9 +587,9 @@ func (bifrost *Bifrost) RegisterMCPTool(name, description string, handler func(a // - toolCall: The tool call to execute (from assistant message) // // Returns: -// - schemas.BifrostMessage: Tool message with execution result +// - schemas.ChatMessage: Tool message with execution result // - schemas.BifrostError: Any execution error -func (bifrost *Bifrost) ExecuteMCPTool(ctx context.Context, toolCall schemas.ToolCall) (*schemas.BifrostMessage, *schemas.BifrostError) { +func (bifrost *Bifrost) ExecuteMCPTool(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, *schemas.BifrostError) { if bifrost.mcpManager == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, @@ -789,20 +893,17 @@ func (bifrost *Bifrost) shouldTryFallbacks(req *schemas.BifrostRequest, primaryE // Handle request cancellation if primaryErr.Error.Type != nil && *primaryErr.Error.Type == schemas.RequestCancelled { - primaryErr.Provider = req.Provider return false } // Check if this is a short-circuit error that doesn't allow fallbacks // Note: AllowFallbacks = nil is treated as true (allow fallbacks by default) if primaryErr.AllowFallbacks != nil && !*primaryErr.AllowFallbacks { - primaryErr.Provider = req.Provider return false } // If no fallbacks configured, return primary error if len(req.Fallbacks) == 0 { - primaryErr.Provider = req.Provider return false } @@ -822,8 +923,52 @@ func (bifrost *Bifrost) prepareFallbackRequest(req *schemas.BifrostRequest, fall // Create a new request with the fallback provider and model fallbackReq := *req + + if req.TextCompletionRequest != nil { + tmp := *req.TextCompletionRequest + tmp.Provider = fallback.Provider + tmp.Model = fallback.Model + fallbackReq.TextCompletionRequest = &tmp + } + + if req.ChatRequest != nil { + tmp := *req.ChatRequest + tmp.Provider = fallback.Provider + tmp.Model = fallback.Model + fallbackReq.ChatRequest = &tmp + } + + if req.ResponsesRequest != nil { + tmp := *req.ResponsesRequest + tmp.Provider = fallback.Provider + tmp.Model = fallback.Model + fallbackReq.ResponsesRequest = &tmp + } + + if req.EmbeddingRequest != nil { + tmp := *req.EmbeddingRequest + tmp.Provider = fallback.Provider + tmp.Model = fallback.Model + fallbackReq.EmbeddingRequest = &tmp + } + + if req.SpeechRequest != nil { + tmp := *req.SpeechRequest + tmp.Provider = fallback.Provider + tmp.Model = fallback.Model + fallbackReq.SpeechRequest = &tmp + } + + if req.TranscriptionRequest != nil { + tmp := *req.TranscriptionRequest + tmp.Provider = fallback.Provider + tmp.Model = fallback.Model + fallbackReq.TranscriptionRequest = &tmp + } + fallbackReq.Provider = fallback.Provider fallbackReq.Model = fallback.Model + return &fallbackReq } @@ -831,13 +976,11 @@ func (bifrost *Bifrost) prepareFallbackRequest(req *schemas.BifrostRequest, fall // Returns true if we should continue with more fallbacks, false if we should stop func (bifrost *Bifrost) shouldContinueWithFallbacks(fallback schemas.Fallback, fallbackErr *schemas.BifrostError) bool { if fallbackErr.Error.Type != nil && *fallbackErr.Error.Type == schemas.RequestCancelled { - fallbackErr.Provider = fallback.Provider return false } // Check if it was a short-circuit error that doesn't allow fallbacks if fallbackErr.AllowFallbacks != nil && !*fallbackErr.AllowFallbacks { - fallbackErr.Provider = fallback.Provider return false } @@ -849,9 +992,13 @@ func (bifrost *Bifrost) shouldContinueWithFallbacks(fallback schemas.Fallback, f // It handles plugin hooks, request validation, response processing, and fallback providers. // If the primary provider fails, it will try each fallback provider in order until one succeeds. // It is the wrapper for all non-streaming public API methods. -func (bifrost *Bifrost) handleRequest(ctx context.Context, req *schemas.BifrostRequest, requestType schemas.RequestType) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (bifrost *Bifrost) handleRequest(ctx context.Context, req *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { if err := validateRequest(req); err != nil { - err.Provider = req.Provider + err.ExtraFields = schemas.BifrostErrorExtraFields{ + Provider: req.Provider, + ModelRequested: req.Model, + RequestType: req.RequestType, + } return nil, err } @@ -861,7 +1008,7 @@ func (bifrost *Bifrost) handleRequest(ctx context.Context, req *schemas.BifrostR } // Try the primary provider first - primaryResult, primaryErr := bifrost.tryRequest(req, ctx, requestType) + primaryResult, primaryErr := bifrost.tryRequest(req, ctx) // Check if we should proceed with fallbacks shouldTryFallbacks := bifrost.shouldTryFallbacks(req, primaryErr) @@ -877,7 +1024,7 @@ func (bifrost *Bifrost) handleRequest(ctx context.Context, req *schemas.BifrostR } // Try the fallback provider - result, fallbackErr := bifrost.tryRequest(fallbackReq, ctx, requestType) + result, fallbackErr := bifrost.tryRequest(fallbackReq, ctx) if fallbackErr == nil { bifrost.logger.Info(fmt.Sprintf("Successfully used fallback provider %s with model %s", fallback.Provider, fallback.Model)) return result, nil @@ -889,7 +1036,6 @@ func (bifrost *Bifrost) handleRequest(ctx context.Context, req *schemas.BifrostR } } - primaryErr.Provider = req.Provider // All providers failed, return the original error return nil, primaryErr } @@ -898,9 +1044,13 @@ func (bifrost *Bifrost) handleRequest(ctx context.Context, req *schemas.BifrostR // It handles plugin hooks, request validation, response processing, and fallback providers. // If the primary provider fails, it will try each fallback provider in order until one succeeds. // It is the wrapper for all streaming public API methods. -func (bifrost *Bifrost) handleStreamRequest(ctx context.Context, req *schemas.BifrostRequest, requestType schemas.RequestType) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (bifrost *Bifrost) handleStreamRequest(ctx context.Context, req *schemas.BifrostRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { if err := validateRequest(req); err != nil { - err.Provider = req.Provider + err.ExtraFields = schemas.BifrostErrorExtraFields{ + Provider: req.Provider, + ModelRequested: req.Model, + RequestType: req.RequestType, + } return nil, err } @@ -910,7 +1060,7 @@ func (bifrost *Bifrost) handleStreamRequest(ctx context.Context, req *schemas.Bi } // Try the primary provider first - primaryResult, primaryErr := bifrost.tryStreamRequest(req, ctx, requestType) + primaryResult, primaryErr := bifrost.tryStreamRequest(req, ctx) // Check if we should proceed with fallbacks shouldTryFallbacks := bifrost.shouldTryFallbacks(req, primaryErr) @@ -926,7 +1076,7 @@ func (bifrost *Bifrost) handleStreamRequest(ctx context.Context, req *schemas.Bi } // Try the fallback provider - result, fallbackErr := bifrost.tryStreamRequest(fallbackReq, ctx, requestType) + result, fallbackErr := bifrost.tryStreamRequest(fallbackReq, ctx) if fallbackErr == nil { bifrost.logger.Info(fmt.Sprintf("Successfully used fallback provider %s with model %s", fallback.Provider, fallback.Model)) return result, nil @@ -937,27 +1087,22 @@ func (bifrost *Bifrost) handleStreamRequest(ctx context.Context, req *schemas.Bi return nil, fallbackErr } } - - primaryErr.Provider = req.Provider // All providers failed, return the original error return nil, primaryErr } // tryRequest is a generic function that handles common request processing logic // It consolidates queue setup, plugin pipeline execution, enqueue logic, and response handling -func (bifrost *Bifrost) tryRequest(req *schemas.BifrostRequest, ctx context.Context, requestType schemas.RequestType) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (bifrost *Bifrost) tryRequest(req *schemas.BifrostRequest, ctx context.Context) (*schemas.BifrostResponse, *schemas.BifrostError) { queue, err := bifrost.getProviderQueue(req.Provider) if err != nil { return nil, newBifrostError(err) } - // Attach context keys to the context - ctx = attachContextKeys(ctx, req, requestType) - // Add MCP tools to request if MCP is configured and requested - if requestType != schemas.EmbeddingRequest && - requestType != schemas.SpeechRequest && - requestType != schemas.TranscriptionRequest && + if req.RequestType != schemas.EmbeddingRequest && + req.RequestType != schemas.SpeechRequest && + req.RequestType != schemas.TranscriptionRequest && bifrost.mcpManager != nil { req = bifrost.mcpManager.addMCPToolsToBifrostRequest(ctx, req) } @@ -988,7 +1133,7 @@ func (bifrost *Bifrost) tryRequest(req *schemas.BifrostRequest, ctx context.Cont return nil, newBifrostErrorFromMsg("bifrost request after plugin hooks cannot be nil") } - msg := bifrost.getChannelMessage(*preReq, requestType) + msg := bifrost.getChannelMessage(*preReq) msg.Context = ctx select { @@ -1036,17 +1181,14 @@ func (bifrost *Bifrost) tryRequest(req *schemas.BifrostRequest, ctx context.Cont // tryStreamRequest is a generic function that handles common request processing logic // It consolidates queue setup, plugin pipeline execution, enqueue logic, and response handling -func (bifrost *Bifrost) tryStreamRequest(req *schemas.BifrostRequest, ctx context.Context, requestType schemas.RequestType) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (bifrost *Bifrost) tryStreamRequest(req *schemas.BifrostRequest, ctx context.Context) (chan *schemas.BifrostStream, *schemas.BifrostError) { queue, err := bifrost.getProviderQueue(req.Provider) if err != nil { return nil, newBifrostError(err) } - // Attach context keys to the context - ctx = attachContextKeys(ctx, req, requestType) - // Add MCP tools to request if MCP is configured and requested - if requestType != schemas.SpeechStreamRequest && requestType != schemas.TranscriptionStreamRequest && bifrost.mcpManager != nil { + if req.RequestType != schemas.SpeechStreamRequest && req.RequestType != schemas.TranscriptionStreamRequest && bifrost.mcpManager != nil { req = bifrost.mcpManager.addMCPToolsToBifrostRequest(ctx, req) } @@ -1106,7 +1248,7 @@ func (bifrost *Bifrost) tryStreamRequest(req *schemas.BifrostRequest, ctx contex return nil, newBifrostErrorFromMsg("bifrost request after plugin hooks cannot be nil") } - msg := bifrost.getChannelMessage(*preReq, requestType) + msg := bifrost.getChannelMessage(*preReq) msg.Context = ctx select { @@ -1195,7 +1337,7 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas // Create plugin pipeline for streaming requests outside retry loop to prevent leaks var postHookRunner schemas.PostHookRunner - if IsStreamRequestType(req.Type) { + if IsStreamRequestType(req.RequestType) { pipeline := bifrost.getPluginPipeline() defer bifrost.releasePluginPipeline(pipeline) @@ -1222,13 +1364,13 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas bifrost.logger.Debug("attempting request for provider %s", provider.GetProviderKey()) // Attempt the request - if IsStreamRequestType(req.Type) { - stream, bifrostError = handleProviderStreamRequest(provider, &req, key, postHookRunner, req.Type) + if IsStreamRequestType(req.RequestType) { + stream, bifrostError = handleProviderStreamRequest(provider, &req, key, postHookRunner) if bifrostError != nil && !bifrostError.IsBifrostError { break // Don't retry client errors } } else { - result, bifrostError = handleProviderRequest(provider, &req, key, req.Type) + result, bifrostError = handleProviderRequest(provider, &req, key) if bifrostError != nil { break // Don't retry client errors } @@ -1250,6 +1392,12 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas if attempts > 0 { bifrost.logger.Warn("request failed after %d %s", attempts, map[bool]string{true: "retries", false: "retry"}[attempts > 1]) } + bifrostError.ExtraFields = schemas.BifrostErrorExtraFields{ + Provider: provider.GetProviderKey(), + ModelRequested: req.Model, + RequestType: req.RequestType, + } + // Send error with context awareness to prevent deadlock select { case req.Err <- *bifrostError: @@ -1262,7 +1410,7 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas bifrost.logger.Warn("Timeout while sending error response, client may have disconnected") } } else { - if IsStreamRequestType(req.Type) { + if IsStreamRequestType(req.RequestType) { // Send stream with context awareness to prevent deadlock select { case req.ResponseStream <- stream: @@ -1275,6 +1423,10 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas bifrost.logger.Warn("Timeout while sending stream response, client may have disconnected") } } else { + result.ExtraFields.RequestType = req.RequestType + result.ExtraFields.Provider = provider.GetProviderKey() + result.ExtraFields.ModelRequested = req.Model + // Send response with context awareness to prevent deadlock select { case req.Response <- result: @@ -1294,42 +1446,46 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas } // handleProviderRequest handles the request to the provider based on the request type -func handleProviderRequest(provider schemas.Provider, req *ChannelMessage, key schemas.Key, reqType schemas.RequestType) (*schemas.BifrostResponse, *schemas.BifrostError) { - switch reqType { +func handleProviderRequest(provider schemas.Provider, req *ChannelMessage, key schemas.Key) (*schemas.BifrostResponse, *schemas.BifrostError) { + switch req.RequestType { case schemas.TextCompletionRequest: - return provider.TextCompletion(req.Context, key, &req.BifrostRequest) + return provider.TextCompletion(req.Context, key, req.BifrostRequest.TextCompletionRequest) case schemas.ChatCompletionRequest: - return provider.ChatCompletion(req.Context, key, &req.BifrostRequest) + return provider.ChatCompletion(req.Context, key, req.BifrostRequest.ChatRequest) + case schemas.ResponsesRequest: + return provider.Responses(req.Context, key, req.BifrostRequest.ResponsesRequest) case schemas.EmbeddingRequest: - return provider.Embedding(req.Context, key, &req.BifrostRequest) + return provider.Embedding(req.Context, key, req.BifrostRequest.EmbeddingRequest) case schemas.SpeechRequest: - return provider.Speech(req.Context, key, &req.BifrostRequest) + return provider.Speech(req.Context, key, req.BifrostRequest.SpeechRequest) case schemas.TranscriptionRequest: - return provider.Transcription(req.Context, key, &req.BifrostRequest) + return provider.Transcription(req.Context, key, req.BifrostRequest.TranscriptionRequest) default: return nil, &schemas.BifrostError{ IsBifrostError: false, Error: schemas.ErrorField{ - Message: fmt.Sprintf("unsupported request type: %s", reqType), + Message: fmt.Sprintf("unsupported request type: %s", req.RequestType), }, } } } // handleProviderStreamRequest handles the stream request to the provider based on the request type -func handleProviderStreamRequest(provider schemas.Provider, req *ChannelMessage, key schemas.Key, postHookRunner schemas.PostHookRunner, reqType schemas.RequestType) (chan *schemas.BifrostStream, *schemas.BifrostError) { - switch reqType { +func handleProviderStreamRequest(provider schemas.Provider, req *ChannelMessage, key schemas.Key, postHookRunner schemas.PostHookRunner) (chan *schemas.BifrostStream, *schemas.BifrostError) { + switch req.RequestType { case schemas.ChatCompletionStreamRequest: - return provider.ChatCompletionStream(req.Context, postHookRunner, key, &req.BifrostRequest) + return provider.ChatCompletionStream(req.Context, postHookRunner, key, req.BifrostRequest.ChatRequest) + case schemas.ResponsesStreamRequest: + return provider.ResponsesStream(req.Context, postHookRunner, key, req.BifrostRequest.ResponsesRequest) case schemas.SpeechStreamRequest: - return provider.SpeechStream(req.Context, postHookRunner, key, &req.BifrostRequest) + return provider.SpeechStream(req.Context, postHookRunner, key, req.BifrostRequest.SpeechRequest) case schemas.TranscriptionStreamRequest: - return provider.TranscriptionStream(req.Context, postHookRunner, key, &req.BifrostRequest) + return provider.TranscriptionStream(req.Context, postHookRunner, key, req.BifrostRequest.TranscriptionRequest) default: return nil, &schemas.BifrostError{ IsBifrostError: false, Error: schemas.ErrorField{ - Message: fmt.Sprintf("unsupported request type: %s", reqType), + Message: fmt.Sprintf("unsupported request type: %s", req.RequestType), }, } } @@ -1411,11 +1567,38 @@ func (bifrost *Bifrost) releasePluginPipeline(pipeline *PluginPipeline) { bifrost.pluginPipelinePool.Put(pipeline) } +// resetBifrostRequest resets a BifrostRequest instance for reuse +func resetBifrostRequest(req *schemas.BifrostRequest) { + req.Provider = "" + req.Model = "" + req.Fallbacks = nil + req.RequestType = "" + req.TextCompletionRequest = nil + req.ChatRequest = nil + req.ResponsesRequest = nil + req.EmbeddingRequest = nil + req.SpeechRequest = nil + req.TranscriptionRequest = nil +} + +// getBifrostRequest gets a BifrostRequest from the pool +func (bifrost *Bifrost) getBifrostRequest() *schemas.BifrostRequest { + req := bifrost.bifrostRequestPool.Get().(*schemas.BifrostRequest) + resetBifrostRequest(req) + return req +} + +// releaseBifrostRequest returns a BifrostRequest to the pool +func (bifrost *Bifrost) releaseBifrostRequest(req *schemas.BifrostRequest) { + resetBifrostRequest(req) + bifrost.bifrostRequestPool.Put(req) +} + // POOL & RESOURCE MANAGEMENT // getChannelMessage gets a ChannelMessage from the pool and configures it with the request. // It also gets response and error channels from their respective pools. -func (bifrost *Bifrost) getChannelMessage(req schemas.BifrostRequest, reqType schemas.RequestType) *ChannelMessage { +func (bifrost *Bifrost) getChannelMessage(req schemas.BifrostRequest) *ChannelMessage { // Get channels from pool responseChan := bifrost.responseChannelPool.Get().(chan *schemas.BifrostResponse) errorChan := bifrost.errorChannelPool.Get().(chan schemas.BifrostError) @@ -1435,10 +1618,9 @@ func (bifrost *Bifrost) getChannelMessage(req schemas.BifrostRequest, reqType sc msg.BifrostRequest = req msg.Response = responseChan msg.Err = errorChan - msg.Type = reqType // Conditionally allocate ResponseStream for streaming requests only - if IsStreamRequestType(reqType) { + if IsStreamRequestType(req.RequestType) { responseStreamChan := bifrost.responseStreamPool.Get().(chan chan *schemas.BifrostStream) // Clear any previous values to avoid leaking between requests select { @@ -1467,6 +1649,9 @@ func (bifrost *Bifrost) releaseChannelMessage(msg *ChannelMessage) { bifrost.responseStreamPool.Put(msg.ResponseStream) } + // Reset and return BifrostRequest to pool + bifrost.releaseBifrostRequest(&msg.BifrostRequest) + // Clear references and return to pool msg.Response = nil msg.ResponseStream = nil diff --git a/core/mcp.go b/core/mcp.go index fc7b19d484..1f50f4ab54 100644 --- a/core/mcp.go +++ b/core/mcp.go @@ -56,12 +56,12 @@ type MCPManager struct { // MCPClient represents a connected MCP client with its configuration and tools. type MCPClient struct { - Name string // Unique name for this client - Conn *client.Client // Active MCP client connection - ExecutionConfig schemas.MCPClientConfig // Tool filtering settings - ToolMap map[string]schemas.Tool // Available tools mapped by name - ConnectionInfo MCPClientConnectionInfo `json:"connection_info"` // Connection metadata for management - cancelFunc context.CancelFunc `json:"-"` // Cancel function for SSE connections (not serialized) + Name string // Unique name for this client + Conn *client.Client // Active MCP client connection + ExecutionConfig schemas.MCPClientConfig // Tool filtering settings + ToolMap map[string]schemas.ChatTool // Available tools mapped by name + ConnectionInfo MCPClientConnectionInfo `json:"connection_info"` // Connection metadata for management + cancelFunc context.CancelFunc `json:"-"` // Cancel function for SSE connections (not serialized) } // MCPClientConnectionInfo stores metadata about how a client is connected. @@ -173,7 +173,7 @@ func (m *MCPManager) AddClient(config schemas.MCPClientConfig) error { m.clientMap[config.Name] = &MCPClient{ Name: config.Name, ExecutionConfig: config, - ToolMap: make(map[string]schemas.Tool), + ToolMap: make(map[string]schemas.ChatTool), } // Temporarily unlock for the connection attempt @@ -228,7 +228,7 @@ func (m *MCPManager) removeClientUnsafe(name string) error { } // Clear client tool map - client.ToolMap = make(map[string]schemas.Tool) + client.ToolMap = make(map[string]schemas.ChatTool) delete(m.clientMap, name) return nil @@ -256,7 +256,7 @@ func (m *MCPManager) EditClientTools(name string, toolsToAdd []string, toolsToRe client.ExecutionConfig = config // Clear current tool map - client.ToolMap = make(map[string]schemas.Tool) + client.ToolMap = make(map[string]schemas.ChatTool) // Temporarily unlock for the network call m.mu.Unlock() @@ -288,7 +288,7 @@ func (m *MCPManager) EditClientTools(name string, toolsToAdd []string, toolsToRe // getAvailableTools returns all tools from connected MCP clients. // Applies client filtering if specified in the context. -func (m *MCPManager) getAvailableTools(ctx context.Context) []schemas.Tool { +func (m *MCPManager) getAvailableTools(ctx context.Context) []schemas.ChatTool { m.mu.RLock() defer m.mu.RUnlock() @@ -303,7 +303,7 @@ func (m *MCPManager) getAvailableTools(ctx context.Context) []schemas.Tool { excludeClients = existingExcludeClients } - tools := make([]schemas.Tool, 0) + tools := make([]schemas.ChatTool, 0) for clientName, client := range m.clientMap { // Apply client filtering logic if !m.shouldIncludeClient(clientName, includeClients, excludeClients) { @@ -348,7 +348,7 @@ func (m *MCPManager) getAvailableTools(ctx context.Context) []schemas.Tool { // func(args EchoArgs) (string, error) { // return args.Message, nil // }, toolSchema) -func (m *MCPManager) registerTool(name, description string, handler MCPToolHandler[any], toolSchema schemas.Tool) error { +func (m *MCPManager) registerTool(name, description string, handler MCPToolHandler[any], toolSchema schemas.ChatTool) error { // Ensure local server is set up if err := m.setupLocalHost(); err != nil { return fmt.Errorf("failed to setup local host: %w", err) @@ -453,7 +453,7 @@ func (m *MCPManager) createLocalMCPClient() (*MCPClient, error) { ExecutionConfig: schemas.MCPClientConfig{ Name: BifrostMCPClientName, }, - ToolMap: make(map[string]schemas.Tool), + ToolMap: make(map[string]schemas.ChatTool), ConnectionInfo: MCPClientConnectionInfo{ Type: schemas.MCPConnectionTypeInProcess, // Accurate: in-process (in-memory) transport }, @@ -524,9 +524,9 @@ func (m *MCPManager) startLocalMCPServer() error { // - toolCall: The tool call to execute (from assistant message) // // Returns: -// - schemas.BifrostMessage: Tool message with execution result +// - schemas.ChatMessage: Tool message with execution result // - error: Any execution error -func (m *MCPManager) executeTool(ctx context.Context, toolCall schemas.ToolCall) (*schemas.BifrostMessage, error) { +func (m *MCPManager) executeTool(ctx context.Context, toolCall schemas.ChatAssistantMessageToolCall) (*schemas.ChatMessage, error) { if toolCall.Function.Name == nil { return nil, fmt.Errorf("tool call missing function name") } @@ -600,7 +600,7 @@ func (m *MCPManager) connectToMCPClient(config schemas.MCPClientConfig) error { m.clientMap[config.Name] = &MCPClient{ Name: config.Name, ExecutionConfig: config, - ToolMap: make(map[string]schemas.Tool), + ToolMap: make(map[string]schemas.ChatTool), ConnectionInfo: MCPClientConnectionInfo{ Type: config.ConnectionType, }, @@ -679,7 +679,7 @@ func (m *MCPManager) connectToMCPClient(config schemas.MCPClientConfig) error { if err != nil { m.logger.Warn(fmt.Sprintf("%s Failed to retrieve tools from %s: %v", MCPLogPrefix, config.Name, err)) // Continue with connection even if tool retrieval fails - tools = make(map[string]schemas.Tool) + tools = make(map[string]schemas.ChatTool) } // Second lock: Update client with final connection details and tools @@ -711,7 +711,7 @@ func (m *MCPManager) connectToMCPClient(config schemas.MCPClientConfig) error { } // retrieveExternalTools retrieves and filters tools from an external MCP server without holding locks. -func (m *MCPManager) retrieveExternalTools(ctx context.Context, client *client.Client, config schemas.MCPClientConfig) (map[string]schemas.Tool, error) { +func (m *MCPManager) retrieveExternalTools(ctx context.Context, client *client.Client, config schemas.MCPClientConfig) (map[string]schemas.ChatTool, error) { // Get available tools from external server listRequest := mcp.ListToolsRequest{ PaginatedRequest: mcp.PaginatedRequest{ @@ -727,10 +727,10 @@ func (m *MCPManager) retrieveExternalTools(ctx context.Context, client *client.C } if toolsResponse == nil { - return make(map[string]schemas.Tool), nil // No tools available + return make(map[string]schemas.ChatTool), nil // No tools available } - tools := make(map[string]schemas.Tool) + tools := make(map[string]schemas.ChatTool) // toolsResponse is already a ListToolsResult for _, mcpTool := range toolsResponse.Tools { @@ -796,13 +796,13 @@ func (m *MCPManager) shouldSkipToolForRequest(toolName string, ctx context.Conte } // convertMCPToolToBifrostSchema converts an MCP tool definition to Bifrost format. -func (m *MCPManager) convertMCPToolToBifrostSchema(mcpTool *mcp.Tool) schemas.Tool { - return schemas.Tool{ - Type: "function", - Function: schemas.Function{ +func (m *MCPManager) convertMCPToolToBifrostSchema(mcpTool *mcp.Tool) schemas.ChatTool { + return schemas.ChatTool{ + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ Name: mcpTool.Name, - Description: mcpTool.Description, - Parameters: schemas.FunctionParameters{ + Description: Ptr(mcpTool.Description), + Parameters: &schemas.ToolFunctionParameters{ Type: mcpTool.InputSchema.Type, Properties: mcpTool.InputSchema.Properties, Required: mcpTool.InputSchema.Required, @@ -852,13 +852,13 @@ func (m *MCPManager) extractTextFromMCPResponse(toolResponse *mcp.CallToolResult } // createToolResponseMessage creates a tool response message with the execution result. -func (m *MCPManager) createToolResponseMessage(toolCall schemas.ToolCall, responseText string) *schemas.BifrostMessage { - return &schemas.BifrostMessage{ - Role: schemas.ModelChatMessageRoleTool, - Content: schemas.MessageContent{ +func (m *MCPManager) createToolResponseMessage(toolCall schemas.ChatAssistantMessageToolCall, responseText string) *schemas.ChatMessage { + return &schemas.ChatMessage{ + Role: schemas.ChatMessageRoleTool, + Content: schemas.ChatMessageContent{ ContentStr: &responseText, }, - ToolMessage: &schemas.ToolMessage{ + ChatToolMessage: &schemas.ChatToolMessage{ ToolCallID: toolCall.ID, }, } @@ -867,31 +867,74 @@ func (m *MCPManager) createToolResponseMessage(toolCall schemas.ToolCall, respon func (m *MCPManager) addMCPToolsToBifrostRequest(ctx context.Context, req *schemas.BifrostRequest) *schemas.BifrostRequest { mcpTools := m.getAvailableTools(ctx) if len(mcpTools) > 0 { - // Initialize tools array if needed - if req.Params == nil { - req.Params = &schemas.ModelParameters{} - } - if req.Params.Tools == nil { - req.Params.Tools = &[]schemas.Tool{} - } - tools := *req.Params.Tools + switch req.RequestType { + case schemas.ChatCompletionRequest, schemas.ChatCompletionStreamRequest: + // Only allocate new Params if it's nil to preserve caller-supplied settings + if req.ChatRequest.Params == nil { + req.ChatRequest.Params = &schemas.ChatParameters{} + } - // Create a map of existing tool names for O(1) lookup - existingToolsMap := make(map[string]bool) - for _, tool := range tools { - existingToolsMap[tool.Function.Name] = true - } + tools := req.ChatRequest.Params.Tools - // Add MCP tools that are not already present - for _, mcpTool := range mcpTools { - if !existingToolsMap[mcpTool.Function.Name] { - tools = append(tools, mcpTool) - // Update the map to prevent duplicates within MCP tools as well - existingToolsMap[mcpTool.Function.Name] = true + // Create a map of existing tool names for O(1) lookup + existingToolsMap := make(map[string]bool) + for _, tool := range tools { + if tool.Function != nil && tool.Function.Name != "" { + existingToolsMap[tool.Function.Name] = true + } + } + + // Add MCP tools that are not already present + for _, mcpTool := range mcpTools { + // Skip tools with nil Function or empty Name + if mcpTool.Function == nil || mcpTool.Function.Name == "" { + continue + } + + if !existingToolsMap[mcpTool.Function.Name] { + tools = append(tools, mcpTool) + // Update the map to prevent duplicates within MCP tools as well + existingToolsMap[mcpTool.Function.Name] = true + } + } + req.ChatRequest.Params.Tools = tools + case schemas.ResponsesRequest, schemas.ResponsesStreamRequest: + // Only allocate new Params if it's nil to preserve caller-supplied settings + if req.ResponsesRequest.Params == nil { + req.ResponsesRequest.Params = &schemas.ResponsesParameters{} } - } - req.Params.Tools = &tools + tools := req.ResponsesRequest.Params.Tools + + // Create a map of existing tool names for O(1) lookup + existingToolsMap := make(map[string]bool) + for _, tool := range tools { + if tool.Name != nil { + existingToolsMap[*tool.Name] = true + } + } + + // Add MCP tools that are not already present + for _, mcpTool := range mcpTools { + // Skip tools with nil Function or empty Name + if mcpTool.Function == nil || mcpTool.Function.Name == "" { + continue + } + + if !existingToolsMap[mcpTool.Function.Name] { + responsesTool := mcpTool.ToResponsesTool() + // Skip if the converted tool has nil Name + if responsesTool.Name == nil { + continue + } + + tools = append(tools, *responsesTool) + // Update the map to prevent duplicates within MCP tools as well + existingToolsMap[*responsesTool.Name] = true + } + } + req.ResponsesRequest.Params.Tools = tools + } } return req } diff --git a/core/providers/anthropic.go b/core/providers/anthropic.go index 3ffec408ad..6199cbaa4e 100644 --- a/core/providers/anthropic.go +++ b/core/providers/anthropic.go @@ -33,7 +33,7 @@ type AnthropicProvider struct { // anthropicChatResponsePool provides a pool for Anthropic chat response objects. var anthropicChatResponsePool = sync.Pool{ New: func() interface{} { - return &anthropic.AnthropicChatResponse{} + return &anthropic.AnthropicMessageResponse{} }, } @@ -175,14 +175,14 @@ func (provider *AnthropicProvider) completeRequest(ctx context.Context, requestB // TextCompletion performs a text completion request to Anthropic's API. // It formats the request, sends it to Anthropic, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. -func (provider *AnthropicProvider) TextCompletion(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { - if err := checkOperationAllowed(schemas.Anthropic, provider.customProviderConfig, schemas.OperationTextCompletion); err != nil { +func (provider *AnthropicProvider) TextCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { + if err := checkOperationAllowed(schemas.Anthropic, provider.customProviderConfig, schemas.TextCompletionRequest); err != nil { return nil, err } // Convert to Anthropic format using the centralized converter - anthropicReq := anthropic.ToAnthropicTextCompletionRequest(input) - if anthropicReq == nil { + reqBody := anthropic.ToAnthropicTextCompletionRequest(request) + if reqBody == nil { return nil, newBifrostOperationError("text completion input is not provided", nil, provider.GetProviderKey()) } @@ -203,30 +203,31 @@ func (provider *AnthropicProvider) TextCompletion(ctx context.Context, key schem bifrostResponse := response.ToBifrostResponse() + // Set ExtraFields + bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() + bifrostResponse.ExtraFields.ModelRequested = request.Model + bifrostResponse.ExtraFields.RequestType = schemas.TextCompletionRequest + // Set raw response if enabled if provider.sendBackRawResponse { bifrostResponse.ExtraFields.RawResponse = rawResponse } - if input.Params != nil { - bifrostResponse.ExtraFields.Params = *input.Params - } - return bifrostResponse, nil } // ChatCompletion performs a chat completion request to Anthropic's API. // It formats the request, sends it to Anthropic, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. -func (provider *AnthropicProvider) ChatCompletion(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { - if err := checkOperationAllowed(schemas.Anthropic, provider.customProviderConfig, schemas.OperationChatCompletion); err != nil { +func (provider *AnthropicProvider) ChatCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { + if err := checkOperationAllowed(schemas.Anthropic, provider.customProviderConfig, schemas.ChatCompletionRequest); err != nil { return nil, err } // Convert to Anthropic format using the centralized converter - anthropicReq := anthropic.ToAnthropicChatCompletionRequest(input) - if anthropicReq == nil { - return nil, newBifrostOperationError("failed to convert request", fmt.Errorf("conversion returned nil"), provider.GetProviderKey()) + reqBody := anthropic.ToAnthropicChatCompletionRequest(request) + if reqBody == nil { + return nil, newBifrostOperationError("chat completion input is not provided", nil, provider.GetProviderKey()) } // Use struct directly for JSON marshaling @@ -247,34 +248,80 @@ func (provider *AnthropicProvider) ChatCompletion(ctx context.Context, key schem // Create final response bifrostResponse := response.ToBifrostResponse() + // Set ExtraFields + bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() + bifrostResponse.ExtraFields.ModelRequested = request.Model + bifrostResponse.ExtraFields.RequestType = schemas.ChatCompletionRequest + // Set raw response if enabled if provider.sendBackRawResponse { bifrostResponse.ExtraFields.RawResponse = rawResponse } - if input.Params != nil { - bifrostResponse.ExtraFields.Params = *input.Params + return bifrostResponse, nil +} + +// Responses performs a chat completion request to Anthropic's API. +// It formats the request, sends it to Anthropic, and processes the response. +// Returns a BifrostResponse containing the completion results or an error if the request fails. +func (provider *AnthropicProvider) Responses(ctx context.Context, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { + if err := checkOperationAllowed(schemas.Anthropic, provider.customProviderConfig, schemas.ResponsesRequest); err != nil { + return nil, err + } + + // Convert to Anthropic format using the centralized converter + reqBody := anthropic.ToAnthropicResponsesRequest(request) + if reqBody == nil { + return nil, newBifrostOperationError("responses input is not provided", nil, provider.GetProviderKey()) + } + + // Use struct directly for JSON marshaling + responseBody, err := provider.completeRequest(ctx, reqBody, provider.networkConfig.BaseURL+"/v1/messages", key.Value) + if err != nil { + return nil, err + } + + // Create response object from pool + response := acquireAnthropicChatResponse() + defer releaseAnthropicChatResponse(response) + + rawResponse, bifrostErr := handleProviderResponse(responseBody, response, provider.sendBackRawResponse) + if bifrostErr != nil { + return nil, bifrostErr + } + + // Create final response + bifrostResponse := response.ToResponsesBifrostResponse() + + // Set ExtraFields + bifrostResponse.ExtraFields.Provider = provider.GetProviderKey() + bifrostResponse.ExtraFields.ModelRequested = request.Model + bifrostResponse.ExtraFields.RequestType = schemas.ResponsesRequest + + // Set raw response if enabled + if provider.sendBackRawResponse { + bifrostResponse.ExtraFields.RawResponse = rawResponse } return bifrostResponse, nil } // Embedding is not supported by the Anthropic provider. -func (provider *AnthropicProvider) Embedding(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *AnthropicProvider) Embedding(ctx context.Context, key schemas.Key, input *schemas.BifrostEmbeddingRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { return nil, newUnsupportedOperationError("embedding", "anthropic") } // ChatCompletionStream performs a streaming chat completion request to the Anthropic API. // It supports real-time streaming of responses using Server-Sent Events (SSE). // Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails. -func (provider *AnthropicProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, input *schemas.BifrostRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { - if err := checkOperationAllowed(schemas.Anthropic, provider.customProviderConfig, schemas.OperationChatCompletionStream); err != nil { +func (provider *AnthropicProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { + if err := checkOperationAllowed(schemas.Anthropic, provider.customProviderConfig, schemas.ChatCompletionStreamRequest); err != nil { return nil, err } // Convert to Anthropic format using the centralized converter - anthropicReq := anthropic.ToAnthropicChatCompletionRequest(input) - if anthropicReq == nil { + reqBody := anthropic.ToAnthropicChatCompletionRequest(request) + if reqBody == nil { return nil, newBifrostOperationError("failed to convert request", fmt.Errorf("conversion returned nil"), provider.GetProviderKey()) } reqBody.Stream = schemas.Ptr(true) @@ -297,7 +344,6 @@ func (provider *AnthropicProvider) ChatCompletionStream(ctx context.Context, pos headers, provider.networkConfig.ExtraHeaders, provider.GetProviderKey(), - input.Params, postHookRunner, provider.logger, ) @@ -313,7 +359,6 @@ func handleAnthropicStreaming( headers map[string]string, extraHeaders map[string]string, providerType schemas.ModelProvider, - params *schemas.ModelParameters, postHookRunner schemas.PostHookRunner, logger schemas.Logger, ) (chan *schemas.BifrostStream, *schemas.BifrostError) { @@ -408,7 +453,7 @@ func handleAnthropicStreaming( } } if event.Delta != nil && event.Delta.StopReason != nil { - mappedReason := anthropic.MapAnthropicFinishReason(*event.Delta.StopReason) + mappedReason := anthropic.MapAnthropicFinishReasonToBifrost(*event.Delta.StopReason) finishReason = &mappedReason } @@ -429,7 +474,7 @@ func handleAnthropicStreaming( ID: messageID, Object: "chat.completion.chunk", Model: modelName, - Choices: []schemas.BifrostResponseChoice{ + Choices: []schemas.BifrostChatResponseChoice{ { Index: 0, BifrostStreamResponseChoice: &schemas.BifrostStreamResponseChoice{ @@ -440,8 +485,10 @@ func handleAnthropicStreaming( }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerType, - ChunkIndex: chunkIndex, + RequestType: schemas.ChatCompletionStreamRequest, + Provider: providerType, + ModelRequested: modelName, + ChunkIndex: chunkIndex, }, } @@ -465,16 +512,16 @@ func handleAnthropicStreaming( ID: messageID, Object: "chat.completion.chunk", Model: modelName, - Choices: []schemas.BifrostResponseChoice{ + Choices: []schemas.BifrostChatResponseChoice{ { Index: *event.Index, BifrostStreamResponseChoice: &schemas.BifrostStreamResponseChoice{ Delta: schemas.BifrostStreamDelta{ - ToolCalls: []schemas.ToolCall{ + ToolCalls: []schemas.ChatAssistantMessageToolCall{ { Type: func() *string { s := "function"; return &s }(), ID: event.ContentBlock.ID, - Function: schemas.FunctionCall{ + Function: schemas.ChatAssistantMessageToolCallFunction{ Name: event.ContentBlock.Name, }, }, @@ -484,8 +531,10 @@ func handleAnthropicStreaming( }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerType, - ChunkIndex: chunkIndex, + RequestType: schemas.ChatCompletionStreamRequest, + Provider: providerType, + ModelRequested: modelName, + ChunkIndex: chunkIndex, }, } @@ -507,7 +556,7 @@ func handleAnthropicStreaming( ID: messageID, Object: "chat.completion.chunk", Model: modelName, - Choices: []schemas.BifrostResponseChoice{ + Choices: []schemas.BifrostChatResponseChoice{ { Index: *event.Index, BifrostStreamResponseChoice: &schemas.BifrostStreamResponseChoice{ @@ -519,8 +568,10 @@ func handleAnthropicStreaming( }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerType, - ChunkIndex: chunkIndex, + RequestType: schemas.ChatCompletionStreamRequest, + Provider: providerType, + ModelRequested: modelName, + ChunkIndex: chunkIndex, }, } @@ -542,7 +593,7 @@ func handleAnthropicStreaming( ID: messageID, Object: "chat.completion.chunk", Model: modelName, - Choices: []schemas.BifrostResponseChoice{ + Choices: []schemas.BifrostChatResponseChoice{ { Index: *event.Index, BifrostStreamResponseChoice: &schemas.BifrostStreamResponseChoice{ @@ -553,8 +604,10 @@ func handleAnthropicStreaming( }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerType, - ChunkIndex: chunkIndex, + RequestType: schemas.ChatCompletionStreamRequest, + Provider: providerType, + ModelRequested: modelName, + ChunkIndex: chunkIndex, }, } @@ -570,15 +623,15 @@ func handleAnthropicStreaming( ID: messageID, Object: "chat.completion.chunk", Model: modelName, - Choices: []schemas.BifrostResponseChoice{ + Choices: []schemas.BifrostChatResponseChoice{ { Index: *event.Index, BifrostStreamResponseChoice: &schemas.BifrostStreamResponseChoice{ Delta: schemas.BifrostStreamDelta{ - ToolCalls: []schemas.ToolCall{ + ToolCalls: []schemas.ChatAssistantMessageToolCall{ { Type: func() *string { s := "function"; return &s }(), - Function: schemas.FunctionCall{ + Function: schemas.ChatAssistantMessageToolCallFunction{ Arguments: event.Delta.PartialJSON, }, }, @@ -588,8 +641,10 @@ func handleAnthropicStreaming( }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerType, - ChunkIndex: chunkIndex, + RequestType: schemas.ChatCompletionStreamRequest, + Provider: providerType, + ModelRequested: modelName, + ChunkIndex: chunkIndex, }, } @@ -605,7 +660,7 @@ func handleAnthropicStreaming( ID: messageID, Object: "chat.completion.chunk", Model: modelName, - Choices: []schemas.BifrostResponseChoice{ + Choices: []schemas.BifrostChatResponseChoice{ { Index: *event.Index, BifrostStreamResponseChoice: &schemas.BifrostStreamResponseChoice{ @@ -616,8 +671,10 @@ func handleAnthropicStreaming( }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerType, - ChunkIndex: chunkIndex, + RequestType: schemas.ChatCompletionStreamRequest, + Provider: providerType, + ModelRequested: modelName, + ChunkIndex: chunkIndex, }, } @@ -676,9 +733,9 @@ func handleAnthropicStreaming( if err := scanner.Err(); err != nil { logger.Warn(fmt.Sprintf("Error reading %s stream: %v", providerType, err)) - processAndSendError(ctx, postHookRunner, err, responseChan, logger) + processAndSendError(ctx, postHookRunner, err, responseChan, schemas.ChatCompletionStreamRequest, providerType, modelName, logger) } else { - response := createBifrostChatCompletionChunkResponse(messageID, usage, finishReason, chunkIndex, params, providerType) + response := createBifrostChatCompletionChunkResponse(messageID, usage, finishReason, chunkIndex, schemas.ChatCompletionStreamRequest, providerType, modelName) handleStreamEndWithSuccess(ctx, response, postHookRunner, responseChan, logger) } }() @@ -686,18 +743,22 @@ func handleAnthropicStreaming( return responseChan, nil } -func (provider *AnthropicProvider) Speech(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *AnthropicProvider) Speech(ctx context.Context, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { return nil, newUnsupportedOperationError("speech", "anthropic") } -func (provider *AnthropicProvider) SpeechStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, input *schemas.BifrostRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *AnthropicProvider) SpeechStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, newUnsupportedOperationError("speech stream", "anthropic") } -func (provider *AnthropicProvider) Transcription(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *AnthropicProvider) Transcription(ctx context.Context, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { return nil, newUnsupportedOperationError("transcription", "anthropic") } -func (provider *AnthropicProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, input *schemas.BifrostRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *AnthropicProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, newUnsupportedOperationError("transcription stream", "anthropic") } + +func (provider *AnthropicProvider) ResponsesStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("responses stream", "anthropic") +} diff --git a/core/providers/azure.go b/core/providers/azure.go index a30e54e9fe..a202329487 100644 --- a/core/providers/azure.go +++ b/core/providers/azure.go @@ -6,7 +6,6 @@ import ( "context" "fmt" "net/http" - "sync" "time" "github.com/bytedance/sonic" @@ -16,49 +15,7 @@ import ( ) // AzureAuthorizationTokenKey is the context key for the Azure authentication token. -const AzureAuthorizationTokenKey ContextKey = "azure-authorization-token" - -// azureTextCompletionResponsePool provides a pool for Azure text completion response objects. -var azureTextCompletionResponsePool = sync.Pool{ - New: func() interface{} { - return &openai.OpenAITextCompletionResponse{} - }, -} - -// // azureChatResponsePool provides a pool for Azure chat response objects. -// var azureChatResponsePool = sync.Pool{ -// New: func() interface{} { -// return &schemas.BifrostResponse{} -// }, -// } - -// // acquireAzureChatResponse gets an Azure chat response from the pool and resets it. -// func acquireAzureChatResponse() *schemas.BifrostResponse { -// resp := azureChatResponsePool.Get().(*schemas.BifrostResponse) -// *resp = schemas.BifrostResponse{} // Reset the struct -// return resp -// } - -// // releaseAzureChatResponse returns an Azure chat response to the pool. -// func releaseAzureChatResponse(resp *schemas.BifrostResponse) { -// if resp != nil { -// azureChatResponsePool.Put(resp) -// } -// } - -// acquireAzureTextResponse gets an Azure text completion response from the pool and resets it. -func acquireAzureTextResponse() *openai.OpenAITextCompletionResponse { - resp := azureTextCompletionResponsePool.Get().(*openai.OpenAITextCompletionResponse) - *resp = openai.OpenAITextCompletionResponse{} // Reset the struct - return resp -} - -// releaseAzureTextResponse returns an Azure text completion response to the pool. -func releaseAzureTextResponse(resp *openai.OpenAITextCompletionResponse) { - if resp != nil { - azureTextCompletionResponsePool.Put(resp) - } -} +const AzureAuthorizationTokenKey schemas.BifrostContextKey = "azure-authorization-token" // AzureProvider implements the Provider interface for Azure's OpenAI API. type AzureProvider struct { @@ -86,13 +43,6 @@ func NewAzureProvider(config *schemas.ProviderConfig, logger schemas.Logger) (*A Timeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), } - // Pre-warm response pools - for range config.ConcurrencyAndBufferSize.Concurrency { - // azureChatResponsePool.Put(&schemas.BifrostResponse{}) - azureTextCompletionResponsePool.Put(&openai.OpenAITextCompletionResponse{}) - - } - // Configure proxy if provided client = configureProxy(client, config.ProxyConfig, logger) @@ -138,7 +88,7 @@ func (provider *AzureProvider) completeRequest(ctx context.Context, requestBody apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil { - apiVersion = Ptr("2024-02-01") + apiVersion = schemas.Ptr("2024-02-01") } url = fmt.Sprintf("%s/openai/deployments/%s/%s?api-version=%s", url, deployment, path, *apiVersion) @@ -178,11 +128,10 @@ func (provider *AzureProvider) completeRequest(ctx context.Context, requestBody if resp.StatusCode() != fasthttp.StatusOK { provider.logger.Debug(fmt.Sprintf("error from azure provider: %s", string(resp.Body()))) - var errorResp openai.OpenAIChatError + var errorResp map[string]interface{} bifrostErr := handleProviderAPIError(resp, &errorResp) - bifrostErr.Error.Type = &errorResp.Error.Code - bifrostErr.Error.Message = errorResp.Error.Message + bifrostErr.Error.Message = fmt.Sprintf("%s error: %v", schemas.Azure, errorResp) return nil, bifrostErr } @@ -196,49 +145,48 @@ func (provider *AzureProvider) completeRequest(ctx context.Context, requestBody // TextCompletion performs a text completion request to Azure's API. // It formats the request, sends it to Azure, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. -func (provider *AzureProvider) TextCompletion(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *AzureProvider) TextCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { // Use centralized OpenAI text converter (Azure is OpenAI-compatible) - reqBody := openai.ToOpenAITextCompletionRequest(input) + reqBody := openai.ToOpenAITextCompletionRequest(request) + if reqBody == nil { + return nil, newBifrostOperationError("text completion input is not provided", nil, schemas.Azure) + } - responseBody, err := provider.completeRequest(ctx, reqBody, "completions", key, input.Model) + responseBody, err := provider.completeRequest(ctx, reqBody, "completions", key, request.Model) if err != nil { return nil, err } - // Create response object from pool - response := acquireAzureTextResponse() - defer releaseAzureTextResponse(response) + response := &schemas.BifrostResponse{} rawResponse, bifrostErr := handleProviderResponse(responseBody, response, provider.sendBackRawResponse) if bifrostErr != nil { return nil, bifrostErr } - // Use centralized OpenAI response converter (Azure is OpenAI-compatible) - bifrostResponse := response.ToBifrostResponse() - - bifrostResponse.ExtraFields.Provider = schemas.Azure + response.ExtraFields.Provider = schemas.Azure + response.ExtraFields.ModelRequested = request.Model + response.ExtraFields.RequestType = schemas.TextCompletionRequest // Set raw response if enabled if provider.sendBackRawResponse { - bifrostResponse.ExtraFields.RawResponse = rawResponse - } - - if input.Params != nil { - bifrostResponse.ExtraFields.Params = *input.Params + response.ExtraFields.RawResponse = rawResponse } - return bifrostResponse, nil + return response, nil } // ChatCompletion performs a chat completion request to Azure's API. // It formats the request, sends it to Azure, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. -func (provider *AzureProvider) ChatCompletion(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *AzureProvider) ChatCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { // Use centralized OpenAI converter since Azure is OpenAI-compatible - reqBody := openai.ToOpenAIChatCompletionRequest(input) + reqBody := openai.ToOpenAIChatRequest(request) + if reqBody == nil { + return nil, newBifrostOperationError("chat completion input is not provided", nil, schemas.Azure) + } - responseBody, err := provider.completeRequest(ctx, reqBody, "chat/completions", key, input.Model) + responseBody, err := provider.completeRequest(ctx, reqBody, "chat/completions", key, request.Model) if err != nil { return nil, err } @@ -255,31 +203,46 @@ func (provider *AzureProvider) ChatCompletion(ctx context.Context, key schemas.K } response.ExtraFields.Provider = schemas.Azure + response.ExtraFields.ModelRequested = request.Model + response.ExtraFields.RequestType = schemas.ChatCompletionRequest // Set raw response if enabled if provider.sendBackRawResponse { response.ExtraFields.RawResponse = rawResponse } - if input.Params != nil { - response.ExtraFields.Params = *input.Params + return response, nil +} + +// Responses performs a responses request to Azure's API. +// It formats the request, sends it to Azure, and processes the response. +// Returns a BifrostResponse containing the completion results or an error if the request fails. +func (provider *AzureProvider) Responses(ctx context.Context, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { + response, err := provider.ChatCompletion(ctx, key, request.ToChatRequest()) + if err != nil { + return nil, err } + response.ToResponsesOnly() + response.ExtraFields.RequestType = schemas.ResponsesRequest + response.ExtraFields.Provider = provider.GetProviderKey() + response.ExtraFields.ModelRequested = request.Model + return response, nil } // Embedding generates embeddings for the given input text(s) using Azure OpenAI. // The input can be either a single string or a slice of strings for batch embedding. // Returns a BifrostResponse containing the embedding(s) and any error that occurred. -func (provider *AzureProvider) Embedding(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *AzureProvider) Embedding(ctx context.Context, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { // Use centralized converter - reqBody := openai.ToOpenAIEmbeddingRequest(input) + reqBody := openai.ToOpenAIEmbeddingRequest(request) if reqBody == nil { return nil, newBifrostOperationError("embedding input is not provided", nil, schemas.Azure) } - responseBody, err := provider.completeRequest(ctx, reqBody, "embeddings", key, input.Model) + responseBody, err := provider.completeRequest(ctx, reqBody, "embeddings", key, request.Model) if err != nil { return nil, err } @@ -293,10 +256,8 @@ func (provider *AzureProvider) Embedding(ctx context.Context, key schemas.Key, i } response.ExtraFields.Provider = schemas.Azure - - if input.Params != nil { - response.ExtraFields.Params = *input.Params - } + response.ExtraFields.ModelRequested = request.Model + response.ExtraFields.RequestType = schemas.EmbeddingRequest if provider.sendBackRawResponse { response.ExtraFields.RawResponse = rawResponse @@ -309,10 +270,7 @@ func (provider *AzureProvider) Embedding(ctx context.Context, key schemas.Key, i // It supports real-time streaming of responses using Server-Sent Events (SSE). // Uses Azure-specific URL construction with deployments and supports both api-key and Bearer token authentication. // Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails. -func (provider *AzureProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, input *schemas.BifrostRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { - reqBody := openai.ToOpenAIChatCompletionRequest(input) - reqBody.Stream = schemas.Ptr(true) - +func (provider *AzureProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { if key.AzureKeyConfig == nil { return nil, newConfigurationError("azure key config not set", schemas.Azure) } @@ -326,14 +284,14 @@ func (provider *AzureProvider) ChatCompletionStream(ctx context.Context, postHoo var fullURL string if key.AzureKeyConfig.Deployments != nil { - deployment := key.AzureKeyConfig.Deployments[input.Model] + deployment := key.AzureKeyConfig.Deployments[request.Model] if deployment == "" { - return nil, newConfigurationError(fmt.Sprintf("deployment not found for model %s", input.Model), schemas.Azure) + return nil, newConfigurationError(fmt.Sprintf("deployment not found for model %s", request.Model), schemas.Azure) } apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil { - apiVersion = Ptr("2024-02-01") + apiVersion = schemas.Ptr("2024-02-01") } fullURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=%s", baseURL, deployment, *apiVersion) @@ -342,16 +300,13 @@ func (provider *AzureProvider) ChatCompletionStream(ctx context.Context, postHoo } // Prepare Azure-specific headers - headers := make(map[string]string) - headers["Content-Type"] = "application/json" - headers["Accept"] = "text/event-stream" - headers["Cache-Control"] = "no-cache" + authHeader := make(map[string]string) // Set Azure authentication - either Bearer token or api-key if authToken, ok := ctx.Value(AzureAuthorizationTokenKey).(string); ok { - headers["Authorization"] = fmt.Sprintf("Bearer %s", authToken) + authHeader["Authorization"] = fmt.Sprintf("Bearer %s", authToken) } else { - headers["api-key"] = key.Value + authHeader["api-key"] = key.Value } // Use shared streaming logic from OpenAI @@ -359,28 +314,31 @@ func (provider *AzureProvider) ChatCompletionStream(ctx context.Context, postHoo ctx, provider.streamClient, fullURL, - reqBody, - headers, + request, + authHeader, provider.networkConfig.ExtraHeaders, - schemas.Azure, // Provider type - input.Params, + schemas.Azure, postHookRunner, provider.logger, ) } -func (provider *AzureProvider) Speech(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *AzureProvider) Speech(ctx context.Context, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { return nil, newUnsupportedOperationError("speech", "azure") } -func (provider *AzureProvider) SpeechStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, input *schemas.BifrostRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *AzureProvider) SpeechStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, newUnsupportedOperationError("speech stream", "azure") } -func (provider *AzureProvider) Transcription(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *AzureProvider) Transcription(ctx context.Context, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { return nil, newUnsupportedOperationError("transcription", "azure") } -func (provider *AzureProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, input *schemas.BifrostRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *AzureProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, newUnsupportedOperationError("transcription stream", "azure") } + +func (provider *AzureProvider) ResponsesStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("responses stream", "azure") +} diff --git a/core/providers/bedrock.go b/core/providers/bedrock.go index 00586038ec..66a078c906 100644 --- a/core/providers/bedrock.go +++ b/core/providers/bedrock.go @@ -10,7 +10,6 @@ import ( "errors" "fmt" "io" - "maps" "net/http" "net/url" "strings" @@ -102,7 +101,7 @@ func (provider *BedrockProvider) completeRequest(ctx context.Context, requestBod return nil, &schemas.BifrostError{ IsBifrostError: false, Error: schemas.ErrorField{ - Type: Ptr(schemas.RequestCancelled), + Type: schemas.Ptr(schemas.RequestCancelled), Message: fmt.Sprintf("Request cancelled or timed out by context: %v", ctx.Err()), Error: err, }, @@ -195,8 +194,8 @@ func (provider *BedrockProvider) completeRequest(ctx context.Context, requestBod // TextCompletion performs a text completion request to Bedrock's API. // It formats the request, sends it to Bedrock, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. -func (provider *BedrockProvider) TextCompletion(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { - if err := checkOperationAllowed(schemas.Bedrock, provider.customProviderConfig, schemas.OperationTextCompletion); err != nil { +func (provider *BedrockProvider) TextCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { + if err := checkOperationAllowed(schemas.Bedrock, provider.customProviderConfig, schemas.TextCompletionRequest); err != nil { return nil, err } @@ -206,9 +205,12 @@ func (provider *BedrockProvider) TextCompletion(ctx context.Context, key schemas return nil, newConfigurationError("bedrock key config is not provided", providerName) } - requestBody := bedrock.ToBedrockTextCompletionRequest(input) + requestBody := bedrock.ToBedrockTextCompletionRequest(request) + if requestBody == nil { + return nil, newBifrostOperationError("text completion input is not provided", nil, providerName) + } - path := provider.getModelPath("invoke", input.Model, key) + path := provider.getModelPath("invoke", request.Model, key) body, err := provider.completeRequest(ctx, requestBody, path, key) if err != nil { return nil, err @@ -217,14 +219,14 @@ func (provider *BedrockProvider) TextCompletion(ctx context.Context, key schemas // Handle model-specific response conversion var bifrostResponse *schemas.BifrostResponse switch { - case strings.Contains(input.Model, "anthropic."): + case strings.Contains(request.Model, "anthropic.") || strings.Contains(request.Model, "claude"): var response bedrock.BedrockAnthropicTextResponse if err := sonic.Unmarshal(body, &response); err != nil { return nil, newBifrostOperationError("error parsing anthropic response", err, providerName) } bifrostResponse = response.ToBifrostResponse() - case strings.Contains(input.Model, "mistral."): + case strings.Contains(request.Model, "mistral."): var response bedrock.BedrockMistralTextResponse if err := sonic.Unmarshal(body, &response); err != nil { return nil, newBifrostOperationError("error parsing mistral response", err, providerName) @@ -232,9 +234,14 @@ func (provider *BedrockProvider) TextCompletion(ctx context.Context, key schemas bifrostResponse = response.ToBifrostResponse() default: - return nil, newConfigurationError(fmt.Sprintf("unsupported model type for text completion: %s", input.Model), providerName) + return nil, newConfigurationError(fmt.Sprintf("unsupported model type for text completion: %s", request.Model), providerName) } + // Set ExtraFields + bifrostResponse.ExtraFields.Provider = providerName + bifrostResponse.ExtraFields.ModelRequested = request.Model + bifrostResponse.ExtraFields.RequestType = schemas.TextCompletionRequest + // Parse raw response if enabled if provider.sendBackRawResponse { var rawResponse interface{} @@ -244,18 +251,14 @@ func (provider *BedrockProvider) TextCompletion(ctx context.Context, key schemas bifrostResponse.ExtraFields.RawResponse = rawResponse } - if input.Params != nil { - bifrostResponse.ExtraFields.Params = *input.Params - } - return bifrostResponse, nil } // ChatCompletion performs a chat completion request to Bedrock's API. // It formats the request, sends it to Bedrock, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. -func (provider *BedrockProvider) ChatCompletion(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { - if err := checkOperationAllowed(schemas.Bedrock, provider.customProviderConfig, schemas.OperationChatCompletion); err != nil { +func (provider *BedrockProvider) ChatCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { + if err := checkOperationAllowed(schemas.Bedrock, provider.customProviderConfig, schemas.ChatCompletionRequest); err != nil { return nil, err } @@ -266,17 +269,13 @@ func (provider *BedrockProvider) ChatCompletion(ctx context.Context, key schemas } // pool the request - bedrockReq, err := bedrock.ToBedrockChatCompletionRequest(input) + reqBody, err := bedrock.ToBedrockChatCompletionRequest(request) if err != nil { return nil, newBifrostOperationError("failed to convert request", err, providerName) } - if bedrockReq == nil { - return nil, newBifrostOperationError("failed to convert request", fmt.Errorf("conversion returned nil"), providerName) - } - // Format the path with proper model identifier - path := provider.getModelPath("converse", input.Model, key) + path := provider.getModelPath("converse", request.Model, key) // Create the signed request responseBody, bifrostErr := provider.completeRequest(ctx, reqBody, path, key) @@ -299,6 +298,11 @@ func (provider *BedrockProvider) ChatCompletion(ctx context.Context, key schemas return nil, newBifrostOperationError("failed to convert bedrock response", err, providerName) } + // Set ExtraFields + bifrostResponse.ExtraFields.Provider = providerName + bifrostResponse.ExtraFields.ModelRequested = request.Model + bifrostResponse.ExtraFields.RequestType = schemas.ChatCompletionRequest + // Set raw response if enabled if provider.sendBackRawResponse { var rawResponse interface{} @@ -307,8 +311,64 @@ func (provider *BedrockProvider) ChatCompletion(ctx context.Context, key schemas } } - if input.Params != nil { - bifrostResponse.ExtraFields.Params = *input.Params + return bifrostResponse, nil +} + +// Responses performs a chat completion request to Anthropic's API. +// It formats the request, sends it to Anthropic, and processes the response. +// Returns a BifrostResponse containing the completion results or an error if the request fails. +func (provider *BedrockProvider) Responses(ctx context.Context, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { + if err := checkOperationAllowed(schemas.Bedrock, provider.customProviderConfig, schemas.ResponsesRequest); err != nil { + return nil, err + } + + providerName := provider.GetProviderKey() + + if key.BedrockKeyConfig == nil { + return nil, newConfigurationError("bedrock key config is not provided", providerName) + } + + // pool the request + reqBody, err := bedrock.ToBedrockResponsesRequest(request) + if err != nil { + return nil, newBifrostOperationError("failed to convert request", err, providerName) + } + + // Format the path with proper model identifier + path := provider.getModelPath("converse", request.Model, key) + + // Create the signed request + responseBody, bifrostErr := provider.completeRequest(ctx, reqBody, path, key) + if bifrostErr != nil { + return nil, bifrostErr + } + + // pool the response + bedrockResponse := acquireBedrockChatResponse() + defer releaseBedrockChatResponse(bedrockResponse) + + // Parse the response using the new Bedrock type + if err := sonic.Unmarshal(responseBody, bedrockResponse); err != nil { + return nil, newBifrostOperationError("failed to parse bedrock response", err, providerName) + } + + // Convert using the new response converter + bifrostResponse, err := bedrockResponse.ToResponsesBifrostResponse() + if err != nil { + return nil, newBifrostOperationError("failed to convert bedrock response", err, providerName) + } + + // Set ExtraFields + bifrostResponse.ExtraFields.Provider = providerName + bifrostResponse.ExtraFields.ModelRequested = request.Model + bifrostResponse.ExtraFields.RequestType = schemas.ResponsesRequest + + // Set raw response if enabled + if provider.sendBackRawResponse { + var rawResponse interface{} + if err := sonic.Unmarshal(responseBody, &rawResponse); err == nil { + bifrostResponse.ExtraFields.RawResponse = rawResponse + } } return bifrostResponse, nil @@ -390,152 +450,81 @@ func signAWSRequest(ctx context.Context, req *http.Request, accessKey, secretKey // Embedding generates embeddings for the given input text(s) using Amazon Bedrock. // Supports Titan and Cohere embedding models. Returns a BifrostResponse containing the embedding(s) and any error that occurred. -func (provider *BedrockProvider) Embedding(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { - if err := checkOperationAllowed(schemas.Bedrock, provider.customProviderConfig, schemas.OperationEmbedding); err != nil { +func (provider *BedrockProvider) Embedding(ctx context.Context, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { + if err := checkOperationAllowed(schemas.Bedrock, provider.customProviderConfig, schemas.EmbeddingRequest); err != nil { return nil, err } providerName := provider.GetProviderKey() - embeddingInput := input.Input.EmbeddingInput - if key.BedrockKeyConfig == nil { return nil, newConfigurationError("bedrock key config is not provided", providerName) } - switch { - case strings.Contains(input.Model, "amazon.titan-embed-text"): - return provider.handleTitanEmbedding(ctx, input.Model, key, embeddingInput, input.Params, providerName) - case strings.Contains(input.Model, "cohere.embed"): - return provider.handleCohereEmbedding(ctx, input.Model, key, embeddingInput, input.Params, providerName) - default: - return nil, newConfigurationError("embedding is not supported for this Bedrock model", providerName) - } -} - -// handleTitanEmbedding handles embedding requests for Amazon Titan models. -func (provider *BedrockProvider) handleTitanEmbedding(ctx context.Context, model string, key schemas.Key, input *schemas.EmbeddingInput, params *schemas.ModelParameters, providerName schemas.ModelProvider) (*schemas.BifrostResponse, *schemas.BifrostError) { - // Titan Text Embeddings V1/V2 - only supports single text input - if input.Text == nil && len(input.Texts) == 0 { - return nil, newConfigurationError("no input text provided for embedding", providerName) - } - - // Validate that only single text input is provided for Titan models - if input.Text == nil && len(input.Texts) > 1 { - return nil, newConfigurationError("Amazon Titan embedding models only support single text input, but multiple texts were provided", providerName) - } - - requestBody := map[string]interface{}{} - - if input.Text != nil { - requestBody["inputText"] = *input.Text - } else if len(input.Texts) == 1 { - requestBody["inputText"] = input.Texts[0] - } - - if params != nil { - // Titan models do not support the dimensions parameter - they have fixed dimensions - if params.Dimensions != nil { - return nil, newConfigurationError("Amazon Titan embedding models do not support custom dimensions parameter", providerName) - } - if params.ExtraParams != nil { - for k, v := range params.ExtraParams { - requestBody[k] = v - } - } - } - - // Properly escape model name for URL path to ensure AWS SIGv4 signing works correctly - path := provider.getModelPath("invoke", model, key) - rawResponse, err := provider.completeRequest(ctx, requestBody, path, key) + // Determine model type + modelType, err := bedrock.DetermineEmbeddingModelType(request.Model) if err != nil { - return nil, err - } - - // Parse Titan response from raw message - var titanResp struct { - Embedding []float32 `json:"embedding"` - InputTextTokenCount int `json:"inputTextTokenCount"` - } - if err := sonic.Unmarshal(rawResponse, &titanResp); err != nil { - return nil, newBifrostOperationError("error parsing Titan embedding response", err, providerName) - } - - bifrostResponse := &schemas.BifrostResponse{ - Object: "list", - Data: []schemas.BifrostEmbedding{ - { - Index: 0, - Object: "embedding", - Embedding: schemas.BifrostEmbeddingResponse{ - Embedding2DArray: &[][]float32{titanResp.Embedding}, - }, - }, - }, - Model: model, - Usage: &schemas.LLMUsage{ - PromptTokens: titanResp.InputTextTokenCount, - TotalTokens: titanResp.InputTextTokenCount, - }, - ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - }, - } - - if provider.sendBackRawResponse { - bifrostResponse.ExtraFields.RawResponse = rawResponse + return nil, newConfigurationError(err.Error(), providerName) } - if params != nil { - bifrostResponse.ExtraFields.Params = *params - } - - return bifrostResponse, nil -} + // Convert request and execute based on model type + var rawResponse []byte + var bifrostError *schemas.BifrostError -// handleCohereEmbedding handles embedding requests for Cohere models on Bedrock. -func (provider *BedrockProvider) handleCohereEmbedding(ctx context.Context, model string, key schemas.Key, input *schemas.EmbeddingInput, params *schemas.ModelParameters, providerName schemas.ModelProvider) (*schemas.BifrostResponse, *schemas.BifrostError) { - if input.Text == nil && len(input.Texts) == 0 { - return nil, newConfigurationError("no input text provided for embedding", providerName) - } + switch modelType { + case "titan": + titanReq, err := bedrock.ToBedrockTitanEmbeddingRequest(request) + if err != nil { + return nil, newBifrostOperationError("failed to convert Titan request", err, providerName) + } + path := provider.getModelPath("invoke", request.Model, key) + rawResponse, bifrostError = provider.completeRequest(ctx, titanReq, path, key) - requestBody := map[string]interface{}{ - "input_type": "search_document", - } + case "cohere": + cohereReq, err := bedrock.ToBedrockCohereEmbeddingRequest(request) + if err != nil { + return nil, newBifrostOperationError("failed to convert Cohere request", err, providerName) + } + path := provider.getModelPath("invoke", request.Model, key) + rawResponse, bifrostError = provider.completeRequest(ctx, cohereReq, path, key) - if input.Text != nil { - requestBody["texts"] = []string{*input.Text} - } else { - requestBody["texts"] = input.Texts + default: + return nil, newConfigurationError("unsupported embedding model type", providerName) } - if params != nil && params.ExtraParams != nil { - maps.Copy(requestBody, params.ExtraParams) + if bifrostError != nil { + return nil, bifrostError } - // Properly escape model name for URL path to ensure AWS SIGv4 signing works correctly - path := provider.getModelPath("invoke", model, key) - rawResponse, err := provider.completeRequest(ctx, requestBody, path, key) - if err != nil { - return nil, err - } + // Parse response based on model type + var bifrostResponse *schemas.BifrostResponse + switch modelType { + case "titan": + var titanResp bedrock.BedrockTitanEmbeddingResponse + if err := sonic.Unmarshal(rawResponse, &titanResp); err != nil { + return nil, newBifrostOperationError("error parsing Titan embedding response", err, providerName) + } + bifrostResponse = titanResp.ToBifrostResponse(request.Model) - var cohereResp cohere.CohereEmbeddingResponse - if err := sonic.Unmarshal(rawResponse, &cohereResp); err != nil { - return nil, newBifrostOperationError("error parsing embedding response", err, providerName) + case "cohere": + var cohereResp cohere.CohereEmbeddingResponse + if err := sonic.Unmarshal(rawResponse, &cohereResp); err != nil { + return nil, newBifrostOperationError("error parsing Cohere embedding response", err, providerName) + } + bifrostResponse = cohereResp.ToBifrostResponse() + bifrostResponse.Model = request.Model } - // Create BifrostResponse - bifrostResponse := cohereResp.ToBifrostResponse() - bifrostResponse.Model = model + // Set ExtraFields bifrostResponse.ExtraFields.Provider = providerName + bifrostResponse.ExtraFields.ModelRequested = request.Model + bifrostResponse.ExtraFields.RequestType = schemas.EmbeddingRequest - // Only include RawResponse if sendBackRawResponse is enabled + // Set raw response if enabled if provider.sendBackRawResponse { - bifrostResponse.ExtraFields.RawResponse = rawResponse - } - - if params != nil { - bifrostResponse.ExtraFields.Params = *params + var rawResponseData interface{} + if err := sonic.Unmarshal(rawResponse, &rawResponseData); err == nil { + bifrostResponse.ExtraFields.RawResponse = rawResponseData + } } return bifrostResponse, nil @@ -544,8 +533,8 @@ func (provider *BedrockProvider) handleCohereEmbedding(ctx context.Context, mode // ChatCompletionStream performs a streaming chat completion request to Bedrock's API. // It formats the request, sends it to Bedrock, and processes the streaming response. // Returns a channel for streaming BifrostResponse objects or an error if the request fails. -func (provider *BedrockProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, input *schemas.BifrostRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { - if err := checkOperationAllowed(schemas.Bedrock, provider.customProviderConfig, schemas.OperationChatCompletionStream); err != nil { +func (provider *BedrockProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { + if err := checkOperationAllowed(schemas.Bedrock, provider.customProviderConfig, schemas.ChatCompletionStreamRequest); err != nil { return nil, err } @@ -555,13 +544,13 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx context.Context, postH return nil, newConfigurationError("bedrock key config is not provided", providerName) } - bedrockReq, err := bedrock.ToBedrockChatCompletionRequest(input) + reqBody, err := bedrock.ToBedrockChatCompletionRequest(request) if err != nil { return nil, newBifrostOperationError("failed to convert request", err, providerName) } // Format the path with proper model identifier for streaming - path := provider.getModelPath("converse-stream", input.Model, key) + path := provider.getModelPath("converse-stream", request.Model, key) region := "us-east-1" if key.BedrockKeyConfig.Region != nil { @@ -631,12 +620,12 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx context.Context, postH if err == io.EOF { // Process any remaining data in the accumulator if len(accumulator) > 0 { - _ = provider.processAWSEventStreamData(ctx, postHookRunner, accumulator, &messageID, &chunkIndex, &usage, &finishReason, input.Model, providerName, responseChan) + _ = provider.processAWSEventStreamData(ctx, postHookRunner, accumulator, &messageID, &chunkIndex, &usage, &finishReason, request.Model, providerName, responseChan) } break } provider.logger.Warn(fmt.Sprintf("Error reading %s stream: %v", providerName, err)) - processAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) + processAndSendError(ctx, postHookRunner, err, responseChan, schemas.ChatCompletionStreamRequest, providerName, request.Model, provider.logger) return } @@ -648,14 +637,14 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx context.Context, postH accumulator = append(accumulator, buffer[:n]...) // Process the accumulated data and get the remaining unprocessed part - remaining := provider.processAWSEventStreamData(ctx, postHookRunner, accumulator, &messageID, &chunkIndex, &usage, &finishReason, input.Model, providerName, responseChan) + remaining := provider.processAWSEventStreamData(ctx, postHookRunner, accumulator, &messageID, &chunkIndex, &usage, &finishReason, request.Model, providerName, responseChan) // Reset accumulator with remaining data accumulator = remaining } // Send final response - response := createBifrostChatCompletionChunkResponse(messageID, usage, finishReason, chunkIndex, input.Params, providerName) + response := createBifrostChatCompletionChunkResponse(messageID, usage, finishReason, chunkIndex, schemas.ChatCompletionStreamRequest, providerName, request.Model) handleStreamEndWithSuccess(ctx, response, postHookRunner, responseChan, provider.logger) }() @@ -763,7 +752,7 @@ func (provider *BedrockProvider) processEventBuffer(ctx context.Context, postHoo ID: *messageID, Object: "chat.completion.chunk", Model: model, - Choices: []schemas.BifrostResponseChoice{ + Choices: []schemas.BifrostChatResponseChoice{ { Index: 0, BifrostStreamResponseChoice: &schemas.BifrostStreamResponseChoice{ @@ -774,8 +763,10 @@ func (provider *BedrockProvider) processEventBuffer(ctx context.Context, postHoo }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - ChunkIndex: *chunkIndex, + RequestType: schemas.ChatCompletionStreamRequest, + Provider: providerName, + ModelRequested: model, + ChunkIndex: *chunkIndex, }, } @@ -792,7 +783,7 @@ func (provider *BedrockProvider) processEventBuffer(ctx context.Context, postHoo toolUseStart := streamEvent.Start.ToolUse // Create tool call structure for start event - var toolCall schemas.ToolCall + var toolCall schemas.ChatAssistantMessageToolCall toolCall.Type = schemas.Ptr("function") toolCall.Function.Name = schemas.Ptr(toolUseStart.Name) toolCall.Function.Arguments = "{}" // Start with empty arguments @@ -801,19 +792,21 @@ func (provider *BedrockProvider) processEventBuffer(ctx context.Context, postHoo ID: *messageID, Object: "chat.completion.chunk", Model: model, - Choices: []schemas.BifrostResponseChoice{ + Choices: []schemas.BifrostChatResponseChoice{ { Index: contentBlockIndex, BifrostStreamResponseChoice: &schemas.BifrostStreamResponseChoice{ Delta: schemas.BifrostStreamDelta{ - ToolCalls: []schemas.ToolCall{toolCall}, + ToolCalls: []schemas.ChatAssistantMessageToolCall{toolCall}, }, }, }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - ChunkIndex: *chunkIndex, + RequestType: schemas.ChatCompletionStreamRequest, + Provider: providerName, + ModelRequested: model, + ChunkIndex: *chunkIndex, }, } @@ -833,7 +826,7 @@ func (provider *BedrockProvider) processEventBuffer(ctx context.Context, postHoo ID: *messageID, Object: "chat.completion.chunk", Model: model, - Choices: []schemas.BifrostResponseChoice{ + Choices: []schemas.BifrostChatResponseChoice{ { Index: contentBlockIndex, BifrostStreamResponseChoice: &schemas.BifrostStreamResponseChoice{ @@ -844,8 +837,10 @@ func (provider *BedrockProvider) processEventBuffer(ctx context.Context, postHoo }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - ChunkIndex: *chunkIndex, + RequestType: schemas.ChatCompletionStreamRequest, + Provider: providerName, + ModelRequested: model, + ChunkIndex: *chunkIndex, }, } @@ -863,7 +858,7 @@ func (provider *BedrockProvider) processEventBuffer(ctx context.Context, postHoo } // Create tool call structure - var toolCall schemas.ToolCall + var toolCall schemas.ChatAssistantMessageToolCall toolCall.Type = schemas.Ptr("function") // For streaming, we need to accumulate tool use data @@ -874,19 +869,21 @@ func (provider *BedrockProvider) processEventBuffer(ctx context.Context, postHoo ID: *messageID, Object: "chat.completion.chunk", Model: model, - Choices: []schemas.BifrostResponseChoice{ + Choices: []schemas.BifrostChatResponseChoice{ { Index: contentBlockIndex, BifrostStreamResponseChoice: &schemas.BifrostStreamResponseChoice{ Delta: schemas.BifrostStreamDelta{ - ToolCalls: []schemas.ToolCall{toolCall}, + ToolCalls: []schemas.ChatAssistantMessageToolCall{toolCall}, }, }, }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - ChunkIndex: *chunkIndex, + RequestType: schemas.ChatCompletionStreamRequest, + Provider: providerName, + ModelRequested: model, + ChunkIndex: *chunkIndex, }, } @@ -912,19 +909,19 @@ func (provider *BedrockProvider) processEventBuffer(ctx context.Context, postHoo } } -func (provider *BedrockProvider) Speech(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *BedrockProvider) Speech(ctx context.Context, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { return nil, newUnsupportedOperationError("speech", "bedrock") } -func (provider *BedrockProvider) SpeechStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, input *schemas.BifrostRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *BedrockProvider) SpeechStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, newUnsupportedOperationError("speech stream", "bedrock") } -func (provider *BedrockProvider) Transcription(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *BedrockProvider) Transcription(ctx context.Context, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { return nil, newUnsupportedOperationError("transcription", "bedrock") } -func (provider *BedrockProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, input *schemas.BifrostRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *BedrockProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, newUnsupportedOperationError("transcription stream", "bedrock") } @@ -943,3 +940,7 @@ func (provider *BedrockProvider) getModelPath(basePath string, model string, key return path } + +func (provider *BedrockProvider) ResponsesStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("responses stream", "bedrock") +} diff --git a/core/providers/cerebras.go b/core/providers/cerebras.go index 2444d75792..2057648791 100644 --- a/core/providers/cerebras.go +++ b/core/providers/cerebras.go @@ -4,38 +4,14 @@ package providers import ( "context" - "fmt" "net/http" "strings" - "sync" "time" - "github.com/bytedance/sonic" schemas "github.com/maximhq/bifrost/core/schemas" "github.com/valyala/fasthttp" ) -// cerebrasTextResponsePool provides a pool for Cerebras text completion response objects. -var cerebrasTextResponsePool = sync.Pool{ - New: func() interface{} { - return &openai.OpenAITextCompletionResponse{} - }, -} - -// acquireCerebrasTextResponse gets a Cerebras text completion response from the pool and resets it. -func acquireCerebrasTextResponse() *openai.OpenAITextCompletionResponse { - resp := cerebrasTextResponsePool.Get().(*openai.OpenAITextCompletionResponse) - *resp = openai.OpenAITextCompletionResponse{} // Reset the struct - return resp -} - -// releaseCerebrasTextResponse returns a Cerebras text completion response to the pool. -func releaseCerebrasTextResponse(resp *openai.OpenAITextCompletionResponse) { - if resp != nil { - cerebrasTextResponsePool.Put(resp) - } -} - // CerebrasProvider implements the Provider interface for Cerebras's API. type CerebrasProvider struct { logger schemas.Logger // Logger for provider operations @@ -62,12 +38,6 @@ func NewCerebrasProvider(config *schemas.ProviderConfig, logger schemas.Logger) Timeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), } - // Pre-warm response pools - for range config.ConcurrencyAndBufferSize.Concurrency { - // cerebrasChatResponsePool.Put(&schemas.BifrostResponse{}) - cerebrasTextResponsePool.Put(&openai.OpenAITextCompletionResponse{}) - } - // Configure proxy if provided client = configureProxy(client, config.ProxyConfig, logger) @@ -94,144 +64,51 @@ func (provider *CerebrasProvider) GetProviderKey() schemas.ModelProvider { // TextCompletion performs a text completion request to Cerebras's API. // It formats the request, sends it to Cerebras, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. -func (provider *CerebrasProvider) TextCompletion(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { - // Use centralized OpenAI text converter (Cerebras is OpenAI-compatible) - reqBody := openai.ToOpenAITextCompletionRequest(input) - - // Create request - req := fasthttp.AcquireRequest() - resp := fasthttp.AcquireResponse() - defer fasthttp.ReleaseRequest(req) - defer fasthttp.ReleaseResponse(resp) - - jsonBody, err := sonic.Marshal(reqBody) - if err != nil { - return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, schemas.Cerebras) - } - - // Set any extra headers from network config - setExtraHeaders(req, provider.networkConfig.ExtraHeaders, nil) - - req.SetRequestURI(provider.networkConfig.BaseURL + "/v1/completions") - req.Header.SetMethod("POST") - req.Header.SetContentType("application/json") - req.Header.Set("Authorization", "Bearer "+key.Value) - - req.SetBody(jsonBody) - - // Make request - bifrostErr := makeRequestWithContext(ctx, provider.client, req, resp) - if bifrostErr != nil { - return nil, bifrostErr - } - - // Handle error response - if resp.StatusCode() != fasthttp.StatusOK { - provider.logger.Debug(fmt.Sprintf("error from cerebras provider: %s", string(resp.Body()))) - - var errorResp map[string]interface{} - bifrostErr := handleProviderAPIError(resp, &errorResp) - bifrostErr.Error.Message = fmt.Sprintf("Cerebras error: %v", errorResp) - return nil, bifrostErr - } - - responseBody := resp.Body() - - // Pre-allocate response structs from pools - response := acquireCerebrasTextResponse() - defer releaseCerebrasTextResponse(response) - - rawResponse, bifrostErr := handleProviderResponse(responseBody, response, provider.sendBackRawResponse) - if bifrostErr != nil { - return nil, bifrostErr - } - - // Use centralized OpenAI response converter (Cerebras is OpenAI-compatible) - bifrostResponse := response.ToBifrostResponse() - - bifrostResponse.ExtraFields.Provider = schemas.Cerebras - - // Set raw response if enabled - if provider.sendBackRawResponse { - bifrostResponse.ExtraFields.RawResponse = rawResponse - } - - if input.Params != nil { - bifrostResponse.ExtraFields.Params = *input.Params - } - - return bifrostResponse, nil +func (provider *CerebrasProvider) TextCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { + return handleOpenAITextCompletionRequest( + ctx, + provider.client, + provider.networkConfig.BaseURL+"/v1/completions", + request, + key, + provider.networkConfig.ExtraHeaders, + provider.GetProviderKey(), + provider.sendBackRawResponse, + provider.logger, + ) } // ChatCompletion performs a chat completion request to the Cerebras API. -func (provider *CerebrasProvider) ChatCompletion(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { - // Use centralized OpenAI converter since Cerebras is OpenAI-compatible - reqBody := openai.ToOpenAIChatCompletionRequest(input) +func (provider *CerebrasProvider) ChatCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { + return handleOpenAIChatCompletionRequest( + ctx, + provider.client, + provider.networkConfig.BaseURL+"/v1/chat/completions", + request, + key, + provider.networkConfig.ExtraHeaders, + provider.GetProviderKey(), + provider.sendBackRawResponse, + provider.logger, + ) +} - jsonBody, err := sonic.Marshal(reqBody) +func (provider *CerebrasProvider) Responses(ctx context.Context, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { + response, err := provider.ChatCompletion(ctx, key, request.ToChatRequest()) if err != nil { - return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, schemas.Cerebras) + return nil, err } - // Create request - req := fasthttp.AcquireRequest() - resp := fasthttp.AcquireResponse() - defer fasthttp.ReleaseRequest(req) - defer fasthttp.ReleaseResponse(resp) - - // Set any extra headers from network config - setExtraHeaders(req, provider.networkConfig.ExtraHeaders, nil) - - req.SetRequestURI(provider.networkConfig.BaseURL + "/v1/chat/completions") - req.Header.SetMethod("POST") - req.Header.SetContentType("application/json") - req.Header.Set("Authorization", "Bearer "+key.Value) - - req.SetBody(jsonBody) - - // Make request - bifrostErr := makeRequestWithContext(ctx, provider.client, req, resp) - if bifrostErr != nil { - return nil, bifrostErr - } - - // Handle error response - if resp.StatusCode() != fasthttp.StatusOK { - provider.logger.Debug(fmt.Sprintf("error from cerebras provider: %s", string(resp.Body()))) - - var errorResp map[string]interface{} - bifrostErr := handleProviderAPIError(resp, &errorResp) - bifrostErr.Error.Message = fmt.Sprintf("Cerebras error: %v", errorResp) - return nil, bifrostErr - } - - responseBody := resp.Body() - - // Pre-allocate response structs from pools - response := &schemas.BifrostResponse{} - - // Use enhanced response handler with pre-allocated response - rawResponse, bifrostErr := handleProviderResponse(responseBody, response, provider.sendBackRawResponse) - if bifrostErr != nil { - return nil, bifrostErr - } - - // Create final response - response.ExtraFields.Provider = schemas.Cerebras - - if provider.sendBackRawResponse { - response.ExtraFields.RawResponse = rawResponse - } - - if input.Params != nil { - response.ExtraFields.Params = *input.Params - } + response.ToResponsesOnly() + response.ExtraFields.RequestType = schemas.ResponsesRequest + response.ExtraFields.Provider = provider.GetProviderKey() + response.ExtraFields.ModelRequested = request.Model return response, nil } // Embedding is not supported by the Cerebras provider. -func (provider *CerebrasProvider) Embedding(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *CerebrasProvider) Embedding(ctx context.Context, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { return nil, newUnsupportedOperationError("embedding", "cerebras") } @@ -239,47 +116,37 @@ func (provider *CerebrasProvider) Embedding(ctx context.Context, key schemas.Key // It supports real-time streaming of responses using Server-Sent Events (SSE). // Uses Cerebras's OpenAI-compatible streaming format. // Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails. -func (provider *CerebrasProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, input *schemas.BifrostRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { - - reqBody := openai.ToOpenAIChatCompletionRequest(input) - reqBody.Stream = schemas.Ptr(true) - - // Prepare Cerebras headers - headers := map[string]string{ - "Content-Type": "application/json", - "Accept": "text/event-stream", - "Cache-Control": "no-cache", - } - - headers["Authorization"] = "Bearer " + key.Value - +func (provider *CerebrasProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { // Use shared OpenAI-compatible streaming logic return handleOpenAIStreaming( ctx, provider.streamClient, provider.networkConfig.BaseURL+"/v1/chat/completions", - reqBody, - headers, + request, + map[string]string{"Authorization": "Bearer " + key.Value}, provider.networkConfig.ExtraHeaders, schemas.Cerebras, - input.Params, postHookRunner, provider.logger, ) } -func (provider *CerebrasProvider) Speech(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *CerebrasProvider) Speech(ctx context.Context, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { return nil, newUnsupportedOperationError("speech", "cerebras") } -func (provider *CerebrasProvider) SpeechStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, input *schemas.BifrostRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *CerebrasProvider) SpeechStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, newUnsupportedOperationError("speech stream", "cerebras") } -func (provider *CerebrasProvider) Transcription(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *CerebrasProvider) Transcription(ctx context.Context, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { return nil, newUnsupportedOperationError("transcription", "cerebras") } -func (provider *CerebrasProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, input *schemas.BifrostRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *CerebrasProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, newUnsupportedOperationError("transcription stream", "cerebras") } + +func (provider *CerebrasProvider) ResponsesStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("responses stream", "cerebras") +} diff --git a/core/providers/cohere.go b/core/providers/cohere.go index 121541cba1..23c373f4f6 100644 --- a/core/providers/cohere.go +++ b/core/providers/cohere.go @@ -96,28 +96,54 @@ func (provider *CohereProvider) GetProviderKey() schemas.ModelProvider { // TextCompletion is not supported by the Cohere provider. // Returns an error indicating that text completion is not supported. -func (provider *CohereProvider) TextCompletion(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *CohereProvider) TextCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { return nil, newUnsupportedOperationError("text completion", "cohere") } // ChatCompletion performs a chat completion request to the Cohere API using v2 converter. // It formats the request, sends it to Cohere, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. -func (provider *CohereProvider) ChatCompletion(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *CohereProvider) ChatCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { // Check if chat completion is allowed - if err := checkOperationAllowed(schemas.Cohere, provider.customProviderConfig, schemas.OperationChatCompletion); err != nil { + if err := checkOperationAllowed(schemas.Cohere, provider.customProviderConfig, schemas.ChatCompletionRequest); err != nil { return nil, err } providerName := provider.GetProviderKey() // Convert to Cohere v2 request - reqBody := cohere.ToCohereChatCompletionRequest(input) + reqBody := cohere.ToCohereChatCompletionRequest(request) + if reqBody == nil { + return nil, newBifrostOperationError("chat completion input is not provided", nil, providerName) + } + + cohereResponse, rawResponse, err := provider.handleCohereChatCompletionRequest(ctx, reqBody, key) + if err != nil { + return nil, err + } + + // Convert Cohere v2 response to Bifrost response + bifrostResponse := cohereResponse.ToBifrostResponse() + + bifrostResponse.Model = request.Model + bifrostResponse.ExtraFields.Provider = providerName + bifrostResponse.ExtraFields.ModelRequested = request.Model + bifrostResponse.ExtraFields.RequestType = schemas.ChatCompletionRequest + + if provider.sendBackRawResponse { + bifrostResponse.ExtraFields.RawResponse = rawResponse + } + + return bifrostResponse, nil +} + +func (provider *CohereProvider) handleCohereChatCompletionRequest(ctx context.Context, reqBody *cohere.CohereChatRequest, key schemas.Key) (*cohere.CohereChatResponse, interface{}, *schemas.BifrostError) { + providerName := provider.GetProviderKey() // Marshal request body jsonBody, err := sonic.Marshal(reqBody) if err != nil { - return nil, &schemas.BifrostError{ + return nil, nil, &schemas.BifrostError{ IsBifrostError: true, Error: schemas.ErrorField{ Message: schemas.ErrProviderJSONMarshaling, @@ -145,7 +171,7 @@ func (provider *CohereProvider) ChatCompletion(ctx context.Context, key schemas. // Make request bifrostErr := makeRequestWithContext(ctx, provider.client, req, resp) if bifrostErr != nil { - return nil, bifrostErr + return nil, nil, bifrostErr } // Handle error response @@ -156,13 +182,13 @@ func (provider *CohereProvider) ChatCompletion(ctx context.Context, key schemas. bifrostErr := handleProviderAPIError(resp, &errorResp) bifrostErr.Error.Message = errorResp.Message - return nil, bifrostErr + return nil, nil, bifrostErr } // Parse Cohere v2 response var cohereResponse cohere.CohereChatResponse if err := sonic.Unmarshal(resp.Body(), &cohereResponse); err != nil { - return nil, &schemas.BifrostError{ + return nil, nil, &schemas.BifrostError{ IsBifrostError: true, Error: schemas.ErrorField{ Message: "error parsing Cohere v2 response", @@ -175,7 +201,7 @@ func (provider *CohereProvider) ChatCompletion(ctx context.Context, key schemas. var rawResponse interface{} if provider.sendBackRawResponse { if err := sonic.Unmarshal(resp.Body(), &rawResponse); err != nil { - return nil, &schemas.BifrostError{ + return nil, nil, &schemas.BifrostError{ IsBifrostError: true, Error: schemas.ErrorField{ Message: "error parsing raw response", @@ -185,35 +211,58 @@ func (provider *CohereProvider) ChatCompletion(ctx context.Context, key schemas. } } + return &cohereResponse, rawResponse, nil +} + +func (provider *CohereProvider) Responses(ctx context.Context, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { + // Check if chat completion is allowed + if err := checkOperationAllowed(schemas.Cohere, provider.customProviderConfig, schemas.ResponsesRequest); err != nil { + return nil, err + } + + providerName := provider.GetProviderKey() + + // Convert to Cohere v2 request + reqBody := cohere.ToCohereResponsesRequest(request) + if reqBody == nil { + return nil, newBifrostOperationError("responses input is not provided", nil, providerName) + } + + cohereResponse, rawResponse, err := provider.handleCohereChatCompletionRequest(ctx, reqBody, key) + if err != nil { + return nil, err + } + // Convert Cohere v2 response to Bifrost response - bifrostResponse := cohereResponse.ToBifrostResponse() + bifrostResponse := cohereResponse.ToResponsesBifrostResponse() - bifrostResponse.Model = input.Model + bifrostResponse.Model = request.Model bifrostResponse.ExtraFields.Provider = providerName + bifrostResponse.ExtraFields.ModelRequested = request.Model + bifrostResponse.ExtraFields.RequestType = schemas.ResponsesRequest if provider.sendBackRawResponse { bifrostResponse.ExtraFields.RawResponse = rawResponse } - if input.Params != nil { - bifrostResponse.ExtraFields.Params = *input.Params - } - return bifrostResponse, nil } // Embedding generates embeddings for the given input text(s) using the Cohere API. // Supports Cohere's embedding models and returns a BifrostResponse containing the embedding(s). -func (provider *CohereProvider) Embedding(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *CohereProvider) Embedding(ctx context.Context, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { // Check if embedding is allowed - if err := checkOperationAllowed(schemas.Cohere, provider.customProviderConfig, schemas.OperationEmbedding); err != nil { + if err := checkOperationAllowed(schemas.Cohere, provider.customProviderConfig, schemas.EmbeddingRequest); err != nil { return nil, err } providerName := provider.GetProviderKey() // Create Bifrost request for conversion - reqBody := cohere.ToCohereEmbeddingRequest(input) + reqBody := cohere.ToCohereEmbeddingRequest(request) + if reqBody == nil { + return nil, newBifrostOperationError("embedding input is not provided", nil, providerName) + } // Marshal request body jsonBody, err := sonic.Marshal(reqBody) @@ -268,33 +317,31 @@ func (provider *CohereProvider) Embedding(ctx context.Context, key schemas.Key, // Create BifrostResponse bifrostResponse := cohereResp.ToBifrostResponse() - bifrostResponse.Model = input.Model + bifrostResponse.Model = request.Model bifrostResponse.ExtraFields.Provider = providerName + bifrostResponse.ExtraFields.ModelRequested = request.Model + bifrostResponse.ExtraFields.RequestType = schemas.EmbeddingRequest // Only include RawResponse if sendBackRawResponse is enabled if provider.sendBackRawResponse { bifrostResponse.ExtraFields.RawResponse = rawResponse } - if input.Params != nil { - bifrostResponse.ExtraFields.Params = *input.Params - } - return bifrostResponse, nil } // ChatCompletionStream performs a streaming chat completion request to the Cohere API. // It supports real-time streaming of responses using Server-Sent Events (SSE). // Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails. -func (provider *CohereProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, input *schemas.BifrostRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *CohereProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { // Check if chat completion stream is allowed - if err := checkOperationAllowed(schemas.Cohere, provider.customProviderConfig, schemas.OperationChatCompletionStream); err != nil { + if err := checkOperationAllowed(schemas.Cohere, provider.customProviderConfig, schemas.ChatCompletionStreamRequest); err != nil { return nil, err } providerName := provider.GetProviderKey() // Convert to Cohere v2 request and add streaming - reqBody := cohere.ToCohereChatCompletionRequest(input) + reqBody := cohere.ToCohereChatCompletionRequest(request) if reqBody == nil { return nil, newBifrostOperationError("chat completion input is not provided", nil, providerName) } @@ -388,8 +435,8 @@ func (provider *CohereProvider) ChatCompletionStream(ctx context.Context, postHo response := &schemas.BifrostResponse{ ID: responseID, Object: "chat.completion.chunk", - Model: input.Model, - Choices: []schemas.BifrostResponseChoice{ + Model: request.Model, + Choices: []schemas.BifrostChatResponseChoice{ { Index: 0, BifrostStreamResponseChoice: &schemas.BifrostStreamResponseChoice{ @@ -398,8 +445,10 @@ func (provider *CohereProvider) ChatCompletionStream(ctx context.Context, postHo }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - ChunkIndex: chunkIndex, + RequestType: schemas.ChatCompletionStreamRequest, + Provider: providerName, + ModelRequested: request.Model, + ChunkIndex: chunkIndex, }, } @@ -431,7 +480,7 @@ func (provider *CohereProvider) ChatCompletionStream(ctx context.Context, postHo if event.Delta != nil && event.Delta.Message != nil && event.Delta.Message.ToolCalls != nil && event.Delta.Message.ToolCalls.ToolCall != nil { // Handle single tool call object (tool-call-start/delta events) cohereToolCall := event.Delta.Message.ToolCalls.ToolCall - toolCall := schemas.ToolCall{} + toolCall := schemas.ChatAssistantMessageToolCall{} if cohereToolCall.ID != nil { toolCall.ID = cohereToolCall.ID @@ -444,7 +493,7 @@ func (provider *CohereProvider) ChatCompletionStream(ctx context.Context, postHo toolCall.Function.Arguments = cohereToolCall.Function.Arguments } - response.Choices[0].BifrostStreamResponseChoice.Delta.ToolCalls = []schemas.ToolCall{toolCall} + response.Choices[0].BifrostStreamResponseChoice.Delta.ToolCalls = []schemas.ChatAssistantMessageToolCall{toolCall} } case cohere.StreamEventMessageEnd: @@ -481,10 +530,6 @@ func (provider *CohereProvider) ChatCompletionStream(ctx context.Context, postHo } } - if input.Params != nil { - response.ExtraFields.Params = *input.Params - } - ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) } @@ -507,25 +552,29 @@ func (provider *CohereProvider) ChatCompletionStream(ctx context.Context, postHo if err := scanner.Err(); err != nil { provider.logger.Warn(fmt.Sprintf("Error reading stream: %v", err)) - processAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) + processAndSendError(ctx, postHookRunner, err, responseChan, schemas.ChatCompletionStreamRequest, providerName, request.Model, provider.logger) } }() return responseChan, nil } -func (provider *CohereProvider) Speech(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *CohereProvider) Speech(ctx context.Context, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { return nil, newUnsupportedOperationError("speech", "cohere") } -func (provider *CohereProvider) SpeechStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, input *schemas.BifrostRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *CohereProvider) SpeechStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, newUnsupportedOperationError("speech stream", "cohere") } -func (provider *CohereProvider) Transcription(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *CohereProvider) Transcription(ctx context.Context, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { return nil, newUnsupportedOperationError("transcription", "cohere") } -func (provider *CohereProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, input *schemas.BifrostRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *CohereProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, newUnsupportedOperationError("transcription stream", "cohere") } + +func (provider *CohereProvider) ResponsesStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("responses stream", "cohere") +} diff --git a/core/providers/gemini.go b/core/providers/gemini.go index 9adffe3c0a..374d76ef21 100644 --- a/core/providers/gemini.go +++ b/core/providers/gemini.go @@ -14,69 +14,11 @@ import ( "github.com/bytedance/sonic" schemas "github.com/maximhq/bifrost/core/schemas" - "github.com/maximhq/bifrost/core/schemas/providers/openai" "github.com/maximhq/bifrost/core/schemas/providers/gemini" + "github.com/maximhq/bifrost/core/schemas/providers/openai" "github.com/valyala/fasthttp" ) -// Response message for PredictionService.GenerateContent. -type GenerateContentResponse struct { - // Response variations returned by the model. - Candidates []*Candidate `json:"candidates,omitempty"` - // Usage metadata about the response(s). - UsageMetadata *GenerateContentResponseUsageMetadata `json:"usageMetadata,omitempty"` -} - -// A response candidate generated from the model. -type Candidate struct { - // Optional. Contains the multi-part content of the response. - Content *Content `json:"content,omitempty"` - // Optional. The reason why the model stopped generating tokens. - // If empty, the model has not stopped generating the tokens. - FinishReason string `json:"finishReason,omitempty"` - // Output only. Index of the candidate. - Index int32 `json:"index,omitempty"` -} - -// Contains the multi-part content of a message. -type Content struct { - // Optional. List of parts that constitute a single message. Each part may have - // a different IANA MIME type. - Parts []*Part `json:"parts,omitempty"` - // Optional. The producer of the content. Must be either 'user' or - // 'model'. Useful to set for multi-turn conversations, otherwise can be - // empty. If role is not specified, SDK will determine the role. - Role string `json:"role,omitempty"` -} - -// A datatype containing media content. -// Exactly one field within a Part should be set, representing the specific type -// of content being conveyed. Using multiple fields within the same `Part` -// instance is considered invalid. -type Part struct { - // Optional. Inlined bytes data. - InlineData *Blob `json:"inlineData,omitempty"` - // Optional. Text part (can be code). - Text string `json:"text,omitempty"` -} - -// Content blob. -type Blob struct { - // Required. Raw bytes. - Data []byte `json:"data,omitempty"` -} - -// Usage metadata about response(s). -type GenerateContentResponseUsageMetadata struct { - // Number of tokens in the response(s). This includes all the generated response candidates. - CandidatesTokenCount int32 `json:"candidatesTokenCount,omitempty"` - // Number of tokens in the prompt. When cached_content is set, this is still the total - // effective prompt size meaning this includes the number of tokens in the cached content. - PromptTokenCount int32 `json:"promptTokenCount,omitempty"` - // Total token count for prompt, response candidates, and tool-use prompts (if present). - TotalTokenCount int32 `json:"totalTokenCount,omitempty"` -} - type GeminiProvider struct { logger schemas.Logger // Logger for provider operations client *fasthttp.Client // HTTP client for API requests @@ -128,21 +70,24 @@ func (provider *GeminiProvider) GetProviderKey() schemas.ModelProvider { } // TextCompletion is not supported by the Gemini provider. -func (provider *GeminiProvider) TextCompletion(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *GeminiProvider) TextCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { return nil, newUnsupportedOperationError("text completion", string(provider.GetProviderKey())) } // ChatCompletion performs a chat completion request to the Gemini API. -func (provider *GeminiProvider) ChatCompletion(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *GeminiProvider) ChatCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { // Check if chat completion is allowed for this provider - if err := checkOperationAllowed(schemas.Gemini, provider.customProviderConfig, schemas.OperationChatCompletion); err != nil { + if err := checkOperationAllowed(schemas.Gemini, provider.customProviderConfig, schemas.ChatCompletionRequest); err != nil { return nil, err } providerName := provider.GetProviderKey() // Use centralized OpenAI converter since Gemini uses OpenAI-compatible endpoints - reqBody := openai.ToOpenAIChatCompletionRequest(input) + reqBody := openai.ToOpenAIChatRequest(request) + if reqBody == nil { + return nil, newBifrostOperationError("chat completion input is not provided", nil, providerName) + } jsonBody, err := sonic.Marshal(reqBody) if err != nil { @@ -193,13 +138,12 @@ func (provider *GeminiProvider) ChatCompletion(ctx context.Context, key schemas. } for _, choice := range response.Choices { - if choice.Message.AssistantMessage == nil || choice.Message.AssistantMessage.ToolCalls == nil { - continue - } - for i, toolCall := range *choice.Message.AssistantMessage.ToolCalls { - if (toolCall.ID == nil || *toolCall.ID == "") && toolCall.Function.Name != nil && *toolCall.Function.Name != "" { - id := *toolCall.Function.Name - (*choice.Message.AssistantMessage.ToolCalls)[i].ID = &id + if choice.BifrostNonStreamResponseChoice.Message.ChatAssistantMessage != nil && choice.BifrostNonStreamResponseChoice.Message.ChatAssistantMessage.ToolCalls != nil { + for i, toolCall := range *choice.BifrostNonStreamResponseChoice.Message.ChatAssistantMessage.ToolCalls { + if (toolCall.ID == nil || *toolCall.ID == "") && toolCall.Function.Name != nil && *toolCall.Function.Name != "" { + id := *toolCall.Function.Name + (*choice.BifrostNonStreamResponseChoice.Message.ChatAssistantMessage.ToolCalls)[i].ID = &id + } } } } @@ -210,10 +154,6 @@ func (provider *GeminiProvider) ChatCompletion(ctx context.Context, key schemas. response.ExtraFields.RawResponse = rawResponse } - if input.Params != nil { - response.ExtraFields.Params = *input.Params - } - return response, nil } @@ -221,87 +161,79 @@ func (provider *GeminiProvider) ChatCompletion(ctx context.Context, key schemas. // It supports real-time streaming of responses using Server-Sent Events (SSE). // Uses Gemini's OpenAI-compatible streaming format. // Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails. -func (provider *GeminiProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, input *schemas.BifrostRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *GeminiProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { // Check if chat completion stream is allowed for this provider - if err := checkOperationAllowed(schemas.Gemini, provider.customProviderConfig, schemas.OperationChatCompletionStream); err != nil { + if err := checkOperationAllowed(schemas.Gemini, provider.customProviderConfig, schemas.ChatCompletionStreamRequest); err != nil { return nil, err } - providerName := provider.GetProviderKey() - - // Use centralized OpenAI converter since Gemini uses OpenAI-compatible endpoints - reqBody := openai.ToOpenAIChatCompletionRequest(input) - reqBody.Stream = schemas.Ptr(true) - - // Prepare Gemini headers - headers := map[string]string{ - "Content-Type": "application/json", - "Authorization": "Bearer " + key.Value, - "Accept": "text/event-stream", - "Cache-Control": "no-cache", - } - // Use shared OpenAI-compatible streaming logic return handleOpenAIStreaming( ctx, provider.streamClient, provider.networkConfig.BaseURL+"/openai/chat/completions", - reqBody, - headers, + request, + map[string]string{"Authorization": "Bearer " + key.Value}, provider.networkConfig.ExtraHeaders, - providerName, - input.Params, + provider.GetProviderKey(), postHookRunner, provider.logger, ) } -// Embedding performs an embedding request to the Gemini API. -func (provider *GeminiProvider) Embedding(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { - // Check if embedding is allowed for this provider - if err := checkOperationAllowed(schemas.Gemini, provider.customProviderConfig, schemas.OperationEmbedding); err != nil { +// Responses performs a chat completion request to Anthropic's API. +// It formats the request, sends it to Anthropic, and processes the response. +// Returns a BifrostResponse containing the completion results or an error if the request fails. +func (provider *GeminiProvider) Responses(ctx context.Context, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { + response, err := provider.ChatCompletion(ctx, key, request.ToChatRequest()) + if err != nil { return nil, err } - providerName := provider.GetProviderKey() - embeddingInput := input.Input.EmbeddingInput + response.ToResponsesOnly() + response.ExtraFields.RequestType = schemas.ResponsesRequest + response.ExtraFields.Provider = provider.GetProviderKey() + response.ExtraFields.ModelRequested = request.Model - if embeddingInput.Text == nil && len(embeddingInput.Texts) == 0 { - return nil, newBifrostOperationError("invalid embedding input: at least one text is required", nil, providerName) - } + return response, nil +} - requestBody := openai.ToOpenAIEmbeddingRequest(input) +// Embedding performs an embedding request to the Gemini API. +func (provider *GeminiProvider) Embedding(ctx context.Context, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { + // Check if embedding is allowed for this provider + if err := checkOperationAllowed(schemas.Gemini, provider.customProviderConfig, schemas.EmbeddingRequest); err != nil { + return nil, err + } // Use the shared embedding request handler return handleOpenAIEmbeddingRequest( ctx, provider.client, provider.networkConfig.BaseURL+"/openai/embeddings", - requestBody, + request, key, - input.Params, provider.networkConfig.ExtraHeaders, - providerName, + provider.GetProviderKey(), provider.sendBackRawResponse, provider.logger, ) } -func (provider *GeminiProvider) Speech(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *GeminiProvider) Speech(ctx context.Context, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { // Check if speech is allowed for this provider - if err := checkOperationAllowed(schemas.Gemini, provider.customProviderConfig, schemas.OperationSpeech); err != nil { + if err := checkOperationAllowed(schemas.Gemini, provider.customProviderConfig, schemas.SpeechRequest); err != nil { return nil, err } providerName := provider.GetProviderKey() // Validate input - if input == nil || input.Input.SpeechInput == nil || input.Input.SpeechInput.Input == "" { + if request == nil || request.Input.Input == "" { return nil, newBifrostOperationError("invalid speech input: no text provided", fmt.Errorf("empty text input"), providerName) } - // Prepare request body using shared function - requestBody := gemini.ToGeminiGenerationRequest(input, []string{"AUDIO"}) + // Prepare request body using speech-specific function + requestBody := gemini.ToGeminiSpeechRequest(request, []string{"AUDIO"}) if requestBody == nil { return nil, newBifrostOperationError("speech input is not provided", nil, providerName) } @@ -312,50 +244,46 @@ func (provider *GeminiProvider) Speech(ctx context.Context, key schemas.Key, inp } // Use common request function - bifrostResponse, geminiResponse, bifrostErr := provider.completeRequest(ctx, input.Model, key, jsonBody, ":generateContent", input.Params) + geminiResponse, rawResponse, bifrostErr := provider.completeRequest(ctx, request.Model, key, jsonBody, ":generateContent") if bifrostErr != nil { return nil, bifrostErr } bifrostResponse := geminiResponse.ToBifrostResponse() - if provider.sendBackRawResponse { - var rawResponse interface{} - if err := sonic.Unmarshal(jsonBody, &rawResponse); err == nil { - bifrostResponse.ExtraFields.RawResponse = rawResponse - } - } + // Set ExtraFields + bifrostResponse.ExtraFields.Provider = providerName + bifrostResponse.ExtraFields.ModelRequested = request.Model + bifrostResponse.ExtraFields.RequestType = schemas.SpeechRequest - if input.Params != nil { - bifrostResponse.ExtraFields.Params = *input.Params + if provider.sendBackRawResponse { + bifrostResponse.ExtraFields.RawResponse = rawResponse } return bifrostResponse, nil } -func (provider *GeminiProvider) SpeechStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, input *schemas.BifrostRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *GeminiProvider) SpeechStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { // Check if speech stream is allowed for this provider - if err := checkOperationAllowed(schemas.Gemini, provider.customProviderConfig, schemas.OperationSpeechStream); err != nil { + if err := checkOperationAllowed(schemas.Gemini, provider.customProviderConfig, schemas.SpeechStreamRequest); err != nil { return nil, err } providerName := provider.GetProviderKey() - // Validate input - if input == nil || input.Input.SpeechInput == nil || input.Input.SpeechInput.Input == "" { - return nil, newBifrostOperationError("invalid speech input: no text provided", fmt.Errorf("empty text input"), providerName) + // Prepare request body using speech-specific function + requestBody := gemini.ToGeminiSpeechRequest(request, []string{"AUDIO"}) + if requestBody == nil { + return nil, newBifrostOperationError("speech input is not provided", nil, providerName) } - // Prepare request body using shared function - requestBody := gemini.ToGeminiGenerationRequest(input, []string{"AUDIO"}) - jsonBody, err := sonic.Marshal(requestBody) if err != nil { return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, providerName) } // Create HTTP request for streaming - req, err := http.NewRequestWithContext(ctx, "POST", provider.networkConfig.BaseURL+"/models/"+input.Model+":streamGenerateContent?alt=sse", bytes.NewReader(jsonBody)) + req, err := http.NewRequestWithContext(ctx, "POST", provider.networkConfig.BaseURL+"/models/"+request.Model+":streamGenerateContent?alt=sse", bytes.NewReader(jsonBody)) if err != nil { return nil, newBifrostOperationError(schemas.ErrProviderRequest, err, providerName) } @@ -424,7 +352,7 @@ func (provider *GeminiProvider) SpeechStream(ctx context.Context, postHookRunner if strings.Contains(err.Error(), "gemini api error") { // Handle API error bifrostErr := &schemas.BifrostError{ - Type: Ptr("gemini_api_error"), + Type: schemas.Ptr("gemini_api_error"), IsBifrostError: false, Error: schemas.ErrorField{ Message: err.Error(), @@ -472,7 +400,7 @@ func (provider *GeminiProvider) SpeechStream(ctx context.Context, postHookRunner // Create Bifrost speech response for streaming response := &schemas.BifrostResponse{ Object: "audio.speech.chunk", - Model: input.Model, + Model: request.Model, Speech: &schemas.BifrostSpeech{ Audio: audioChunk, BifrostSpeechStreamResponse: &schemas.BifrostSpeechStreamResponse{ @@ -480,8 +408,10 @@ func (provider *GeminiProvider) SpeechStream(ctx context.Context, postHookRunner }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - ChunkIndex: chunkIndex, + RequestType: schemas.SpeechStreamRequest, + Provider: providerName, + ModelRequested: request.Model, + ChunkIndex: chunkIndex, }, } @@ -493,7 +423,7 @@ func (provider *GeminiProvider) SpeechStream(ctx context.Context, postHookRunner // Handle scanner errors if err := scanner.Err(); err != nil { provider.logger.Warn(fmt.Sprintf("Error reading stream: %v", err)) - processAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) + processAndSendError(ctx, postHookRunner, err, responseChan, schemas.SpeechStreamRequest, providerName, request.Model, provider.logger) } else { response := &schemas.BifrostResponse{ Object: "audio.speech.chunk", @@ -501,14 +431,13 @@ func (provider *GeminiProvider) SpeechStream(ctx context.Context, postHookRunner Usage: usage, }, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - ChunkIndex: chunkIndex + 1, + RequestType: schemas.SpeechStreamRequest, + Provider: providerName, + ModelRequested: request.Model, + ChunkIndex: chunkIndex + 1, }, } - if input.Params != nil { - response.ExtraFields.Params = *input.Params - } handleStreamEndWithSuccess(ctx, response, postHookRunner, responseChan, provider.logger) } }() @@ -516,28 +445,22 @@ func (provider *GeminiProvider) SpeechStream(ctx context.Context, postHookRunner return responseChan, nil } -func (provider *GeminiProvider) Transcription(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *GeminiProvider) Transcription(ctx context.Context, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { // Check if transcription is allowed for this provider - if err := checkOperationAllowed(schemas.Gemini, provider.customProviderConfig, schemas.OperationTranscription); err != nil { + if err := checkOperationAllowed(schemas.Gemini, provider.customProviderConfig, schemas.TranscriptionRequest); err != nil { return nil, err } providerName := provider.GetProviderKey() - transcriptionInput := input.Input.TranscriptionInput - // Check file size limit (Gemini has a 20MB limit for inline data) const maxFileSize = 20 * 1024 * 1024 // 20MB - if len(transcriptionInput.File) > maxFileSize { - return nil, newBifrostOperationError("audio file too large for inline transcription", fmt.Errorf("file size %d bytes exceeds 20MB limit", len(transcriptionInput.File)), providerName) - } - - if transcriptionInput.Prompt == nil { - input.Input.TranscriptionInput.Prompt = Ptr("Generate a transcript of the speech.") + if len(request.Input.File) > maxFileSize { + return nil, newBifrostOperationError("audio file too large for inline transcription", fmt.Errorf("file size %d bytes exceeds 20MB limit", len(request.Input.File)), providerName) } - // Prepare request body using shared function - requestBody := gemini.ToGeminiGenerationRequest(input, nil) + // Prepare request body using transcription-specific function + requestBody := gemini.ToGeminiTranscriptionRequest(request) if requestBody == nil { return nil, newBifrostOperationError("transcription input is not provided", nil, providerName) } @@ -548,58 +471,54 @@ func (provider *GeminiProvider) Transcription(ctx context.Context, key schemas.K } // Use common request function - bifrostResponse, geminiResponse, bifrostErr := provider.completeRequest(ctx, input.Model, key, jsonBody, ":generateContent", input.Params) + geminiResponse, rawResponse, bifrostErr := provider.completeRequest(ctx, request.Model, key, jsonBody, ":generateContent") if bifrostErr != nil { return nil, bifrostErr } bifrostResponse := geminiResponse.ToBifrostResponse() - if provider.sendBackRawResponse { - var rawResponse interface{} - if err := sonic.Unmarshal(jsonBody, &rawResponse); err == nil { - bifrostResponse.ExtraFields.RawResponse = rawResponse - } - } + // Set ExtraFields + bifrostResponse.ExtraFields.Provider = providerName + bifrostResponse.ExtraFields.ModelRequested = request.Model + bifrostResponse.ExtraFields.RequestType = schemas.TranscriptionRequest - if input.Params != nil { - bifrostResponse.ExtraFields.Params = *input.Params + if provider.sendBackRawResponse { + bifrostResponse.ExtraFields.RawResponse = rawResponse } return bifrostResponse, nil } -func (provider *GeminiProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, input *schemas.BifrostRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *GeminiProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { // Check if transcription stream is allowed for this provider - if err := checkOperationAllowed(schemas.Gemini, provider.customProviderConfig, schemas.OperationTranscriptionStream); err != nil { + if err := checkOperationAllowed(schemas.Gemini, provider.customProviderConfig, schemas.TranscriptionStreamRequest); err != nil { return nil, err } providerName := provider.GetProviderKey() - transcriptionInput := input.Input.TranscriptionInput // Check file size limit (Gemini has a 20MB limit for inline data) - if transcriptionInput.File != nil { + if request.Input.File != nil { const maxFileSize = 20 * 1024 * 1024 // 20MB - if len(transcriptionInput.File) > maxFileSize { - return nil, newBifrostOperationError("audio file too large for inline transcription", fmt.Errorf("file size %d bytes exceeds 20MB limit", len(transcriptionInput.File)), providerName) + if len(request.Input.File) > maxFileSize { + return nil, newBifrostOperationError("audio file too large for inline transcription", fmt.Errorf("file size %d bytes exceeds 20MB limit", len(request.Input.File)), providerName) } } - if transcriptionInput.Prompt == nil { - transcriptionInput.Prompt = Ptr("Generate a transcript of the speech.") + // Prepare request body using transcription-specific function + requestBody := gemini.ToGeminiTranscriptionRequest(request) + if requestBody == nil { + return nil, newBifrostOperationError("transcription input is not provided", nil, providerName) } - // Prepare request body using shared function - requestBody := gemini.ToGeminiGenerationRequest(input, nil) - jsonBody, err := sonic.Marshal(requestBody) if err != nil { return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, providerName) } // Create HTTP request for streaming - req, err := http.NewRequestWithContext(ctx, "POST", provider.networkConfig.BaseURL+"/models/"+input.Model+":streamGenerateContent?alt=sse", bytes.NewReader(jsonBody)) + req, err := http.NewRequestWithContext(ctx, "POST", provider.networkConfig.BaseURL+"/models/"+request.Model+":streamGenerateContent?alt=sse", bytes.NewReader(jsonBody)) if err != nil { return nil, newBifrostOperationError(schemas.ErrProviderRequest, err, providerName) } @@ -670,7 +589,7 @@ func (provider *GeminiProvider) TranscriptionStream(ctx context.Context, postHoo // Handle error responses if _, hasError := errorCheck["error"]; hasError { bifrostErr := &schemas.BifrostError{ - Type: Ptr("gemini_api_error"), + Type: schemas.Ptr("gemini_api_error"), IsBifrostError: false, Error: schemas.ErrorField{ Message: fmt.Sprintf("Gemini API error: %v", errorCheck["error"]), @@ -710,9 +629,9 @@ func (provider *GeminiProvider) TranscriptionStream(ctx context.Context, postHoo if len(geminiResponse.Candidates) > 0 && (geminiResponse.Candidates[0].FinishReason != "" || geminiResponse.UsageMetadata != nil) { // Extract usage metadata from Gemini response inputTokens, outputTokens, totalTokens := extractGeminiUsageMetadata(&geminiResponse) - usage.InputTokens = Ptr(inputTokens) - usage.OutputTokens = Ptr(outputTokens) - usage.TotalTokens = Ptr(totalTokens) + usage.InputTokens = schemas.Ptr(inputTokens) + usage.OutputTokens = schemas.Ptr(outputTokens) + usage.TotalTokens = schemas.Ptr(totalTokens) } // Only send response if we have actual text content @@ -724,14 +643,16 @@ func (provider *GeminiProvider) TranscriptionStream(ctx context.Context, postHoo Object: "audio.transcription.chunk", Transcribe: &schemas.BifrostTranscribe{ BifrostTranscribeStreamResponse: &schemas.BifrostTranscribeStreamResponse{ - Type: Ptr("transcript.text.delta"), + Type: schemas.Ptr("transcript.text.delta"), Delta: &deltaText, // Delta text for this chunk }, }, - Model: input.Model, + Model: request.Model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - ChunkIndex: chunkIndex, + RequestType: schemas.TranscriptionStreamRequest, + Provider: providerName, + ModelRequested: request.Model, + ChunkIndex: chunkIndex, }, } @@ -743,7 +664,7 @@ func (provider *GeminiProvider) TranscriptionStream(ctx context.Context, postHoo // Handle scanner errors if err := scanner.Err(); err != nil { provider.logger.Warn(fmt.Sprintf("Error reading stream: %v", err)) - processAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) + processAndSendError(ctx, postHookRunner, err, responseChan, schemas.TranscriptionStreamRequest, providerName, request.Model, provider.logger) } else { response := &schemas.BifrostResponse{ Object: "audio.transcription.chunk", @@ -757,14 +678,13 @@ func (provider *GeminiProvider) TranscriptionStream(ctx context.Context, postHoo }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - ChunkIndex: chunkIndex + 1, + RequestType: schemas.TranscriptionStreamRequest, + Provider: providerName, + ModelRequested: request.Model, + ChunkIndex: chunkIndex + 1, }, } - if input.Params != nil { - response.ExtraFields.Params = *input.Params - } handleStreamEndWithSuccess(ctx, response, postHookRunner, responseChan, provider.logger) } }() @@ -807,7 +727,7 @@ func extractGeminiUsageMetadata(geminiResponse *gemini.GenerateContentResponse) } // completeRequest handles the common HTTP request pattern for Gemini API calls -func (provider *GeminiProvider) completeRequest(ctx context.Context, model string, key schemas.Key, jsonBody []byte, endpoint string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *gemini.GenerateContentResponse, *schemas.BifrostError) { +func (provider *GeminiProvider) completeRequest(ctx context.Context, model string, key schemas.Key, jsonBody []byte, endpoint string) (*gemini.GenerateContentResponse, interface{}, *schemas.BifrostError) { providerName := provider.GetProviderKey() // Create request @@ -846,27 +766,12 @@ func (provider *GeminiProvider) completeRequest(ctx context.Context, model strin return nil, nil, newBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) } - // Create base response - bifrostResponse := &schemas.BifrostResponse{ - Model: model, - ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - }, - } - - if params != nil { - bifrostResponse.ExtraFields.Params = *params - } - - // Set raw response if enabled - if provider.sendBackRawResponse { - var rawResponse interface{} - if err := sonic.Unmarshal(responseBody, &rawResponse); err == nil { - bifrostResponse.ExtraFields.RawResponse = rawResponse - } + var rawResponse interface{} + if err := sonic.Unmarshal(responseBody, &rawResponse); err != nil { + return nil, nil, newBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, providerName) } - return bifrostResponse, &geminiResponse, nil + return &geminiResponse, rawResponse, nil } // parseStreamGeminiError parses Gemini streaming error responses @@ -903,3 +808,7 @@ func parseGeminiError(providerName schemas.ModelProvider, resp *fasthttp.Respons return newBifrostOperationError(fmt.Sprintf("Gemini error: %v", errorResp), fmt.Errorf("HTTP %d", resp.StatusCode()), providerName) } + +func (provider *GeminiProvider) ResponsesStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("responses stream", "gemini") +} diff --git a/core/providers/groq.go b/core/providers/groq.go index 6fcf7cc572..a6639172a4 100644 --- a/core/providers/groq.go +++ b/core/providers/groq.go @@ -4,37 +4,14 @@ package providers import ( "context" - "fmt" "net/http" "strings" "time" - "github.com/bytedance/sonic" schemas "github.com/maximhq/bifrost/core/schemas" "github.com/valyala/fasthttp" ) -// // groqResponsePool provides a pool for Groq response objects. -// var groqResponsePool = sync.Pool{ -// New: func() interface{} { -// return &schemas.BifrostResponse{} -// }, -// } - -// // acquireGroqResponse gets a Groq response from the pool and resets it. -// func acquireGroqResponse() *schemas.BifrostResponse { -// resp := groqResponsePool.Get().(*schemas.BifrostResponse) -// *resp = schemas.BifrostResponse{} // Reset the struct -// return resp -// } - -// // releaseGroqResponse returns a Groq response to the pool. -// func releaseGroqResponse(resp *schemas.BifrostResponse) { -// if resp != nil { -// groqResponsePool.Put(resp) -// } -// } - // GroqProvider implements the Provider interface for Groq's API. type GroqProvider struct { logger schemas.Logger // Logger for provider operations @@ -90,81 +67,41 @@ func (provider *GroqProvider) GetProviderKey() schemas.ModelProvider { } // TextCompletion is not supported by the Groq provider. -func (provider *GroqProvider) TextCompletion(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *GroqProvider) TextCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { return nil, newUnsupportedOperationError("text completion", "groq") } // ChatCompletion performs a chat completion request to the Groq API. -func (provider *GroqProvider) ChatCompletion(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { - // Use centralized OpenAI converter since Groq is OpenAI-compatible - reqBody := openai.ToOpenAIChatCompletionRequest(input) +func (provider *GroqProvider) ChatCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { + return handleOpenAIChatCompletionRequest( + ctx, + provider.client, + provider.networkConfig.BaseURL+"/v1/chat/completions", + request, + key, + provider.networkConfig.ExtraHeaders, + provider.GetProviderKey(), + provider.sendBackRawResponse, + provider.logger, + ) +} - jsonBody, err := sonic.Marshal(reqBody) +func (provider *GroqProvider) Responses(ctx context.Context, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { + response, err := provider.ChatCompletion(ctx, key, request.ToChatRequest()) if err != nil { - return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, schemas.Groq) + return nil, err } - // Create request - req := fasthttp.AcquireRequest() - resp := fasthttp.AcquireResponse() - defer fasthttp.ReleaseRequest(req) - defer fasthttp.ReleaseResponse(resp) - - // Set any extra headers from network config - setExtraHeaders(req, provider.networkConfig.ExtraHeaders, nil) - - req.SetRequestURI(provider.networkConfig.BaseURL + "/v1/chat/completions") - req.Header.SetMethod("POST") - req.Header.SetContentType("application/json") - req.Header.Set("Authorization", "Bearer "+key.Value) - - req.SetBody(jsonBody) - - // Make request - bifrostErr := makeRequestWithContext(ctx, provider.client, req, resp) - if bifrostErr != nil { - return nil, bifrostErr - } - - // Handle error response - if resp.StatusCode() != fasthttp.StatusOK { - provider.logger.Debug(fmt.Sprintf("error from groq provider: %s", string(resp.Body()))) - - var errorResp map[string]interface{} - bifrostErr := handleProviderAPIError(resp, &errorResp) - bifrostErr.Error.Message = fmt.Sprintf("Groq error: %v", errorResp) - return nil, bifrostErr - } - - responseBody := resp.Body() - - // Pre-allocate response structs from pools - // response := acquireGroqResponse() - // defer releaseGroqResponse(response) - response := &schemas.BifrostResponse{} - - // Use enhanced response handler with pre-allocated response - rawResponse, bifrostErr := handleProviderResponse(responseBody, response, provider.sendBackRawResponse) - if bifrostErr != nil { - return nil, bifrostErr - } - - // Create final response - response.ExtraFields.Provider = schemas.Groq - - if provider.sendBackRawResponse { - response.ExtraFields.RawResponse = rawResponse - } - - if input.Params != nil { - response.ExtraFields.Params = *input.Params - } + response.ToResponsesOnly() + response.ExtraFields.RequestType = schemas.ResponsesRequest + response.ExtraFields.Provider = provider.GetProviderKey() + response.ExtraFields.ModelRequested = request.Model return response, nil } // Embedding is not supported by the Groq provider. -func (provider *GroqProvider) Embedding(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *GroqProvider) Embedding(ctx context.Context, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { return nil, newUnsupportedOperationError("embedding", "groq") } @@ -172,47 +109,37 @@ func (provider *GroqProvider) Embedding(ctx context.Context, key schemas.Key, in // It supports real-time streaming of responses using Server-Sent Events (SSE). // Uses Groq's OpenAI-compatible streaming format. // Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails. -func (provider *GroqProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, input *schemas.BifrostRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { - // Use centralized OpenAI converter since Groq is OpenAI-compatible - reqBody := openai.ToOpenAIChatCompletionRequest(input) - reqBody.Stream = schemas.Ptr(true) - - // Prepare Groq headers - headers := map[string]string{ - "Content-Type": "application/json", - "Accept": "text/event-stream", - "Cache-Control": "no-cache", - } - - headers["Authorization"] = "Bearer " + key.Value - +func (provider *GroqProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { // Use shared OpenAI-compatible streaming logic return handleOpenAIStreaming( ctx, provider.streamClient, provider.networkConfig.BaseURL+"/v1/chat/completions", - reqBody, - headers, + request, + map[string]string{"Authorization": "Bearer " + key.Value}, provider.networkConfig.ExtraHeaders, schemas.Groq, - input.Params, postHookRunner, provider.logger, ) } -func (provider *GroqProvider) Speech(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *GroqProvider) Speech(ctx context.Context, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { return nil, newUnsupportedOperationError("speech", "groq") } -func (provider *GroqProvider) SpeechStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, input *schemas.BifrostRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *GroqProvider) SpeechStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, newUnsupportedOperationError("speech stream", "groq") } -func (provider *GroqProvider) Transcription(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *GroqProvider) Transcription(ctx context.Context, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { return nil, newUnsupportedOperationError("transcription", "groq") } -func (provider *GroqProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, input *schemas.BifrostRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *GroqProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, newUnsupportedOperationError("transcription stream", "groq") } + +func (provider *GroqProvider) ResponsesStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("responses stream", "groq") +} diff --git a/core/providers/mistral.go b/core/providers/mistral.go index fc93cef7d3..7be69fbd7a 100644 --- a/core/providers/mistral.go +++ b/core/providers/mistral.go @@ -4,37 +4,14 @@ package providers import ( "context" - "fmt" "net/http" "strings" "time" - "github.com/bytedance/sonic" schemas "github.com/maximhq/bifrost/core/schemas" "github.com/valyala/fasthttp" ) -// // mistralResponsePool provides a pool for Mistral response objects. -// var mistralResponsePool = sync.Pool{ -// New: func() interface{} { -// return &schemas.BifrostResponse{} -// }, -// } - -// // acquireMistralResponse gets a Mistral response from the pool and resets it. -// func acquireMistralResponse() *schemas.BifrostResponse { -// resp := mistralResponsePool.Get().(*schemas.BifrostResponse) -// *resp = schemas.BifrostResponse{} // Reset the struct -// return resp -// } - -// // releaseMistralResponse returns a Mistral response to the pool. -// func releaseMistralResponse(resp *schemas.BifrostResponse) { -// if resp != nil { -// mistralResponsePool.Put(resp) -// } -// } - // MistralProvider implements the Provider interface for Mistral's API. type MistralProvider struct { logger schemas.Logger // Logger for provider operations @@ -90,191 +67,91 @@ func (provider *MistralProvider) GetProviderKey() schemas.ModelProvider { } // TextCompletion is not supported by the Mistral provider. -func (provider *MistralProvider) TextCompletion(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *MistralProvider) TextCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { return nil, newUnsupportedOperationError("text completion", "mistral") } // ChatCompletion performs a chat completion request to the Mistral API. -func (provider *MistralProvider) ChatCompletion(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { - // Use centralized OpenAI converter since Mistral is OpenAI-compatible - reqBody := openai.ToOpenAIChatCompletionRequest(input) +func (provider *MistralProvider) ChatCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { + return handleOpenAIChatCompletionRequest( + ctx, + provider.client, + provider.networkConfig.BaseURL+"/v1/chat/completions", + request, + key, + provider.networkConfig.ExtraHeaders, + provider.GetProviderKey(), + provider.sendBackRawResponse, + provider.logger, + ) +} - jsonBody, err := sonic.Marshal(reqBody) +func (provider *MistralProvider) Responses(ctx context.Context, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { + response, err := provider.ChatCompletion(ctx, key, request.ToChatRequest()) if err != nil { - return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, schemas.Mistral) - } - - // Create request - req := fasthttp.AcquireRequest() - resp := fasthttp.AcquireResponse() - defer fasthttp.ReleaseRequest(req) - defer fasthttp.ReleaseResponse(resp) - - // Set any extra headers from network config - setExtraHeaders(req, provider.networkConfig.ExtraHeaders, nil) - - req.SetRequestURI(provider.networkConfig.BaseURL + "/v1/chat/completions") - req.Header.SetMethod("POST") - req.Header.SetContentType("application/json") - req.Header.Set("Authorization", "Bearer "+key.Value) - - req.SetBody(jsonBody) - - // Make request - bifrostErr := makeRequestWithContext(ctx, provider.client, req, resp) - if bifrostErr != nil { - return nil, bifrostErr + return nil, err } - // Handle error response - if resp.StatusCode() != fasthttp.StatusOK { - provider.logger.Debug(fmt.Sprintf("error from mistral provider: %s", string(resp.Body()))) - - var errorResp map[string]interface{} - bifrostErr := handleProviderAPIError(resp, &errorResp) - bifrostErr.Error.Message = fmt.Sprintf("Mistral error: %v", errorResp) - return nil, bifrostErr - } - - responseBody := resp.Body() - - // Pre-allocate response structs from pools - // response := acquireMistralResponse() - // defer releaseMistralResponse(response) - response := &schemas.BifrostResponse{} - - // Use enhanced response handler with pre-allocated response - rawResponse, bifrostErr := handleProviderResponse(responseBody, response, provider.sendBackRawResponse) - if bifrostErr != nil { - return nil, bifrostErr - } - - response.ExtraFields.Provider = schemas.Mistral - - if provider.sendBackRawResponse { - response.ExtraFields.RawResponse = rawResponse - } - - if input.Params != nil { - response.ExtraFields.Params = *input.Params - } + response.ToResponsesOnly() + response.ExtraFields.RequestType = schemas.ResponsesRequest + response.ExtraFields.Provider = provider.GetProviderKey() + response.ExtraFields.ModelRequested = request.Model return response, nil } // Embedding generates embeddings for the given input text(s) using the Mistral API. // Supports Mistral's embedding models and returns a BifrostResponse containing the embedding(s). -func (provider *MistralProvider) Embedding(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { - - mistralRequest := mistral.ToMistralEmbeddingRequest(input) - - jsonBody, err := sonic.Marshal(mistralRequest) - if err != nil { - return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, schemas.Mistral) - } - - // Create request - req := fasthttp.AcquireRequest() - resp := fasthttp.AcquireResponse() - defer fasthttp.ReleaseRequest(req) - defer fasthttp.ReleaseResponse(resp) - - // Set any extra headers from network config - setExtraHeaders(req, provider.networkConfig.ExtraHeaders, nil) - - req.SetRequestURI(provider.networkConfig.BaseURL + "/v1/embeddings") - req.Header.SetMethod("POST") - req.Header.SetContentType("application/json") - req.Header.Set("Authorization", "Bearer "+key.Value) - - req.SetBody(jsonBody) - - // Make request - bifrostErr := makeRequestWithContext(ctx, provider.client, req, resp) - if bifrostErr != nil { - return nil, bifrostErr - } - - // Handle error response - if resp.StatusCode() != fasthttp.StatusOK { - provider.logger.Debug(fmt.Sprintf("error from mistral embedding provider: %s", string(resp.Body()))) - - var errorResp map[string]interface{} - bifrostErr := handleProviderAPIError(resp, &errorResp) - bifrostErr.Error.Message = fmt.Sprintf("Mistral embedding error: %v", errorResp) - return nil, bifrostErr - } - - responseBody := resp.Body() - - // Pre-allocate response structs from pools - // response := acquireMistralResponse() - response := &schemas.BifrostResponse{} - // defer releaseMistralResponse(response) - - // Use enhanced response handler with pre-allocated response - rawResponse, bifrostErr := handleProviderResponse(responseBody, response, provider.sendBackRawResponse) - if bifrostErr != nil { - return nil, bifrostErr - } - - response.ExtraFields.Provider = schemas.Mistral - - if input.Params != nil { - response.ExtraFields.Params = *input.Params - } - - if provider.sendBackRawResponse { - response.ExtraFields.RawResponse = rawResponse - } - - return response, nil +func (provider *MistralProvider) Embedding(ctx context.Context, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { + // Use the shared embedding request handler + return handleOpenAIEmbeddingRequest( + ctx, + provider.client, + provider.networkConfig.BaseURL+"/v1/embeddings", + request, + key, + provider.networkConfig.ExtraHeaders, + schemas.Mistral, + provider.sendBackRawResponse, + provider.logger, + ) } // ChatCompletionStream performs a streaming chat completion request to the Mistral API. // It supports real-time streaming of responses using Server-Sent Events (SSE). // Uses Mistral's OpenAI-compatible streaming format. // Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails. -func (provider *MistralProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, input *schemas.BifrostRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { - // Use centralized OpenAI converter since Mistral is OpenAI-compatible - reqBody := openai.ToOpenAIChatCompletionRequest(input) - reqBody.Stream = schemas.Ptr(true) - - // Prepare Mistral headers - headers := map[string]string{ - "Content-Type": "application/json", - "Authorization": "Bearer " + key.Value, - "Accept": "text/event-stream", - "Cache-Control": "no-cache", - } - +func (provider *MistralProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { // Use shared OpenAI-compatible streaming logic return handleOpenAIStreaming( ctx, provider.streamClient, provider.networkConfig.BaseURL+"/v1/chat/completions", - reqBody, - headers, + request, + map[string]string{"Authorization": "Bearer " + key.Value}, provider.networkConfig.ExtraHeaders, schemas.Mistral, - input.Params, postHookRunner, provider.logger, ) } -func (provider *MistralProvider) Speech(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *MistralProvider) Speech(ctx context.Context, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { return nil, newUnsupportedOperationError("speech", "mistral") } -func (provider *MistralProvider) SpeechStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, input *schemas.BifrostRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *MistralProvider) SpeechStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, newUnsupportedOperationError("speech stream", "mistral") } -func (provider *MistralProvider) Transcription(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *MistralProvider) Transcription(ctx context.Context, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { return nil, newUnsupportedOperationError("transcription", "mistral") } -func (provider *MistralProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, input *schemas.BifrostRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *MistralProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, newUnsupportedOperationError("transcription stream", "mistral") } + +func (provider *MistralProvider) ResponsesStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("responses stream", "mistral") +} diff --git a/core/providers/ollama.go b/core/providers/ollama.go index f974b37ea4..400f409fca 100644 --- a/core/providers/ollama.go +++ b/core/providers/ollama.go @@ -9,32 +9,10 @@ import ( "strings" "time" - "github.com/bytedance/sonic" schemas "github.com/maximhq/bifrost/core/schemas" "github.com/valyala/fasthttp" ) -// // ollamaResponsePool provides a pool for Ollama response objects. -// var ollamaResponsePool = sync.Pool{ -// New: func() interface{} { -// return &schemas.BifrostResponse{} -// }, -// } - -// // acquireOllamaResponse gets a Ollama response from the pool and resets it. -// func acquireOllamaResponse() *schemas.BifrostResponse { -// resp := ollamaResponsePool.Get().(*schemas.BifrostResponse) -// *resp = schemas.BifrostResponse{} // Reset the struct -// return resp -// } - -// // releaseOllamaResponse returns a Ollama response to the pool. -// func releaseOllamaResponse(resp *schemas.BifrostResponse) { -// if resp != nil { -// ollamaResponsePool.Put(resp) -// } -// } - // OllamaProvider implements the Provider interface for Ollama's API. type OllamaProvider struct { logger schemas.Logger // Logger for provider operations @@ -90,134 +68,98 @@ func (provider *OllamaProvider) GetProviderKey() schemas.ModelProvider { return schemas.Ollama } -// TextCompletion is not supported by the Ollama provider. -func (provider *OllamaProvider) TextCompletion(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { - return nil, newUnsupportedOperationError("text completion", "ollama") +func (provider *OllamaProvider) TextCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { + return handleOpenAITextCompletionRequest( + ctx, + provider.client, + provider.networkConfig.BaseURL+"/v1/completions", + request, + key, + provider.networkConfig.ExtraHeaders, + provider.GetProviderKey(), + provider.sendBackRawResponse, + provider.logger, + ) } // ChatCompletion performs a chat completion request to the Ollama API. -func (provider *OllamaProvider) ChatCompletion(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { - // Use centralized OpenAI converter since Ollama is OpenAI-compatible - openaiReq := openai.ToOpenAIChatCompletionRequest(input) +func (provider *OllamaProvider) ChatCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { + return handleOpenAIChatCompletionRequest( + ctx, + provider.client, + provider.networkConfig.BaseURL+"/v1/chat/completions", + request, + key, + provider.networkConfig.ExtraHeaders, + provider.GetProviderKey(), + provider.sendBackRawResponse, + provider.logger, + ) +} - jsonBody, err := sonic.Marshal(openaiReq) +func (provider *OllamaProvider) Responses(ctx context.Context, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { + response, err := provider.ChatCompletion(ctx, key, request.ToChatRequest()) if err != nil { - return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, schemas.Ollama) - } - - // Create request - req := fasthttp.AcquireRequest() - resp := fasthttp.AcquireResponse() - defer fasthttp.ReleaseRequest(req) - defer fasthttp.ReleaseResponse(resp) - - // Set any extra headers from network config - setExtraHeaders(req, provider.networkConfig.ExtraHeaders, nil) - - req.SetRequestURI(provider.networkConfig.BaseURL + "/v1/chat/completions") - req.Header.SetMethod("POST") - req.Header.SetContentType("application/json") - if key.Value != "" { - req.Header.Set("Authorization", "Bearer "+key.Value) + return nil, err } - req.SetBody(jsonBody) - - // Make request - bifrostErr := makeRequestWithContext(ctx, provider.client, req, resp) - if bifrostErr != nil { - return nil, bifrostErr - } - - // Handle error response - if resp.StatusCode() != fasthttp.StatusOK { - provider.logger.Debug(fmt.Sprintf("error from ollama provider: %s", string(resp.Body()))) - - var errorResp map[string]interface{} - bifrostErr := handleProviderAPIError(resp, &errorResp) - bifrostErr.Error.Message = fmt.Sprintf("Ollama error: %v", errorResp) - return nil, bifrostErr - } - - responseBody := resp.Body() - - // Pre-allocate response structs from pools - // response := acquireOllamaResponse() - // defer releaseOllamaResponse(response) - response := &schemas.BifrostResponse{} - - // Use enhanced response handler with pre-allocated response - rawResponse, bifrostErr := handleProviderResponse(responseBody, response, provider.sendBackRawResponse) - if bifrostErr != nil { - return nil, bifrostErr - } - - response.ExtraFields.Provider = schemas.Ollama - - if provider.sendBackRawResponse { - response.ExtraFields.RawResponse = rawResponse - } - - if input.Params != nil { - response.ExtraFields.Params = *input.Params - } + response.ToResponsesOnly() + response.ExtraFields.RequestType = schemas.ResponsesRequest + response.ExtraFields.Provider = provider.GetProviderKey() + response.ExtraFields.ModelRequested = request.Model return response, nil } -// Embedding is not supported by the Ollama provider. -func (provider *OllamaProvider) Embedding(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { - return nil, newUnsupportedOperationError("embedding", "ollama") +func (provider *OllamaProvider) Embedding(ctx context.Context, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { + return handleOpenAIEmbeddingRequest( + ctx, + provider.client, + provider.networkConfig.BaseURL+"/v1/embeddings", + request, + key, + provider.networkConfig.ExtraHeaders, + provider.GetProviderKey(), + provider.sendBackRawResponse, + provider.logger, + ) } // ChatCompletionStream performs a streaming chat completion request to the Ollama API. // It supports real-time streaming of responses using Server-Sent Events (SSE). // Uses Ollama's OpenAI-compatible streaming format. // Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails. -func (provider *OllamaProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, input *schemas.BifrostRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { - // Use centralized OpenAI converter since Ollama is OpenAI-compatible - reqBody := openai.ToOpenAIChatCompletionRequest(input) - reqBody.Stream = schemas.Ptr(true) - - // Prepare Ollama headers (Ollama typically doesn't require authorization, but we include it if provided) - headers := map[string]string{ - "Content-Type": "application/json", - "Accept": "text/event-stream", - "Cache-Control": "no-cache", - } - - // Only add Authorization header if key is provided (Ollama can run without auth) - if key.Value != "" { - headers["Authorization"] = "Bearer " + key.Value - } - +func (provider *OllamaProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { // Use shared OpenAI-compatible streaming logic return handleOpenAIStreaming( ctx, provider.streamClient, provider.networkConfig.BaseURL+"/v1/chat/completions", - reqBody, - headers, + request, + map[string]string{"Authorization": "Bearer " + key.Value}, provider.networkConfig.ExtraHeaders, schemas.Ollama, - input.Params, postHookRunner, provider.logger, ) } -func (provider *OllamaProvider) Speech(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *OllamaProvider) Speech(ctx context.Context, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { return nil, newUnsupportedOperationError("speech", "ollama") } -func (provider *OllamaProvider) SpeechStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, input *schemas.BifrostRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *OllamaProvider) SpeechStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, newUnsupportedOperationError("speech stream", "ollama") } -func (provider *OllamaProvider) Transcription(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *OllamaProvider) Transcription(ctx context.Context, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { return nil, newUnsupportedOperationError("transcription", "ollama") } -func (provider *OllamaProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, input *schemas.BifrostRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *OllamaProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, newUnsupportedOperationError("transcription stream", "ollama") } + +func (provider *OllamaProvider) ResponsesStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("responses stream", "ollama") +} diff --git a/core/providers/openai.go b/core/providers/openai.go index f877856df0..945a11f822 100644 --- a/core/providers/openai.go +++ b/core/providers/openai.go @@ -8,6 +8,7 @@ import ( "context" "fmt" "io" + "maps" "mime/multipart" "net/http" "strings" @@ -19,27 +20,6 @@ import ( "github.com/valyala/fasthttp" ) -// // openAIResponsePool provides a pool for OpenAI response objects. -// var openAIResponsePool = sync.Pool{ -// New: func() interface{} { -// return &schemas.BifrostResponse{} -// }, -// } - -// // acquireOpenAIResponse gets an OpenAI response from the pool and resets it. -// func acquireOpenAIResponse() *schemas.BifrostResponse { -// resp := openAIResponsePool.Get().(*schemas.BifrostResponse) -// *resp = schemas.BifrostResponse{} // Reset the struct -// return resp -// } - -// // releaseOpenAIResponse returns an OpenAI response to the pool. -// func releaseOpenAIResponse(resp *schemas.BifrostResponse) { -// if resp != nil { -// openAIResponsePool.Put(resp) -// } -// } - // OpenAIProvider implements the Provider interface for OpenAI's GPT API. type OpenAIProvider struct { logger schemas.Logger // Logger for provider operations @@ -98,23 +78,134 @@ func (provider *OpenAIProvider) GetProviderKey() schemas.ModelProvider { // TextCompletion is not supported by the OpenAI provider. // Returns an error indicating that text completion is not available. -func (provider *OpenAIProvider) TextCompletion(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { - return nil, newUnsupportedOperationError("text completion", "openai") +func (provider *OpenAIProvider) TextCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { + if err := checkOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.TextCompletionRequest); err != nil { + return nil, err + } + return handleOpenAITextCompletionRequest( + ctx, + provider.client, + provider.networkConfig.BaseURL+"/v1/completions", + request, + key, + provider.networkConfig.ExtraHeaders, + provider.GetProviderKey(), + provider.sendBackRawResponse, + provider.logger, + ) +} + +func handleOpenAITextCompletionRequest( + ctx context.Context, + client *fasthttp.Client, + url string, + request *schemas.BifrostTextCompletionRequest, + key schemas.Key, + extraHeaders map[string]string, + providerName schemas.ModelProvider, + sendBackRawResponse bool, + logger schemas.Logger, +) (*schemas.BifrostResponse, *schemas.BifrostError) { + reqBody := openai.ToOpenAITextCompletionRequest(request) + if reqBody == nil { + return nil, newBifrostOperationError("text completion input is not provided", nil, providerName) + } + // Create request + req := fasthttp.AcquireRequest() + resp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseRequest(req) + defer fasthttp.ReleaseResponse(resp) + + jsonBody, err := sonic.Marshal(reqBody) + if err != nil { + return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, providerName) + } + + // Set any extra headers from network config + setExtraHeaders(req, extraHeaders, nil) + + req.SetRequestURI(url) + req.Header.SetMethod("POST") + req.Header.SetContentType("application/json") + req.Header.Set("Authorization", "Bearer "+key.Value) + + req.SetBody(jsonBody) + + // Make request + bifrostErr := makeRequestWithContext(ctx, client, req, resp) + if bifrostErr != nil { + return nil, bifrostErr + } + + // Handle error response + if resp.StatusCode() != fasthttp.StatusOK { + logger.Debug(fmt.Sprintf("error from %s provider: %s", providerName, string(resp.Body()))) + + var errorResp map[string]interface{} + bifrostErr := handleProviderAPIError(resp, &errorResp) + bifrostErr.Error.Message = fmt.Sprintf("%s error: %v", providerName, errorResp) + return nil, bifrostErr + } + + responseBody := resp.Body() + + response := &schemas.BifrostResponse{} + + rawResponse, bifrostErr := handleProviderResponse(responseBody, response, sendBackRawResponse) + if bifrostErr != nil { + return nil, bifrostErr + } + + response.ExtraFields.Provider = providerName + response.ExtraFields.ModelRequested = request.Model + response.ExtraFields.RequestType = schemas.TextCompletionRequest + + // Set raw response if enabled + if sendBackRawResponse { + response.ExtraFields.RawResponse = rawResponse + } + + return response, nil } // ChatCompletion performs a chat completion request to the OpenAI API. // It supports both text and image content in messages. // Returns a BifrostResponse containing the completion results or an error if the request fails. -func (provider *OpenAIProvider) ChatCompletion(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *OpenAIProvider) ChatCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { // Check if chat completion is allowed for this provider - if err := checkOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.OperationChatCompletion); err != nil { + if err := checkOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.ChatCompletionRequest); err != nil { return nil, err } - providerName := provider.GetProviderKey() + return handleOpenAIChatCompletionRequest( + ctx, + provider.client, + provider.networkConfig.BaseURL+"/v1/chat/completions", + request, + key, + provider.networkConfig.ExtraHeaders, + provider.GetProviderKey(), + provider.sendBackRawResponse, + provider.logger, + ) +} +func handleOpenAIChatCompletionRequest( + ctx context.Context, + client *fasthttp.Client, + url string, + request *schemas.BifrostChatRequest, + key schemas.Key, + extraHeaders map[string]string, + providerName schemas.ModelProvider, + sendBackRawResponse bool, + logger schemas.Logger, +) (*schemas.BifrostResponse, *schemas.BifrostError) { // Use centralized converter - reqBody := openai.ToOpenAIChatCompletionRequest(input) + reqBody := openai.ToOpenAIChatRequest(request) + if reqBody == nil { + return nil, newBifrostOperationError("chat completion input is not provided", nil, providerName) + } jsonBody, err := sonic.Marshal(reqBody) if err != nil { @@ -128,9 +219,9 @@ func (provider *OpenAIProvider) ChatCompletion(ctx context.Context, key schemas. defer fasthttp.ReleaseResponse(resp) // Set any extra headers from network config - setExtraHeaders(req, provider.networkConfig.ExtraHeaders, nil) + setExtraHeaders(req, extraHeaders, nil) - req.SetRequestURI(provider.networkConfig.BaseURL + "/v1/chat/completions") + req.SetRequestURI(url) req.Header.SetMethod("POST") req.Header.SetContentType("application/json") req.Header.Set("Authorization", "Bearer "+key.Value) @@ -138,14 +229,14 @@ func (provider *OpenAIProvider) ChatCompletion(ctx context.Context, key schemas. req.SetBody(jsonBody) // Make request - bifrostErr := makeRequestWithContext(ctx, provider.client, req, resp) + bifrostErr := makeRequestWithContext(ctx, client, req, resp) if bifrostErr != nil { return nil, bifrostErr } // Handle error response if resp.StatusCode() != fasthttp.StatusOK { - provider.logger.Debug(fmt.Sprintf("error from %s provider: %s", providerName, string(resp.Body()))) + logger.Debug(fmt.Sprintf("error from %s provider: %s", providerName, string(resp.Body()))) return nil, parseOpenAIError(resp) } @@ -157,21 +248,113 @@ func (provider *OpenAIProvider) ChatCompletion(ctx context.Context, key schemas. response := &schemas.BifrostResponse{} // Use enhanced response handler with pre-allocated response - rawResponse, bifrostErr := handleProviderResponse(responseBody, response, provider.sendBackRawResponse) + rawResponse, bifrostErr := handleProviderResponse(responseBody, response, sendBackRawResponse) if bifrostErr != nil { return nil, bifrostErr } // Set raw response if enabled - if provider.sendBackRawResponse { + if sendBackRawResponse { response.ExtraFields.RawResponse = rawResponse } - if input.Params != nil { - response.ExtraFields.Params = *input.Params + response.ExtraFields.Provider = providerName + response.ExtraFields.ModelRequested = request.Model + response.ExtraFields.RequestType = schemas.ChatCompletionRequest + + return response, nil +} + +func (provider *OpenAIProvider) Responses(ctx context.Context, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { + // Check if chat completion is allowed for this provider + if err := checkOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.ResponsesRequest); err != nil { + return nil, err + } + + return handleOpenAIResponsesRequest( + ctx, + provider.client, + provider.networkConfig.BaseURL+"/v1/responses", + request, + key, + provider.networkConfig.ExtraHeaders, + provider.GetProviderKey(), + provider.sendBackRawResponse, + provider.logger, + ) +} + +func handleOpenAIResponsesRequest( + ctx context.Context, + client *fasthttp.Client, + url string, + request *schemas.BifrostResponsesRequest, + key schemas.Key, + extraHeaders map[string]string, + providerName schemas.ModelProvider, + sendBackRawResponse bool, + logger schemas.Logger, +) (*schemas.BifrostResponse, *schemas.BifrostError) { + // Use centralized converter + reqBody := openai.ToOpenAIResponsesRequest(request) + if reqBody == nil { + return nil, newBifrostOperationError("responses input is not provided", nil, providerName) + } + + jsonBody, err := sonic.Marshal(reqBody) + if err != nil { + return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, providerName) + } + + // Create request + req := fasthttp.AcquireRequest() + resp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseRequest(req) + defer fasthttp.ReleaseResponse(resp) + + // Set any extra headers from network config + setExtraHeaders(req, extraHeaders, nil) + + req.SetRequestURI(url) + req.Header.SetMethod("POST") + req.Header.SetContentType("application/json") + req.Header.Set("Authorization", "Bearer "+key.Value) + + req.SetBody(jsonBody) + + // Make request + bifrostErr := makeRequestWithContext(ctx, client, req, resp) + if bifrostErr != nil { + return nil, bifrostErr + } + + // Handle error response + if resp.StatusCode() != fasthttp.StatusOK { + logger.Debug(fmt.Sprintf("error from %s provider: %s", providerName, string(resp.Body()))) + return nil, parseOpenAIError(resp) + } + + responseBody := resp.Body() + + // Pre-allocate response structs from pools + // response := acquireOpenAIResponse() + // defer releaseOpenAIResponse(response) + response := &schemas.BifrostResponse{} + + // Use enhanced response handler with pre-allocated response + rawResponse, bifrostErr := handleProviderResponse(responseBody, response, sendBackRawResponse) + if bifrostErr != nil { + return nil, bifrostErr + } + + // Set raw response if enabled + if sendBackRawResponse { + response.ExtraFields.RawResponse = rawResponse } response.ExtraFields.Provider = providerName + response.ExtraFields.ModelRequested = request.Model + response.ExtraFields.RequestType = schemas.ResponsesRequest return response, nil } @@ -179,34 +362,46 @@ func (provider *OpenAIProvider) ChatCompletion(ctx context.Context, key schemas. // Embedding generates embeddings for the given input text(s). // The input can be either a single string or a slice of strings for batch embedding. // Returns a BifrostResponse containing the embedding(s) and any error that occurred. -func (provider *OpenAIProvider) Embedding(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *OpenAIProvider) Embedding(ctx context.Context, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { // Check if embedding is allowed for this provider - if err := checkOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.OperationEmbedding); err != nil { + if err := checkOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.EmbeddingRequest); err != nil { return nil, err } - providerName := provider.GetProviderKey() - - // Use centralized converter - reqBody := openai.ToOpenAIEmbeddingRequest(input) - // Use the shared embedding request handler return handleOpenAIEmbeddingRequest( ctx, provider.client, provider.networkConfig.BaseURL+"/v1/embeddings", - reqBody, + request, key, - input.Params, provider.networkConfig.ExtraHeaders, - providerName, + provider.GetProviderKey(), provider.sendBackRawResponse, provider.logger, ) } -func handleOpenAIEmbeddingRequest(ctx context.Context, client *fasthttp.Client, url string, requestBody interface{}, key schemas.Key, params *schemas.ModelParameters, extraHeaders map[string]string, providerName schemas.ModelProvider, sendBackRawResponse bool, logger schemas.Logger) (*schemas.BifrostResponse, *schemas.BifrostError) { - jsonBody, err := sonic.Marshal(requestBody) +// handleOpenAIEmbeddingRequest handles embedding requests for OpenAI-compatible APIs. +// This shared function reduces code duplication between providers that use the same embedding request format. +func handleOpenAIEmbeddingRequest( + ctx context.Context, + client *fasthttp.Client, + url string, + request *schemas.BifrostEmbeddingRequest, + key schemas.Key, + extraHeaders map[string]string, + providerName schemas.ModelProvider, + sendBackRawResponse bool, + logger schemas.Logger, +) (*schemas.BifrostResponse, *schemas.BifrostError) { + // Use centralized converter + reqBody := openai.ToOpenAIEmbeddingRequest(request) + if reqBody == nil { + return nil, newBifrostOperationError("embedding input is not provided", nil, providerName) + } + + jsonBody, err := sonic.Marshal(reqBody) if err != nil { return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, providerName) } @@ -251,10 +446,8 @@ func handleOpenAIEmbeddingRequest(ctx context.Context, client *fasthttp.Client, } response.ExtraFields.Provider = providerName - - if params != nil { - response.ExtraFields.Params = *params - } + response.ExtraFields.ModelRequested = request.Model + response.ExtraFields.RequestType = schemas.EmbeddingRequest if sendBackRawResponse { response.ExtraFields.RawResponse = rawResponse @@ -266,60 +459,59 @@ func handleOpenAIEmbeddingRequest(ctx context.Context, client *fasthttp.Client, // ChatCompletionStream handles streaming for OpenAI chat completions. // It formats messages, prepares request body, and uses shared streaming logic. // Returns a channel for streaming responses and any error that occurred. -func (provider *OpenAIProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, input *schemas.BifrostRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *OpenAIProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { // Check if chat completion stream is allowed for this provider - if err := checkOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.OperationChatCompletionStream); err != nil { + if err := checkOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.ChatCompletionStreamRequest); err != nil { return nil, err } - // Use centralized converter - reqBody := openai.ToOpenAIChatCompletionRequest(input) - reqBody.Stream = schemas.Ptr(true) - reqBody.StreamOptions = &map[string]interface{}{ - "include_usage": true, - } - - // Prepare OpenAI headers - headers := map[string]string{ - "Content-Type": "application/json", - "Authorization": "Bearer " + key.Value, - "Accept": "text/event-stream", - "Cache-Control": "no-cache", - } - - providerName := provider.GetProviderKey() - // Use shared streaming logic return handleOpenAIStreaming( ctx, provider.streamClient, provider.networkConfig.BaseURL+"/v1/chat/completions", - reqBody, - headers, + request, + map[string]string{"Authorization": "Bearer " + key.Value}, provider.networkConfig.ExtraHeaders, - providerName, - input.Params, + provider.GetProviderKey(), postHookRunner, provider.logger, ) } -// performOpenAICompatibleStreaming handles streaming for OpenAI-compatible APIs (OpenAI, Azure). +// handleOpenAIStreaming handles streaming for OpenAI-compatible APIs. // This shared function reduces code duplication between providers that use the same SSE format. func handleOpenAIStreaming( ctx context.Context, - httpClient *http.Client, + client *http.Client, url string, - requestBody T, - headers map[string]string, + request *schemas.BifrostChatRequest, + authHeader map[string]string, extraHeaders map[string]string, providerName schemas.ModelProvider, - params *schemas.ModelParameters, postHookRunner schemas.PostHookRunner, logger schemas.Logger, ) (chan *schemas.BifrostStream, *schemas.BifrostError) { + reqBody := openai.ToOpenAIChatRequest(request) + if reqBody == nil { + return nil, newBifrostOperationError("chat completion input is not provided", nil, providerName) + } + reqBody.Stream = schemas.Ptr(true) + reqBody.StreamOptions = &schemas.ChatStreamOptions{ + IncludeUsage: schemas.Ptr(true), + } + + // Prepare SGL headers (SGL typically doesn't require authorization, but we include it if provided) + headers := map[string]string{ + "Content-Type": "application/json", + "Accept": "text/event-stream", + "Cache-Control": "no-cache", + } + + // Copy auth header to headers + maps.Copy(headers, authHeader) - jsonBody, err := sonic.Marshal(requestBody) + jsonBody, err := sonic.Marshal(reqBody) if err != nil { return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, providerName) } @@ -339,7 +531,7 @@ func handleOpenAIStreaming( } // Make the request - resp, err := httpClient.Do(req) + resp, err := client.Do(req) if err != nil { return nil, newBifrostOperationError(schemas.ErrProviderRequest, err, providerName) } @@ -362,7 +554,7 @@ func handleOpenAIStreaming( usage := &schemas.LLMUsage{} var finishReason *string - var id string + var messageID string for scanner.Scan() { line := scanner.Text() @@ -453,15 +645,17 @@ func handleOpenAIStreaming( response.Choices[0].FinishReason = nil } - if response.ID != "" && id == "" { - id = response.ID + if response.ID != "" && messageID == "" { + messageID = response.ID } // Handle regular content chunks if choice.BifrostStreamResponseChoice != nil && (choice.BifrostStreamResponseChoice.Delta.Content != nil || len(choice.BifrostStreamResponseChoice.Delta.ToolCalls) > 0) { chunkIndex++ + response.ExtraFields.RequestType = schemas.ChatCompletionStreamRequest response.ExtraFields.Provider = providerName + response.ExtraFields.ModelRequested = request.Model response.ExtraFields.ChunkIndex = chunkIndex processAndSendResponse(ctx, postHookRunner, &response, responseChan, logger) @@ -471,9 +665,9 @@ func handleOpenAIStreaming( // Handle scanner errors first if err := scanner.Err(); err != nil { logger.Warn(fmt.Sprintf("Error reading stream: %v", err)) - processAndSendError(ctx, postHookRunner, err, responseChan, logger) + processAndSendError(ctx, postHookRunner, err, responseChan, schemas.ChatCompletionStreamRequest, providerName, request.Model, logger) } else { - response := createBifrostChatCompletionChunkResponse(id, usage, finishReason, chunkIndex, params, providerName) + response := createBifrostChatCompletionChunkResponse(messageID, usage, finishReason, chunkIndex, schemas.ChatCompletionStreamRequest, providerName, request.Model) handleStreamEndWithSuccess(ctx, response, postHookRunner, responseChan, logger) } }() @@ -484,17 +678,20 @@ func handleOpenAIStreaming( // Speech handles non-streaming speech synthesis requests. // It formats the request body, makes the API call, and returns the response. // Returns the response and any error that occurred. -func (provider *OpenAIProvider) Speech(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { - if err := checkOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.OperationSpeech); err != nil { +func (provider *OpenAIProvider) Speech(ctx context.Context, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { + if err := checkOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.SpeechRequest); err != nil { return nil, err } providerName := provider.GetProviderKey() // Use centralized converter - openaiReq := openai.ToOpenAISpeechRequest(input) + reqBody := openai.ToOpenAISpeechRequest(request) + if reqBody == nil { + return nil, newBifrostOperationError("speech input is not provided", nil, providerName) + } - jsonBody, err := sonic.Marshal(openaiReq) + jsonBody, err := sonic.Marshal(reqBody) if err != nil { return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, providerName) } @@ -535,34 +732,35 @@ func (provider *OpenAIProvider) Speech(ctx context.Context, key schemas.Key, inp // The audio data is typically in MP3, WAV, or other audio formats as specified by response_format bifrostResponse := &schemas.BifrostResponse{ Object: "audio.speech", - Model: input.Model, + Model: request.Model, Speech: &schemas.BifrostSpeech{ Audio: audioData, }, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, + RequestType: schemas.SpeechRequest, + Provider: providerName, + ModelRequested: request.Model, }, } - if input.Params != nil { - bifrostResponse.ExtraFields.Params = *input.Params - } - return bifrostResponse, nil } // SpeechStream handles streaming for speech synthesis. // It formats the request body, creates HTTP request, and uses shared streaming logic. // Returns a channel for streaming responses and any error that occurred. -func (provider *OpenAIProvider) SpeechStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, input *schemas.BifrostRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { - if err := checkOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.OperationSpeechStream); err != nil { +func (provider *OpenAIProvider) SpeechStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { + if err := checkOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.SpeechStreamRequest); err != nil { return nil, err } providerName := provider.GetProviderKey() // Use centralized converter - reqBody := openai.ToOpenAISpeechRequest(input) + reqBody := openai.ToOpenAISpeechRequest(request) + if reqBody == nil { + return nil, newBifrostOperationError("speech input is not provided", nil, providerName) + } reqBody.StreamFormat = schemas.Ptr("sse") jsonBody, err := sonic.Marshal(reqBody) @@ -674,18 +872,16 @@ func (provider *OpenAIProvider) SpeechStream(ctx context.Context, postHookRunner response.Speech = &speechResponse response.Object = "audio.speech.chunk" - response.Model = input.Model + response.Model = request.Model response.ExtraFields = schemas.BifrostResponseExtraFields{ - Provider: providerName, + RequestType: schemas.SpeechStreamRequest, + Provider: providerName, + ModelRequested: request.Model, } response.ExtraFields.ChunkIndex = chunkIndex if speechResponse.Usage != nil { - if input.Params != nil { - response.ExtraFields.Params = *input.Params - } - ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) processAndSendResponse(ctx, postHookRunner, &response, responseChan, provider.logger) return @@ -697,7 +893,7 @@ func (provider *OpenAIProvider) SpeechStream(ctx context.Context, postHookRunner // Handle scanner errors if err := scanner.Err(); err != nil { provider.logger.Warn(fmt.Sprintf("Error reading stream: %v", err)) - processAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) + processAndSendError(ctx, postHookRunner, err, responseChan, schemas.SpeechStreamRequest, providerName, request.Model, provider.logger) } }() @@ -707,15 +903,18 @@ func (provider *OpenAIProvider) SpeechStream(ctx context.Context, postHookRunner // Transcription handles non-streaming transcription requests. // It creates a multipart form, adds fields, makes the API call, and returns the response. // Returns the response and any error that occurred. -func (provider *OpenAIProvider) Transcription(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { - if err := checkOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.OperationTranscription); err != nil { +func (provider *OpenAIProvider) Transcription(ctx context.Context, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { + if err := checkOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.TranscriptionRequest); err != nil { return nil, err } providerName := provider.GetProviderKey() // Use centralized converter - reqBody := openai.ToOpenAITranscriptionRequest(input) + reqBody := openai.ToOpenAITranscriptionRequest(request) + if reqBody == nil { + return nil, newBifrostOperationError("transcription input is not provided", nil, providerName) + } // Create multipart form var body bytes.Buffer @@ -773,10 +972,12 @@ func (provider *OpenAIProvider) Transcription(ctx context.Context, key schemas.K // Create final response bifrostResponse := &schemas.BifrostResponse{ Object: "audio.transcription", - Model: input.Model, + Model: request.Model, Transcribe: transcribeResponse, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, + RequestType: schemas.TranscriptionRequest, + Provider: providerName, + ModelRequested: request.Model, }, } @@ -784,23 +985,19 @@ func (provider *OpenAIProvider) Transcription(ctx context.Context, key schemas.K bifrostResponse.ExtraFields.RawResponse = rawResponse } - if input.Params != nil { - bifrostResponse.ExtraFields.Params = *input.Params - } - return bifrostResponse, nil } -func (provider *OpenAIProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, input *schemas.BifrostRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { - if err := checkOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.OperationTranscriptionStream); err != nil { +func (provider *OpenAIProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { + if err := checkOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.TranscriptionStreamRequest); err != nil { return nil, err } providerName := provider.GetProviderKey() // Use centralized converter - reqBody := openai.ToOpenAITranscriptionRequest(input) + reqBody := openai.ToOpenAITranscriptionRequest(request) if reqBody == nil { return nil, newBifrostOperationError("transcription input is not provided", nil, providerName) } @@ -916,18 +1113,16 @@ func (provider *OpenAIProvider) TranscriptionStream(ctx context.Context, postHoo response.Transcribe = &transcriptionResponse response.Object = "audio.transcription.chunk" - response.Model = input.Model + response.Model = request.Model response.ExtraFields = schemas.BifrostResponseExtraFields{ - Provider: providerName, + RequestType: schemas.TranscriptionStreamRequest, + Provider: providerName, + ModelRequested: request.Model, } response.ExtraFields.ChunkIndex = chunkIndex if transcriptionResponse.Usage != nil { - if input.Params != nil { - response.ExtraFields.Params = *input.Params - } - ctx = context.WithValue(ctx, schemas.BifrostContextKeyStreamEndIndicator, true) processAndSendResponse(ctx, postHookRunner, &response, responseChan, provider.logger) return @@ -939,7 +1134,7 @@ func (provider *OpenAIProvider) TranscriptionStream(ctx context.Context, postHoo // Handle scanner errors if err := scanner.Err(); err != nil { provider.logger.Warn(fmt.Sprintf("Error reading stream: %v", err)) - processAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) + processAndSendError(ctx, postHookRunner, err, responseChan, schemas.TranscriptionStreamRequest, providerName, request.Model, provider.logger) } }() @@ -980,31 +1175,12 @@ func parseTranscriptionFormDataBodyFromRequest(writer *multipart.Writer, openaiR } } - if openaiReq.Temperature != nil { - if err := writer.WriteField("temperature", fmt.Sprintf("%f", *openaiReq.Temperature)); err != nil { - return newBifrostOperationError("failed to write temperature field", err, providerName) - } - } - if openaiReq.Stream != nil && *openaiReq.Stream { if err := writer.WriteField("stream", "true"); err != nil { return newBifrostOperationError("failed to write stream field", err, providerName) } } - // Handle array parameters specially for OpenAI's form data format - for _, item := range openaiReq.TimestampGranularities { - if err := writer.WriteField("timestamp_granularities[]", item); err != nil { - return newBifrostOperationError("failed to write timestamp_granularities param", err, providerName) - } - } - - for _, item := range openaiReq.Include { - if err := writer.WriteField("include[]", item); err != nil { - return newBifrostOperationError("failed to write include param", err, providerName) - } - } - // Close the multipart writer if err := writer.Close(); err != nil { return newBifrostOperationError("failed to close multipart writer", err, providerName) @@ -1013,7 +1189,7 @@ func parseTranscriptionFormDataBodyFromRequest(writer *multipart.Writer, openaiR return nil } -func parseTranscriptionFormDataBody(writer *multipart.Writer, input *schemas.TranscriptionInput, model string, params *schemas.ModelParameters, providerName schemas.ModelProvider) *schemas.BifrostError { +func parseTranscriptionFormDataBody(writer *multipart.Writer, input *schemas.TranscriptionInput, model string, params *schemas.TranscriptionParameters, providerName schemas.ModelProvider) *schemas.BifrostError { // Add file field fileWriter, err := writer.CreateFormFile("file", "audio.mp3") // OpenAI requires a filename if err != nil { @@ -1029,20 +1205,20 @@ func parseTranscriptionFormDataBody(writer *multipart.Writer, input *schemas.Tra } // Add optional fields - if input.Language != nil { - if err := writer.WriteField("language", *input.Language); err != nil { + if params.Language != nil { + if err := writer.WriteField("language", *params.Language); err != nil { return newBifrostOperationError("failed to write language field", err, providerName) } } - if input.Prompt != nil { - if err := writer.WriteField("prompt", *input.Prompt); err != nil { + if params.Prompt != nil { + if err := writer.WriteField("prompt", *params.Prompt); err != nil { return newBifrostOperationError("failed to write prompt field", err, providerName) } } - if input.ResponseFormat != nil { - if err := writer.WriteField("response_format", *input.ResponseFormat); err != nil { + if params.ResponseFormat != nil { + if err := writer.WriteField("response_format", *params.ResponseFormat); err != nil { return newBifrostOperationError("failed to write response_format field", err, providerName) } } @@ -1085,6 +1261,10 @@ func parseTranscriptionFormDataBody(writer *multipart.Writer, input *schemas.Tra return nil } +func (provider *OpenAIProvider) ResponsesStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("responses stream", "openai") +} + func parseOpenAIError(resp *fasthttp.Response) *schemas.BifrostError { var errorResp schemas.BifrostError diff --git a/core/providers/openrouter.go b/core/providers/openrouter.go index 34333c1838..7476621db0 100644 --- a/core/providers/openrouter.go +++ b/core/providers/openrouter.go @@ -4,13 +4,10 @@ package providers import ( "context" - "fmt" "net/http" "strings" - "sync" "time" - "github.com/bytedance/sonic" schemas "github.com/maximhq/bifrost/core/schemas" "github.com/valyala/fasthttp" ) @@ -24,27 +21,6 @@ type OpenRouterProvider struct { sendBackRawResponse bool // Whether to include raw response in BifrostResponse } -// openRouterTextCompletionResponsePool provides a pool for OpenRouter text completion response objects. -var openRouterTextCompletionResponsePool = sync.Pool{ - New: func() interface{} { - return &openai.OpenAITextCompletionResponse{} - }, -} - -// acquireOpenRouterTextResponse gets an OpenRouter text completion response from the pool and resets it. -func acquireOpenRouterTextResponse() *openai.OpenAITextCompletionResponse { - resp := openRouterTextCompletionResponsePool.Get().(*openai.OpenAITextCompletionResponse) - *resp = openai.OpenAITextCompletionResponse{} // Reset the struct - return resp -} - -// releaseOpenRouterTextResponse returns an OpenRouter text completion response to the pool. -func releaseOpenRouterTextResponse(resp *openai.OpenAITextCompletionResponse) { - if resp != nil { - openRouterTextCompletionResponsePool.Put(resp) - } -} - // NewOpenRouterProvider creates a new OpenRouter provider instance. // It initializes the HTTP client with the provided configuration and sets up response pools. // The client is configured with timeouts, concurrency limits, and optional proxy settings. @@ -62,11 +38,6 @@ func NewOpenRouterProvider(config *schemas.ProviderConfig, logger schemas.Logger Timeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), } - // Pre-warm response pools - for i := 0; i < config.ConcurrencyAndBufferSize.Concurrency; i++ { - openRouterTextCompletionResponsePool.Put(&openai.OpenAITextCompletionResponse{}) - } - // Configure proxy if provided client = configureProxy(client, config.ProxyConfig, logger) @@ -91,184 +62,88 @@ func (provider *OpenRouterProvider) GetProviderKey() schemas.ModelProvider { } // TextCompletion performs a text completion request to the OpenRouter API. -func (provider *OpenRouterProvider) TextCompletion(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { - // Use centralized OpenRouter converter for completion request - reqBody := openai.ToOpenAITextCompletionRequest(input) - - jsonBody, err := sonic.Marshal(reqBody) - if err != nil { - return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, schemas.OpenRouter) - } - - // Create request - req := fasthttp.AcquireRequest() - resp := fasthttp.AcquireResponse() - defer fasthttp.ReleaseRequest(req) - defer fasthttp.ReleaseResponse(resp) - - // Set any extra headers from network config - setExtraHeaders(req, provider.networkConfig.ExtraHeaders, nil) - - req.SetRequestURI(provider.networkConfig.BaseURL + "/v1/completions") - req.Header.SetMethod("POST") - req.Header.SetContentType("application/json") - req.Header.Set("Authorization", "Bearer "+key.Value) - - req.SetBody(jsonBody) - - // Make request - bifrostErr := makeRequestWithContext(ctx, provider.client, req, resp) - if bifrostErr != nil { - return nil, bifrostErr - } - - // Handle error response - if resp.StatusCode() != fasthttp.StatusOK { - var errorResp map[string]interface{} - bifrostErr := handleProviderAPIError(resp, &errorResp) - bifrostErr.Error.Message = fmt.Sprintf("OpenRouter error: %v", errorResp) - return nil, bifrostErr - } - - responseBody := resp.Body() - - // Create response object from pool - response := acquireOpenRouterTextResponse() - defer releaseOpenRouterTextResponse(response) - - rawResponse, bifrostErr := handleProviderResponse(responseBody, response, provider.sendBackRawResponse) - if bifrostErr != nil { - return nil, bifrostErr - } - - // Use centralized OpenRouter converter - bifrostResponse := response.ToBifrostResponse() - - // Set raw response if enabled - if provider.sendBackRawResponse { - bifrostResponse.ExtraFields.RawResponse = rawResponse - } - - if input.Params != nil { - bifrostResponse.ExtraFields.Params = *input.Params - } - - return bifrostResponse, nil +func (provider *OpenRouterProvider) TextCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { + return handleOpenAITextCompletionRequest( + ctx, + provider.client, + provider.networkConfig.BaseURL+"/v1/completions", + request, + key, + provider.networkConfig.ExtraHeaders, + provider.GetProviderKey(), + provider.sendBackRawResponse, + provider.logger, + ) } // ChatCompletion performs a chat completion request to the OpenRouter API. -func (provider *OpenRouterProvider) ChatCompletion(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { - // Use centralized OpenAI converter since OpenRouter is OpenAI-compatible - reqBody := openai.ToOpenAIChatCompletionRequest(input) - - jsonBody, err := sonic.Marshal(reqBody) - if err != nil { - return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, schemas.OpenRouter) - } - - // Create request - req := fasthttp.AcquireRequest() - resp := fasthttp.AcquireResponse() - defer fasthttp.ReleaseRequest(req) - defer fasthttp.ReleaseResponse(resp) - - // Set any extra headers from network config - setExtraHeaders(req, provider.networkConfig.ExtraHeaders, nil) - - req.SetRequestURI(provider.networkConfig.BaseURL + "/v1/chat/completions") - req.Header.SetMethod("POST") - req.Header.SetContentType("application/json") - req.Header.Set("Authorization", "Bearer "+key.Value) - - req.SetBody(jsonBody) - - // Make request - bifrostErr := makeRequestWithContext(ctx, provider.client, req, resp) - if bifrostErr != nil { - return nil, bifrostErr - } - - // Handle error response - if resp.StatusCode() != fasthttp.StatusOK { - provider.logger.Debug(fmt.Sprintf("error from openrouter provider: %s", string(resp.Body()))) - - var errorResp map[string]interface{} - bifrostErr := handleProviderAPIError(resp, &errorResp) - bifrostErr.Error.Message = fmt.Sprintf("OpenRouter error: %v", errorResp) - return nil, bifrostErr - } - - responseBody := resp.Body() - - response := &schemas.BifrostResponse{} - - // Use enhanced response handler with pre-allocated response - rawResponse, bifrostErr := handleProviderResponse(responseBody, response, provider.sendBackRawResponse) - if bifrostErr != nil { - return nil, bifrostErr - } - - response.ExtraFields.Provider = schemas.OpenRouter - - if provider.sendBackRawResponse { - response.ExtraFields.RawResponse = rawResponse - } - - if input.Params != nil { - response.ExtraFields.Params = *input.Params - } +func (provider *OpenRouterProvider) ChatCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { + return handleOpenAIChatCompletionRequest( + ctx, + provider.client, + provider.networkConfig.BaseURL+"/v1/chat/completions", + request, + key, + provider.networkConfig.ExtraHeaders, + provider.GetProviderKey(), + provider.sendBackRawResponse, + provider.logger, + ) +} - return response, nil +func (provider *OpenRouterProvider) Responses(ctx context.Context, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { + return handleOpenAIResponsesRequest( + ctx, + provider.client, + provider.networkConfig.BaseURL+"/alpha/responses", + request, + key, + provider.networkConfig.ExtraHeaders, + provider.GetProviderKey(), + provider.sendBackRawResponse, + provider.logger, + ) } // ChatCompletionStream performs a streaming chat completion request to the OpenRouter API. // It supports real-time streaming of responses using Server-Sent Events (SSE). // Uses OpenRouter's OpenAI-compatible streaming format. // Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails. -func (provider *OpenRouterProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, input *schemas.BifrostRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { - // Use centralized OpenAI converter since OpenRouter is OpenAI-compatible - reqBody := openai.ToOpenAIChatCompletionRequest(input) - reqBody.Stream = schemas.Ptr(true) - - // Prepare OpenRouter headers - headers := map[string]string{ - "Content-Type": "application/json", - "Authorization": "Bearer " + key.Value, - "Accept": "text/event-stream", - "Cache-Control": "no-cache", - } - +func (provider *OpenRouterProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { // Use shared OpenAI-compatible streaming logic return handleOpenAIStreaming( ctx, provider.streamClient, provider.networkConfig.BaseURL+"/v1/chat/completions", - reqBody, - headers, + request, + map[string]string{"Authorization": "Bearer " + key.Value}, provider.networkConfig.ExtraHeaders, schemas.OpenRouter, - input.Params, postHookRunner, provider.logger, ) } -func (provider *OpenRouterProvider) Embedding(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *OpenRouterProvider) Embedding(ctx context.Context, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { return nil, newUnsupportedOperationError("embedding", "openrouter") } -func (provider *OpenRouterProvider) Speech(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *OpenRouterProvider) Speech(ctx context.Context, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { return nil, newUnsupportedOperationError("speech", "openrouter") } -func (provider *OpenRouterProvider) SpeechStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, input *schemas.BifrostRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *OpenRouterProvider) SpeechStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, newUnsupportedOperationError("speech stream", "openrouter") } -func (provider *OpenRouterProvider) Transcription(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *OpenRouterProvider) Transcription(ctx context.Context, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { return nil, newUnsupportedOperationError("transcription", "openrouter") } -func (provider *OpenRouterProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, input *schemas.BifrostRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *OpenRouterProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, newUnsupportedOperationError("transcription stream", "openrouter") } + +func (provider *OpenRouterProvider) ResponsesStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("responses stream", "openrouter") +} diff --git a/core/providers/parasail.go b/core/providers/parasail.go index a707296007..e90370828d 100644 --- a/core/providers/parasail.go +++ b/core/providers/parasail.go @@ -4,37 +4,14 @@ package providers import ( "context" - "fmt" "net/http" "strings" "time" - "github.com/bytedance/sonic" schemas "github.com/maximhq/bifrost/core/schemas" "github.com/valyala/fasthttp" ) -// // parasailResponsePool provides a pool for Parasail response objects. -// var parasailResponsePool = sync.Pool{ -// New: func() interface{} { -// return &schemas.BifrostResponse{} -// }, -// } - -// // acquireParasailResponse gets a Parasail response from the pool and resets it. -// func acquireParasailResponse() *schemas.BifrostResponse { -// resp := parasailResponsePool.Get().(*schemas.BifrostResponse) -// *resp = schemas.BifrostResponse{} // Reset the struct -// return resp -// } - -// // releaseParasailResponse returns a Parasail response to the pool. -// func releaseParasailResponse(resp *schemas.BifrostResponse) { -// if resp != nil { -// parasailResponsePool.Put(resp) -// } -// } - // ParasailProvider implements the Provider interface for Parasail's API. type ParasailProvider struct { logger schemas.Logger // Logger for provider operations @@ -90,81 +67,41 @@ func (provider *ParasailProvider) GetProviderKey() schemas.ModelProvider { } // TextCompletion is not supported by the Parasail provider. -func (provider *ParasailProvider) TextCompletion(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *ParasailProvider) TextCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { return nil, newUnsupportedOperationError("text completion", "parasail") } // ChatCompletion performs a chat completion request to the Parasail API. -func (provider *ParasailProvider) ChatCompletion(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { - // Use centralized OpenAI converter since Parasail is OpenAI-compatible - reqBody := openai.ToOpenAIChatCompletionRequest(input) +func (provider *ParasailProvider) ChatCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { + return handleOpenAIChatCompletionRequest( + ctx, + provider.client, + provider.networkConfig.BaseURL+"/v1/chat/completions", + request, + key, + provider.networkConfig.ExtraHeaders, + provider.GetProviderKey(), + provider.sendBackRawResponse, + provider.logger, + ) +} - jsonBody, err := sonic.Marshal(reqBody) +func (provider *ParasailProvider) Responses(ctx context.Context, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { + response, err := provider.ChatCompletion(ctx, key, request.ToChatRequest()) if err != nil { - return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, schemas.Parasail) + return nil, err } - // Create request - req := fasthttp.AcquireRequest() - resp := fasthttp.AcquireResponse() - defer fasthttp.ReleaseRequest(req) - defer fasthttp.ReleaseResponse(resp) - - // Set any extra headers from network config - setExtraHeaders(req, provider.networkConfig.ExtraHeaders, nil) - - req.SetRequestURI(provider.networkConfig.BaseURL + "/v1/chat/completions") - req.Header.SetMethod("POST") - req.Header.SetContentType("application/json") - req.Header.Set("Authorization", "Bearer "+key.Value) - - req.SetBody(jsonBody) - - // Make request - bifrostErr := makeRequestWithContext(ctx, provider.client, req, resp) - if bifrostErr != nil { - return nil, bifrostErr - } - - // Handle error response - if resp.StatusCode() != fasthttp.StatusOK { - provider.logger.Debug(fmt.Sprintf("error from parasail provider: %s", string(resp.Body()))) - - var errorResp map[string]interface{} - bifrostErr := handleProviderAPIError(resp, &errorResp) - bifrostErr.Error.Message = fmt.Sprintf("Parasail error: %v", errorResp) - return nil, bifrostErr - } - - responseBody := resp.Body() - - // Pre-allocate response structs from pools - // response := acquireParasailResponse() - // defer releaseParasailResponse(response) - response := &schemas.BifrostResponse{} - - // Use enhanced response handler with pre-allocated response - rawResponse, bifrostErr := handleProviderResponse(responseBody, response, provider.sendBackRawResponse) - if bifrostErr != nil { - return nil, bifrostErr - } - - // Create final response - response.ExtraFields.Provider = schemas.Parasail - - if provider.sendBackRawResponse { - response.ExtraFields.RawResponse = rawResponse - } - - if input.Params != nil { - response.ExtraFields.Params = *input.Params - } + response.ToResponsesOnly() + response.ExtraFields.RequestType = schemas.ResponsesRequest + response.ExtraFields.Provider = provider.GetProviderKey() + response.ExtraFields.ModelRequested = request.Model return response, nil } // Embedding is not supported by the Parasail provider. -func (provider *ParasailProvider) Embedding(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *ParasailProvider) Embedding(ctx context.Context, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { return nil, newUnsupportedOperationError("embedding", "parasail") } @@ -172,51 +109,41 @@ func (provider *ParasailProvider) Embedding(ctx context.Context, key schemas.Key // It supports real-time streaming of responses using Server-Sent Events (SSE). // Uses Parasail's OpenAI-compatible streaming format. // Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails. -func (provider *ParasailProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, input *schemas.BifrostRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { - // Use centralized OpenAI converter since Parasail is OpenAI-compatible - reqBody := openai.ToOpenAIChatCompletionRequest(input) - reqBody.Stream = schemas.Ptr(true) - - // Prepare Parasail headers - headers := map[string]string{ - "Content-Type": "application/json", - "Accept": "text/event-stream", - "Cache-Control": "no-cache", - } - - headers["Authorization"] = "Bearer " + key.Value - +func (provider *ParasailProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { // Use shared OpenAI-compatible streaming logic return handleOpenAIStreaming( ctx, provider.streamClient, provider.networkConfig.BaseURL+"/v1/chat/completions", - reqBody, - headers, + request, + map[string]string{"Authorization": "Bearer " + key.Value}, provider.networkConfig.ExtraHeaders, schemas.Parasail, - input.Params, postHookRunner, provider.logger, ) } // Speech is not supported by the Parasail provider. -func (provider *ParasailProvider) Speech(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *ParasailProvider) Speech(ctx context.Context, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { return nil, newUnsupportedOperationError("speech", "parasail") } // SpeechStream is not supported by the Parasail provider. -func (provider *ParasailProvider) SpeechStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, input *schemas.BifrostRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *ParasailProvider) SpeechStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, newUnsupportedOperationError("speech stream", "parasail") } // Transcription is not supported by the Parasail provider. -func (provider *ParasailProvider) Transcription(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *ParasailProvider) Transcription(ctx context.Context, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { return nil, newUnsupportedOperationError("transcription", "parasail") } // TranscriptionStream is not supported by the Parasail provider. -func (provider *ParasailProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, input *schemas.BifrostRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *ParasailProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, newUnsupportedOperationError("transcription stream", "parasail") } + +func (provider *ParasailProvider) ResponsesStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("responses stream", "parasail") +} diff --git a/core/providers/sgl.go b/core/providers/sgl.go index 11cb6dd145..e93885c618 100644 --- a/core/providers/sgl.go +++ b/core/providers/sgl.go @@ -9,32 +9,10 @@ import ( "strings" "time" - "github.com/bytedance/sonic" schemas "github.com/maximhq/bifrost/core/schemas" "github.com/valyala/fasthttp" ) -// // sglResponsePool provides a pool for SGL response objects. -// var sglResponsePool = sync.Pool{ -// New: func() interface{} { -// return &schemas.BifrostResponse{} -// }, -// } - -// // acquireSGLResponse gets a SGL response from the pool and resets it. -// func acquireSGLResponse() *schemas.BifrostResponse { -// resp := sglResponsePool.Get().(*schemas.BifrostResponse) -// *resp = schemas.BifrostResponse{} // Reset the struct -// return resp -// } - -// // releaseSGLResponse returns a SGL response to the pool. -// func releaseSGLResponse(resp *schemas.BifrostResponse) { -// if resp != nil { -// sglResponsePool.Put(resp) -// } -// } - // SGLProvider implements the Provider interface for SGL's API. type SGLProvider struct { logger schemas.Logger // Logger for provider operations @@ -91,88 +69,51 @@ func (provider *SGLProvider) GetProviderKey() schemas.ModelProvider { } // TextCompletion is not supported by the SGL provider. -func (provider *SGLProvider) TextCompletion(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { - return nil, newUnsupportedOperationError("text completion", "sgl") +func (provider *SGLProvider) TextCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { + return handleOpenAITextCompletionRequest( + ctx, + provider.client, + provider.networkConfig.BaseURL+"/v1/completions", + request, + key, + provider.networkConfig.ExtraHeaders, + provider.GetProviderKey(), + provider.sendBackRawResponse, + provider.logger, + ) } // ChatCompletion performs a chat completion request to the SGL API. -func (provider *SGLProvider) ChatCompletion(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { - // Use centralized OpenAI converter since SGL is OpenAI-compatible - reqBody := openai.ToOpenAIChatCompletionRequest(input) +func (provider *SGLProvider) ChatCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { + return handleOpenAIChatCompletionRequest( + ctx, + provider.client, + provider.networkConfig.BaseURL+"/v1/chat/completions", + request, + key, + provider.networkConfig.ExtraHeaders, + provider.GetProviderKey(), + provider.sendBackRawResponse, + provider.logger, + ) +} - jsonBody, err := sonic.Marshal(reqBody) +func (provider *SGLProvider) Responses(ctx context.Context, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { + response, err := provider.ChatCompletion(ctx, key, request.ToChatRequest()) if err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: true, - Error: schemas.ErrorField{ - Message: schemas.ErrProviderJSONMarshaling, - Error: err, - }, - } - } - - // Create request - req := fasthttp.AcquireRequest() - resp := fasthttp.AcquireResponse() - defer fasthttp.ReleaseRequest(req) - defer fasthttp.ReleaseResponse(resp) - - // Set any extra headers from network config - setExtraHeaders(req, provider.networkConfig.ExtraHeaders, nil) - - req.SetRequestURI(provider.networkConfig.BaseURL + "/v1/chat/completions") - req.Header.SetMethod("POST") - req.Header.SetContentType("application/json") - if key.Value != "" { - req.Header.Set("Authorization", "Bearer "+key.Value) - } - - req.SetBody(jsonBody) - - // Make request - bifrostErr := makeRequestWithContext(ctx, provider.client, req, resp) - if bifrostErr != nil { - return nil, bifrostErr + return nil, err } - // Handle error response - if resp.StatusCode() != fasthttp.StatusOK { - provider.logger.Debug(fmt.Sprintf("error from sgl provider: %s", string(resp.Body()))) - - var errorResp map[string]interface{} - bifrostErr := handleProviderAPIError(resp, &errorResp) - bifrostErr.Error.Message = fmt.Sprintf("SGL error: %v", errorResp) - return nil, bifrostErr - } - - responseBody := resp.Body() - - // Pre-allocate response structs from pools - // response := acquireSGLResponse() - response := &schemas.BifrostResponse{} - // defer releaseSGLResponse(response) - - // Use enhanced response handler with pre-allocated response - rawResponse, bifrostErr := handleProviderResponse(responseBody, response, provider.sendBackRawResponse) - if bifrostErr != nil { - return nil, bifrostErr - } - - response.ExtraFields.Provider = schemas.SGL - - if provider.sendBackRawResponse { - response.ExtraFields.RawResponse = rawResponse - } - - if input.Params != nil { - response.ExtraFields.Params = *input.Params - } + response.ToResponsesOnly() + response.ExtraFields.RequestType = schemas.ResponsesRequest + response.ExtraFields.Provider = provider.GetProviderKey() + response.ExtraFields.ModelRequested = request.Model return response, nil } // Embedding is not supported by the SGL provider. -func (provider *SGLProvider) Embedding(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *SGLProvider) Embedding(ctx context.Context, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { return nil, newUnsupportedOperationError("embedding", "sgl") } @@ -180,50 +121,37 @@ func (provider *SGLProvider) Embedding(ctx context.Context, key schemas.Key, inp // It supports real-time streaming of responses using Server-Sent Events (SSE). // Uses SGL's OpenAI-compatible streaming format. // Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails. -func (provider *SGLProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, input *schemas.BifrostRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { - // Use centralized OpenAI converter since SGL is OpenAI-compatible - reqBody := openai.ToOpenAIChatCompletionRequest(input) - reqBody.Stream = schemas.Ptr(true) - - // Prepare SGL headers (SGL typically doesn't require authorization, but we include it if provided) - headers := map[string]string{ - "Content-Type": "application/json", - "Accept": "text/event-stream", - "Cache-Control": "no-cache", - } - - // Only add Authorization header if key is provided (SGL can run without auth) - if key.Value != "" { - headers["Authorization"] = "Bearer " + key.Value - } - +func (provider *SGLProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { // Use shared OpenAI-compatible streaming logic return handleOpenAIStreaming( ctx, provider.streamClient, provider.networkConfig.BaseURL+"/v1/chat/completions", - reqBody, - headers, + request, + map[string]string{"Authorization": "Bearer " + key.Value}, provider.networkConfig.ExtraHeaders, schemas.SGL, - input.Params, postHookRunner, provider.logger, ) } -func (provider *SGLProvider) Speech(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *SGLProvider) Speech(ctx context.Context, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { return nil, newUnsupportedOperationError("speech", "sgl") } -func (provider *SGLProvider) SpeechStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, input *schemas.BifrostRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *SGLProvider) SpeechStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, newUnsupportedOperationError("speech stream", "sgl") } -func (provider *SGLProvider) Transcription(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *SGLProvider) Transcription(ctx context.Context, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { return nil, newUnsupportedOperationError("transcription", "sgl") } -func (provider *SGLProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, input *schemas.BifrostRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *SGLProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, newUnsupportedOperationError("transcription stream", "sgl") } + +func (provider *SGLProvider) ResponsesStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("responses stream", "sgl") +} diff --git a/core/providers/utils.go b/core/providers/utils.go index 53256dcfa5..b197998909 100644 --- a/core/providers/utils.go +++ b/core/providers/utils.go @@ -3,13 +3,11 @@ package providers import ( - "bytes" "context" "fmt" "net/http" "net/textproto" "net/url" - "reflect" "slices" "strings" "sync" @@ -18,81 +16,8 @@ import ( schemas "github.com/maximhq/bifrost/core/schemas" "github.com/valyala/fasthttp" "github.com/valyala/fasthttp/fasthttpproxy" - - "maps" ) -// ContextKey is a custom type for context keys to prevent key collisions in the context. -// It provides type safety for context values and ensures that context keys are unique -// across different packages. -type ContextKey string - -// mergeConfig merges a default configuration map with custom parameters. -// It creates a new map containing all default values, then overrides them with any custom values. -// Returns a new map containing the merged configuration. -func mergeConfig(defaultConfig map[string]interface{}, customParams map[string]interface{}) map[string]interface{} { - merged := make(map[string]interface{}) - - // Copy default config - for k, v := range defaultConfig { - merged[k] = v - } - - // Override with custom parameters - for k, v := range customParams { - merged[k] = v - } - - return merged -} - -// prepareParams converts ModelParameters into a flat map of parameters. -// It handles both standard fields and extra parameters, using reflection to process -// the struct fields and their JSON tags. -// Returns a map containing all parameters ready for use in API requests. -func prepareParams(params *schemas.ModelParameters) map[string]interface{} { - flatParams := make(map[string]interface{}) - - // Return empty map if params is nil - if params == nil { - return flatParams - } - - // Use reflection to get the type and value of params - val := reflect.ValueOf(params).Elem() - typ := val.Type() - - // Iterate through all fields - for i := range val.NumField() { - field := val.Field(i) - fieldType := typ.Field(i) - - // Skip the ExtraParams field as it's handled separately - if fieldType.Name == "ExtraParams" { - continue - } - - // Get the JSON tag name - jsonTag := fieldType.Tag.Get("json") - if jsonTag == "" || jsonTag == "-" { - continue - } - - // Strip out ,omitempty and others from the tag - jsonTag = strings.Split(jsonTag, ",")[0] - - // Handle pointer fields - if field.Kind() == reflect.Ptr && !field.IsNil() { - flatParams[jsonTag] = field.Elem().Interface() - } - } - - // Handle ExtraParams - maps.Copy(flatParams, params.ExtraParams) - - return flatParams -} - // IMPORTANT: This function does NOT truly cancel the underlying fasthttp network request if the // context is done. The fasthttp client call will continue in its goroutine until it completes // or times out based on its own settings. This function merely stops *waiting* for the @@ -113,7 +38,7 @@ func makeRequestWithContext(ctx context.Context, client *fasthttp.Client, req *f return &schemas.BifrostError{ IsBifrostError: true, Error: schemas.ErrorField{ - Type: Ptr(schemas.RequestCancelled), + Type: schemas.Ptr(schemas.RequestCancelled), Message: fmt.Sprintf("Request cancelled or timed out by context: %v", ctx.Err()), Error: ctx.Err(), }, @@ -320,100 +245,18 @@ func handleProviderResponse[T any](responseBody []byte, response *T, sendBackRaw return nil, nil } -// getRoleFromMessage extracts and validates the role from a message map. -func getRoleFromMessage(msg map[string]interface{}) (schemas.ModelChatMessageRole, bool) { - roleVal, exists := msg["role"] - if !exists { - return "", false // Role key doesn't exist - } - - // Try direct assertion to ModelChatMessageRole - roleAsModelType, ok := roleVal.(schemas.ModelChatMessageRole) - if ok { - return roleAsModelType, true - } - - // Try assertion to string and then convert - roleAsString, okStr := roleVal.(string) - if okStr { - return schemas.ModelChatMessageRole(roleAsString), true - } - - return "", false // Role is of an unexpected or invalid type -} - -// Ptr creates a pointer to any value. -// This is a helper function for creating pointers to values. -func Ptr[T any](v T) *T { - return &v -} - -var ( - riff = []byte("RIFF") - wave = []byte("WAVE") - id3 = []byte("ID3") - form = []byte("FORM") - aiff = []byte("AIFF") - aifc = []byte("AIFC") - flac = []byte("fLaC") - oggs = []byte("OggS") - adif = []byte("ADIF") -) - -// detectAudioMimeType attempts to detect the MIME type from audio file headers -// Gemini supports: WAV, MP3, AIFF, AAC, OGG Vorbis, FLAC -func detectAudioMimeType(audioData []byte) string { - if len(audioData) < 4 { - return "audio/mp3" - } - // WAV (RIFF/WAVE) - if len(audioData) >= 12 && - bytes.Equal(audioData[:4], riff) && - bytes.Equal(audioData[8:12], wave) { - return "audio/wav" - } - // MP3: ID3v2 tag (keep this check for MP3) - if len(audioData) >= 3 && bytes.Equal(audioData[:3], id3) { - return "audio/mp3" - } - // AAC: ADIF or ADTS (0xFFF sync) - check before MP3 frame sync to avoid misclassification - if bytes.HasPrefix(audioData, adif) { - return "audio/aac" - } - if len(audioData) >= 2 && audioData[0] == 0xFF && (audioData[1]&0xF6) == 0xF0 { - return "audio/aac" - } - // AIFF / AIFC (map both to audio/aiff) - if len(audioData) >= 12 && bytes.Equal(audioData[:4], form) && - (bytes.Equal(audioData[8:12], aiff) || bytes.Equal(audioData[8:12], aifc)) { - return "audio/aiff" - } - // FLAC - if bytes.HasPrefix(audioData, flac) { - return "audio/flac" - } - // OGG container - if bytes.HasPrefix(audioData, oggs) { - return "audio/ogg" - } - // MP3: MPEG frame sync (cover common variants) - check after AAC to avoid misclassification - if len(audioData) >= 2 && audioData[0] == 0xFF && - (audioData[1] == 0xFB || audioData[1] == 0xF3 || audioData[1] == 0xF2 || audioData[1] == 0xFA) { - return "audio/mp3" - } - // Fallback within supported set - return "audio/mp3" -} - // newUnsupportedOperationError creates a standardized error for unsupported operations. // This helper reduces code duplication across providers that don't support certain operations. func newUnsupportedOperationError(operation string, providerName string) *schemas.BifrostError { return &schemas.BifrostError{ IsBifrostError: false, - Provider: schemas.ModelProvider(providerName), Error: schemas.ErrorField{ Message: fmt.Sprintf("%s is not supported by %s provider", operation, providerName), }, + ExtraFields: schemas.BifrostErrorExtraFields{ + Provider: schemas.ModelProvider(providerName), + RequestType: schemas.RequestType(operation), + }, } } @@ -421,7 +264,7 @@ func newUnsupportedOperationError(operation string, providerName string) *schema // Behavior: // - If no gating is configured (config == nil or AllowedRequests == nil), the operation is allowed. // - If gating is configured, returns an error when the operation is not explicitly allowed. -func checkOperationAllowed(defaultProvider schemas.ModelProvider, config *schemas.CustomProviderConfig, operation schemas.Operation) *schemas.BifrostError { +func checkOperationAllowed(defaultProvider schemas.ModelProvider, config *schemas.CustomProviderConfig, operation schemas.RequestType) *schemas.BifrostError { // No gating configured => allowed if config == nil || config.AllowedRequests == nil { return nil @@ -440,7 +283,6 @@ func checkOperationAllowed(defaultProvider schemas.ModelProvider, config *schema func newConfigurationError(message string, providerType schemas.ModelProvider) *schemas.BifrostError { return &schemas.BifrostError{ IsBifrostError: false, - Provider: providerType, Error: schemas.ErrorField{ Message: message, }, @@ -452,7 +294,6 @@ func newConfigurationError(message string, providerType schemas.ModelProvider) * func newBifrostOperationError(message string, err error, providerType schemas.ModelProvider) *schemas.BifrostError { return &schemas.BifrostError{ IsBifrostError: true, - Provider: providerType, Error: schemas.ErrorField{ Message: message, Error: err, @@ -465,7 +306,6 @@ func newBifrostOperationError(message string, err error, providerType schemas.Mo func newProviderAPIError(message string, err error, statusCode int, providerType schemas.ModelProvider, errorType *string, eventID *string) *schemas.BifrostError { return &schemas.BifrostError{ IsBifrostError: false, - Provider: providerType, StatusCode: &statusCode, Type: errorType, EventID: eventID, @@ -557,6 +397,9 @@ func processAndSendError( postHookRunner schemas.PostHookRunner, err error, responseChan chan *schemas.BifrostStream, + requestType schemas.RequestType, + providerName schemas.ModelProvider, + model string, logger schemas.Logger, ) { // Send scanner error through channel @@ -567,6 +410,11 @@ func processAndSendError( Message: fmt.Sprintf("Error reading stream: %v", err), Error: err, }, + ExtraFields: schemas.BifrostErrorExtraFields{ + RequestType: requestType, + Provider: providerName, + ModelRequested: model, + }, } processedResponse, processedError := postHookRunner(&ctx, nil, bifrostError) @@ -589,14 +437,15 @@ func createBifrostChatCompletionChunkResponse( usage *schemas.LLMUsage, finishReason *string, currentChunkIndex int, - params *schemas.ModelParameters, + requestType schemas.RequestType, providerName schemas.ModelProvider, + model string, ) *schemas.BifrostResponse { response := &schemas.BifrostResponse{ ID: id, Object: "chat.completion.chunk", Usage: usage, - Choices: []schemas.BifrostResponseChoice{ + Choices: []schemas.BifrostChatResponseChoice{ { FinishReason: finishReason, BifrostStreamResponseChoice: &schemas.BifrostStreamResponseChoice{ @@ -605,13 +454,12 @@ func createBifrostChatCompletionChunkResponse( }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: providerName, - ChunkIndex: currentChunkIndex + 1, + RequestType: requestType, + Provider: providerName, + ModelRequested: model, + ChunkIndex: currentChunkIndex + 1, }, } - if params != nil { - response.ExtraFields.Params = *params - } return response } diff --git a/core/providers/vertex.go b/core/providers/vertex.go index 825e6f5166..8fc66ff6d3 100644 --- a/core/providers/vertex.go +++ b/core/providers/vertex.go @@ -70,7 +70,7 @@ func NewVertexProvider(config *schemas.ProviderConfig, logger schemas.Logger) (* // Pre-warm response pools for range config.ConcurrencyAndBufferSize.Concurrency { // openAIResponsePool.Put(&schemas.BifrostResponse{}) - anthropicChatResponsePool.Put(&anthropic.AnthropicChatResponse{}) + anthropicChatResponsePool.Put(&anthropic.AnthropicMessageResponse{}) } @@ -133,39 +133,44 @@ func (provider *VertexProvider) GetProviderKey() schemas.ModelProvider { // TextCompletion is not supported by the Vertex provider. // Returns an error indicating that text completion is not available. -func (provider *VertexProvider) TextCompletion(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *VertexProvider) TextCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostTextCompletionRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { return nil, newUnsupportedOperationError("text completion", "vertex") } // ChatCompletion performs a chat completion request to the Vertex API. // It supports both text and image content in messages. // Returns a BifrostResponse containing the completion results or an error if the request fails. -func (provider *VertexProvider) ChatCompletion(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *VertexProvider) ChatCompletion(ctx context.Context, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { if key.VertexKeyConfig == nil { return nil, newConfigurationError("vertex key config is not set", schemas.Vertex) } // Format messages for Vertex API - var requestBody map[string]interface{} - if strings.Contains(input.Model, "claude") { + if strings.Contains(request.Model, "claude") { // Use centralized Anthropic converter - anthropicReq := anthropic.ToAnthropicChatCompletionRequest(input) + reqBody := anthropic.ToAnthropicChatCompletionRequest(request) + if reqBody == nil { + return nil, newBifrostOperationError("chat completion input is not provided", nil, schemas.Vertex) + } // Convert struct to map for Vertex API reqBytes, _ := sonic.Marshal(reqBody) sonic.Unmarshal(reqBytes, &requestBody) } else { // Use centralized OpenAI converter for non-Claude models - openaiReq := openai.ToOpenAIChatCompletionRequest(input) + reqBody := openai.ToOpenAIChatRequest(request) + if reqBody == nil { + return nil, newBifrostOperationError("chat completion input is not provided", nil, schemas.Vertex) + } // Convert struct to map for Vertex API reqBytes, _ := sonic.Marshal(reqBody) sonic.Unmarshal(reqBytes, &requestBody) } - if strings.Contains(input.Model, "claude") { + if strings.Contains(request.Model, "claude") { if _, exists := requestBody["anthropic_version"]; !exists { requestBody["anthropic_version"] = "vertex-2023-10-16" } @@ -192,8 +197,8 @@ func (provider *VertexProvider) ChatCompletion(ctx context.Context, key schemas. url := fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions", region, projectID, region) - if strings.Contains(input.Model, "claude") { - url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:streamRawPredict", region, projectID, region, input.Model) + if strings.Contains(request.Model, "claude") { + url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:rawPredict", region, projectID, region, request.Model) } // Create request @@ -227,7 +232,7 @@ func (provider *VertexProvider) ChatCompletion(ctx context.Context, key schemas. return nil, &schemas.BifrostError{ IsBifrostError: false, Error: schemas.ErrorField{ - Type: Ptr(schemas.RequestCancelled), + Type: schemas.Ptr(schemas.RequestCancelled), Message: fmt.Sprintf("Request cancelled or timed out by context: %v", ctx.Err()), Error: err, }, @@ -269,7 +274,7 @@ func (provider *VertexProvider) ChatCompletion(ctx context.Context, key schemas. return nil, newProviderAPIError(openAIErr.Error.Message, nil, resp.StatusCode, schemas.Vertex, nil, nil) } - if strings.Contains(input.Model, "claude") { + if strings.Contains(request.Model, "claude") { // Create response object from pool response := acquireAnthropicChatResponse() defer releaseAnthropicChatResponse(response) @@ -283,17 +288,15 @@ func (provider *VertexProvider) ChatCompletion(ctx context.Context, key schemas. bifrostResponse := response.ToBifrostResponse() bifrostResponse.ExtraFields = schemas.BifrostResponseExtraFields{ - Provider: schemas.Vertex, + RequestType: schemas.ChatCompletionRequest, + Provider: schemas.Vertex, + ModelRequested: request.Model, } if provider.sendBackRawResponse { bifrostResponse.ExtraFields.RawResponse = rawResponse } - if input.Params != nil { - bifrostResponse.ExtraFields.Params = *input.Params - } - return bifrostResponse, nil } else { // Pre-allocate response structs from pools @@ -307,24 +310,36 @@ func (provider *VertexProvider) ChatCompletion(ctx context.Context, key schemas. return nil, bifrostErr } + response.ExtraFields.RequestType = schemas.ChatCompletionRequest response.ExtraFields.Provider = schemas.Vertex + response.ExtraFields.ModelRequested = request.Model if provider.sendBackRawResponse { response.ExtraFields.RawResponse = rawResponse } - if input.Params != nil { - response.ExtraFields.Params = *input.Params - } - return response, nil } } +func (provider *VertexProvider) Responses(ctx context.Context, key schemas.Key, request *schemas.BifrostResponsesRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { + response, err := provider.ChatCompletion(ctx, key, request.ToChatRequest()) + if err != nil { + return nil, err + } + + response.ToResponsesOnly() + response.ExtraFields.RequestType = schemas.ResponsesRequest + response.ExtraFields.Provider = provider.GetProviderKey() + response.ExtraFields.ModelRequested = request.Model + + return response, nil +} + // Embedding generates embeddings for the given input text(s) using Vertex AI. // All Vertex AI embedding models use the same response format regardless of the model type. // Returns a BifrostResponse containing the embedding(s) and any error that occurred. -func (provider *VertexProvider) Embedding(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *VertexProvider) Embedding(ctx context.Context, key schemas.Key, request *schemas.BifrostEmbeddingRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { if key.VertexKeyConfig == nil { return nil, newConfigurationError("vertex key config is not set", schemas.Vertex) } @@ -339,21 +354,19 @@ func (provider *VertexProvider) Embedding(ctx context.Context, key schemas.Key, return nil, newConfigurationError("region is not set in key config", schemas.Vertex) } - // Validate input - if input.Input.EmbeddingInput == nil || len(input.Input.EmbeddingInput.Texts) == 0 { + // Use centralized Vertex converter + reqBody := vertex.ToVertexEmbeddingRequest(request) + if reqBody == nil { return nil, newConfigurationError("embedding input texts are empty", schemas.Vertex) } - // Use centralized Vertex converter - reqBody := vertex.ToVertexEmbeddingRequest(input) - // All Vertex AI embedding models use the same native Vertex embedding API - return provider.handleVertexEmbedding(ctx, input.Model, key, reqBody, input.Params) + return provider.handleVertexEmbedding(ctx, request.Model, key, reqBody, request.Params) } // handleVertexEmbedding handles embedding requests using Vertex's native embedding API // This is used for all Vertex AI embedding models as they all use the same response format -func (provider *VertexProvider) handleVertexEmbedding(ctx context.Context, model string, key schemas.Key, vertexReq *vertex.VertexEmbeddingRequest, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *VertexProvider) handleVertexEmbedding(ctx context.Context, model string, key schemas.Key, vertexReq *vertex.VertexEmbeddingRequest, params *schemas.EmbeddingParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { // Use the typed request directly jsonBody, err := sonic.Marshal(vertexReq) if err != nil { @@ -389,7 +402,7 @@ func (provider *VertexProvider) handleVertexEmbedding(ctx context.Context, model return nil, &schemas.BifrostError{ IsBifrostError: false, Error: schemas.ErrorField{ - Type: Ptr(schemas.RequestCancelled), + Type: schemas.Ptr(schemas.RequestCancelled), Message: fmt.Sprintf("Request cancelled or timed out by context: %v", ctx.Err()), Error: err, }, @@ -443,6 +456,11 @@ func (provider *VertexProvider) handleVertexEmbedding(ctx context.Context, model // Use centralized Vertex converter bifrostResponse := vertexResponse.ToBifrostResponse() + // Set ExtraFields + bifrostResponse.ExtraFields.Provider = schemas.Vertex + bifrostResponse.ExtraFields.ModelRequested = model + bifrostResponse.ExtraFields.RequestType = schemas.EmbeddingRequest + // Set raw response if enabled if provider.sendBackRawResponse { // Convert back to map for raw response @@ -458,7 +476,7 @@ func (provider *VertexProvider) handleVertexEmbedding(ctx context.Context, model // ChatCompletionStream performs a streaming chat completion request to the Vertex API. // It supports both OpenAI-style streaming (for non-Claude models) and Anthropic-style streaming (for Claude models). // Returns a channel of BifrostResponse objects for streaming results or an error if the request fails. -func (provider *VertexProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, input *schemas.BifrostRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *VertexProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { if key.VertexKeyConfig == nil { return nil, newConfigurationError("vertex key config is not set", schemas.Vertex) } @@ -480,10 +498,14 @@ func (provider *VertexProvider) ChatCompletionStream(ctx context.Context, postHo return nil, newBifrostOperationError("error creating auth client", err, schemas.Vertex) } - if strings.Contains(input.Model, "claude") { + if strings.Contains(request.Model, "claude") { // Use Anthropic-style streaming for Claude models - anthropicReq := anthropic.ToAnthropicChatCompletionRequest(input) - anthropicReq.Stream = schemas.Ptr(true) + reqBody := anthropic.ToAnthropicChatCompletionRequest(request) + if reqBody == nil { + return nil, newBifrostOperationError("chat completion input is not provided", nil, schemas.Vertex) + } + + reqBody.Stream = schemas.Ptr(true) // Convert struct to map for Vertex API reqBytes, _ := sonic.Marshal(reqBody) @@ -497,7 +519,7 @@ func (provider *VertexProvider) ChatCompletionStream(ctx context.Context, postHo delete(requestBody, "model") delete(requestBody, "region") - url := fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:streamRawPredict", region, projectID, region, input.Model) + url := fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:streamRawPredict", region, projectID, region, request.Model) // Prepare headers for Vertex Anthropic headers := map[string]string{ @@ -515,59 +537,43 @@ func (provider *VertexProvider) ChatCompletionStream(ctx context.Context, postHo headers, provider.networkConfig.ExtraHeaders, schemas.Vertex, - input.Params, postHookRunner, provider.logger, ) } else { - // Use OpenAI-style streaming for non-Claude models - openaiReq := openai.ToOpenAIChatCompletionRequest(input) - openaiReq.Stream = schemas.Ptr(true) - - // Convert struct to map for Vertex API - reqBytes, _ := sonic.Marshal(openaiReq) - var requestBody map[string]interface{} - sonic.Unmarshal(reqBytes, &requestBody) - - delete(requestBody, "region") - url := fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions", region, projectID, region) - // Prepare headers for Vertex OpenAI-compatible - headers := map[string]string{ - "Content-Type": "application/json", - "Accept": "text/event-stream", - "Cache-Control": "no-cache", - } - // Use shared OpenAI streaming logic return handleOpenAIStreaming( ctx, client, url, - requestBody, - headers, + request, + map[string]string{"Authorization": "Bearer " + key.Value}, provider.networkConfig.ExtraHeaders, schemas.Vertex, - input.Params, postHookRunner, provider.logger, ) } } -func (provider *VertexProvider) Speech(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *VertexProvider) Speech(ctx context.Context, key schemas.Key, request *schemas.BifrostSpeechRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { return nil, newUnsupportedOperationError("speech", "vertex") } -func (provider *VertexProvider) SpeechStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, input *schemas.BifrostRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *VertexProvider) SpeechStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostSpeechRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, newUnsupportedOperationError("speech stream", "vertex") } -func (provider *VertexProvider) Transcription(ctx context.Context, key schemas.Key, input *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *VertexProvider) Transcription(ctx context.Context, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { return nil, newUnsupportedOperationError("transcription", "vertex") } -func (provider *VertexProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, input *schemas.BifrostRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *VertexProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostTranscriptionRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { return nil, newUnsupportedOperationError("transcription stream", "vertex") } + +func (provider *VertexProvider) ResponsesStream(ctx context.Context, postHookRunner schemas.PostHookRunner, key schemas.Key, request *schemas.BifrostResponsesRequest) (chan *schemas.BifrostStream, *schemas.BifrostError) { + return nil, newUnsupportedOperationError("responses stream", "vertex") +} diff --git a/core/schemas/account.go b/core/schemas/account.go index 44563ca7be..a96342d1ae 100644 --- a/core/schemas/account.go +++ b/core/schemas/account.go @@ -10,11 +10,18 @@ type Key struct { Value string `json:"value"` // The actual API key value Models []string `json:"models"` // List of models this key can access Weight float64 `json:"weight"` // Weight for load balancing between multiple keys + OpenAIKeyConfig *OpenAIKeyConfig `json:"openai_key_config,omitempty"` // OpenAI-specific key configuration AzureKeyConfig *AzureKeyConfig `json:"azure_key_config,omitempty"` // Azure-specific key configuration VertexKeyConfig *VertexKeyConfig `json:"vertex_key_config,omitempty"` // Vertex-specific key configuration BedrockKeyConfig *BedrockKeyConfig `json:"bedrock_key_config,omitempty"` // AWS Bedrock-specific key configuration } +// OpenAIKeyConfig represents the OpenAI-specific configuration. +// It contains OpenAI-specific settings required for which endpoint to use. (chat completion or responses api) +type OpenAIKeyConfig struct { + UseResponsesAPI bool `json:"use_responses_api,omitempty"` +} + // AzureKeyConfig represents the Azure-specific configuration. // It contains Azure-specific settings required for service access and deployment management. type AzureKeyConfig struct { diff --git a/core/schemas/bifrost.go b/core/schemas/bifrost.go index 6b0537d99f..7e7f03697e 100644 --- a/core/schemas/bifrost.go +++ b/core/schemas/bifrost.go @@ -11,6 +11,21 @@ const ( DefaultInitialPoolSize = 5000 ) +// only ONE of the following fields should be set: +type BifrostRequest struct { + Provider ModelProvider + Model string + Fallbacks []Fallback + RequestType RequestType + + TextCompletionRequest *BifrostTextCompletionRequest + ChatRequest *BifrostChatRequest + ResponsesRequest *BifrostResponsesRequest + EmbeddingRequest *BifrostEmbeddingRequest + SpeechRequest *BifrostSpeechRequest + TranscriptionRequest *BifrostTranscriptionRequest +} + // BifrostConfig represents the configuration for initializing a Bifrost instance. // It contains the necessary components for setting up the system including account details, // plugins, logging, and initial pool size. @@ -23,17 +38,6 @@ type BifrostConfig struct { MCPConfig *MCPConfig // MCP (Model Context Protocol) configuration for tool integration } -// ModelChatMessageRole represents the role of a chat message -type ModelChatMessageRole string - -const ( - ModelChatMessageRoleAssistant ModelChatMessageRole = "assistant" - ModelChatMessageRoleUser ModelChatMessageRole = "user" - ModelChatMessageRoleSystem ModelChatMessageRole = "system" - ModelChatMessageRoleChatbot ModelChatMessageRole = "chatbot" - ModelChatMessageRoleTool ModelChatMessageRole = "tool" -) - // ModelProvider represents the different AI model providers supported by Bifrost. type ModelProvider string @@ -88,6 +92,8 @@ const ( TextCompletionRequest RequestType = "text_completion" ChatCompletionRequest RequestType = "chat_completion" ChatCompletionStreamRequest RequestType = "chat_completion_stream" + ResponsesRequest RequestType = "responses" + ResponsesStreamRequest RequestType = "responses_stream" EmbeddingRequest RequestType = "embedding" SpeechRequest RequestType = "speech" SpeechStreamRequest RequestType = "speech_stream" @@ -102,9 +108,6 @@ type BifrostContextKey string const ( BifrostContextKeyDirectKey BifrostContextKey = "bifrost-direct-key" BifrostContextKeyStreamEndIndicator BifrostContextKey = "bifrost-stream-end-indicator" - BifrostContextKeyRequestType BifrostContextKey = "bifrost-request-type" - BifrostContextKeyRequestProvider BifrostContextKey = "bifrost-request-provider" - BifrostContextKeyRequestModel BifrostContextKey = "bifrost-request-model" ) // NOTE: for custom plugin implementation dealing with streaming short circuit, @@ -112,175 +115,52 @@ const ( //* Request Structs -// RequestInput represents the input for a model request, which can be either -// a text completion, a chat completion, an embedding request, a speech request, or a transcription request. -type RequestInput struct { - TextCompletionInput *string `json:"text_completion_input,omitempty"` - ChatCompletionInput *[]BifrostMessage `json:"chat_completion_input,omitempty"` - EmbeddingInput *EmbeddingInput `json:"embedding_input,omitempty"` - SpeechInput *SpeechInput `json:"speech_input,omitempty"` - TranscriptionInput *TranscriptionInput `json:"transcription_input,omitempty"` -} - -// EmbeddingInput represents the input for an embedding request. -type EmbeddingInput struct { - Text *string - Texts []string - Embedding []int - Embeddings [][]int -} - -func (e *EmbeddingInput) MarshalJSON() ([]byte, error) { - // enforce one-of - set := 0 - if e.Text != nil { - set++ - } - if e.Texts != nil { - set++ - } - if e.Embedding != nil { - set++ - } - if e.Embeddings != nil { - set++ - } - if set == 0 { - return nil, fmt.Errorf("embedding input is empty") - } - if set > 1 { - return nil, fmt.Errorf("embedding input must set exactly one of: text, texts, embedding, embeddings") - } - - if e.Text != nil { - return sonic.Marshal(*e.Text) - } - if e.Texts != nil { - return sonic.Marshal(e.Texts) - } - if e.Embedding != nil { - return sonic.Marshal(e.Embedding) - } - if e.Embeddings != nil { - return sonic.Marshal(e.Embeddings) - } - - return nil, fmt.Errorf("invalid embedding input") -} - -func (e *EmbeddingInput) UnmarshalJSON(data []byte) error { - // Try string - var s string - if err := sonic.Unmarshal(data, &s); err == nil { - e.Text = &s - return nil - } - // Try []string - var ss []string - if err := sonic.Unmarshal(data, &ss); err == nil { - e.Texts = ss - return nil - } - // Try []int - var i []int - if err := sonic.Unmarshal(data, &i); err == nil { - e.Embedding = i - return nil - } - // Try [][]int - var i2 [][]int - if err := sonic.Unmarshal(data, &i2); err == nil { - e.Embeddings = i2 - return nil - } - - return fmt.Errorf("unsupported embedding input shape") +type BifrostTextCompletionRequest struct { + Provider ModelProvider `json:"provider"` + Model string `json:"model"` + Input TextCompletionInput `json:"input,omitempty"` + Params *TextCompletionParameters `json:"params,omitempty"` + Fallbacks []Fallback `json:"fallbacks,omitempty"` } -// SpeechInput represents the input for a speech request. -type SpeechInput struct { - Input string `json:"input"` - VoiceConfig SpeechVoiceInput `json:"voice"` - Instructions string `json:"instructions,omitempty"` - ResponseFormat string `json:"response_format,omitempty"` // Default is "mp3" +type BifrostChatRequest struct { + Provider ModelProvider `json:"provider"` + Model string `json:"model"` + Input []ChatMessage `json:"input,omitempty"` + Params *ChatParameters `json:"params,omitempty"` + Fallbacks []Fallback `json:"fallbacks,omitempty"` } -type SpeechVoiceInput struct { - Voice *string - MultiVoiceConfig []VoiceConfig -} - -type VoiceConfig struct { - Speaker string `json:"speaker"` - Voice string `json:"voice"` -} - -// MarshalJSON implements custom JSON marshalling for SpeechVoiceInput. -// It marshals either Voice or MultiVoiceConfig directly without wrapping. -func (tc SpeechVoiceInput) MarshalJSON() ([]byte, error) { - // Validation: ensure only one field is set at a time - if tc.Voice != nil && len(tc.MultiVoiceConfig) > 0 { - return nil, fmt.Errorf("both Voice and MultiVoiceConfig are set; only one should be non-nil") - } - - if tc.Voice != nil { - return sonic.Marshal(*tc.Voice) - } - if len(tc.MultiVoiceConfig) > 0 { - return sonic.Marshal(tc.MultiVoiceConfig) - } - // If both are nil, return null - return sonic.Marshal(nil) +type BifrostResponsesRequest struct { + Provider ModelProvider `json:"provider"` + Model string `json:"model"` + Input []ResponsesMessage `json:"input,omitempty"` + Params *ResponsesParameters `json:"params,omitempty"` + Fallbacks []Fallback `json:"fallbacks,omitempty"` } -// UnmarshalJSON implements custom JSON unmarshalling for SpeechVoiceInput. -// It determines whether "voice" is a string or a VoiceConfig object/array and assigns to the appropriate field. -// It also handles direct string/array content without a wrapper object. -func (tc *SpeechVoiceInput) UnmarshalJSON(data []byte) error { - // First, try to unmarshal as a direct string - var stringContent string - if err := sonic.Unmarshal(data, &stringContent); err == nil { - tc.Voice = &stringContent - return nil - } - - // Try to unmarshal as an array of VoiceConfig objects - var voiceConfigs []VoiceConfig - if err := sonic.Unmarshal(data, &voiceConfigs); err == nil { - // Validate each VoiceConfig and append to MultiVoiceConfig - for _, config := range voiceConfigs { - if config.Voice == "" { - return fmt.Errorf("voice config has empty voice field") - } - tc.MultiVoiceConfig = append(tc.MultiVoiceConfig, config) - } - return nil - } - - return fmt.Errorf("voice field is neither a string, nor an array of VoiceConfig objects") +type BifrostEmbeddingRequest struct { + Provider ModelProvider `json:"provider"` + Model string `json:"model"` + Input EmbeddingInput `json:"input,omitempty"` + Params *EmbeddingParameters `json:"params,omitempty"` + Fallbacks []Fallback `json:"fallbacks,omitempty"` } -type TranscriptionInput struct { - File []byte `json:"file"` - Language *string `json:"language,omitempty"` - Prompt *string `json:"prompt,omitempty"` - ResponseFormat *string `json:"response_format,omitempty"` // Default is "json" - Format *string `json:"file_format,omitempty"` // Type of file, not required in openai, but required in gemini +type BifrostSpeechRequest struct { + Provider ModelProvider `json:"provider"` + Model string `json:"model"` + Input SpeechInput `json:"input,omitempty"` + Params *SpeechParameters `json:"params,omitempty"` + Fallbacks []Fallback `json:"fallbacks,omitempty"` } -// BifrostRequest represents a request to be processed by Bifrost. -// It must be provided when calling the Bifrost for text completion, chat completion, or embedding. -// It contains the model identifier, input data, and parameters for the request. -type BifrostRequest struct { - Provider ModelProvider `json:"provider"` - Model string `json:"model"` - Input RequestInput `json:"input"` - Params *ModelParameters `json:"params,omitempty"` - - // Fallbacks are tried in order, the first one to succeed is returned - // Provider config must be available for each fallback's provider in account's GetConfigForProvider, - // else it will be skipped. - Fallbacks []Fallback `json:"fallbacks,omitempty"` +type BifrostTranscriptionRequest struct { + Provider ModelProvider `json:"provider"` + Model string `json:"model"` + Input TranscriptionInput `json:"input,omitempty"` + Params *TranscriptionParameters `json:"params,omitempty"` + Fallbacks []Fallback `json:"fallbacks,omitempty"` } // Fallback represents a fallback model to be used if the primary model is not available. @@ -289,277 +169,33 @@ type Fallback struct { Model string `json:"model"` } -// ModelParameters represents the parameters that can be used to configure -// your request to the model. Bifrost follows a standard set of parameters which -// mapped to the provider's parameters. -type ModelParameters struct { - ToolChoice *ToolChoice `json:"tool_choice,omitempty"` // Whether to call a tool - Tools *[]Tool `json:"tools,omitempty"` // Tools to use - Temperature *float64 `json:"temperature,omitempty"` // Controls randomness in the output - TopP *float64 `json:"top_p,omitempty"` // Controls diversity via nucleus sampling - TopK *int `json:"top_k,omitempty"` // Controls diversity via top-k sampling - MaxTokens *int `json:"max_tokens,omitempty"` // Maximum number of tokens to generate - StopSequences *[]string `json:"stop_sequences,omitempty"` // Sequences that stop generation - PresencePenalty *float64 `json:"presence_penalty,omitempty"` // Penalizes repeated tokens - FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` // Penalizes frequent tokens - ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` // Enables parallel tool calls - EncodingFormat *string `json:"encoding_format,omitempty"` // Format for embedding output (e.g., "float", "base64") - Dimensions *int `json:"dimensions,omitempty"` // Number of dimensions for embedding output - User *string `json:"user,omitempty"` // User identifier for tracking - N *int `json:"n,omitempty"` - Stop interface{} `json:"stop,omitempty"` - MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"` - ReasoningEffort *string `json:"reasoning_effort,omitempty"` - StreamOptions *map[string]interface{} `json:"stream_options,omitempty"` - Stream *bool `json:"stream,omitempty"` - LogProbs *bool `json:"logprobs,omitempty"` - TopLogProbs *int `json:"top_logprobs,omitempty"` - ResponseFormat interface{} `json:"response_format,omitempty"` - Seed *int `json:"seed,omitempty"` - LogitBias map[string]float64 `json:"logit_bias,omitempty"` - // Dynamic parameters that can be provider-specific, they are directly - // added to the request as is. - ExtraParams map[string]interface{} `json:"-"` -} - -// FunctionParameters represents the parameters for a function definition. -type FunctionParameters struct { - Type string `json:"type"` // Type of the parameters - Description *string `json:"description,omitempty"` // Description of the parameters - Required []string `json:"required,omitempty"` // Required parameter names - Properties map[string]interface{} `json:"properties,omitempty"` // Parameter properties - Enum *[]string `json:"enum,omitempty"` // Enum values for the parameters -} - -// Function represents a function that can be called by the model. -type Function struct { - Name string `json:"name"` // Name of the function - Description string `json:"description"` // Description of the function - Parameters FunctionParameters `json:"parameters"` // Parameters of the function -} - -// Tool represents a tool that can be used with the model. -type Tool struct { - ID *string `json:"id,omitempty"` // Optional tool identifier - Type string `json:"type"` // Type of the tool - Function Function `json:"function"` // Function definition -} - -// Combined tool choices for all providers, make sure to check the provider's -// documentation to see which tool choices are supported. -type ToolChoiceType string - -const ( - // ToolChoiceTypeNone means no tool will be called - ToolChoiceTypeNone ToolChoiceType = "none" - // ToolChoiceTypeAuto means the model can choose whether to call a tool - ToolChoiceTypeAuto ToolChoiceType = "auto" - // ToolChoiceTypeAny means any tool can be called - ToolChoiceTypeAny ToolChoiceType = "any" - // ToolChoiceTypeFunction means a specific tool must be called (converted to "tool" for Anthropic) - ToolChoiceTypeFunction ToolChoiceType = "function" - // ToolChoiceTypeRequired means a tool must be called - ToolChoiceTypeRequired ToolChoiceType = "required" -) - -// ToolChoiceFunction represents a specific function to be called. -type ToolChoiceFunction struct { - Name string `json:"name"` // Name of the function to call -} - -// ToolChoiceStruct represents a specific tool choice. -type ToolChoiceStruct struct { - Type ToolChoiceType `json:"type"` // Type of tool choice - Function ToolChoiceFunction `json:"function,omitempty"` // Function to call if type is ToolChoiceTypeFunction -} - -// ToolChoice represents how a tool should be chosen for a request. (either a string or a struct) -type ToolChoice struct { - ToolChoiceStr *string - ToolChoiceStruct *ToolChoiceStruct -} - -// MarshalJSON implements custom JSON marshalling for ToolChoice. -// It marshals either ToolChoiceStr or ToolChoiceStruct directly without wrapping. -func (tc ToolChoice) MarshalJSON() ([]byte, error) { - // Validation: ensure only one field is set at a time - if tc.ToolChoiceStr != nil && tc.ToolChoiceStruct != nil { - return nil, fmt.Errorf("both ToolChoiceStr and ToolChoiceStruct are set; only one should be non-nil") - } - - if tc.ToolChoiceStr != nil { - return sonic.Marshal(*tc.ToolChoiceStr) - } - if tc.ToolChoiceStruct != nil { - return sonic.Marshal(*tc.ToolChoiceStruct) - } - // If both are nil, return null - return sonic.Marshal(nil) -} - -// UnmarshalJSON implements custom JSON unmarshalling for ToolChoice. -// It determines whether "tool_choice" is a string or struct and assigns to the appropriate field. -// It also handles direct string/array content without a wrapper object. -func (tc *ToolChoice) UnmarshalJSON(data []byte) error { - // First, try to unmarshal as a direct string - var stringContent string - if err := sonic.Unmarshal(data, &stringContent); err == nil { - tc.ToolChoiceStr = &stringContent - return nil - } - - // Try to unmarshal as a direct struct of ToolChoiceStruct - var toolChoiceStruct ToolChoiceStruct - if err := sonic.Unmarshal(data, &toolChoiceStruct); err == nil { - // Validate the Type field is not empty and is a valid value - if toolChoiceStruct.Type == "" { - return fmt.Errorf("tool_choice struct has empty type field") - } - - tc.ToolChoiceStruct = &toolChoiceStruct - return nil - } - - return fmt.Errorf("tool_choice field is neither a string nor a struct") -} - -// BifrostMessage represents a message in a chat conversation. -type BifrostMessage struct { - Role ModelChatMessageRole `json:"role"` - Content MessageContent `json:"content"` - - // Embedded pointer structs - when non-nil, their exported fields are flattened into the top-level JSON object - // IMPORTANT: Only one of the following can be non-nil at a time, otherwise the JSON marshalling will override the common fields - *ToolMessage - *AssistantMessage -} - -type MessageContent struct { - ContentStr *string - ContentBlocks *[]ContentBlock -} - -// MarshalJSON implements custom JSON marshalling for MessageContent. -// It marshals either ContentStr or ContentBlocks directly without wrapping. -func (mc MessageContent) MarshalJSON() ([]byte, error) { - // Validation: ensure only one field is set at a time - if mc.ContentStr != nil && mc.ContentBlocks != nil { - return nil, fmt.Errorf("both ContentStr and ContentBlocks are set; only one should be non-nil") - } - - if mc.ContentStr != nil { - return sonic.Marshal(*mc.ContentStr) - } - if mc.ContentBlocks != nil { - return sonic.Marshal(*mc.ContentBlocks) - } - // If both are nil, return null - return sonic.Marshal(nil) -} - -// UnmarshalJSON implements custom JSON unmarshalling for MessageContent. -// It determines whether "content" is a string or array and assigns to the appropriate field. -// It also handles direct string/array content without a wrapper object. -func (mc *MessageContent) UnmarshalJSON(data []byte) error { - // First, try to unmarshal as a direct string - var stringContent string - if err := sonic.Unmarshal(data, &stringContent); err == nil { - mc.ContentStr = &stringContent - return nil - } - - // Try to unmarshal as a direct array of ContentBlock - var arrayContent []ContentBlock - if err := sonic.Unmarshal(data, &arrayContent); err == nil { - mc.ContentBlocks = &arrayContent - return nil - } - - return fmt.Errorf("content field is neither a string nor an array of ContentBlock") -} - -type ContentBlockType string - -const ( - ContentBlockTypeText ContentBlockType = "text" - ContentBlockTypeImage ContentBlockType = "image_url" - ContentBlockTypeInputAudio ContentBlockType = "input_audio" -) - -type ContentBlock struct { - Type ContentBlockType `json:"type"` - Text *string `json:"text,omitempty"` - ImageURL *ImageURLStruct `json:"image_url,omitempty"` - InputAudio *InputAudioStruct `json:"input_audio,omitempty"` -} - -// ToolMessage represents a message from a tool -type ToolMessage struct { - ToolCallID *string `json:"tool_call_id,omitempty"` -} - -// AssistantMessage represents a message from an assistant -type AssistantMessage struct { - Refusal *string `json:"refusal,omitempty"` - Annotations []Annotation `json:"annotations,omitempty"` - ToolCalls *[]ToolCall `json:"tool_calls,omitempty"` - Thought *string `json:"thought,omitempty"` -} - -// ImageContent represents image data in a message. -type ImageURLStruct struct { - URL string `json:"url"` - Detail *string `json:"detail,omitempty"` -} - -// ImageContentType represents the type of image content -type ImageContentType string - -const ( - ImageContentTypeBase64 ImageContentType = "base64" - ImageContentTypeURL ImageContentType = "url" -) - -// URLTypeInfo contains extracted information about a URL -type URLTypeInfo struct { - Type ImageContentType - MediaType *string - DataURLWithoutPrefix *string // URL without the prefix (eg data:image/png;base64,iVBORw0KGgo...) -} - -// InputAudioStruct represents audio data in a message. -// Data carries the audio payload as a string (e.g., data URL or provider-accepted encoded content). -// Format is optional (e.g., "wav", "mp3"); when nil, providers may attempt auto-detection. -type InputAudioStruct struct { - Data string `json:"data"` - Format *string `json:"format,omitempty"` -} - //* Response Structs // BifrostResponse represents the complete result from any bifrost request. type BifrostResponse struct { - ID string `json:"id,omitempty"` - Object string `json:"object,omitempty"` // text.completion, chat.completion, embedding, speech, transcribe - Choices []BifrostResponseChoice `json:"choices,omitempty"` - Data []BifrostEmbedding `json:"data,omitempty"` // Maps to "data" field in provider responses (e.g., OpenAI embedding format) - Speech *BifrostSpeech `json:"speech,omitempty"` // Maps to "speech" field in provider responses (e.g., OpenAI speech format) - Transcribe *BifrostTranscribe `json:"transcribe,omitempty"` // Maps to "transcribe" field in provider responses (e.g., OpenAI transcription format) - Model string `json:"model,omitempty"` - Created int `json:"created,omitempty"` // The Unix timestamp (in seconds). - ServiceTier *string `json:"service_tier,omitempty"` - SystemFingerprint *string `json:"system_fingerprint,omitempty"` - Usage *LLMUsage `json:"usage,omitempty"` - ExtraFields BifrostResponseExtraFields `json:"extra_fields"` + ID string `json:"id,omitempty"` + Object string `json:"object,omitempty"` // text.completion, chat.completion, embedding, speech, transcribe + Choices []BifrostChatResponseChoice `json:"choices,omitempty"` + Data []BifrostEmbedding `json:"data,omitempty"` // Maps to "data" field in provider responses (e.g., OpenAI embedding format) + Speech *BifrostSpeech `json:"speech,omitempty"` // Maps to "speech" field in provider responses (e.g., OpenAI speech format) + Transcribe *BifrostTranscribe `json:"transcribe,omitempty"` // Maps to "transcribe" field in provider responses (e.g., OpenAI transcription format) + Model string `json:"model,omitempty"` + Created int `json:"created,omitempty"` // The Unix timestamp (in seconds). + ServiceTier *string `json:"service_tier,omitempty"` + SystemFingerprint *string `json:"system_fingerprint,omitempty"` + Usage *LLMUsage `json:"usage,omitempty"` + ExtraFields BifrostResponseExtraFields `json:"extra_fields"` + + *ResponsesResponse } // LLMUsage represents token usage information type LLMUsage struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` - TokenDetails *TokenDetails `json:"prompt_tokens_details,omitempty"` - CompletionTokensDetails *CompletionTokensDetails `json:"completion_tokens_details,omitempty"` + PromptTokens int `json:"prompt_tokens,omitempty"` + CompletionTokens int `json:"completion_tokens,omitempty"` + TotalTokens int `json:"total_tokens"` + + *ResponsesExtendedResponseUsage } type AudioLLMUsage struct { @@ -629,35 +265,6 @@ type LogProbs struct { *TextCompletionLogProb } -// FunctionCall represents a call to a function. -type FunctionCall struct { - Name *string `json:"name"` - Arguments string `json:"arguments"` // stringified json as retured by OpenAI, might not be a valid JSON always -} - -// ToolCall represents a tool call in a message -type ToolCall struct { - Type *string `json:"type,omitempty"` - ID *string `json:"id,omitempty"` - Function FunctionCall `json:"function"` -} - -// Citation represents a citation in a response. -type Citation struct { - StartIndex int `json:"start_index"` - EndIndex int `json:"end_index"` - Title string `json:"title"` - URL *string `json:"url,omitempty"` - Sources *interface{} `json:"sources,omitempty"` - Type *string `json:"type,omitempty"` -} - -// Annotation represents an annotation in a response. -type Annotation struct { - Type string `json:"type"` - Citation Citation `json:"url_citation"` -} - type BifrostEmbedding struct { Index int `json:"index"` Object string `json:"object"` // embedding @@ -708,11 +315,11 @@ func (be *BifrostEmbeddingResponse) UnmarshalJSON(data []byte) error { return fmt.Errorf("embedding field is neither a string nor an array of float32 nor a 2D array of float32") } -// BifrostResponseChoice represents a choice in the completion result. +// BifrostChatResponseChoice represents a choice in the completion result. // This struct can represent either a streaming or non-streaming response choice. // IMPORTANT: Only one of BifrostNonStreamResponseChoice or BifrostStreamResponseChoice // should be non-nil at a time. -type BifrostResponseChoice struct { +type BifrostChatResponseChoice struct { Index int `json:"index"` FinishReason *string `json:"finish_reason,omitempty"` LogProbs *LogProbs `json:"log_probs,omitempty"` @@ -729,8 +336,8 @@ type BifrostTextCompletionResponseChoice struct { // BifrostNonStreamResponseChoice represents a choice in the non-stream response type BifrostNonStreamResponseChoice struct { - Message BifrostMessage `json:"message"` - StopString *string `json:"stop,omitempty"` + Message ChatMessage `json:"message"` + StopString *string `json:"stop,omitempty"` } // BifrostStreamResponseChoice represents a choice in the stream response @@ -740,11 +347,11 @@ type BifrostStreamResponseChoice struct { // BifrostStreamDelta represents a delta in the stream response type BifrostStreamDelta struct { - Role *string `json:"role,omitempty"` // Only in the first chunk - Content *string `json:"content,omitempty"` // May be empty string or null - Thought *string `json:"thought,omitempty"` // May be empty string or null - Refusal *string `json:"refusal,omitempty"` // Refusal content if any - ToolCalls []ToolCall `json:"tool_calls,omitempty"` // If tool calls used (supports incremental updates) + Role *string `json:"role,omitempty"` // Only in the first chunk + Content *string `json:"content,omitempty"` // May be empty string or null + Thought *string `json:"thought,omitempty"` // May be empty string or null + Refusal *string `json:"refusal,omitempty"` // Refusal content if any + ToolCalls []ChatAssistantMessageToolCall `json:"tool_calls,omitempty"` // If tool calls used (supports incremental updates) } type BifrostSpeech struct { @@ -753,6 +360,7 @@ type BifrostSpeech struct { *BifrostSpeechStreamResponse } + type BifrostSpeechStreamResponse struct { Type string `json:"type"` } @@ -784,54 +392,16 @@ type BifrostTranscribeStreamResponse struct { Delta *string `json:"delta,omitempty"` // For delta events } -// TranscriptionLogProb represents log probability information for transcription -type TranscriptionLogProb struct { - Token string `json:"token"` - LogProb float64 `json:"logprob"` - Bytes []int `json:"bytes"` -} - -// TranscriptionWord represents word-level timing information -type TranscriptionWord struct { - Word string `json:"word"` - Start float64 `json:"start"` - End float64 `json:"end"` -} - -// TranscriptionSegment represents segment-level transcription information -type TranscriptionSegment struct { - ID int `json:"id"` - Seek int `json:"seek"` - Start float64 `json:"start"` - End float64 `json:"end"` - Text string `json:"text"` - Tokens []int `json:"tokens"` - Temperature float64 `json:"temperature"` - AvgLogProb float64 `json:"avg_logprob"` - CompressionRatio float64 `json:"compression_ratio"` - NoSpeechProb float64 `json:"no_speech_prob"` -} - -// TranscriptionUsage represents usage information for transcription -type TranscriptionUsage struct { - Type string `json:"type"` // "tokens" or "duration" - InputTokens *int `json:"input_tokens,omitempty"` - InputTokenDetails *AudioTokenDetails `json:"input_token_details,omitempty"` - OutputTokens *int `json:"output_tokens,omitempty"` - TotalTokens *int `json:"total_tokens,omitempty"` - Seconds *int `json:"seconds,omitempty"` // For duration-based usage -} - // BifrostResponseExtraFields contains additional fields in a response. type BifrostResponseExtraFields struct { - Provider ModelProvider `json:"provider"` - Params ModelParameters `json:"model_params"` - Latency *float64 `json:"latency,omitempty"` - ChatHistory *[]BifrostMessage `json:"chat_history,omitempty"` - BilledUsage *BilledLLMUsage `json:"billed_usage,omitempty"` - ChunkIndex int `json:"chunk_index"` // used for streaming responses to identify the chunk index, will be 0 for non-streaming responses - RawResponse interface{} `json:"raw_response,omitempty"` - CacheDebug *BifrostCacheDebug `json:"cache_debug,omitempty"` + RequestType RequestType `json:"request_type"` + Provider ModelProvider `json:"provider"` + ModelRequested string `json:"model_requested"` + Latency *float64 `json:"latency,omitempty"` + BilledUsage *BilledLLMUsage `json:"billed_usage,omitempty"` + ChunkIndex int `json:"chunk_index"` // used for streaming responses to identify the chunk index, will be 0 for non-streaming responses + RawResponse interface{} `json:"raw_response,omitempty"` + CacheDebug *BifrostCacheDebug `json:"cache_debug,omitempty"` } // BifrostCacheDebug represents debug information about the cache. @@ -869,14 +439,14 @@ type BifrostStream struct { // - AllowFallbacks = &false: Bifrost will return this error immediately, no fallbacks // - AllowFallbacks = nil: Treated as true by default (fallbacks allowed for resilience) type BifrostError struct { - Provider ModelProvider `json:"-"` - EventID *string `json:"event_id,omitempty"` - Type *string `json:"type,omitempty"` - IsBifrostError bool `json:"is_bifrost_error"` - StatusCode *int `json:"status_code,omitempty"` - Error ErrorField `json:"error"` - AllowFallbacks *bool `json:"-"` // Optional: Controls fallback behavior (nil = true by default) - StreamControl *StreamControl `json:"-"` // Optional: Controls stream behavior + EventID *string `json:"event_id,omitempty"` + Type *string `json:"type,omitempty"` + IsBifrostError bool `json:"is_bifrost_error"` + StatusCode *int `json:"status_code,omitempty"` + Error ErrorField `json:"error"` + AllowFallbacks *bool `json:"-"` // Optional: Controls fallback behavior (nil = true by default) + StreamControl *StreamControl `json:"-"` // Optional: Controls stream behavior + ExtraFields BifrostErrorExtraFields `json:"extra_fields,omitempty"` } type StreamControl struct { @@ -893,3 +463,9 @@ type ErrorField struct { Param interface{} `json:"param,omitempty"` EventID *string `json:"event_id,omitempty"` } + +type BifrostErrorExtraFields struct { + Provider ModelProvider `json:"provider"` + ModelRequested string `json:"model_requested"` + RequestType RequestType `json:"request_type"` +} diff --git a/core/schemas/chatcompletions.go b/core/schemas/chatcompletions.go new file mode 100644 index 0000000000..6aa243440b --- /dev/null +++ b/core/schemas/chatcompletions.go @@ -0,0 +1,323 @@ +package schemas + +import ( + "fmt" + + "github.com/bytedance/sonic" +) + +// Parameters + +type ChatParameters struct { + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` // Penalizes frequent tokens + LogitBias *map[string]float64 `json:"logit_bias,omitempty"` // Bias for logit values + LogProbs *bool `json:"logprobs,omitempty"` // Number of logprobs to return + MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"` // Maximum number of tokens to generate + Metadata *map[string]any `json:"metadata,omitempty"` // Metadata to be returned with the response + Modalities *[]string `json:"modalities,omitempty"` // Modalities to be returned with the response + ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` + PresencePenalty *float64 `json:"presence_penalty,omitempty"` // Penalizes repeated tokens + PromptCacheKey *string `json:"prompt_cache_key,omitempty"` // Prompt cache key + ReasoningEffort *string `json:"reasoning_effort,omitempty"` // "minimal" | "low" | "medium" | "high" + ResponseFormat *interface{} `json:"response_format,omitempty"` // Format for the response + SafetyIdentifier *string `json:"safety_identifier,omitempty"` // Safety identifier + Seed *int `json:"seed,omitempty"` + ServiceTier *string `json:"service_tier,omitempty"` + StreamOptions *ChatStreamOptions `json:"stream_options,omitempty"` + Stop *[]string `json:"stop,omitempty"` + Store *bool `json:"store,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopLogProbs *int `json:"top_logprobs,omitempty"` + TopP *float64 `json:"top_p,omitempty"` // Controls diversity via nucleus sampling + ToolChoice *ChatToolChoice `json:"tool_choice,omitempty"` // Whether to call a tool + Tools []ChatTool `json:"tools,omitempty"` // Tools to use + User *string `json:"user,omitempty"` // User identifier for tracking + Verbosity *string `json:"verbosity,omitempty"` // "low" | "medium" | "high" + + // Dynamic parameters that can be provider-specific, they are directly + // added to the request as is. + ExtraParams map[string]interface{} `json:"-"` +} + +type ChatStreamOptions struct { + IncludeObfuscation *bool `json:"include_obfuscation,omitempty"` + IncludeUsage *bool `json:"include_usage,omitempty"` // Bifrost marks this as true by default +} + +// TOOLS + +type ChatToolType string + +const ( + ChatToolTypeFunction ChatToolType = "function" + ChatToolTypeCustom ChatToolType = "custom" +) + +type ChatTool struct { + Type ChatToolType `json:"type"` + Function *ChatToolFunction `json:"function,omitempty"` // Function definition + Custom *ChatToolCustom `json:"custom,omitempty"` // Custom tool definition +} +type ChatToolFunction struct { + Name string `json:"name"` // Name of the function + Description *string `json:"description,omitempty"` // Description of the parameters + Parameters *ToolFunctionParameters `json:"parameters,omitempty"` // A JSON schema object describing the parameters + Strict *bool `json:"strict,omitempty"` // Whether to enforce strict parameter validation +} + +// FunctionParameters represents the parameters for a function definition. +type ToolFunctionParameters struct { + Type string `json:"type"` // Type of the parameters + Description *string `json:"description,omitempty"` // Description of the parameters + Required []string `json:"required,omitempty"` // Required parameter names + Properties map[string]interface{} `json:"properties,omitempty"` // Parameter properties + Enum *[]string `json:"enum,omitempty"` // Enum values for the parameters +} + +type ChatToolCustom struct { + Format *ChatToolCustomFormat `json:"format,omitempty"` // The input format +} + +type ChatToolCustomFormat struct { + Type string `json:"type"` // always "text" + Grammar *ChatToolCustomGrammarFormat `json:"grammar,omitempty"` +} + +// ChatCustomToolGrammarFormat - A grammar defined by the user +type ChatToolCustomGrammarFormat struct { + Definition string `json:"definition"` // The grammar definition + Syntax string `json:"syntax"` // "lark" | "regex" +} + +// Combined tool choices for all providers, make sure to check the provider's +// documentation to see which tool choices are supported. +type ChatToolChoiceType string + +const ( + ChatToolChoiceTypeNone ChatToolChoiceType = "none" + ChatToolChoiceTypeAny ChatToolChoiceType = "any" + ChatToolChoiceTypeRequired ChatToolChoiceType = "required" + // ChatToolChoiceTypeFunction means a specific tool must be called + ChatToolChoiceTypeFunction ChatToolChoiceType = "function" + // ChatToolChoiceTypeAllowedTools means a specific tool must be called + ChatToolChoiceTypeAllowedTools ChatToolChoiceType = "allowed_tools" + // ChatToolChoiceTypeCustom means a custom tool must be called + ChatToolChoiceTypeCustom ChatToolChoiceType = "custom" +) + +type ChatToolChoiceStruct struct { + Type ChatToolChoiceType `json:"type"` // Type of tool choice + Function ChatToolChoiceFunction `json:"function,omitempty"` // Function to call if type is ToolChoiceTypeFunction + Custom ChatToolChoiceCustom `json:"custom,omitempty"` // Custom tool to call if type is ToolChoiceTypeCustom + AllowedTools ChatToolChoiceAllowedTools `json:"allowed_tools,omitempty"` // Allowed tools to call if type is ToolChoiceTypeAllowedTools +} + +type ChatToolChoice struct { + ChatToolChoiceStr *string + ChatToolChoiceStruct *ChatToolChoiceStruct +} + +// MarshalJSON implements custom JSON marshalling for ChatMessageContent. +// It marshals either ContentStr or ContentBlocks directly without wrapping. +func (bc ChatToolChoice) MarshalJSON() ([]byte, error) { + // Validation: ensure only one field is set at a time + if bc.ChatToolChoiceStr != nil && bc.ChatToolChoiceStruct != nil { + return nil, fmt.Errorf("both ChatToolChoiceStr, ChatToolChoiceStruct are set; only one should be non-nil") + } + + if bc.ChatToolChoiceStr != nil { + return sonic.Marshal(bc.ChatToolChoiceStr) + } + if bc.ChatToolChoiceStruct != nil { + return sonic.Marshal(bc.ChatToolChoiceStruct) + } + // If both are nil, return null + return sonic.Marshal(nil) +} + +// UnmarshalJSON implements custom JSON unmarshalling for ChatMessageContent. +// It determines whether "content" is a string or array and assigns to the appropriate field. +// It also handles direct string/array content without a wrapper object. +func (bc *ChatToolChoice) UnmarshalJSON(data []byte) error { + // First, try to unmarshal as a direct string + var toolChoiceStr string + if err := sonic.Unmarshal(data, &toolChoiceStr); err == nil { + bc.ChatToolChoiceStr = &toolChoiceStr + bc.ChatToolChoiceStruct = nil + return nil + } + + // Try to unmarshal as a direct array of ContentBlock + var chatToolChoice ChatToolChoiceStruct + if err := sonic.Unmarshal(data, &chatToolChoice); err == nil { + bc.ChatToolChoiceStr = nil + bc.ChatToolChoiceStruct = &chatToolChoice + return nil + } + + return fmt.Errorf("tool_choice field is neither a string nor a ChatToolChoiceStruct object") +} + +type ChatToolChoiceFunction struct { + Name string `json:"name"` +} + +type ChatToolChoiceCustom struct { + Name string `json:"name"` +} + +type ChatToolChoiceAllowedTools struct { + Mode string `json:"mode"` // "auto" | "required" + Tools []ChatToolChoiceAllowedToolsTool `json:"tools"` +} + +type ChatToolChoiceAllowedToolsTool struct { + Type string `json:"type"` // "function" + Function ChatToolChoiceFunction `json:"function,omitempty"` +} + +// MESSAGES + +// ModelChatMessageRole represents the role of a chat message +type ChatMessageRole string + +const ( + ChatMessageRoleAssistant ChatMessageRole = "assistant" + ChatMessageRoleUser ChatMessageRole = "user" + ChatMessageRoleSystem ChatMessageRole = "system" + ChatMessageRoleTool ChatMessageRole = "tool" + ChatMessageRoleDeveloper ChatMessageRole = "developer" +) + +// ChatMessage represents a message in a chat conversation. +type ChatMessage struct { + Name *string `json:"name,omitempty"` // for chat completions + Role ChatMessageRole `json:"role,omitempty"` + Content ChatMessageContent `json:"content,omitempty"` + + // Embedded pointer structs - when non-nil, their exported fields are flattened into the top-level JSON object + // IMPORTANT: Only one of the following can be non-nil at a time, otherwise the JSON marshalling will override the common fields + *ChatToolMessage + *ChatAssistantMessage +} + +type ChatMessageContent struct { + ContentStr *string + ContentBlocks *[]ChatContentBlock +} + +// MarshalJSON implements custom JSON marshalling for ChatMessageContent. +// It marshals either ContentStr or ContentBlocks directly without wrapping. +func (mc ChatMessageContent) MarshalJSON() ([]byte, error) { + // Validation: ensure only one field is set at a time + if mc.ContentStr != nil && mc.ContentBlocks != nil { + return nil, fmt.Errorf("both Content string and Content blocks are set; only one should be non-nil") + } + + if mc.ContentStr != nil { + return sonic.Marshal(*mc.ContentStr) + } + if mc.ContentBlocks != nil { + return sonic.Marshal(*mc.ContentBlocks) + } + // If both are nil, return null + return sonic.Marshal(nil) +} + +// UnmarshalJSON implements custom JSON unmarshalling for ChatMessageContent. +// It determines whether "content" is a string or array and assigns to the appropriate field. +// It also handles direct string/array content without a wrapper object. +func (mc *ChatMessageContent) UnmarshalJSON(data []byte) error { + // First, try to unmarshal as a direct string + var stringContent string + if err := sonic.Unmarshal(data, &stringContent); err == nil { + mc.ContentStr = &stringContent + return nil + } + + // Try to unmarshal as a direct array of ContentBlock + var arrayContent []ChatContentBlock + if err := sonic.Unmarshal(data, &arrayContent); err == nil { + mc.ContentBlocks = &arrayContent + return nil + } + + return fmt.Errorf("content field is neither a string nor an array of Content blocks") +} + +type ChatContentBlockType string + +const ( + ChatContentBlockTypeText ChatContentBlockType = "text" + ChatContentBlockTypeImage ChatContentBlockType = "image_url" + ChatContentBlockTypeInputAudio ChatContentBlockType = "input_audio" + ChatContentBlockTypeFile ChatContentBlockType = "input_file" + ChatContentBlockTypeRefusal ChatContentBlockType = "refusal" +) + +type ChatContentBlock struct { + Type ChatContentBlockType `json:"type"` + Text *string `json:"text,omitempty"` + Refusal *string `json:"refusal,omitempty"` + ImageURLStruct *ChatInputImage `json:"image_url,omitempty"` + InputAudio *ChatInputAudio `json:"input_audio,omitempty"` + File *ChatInputFile `json:"file,omitempty"` +} + +// ImageContent represents image data in a message. +type ChatInputImage struct { + URL string `json:"url"` + Detail *string `json:"detail,omitempty"` +} + +// InputAudioStruct represents audio data in a message. +// Data carries the audio payload as a string (e.g., data URL or provider-accepted encoded content). +// Format is optional (e.g., "wav", "mp3"); when nil, providers may attempt auto-detection. +type ChatInputAudio struct { + Data string `json:"data"` + Format *string `json:"format,omitempty"` +} + +type ChatInputFile struct { + FileData *string `json:"file_data,omitempty"` // Base64 encoded file data + FileID *string `json:"file_id,omitempty"` // Reference to uploaded file + Filename *string `json:"filename,omitempty"` // Name of the file +} + +type ChatToolMessage struct { + ToolCallID *string `json:"tool_call_id,omitempty"` +} + +type ChatAssistantMessage struct { + Refusal *string `json:"refusal,omitempty"` + Annotations []ChatAssistantMessageAnnotation `json:"annotations,omitempty"` + ToolCalls *[]ChatAssistantMessageToolCall `json:"tool_calls,omitempty"` +} + +type ChatAssistantMessageAnnotation struct { + Type string `json:"type"` + Citation ChatAssistantMessageAnnotationCitation `json:"url_citation"` +} + +// Citation represents a citation in a response. +type ChatAssistantMessageAnnotationCitation struct { + StartIndex int `json:"start_index"` + EndIndex int `json:"end_index"` + Title string `json:"title"` + URL *string `json:"url,omitempty"` + Sources *interface{} `json:"sources,omitempty"` + Type *string `json:"type,omitempty"` +} + +// ToolCall represents a tool call in a message +type ChatAssistantMessageToolCall struct { + Type *string `json:"type,omitempty"` + ID *string `json:"id,omitempty"` + Function ChatAssistantMessageToolCallFunction `json:"function"` +} + +// FunctionCall represents a call to a function. +type ChatAssistantMessageToolCallFunction struct { + Name *string `json:"name"` + Arguments string `json:"arguments"` // stringified json as retured by OpenAI, might not be a valid JSON always +} diff --git a/core/schemas/embedding.go b/core/schemas/embedding.go new file mode 100644 index 0000000000..b71cf706d8 --- /dev/null +++ b/core/schemas/embedding.go @@ -0,0 +1,91 @@ +package schemas + +import ( + "fmt" + + "github.com/bytedance/sonic" +) + +// EmbeddingInput represents the input for an embedding request. +type EmbeddingInput struct { + Text *string + Texts []string + Embedding []int + Embeddings [][]int +} + +func (e *EmbeddingInput) MarshalJSON() ([]byte, error) { + // enforce one-of + set := 0 + if e.Text != nil { + set++ + } + if e.Texts != nil { + set++ + } + if e.Embedding != nil { + set++ + } + if e.Embeddings != nil { + set++ + } + if set == 0 { + return nil, fmt.Errorf("embedding input is empty") + } + if set > 1 { + return nil, fmt.Errorf("embedding input must set exactly one of: text, texts, embedding, embeddings") + } + + if e.Text != nil { + return sonic.Marshal(*e.Text) + } + if e.Texts != nil { + return sonic.Marshal(e.Texts) + } + if e.Embedding != nil { + return sonic.Marshal(e.Embedding) + } + if e.Embeddings != nil { + return sonic.Marshal(e.Embeddings) + } + + return nil, fmt.Errorf("invalid embedding input") +} + +func (e *EmbeddingInput) UnmarshalJSON(data []byte) error { + // Try string + var s string + if err := sonic.Unmarshal(data, &s); err == nil { + e.Text = &s + return nil + } + // Try []string + var ss []string + if err := sonic.Unmarshal(data, &ss); err == nil { + e.Texts = ss + return nil + } + // Try []int + var i []int + if err := sonic.Unmarshal(data, &i); err == nil { + e.Embedding = i + return nil + } + // Try [][]int + var i2 [][]int + if err := sonic.Unmarshal(data, &i2); err == nil { + e.Embeddings = i2 + return nil + } + + return fmt.Errorf("unsupported embedding input shape") +} + +type EmbeddingParameters struct { + EncodingFormat *string `json:"encoding_format,omitempty"` // Format for embedding output (e.g., "float", "base64") + Dimensions *int `json:"dimensions,omitempty"` // Number of dimensions for embedding output + + // Dynamic parameters that can be provider-specific, they are directly + // added to the request as is. + ExtraParams map[string]interface{} `json:"-"` +} diff --git a/core/schemas/mux.go b/core/schemas/mux.go new file mode 100644 index 0000000000..f36cfe3513 --- /dev/null +++ b/core/schemas/mux.go @@ -0,0 +1,865 @@ +package schemas + +// ============================================================================= +// BIDIRECTIONAL CONVERSION METHODS +// ============================================================================= +// +// This section contains methods for converting between Chat Completions API +// and Responses API formats. These methods are attached to the structs themselves +// for easy conversion in both directions. +// +// Key Features: +// 1. Bidirectional: Convert to and from both formats +// 2. Data preservation: All relevant data is preserved during conversion +// 3. Aggregation/Spreading: Handle tool messages properly for each format +// 4. Validation: Ensure data integrity during conversion +// +// ============================================================================= + +// ============================================================================= +// TOOL CONVERSION METHODS +// ============================================================================= + +// ToResponsesTool converts a ChatTool to ResponsesTool format +func (ct *ChatTool) ToResponsesTool() *ResponsesTool { + if ct == nil { + return &ResponsesTool{} + } + + rt := &ResponsesTool{ + Type: string(ct.Type), + } + + // Convert function tools + if ct.Type == ChatToolTypeFunction && ct.Function != nil { + rt.Name = &ct.Function.Name + rt.Description = ct.Function.Description + + // Create ResponsesToolFunction if needed + if ct.Function.Parameters != nil || ct.Function.Strict != nil { + rt.ResponsesToolFunction = &ResponsesToolFunction{ + Parameters: ct.Function.Parameters, + Strict: ct.Function.Strict, + } + } + } + + // Convert custom tools + if ct.Type == ChatToolTypeCustom && ct.Custom != nil { + if ct.Custom.Format != nil { + rt.ResponsesToolCustom = &ResponsesToolCustom{ + Format: &ResponsesToolCustomFormat{ + Type: ct.Custom.Format.Type, + }, + } + if ct.Custom.Format.Grammar != nil { + rt.ResponsesToolCustom.Format.Definition = &ct.Custom.Format.Grammar.Definition + rt.ResponsesToolCustom.Format.Syntax = &ct.Custom.Format.Grammar.Syntax + } + } + } + + return rt +} + +// ToChatTool converts a ResponsesTool to ChatTool format +func (rt *ResponsesTool) ToChatTool() *ChatTool { + if rt == nil { + return &ChatTool{} + } + + ct := &ChatTool{ + Type: ChatToolType(rt.Type), + } + + // Convert function tools + if rt.Type == "function" { + ct.Function = &ChatToolFunction{} + + if rt.Name != nil { + ct.Function.Name = *rt.Name + } + if rt.Description != nil { + ct.Function.Description = rt.Description + } + if rt.ResponsesToolFunction != nil { + ct.Function.Parameters = rt.ResponsesToolFunction.Parameters + ct.Function.Strict = rt.ResponsesToolFunction.Strict + } + } + + // Convert custom tools + if rt.Type == "custom" && rt.ResponsesToolCustom != nil { + ct.Custom = &ChatToolCustom{} + if rt.ResponsesToolCustom.Format != nil { + ct.Custom.Format = &ChatToolCustomFormat{ + Type: rt.ResponsesToolCustom.Format.Type, + } + if rt.ResponsesToolCustom.Format.Definition != nil && rt.ResponsesToolCustom.Format.Syntax != nil { + ct.Custom.Format.Grammar = &ChatToolCustomGrammarFormat{ + Definition: *rt.ResponsesToolCustom.Format.Definition, + Syntax: *rt.ResponsesToolCustom.Format.Syntax, + } + } + } + } + + return ct +} + +// ============================================================================= +// TOOL CHOICE CONVERSION METHODS +// ============================================================================= + +// ToResponsesToolChoice converts a ChatToolChoice to ResponsesToolChoice format +func (ctc *ChatToolChoice) ToResponsesToolChoice() *ResponsesToolChoice { + if ctc == nil { + return &ResponsesToolChoice{} + } + + rtc := &ResponsesToolChoice{} + + // Handle string choice (e.g., "none", "auto", "required") + if ctc.ChatToolChoiceStr != nil { + rtc.ResponsesToolChoiceStr = ctc.ChatToolChoiceStr + return rtc + } + + // Handle structured choice + if ctc.ChatToolChoiceStruct != nil { + rtc.ResponsesToolChoiceStruct = &ResponsesToolChoiceStruct{ + Type: ResponsesToolChoiceType(ctc.ChatToolChoiceStruct.Type), + } + + switch ctc.ChatToolChoiceStruct.Type { + case ChatToolChoiceTypeNone, ChatToolChoiceTypeAny, ChatToolChoiceTypeRequired: + // These map to mode field + modeStr := string(ctc.ChatToolChoiceStruct.Type) + rtc.ResponsesToolChoiceStruct.Mode = &modeStr + + case ChatToolChoiceTypeFunction: + // Map function choice + if ctc.ChatToolChoiceStruct.Function.Name != "" { + rtc.ResponsesToolChoiceStruct.Name = &ctc.ChatToolChoiceStruct.Function.Name + } + + case ChatToolChoiceTypeAllowedTools: + // Map allowed tools + if len(ctc.ChatToolChoiceStruct.AllowedTools.Tools) > 0 { + tools := make([]ResponsesToolChoiceAllowedToolDef, len(ctc.ChatToolChoiceStruct.AllowedTools.Tools)) + for i, tool := range ctc.ChatToolChoiceStruct.AllowedTools.Tools { + tools[i] = ResponsesToolChoiceAllowedToolDef{ + Type: tool.Type, + } + if tool.Function.Name != "" { + name := tool.Function.Name + tools[i].Name = &name + } + } + rtc.ResponsesToolChoiceStruct.Tools = tools + } + // Copy the mode (e.g., "auto", "required") + if ctc.ChatToolChoiceStruct.AllowedTools.Mode != "" { + mode := ctc.ChatToolChoiceStruct.AllowedTools.Mode + rtc.ResponsesToolChoiceStruct.Mode = &mode + } + + case ChatToolChoiceTypeCustom: + // Map custom choice + if ctc.ChatToolChoiceStruct.Custom.Name != "" { + rtc.ResponsesToolChoiceStruct.Name = &ctc.ChatToolChoiceStruct.Custom.Name + } + } + } + + return rtc +} + +// ToChatToolChoice converts a ResponsesToolChoice to ChatToolChoice format +func (rtc *ResponsesToolChoice) ToChatToolChoice() *ChatToolChoice { + if rtc == nil { + return &ChatToolChoice{} + } + + ctc := &ChatToolChoice{} + + // Handle string choice + if rtc.ResponsesToolChoiceStr != nil { + ctc.ChatToolChoiceStr = rtc.ResponsesToolChoiceStr + return ctc + } + + // Handle structured choice + if rtc.ResponsesToolChoiceStruct != nil { + ctc.ChatToolChoiceStruct = &ChatToolChoiceStruct{ + Type: ChatToolChoiceType(rtc.ResponsesToolChoiceStruct.Type), + } + + // Handle mode-based choices (none, auto, required) + if rtc.ResponsesToolChoiceStruct.Mode != nil { + switch *rtc.ResponsesToolChoiceStruct.Mode { + case "none": + ctc.ChatToolChoiceStruct.Type = ChatToolChoiceTypeNone + case "auto": + ctc.ChatToolChoiceStruct.Type = ChatToolChoiceTypeAny + case "required": + ctc.ChatToolChoiceStruct.Type = ChatToolChoiceTypeRequired + } + } + + // Handle function choice + if rtc.ResponsesToolChoiceStruct.Type == ResponsesToolChoiceTypeFunction && rtc.ResponsesToolChoiceStruct.Name != nil { + ctc.ChatToolChoiceStruct.Function = ChatToolChoiceFunction{ + Name: *rtc.ResponsesToolChoiceStruct.Name, + } + } + + // Handle custom choice + if rtc.ResponsesToolChoiceStruct.Type == ResponsesToolChoiceTypeCustom && rtc.ResponsesToolChoiceStruct.Name != nil { + ctc.ChatToolChoiceStruct.Custom = ChatToolChoiceCustom{ + Name: *rtc.ResponsesToolChoiceStruct.Name, + } + } + + // Handle allowed tools + if len(rtc.ResponsesToolChoiceStruct.Tools) > 0 { + ctc.ChatToolChoiceStruct.Type = ChatToolChoiceTypeAllowedTools + tools := make([]ChatToolChoiceAllowedToolsTool, len(rtc.ResponsesToolChoiceStruct.Tools)) + for i, tool := range rtc.ResponsesToolChoiceStruct.Tools { + tools[i] = ChatToolChoiceAllowedToolsTool{ + Type: tool.Type, + } + if tool.Name != nil { + tools[i].Function = ChatToolChoiceFunction{Name: *tool.Name} + } + } + // Copy the mode if present, otherwise default to "auto" + mode := "auto" + if rtc.ResponsesToolChoiceStruct.Mode != nil && *rtc.ResponsesToolChoiceStruct.Mode != "" { + mode = *rtc.ResponsesToolChoiceStruct.Mode + } + ctc.ChatToolChoiceStruct.AllowedTools = ChatToolChoiceAllowedTools{ + Mode: mode, + Tools: tools, + } + } + + return ctc + } + + return nil +} + +// ============================================================================= +// MESSAGE CONVERSION METHODS +// ============================================================================= + +// ToResponsesMessages converts a ChatMessage to one or more ResponsesMessages +// This handles the expansion of assistant messages with tool calls into separate function_call messages +func (cm *ChatMessage) ToResponsesMessages() []ResponsesMessage { + if cm == nil { + return []ResponsesMessage{} + } + + var messages []ResponsesMessage + + // Check if this is an assistant message with multiple tool calls that need expansion + if cm.ChatAssistantMessage != nil && cm.ChatAssistantMessage.ToolCalls != nil && len(*cm.ChatAssistantMessage.ToolCalls) > 0 { + // Expand multiple tool calls into separate function_call items + for _, tc := range *cm.ChatAssistantMessage.ToolCalls { + messageType := ResponsesMessageTypeFunctionCall + + var callID *string + if tc.ID != nil && *tc.ID != "" { + callID = tc.ID + } + + var namePtr *string + if tc.Function.Name != nil && *tc.Function.Name != "" { + namePtr = tc.Function.Name + } + + // Create a copy of the arguments string to avoid range loop variable capture + var argumentsPtr *string + if tc.Function.Arguments != "" { + argumentsPtr = Ptr(tc.Function.Arguments) + } + + rm := ResponsesMessage{ + Type: &messageType, + Role: Ptr(ResponsesInputMessageRoleAssistant), + ResponsesToolMessage: &ResponsesToolMessage{ + CallID: callID, + Name: namePtr, + Arguments: argumentsPtr, + }, + } + + messages = append(messages, rm) + } + return messages + } + + // Regular message conversion + messageType := ResponsesMessageTypeMessage + role := ResponsesInputMessageRoleUser + + // Determine message type and role + switch cm.Role { + case ChatMessageRoleAssistant: + role = ResponsesInputMessageRoleAssistant + // Check for refusal + if cm.ChatAssistantMessage != nil && cm.ChatAssistantMessage.Refusal != nil { + messageType = ResponsesMessageTypeRefusal + } + case ChatMessageRoleUser: + role = ResponsesInputMessageRoleUser + case ChatMessageRoleSystem: + role = ResponsesInputMessageRoleSystem + case ChatMessageRoleTool: + messageType = ResponsesMessageTypeFunctionCallOutput + role = ResponsesInputMessageRoleUser // Tool messages are typically user role in responses + case ChatMessageRoleDeveloper: + role = ResponsesInputMessageRoleDeveloper + } + + rm := ResponsesMessage{ + Type: &messageType, + Role: &role, + } + + // Handle refusal content specifically - use content blocks with ResponsesOutputMessageContentRefusal + if messageType == ResponsesMessageTypeRefusal && cm.ChatAssistantMessage != nil && cm.ChatAssistantMessage.Refusal != nil { + refusalBlock := ResponsesMessageContentBlock{ + Type: ResponsesOutputMessageContentTypeRefusal, + ResponsesOutputMessageContentRefusal: &ResponsesOutputMessageContentRefusal{ + Refusal: *cm.ChatAssistantMessage.Refusal, + }, + } + rm.Content = &ResponsesMessageContent{ + ContentBlocks: &[]ResponsesMessageContentBlock{refusalBlock}, + } + } else if cm.Content.ContentStr != nil { + // Convert regular string content + rm.Content = &ResponsesMessageContent{ + ContentStr: cm.Content.ContentStr, + } + } else if cm.Content.ContentBlocks != nil { + // Convert content blocks + responseBlocks := make([]ResponsesMessageContentBlock, len(*cm.Content.ContentBlocks)) + for i, block := range *cm.Content.ContentBlocks { + responseBlocks[i] = ResponsesMessageContentBlock{ + Type: ResponsesMessageContentBlockType(block.Type), + Text: block.Text, + } + + // Convert specific block types + if block.ImageURLStruct != nil { + responseBlocks[i].ResponsesInputMessageContentBlockImage = &ResponsesInputMessageContentBlockImage{ + ImageURL: &block.ImageURLStruct.URL, + Detail: block.ImageURLStruct.Detail, + } + } + if block.File != nil { + responseBlocks[i].ResponsesInputMessageContentBlockFile = &ResponsesInputMessageContentBlockFile{ + FileData: block.File.FileData, + Filename: block.File.Filename, + } + responseBlocks[i].FileID = block.File.FileID + } + if block.InputAudio != nil { + format := "" + if block.InputAudio.Format != nil { + format = *block.InputAudio.Format + } + responseBlocks[i].Audio = &ResponsesInputMessageContentBlockAudio{ + Data: block.InputAudio.Data, + Format: format, + } + } + } + rm.Content = &ResponsesMessageContent{ + ContentBlocks: &responseBlocks, + } + } + + // Handle tool messages + if cm.ChatToolMessage != nil { + rm.ResponsesToolMessage = &ResponsesToolMessage{} + if cm.ChatToolMessage.ToolCallID != nil { + rm.ResponsesToolMessage.CallID = cm.ChatToolMessage.ToolCallID + } + + // If tool output content exists, add it to function_call_output + if rm.Content != nil && rm.Content.ContentStr != nil && *rm.Content.ContentStr != "" { + rm.ResponsesToolMessage.ResponsesFunctionToolCallOutput = &ResponsesFunctionToolCallOutput{ + ResponsesFunctionToolCallOutputStr: rm.Content.ContentStr, + } + } + } + + messages = append(messages, rm) + return messages +} + +// ToChatMessages converts a slice of ResponsesMessages back to ChatMessages +// This handles the aggregation of function_call messages back into assistant messages with tool calls +func ToChatMessages(rms []ResponsesMessage) []ChatMessage { + if len(rms) == 0 { + return []ChatMessage{} + } + + var chatMessages []ChatMessage + var currentToolCalls []ChatAssistantMessageToolCall + + for _, rm := range rms { + if rm.Type != nil && *rm.Type == ResponsesMessageTypeReasoning { + continue + } + + // Handle function_call messages - collect them for aggregation + if rm.Type != nil && *rm.Type == ResponsesMessageTypeFunctionCall { + if rm.ResponsesToolMessage != nil { + tc := ChatAssistantMessageToolCall{ + Type: Ptr("function"), + } + + if rm.ResponsesToolMessage.CallID != nil { + tc.ID = rm.ResponsesToolMessage.CallID + } + + tc.Function = ChatAssistantMessageToolCallFunction{} + if rm.ResponsesToolMessage.Name != nil { + tc.Function.Name = rm.ResponsesToolMessage.Name + } + if rm.ResponsesToolMessage.Arguments != nil { + tc.Function.Arguments = *rm.ResponsesToolMessage.Arguments + } + + currentToolCalls = append(currentToolCalls, tc) + } + continue + } + + // If we have collected tool calls, create an assistant message with them + if len(currentToolCalls) > 0 { + // Create a copy of the slice to avoid shared slice header issues + toolCallsCopy := append([]ChatAssistantMessageToolCall(nil), currentToolCalls...) + chatMessages = append(chatMessages, ChatMessage{ + Role: ChatMessageRoleAssistant, + ChatAssistantMessage: &ChatAssistantMessage{ + ToolCalls: &toolCallsCopy, + }, + }) + currentToolCalls = nil // Reset for next batch + } + + // Convert regular message + cm := ChatMessage{} + + // Set role + if rm.Role != nil { + switch *rm.Role { + case ResponsesInputMessageRoleAssistant: + cm.Role = ChatMessageRoleAssistant + case ResponsesInputMessageRoleUser: + cm.Role = ChatMessageRoleUser + case ResponsesInputMessageRoleSystem: + cm.Role = ChatMessageRoleSystem + case ResponsesInputMessageRoleDeveloper: + cm.Role = ChatMessageRoleDeveloper + } + } + + // Handle special message types + if rm.Type != nil { + switch *rm.Type { + case ResponsesMessageTypeFunctionCallOutput: + cm.Role = ChatMessageRoleTool + if rm.ResponsesToolMessage != nil && rm.ResponsesToolMessage.CallID != nil { + cm.ChatToolMessage = &ChatToolMessage{ + ToolCallID: rm.ResponsesToolMessage.CallID, + } + + // Extract content from ResponsesFunctionToolCallOutput if present + // This is needed because OpenAI Responses API uses an "output" field + // which is stored in ResponsesFunctionToolCallOutput + if rm.ResponsesToolMessage.ResponsesFunctionToolCallOutput != nil { + if rm.Content == nil { + rm.Content = &ResponsesMessageContent{} + } + // If Content is not already set, extract from ResponsesFunctionToolCallOutput + if rm.Content.ContentStr == nil && rm.Content.ContentBlocks == nil { + if rm.ResponsesToolMessage.ResponsesFunctionToolCallOutput.ResponsesFunctionToolCallOutputStr != nil { + rm.Content.ContentStr = rm.ResponsesToolMessage.ResponsesFunctionToolCallOutput.ResponsesFunctionToolCallOutputStr + } else if rm.ResponsesToolMessage.ResponsesFunctionToolCallOutput.ResponsesFunctionToolCallOutputBlocks != nil { + rm.Content.ContentBlocks = rm.ResponsesToolMessage.ResponsesFunctionToolCallOutput.ResponsesFunctionToolCallOutputBlocks + } + } + } + } + case ResponsesMessageTypeRefusal: + cm.ChatAssistantMessage = &ChatAssistantMessage{} + // Extract refusal from content blocks or ContentStr + if rm.Content != nil { + if rm.Content.ContentBlocks != nil { + // Look for refusal content block + for _, block := range *rm.Content.ContentBlocks { + if block.Type == ResponsesOutputMessageContentTypeRefusal && block.ResponsesOutputMessageContentRefusal != nil { + refusalText := block.ResponsesOutputMessageContentRefusal.Refusal + cm.ChatAssistantMessage.Refusal = &refusalText + break + } + } + } else if rm.Content.ContentStr != nil { + // Fallback to ContentStr for backward compatibility + cm.ChatAssistantMessage.Refusal = rm.Content.ContentStr + } + } + } + } + + // Convert content (skip for refusal messages since refusal is already extracted) + if rm.Content != nil && (rm.Type == nil || *rm.Type != ResponsesMessageTypeRefusal) { + if rm.Content.ContentStr != nil { + cm.Content = ChatMessageContent{ + ContentStr: rm.Content.ContentStr, + } + } else if rm.Content.ContentBlocks != nil { + chatBlocks := make([]ChatContentBlock, len(*rm.Content.ContentBlocks)) + for i, block := range *rm.Content.ContentBlocks { + // Map ResponsesMessageContentBlockType to ChatContentBlockType + var chatBlockType ChatContentBlockType + switch block.Type { + case ResponsesInputMessageContentBlockTypeText: + chatBlockType = ChatContentBlockTypeText // "input_text" -> "text" + case ResponsesInputMessageContentBlockTypeImage: + chatBlockType = ChatContentBlockTypeImage // "input_image" -> "image_url" + case ResponsesInputMessageContentBlockTypeFile: + chatBlockType = ChatContentBlockTypeFile // "input_file" -> "input_file" (same) + case ResponsesInputMessageContentBlockTypeAudio: + chatBlockType = ChatContentBlockTypeInputAudio // "input_audio" -> "input_audio" (same) + default: + // For unknown types, fall back to direct conversion + chatBlockType = ChatContentBlockType(block.Type) + } + + chatBlocks[i] = ChatContentBlock{ + Type: chatBlockType, + Text: block.Text, + } + + // Convert specific block types + if block.ResponsesInputMessageContentBlockImage != nil { + chatBlocks[i].ImageURLStruct = &ChatInputImage{ + Detail: block.ResponsesInputMessageContentBlockImage.Detail, + } + if block.ResponsesInputMessageContentBlockImage.ImageURL != nil { + chatBlocks[i].ImageURLStruct.URL = *block.ResponsesInputMessageContentBlockImage.ImageURL + } + } + if block.ResponsesInputMessageContentBlockFile != nil { + chatBlocks[i].File = &ChatInputFile{ + FileData: block.ResponsesInputMessageContentBlockFile.FileData, + Filename: block.ResponsesInputMessageContentBlockFile.Filename, + FileID: block.FileID, + } + } + if block.Audio != nil { + chatBlocks[i].InputAudio = &ChatInputAudio{ + Data: block.Audio.Data, + } + if block.Audio.Format != "" { + chatBlocks[i].InputAudio.Format = &block.Audio.Format + } + } + } + cm.Content = ChatMessageContent{ + ContentBlocks: &chatBlocks, + } + } + } + + chatMessages = append(chatMessages, cm) + } + + // Handle any remaining tool calls at the end + if len(currentToolCalls) > 0 { + // Create a copy of the slice to avoid shared slice header issues + toolCallsCopy := append([]ChatAssistantMessageToolCall(nil), currentToolCalls...) + chatMessages = append(chatMessages, ChatMessage{ + Role: ChatMessageRoleAssistant, + ChatAssistantMessage: &ChatAssistantMessage{ + ToolCalls: &toolCallsCopy, + }, + }) + } + + return chatMessages +} + +// ============================================================================= +// REQUEST CONVERSION METHODS +// ============================================================================= + +// ToResponsesRequest converts a BifrostChatRequest to BifrostResponsesRequest format +func (bcr *BifrostChatRequest) ToResponsesRequest() *BifrostResponsesRequest { + if bcr == nil { + return &BifrostResponsesRequest{} + } + + brr := &BifrostResponsesRequest{ + Provider: bcr.Provider, + Model: bcr.Model, + Fallbacks: bcr.Fallbacks, // Copy fallbacks as-is + } + + // Convert Input messages using existing ChatMessage.ToResponsesMessages() + var allResponsesMessages []ResponsesMessage + for _, chatMsg := range bcr.Input { + responsesMessages := chatMsg.ToResponsesMessages() + allResponsesMessages = append(allResponsesMessages, responsesMessages...) + } + brr.Input = allResponsesMessages + + // Convert Parameters + if bcr.Params != nil { + brr.Params = &ResponsesParameters{ + // Map common fields + ParallelToolCalls: bcr.Params.ParallelToolCalls, + PromptCacheKey: bcr.Params.PromptCacheKey, + SafetyIdentifier: bcr.Params.SafetyIdentifier, + ServiceTier: bcr.Params.ServiceTier, + Store: bcr.Params.Store, + Temperature: bcr.Params.Temperature, + TopLogProbs: bcr.Params.TopLogProbs, + TopP: bcr.Params.TopP, + ExtraParams: bcr.Params.ExtraParams, + + // Map specific fields + MaxOutputTokens: bcr.Params.MaxCompletionTokens, // max_completion_tokens -> max_output_tokens + Metadata: bcr.Params.Metadata, + } + + // Convert StreamOptions + if bcr.Params.StreamOptions != nil { + brr.Params.StreamOptions = &ResponsesStreamOptions{ + IncludeObfuscation: bcr.Params.StreamOptions.IncludeObfuscation, + } + } + + // Convert Tools using existing ChatTool.ToResponsesTool() + if len(bcr.Params.Tools) > 0 { + responsesTools := make([]ResponsesTool, 0, len(bcr.Params.Tools)) + for _, chatTool := range bcr.Params.Tools { + responsesTool := chatTool.ToResponsesTool() + responsesTools = append(responsesTools, *responsesTool) + } + brr.Params.Tools = responsesTools + } + + // Convert ToolChoice using existing ChatToolChoice.ToResponsesToolChoice() + if bcr.Params.ToolChoice != nil { + responsesToolChoice := bcr.Params.ToolChoice.ToResponsesToolChoice() + brr.Params.ToolChoice = responsesToolChoice + } + + // Handle Reasoning from reasoning_effort + if bcr.Params.ReasoningEffort != nil { + brr.Params.Reasoning = &ResponsesParametersReasoning{ + Effort: bcr.Params.ReasoningEffort, + } + } + + // Handle Verbosity + if bcr.Params.Verbosity != nil { + if brr.Params.Text == nil { + brr.Params.Text = &ResponsesTextConfig{} + } + brr.Params.Text.Verbosity = bcr.Params.Verbosity + } + } + + return brr +} + +// ToChatRequest converts a BifrostResponsesRequest to BifrostChatRequest format +func (brr *BifrostResponsesRequest) ToChatRequest() *BifrostChatRequest { + if brr == nil { + return &BifrostChatRequest{} + } + + bcr := &BifrostChatRequest{ + Provider: brr.Provider, + Model: brr.Model, + Fallbacks: brr.Fallbacks, // Copy fallbacks as-is + } + + // Convert Input messages using existing ToChatMessages() + bcr.Input = ToChatMessages(brr.Input) + + // Convert Parameters + if brr.Params != nil { + bcr.Params = &ChatParameters{ + // Map common fields + ParallelToolCalls: brr.Params.ParallelToolCalls, + PromptCacheKey: brr.Params.PromptCacheKey, + SafetyIdentifier: brr.Params.SafetyIdentifier, + ServiceTier: brr.Params.ServiceTier, + Store: brr.Params.Store, + Temperature: brr.Params.Temperature, + TopLogProbs: brr.Params.TopLogProbs, + TopP: brr.Params.TopP, + ExtraParams: brr.Params.ExtraParams, + + // Map specific fields + MaxCompletionTokens: brr.Params.MaxOutputTokens, // max_output_tokens -> max_completion_tokens + Metadata: brr.Params.Metadata, + } + + // Convert StreamOptions + if brr.Params.StreamOptions != nil { + bcr.Params.StreamOptions = &ChatStreamOptions{ + IncludeObfuscation: brr.Params.StreamOptions.IncludeObfuscation, + IncludeUsage: Ptr(true), // Default for Chat API + } + } + + // Convert Tools using existing ResponsesTool.ToChatTool() + if len(brr.Params.Tools) > 0 { + chatTools := make([]ChatTool, 0, len(brr.Params.Tools)) + for _, responsesTool := range brr.Params.Tools { + chatTool := responsesTool.ToChatTool() + chatTools = append(chatTools, *chatTool) + } + bcr.Params.Tools = chatTools + } + + // Convert ToolChoice using existing ResponsesToolChoice.ToChatToolChoice() + if brr.Params.ToolChoice != nil { + chatToolChoice := brr.Params.ToolChoice.ToChatToolChoice() + bcr.Params.ToolChoice = chatToolChoice + } + + // Handle ReasoningEffort from Reasoning + if brr.Params.Reasoning != nil && brr.Params.Reasoning.Effort != nil { + bcr.Params.ReasoningEffort = brr.Params.Reasoning.Effort + } + + // Handle Verbosity from Text config + if brr.Params.Text != nil && brr.Params.Text.Verbosity != nil { + bcr.Params.Verbosity = brr.Params.Text.Verbosity + } + } + + return bcr +} + +// ============================================================================= +// RESPONSE CONVERSION METHODS +// ============================================================================= + +// ToResponsesOnly converts the BifrostResponse to use only Responses API format +// This converts Chat-style fields (Choices) to embedded ResponsesResponse format +func (br *BifrostResponse) ToResponsesOnly() { + // If ResponsesResponse already exists, keep it and clear Chat fields + if br.ResponsesResponse != nil { + br.Choices = nil + return + } + + // Create ResponsesResponse from Chat fields + br.ResponsesResponse = &ResponsesResponse{ + CreatedAt: br.Created, + } + + br.Created = 0 + + // Convert Choices to Output messages + var outputMessages []ResponsesMessage + for _, choice := range br.Choices { + if choice.BifrostNonStreamResponseChoice != nil { + // Convert ChatMessage to ResponsesMessages + responsesMessages := choice.BifrostNonStreamResponseChoice.Message.ToResponsesMessages() + outputMessages = append(outputMessages, responsesMessages...) + } + // Note: Stream choices would need different handling if needed + } + + if len(outputMessages) > 0 { + br.ResponsesResponse.Output = outputMessages + } + + // Convert Usage if needed + if br.Usage != nil { + if br.Usage.ResponsesExtendedResponseUsage == nil { + br.Usage.ResponsesExtendedResponseUsage = &ResponsesExtendedResponseUsage{ + InputTokens: br.Usage.PromptTokens, + OutputTokens: br.Usage.CompletionTokens, + } + + if br.Usage.TotalTokens == 0 { + br.Usage.TotalTokens = br.Usage.PromptTokens + br.Usage.CompletionTokens + } + + br.Usage.PromptTokens = 0 + br.Usage.CompletionTokens = 0 + } + } + + // Clear Chat fields after conversion + br.Choices = nil + br.ExtraFields.RequestType = ResponsesRequest +} + +// ToChatOnly converts the BifrostResponse to use only Chat API format +// This converts embedded ResponsesResponse format to Chat-style fields (Choices) +func (br *BifrostResponse) ToChatOnly() { + if br == nil { + return + } + + // If Choices already exist, keep them and clear ResponsesResponse + if len(br.Choices) > 0 { + br.ResponsesResponse = nil + return + } + + // Create Choices from ResponsesResponse + if br.ResponsesResponse != nil && len(br.ResponsesResponse.Output) > 0 { + // Convert ResponsesMessages back to ChatMessages + chatMessages := ToChatMessages(br.ResponsesResponse.Output) + + // Create choices from chat messages + choices := make([]BifrostChatResponseChoice, 0, len(chatMessages)) + for i, chatMsg := range chatMessages { + choice := BifrostChatResponseChoice{ + Index: i, + BifrostNonStreamResponseChoice: &BifrostNonStreamResponseChoice{ + Message: chatMsg, + }, + } + choices = append(choices, choice) + } + + br.Choices = choices + + // Update Created timestamp from ResponsesResponse + if br.ResponsesResponse.CreatedAt > 0 { + br.Created = br.ResponsesResponse.CreatedAt + } + } + + // Convert Usage if needed + if br.Usage != nil && br.Usage.ResponsesExtendedResponseUsage != nil { + // Map Responses usage to Chat usage + br.Usage.PromptTokens = br.Usage.ResponsesExtendedResponseUsage.InputTokens + br.Usage.CompletionTokens = br.Usage.ResponsesExtendedResponseUsage.OutputTokens + if br.Usage.TotalTokens == 0 { + br.Usage.TotalTokens = br.Usage.PromptTokens + br.Usage.CompletionTokens + } + } + + // Clear ResponsesResponse after conversion + br.ResponsesResponse = nil +} diff --git a/core/schemas/provider.go b/core/schemas/provider.go index ca8d8a2a5c..0452397888 100644 --- a/core/schemas/provider.go +++ b/core/schemas/provider.go @@ -96,27 +96,27 @@ type AllowedRequests struct { } // IsOperationAllowed checks if a specific operation is allowed -func (ar *AllowedRequests) IsOperationAllowed(operation Operation) bool { +func (ar *AllowedRequests) IsOperationAllowed(operation RequestType) bool { if ar == nil { return true // Default to allowed if no restrictions } switch operation { - case OperationTextCompletion: + case TextCompletionRequest: return ar.TextCompletion - case OperationChatCompletion: + case ChatCompletionRequest: return ar.ChatCompletion - case OperationChatCompletionStream: + case ChatCompletionStreamRequest: return ar.ChatCompletionStream - case OperationEmbedding: + case EmbeddingRequest: return ar.Embedding - case OperationSpeech: + case SpeechRequest: return ar.Speech - case OperationSpeechStream: + case SpeechStreamRequest: return ar.SpeechStream - case OperationTranscription: + case TranscriptionRequest: return ar.Transcription - case OperationTranscriptionStream: + case TranscriptionStreamRequest: return ar.TranscriptionStream default: return false // Default to not allowed for unknown operations @@ -130,7 +130,7 @@ type CustomProviderConfig struct { } // IsOperationAllowed checks if a specific operation is allowed for this custom provider -func (cpc *CustomProviderConfig) IsOperationAllowed(operation Operation) bool { +func (cpc *CustomProviderConfig) IsOperationAllowed(operation RequestType) bool { if cpc == nil || cpc.AllowedRequests == nil { return true // Default to allowed if no restrictions } @@ -150,19 +150,6 @@ type ProviderConfig struct { CustomProviderConfig *CustomProviderConfig `json:"custom_provider_config,omitempty"` } -type Operation string - -const ( - OperationTextCompletion Operation = "text_completion" - OperationChatCompletion Operation = "chat_completion" - OperationChatCompletionStream Operation = "chat_completion_stream" - OperationEmbedding Operation = "embedding" - OperationSpeech Operation = "speech" - OperationSpeechStream Operation = "speech_stream" - OperationTranscription Operation = "transcription" - OperationTranscriptionStream Operation = "transcription_stream" -) - func (config *ProviderConfig) CheckAndSetDefaults() { if config.ConcurrencyAndBufferSize.Concurrency == 0 { config.ConcurrencyAndBufferSize.Concurrency = DefaultConcurrency @@ -203,19 +190,23 @@ type Provider interface { // GetProviderKey returns the provider's identifier GetProviderKey() ModelProvider // TextCompletion performs a text completion request - TextCompletion(ctx context.Context, key Key, input *BifrostRequest) (*BifrostResponse, *BifrostError) + TextCompletion(ctx context.Context, key Key, request *BifrostTextCompletionRequest) (*BifrostResponse, *BifrostError) // ChatCompletion performs a chat completion request - ChatCompletion(ctx context.Context, key Key, input *BifrostRequest) (*BifrostResponse, *BifrostError) + ChatCompletion(ctx context.Context, key Key, request *BifrostChatRequest) (*BifrostResponse, *BifrostError) // ChatCompletionStream performs a chat completion stream request - ChatCompletionStream(ctx context.Context, postHookRunner PostHookRunner, key Key, input *BifrostRequest) (chan *BifrostStream, *BifrostError) + ChatCompletionStream(ctx context.Context, postHookRunner PostHookRunner, key Key, request *BifrostChatRequest) (chan *BifrostStream, *BifrostError) + // Responses performs a completion request using the Responses API (uses chat completion request internally for non-openai providers) + Responses(ctx context.Context, key Key, request *BifrostResponsesRequest) (*BifrostResponse, *BifrostError) + // ResponsesStream performs a completion request using the Responses API stream (uses chat completion stream request internally for non-openai providers) + ResponsesStream(ctx context.Context, postHookRunner PostHookRunner, key Key, request *BifrostResponsesRequest) (chan *BifrostStream, *BifrostError) // Embedding performs an embedding request - Embedding(ctx context.Context, key Key, input *BifrostRequest) (*BifrostResponse, *BifrostError) + Embedding(ctx context.Context, key Key, request *BifrostEmbeddingRequest) (*BifrostResponse, *BifrostError) // Speech performs a text to speech request - Speech(ctx context.Context, key Key, input *BifrostRequest) (*BifrostResponse, *BifrostError) + Speech(ctx context.Context, key Key, request *BifrostSpeechRequest) (*BifrostResponse, *BifrostError) // SpeechStream performs a text to speech stream request - SpeechStream(ctx context.Context, postHookRunner PostHookRunner, key Key, input *BifrostRequest) (chan *BifrostStream, *BifrostError) + SpeechStream(ctx context.Context, postHookRunner PostHookRunner, key Key, request *BifrostSpeechRequest) (chan *BifrostStream, *BifrostError) // Transcription performs a transcription request - Transcription(ctx context.Context, key Key, input *BifrostRequest) (*BifrostResponse, *BifrostError) + Transcription(ctx context.Context, key Key, request *BifrostTranscriptionRequest) (*BifrostResponse, *BifrostError) // TranscriptionStream performs a transcription stream request - TranscriptionStream(ctx context.Context, postHookRunner PostHookRunner, key Key, input *BifrostRequest) (chan *BifrostStream, *BifrostError) + TranscriptionStream(ctx context.Context, postHookRunner PostHookRunner, key Key, request *BifrostTranscriptionRequest) (chan *BifrostStream, *BifrostError) } diff --git a/core/schemas/providers/anthropic/chat.go b/core/schemas/providers/anthropic/chat.go index a0899184bb..49b71f1df6 100644 --- a/core/schemas/providers/anthropic/chat.go +++ b/core/schemas/providers/anthropic/chat.go @@ -7,39 +7,41 @@ import ( "github.com/maximhq/bifrost/core/schemas" ) -var fnTypePtr = schemas.Ptr(string(schemas.ToolChoiceTypeFunction)) +var fnTypePtr = schemas.Ptr(string(schemas.ChatToolChoiceTypeFunction)) // ToChatCompletionRequest converts an Anthropic messages request to Bifrost format -func (request *AnthropicMessageRequest) ToBifrostRequest() *schemas.BifrostRequest { +func (request *AnthropicMessageRequest) ToBifrostRequest() *schemas.BifrostChatRequest { provider, model := schemas.ParseModelString(request.Model, schemas.Anthropic) - bifrostReq := &schemas.BifrostRequest{ + bifrostReq := &schemas.BifrostChatRequest{ Provider: provider, Model: model, } - messages := []schemas.BifrostMessage{} + messages := []schemas.ChatMessage{} // Add system message if present if request.System != nil { if request.System.ContentStr != nil && *request.System.ContentStr != "" { - messages = append(messages, schemas.BifrostMessage{ - Role: schemas.ModelChatMessageRoleSystem, - Content: schemas.MessageContent{ + messages = append(messages, schemas.ChatMessage{ + Role: schemas.ChatMessageRoleSystem, + Content: schemas.ChatMessageContent{ ContentStr: request.System.ContentStr, }, }) } else if request.System.ContentBlocks != nil { - contentBlocks := []schemas.ContentBlock{} + contentBlocks := []schemas.ChatContentBlock{} for _, block := range *request.System.ContentBlocks { - contentBlocks = append(contentBlocks, schemas.ContentBlock{ - Type: schemas.ContentBlockTypeText, - Text: block.Text, - }) + if block.Text != nil { // System messages will only have text content + contentBlocks = append(contentBlocks, schemas.ChatContentBlock{ + Type: schemas.ChatContentBlockTypeText, + Text: block.Text, + }) + } } - messages = append(messages, schemas.BifrostMessage{ - Role: schemas.ModelChatMessageRoleSystem, - Content: schemas.MessageContent{ + messages = append(messages, schemas.ChatMessage{ + Role: schemas.ChatMessageRoleSystem, + Content: schemas.ChatMessageContent{ ContentBlocks: &contentBlocks, }, }) @@ -48,112 +50,79 @@ func (request *AnthropicMessageRequest) ToBifrostRequest() *schemas.BifrostReque // Convert messages for _, msg := range request.Messages { - var bifrostMsg schemas.BifrostMessage - bifrostMsg.Role = schemas.ModelChatMessageRole(msg.Role) + var bifrostMsg schemas.ChatMessage + bifrostMsg.Role = schemas.ChatMessageRole(msg.Role) if msg.Content.ContentStr != nil { - bifrostMsg.Content = schemas.MessageContent{ + bifrostMsg.Content = schemas.ChatMessageContent{ ContentStr: msg.Content.ContentStr, } } else if msg.Content.ContentBlocks != nil { // Handle different content types - var toolCalls []schemas.ToolCall - var contentBlocks []schemas.ContentBlock + var toolCalls []schemas.ChatAssistantMessageToolCall + var contentBlocks []schemas.ChatContentBlock for _, content := range *msg.Content.ContentBlocks { switch content.Type { - case "text": + case AnthropicContentBlockTypeText: if content.Text != nil { - contentBlocks = append(contentBlocks, schemas.ContentBlock{ - Type: schemas.ContentBlockTypeText, + contentBlocks = append(contentBlocks, schemas.ChatContentBlock{ + Type: schemas.ChatContentBlockTypeText, Text: content.Text, }) } - case "image": + case AnthropicContentBlockTypeImage: if content.Source != nil { - contentBlocks = append(contentBlocks, schemas.ContentBlock{ - Type: schemas.ContentBlockTypeImage, - ImageURL: &schemas.ImageURLStruct{ - URL: func() string { - if content.Source.Data != nil { - mime := "image/png" - if content.Source.MediaType != nil && *content.Source.MediaType != "" { - mime = *content.Source.MediaType - } - return "data:" + mime + ";base64," + *content.Source.Data - } - if content.Source.URL != nil { - return *content.Source.URL - } - return "" - }(), - }, - }) + contentBlocks = append(contentBlocks, content.ToBifrostContentImageBlock()) } - case "tool_use": + case AnthropicContentBlockTypeToolUse: if content.ID != nil && content.Name != nil { - tc := schemas.ToolCall{ + tc := schemas.ChatAssistantMessageToolCall{ Type: fnTypePtr, ID: content.ID, - Function: schemas.FunctionCall{ + Function: schemas.ChatAssistantMessageToolCallFunction{ Name: content.Name, - Arguments: jsonifyInput(content.Input), + Arguments: schemas.JsonifyInput(content.Input), }, } toolCalls = append(toolCalls, tc) } - case "tool_result": + case AnthropicContentBlockTypeToolResult: if content.ToolUseID != nil { - bifrostMsg.ToolMessage = &schemas.ToolMessage{ + bifrostMsg.ChatToolMessage = &schemas.ChatToolMessage{ ToolCallID: content.ToolUseID, } if content.Content.ContentStr != nil { - contentBlocks = append(contentBlocks, schemas.ContentBlock{ - Type: schemas.ContentBlockTypeText, + contentBlocks = append(contentBlocks, schemas.ChatContentBlock{ + Type: schemas.ChatContentBlockTypeText, Text: content.Content.ContentStr, }) } else if content.Content.ContentBlocks != nil { for _, block := range *content.Content.ContentBlocks { if block.Text != nil { - contentBlocks = append(contentBlocks, schemas.ContentBlock{ - Type: schemas.ContentBlockTypeText, + contentBlocks = append(contentBlocks, schemas.ChatContentBlock{ + Type: schemas.ChatContentBlockTypeText, Text: block.Text, }) } else if block.Source != nil { - contentBlocks = append(contentBlocks, schemas.ContentBlock{ - Type: schemas.ContentBlockTypeImage, - ImageURL: &schemas.ImageURLStruct{ - URL: func() string { - if block.Source.Data != nil { - mime := "image/png" - if block.Source.MediaType != nil && *block.Source.MediaType != "" { - mime = *block.Source.MediaType - } - return "data:" + mime + ";base64," + *block.Source.Data - } - if block.Source.URL != nil { - return *block.Source.URL - } - return "" - }()}, - }) + contentBlocks = append(contentBlocks, block.ToBifrostContentImageBlock()) } } } - bifrostMsg.Role = schemas.ModelChatMessageRoleTool + bifrostMsg.Role = schemas.ChatMessageRoleTool } } } // Concatenate all text contents if len(contentBlocks) > 0 { - bifrostMsg.Content = schemas.MessageContent{ + bifrostMsg.Content = schemas.ChatMessageContent{ ContentBlocks: &contentBlocks, } } - if len(toolCalls) > 0 && msg.Role == string(schemas.ModelChatMessageRoleAssistant) { - bifrostMsg.AssistantMessage = &schemas.AssistantMessage{ + if len(toolCalls) > 0 && msg.Role == AnthropicMessageRoleAssistant { + bifrostMsg.ChatAssistantMessage = &schemas.ChatAssistantMessage{ ToolCalls: &toolCalls, } } @@ -161,14 +130,16 @@ func (request *AnthropicMessageRequest) ToBifrostRequest() *schemas.BifrostReque messages = append(messages, bifrostMsg) } - bifrostReq.Input.ChatCompletionInput = &messages + bifrostReq.Input = messages // Convert parameters if request.MaxTokens > 0 || request.Temperature != nil || request.TopP != nil || request.TopK != nil || request.StopSequences != nil { - params := &schemas.ModelParameters{} + params := &schemas.ChatParameters{ + ExtraParams: make(map[string]interface{}), + } if request.MaxTokens > 0 { - params.MaxTokens = &request.MaxTokens + params.MaxCompletionTokens = &request.MaxTokens } if request.Temperature != nil { params.Temperature = request.Temperature @@ -177,10 +148,10 @@ func (request *AnthropicMessageRequest) ToBifrostRequest() *schemas.BifrostReque params.TopP = request.TopP } if request.TopK != nil { - params.TopK = request.TopK + params.ExtraParams["top_k"] = *request.TopK } if request.StopSequences != nil { - params.StopSequences = request.StopSequences + params.Stop = request.StopSequences } bifrostReq.Params = params @@ -188,10 +159,10 @@ func (request *AnthropicMessageRequest) ToBifrostRequest() *schemas.BifrostReque // Convert tools if request.Tools != nil { - tools := []schemas.Tool{} + tools := []schemas.ChatTool{} for _, tool := range *request.Tools { // Convert input_schema to FunctionParameters - params := schemas.FunctionParameters{ + params := schemas.ToolFunctionParameters{ Type: "object", } if tool.InputSchema != nil { @@ -200,49 +171,44 @@ func (request *AnthropicMessageRequest) ToBifrostRequest() *schemas.BifrostReque params.Properties = tool.InputSchema.Properties } - tools = append(tools, schemas.Tool{ - Type: "function", - Function: schemas.Function{ + tools = append(tools, schemas.ChatTool{ + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ Name: tool.Name, - Description: tool.Description, - Parameters: params, + Description: schemas.Ptr(tool.Description), + Parameters: ¶ms, }, }) } if bifrostReq.Params == nil { - bifrostReq.Params = &schemas.ModelParameters{} + bifrostReq.Params = &schemas.ChatParameters{} } - bifrostReq.Params.Tools = &tools + bifrostReq.Params.Tools = tools } // Convert tool choice if request.ToolChoice != nil { if bifrostReq.Params == nil { - bifrostReq.Params = &schemas.ModelParameters{} + bifrostReq.Params = &schemas.ChatParameters{} } - toolChoice := &schemas.ToolChoice{ - ToolChoiceStruct: &schemas.ToolChoiceStruct{ - Type: func() schemas.ToolChoiceType { + toolChoice := &schemas.ChatToolChoice{ + ChatToolChoiceStruct: &schemas.ChatToolChoiceStruct{ + Type: func() schemas.ChatToolChoiceType { if request.ToolChoice.Type == "tool" { - return schemas.ToolChoiceTypeFunction + return schemas.ChatToolChoiceTypeFunction } - return schemas.ToolChoiceType(request.ToolChoice.Type) + return schemas.ChatToolChoiceType(request.ToolChoice.Type) }(), }, } if request.ToolChoice.Type == "tool" && request.ToolChoice.Name != "" { - toolChoice.ToolChoiceStruct.Function = schemas.ToolChoiceFunction{ + toolChoice.ChatToolChoiceStruct.Function = schemas.ChatToolChoiceFunction{ Name: request.ToolChoice.Name, } } bifrostReq.Params.ToolChoice = toolChoice } - // Apply parameter validation - if bifrostReq.Params != nil { - bifrostReq.Params = schemas.ValidateAndFilterParamsForProvider(provider, bifrostReq.Params) - } - return bifrostReq } @@ -257,94 +223,95 @@ func (response *AnthropicMessageResponse) ToBifrostResponse() *schemas.BifrostRe ID: response.ID, Model: response.Model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: schemas.Anthropic, + RequestType: schemas.ChatCompletionRequest, + Provider: schemas.Anthropic, }, } // Collect all content and tool calls into a single message - var toolCalls []schemas.ToolCall - var thinking string - var contentBlocks []schemas.ContentBlock + var toolCalls []schemas.ChatAssistantMessageToolCall + var contentBlocks []schemas.ChatContentBlock // Process content and tool calls - for _, c := range response.Content { - switch c.Type { - case "thinking": - if c.Thinking != nil { - thinking = *c.Thinking - } - case "text": - if c.Text != nil { - contentBlocks = append(contentBlocks, schemas.ContentBlock{ - Type: schemas.ContentBlockTypeText, - Text: c.Text, - }) - } - case "tool_use": - if c.ID != nil && c.Name != nil { - function := schemas.FunctionCall{ - Name: c.Name, + if response.Content != nil { + for _, c := range response.Content { + switch c.Type { + case AnthropicContentBlockTypeText: + if c.Text != nil { + contentBlocks = append(contentBlocks, schemas.ChatContentBlock{ + Type: schemas.ChatContentBlockTypeText, + Text: c.Text, + }) } + case AnthropicContentBlockTypeToolUse: + if c.ID != nil && c.Name != nil { + function := schemas.ChatAssistantMessageToolCallFunction{ + Name: c.Name, + } - // Marshal the input to JSON string - if c.Input != nil { - args, err := json.Marshal(c.Input) - if err != nil { - function.Arguments = fmt.Sprintf("%v", c.Input) + // Marshal the input to JSON string + if c.Input != nil { + args, err := json.Marshal(c.Input) + if err != nil { + function.Arguments = fmt.Sprintf("%v", c.Input) + } else { + function.Arguments = string(args) + } } else { - function.Arguments = string(args) + function.Arguments = "{}" } - } else { - function.Arguments = "{}" - } - toolCalls = append(toolCalls, schemas.ToolCall{ - Type: schemas.Ptr("function"), - ID: c.ID, - Function: function, - }) + toolCalls = append(toolCalls, schemas.ChatAssistantMessageToolCall{ + Type: schemas.Ptr(string(schemas.ChatToolTypeFunction)), + ID: c.ID, + Function: function, + }) + } } } } // Create the assistant message - var assistantMessage *schemas.AssistantMessage + var assistantMessage *schemas.ChatAssistantMessage // Create AssistantMessage if we have tool calls or thinking - if len(toolCalls) > 0 || thinking != "" { - assistantMessage = &schemas.AssistantMessage{} - if len(toolCalls) > 0 { - assistantMessage.ToolCalls = &toolCalls - } - if thinking != "" { - assistantMessage.Thought = &thinking + if len(toolCalls) > 0 { + assistantMessage = &schemas.ChatAssistantMessage{ + ToolCalls: &toolCalls, } } // Create a single choice with the collected content - bifrostResponse.Choices = []schemas.BifrostResponseChoice{ - { - Index: 0, - BifrostNonStreamResponseChoice: &schemas.BifrostNonStreamResponseChoice{ - Message: schemas.BifrostMessage{ - Role: schemas.ModelChatMessageRoleAssistant, - Content: schemas.MessageContent{ - ContentBlocks: &contentBlocks, - }, - AssistantMessage: assistantMessage, - }, - StopString: response.StopSequence, - }, - FinishReason: func() *string { - if response.StopReason != nil && *response.StopReason != "" { - mapped := MapAnthropicFinishReason(*response.StopReason) - return &mapped - } - return nil - }(), + // Create message content + messageContent := schemas.ChatMessageContent{ + ContentBlocks: &contentBlocks, + } + + // Create message + message := schemas.ChatMessage{ + Role: schemas.ChatMessageRoleAssistant, + Content: messageContent, + ChatAssistantMessage: assistantMessage, + } + + // Create choice + choice := schemas.BifrostChatResponseChoice{ + Index: 0, + BifrostNonStreamResponseChoice: &schemas.BifrostNonStreamResponseChoice{ + Message: message, + StopString: response.StopSequence, }, + FinishReason: func() *string { + if response.StopReason != nil && *response.StopReason != "" { + mapped := MapAnthropicFinishReasonToBifrost(*response.StopReason) + return &mapped + } + return nil + }(), } + bifrostResponse.Choices = []schemas.BifrostChatResponseChoice{choice} + // Convert usage information if response.Usage != nil { bifrostResponse.Usage = &schemas.LLMUsage{ @@ -359,12 +326,12 @@ func (response *AnthropicMessageResponse) ToBifrostResponse() *schemas.BifrostRe // ToAnthropicChatCompletionRequest converts a Bifrost request to Anthropic format // This is the reverse of ConvertChatRequestToBifrost for provider-side usage -func ToAnthropicChatCompletionRequest(bifrostReq *schemas.BifrostRequest) *AnthropicMessageRequest { - if bifrostReq == nil || bifrostReq.Input.ChatCompletionInput == nil { +func ToAnthropicChatCompletionRequest(bifrostReq *schemas.BifrostChatRequest) *AnthropicMessageRequest { + if bifrostReq == nil || bifrostReq.Input == nil { return nil } - messages := *bifrostReq.Input.ChatCompletionInput + messages := bifrostReq.Input anthropicReq := &AnthropicMessageRequest{ Model: bifrostReq.Model, MaxTokens: AnthropicDefaultMaxTokens, @@ -372,14 +339,13 @@ func ToAnthropicChatCompletionRequest(bifrostReq *schemas.BifrostRequest) *Anthr // Convert parameters if bifrostReq.Params != nil { - if bifrostReq.Params.MaxTokens != nil { - anthropicReq.MaxTokens = *bifrostReq.Params.MaxTokens + if bifrostReq.Params.MaxCompletionTokens != nil { + anthropicReq.MaxTokens = *bifrostReq.Params.MaxCompletionTokens } anthropicReq.Temperature = bifrostReq.Params.Temperature anthropicReq.TopP = bifrostReq.Params.TopP - anthropicReq.TopK = bifrostReq.Params.TopK - anthropicReq.StopSequences = bifrostReq.Params.StopSequences + anthropicReq.StopSequences = bifrostReq.Params.Stop topK, ok := schemas.SafeExtractIntPointer(bifrostReq.Params.ExtraParams["top_k"]) if ok { anthropicReq.TopK = topK @@ -387,20 +353,18 @@ func ToAnthropicChatCompletionRequest(bifrostReq *schemas.BifrostRequest) *Anthr // Convert tools if bifrostReq.Params.Tools != nil { - tools := make([]AnthropicTool, 0, len(*bifrostReq.Params.Tools)) - for _, tool := range *bifrostReq.Params.Tools { + tools := make([]AnthropicTool, 0, len(bifrostReq.Params.Tools)) + for _, tool := range bifrostReq.Params.Tools { anthropicTool := AnthropicTool{ - Name: tool.Function.Name, - Description: tool.Function.Description, + Name: tool.Function.Name, + } + if tool.Function.Description != nil { + anthropicTool.Description = *tool.Function.Description } // Convert function parameters to input_schema if tool.Function.Parameters.Type != "" || tool.Function.Parameters.Properties != nil { - anthropicTool.InputSchema = &struct { - Type string `json:"type"` - Properties map[string]interface{} `json:"properties"` - Required []string `json:"required"` - }{ + anthropicTool.InputSchema = &schemas.ToolFunctionParameters{ Type: tool.Function.Parameters.Type, Properties: tool.Function.Parameters.Properties, Required: tool.Function.Parameters.Required, @@ -416,19 +380,28 @@ func ToAnthropicChatCompletionRequest(bifrostReq *schemas.BifrostRequest) *Anthr if bifrostReq.Params.ToolChoice != nil { toolChoice := &AnthropicToolChoice{} - if bifrostReq.Params.ToolChoice.ToolChoiceStr != nil { - toolChoice.Type = *bifrostReq.Params.ToolChoice.ToolChoiceStr - } else if bifrostReq.Params.ToolChoice.ToolChoiceStruct != nil { - switch bifrostReq.Params.ToolChoice.ToolChoiceStruct.Type { - case schemas.ToolChoiceTypeFunction: + if bifrostReq.Params.ToolChoice.ChatToolChoiceStr != nil { + switch schemas.ChatToolChoiceType(*bifrostReq.Params.ToolChoice.ChatToolChoiceStr) { + case schemas.ChatToolChoiceTypeAny: + toolChoice.Type = "any" + case schemas.ChatToolChoiceTypeRequired: + toolChoice.Type = "any" + case schemas.ChatToolChoiceTypeNone: + toolChoice.Type = "none" + default: + toolChoice.Type = "auto" + } + } else if bifrostReq.Params.ToolChoice.ChatToolChoiceStruct != nil { + switch bifrostReq.Params.ToolChoice.ChatToolChoiceStruct.Type { + case schemas.ChatToolChoiceTypeFunction: toolChoice.Type = "tool" - toolChoice.Name = bifrostReq.Params.ToolChoice.ToolChoiceStruct.Function.Name + toolChoice.Name = bifrostReq.Params.ToolChoice.ChatToolChoiceStruct.Function.Name case schemas.ChatToolChoiceTypeAllowedTools: toolChoice.Type = "any" case schemas.ChatToolChoiceTypeCustom: toolChoice.Type = "auto" default: - toolChoice.Type = string(bifrostReq.Params.ToolChoice.ToolChoiceStruct.Type) + toolChoice.Type = "auto" } } @@ -442,7 +415,7 @@ func ToAnthropicChatCompletionRequest(bifrostReq *schemas.BifrostRequest) *Anthr for _, msg := range messages { switch msg.Role { - case schemas.ModelChatMessageRoleSystem: + case schemas.ChatMessageRoleSystem: // Handle system message separately if msg.Content.ContentStr != nil { systemContent = &AnthropicContent{ContentStr: msg.Content.ContentStr} @@ -461,14 +434,14 @@ func ToAnthropicChatCompletionRequest(bifrostReq *schemas.BifrostRequest) *Anthr } } - case schemas.ModelChatMessageRoleTool: + case schemas.ChatMessageRoleTool: // Convert tool message to user message with tool_result content - if msg.ToolMessage != nil && msg.ToolMessage.ToolCallID != nil { + if msg.ChatToolMessage != nil && msg.ChatToolMessage.ToolCallID != nil { content := make([]AnthropicContentBlock, 0, 1) toolResult := AnthropicContentBlock{ Type: "tool_result", - ToolUseID: msg.ToolMessage.ToolCallID, + ToolUseID: msg.ChatToolMessage.ToolCallID, } // Convert tool result content @@ -482,8 +455,8 @@ func ToAnthropicChatCompletionRequest(bifrostReq *schemas.BifrostRequest) *Anthr Type: "text", Text: block.Text, }) - } else if block.ImageURL != nil { - blocks = append(blocks, convertImageBlock(block)) + } else if block.ImageURLStruct != nil { + blocks = append(blocks, ConvertToAnthropicImageBlock(block)) } } if len(blocks) > 0 { @@ -501,7 +474,7 @@ func ToAnthropicChatCompletionRequest(bifrostReq *schemas.BifrostRequest) *Anthr default: // Handle user and assistant messages anthropicMsg := AnthropicMessage{ - Role: string(msg.Role), + Role: AnthropicMessageRole(msg.Role), } var content []AnthropicContentBlock @@ -519,23 +492,15 @@ func ToAnthropicChatCompletionRequest(bifrostReq *schemas.BifrostRequest) *Anthr Type: "text", Text: block.Text, }) - } else if block.ImageURL != nil { - content = append(content, convertImageBlock(block)) + } else if block.ImageURLStruct != nil { + content = append(content, ConvertToAnthropicImageBlock(block)) } } } - // Convert thinking content - if msg.AssistantMessage != nil && msg.AssistantMessage.Thought != nil { - content = append(content, AnthropicContentBlock{ - Type: "thinking", - Thinking: msg.AssistantMessage.Thought, - }) - } - // Convert tool calls - if msg.AssistantMessage != nil && msg.AssistantMessage.ToolCalls != nil { - for _, toolCall := range *msg.AssistantMessage.ToolCalls { + if msg.ChatAssistantMessage != nil && msg.ChatAssistantMessage.ToolCalls != nil { + for _, toolCall := range *msg.ChatAssistantMessage.ToolCalls { toolUse := AnthropicContentBlock{ Type: "tool_use", ID: toolCall.ID, @@ -582,7 +547,7 @@ func ToAnthropicChatCompletionResponse(bifrostResp *schemas.BifrostResponse) *An anthropicResp := &AnthropicMessageResponse{ ID: bifrostResp.ID, Type: "message", - Role: string(schemas.ModelChatMessageRoleAssistant), + Role: string(schemas.ChatMessageRoleAssistant), Model: bifrostResp.Model, } @@ -607,14 +572,6 @@ func ToAnthropicChatCompletionResponse(bifrostResp *schemas.BifrostResponse) *An anthropicResp.StopSequence = choice.StopString } - // Add thinking content if present - if choice.Message.AssistantMessage != nil && choice.Message.AssistantMessage.Thought != nil && *choice.Message.AssistantMessage.Thought != "" { - content = append(content, AnthropicContentBlock{ - Type: "thinking", - Text: choice.Message.AssistantMessage.Thought, - }) - } - // Add text content if choice.Message.Content.ContentStr != nil && *choice.Message.Content.ContentStr != "" { content = append(content, AnthropicContentBlock{ @@ -633,8 +590,8 @@ func ToAnthropicChatCompletionResponse(bifrostResp *schemas.BifrostResponse) *An } // Add tool calls as tool_use content - if choice.Message.AssistantMessage != nil && choice.Message.AssistantMessage.ToolCalls != nil { - for _, toolCall := range *choice.Message.AssistantMessage.ToolCalls { + if choice.Message.ChatAssistantMessage != nil && choice.Message.ChatAssistantMessage.ToolCalls != nil { + for _, toolCall := range *choice.Message.ChatAssistantMessage.ToolCalls { // Parse arguments JSON string back to map var input map[string]interface{} if toolCall.Function.Arguments != "" { diff --git a/core/schemas/providers/anthropic/responses.go b/core/schemas/providers/anthropic/responses.go new file mode 100644 index 0000000000..a994e8c4c3 --- /dev/null +++ b/core/schemas/providers/anthropic/responses.go @@ -0,0 +1,1051 @@ +package anthropic + +import ( + "encoding/json" + "fmt" + "strings" + "time" + + "github.com/maximhq/bifrost/core/schemas" +) + +func (request *AnthropicMessageRequest) ToResponsesBifrostRequest() *schemas.BifrostResponsesRequest { + provider, model := schemas.ParseModelString(request.Model, schemas.Anthropic) + + bifrostReq := &schemas.BifrostResponsesRequest{ + Provider: provider, + Model: model, + } + + // Convert basic parameters + params := &schemas.ResponsesParameters{ + ExtraParams: make(map[string]interface{}), + } + + if request.MaxTokens > 0 { + params.MaxOutputTokens = &request.MaxTokens + } + if request.Temperature != nil { + params.Temperature = request.Temperature + } + if request.TopP != nil { + params.TopP = request.TopP + } + if request.TopK != nil { + params.ExtraParams["top_k"] = *request.TopK + } + if request.StopSequences != nil { + params.ExtraParams["stop"] = *request.StopSequences + } + bifrostReq.Params = params + + // Convert messages directly to ChatMessage format + var bifrostMessages []schemas.ResponsesMessage + + // Handle system message - convert Anthropic system field to first message with role "system" + if request.System != nil { + var systemText string + if request.System.ContentStr != nil { + systemText = *request.System.ContentStr + } else if request.System.ContentBlocks != nil { + // Combine text blocks from system content + var textParts []string + for _, block := range *request.System.ContentBlocks { + if block.Text != nil { + textParts = append(textParts, *block.Text) + } + } + systemText = strings.Join(textParts, "\n") + } + + if systemText != "" { + systemMsg := schemas.ResponsesMessage{ + Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleSystem), + Content: &schemas.ResponsesMessageContent{ + ContentStr: &systemText, + }, + } + bifrostMessages = append(bifrostMessages, systemMsg) + } + } + + // Convert regular messages + for _, msg := range request.Messages { + convertedMessages := convertAnthropicMessageToBifrostResponsesMessages(&msg) + bifrostMessages = append(bifrostMessages, convertedMessages...) + } + + // Convert tools if present + if request.Tools != nil { + var bifrostTools []schemas.ResponsesTool + for _, tool := range *request.Tools { + bifrostTool := convertAnthropicToolToBifrost(&tool) + if bifrostTool != nil { + bifrostTools = append(bifrostTools, *bifrostTool) + } + } + if len(bifrostTools) > 0 { + bifrostReq.Params.Tools = bifrostTools + } + } + + // Convert tool choice if present + if request.ToolChoice != nil { + bifrostToolChoice := convertAnthropicToolChoiceToBifrost(request.ToolChoice) + if bifrostToolChoice != nil { + bifrostReq.Params.ToolChoice = bifrostToolChoice + } + } + + // Set the converted messages + if len(bifrostMessages) > 0 { + bifrostReq.Input = bifrostMessages + } + + return bifrostReq +} + +// ToAnthropicResponsesRequest converts a BifrostRequest with Responses structure back to AnthropicMessageRequest +func ToAnthropicResponsesRequest(bifrostReq *schemas.BifrostResponsesRequest) *AnthropicMessageRequest { + anthropicReq := &AnthropicMessageRequest{ + Model: bifrostReq.Model, + MaxTokens: AnthropicDefaultMaxTokens, + } + + // Convert basic parameters + if bifrostReq.Params != nil { + if bifrostReq.Params.MaxOutputTokens != nil { + anthropicReq.MaxTokens = *bifrostReq.Params.MaxOutputTokens + } + if bifrostReq.Params.Temperature != nil { + anthropicReq.Temperature = bifrostReq.Params.Temperature + } + if bifrostReq.Params.TopP != nil { + anthropicReq.TopP = bifrostReq.Params.TopP + } + if bifrostReq.Params.ExtraParams != nil { + topK, ok := schemas.SafeExtractIntPointer(bifrostReq.Params.ExtraParams["top_k"]) + if ok { + anthropicReq.TopK = topK + } + if stop, ok := schemas.SafeExtractStringSlicePointer(bifrostReq.Params.ExtraParams["stop"]); ok { + anthropicReq.StopSequences = stop + } + } + + // Convert tools + if bifrostReq.Params.Tools != nil { + anthropicTools := []AnthropicTool{} + for _, tool := range bifrostReq.Params.Tools { + anthropicTool := convertBifrostToolToAnthropic(&tool) + if anthropicTool != nil { + anthropicTools = append(anthropicTools, *anthropicTool) + } + } + if len(anthropicTools) > 0 { + anthropicReq.Tools = &anthropicTools + } + } + + // Convert tool choice + if bifrostReq.Params.ToolChoice != nil { + anthropicToolChoice := convertResponsesToolChoiceToAnthropic(bifrostReq.Params.ToolChoice) + if anthropicToolChoice != nil { + anthropicReq.ToolChoice = anthropicToolChoice + } + } + } + + if bifrostReq.Input != nil { + anthropicMessages, systemContent := convertResponsesMessagesToAnthropicMessages(bifrostReq.Input) + + // Set system message if present + if systemContent != nil { + anthropicReq.System = systemContent + } + + // Set regular messages + anthropicReq.Messages = anthropicMessages + } + + return anthropicReq +} + +// ToAnthropicResponsesResponse converts an Anthropic response to BifrostResponse with Responses structure +func (anthropicResp *AnthropicMessageResponse) ToResponsesBifrostResponse() *schemas.BifrostResponse { + if anthropicResp == nil { + return nil + } + + // Create the BifrostResponse with Responses structure + bifrostResp := &schemas.BifrostResponse{ + ID: anthropicResp.ID, + Model: anthropicResp.Model, + Object: "response", + ResponsesResponse: &schemas.ResponsesResponse{ + CreatedAt: int(time.Now().Unix()), + }, + } + + // Convert usage information + if anthropicResp.Usage != nil { + bifrostResp.Usage = &schemas.LLMUsage{ + TotalTokens: anthropicResp.Usage.InputTokens + anthropicResp.Usage.OutputTokens, + ResponsesExtendedResponseUsage: &schemas.ResponsesExtendedResponseUsage{ + InputTokens: anthropicResp.Usage.InputTokens, + OutputTokens: anthropicResp.Usage.OutputTokens, + }, + } + + // Handle cached tokens if present + if anthropicResp.Usage.CacheReadInputTokens > 0 { + if bifrostResp.Usage.ResponsesExtendedResponseUsage.InputTokensDetails == nil { + bifrostResp.Usage.ResponsesExtendedResponseUsage.InputTokensDetails = &schemas.ResponsesResponseInputTokens{} + } + bifrostResp.Usage.ResponsesExtendedResponseUsage.InputTokensDetails.CachedTokens = anthropicResp.Usage.CacheReadInputTokens + } + } + + // Convert content to Responses output messages + outputMessages := convertAnthropicContentBlocksToResponsesMessages(anthropicResp.Content) + if len(outputMessages) > 0 { + bifrostResp.ResponsesResponse.Output = outputMessages + } + + return bifrostResp +} + +// ConvertBifrostResponseToAnthropic converts a BifrostResponse with Responses structure back to AnthropicMessageResponse +func ToAnthropicResponsesResponse(bifrostResp *schemas.BifrostResponse) *AnthropicMessageResponse { + anthropicResp := &AnthropicMessageResponse{ + ID: bifrostResp.ID, + Model: bifrostResp.Model, + Type: "message", + Role: "assistant", + } + + // Convert usage information + if bifrostResp.Usage != nil { + anthropicResp.Usage = &AnthropicUsage{ + InputTokens: bifrostResp.Usage.PromptTokens, + OutputTokens: bifrostResp.Usage.CompletionTokens, + } + + responsesUsage := bifrostResp.Usage.ResponsesExtendedResponseUsage + + if responsesUsage != nil && responsesUsage.InputTokens > 0 { + anthropicResp.Usage.InputTokens = responsesUsage.InputTokens + } + + if responsesUsage != nil && responsesUsage.OutputTokens > 0 { + anthropicResp.Usage.OutputTokens = responsesUsage.OutputTokens + } + + // Handle cached tokens if present + if responsesUsage != nil && + responsesUsage.InputTokensDetails != nil && + responsesUsage.InputTokensDetails.CachedTokens > 0 { + anthropicResp.Usage.CacheReadInputTokens = responsesUsage.InputTokensDetails.CachedTokens + } + } + + // Convert output messages to Anthropic content blocks + var contentBlocks []AnthropicContentBlock + if bifrostResp.ResponsesResponse != nil && bifrostResp.ResponsesResponse.Output != nil { + contentBlocks = convertBifrostMessagesToAnthropicContent(bifrostResp.ResponsesResponse.Output) + } + + if len(contentBlocks) > 0 { + anthropicResp.Content = contentBlocks + } + + // Set default stop reason - could be enhanced based on additional context + stopReason := "end_turn" + anthropicResp.StopReason = &stopReason + + // Check if there are tool calls to set appropriate stop reason + for _, block := range contentBlocks { + if block.Type == AnthropicContentBlockTypeToolUse { + toolStopReason := "tool_use" + anthropicResp.StopReason = &toolStopReason + break + } + } + + return anthropicResp +} + +// convertAnthropicMessageToBifrostResponsesMessages converts AnthropicMessage to ChatMessage format +func convertAnthropicMessageToBifrostResponsesMessages(msg *AnthropicMessage) []schemas.ResponsesMessage { + var bifrostMessages []schemas.ResponsesMessage + + // Handle text content + if msg.Content.ContentStr != nil { + bifrostMsg := schemas.ResponsesMessage{ + Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), + Role: schemas.Ptr(schemas.ResponsesMessageRoleType(msg.Role)), + Content: &schemas.ResponsesMessageContent{ + ContentStr: msg.Content.ContentStr, + }, + } + bifrostMessages = append(bifrostMessages, bifrostMsg) + } else if msg.Content.ContentBlocks != nil { + // Handle content blocks + for _, block := range *msg.Content.ContentBlocks { + switch block.Type { + case AnthropicContentBlockTypeText: + if block.Text != nil { + bifrostMsg := schemas.ResponsesMessage{ + Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), + Role: schemas.Ptr(schemas.ResponsesMessageRoleType(msg.Role)), + Content: &schemas.ResponsesMessageContent{ + ContentStr: block.Text, + }, + } + bifrostMessages = append(bifrostMessages, bifrostMsg) + } + case AnthropicContentBlockTypeImage: + if block.Source != nil { + bifrostMsg := schemas.ResponsesMessage{ + Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), + Role: schemas.Ptr(schemas.ResponsesMessageRoleType(msg.Role)), + Content: &schemas.ResponsesMessageContent{ + ContentBlocks: &[]schemas.ResponsesMessageContentBlock{block.toBifrostResponsesImageBlock()}, + }, + } + bifrostMessages = append(bifrostMessages, bifrostMsg) + } + case AnthropicContentBlockTypeToolUse: + // Convert tool use to function call message + if block.ID != nil && block.Name != nil { + bifrostMsg := schemas.ResponsesMessage{ + Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall), + Status: schemas.Ptr("completed"), + ResponsesToolMessage: &schemas.ResponsesToolMessage{ + CallID: block.ID, + Name: block.Name, + Arguments: schemas.Ptr(schemas.JsonifyInput(block.Input)), + }, + } + bifrostMessages = append(bifrostMessages, bifrostMsg) + } + case AnthropicContentBlockTypeToolResult: + // Convert tool result to function call output message + if block.ToolUseID != nil { + if block.Content != nil { + bifrostMsg := schemas.ResponsesMessage{ + Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCallOutput), + Status: schemas.Ptr("completed"), + ResponsesToolMessage: &schemas.ResponsesToolMessage{ + CallID: block.ToolUseID, + }, + } + // Initialize the nested struct before any writes + bifrostMsg.ResponsesToolMessage.ResponsesFunctionToolCallOutput = &schemas.ResponsesFunctionToolCallOutput{} + + if block.Content.ContentStr != nil { + bifrostMsg.ResponsesToolMessage.ResponsesFunctionToolCallOutput.ResponsesFunctionToolCallOutputStr = block.Content.ContentStr + } else if block.Content.ContentBlocks != nil { + var toolMsgContentBlocks []schemas.ResponsesMessageContentBlock + for _, contentBlock := range *block.Content.ContentBlocks { + switch contentBlock.Type { + case AnthropicContentBlockTypeText: + if contentBlock.Text != nil { + toolMsgContentBlocks = append(toolMsgContentBlocks, schemas.ResponsesMessageContentBlock{ + Type: schemas.ResponsesInputMessageContentBlockTypeText, + Text: contentBlock.Text, + }) + } + case AnthropicContentBlockTypeImage: + if contentBlock.Source != nil { + toolMsgContentBlocks = append(toolMsgContentBlocks, contentBlock.toBifrostResponsesImageBlock()) + } + } + } + bifrostMsg.ResponsesToolMessage.ResponsesFunctionToolCallOutput.ResponsesFunctionToolCallOutputBlocks = &toolMsgContentBlocks + } + bifrostMessages = append(bifrostMessages, bifrostMsg) + } + } + } + } + } + + return bifrostMessages +} + +// convertAnthropicToolToBifrost converts AnthropicTool to schemas.Tool +func convertAnthropicToolToBifrost(tool *AnthropicTool) *schemas.ResponsesTool { + if tool == nil { + return nil + } + + bifrostTool := &schemas.ResponsesTool{ + Type: "function", + Name: &tool.Name, + Description: &tool.Description, + ResponsesToolFunction: &schemas.ResponsesToolFunction{ + Parameters: tool.InputSchema, + }, + } + + return bifrostTool +} + +// convertAnthropicToolChoiceToBifrost converts AnthropicToolChoice to schemas.ToolChoice +func convertAnthropicToolChoiceToBifrost(toolChoice *AnthropicToolChoice) *schemas.ResponsesToolChoice { + if toolChoice == nil { + return nil + } + + bifrostToolChoice := &schemas.ResponsesToolChoice{} + + // Handle string format + if toolChoice.Type != "" { + switch toolChoice.Type { + case "auto": + bifrostToolChoice.ResponsesToolChoiceStr = schemas.Ptr(string(schemas.ResponsesToolChoiceTypeAuto)) + case "any": + bifrostToolChoice.ResponsesToolChoiceStr = schemas.Ptr(string(schemas.ResponsesToolChoiceTypeAny)) + case "none": + bifrostToolChoice.ResponsesToolChoiceStr = schemas.Ptr(string(schemas.ResponsesToolChoiceTypeNone)) + case "tool": + // Handle forced tool choice with specific function name + bifrostToolChoice.ResponsesToolChoiceStruct = &schemas.ResponsesToolChoiceStruct{ + Type: schemas.ResponsesToolChoiceTypeFunction, + Name: &toolChoice.Name, + } + return bifrostToolChoice + default: + bifrostToolChoice.ResponsesToolChoiceStr = schemas.Ptr(string(schemas.ResponsesToolChoiceTypeAuto)) + } + } + + return bifrostToolChoice +} + +// Helper function to convert ResponsesInputItems back to AnthropicMessages +func convertResponsesMessagesToAnthropicMessages(messages []schemas.ResponsesMessage) ([]AnthropicMessage, *AnthropicContent) { + var anthropicMessages []AnthropicMessage + var systemContent *AnthropicContent + var pendingToolCalls []AnthropicContentBlock + var currentAssistantMessage *AnthropicMessage + + for _, msg := range messages { + // Handle nil Type as regular message + msgType := schemas.ResponsesMessageTypeMessage + if msg.Type != nil { + msgType = *msg.Type + } + + switch msgType { + case schemas.ResponsesMessageTypeMessage: + // Flush any pending tool calls first + if len(pendingToolCalls) > 0 && currentAssistantMessage != nil { + // Copy the slice to avoid aliasing issues + copied := make([]AnthropicContentBlock, len(pendingToolCalls)) + copy(copied, pendingToolCalls) + currentAssistantMessage.Content = AnthropicContent{ + ContentBlocks: &copied, + } + anthropicMessages = append(anthropicMessages, *currentAssistantMessage) + pendingToolCalls = nil + currentAssistantMessage = nil + } + + // Handle system messages separately + if msg.Role != nil && *msg.Role == schemas.ResponsesInputMessageRoleSystem { + if msg.Content != nil { + if msg.Content.ContentStr != nil { + systemContent = &AnthropicContent{ + ContentStr: msg.Content.ContentStr, + } + } else if msg.Content.ContentBlocks != nil { + contentBlocks := []AnthropicContentBlock{} + for _, block := range *msg.Content.ContentBlocks { + if anthropicBlock := convertContentBlockToAnthropic(block); anthropicBlock != nil { + contentBlocks = append(contentBlocks, *anthropicBlock) + } + } + if len(contentBlocks) > 0 { + systemContent = &AnthropicContent{ + ContentBlocks: &contentBlocks, + } + } + } + } + continue + } + + // Regular user/assistant message + anthropicMsg := AnthropicMessage{} + + // Set role + if msg.Role != nil { + switch *msg.Role { + case schemas.ResponsesInputMessageRoleUser: + anthropicMsg.Role = AnthropicMessageRoleUser + case schemas.ResponsesInputMessageRoleAssistant: + anthropicMsg.Role = AnthropicMessageRoleAssistant + default: + anthropicMsg.Role = AnthropicMessageRoleUser // Default fallback + } + } else { + anthropicMsg.Role = AnthropicMessageRoleUser // Default fallback + } + + // Convert content + if msg.Content != nil { + if msg.Content.ContentStr != nil { + anthropicMsg.Content = AnthropicContent{ + ContentStr: msg.Content.ContentStr, + } + } else if msg.Content.ContentBlocks != nil { + contentBlocks := []AnthropicContentBlock{} + for _, block := range *msg.Content.ContentBlocks { + if anthropicBlock := convertContentBlockToAnthropic(block); anthropicBlock != nil { + contentBlocks = append(contentBlocks, *anthropicBlock) + } + } + if len(contentBlocks) > 0 { + anthropicMsg.Content = AnthropicContent{ + ContentBlocks: &contentBlocks, + } + } + } + } + + anthropicMessages = append(anthropicMessages, anthropicMsg) + + case schemas.ResponsesMessageTypeReasoning: + // Handle reasoning as thinking content + if msg.ResponsesReasoning != nil && len(msg.ResponsesReasoning.Summary) > 0 { + // Find the last assistant message or create one + var targetMsg *AnthropicMessage + if len(anthropicMessages) > 0 && anthropicMessages[len(anthropicMessages)-1].Role == AnthropicMessageRoleAssistant { + targetMsg = &anthropicMessages[len(anthropicMessages)-1] + } else { + // Create new assistant message for reasoning + newMsg := AnthropicMessage{ + Role: AnthropicMessageRoleAssistant, + } + anthropicMessages = append(anthropicMessages, newMsg) + targetMsg = &anthropicMessages[len(anthropicMessages)-1] + } + + // Add thinking blocks + var contentBlocks []AnthropicContentBlock + if targetMsg.Content.ContentBlocks != nil { + contentBlocks = *targetMsg.Content.ContentBlocks + } + + for _, reasoningContent := range msg.ResponsesReasoning.Summary { + thinkingBlock := AnthropicContentBlock{ + Type: AnthropicContentBlockTypeThinking, + Thinking: &reasoningContent.Text, + } + contentBlocks = append(contentBlocks, thinkingBlock) + } + + targetMsg.Content = AnthropicContent{ + ContentBlocks: &contentBlocks, + } + } + + case schemas.ResponsesMessageTypeFunctionCall: + // Start accumulating tool calls for assistant message + if currentAssistantMessage == nil { + currentAssistantMessage = &AnthropicMessage{ + Role: AnthropicMessageRoleAssistant, + } + } + + if msg.ResponsesToolMessage != nil { + toolUseBlock := AnthropicContentBlock{ + Type: AnthropicContentBlockTypeToolUse, + } + + if msg.ResponsesToolMessage.CallID != nil { + toolUseBlock.ID = msg.ResponsesToolMessage.CallID + } + if msg.ResponsesToolMessage.Name != nil { + toolUseBlock.Name = msg.ResponsesToolMessage.Name + } + + // Parse arguments as JSON input + if msg.ResponsesToolMessage.Arguments != nil && *msg.ResponsesToolMessage.Arguments != "" { + toolUseBlock.Input = parseJSONInput(*msg.ResponsesToolMessage.Arguments) + } + + pendingToolCalls = append(pendingToolCalls, toolUseBlock) + } + + case schemas.ResponsesMessageTypeFunctionCallOutput: + // Flush any pending tool calls first before processing tool results + if len(pendingToolCalls) > 0 && currentAssistantMessage != nil { + // Copy the slice to avoid aliasing issues + copied := make([]AnthropicContentBlock, len(pendingToolCalls)) + copy(copied, pendingToolCalls) + currentAssistantMessage.Content = AnthropicContent{ + ContentBlocks: &copied, + } + anthropicMessages = append(anthropicMessages, *currentAssistantMessage) + pendingToolCalls = nil + currentAssistantMessage = nil + } + + // Handle tool call output - convert to user message with tool_result + if msg.ResponsesToolMessage != nil { + toolResultMsg := AnthropicMessage{ + Role: AnthropicMessageRoleUser, + } + + toolResultBlock := AnthropicContentBlock{ + Type: AnthropicContentBlockTypeToolResult, + } + + if msg.ResponsesToolMessage.CallID != nil { + toolResultBlock.ToolUseID = msg.ResponsesToolMessage.CallID + } + + // Convert tool output content + if msg.ResponsesToolMessage.ResponsesFunctionToolCallOutput != nil { + output := msg.ResponsesToolMessage.ResponsesFunctionToolCallOutput + if output.ResponsesFunctionToolCallOutputStr != nil { + toolResultBlock.Content = &AnthropicContent{ + ContentStr: output.ResponsesFunctionToolCallOutputStr, + } + } else if output.ResponsesFunctionToolCallOutputBlocks != nil { + var resultContentBlocks []AnthropicContentBlock + for _, block := range *output.ResponsesFunctionToolCallOutputBlocks { + if convertedBlock := convertContentBlockToAnthropic(block); convertedBlock != nil { + resultContentBlocks = append(resultContentBlocks, *convertedBlock) + } + } + if len(resultContentBlocks) > 0 { + toolResultBlock.Content = &AnthropicContent{ + ContentBlocks: &resultContentBlocks, + } + } + } + } + + toolResultMsg.Content = AnthropicContent{ + ContentBlocks: &[]AnthropicContentBlock{toolResultBlock}, + } + + anthropicMessages = append(anthropicMessages, toolResultMsg) + } + + case schemas.ResponsesMessageTypeItemReference: + // Handle item reference as regular text message + if msg.Content != nil && msg.Content.ContentStr != nil { + referenceMsg := AnthropicMessage{ + Role: AnthropicMessageRoleUser, // Default to user for references + } + if msg.Role != nil && *msg.Role == schemas.ResponsesInputMessageRoleAssistant { + referenceMsg.Role = AnthropicMessageRoleAssistant + } + + referenceMsg.Content = AnthropicContent{ + ContentStr: msg.Content.ContentStr, + } + + anthropicMessages = append(anthropicMessages, referenceMsg) + } + + // Handle other tool call types that are not natively supported by Anthropic + case schemas.ResponsesMessageTypeFileSearchCall, + schemas.ResponsesMessageTypeComputerCall, + schemas.ResponsesMessageTypeWebSearchCall, + schemas.ResponsesMessageTypeCodeInterpreterCall, + schemas.ResponsesMessageTypeLocalShellCall, + schemas.ResponsesMessageTypeMCPCall, + schemas.ResponsesMessageTypeCustomToolCall, + schemas.ResponsesMessageTypeImageGenerationCall: + // Convert unsupported tool calls to regular text messages + if msg.ResponsesToolMessage != nil { + toolCallMsg := AnthropicMessage{ + Role: AnthropicMessageRoleAssistant, + } + + var description string + if msg.ResponsesToolMessage.Name != nil { + description = fmt.Sprintf("Tool call: %s", *msg.ResponsesToolMessage.Name) + if msg.ResponsesToolMessage.Arguments != nil { + description += fmt.Sprintf(" with arguments: %s", *msg.ResponsesToolMessage.Arguments) + } + } else { + description = fmt.Sprintf("Tool call of type: %s", msgType) + } + + toolCallMsg.Content = AnthropicContent{ + ContentStr: &description, + } + + anthropicMessages = append(anthropicMessages, toolCallMsg) + } + + case schemas.ResponsesMessageTypeComputerCallOutput, + schemas.ResponsesMessageTypeLocalShellCallOutput, + schemas.ResponsesMessageTypeCustomToolCallOutput: + // Handle tool outputs as user messages + if msg.ResponsesToolMessage != nil { + toolOutputMsg := AnthropicMessage{ + Role: AnthropicMessageRoleUser, + } + + var outputText string + // Try to extract output text based on tool type + switch msgType { + case schemas.ResponsesMessageTypeLocalShellCallOutput: + if msg.ResponsesToolMessage.ResponsesLocalShellCallOutput != nil { + outputText = msg.ResponsesToolMessage.ResponsesLocalShellCallOutput.Output + } + case schemas.ResponsesMessageTypeCustomToolCallOutput: + if msg.ResponsesToolMessage.ResponsesCustomToolCallOutput != nil { + outputText = msg.ResponsesToolMessage.ResponsesCustomToolCallOutput.Output + } + } + + if outputText != "" { + toolOutputMsg.Content = AnthropicContent{ + ContentStr: &outputText, + } + anthropicMessages = append(anthropicMessages, toolOutputMsg) + } + } + + default: + // Skip unknown message types or log them for debugging + continue + } + } + + // Flush any remaining pending tool calls + if len(pendingToolCalls) > 0 && currentAssistantMessage != nil { + // Copy the slice to avoid aliasing issues + copied := make([]AnthropicContentBlock, len(pendingToolCalls)) + copy(copied, pendingToolCalls) + currentAssistantMessage.Content = AnthropicContent{ + ContentBlocks: &copied, + } + anthropicMessages = append(anthropicMessages, *currentAssistantMessage) + } + + return anthropicMessages, systemContent +} + +// Helper function to parse JSON input arguments back to interface{} +func parseJSONInput(jsonStr string) interface{} { + if jsonStr == "" || jsonStr == "{}" { + return map[string]interface{}{} + } + + var result interface{} + if err := json.Unmarshal([]byte(jsonStr), &result); err != nil { + // If parsing fails, return as string + return jsonStr + } + + return result +} + +// Helper function to convert Tool back to AnthropicTool +func convertBifrostToolToAnthropic(tool *schemas.ResponsesTool) *AnthropicTool { + if tool == nil { + return nil + } + + anthropicTool := &AnthropicTool{ + Type: schemas.Ptr(AnthropicToolTypeCustom), + } + + // Try to extract from ResponsesExtendedTool if present + if tool.Name != nil { + anthropicTool.Name = *tool.Name + } + + if tool.Description != nil { + anthropicTool.Description = *tool.Description + } + + // Convert parameters from ToolFunction + if tool.ResponsesToolFunction != nil { + anthropicTool.InputSchema = tool.ResponsesToolFunction.Parameters + } + + return anthropicTool +} + +// Helper function to convert ResponsesToolChoice back to AnthropicToolChoice +func convertResponsesToolChoiceToAnthropic(toolChoice *schemas.ResponsesToolChoice) *AnthropicToolChoice { + if toolChoice == nil || toolChoice.ResponsesToolChoiceStruct == nil { + return nil + } + + anthropicChoice := &AnthropicToolChoice{} + + var toolChoiceType *string + if toolChoice.ResponsesToolChoiceStruct != nil { + toolChoiceType = schemas.Ptr(string(toolChoice.ResponsesToolChoiceStruct.Type)) + } else { + toolChoiceType = toolChoice.ResponsesToolChoiceStr + } + + switch *toolChoiceType { + case "auto": + anthropicChoice.Type = "auto" + case "required": + anthropicChoice.Type = "any" + case "function": + // Handle function type - set as "tool" with specific function name + if toolChoice.ResponsesToolChoiceStruct != nil && toolChoice.ResponsesToolChoiceStruct.Name != nil { + anthropicChoice.Type = "tool" + anthropicChoice.Name = *toolChoice.ResponsesToolChoiceStruct.Name + } + return anthropicChoice + } + + // Legacy fallback: also check for Name field (for backward compatibility) + if toolChoice.ResponsesToolChoiceStruct != nil && toolChoice.ResponsesToolChoiceStruct.Name != nil { + anthropicChoice.Type = "tool" + anthropicChoice.Name = *toolChoice.ResponsesToolChoiceStruct.Name + } + + return anthropicChoice +} + +// Helper function to convert Anthropic content blocks to Responses output messages +func convertAnthropicContentBlocksToResponsesMessages(content []AnthropicContentBlock) []schemas.ResponsesMessage { + var messages []schemas.ResponsesMessage + + for _, block := range content { + switch block.Type { + case "text": + if block.Text != nil { + // Append text to existing message + messages = append(messages, schemas.ResponsesMessage{ + Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant), + Content: &schemas.ResponsesMessageContent{ + ContentStr: block.Text, + }, + }) + } + + case "thinking": + if block.Thinking != nil { + // Create reasoning message + messages = append(messages, schemas.ResponsesMessage{ + Type: schemas.Ptr(schemas.ResponsesMessageTypeReasoning), + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant), + Content: &schemas.ResponsesMessageContent{ + ContentBlocks: &[]schemas.ResponsesMessageContentBlock{ + { + Type: schemas.ResponsesOutputMessageContentTypeReasoning, + Text: block.Thinking, + }, + }, + }, + }) + } + + case "tool_use": + if block.ID != nil && block.Name != nil { + // Create function call message + messages = append(messages, schemas.ResponsesMessage{ + Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall), + Status: schemas.Ptr("completed"), + ResponsesToolMessage: &schemas.ResponsesToolMessage{ + CallID: block.ID, + Name: block.Name, + Arguments: schemas.Ptr(schemas.JsonifyInput(block.Input)), + }, + }) + } + + default: + // Handle other block types if needed + } + } + + return messages +} + +// Helper function to convert ChatMessage output to Anthropic content blocks +func convertBifrostMessagesToAnthropicContent(messages []schemas.ResponsesMessage) []AnthropicContentBlock { + var contentBlocks []AnthropicContentBlock + + for _, msg := range messages { + // Handle different message types based on Responses structure + if msg.Type != nil { + switch *msg.Type { + case schemas.ResponsesMessageTypeMessage: + // Regular text message + if msg.Content.ContentStr != nil { + contentBlocks = append(contentBlocks, AnthropicContentBlock{ + Type: "text", + Text: msg.Content.ContentStr, + }) + } else if msg.Content.ContentBlocks != nil { + // Convert content blocks + for _, block := range *msg.Content.ContentBlocks { + anthropicBlock := convertContentBlockToAnthropic(block) + if anthropicBlock != nil { + contentBlocks = append(contentBlocks, *anthropicBlock) + } + } + } + + case schemas.ResponsesMessageTypeFunctionCall: + if msg.ResponsesToolMessage != nil && msg.ResponsesToolMessage.CallID != nil { + toolBlock := AnthropicContentBlock{ + Type: AnthropicContentBlockTypeToolUse, + ID: msg.ResponsesToolMessage.CallID, + } + if msg.ResponsesToolMessage.Name != nil { + toolBlock.Name = msg.ResponsesToolMessage.Name + } + if msg.ResponsesToolMessage.Arguments != nil && *msg.ResponsesToolMessage.Arguments != "" { + toolBlock.Input = parseJSONInput(*msg.ResponsesToolMessage.Arguments) + } + contentBlocks = append(contentBlocks, toolBlock) + } + + case schemas.ResponsesMessageTypeFunctionCallOutput: + // Tool result block - need to extract from ToolMessage + resultBlock := AnthropicContentBlock{ + Type: "tool_result", + } + + // Extract result content from ToolMessage or Content + if msg.ResponsesToolMessage != nil { + // Copy the call ID to maintain association between result and call + if msg.ResponsesToolMessage.CallID != nil { + resultBlock.ToolUseID = msg.ResponsesToolMessage.CallID + } + + // Try to get content from the tool message structure + if msg.Content != nil && msg.Content.ContentStr != nil { + resultBlock.Content = &AnthropicContent{ + ContentStr: msg.Content.ContentStr, + } + } else if msg.ResponsesToolMessage.ResponsesFunctionToolCallOutput != nil { + // Guard access to ResponsesFunctionToolCallOutput + if msg.ResponsesToolMessage.ResponsesFunctionToolCallOutput.ResponsesFunctionToolCallOutputStr != nil { + resultBlock.Content = &AnthropicContent{ + ContentStr: msg.ResponsesToolMessage.ResponsesFunctionToolCallOutput.ResponsesFunctionToolCallOutputStr, + } + } else if msg.ResponsesToolMessage.ResponsesFunctionToolCallOutput.ResponsesFunctionToolCallOutputBlocks != nil { + var resultBlocks []AnthropicContentBlock + for _, block := range *msg.ResponsesToolMessage.ResponsesFunctionToolCallOutput.ResponsesFunctionToolCallOutputBlocks { + if block.Type == schemas.ResponsesInputMessageContentBlockTypeText { + resultBlocks = append(resultBlocks, AnthropicContentBlock{ + Type: AnthropicContentBlockTypeText, + Text: block.Text, + }) + } else if block.Type == schemas.ResponsesInputMessageContentBlockTypeImage { + if block.ResponsesInputMessageContentBlockImage != nil && block.ResponsesInputMessageContentBlockImage.ImageURL != nil { + resultBlocks = append(resultBlocks, AnthropicContentBlock{ + Type: AnthropicContentBlockTypeImage, + Source: &AnthropicImageSource{ + Type: "url", + URL: block.ResponsesInputMessageContentBlockImage.ImageURL, + }, + }) + } + } + } + resultBlock.Content = &AnthropicContent{ + ContentBlocks: &resultBlocks, + } + } + } + } else if msg.Content != nil { + // Fallback to msg.Content when ResponsesToolMessage is nil + if msg.Content.ContentStr != nil { + resultBlock.Content = &AnthropicContent{ + ContentStr: msg.Content.ContentStr, + } + } + } + + contentBlocks = append(contentBlocks, resultBlock) + + case schemas.ResponsesMessageTypeReasoning: + // Thinking block (Claude 3.5 Sonnet specific) + if msg.Content.ContentStr != nil { + contentBlock := AnthropicContentBlock{ + Type: AnthropicContentBlockTypeThinking, + } + + if msg.ResponsesReasoning != nil { + var thinking string + if msg.ResponsesReasoning.Summary != nil { + for _, block := range msg.ResponsesReasoning.Summary { + thinking += block.Text + } + } + contentBlock.Thinking = &thinking + } + contentBlocks = append(contentBlocks, contentBlock) + } + + default: + // Handle other types as text if they have content + if msg.Content.ContentStr != nil { + contentBlocks = append(contentBlocks, AnthropicContentBlock{ + Type: AnthropicContentBlockTypeText, + Text: msg.Content.ContentStr, + }) + } + } + } + } + + return contentBlocks +} + +// Helper function to convert ContentBlock to AnthropicContentBlock +func convertContentBlockToAnthropic(block schemas.ResponsesMessageContentBlock) *AnthropicContentBlock { + switch block.Type { + case schemas.ResponsesInputMessageContentBlockTypeText, schemas.ResponsesOutputMessageContentTypeText: + if block.Text != nil { + return &AnthropicContentBlock{ + Type: AnthropicContentBlockTypeText, + Text: block.Text, + } + } + case schemas.ResponsesInputMessageContentBlockTypeImage: + if block.ResponsesInputMessageContentBlockImage != nil && block.ResponsesInputMessageContentBlockImage.ImageURL != nil { + // Convert using the same logic as ConvertToAnthropicImageBlock + chatBlock := schemas.ChatContentBlock{ + Type: schemas.ChatContentBlockTypeImage, + ImageURLStruct: &schemas.ChatInputImage{ + URL: *block.ResponsesInputMessageContentBlockImage.ImageURL, + }, + } + anthropicBlock := ConvertToAnthropicImageBlock(chatBlock) + return &anthropicBlock + } + case schemas.ResponsesOutputMessageContentTypeReasoning: + if block.Text != nil { + return &AnthropicContentBlock{ + Type: AnthropicContentBlockTypeThinking, + Thinking: block.Text, + } + } + } + return nil +} + +func (block AnthropicContentBlock) toBifrostResponsesImageBlock() schemas.ResponsesMessageContentBlock { + return schemas.ResponsesMessageContentBlock{ + Type: schemas.ResponsesInputMessageContentBlockTypeImage, + ResponsesInputMessageContentBlockImage: &schemas.ResponsesInputMessageContentBlockImage{ + ImageURL: schemas.Ptr(getImageURLFromBlock(block)), + }, + } +} diff --git a/core/schemas/providers/anthropic/text.go b/core/schemas/providers/anthropic/text.go index 3198d87901..d8859e4050 100644 --- a/core/schemas/providers/anthropic/text.go +++ b/core/schemas/providers/anthropic/text.go @@ -2,15 +2,27 @@ package anthropic import ( "fmt" + "strings" "github.com/maximhq/bifrost/core/schemas" ) // ToAnthropicTextCompletionRequest converts a Bifrost text completion request to Anthropic format -func ToAnthropicTextCompletionRequest(bifrostReq *schemas.BifrostRequest) *AnthropicTextRequest { +func ToAnthropicTextCompletionRequest(bifrostReq *schemas.BifrostTextCompletionRequest) *AnthropicTextRequest { + if bifrostReq == nil { + return nil + } + + prompt := "" + if bifrostReq.Input.PromptStr != nil { + prompt = *bifrostReq.Input.PromptStr + } else if len(bifrostReq.Input.PromptArray) > 0 { + prompt = strings.Join(bifrostReq.Input.PromptArray, "\n\n") + } + anthropicReq := &AnthropicTextRequest{ Model: bifrostReq.Model, - Prompt: fmt.Sprintf("\n\nHuman: %s\n\nAssistant:", *bifrostReq.Input.TextCompletionInput), + Prompt: fmt.Sprintf("\n\nHuman: %s\n\nAssistant:", prompt), MaxTokensToSample: AnthropicDefaultMaxTokens, // Default value } @@ -21,8 +33,13 @@ func ToAnthropicTextCompletionRequest(bifrostReq *schemas.BifrostRequest) *Anthr } anthropicReq.Temperature = bifrostReq.Params.Temperature anthropicReq.TopP = bifrostReq.Params.TopP - anthropicReq.TopK = bifrostReq.Params.TopK - anthropicReq.StopSequences = bifrostReq.Params.StopSequences + anthropicReq.StopSequences = bifrostReq.Params.Stop + + if bifrostReq.Params.ExtraParams != nil { + if topK, ok := schemas.SafeExtractIntPointer(bifrostReq.Params.ExtraParams["top_k"]); ok { + anthropicReq.TopK = topK + } + } } return anthropicReq @@ -31,16 +48,11 @@ func ToAnthropicTextCompletionRequest(bifrostReq *schemas.BifrostRequest) *Anthr func (response *AnthropicTextResponse) ToBifrostResponse() *schemas.BifrostResponse { return &schemas.BifrostResponse{ ID: response.ID, - Choices: []schemas.BifrostResponseChoice{ + Choices: []schemas.BifrostChatResponseChoice{ { Index: 0, - BifrostNonStreamResponseChoice: &schemas.BifrostNonStreamResponseChoice{ - Message: schemas.BifrostMessage{ - Role: schemas.ModelChatMessageRoleAssistant, - Content: schemas.MessageContent{ - ContentStr: &response.Completion, - }, - }, + BifrostTextCompletionResponseChoice: &schemas.BifrostTextCompletionResponseChoice{ + Text: &response.Completion, }, }, }, @@ -51,7 +63,8 @@ func (response *AnthropicTextResponse) ToBifrostResponse() *schemas.BifrostRespo }, Model: response.Model, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: schemas.Anthropic, + RequestType: schemas.TextCompletionRequest, + Provider: schemas.Anthropic, }, } } diff --git a/core/schemas/providers/anthropic/types.go b/core/schemas/providers/anthropic/types.go index 8cc1a59c01..a72228b408 100644 --- a/core/schemas/providers/anthropic/types.go +++ b/core/schemas/providers/anthropic/types.go @@ -45,10 +45,17 @@ func (r *AnthropicMessageRequest) IsStreamingRequested() bool { return r.Stream != nil && *r.Stream } +type AnthropicMessageRole string + +const ( + AnthropicMessageRoleUser AnthropicMessageRole = "user" + AnthropicMessageRoleAssistant AnthropicMessageRole = "assistant" +) + // AnthropicMessage represents a message in Anthropic format type AnthropicMessage struct { - Role string `json:"role"` // "user", "assistant" - Content AnthropicContent `json:"content"` // Array of content blocks + Role AnthropicMessageRole `json:"role"` // "user", "assistant" + Content AnthropicContent `json:"content"` // Array of content blocks } // AnthropicContent represents content that can be either string or array of blocks @@ -95,17 +102,27 @@ func (mc *AnthropicContent) UnmarshalJSON(data []byte) error { return fmt.Errorf("content field is neither a string nor an array of ContentBlock") } +type AnthropicContentBlockType string + +const ( + AnthropicContentBlockTypeText AnthropicContentBlockType = "text" + AnthropicContentBlockTypeImage AnthropicContentBlockType = "image" + AnthropicContentBlockTypeToolUse AnthropicContentBlockType = "tool_use" + AnthropicContentBlockTypeToolResult AnthropicContentBlockType = "tool_result" + AnthropicContentBlockTypeThinking AnthropicContentBlockType = "thinking" +) + // AnthropicContentBlock represents content in Anthropic message format type AnthropicContentBlock struct { - Type string `json:"type"` // "text", "image", "tool_use", "tool_result", "thinking" - Text *string `json:"text,omitempty"` // For text content - Thinking *string `json:"thinking,omitempty"` // For thinking content - ToolUseID *string `json:"tool_use_id,omitempty"` // For tool_result content - ID *string `json:"id,omitempty"` // For tool_use content - Name *string `json:"name,omitempty"` // For tool_use content - Input interface{} `json:"input,omitempty"` // For tool_use content - Content *AnthropicContent `json:"content,omitempty"` // For tool_result content - Source *AnthropicImageSource `json:"source,omitempty"` // For image content + Type AnthropicContentBlockType `json:"type"` // "text", "image", "tool_use", "tool_result", "thinking" + Text *string `json:"text,omitempty"` // For text content + Thinking *string `json:"thinking,omitempty"` // For thinking content + ToolUseID *string `json:"tool_use_id,omitempty"` // For tool_result content + ID *string `json:"id,omitempty"` // For tool_use content + Name *string `json:"name,omitempty"` // For tool_use content + Input interface{} `json:"input,omitempty"` // For tool_use content + Content *AnthropicContent `json:"content,omitempty"` // For tool_result content + Source *AnthropicImageSource `json:"source,omitempty"` // For image content } // AnthropicImageSource represents image source in Anthropic format @@ -136,14 +153,10 @@ const ( // AnthropicTool represents a tool in Anthropic format type AnthropicTool struct { - Name string `json:"name"` - Type *string `json:"type,omitempty"` - Description string `json:"description"` - InputSchema *struct { - Type string `json:"type"` // "object" - Properties map[string]interface{} `json:"properties"` - Required []string `json:"required"` - } `json:"input_schema,omitempty"` + Name string `json:"name"` + Type *AnthropicToolType `json:"type,omitempty"` + Description string `json:"description"` + InputSchema *schemas.ToolFunctionParameters `json:"input_schema,omitempty"` } // AnthropicToolChoice represents tool choice in Anthropic format @@ -176,28 +189,6 @@ type AnthropicMessageResponse struct { Usage *AnthropicUsage `json:"usage,omitempty"` } -// AnthropicChatResponse represents the response structure from Anthropic's chat completion API (legacy) -type AnthropicChatResponse struct { - ID string `json:"id"` // Unique identifier for the completion - Type string `json:"type"` // Type of completion - Role string `json:"role"` // Role of the message sender - Content []struct { - Type string `json:"type"` // Type of content - Text string `json:"text,omitempty"` // Text content - Thinking string `json:"thinking,omitempty"` // Thinking process - ID string `json:"id"` // Content identifier - Name string `json:"name"` // Name of the content - Input map[string]interface{} `json:"input"` // Input parameters - } `json:"content"` // Array of content items - Model string `json:"model"` // Model used for the completion - StopReason string `json:"stop_reason,omitempty"` // Reason for completion termination - StopSequence *string `json:"stop_sequence,omitempty"` // Sequence that caused completion to stop - Usage struct { - InputTokens int `json:"input_tokens"` // Number of input tokens used - OutputTokens int `json:"output_tokens"` // Number of output tokens generated - } `json:"usage"` // Token usage statistics -} - // AnthropicTextResponse represents the response structure from Anthropic's text completion API type AnthropicTextResponse struct { ID string `json:"id"` // Unique identifier for the completion diff --git a/core/schemas/providers/anthropic/utils.go b/core/schemas/providers/anthropic/utils.go index 8e786c6618..9acc1cf5b8 100644 --- a/core/schemas/providers/anthropic/utils.go +++ b/core/schemas/providers/anthropic/utils.go @@ -1,49 +1,68 @@ package anthropic import ( - "encoding/json" "github.com/maximhq/bifrost/core/schemas" ) -// mapAnthropicFinishReasonToOpenAI maps Anthropic finish reasons to OpenAI-compatible ones -func MapAnthropicFinishReason(anthropicReason string) string { - switch anthropicReason { - case "end_turn": - return "stop" - case "max_tokens": - return "length" - case "stop_sequence": - return "stop" - case "tool_use": - return "tool_calls" - default: - // Pass through Anthropic-specific reasons like "pause_turn", "refusal", etc. - return anthropicReason +var ( + finishReasonMap = map[string]string{ + "end_turn": "stop", + "max_tokens": "length", + "stop_sequence": "stop", + "tool_use": "tool_calls", } -} -// Helper function to convert interface{} to JSON string -func jsonifyInput(input interface{}) string { - if input == nil { - return "{}" + // reverseFinishReasonMap = func() map[string]string { + // m := make(map[string]string, len(finishReasonMap)) + // for k, v := range finishReasonMap { + // m[v] = k + // } + // return m + // }() + + reverseFinishReasonMap = map[string]string{ + "stop": "end_turn", // canonical default + "length": "max_tokens", + "tool_calls": "tool_use", } - jsonBytes, err := json.Marshal(input) - if err != nil { - return "{}" +) + +// MapAnthropicFinishReasonToOpenAI maps Anthropic finish reasons to OpenAI-compatible ones +func MapAnthropicFinishReasonToBifrost(anthropicReason string) string { + if _, ok := finishReasonMap[anthropicReason]; ok { + return finishReasonMap[anthropicReason] } - return string(jsonBytes) + return anthropicReason +} + +// MapBifrostFinishReasonToAnthropic maps Bifrost finish reasons back to Anthropic +func MapBifrostFinishReasonToAnthropic(bifrostReason string) string { + if mapped, ok := reverseFinishReasonMap[bifrostReason]; ok { + return mapped + } + return bifrostReason } // ConvertToAnthropicImageBlock converts a Bifrost image block to Anthropic format // Uses the same pattern as the original buildAnthropicImageSourceMap function -func convertImageBlock(block schemas.ContentBlock) AnthropicContentBlock { +func ConvertToAnthropicImageBlock(block schemas.ChatContentBlock) AnthropicContentBlock { imageBlock := AnthropicContentBlock{ Type: "image", Source: &AnthropicImageSource{}, } + if block.ImageURLStruct == nil { + return imageBlock + } + // Use the centralized utility functions from schemas package - sanitizedURL, _ := schemas.SanitizeImageURL(block.ImageURL.URL) + sanitizedURL, err := schemas.SanitizeImageURL(block.ImageURLStruct.URL) + if err != nil { + // Best-effort: treat as a regular URL without sanitization + imageBlock.Source.Type = "url" + imageBlock.Source.URL = &block.ImageURLStruct.URL + return imageBlock + } urlTypeInfo := schemas.ExtractURLTypeInfo(sanitizedURL) formattedImgContent := &AnthropicImageContent{ @@ -78,4 +97,35 @@ func convertImageBlock(block schemas.ContentBlock) AnthropicContentBlock { } return imageBlock -} \ No newline at end of file +} + +func (block AnthropicContentBlock) ToBifrostContentImageBlock() schemas.ChatContentBlock { + return schemas.ChatContentBlock{ + Type: schemas.ChatContentBlockTypeImage, + ImageURLStruct: &schemas.ChatInputImage{ + URL: getImageURLFromBlock(block), + }, + } +} + +func getImageURLFromBlock(block AnthropicContentBlock) string { + if block.Source == nil { + return "" + } + + // Handle base64 data - convert to data URL + if block.Source.Data != nil { + mime := "image/png" + if block.Source.MediaType != nil && *block.Source.MediaType != "" { + mime = *block.Source.MediaType + } + return "data:" + mime + ";base64," + *block.Source.Data + } + + // Handle regular URLs + if block.Source.URL != nil { + return *block.Source.URL + } + + return "" +} diff --git a/core/schemas/providers/bedrock/chat.go b/core/schemas/providers/bedrock/chat.go index ead98b8700..34d67fe708 100644 --- a/core/schemas/providers/bedrock/chat.go +++ b/core/schemas/providers/bedrock/chat.go @@ -8,12 +8,12 @@ import ( ) // ToBedrockChatCompletionRequest converts a Bifrost request to Bedrock Converse API format -func ToBedrockChatCompletionRequest(bifrostReq *schemas.BifrostRequest) (*BedrockConverseRequest, error) { +func ToBedrockChatCompletionRequest(bifrostReq *schemas.BifrostChatRequest) (*BedrockConverseRequest, error) { if bifrostReq == nil { return nil, fmt.Errorf("bifrost request is nil") } - if bifrostReq.Input.ChatCompletionInput == nil { + if bifrostReq.Input == nil { return nil, fmt.Errorf("only chat completion requests are supported for Bedrock Converse API") } @@ -22,7 +22,7 @@ func ToBedrockChatCompletionRequest(bifrostReq *schemas.BifrostRequest) (*Bedroc } // Convert messages and system messages - messages, systemMessages, err := convertMessages(*bifrostReq.Input.ChatCompletionInput) + messages, systemMessages, err := convertMessages(bifrostReq.Input) if err != nil { return nil, fmt.Errorf("failed to convert messages: %w", err) } @@ -32,10 +32,10 @@ func ToBedrockChatCompletionRequest(bifrostReq *schemas.BifrostRequest) (*Bedroc } // Convert parameters and configurations - convertParameters(bifrostReq, bedrockReq) + convertChatParameters(bifrostReq, bedrockReq) // Ensure tool config is present when needed - ensureToolConfigForConversation(bifrostReq, bedrockReq) + ensureChatToolConfigForConversation(bifrostReq, bedrockReq) return bedrockReq, nil } @@ -47,15 +47,15 @@ func (bedrockResp *BedrockConverseResponse) ToBifrostResponse() (*schemas.Bifros } // Convert content blocks and tool calls - var contentBlocks []schemas.ContentBlock - var toolCalls []schemas.ToolCall + var contentBlocks []schemas.ChatContentBlock + var toolCalls []schemas.ChatAssistantMessageToolCall if bedrockResp.Output.Message != nil { for _, contentBlock := range bedrockResp.Output.Message.Content { // Handle text content if contentBlock.Text != nil && *contentBlock.Text != "" { - contentBlocks = append(contentBlocks, schemas.ContentBlock{ - Type: schemas.ContentBlockTypeText, + contentBlocks = append(contentBlocks, schemas.ChatContentBlock{ + Type: schemas.ChatContentBlockTypeText, Text: contentBlock.Text, }) } @@ -74,11 +74,15 @@ func (bedrockResp *BedrockConverseResponse) ToBifrostResponse() (*schemas.Bifros arguments = "{}" } - toolCalls = append(toolCalls, schemas.ToolCall{ + // Create copies of the values to avoid range loop variable capture + toolUseID := contentBlock.ToolUse.ToolUseID + toolUseName := contentBlock.ToolUse.Name + + toolCalls = append(toolCalls, schemas.ChatAssistantMessageToolCall{ Type: schemas.Ptr("function"), - ID: &contentBlock.ToolUse.ToolUseID, - Function: schemas.FunctionCall{ - Name: &contentBlock.ToolUse.Name, + ID: &toolUseID, + Function: schemas.ChatAssistantMessageToolCallFunction{ + Name: &toolUseName, Arguments: arguments, }, }) @@ -87,28 +91,28 @@ func (bedrockResp *BedrockConverseResponse) ToBifrostResponse() (*schemas.Bifros } // Create assistant message if we have tool calls - var assistantMessage *schemas.AssistantMessage + var assistantMessage *schemas.ChatAssistantMessage if len(toolCalls) > 0 { - assistantMessage = &schemas.AssistantMessage{ + assistantMessage = &schemas.ChatAssistantMessage{ ToolCalls: &toolCalls, } } // Create the message content - messageContent := schemas.MessageContent{} + messageContent := schemas.ChatMessageContent{} if len(contentBlocks) > 0 { messageContent.ContentBlocks = &contentBlocks } // Create the response choice - choices := []schemas.BifrostResponseChoice{ + choices := []schemas.BifrostChatResponseChoice{ { Index: 0, BifrostNonStreamResponseChoice: &schemas.BifrostNonStreamResponseChoice{ - Message: schemas.BifrostMessage{ - Role: schemas.ModelChatMessageRoleAssistant, - Content: messageContent, - AssistantMessage: assistantMessage, + Message: schemas.ChatMessage{ + Role: schemas.ChatMessageRoleAssistant, + Content: messageContent, + ChatAssistantMessage: assistantMessage, }, }, FinishReason: &bedrockResp.StopReason, @@ -130,8 +134,9 @@ func (bedrockResp *BedrockConverseResponse) ToBifrostResponse() (*schemas.Bifros Choices: choices, Usage: usage, ExtraFields: schemas.BifrostResponseExtraFields{ - Latency: &latency, - Provider: schemas.Bedrock, + RequestType: schemas.ChatCompletionRequest, + Latency: &latency, + Provider: schemas.Bedrock, }, } diff --git a/core/schemas/providers/bedrock/embedding.go b/core/schemas/providers/bedrock/embedding.go new file mode 100644 index 0000000000..fc8f0450cf --- /dev/null +++ b/core/schemas/providers/bedrock/embedding.go @@ -0,0 +1,96 @@ +package bedrock + +import ( + "fmt" + "strings" + + "github.com/maximhq/bifrost/core/schemas" + cohere "github.com/maximhq/bifrost/core/schemas/providers/cohere" +) + +// ToBedrockTitanEmbeddingRequest converts a Bifrost embedding request to Bedrock Titan format +func ToBedrockTitanEmbeddingRequest(bifrostReq *schemas.BifrostEmbeddingRequest) (*BedrockTitanEmbeddingRequest, error) { + if bifrostReq == nil { + return nil, fmt.Errorf("bifrost embedding request is nil") + } + + // Validate that only single text input is provided for Titan models + if bifrostReq.Input.Text == nil && len(bifrostReq.Input.Texts) == 0 { + return nil, fmt.Errorf("no input text provided for embedding") + } + + // Validate dimensions parameter - Titan models do not support it + if bifrostReq.Params != nil && bifrostReq.Params.Dimensions != nil { + return nil, fmt.Errorf("amazon Titan embedding models do not support custom dimensions parameter") + } + + titanReq := &BedrockTitanEmbeddingRequest{} + + // Set input text + if bifrostReq.Input.Text != nil { + titanReq.InputText = *bifrostReq.Input.Text + } else if len(bifrostReq.Input.Texts) > 0 { + var embeddingText string + for _, text := range bifrostReq.Input.Texts { + embeddingText += text + " \n" + } + titanReq.InputText = embeddingText + } + + return titanReq, nil +} + +// ToBifrostResponse converts a Bedrock Titan embedding response to Bifrost format +func (titanResp *BedrockTitanEmbeddingResponse) ToBifrostResponse(model string) *schemas.BifrostResponse { + if titanResp == nil { + return nil + } + + bifrostResponse := &schemas.BifrostResponse{ + Object: "list", + Data: []schemas.BifrostEmbedding{ + { + Index: 0, + Object: "embedding", + Embedding: schemas.BifrostEmbeddingResponse{ + Embedding2DArray: &[][]float32{titanResp.Embedding}, + }, + }, + }, + Model: model, + Usage: &schemas.LLMUsage{ + PromptTokens: titanResp.InputTextTokenCount, + TotalTokens: titanResp.InputTextTokenCount, + }, + } + + return bifrostResponse +} + +// ToBedrockCohereEmbeddingRequest converts a Bifrost embedding request to Bedrock Cohere format +// Reuses the Cohere converter since the format is identical +func ToBedrockCohereEmbeddingRequest(bifrostReq *schemas.BifrostEmbeddingRequest) (*cohere.CohereEmbeddingRequest, error) { + if bifrostReq == nil { + return nil, fmt.Errorf("bifrost embedding request is nil") + } + + // Reuse Cohere's converter - the format is identical for Bedrock + cohereReq := cohere.ToCohereEmbeddingRequest(bifrostReq) + if cohereReq == nil { + return nil, fmt.Errorf("failed to convert to Cohere embedding request") + } + + return cohereReq, nil +} + +// DetermineEmbeddingModelType determines the embedding model type from the model name +func DetermineEmbeddingModelType(model string) (string, error) { + switch { + case strings.Contains(model, "amazon.titan-embed-text"): + return "titan", nil + case strings.Contains(model, "cohere.embed"): + return "cohere", nil + default: + return "", fmt.Errorf("unsupported embedding model: %s", model) + } +} diff --git a/core/schemas/providers/bedrock/responses.go b/core/schemas/providers/bedrock/responses.go new file mode 100644 index 0000000000..24d9bfb98b --- /dev/null +++ b/core/schemas/providers/bedrock/responses.go @@ -0,0 +1,532 @@ +package bedrock + +import ( + "encoding/json" + "fmt" + "time" + + "github.com/maximhq/bifrost/core/schemas" +) + +// ToBedrockResponsesRequest converts a BifrostRequest (Responses structure) back to BedrockConverseRequest +func ToBedrockResponsesRequest(bifrostReq *schemas.BifrostResponsesRequest) (*BedrockConverseRequest, error) { + if bifrostReq == nil { + return nil, fmt.Errorf("bifrost request is nil") + } + + bedrockReq := &BedrockConverseRequest{ + ModelID: bifrostReq.Model, + } + + // map bifrost messages to bedrock messages + if bifrostReq.Input != nil { + messages, systemMessages, err := convertResponsesItemsToBedrockMessages(bifrostReq.Input) + if err != nil { + return nil, fmt.Errorf("failed to convert Responses messages: %w", err) + } + bedrockReq.Messages = messages + if len(systemMessages) > 0 { + bedrockReq.System = &systemMessages + } + } + + // Map basic parameters to inference config + if bifrostReq.Params != nil { + inferenceConfig := &BedrockInferenceConfig{} + + if bifrostReq.Params.MaxOutputTokens != nil { + inferenceConfig.MaxTokens = bifrostReq.Params.MaxOutputTokens + } + if bifrostReq.Params.Temperature != nil { + inferenceConfig.Temperature = bifrostReq.Params.Temperature + } + if bifrostReq.Params.TopP != nil { + inferenceConfig.TopP = bifrostReq.Params.TopP + } + if bifrostReq.Params.ExtraParams != nil { + if stop, ok := schemas.SafeExtractStringSlicePointer(bifrostReq.Params.ExtraParams["stop"]); ok { + inferenceConfig.StopSequences = stop + } + } + + bedrockReq.InferenceConfig = inferenceConfig + } + + // Convert tools + if bifrostReq.Params != nil && bifrostReq.Params.Tools != nil { + var bedrockTools []BedrockTool + for _, tool := range bifrostReq.Params.Tools { + if tool.ResponsesToolFunction != nil { + // Create the complete schema object that Bedrock expects + var schemaObject interface{} + if tool.ResponsesToolFunction.Parameters != nil { + schemaObject = tool.ResponsesToolFunction.Parameters + } else { + // Fallback to empty object schema if no parameters + schemaObject = map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{}, + } + } + + if tool.Name == nil || *tool.Name == "" { + return nil, fmt.Errorf("responses tool is missing required name for Bedrock function conversion") + } + name := *tool.Name + + // Use the tool description if available, otherwise use a generic description + description := "Function tool" + if tool.Description != nil { + description = *tool.Description + } + + bedrockTool := BedrockTool{ + ToolSpec: &BedrockToolSpec{ + Name: name, + Description: &description, + InputSchema: BedrockToolInputSchema{ + JSON: schemaObject, + }, + }, + } + bedrockTools = append(bedrockTools, bedrockTool) + } + } + + if len(bedrockTools) > 0 { + bedrockReq.ToolConfig = &BedrockToolConfig{ + Tools: &bedrockTools, + } + } + } + + // Convert tool choice + if bifrostReq.Params != nil && bifrostReq.Params.ToolChoice != nil { + bedrockToolChoice := convertResponsesToolChoice(*bifrostReq.Params.ToolChoice) + if bedrockToolChoice != nil { + if bedrockReq.ToolConfig == nil { + bedrockReq.ToolConfig = &BedrockToolConfig{} + } + bedrockReq.ToolConfig.ToolChoice = bedrockToolChoice + } + } + + // Ensure tool config is present when tool content exists (similar to Chat Completions) + ensureResponsesToolConfigForConversation(bifrostReq, bedrockReq) + + return bedrockReq, nil +} + +// ensureResponsesToolConfigForConversation ensures toolConfig is present when tool content exists +func ensureResponsesToolConfigForConversation(bifrostReq *schemas.BifrostResponsesRequest, bedrockReq *BedrockConverseRequest) { + if bedrockReq.ToolConfig != nil { + return // Already has tool config + } + + hasToolContent, tools := extractToolsFromResponsesConversationHistory(bifrostReq.Input) + if hasToolContent && len(tools) > 0 { + bedrockReq.ToolConfig = &BedrockToolConfig{Tools: &tools} + } +} + +// extractToolsFromResponsesConversationHistory extracts tools from Responses conversation history +func extractToolsFromResponsesConversationHistory(messages []schemas.ResponsesMessage) (bool, []BedrockTool) { + var hasToolContent bool + toolMap := make(map[string]*schemas.ResponsesTool) // Use map to deduplicate by name + + for _, msg := range messages { + // Check if message contains tool use or tool result + if msg.Type != nil { + switch *msg.Type { + case schemas.ResponsesMessageTypeFunctionCall, schemas.ResponsesMessageTypeFunctionCallOutput: + hasToolContent = true + // Try to infer tool definition from tool call/result + if msg.ResponsesToolMessage != nil && msg.ResponsesToolMessage.Name != nil { + toolName := *msg.ResponsesToolMessage.Name + if _, exists := toolMap[toolName]; !exists { + // Create a minimal tool definition + toolMap[toolName] = &schemas.ResponsesTool{ + Type: "function", + Name: &toolName, + ResponsesToolFunction: &schemas.ResponsesToolFunction{ + Parameters: &schemas.ToolFunctionParameters{ + Type: "object", + Properties: make(map[string]interface{}), + }, + }, + } + } + } + } + } + } + + // Convert map to slice + var tools []BedrockTool + for _, tool := range toolMap { + if tool.Name != nil && tool.ResponsesToolFunction != nil { + schemaObject := tool.ResponsesToolFunction.Parameters + if schemaObject == nil { + schemaObject = &schemas.ToolFunctionParameters{ + Type: "object", + Properties: make(map[string]interface{}), + } + } + + description := "Function tool" + if tool.Description != nil { + description = *tool.Description + } + + bedrockTool := BedrockTool{ + ToolSpec: &BedrockToolSpec{ + Name: *tool.Name, + Description: &description, + InputSchema: BedrockToolInputSchema{ + JSON: schemaObject, + }, + }, + } + tools = append(tools, bedrockTool) + } + } + + return hasToolContent, tools +} + +// ToBedrockResponsesResponse converts BedrockConverseResponse to BifrostResponse (Responses structure) +func (bedrockResp *BedrockConverseResponse) ToResponsesBifrostResponse() (*schemas.BifrostResponse, error) { + if bedrockResp == nil { + return nil, fmt.Errorf("bedrock response is nil") + } + + bifrostResp := &schemas.BifrostResponse{ + ID: "", // Bedrock doesn't provide response ID + Model: "", // Will be set by provider + Object: "response", + ResponsesResponse: &schemas.ResponsesResponse{ + CreatedAt: int(time.Now().Unix()), + }, + } + + // Convert usage information + usage := &schemas.LLMUsage{ + ResponsesExtendedResponseUsage: &schemas.ResponsesExtendedResponseUsage{ + InputTokens: bedrockResp.Usage.InputTokens, + OutputTokens: bedrockResp.Usage.OutputTokens, + }, + TotalTokens: bedrockResp.Usage.TotalTokens, + } + bifrostResp.Usage = usage + + // Convert output message to Responses format + if bedrockResp.Output.Message != nil { + outputMessages := convertBedrockMessageToResponsesMessages(*bedrockResp.Output.Message) + bifrostResp.ResponsesResponse.Output = outputMessages + } + + return bifrostResp, nil +} + +// Helper functions + +func convertResponsesToolChoice(toolChoice schemas.ResponsesToolChoice) *BedrockToolChoice { + // Check if it's a string choice + if toolChoice.ResponsesToolChoiceStr != nil { + switch schemas.ResponsesToolChoiceType(*toolChoice.ResponsesToolChoiceStr) { + case schemas.ResponsesToolChoiceTypeAny, schemas.ResponsesToolChoiceTypeRequired: + return &BedrockToolChoice{ + Any: &BedrockToolChoiceAny{}, + } + case schemas.ResponsesToolChoiceTypeNone: + // Bedrock doesn't have explicit "none" - just don't include tools + return nil + } + } + + // Check if it's a struct choice + if toolChoice.ResponsesToolChoiceStruct != nil { + switch toolChoice.ResponsesToolChoiceStruct.Type { + case schemas.ResponsesToolChoiceTypeFunction: + // Extract the actual function name from the struct + if toolChoice.ResponsesToolChoiceStruct.Name != nil && *toolChoice.ResponsesToolChoiceStruct.Name != "" { + return &BedrockToolChoice{ + Tool: &BedrockToolChoiceTool{ + Name: *toolChoice.ResponsesToolChoiceStruct.Name, + }, + } + } + // If Name is nil or empty, return nil as we can't construct a valid tool choice + return nil + case schemas.ResponsesToolChoiceTypeAuto, schemas.ResponsesToolChoiceTypeAny, schemas.ResponsesToolChoiceTypeRequired: + return &BedrockToolChoice{ + Any: &BedrockToolChoiceAny{}, + } + case schemas.ResponsesToolChoiceTypeNone: + return nil + } + } + + return nil +} + +// convertResponsesItemsToBedrockMessages converts Responses items back to Bedrock messages +func convertResponsesItemsToBedrockMessages(messages []schemas.ResponsesMessage) ([]BedrockMessage, []BedrockSystemMessage, error) { + var bedrockMessages []BedrockMessage + var systemMessages []BedrockSystemMessage + + for _, msg := range messages { + // Handle Responses items + if msg.Type != nil { + switch *msg.Type { + case "message": + // Check if Role is present, skip message if not + if msg.Role == nil { + continue + } + + // Extract role from the Responses message structure + role := *msg.Role + + if role == schemas.ResponsesInputMessageRoleSystem { + // Convert to system message + // Ensure Content and ContentStr are present + if msg.Content != nil { + if msg.Content.ContentStr != nil { + systemMessages = append(systemMessages, BedrockSystemMessage{ + Text: msg.Content.ContentStr, + }) + } else if msg.Content.ContentBlocks != nil { + for _, block := range *msg.Content.ContentBlocks { + if block.Text != nil { + systemMessages = append(systemMessages, BedrockSystemMessage{ + Text: block.Text, + }) + } + } + } + } + // Skip system messages with no content + } else { + // Convert regular message + // Ensure Content is present + if msg.Content == nil { + // Skip messages without content or create with empty content + continue + } + + bedrockMsg := BedrockMessage{ + Role: BedrockMessageRole(role), + } + + // Convert content + contentBlocks, err := convertBifrostResponsesMessageContentBlocksToBedrockContentBlocks(*msg.Content) + if err != nil { + return nil, nil, fmt.Errorf("failed to convert content blocks: %w", err) + } + bedrockMsg.Content = contentBlocks + + bedrockMessages = append(bedrockMessages, bedrockMsg) + } + + case "function_call": + // Handle function calls from Responses + if msg.ResponsesToolMessage != nil { + // Create tool use content block + var toolUseID string + if msg.ResponsesToolMessage.CallID != nil { + toolUseID = *msg.ResponsesToolMessage.CallID + } + + // Get function name from ToolMessage + var functionName string + if msg.ResponsesToolMessage != nil && msg.ResponsesToolMessage.Name != nil { + functionName = *msg.ResponsesToolMessage.Name + } + + // Parse JSON arguments into interface{} + var input interface{} = map[string]interface{}{} + if msg.ResponsesToolMessage.Arguments != nil { + var parsedInput interface{} + if err := json.Unmarshal([]byte(*msg.ResponsesToolMessage.Arguments), &parsedInput); err != nil { + return nil, nil, fmt.Errorf("failed to parse tool arguments JSON: %w", err) + } + input = parsedInput + } + + toolUseBlock := BedrockContentBlock{ + ToolUse: &BedrockToolUse{ + ToolUseID: toolUseID, + Name: functionName, + Input: input, + }, + } + + // Create assistant message with tool use + assistantMsg := BedrockMessage{ + Role: BedrockMessageRoleAssistant, + Content: []BedrockContentBlock{toolUseBlock}, + } + bedrockMessages = append(bedrockMessages, assistantMsg) + + } + + case "function_call_output": + // Handle function call outputs from Responses + if msg.ResponsesToolMessage != nil && msg.ResponsesToolMessage.ResponsesFunctionToolCallOutput != nil { + var toolUseID string + if msg.ResponsesToolMessage.CallID != nil { + toolUseID = *msg.ResponsesToolMessage.CallID + } + + toolResultBlock := BedrockContentBlock{ + ToolResult: &BedrockToolResult{ + ToolUseID: toolUseID, + }, + } + + // Set content based on available data + if msg.ResponsesToolMessage.ResponsesFunctionToolCallOutput.ResponsesFunctionToolCallOutputStr != nil { + // Unmarshal the JSON string into an interface{} to get a proper JSON object + var parsedOutput interface{} + if err := json.Unmarshal([]byte(*msg.ResponsesToolMessage.ResponsesFunctionToolCallOutput.ResponsesFunctionToolCallOutputStr), &parsedOutput); err != nil { + return nil, nil, fmt.Errorf("failed to parse tool result JSON: %w", err) + } + toolResultBlock.ToolResult.Content = []BedrockContentBlock{ + {JSON: parsedOutput}, + } + } else if msg.ResponsesToolMessage.ResponsesFunctionToolCallOutput.ResponsesFunctionToolCallOutputBlocks != nil { + toolResultContent, err := convertBifrostResponsesMessageContentBlocksToBedrockContentBlocks(schemas.ResponsesMessageContent{ + ContentBlocks: msg.ResponsesToolMessage.ResponsesFunctionToolCallOutput.ResponsesFunctionToolCallOutputBlocks, + }) + if err != nil { + return nil, nil, fmt.Errorf("failed to convert tool result content blocks: %w", err) + } + toolResultBlock.ToolResult.Content = toolResultContent + } + + // Create user message with tool result + userMsg := BedrockMessage{ + Role: "user", + Content: []BedrockContentBlock{toolResultBlock}, + } + bedrockMessages = append(bedrockMessages, userMsg) + } + } + } + } + + return bedrockMessages, systemMessages, nil +} + +// convertBifrostResponsesMessageContentBlocksToBedrockContentBlocks converts Bifrost content to Bedrock content blocks +func convertBifrostResponsesMessageContentBlocksToBedrockContentBlocks(content schemas.ResponsesMessageContent) ([]BedrockContentBlock, error) { + var blocks []BedrockContentBlock + + if content.ContentStr != nil { + blocks = append(blocks, BedrockContentBlock{ + Text: content.ContentStr, + }) + } else if content.ContentBlocks != nil { + for _, block := range *content.ContentBlocks { + + bedrockBlock := BedrockContentBlock{} + + switch block.Type { + case schemas.ResponsesInputMessageContentBlockTypeText: + bedrockBlock.Text = block.Text + case schemas.ResponsesInputMessageContentBlockTypeImage: + if block.ResponsesInputMessageContentBlockImage != nil && block.ResponsesInputMessageContentBlockImage.ImageURL != nil { + imageSource, err := convertImageToBedrockSource(*block.ResponsesInputMessageContentBlockImage.ImageURL) + if err != nil { + return nil, fmt.Errorf("failed to convert image in responses content block: %w", err) + } + bedrockBlock.Image = imageSource + } + default: + // Don't add anything + } + + blocks = append(blocks, bedrockBlock) + } + } + + return blocks, nil +} + +// convertBedrockMessageToResponsesMessages converts Bedrock message to ChatMessage output format +func convertBedrockMessageToResponsesMessages(bedrockMsg BedrockMessage) []schemas.ResponsesMessage { + var outputMessages []schemas.ResponsesMessage + + for _, block := range bedrockMsg.Content { + if block.Text != nil { + // Text content + outputMessages = append(outputMessages, schemas.ResponsesMessage{ + Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant), + Content: &schemas.ResponsesMessageContent{ + ContentStr: block.Text, + }, + }) + } else if block.ToolUse != nil { + // Tool use content + // Create copies of the values to avoid range loop variable capture + toolUseID := block.ToolUse.ToolUseID + toolUseName := block.ToolUse.Name + + toolMsg := schemas.ResponsesMessage{ + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant), + Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall), + Status: schemas.Ptr("completed"), + ResponsesToolMessage: &schemas.ResponsesToolMessage{ + CallID: &toolUseID, + Name: &toolUseName, + Arguments: schemas.Ptr(schemas.JsonifyInput(block.ToolUse.Input)), + }, + } + outputMessages = append(outputMessages, toolMsg) + } else if block.ToolResult != nil { + // Tool result content - typically not in assistant output but handled for completeness + // Prefer JSON payloads without unmarshalling; fallback to text + var resultContent string + if len(block.ToolResult.Content) > 0 { + // JSON first (no unmarshal; just one marshal to string when present) + for _, c := range block.ToolResult.Content { + if c.JSON != nil { + resultContent = schemas.JsonifyInput(c.JSON) + break + } + } + // Fallback to first available text block + if resultContent == "" { + for _, c := range block.ToolResult.Content { + if c.Text != nil { + resultContent = *c.Text + break + } + } + } + } + + // Create a copy of the value to avoid range loop variable capture + toolResultID := block.ToolResult.ToolUseID + + resultMsg := schemas.ResponsesMessage{ + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant), + Content: &schemas.ResponsesMessageContent{ + ContentStr: &resultContent, + }, + Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCallOutput), + ResponsesToolMessage: &schemas.ResponsesToolMessage{ + CallID: &toolResultID, + ResponsesFunctionToolCallOutput: &schemas.ResponsesFunctionToolCallOutput{ + ResponsesFunctionToolCallOutputStr: &resultContent, + }, + }, + } + outputMessages = append(outputMessages, resultMsg) + } + } + + return outputMessages +} diff --git a/core/schemas/providers/bedrock/text.go b/core/schemas/providers/bedrock/text.go index 5cd5901489..01a2644181 100644 --- a/core/schemas/providers/bedrock/text.go +++ b/core/schemas/providers/bedrock/text.go @@ -4,46 +4,54 @@ import ( "strings" "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/core/schemas/providers/anthropic" ) -const AnthropicDefaultMaxTokens = 4096 - // ToBedrockTextCompletionRequest converts a Bifrost text completion request to Bedrock format -func ToBedrockTextCompletionRequest(bifrostReq *schemas.BifrostRequest) *BedrockTextCompletionRequest { - if bifrostReq == nil || bifrostReq.Input.TextCompletionInput == nil { +func ToBedrockTextCompletionRequest(bifrostReq *schemas.BifrostTextCompletionRequest) *BedrockTextCompletionRequest { + if bifrostReq == nil || (bifrostReq.Input.PromptStr == nil && len(bifrostReq.Input.PromptArray) == 0) { return nil } + // Extract the raw prompt from bifrostReq + prompt := "" + if bifrostReq.Input.PromptStr != nil { + prompt = *bifrostReq.Input.PromptStr + } else if len(bifrostReq.Input.PromptArray) > 0 { + prompt = strings.Join(bifrostReq.Input.PromptArray, "\n\n") + } + bedrockReq := &BedrockTextCompletionRequest{ - Prompt: *bifrostReq.Input.TextCompletionInput, + Prompt: prompt, } - // Convert parameters if present + // Apply parameters if bifrostReq.Params != nil { - // Handle max tokens with model-specific logic - if bifrostReq.Params.MaxTokens != nil { - if strings.Contains(bifrostReq.Model, "anthropic.") { - bedrockReq.MaxTokensToSample = bifrostReq.Params.MaxTokens - } else { - bedrockReq.MaxTokens = bifrostReq.Params.MaxTokens - } - } - - // Standard sampling parameters bedrockReq.Temperature = bifrostReq.Params.Temperature bedrockReq.TopP = bifrostReq.Params.TopP - bedrockReq.TopK = bifrostReq.Params.TopK - // Handle stop sequences with dual support - if bifrostReq.Params.StopSequences != nil { - if strings.Contains(bifrostReq.Model, "anthropic.") { - bedrockReq.StopSequences = bifrostReq.Params.StopSequences - } else { - bedrockReq.Stop = bifrostReq.Params.StopSequences + if bifrostReq.Params.ExtraParams != nil { + if topK, ok := schemas.SafeExtractIntPointer(bifrostReq.Params.ExtraParams["top_k"]); ok { + bedrockReq.TopK = topK } } } + // Apply model-specific formatting and field naming + if strings.Contains(bifrostReq.Model, "anthropic.") || strings.Contains(bifrostReq.Model, "claude") { + // For Claude models, wrap the prompt in Anthropic format and use Anthropic field names + anthropicReq := anthropic.ToAnthropicTextCompletionRequest(bifrostReq) + bedrockReq.Prompt = anthropicReq.Prompt + bedrockReq.MaxTokensToSample = &anthropicReq.MaxTokensToSample + bedrockReq.StopSequences = anthropicReq.StopSequences + } else { + // For other models, use standard field names with raw prompt + if bifrostReq.Params != nil { + bedrockReq.MaxTokens = bifrostReq.Params.MaxTokens + bedrockReq.Stop = bifrostReq.Params.Stop + } + } + return bedrockReq } @@ -54,23 +62,18 @@ func (response *BedrockAnthropicTextResponse) ToBifrostResponse() *schemas.Bifro } return &schemas.BifrostResponse{ - Choices: []schemas.BifrostResponseChoice{ + Choices: []schemas.BifrostChatResponseChoice{ { Index: 0, - BifrostNonStreamResponseChoice: &schemas.BifrostNonStreamResponseChoice{ - Message: schemas.BifrostMessage{ - Role: schemas.ModelChatMessageRoleAssistant, - Content: schemas.MessageContent{ - ContentStr: &response.Completion, - }, - }, - StopString: &response.Stop, + BifrostTextCompletionResponseChoice: &schemas.BifrostTextCompletionResponseChoice{ + Text: &response.Completion, }, FinishReason: &response.StopReason, }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: schemas.Bedrock, + RequestType: schemas.TextCompletionRequest, + Provider: schemas.Bedrock, }, } } @@ -81,17 +84,12 @@ func (response *BedrockMistralTextResponse) ToBifrostResponse() *schemas.Bifrost return nil } - var choices []schemas.BifrostResponseChoice + var choices []schemas.BifrostChatResponseChoice for i, output := range response.Outputs { - choices = append(choices, schemas.BifrostResponseChoice{ + choices = append(choices, schemas.BifrostChatResponseChoice{ Index: i, - BifrostNonStreamResponseChoice: &schemas.BifrostNonStreamResponseChoice{ - Message: schemas.BifrostMessage{ - Role: schemas.ModelChatMessageRoleAssistant, - Content: schemas.MessageContent{ - ContentStr: &output.Text, - }, - }, + BifrostTextCompletionResponseChoice: &schemas.BifrostTextCompletionResponseChoice{ + Text: &output.Text, }, FinishReason: &output.StopReason, }) @@ -100,7 +98,8 @@ func (response *BedrockMistralTextResponse) ToBifrostResponse() *schemas.Bifrost return &schemas.BifrostResponse{ Choices: choices, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: schemas.Bedrock, + RequestType: schemas.TextCompletionRequest, + Provider: schemas.Bedrock, }, } } diff --git a/core/schemas/providers/bedrock/types.go b/core/schemas/providers/bedrock/types.go index a38f92d89c..225494b510 100644 --- a/core/schemas/providers/bedrock/types.go +++ b/core/schemas/providers/bedrock/types.go @@ -37,9 +37,16 @@ type BedrockConverseRequest struct { RequestMetadata *map[string]string `json:"requestMetadata,omitempty"` // Request metadata } +type BedrockMessageRole string + +const ( + BedrockMessageRoleUser BedrockMessageRole = "user" + BedrockMessageRoleAssistant BedrockMessageRole = "assistant" +) + // BedrockMessage represents a message in the conversation type BedrockMessage struct { - Role string `json:"role"` // Required: "user" or "assistant" + Role BedrockMessageRole `json:"role"` // Required: "user" or "assistant" Content []BedrockContentBlock `json:"content"` // Required: Array of content blocks } @@ -68,6 +75,9 @@ type BedrockContentBlock struct { // Guard content (for guardrails) GuardContent *BedrockGuardContent `json:"guardContent,omitempty"` + + // For Tool Call Result content + JSON interface{} `json:"json,omitempty"` } // BedrockImageSource represents image content @@ -102,24 +112,9 @@ type BedrockToolUse struct { // BedrockToolResult represents the result of a tool use type BedrockToolResult struct { - ToolUseID string `json:"toolUseId"` // Required: ID of the tool use this result corresponds to - Content []BedrockToolResultContent `json:"content"` // Required: Content of the tool result - Status *string `json:"status,omitempty"` // Optional: Status of tool execution ("success" or "error") -} - -// BedrockToolResultContent represents content within a tool result -type BedrockToolResultContent struct { - // Text content - Text *string `json:"text,omitempty"` - - // Image content - Image *BedrockImageSource `json:"image,omitempty"` - - // Document content - Document *BedrockDocumentSource `json:"document,omitempty"` - - // JSON content - JSON interface{} `json:"json,omitempty"` + ToolUseID string `json:"toolUseId"` // Required: ID of the tool use this result corresponds to + Content []BedrockContentBlock `json:"content"` // Required: Content of the tool result + Status *string `json:"status,omitempty"` // Optional: Status of tool execution ("success" or "error") } // BedrockGuardContent represents guard content for guardrails @@ -420,3 +415,18 @@ type BedrockMetadataEvent struct { Metrics *BedrockConverseMetrics `json:"metrics,omitempty"` // Performance metrics Trace *BedrockConverseTrace `json:"trace,omitempty"` // Trace information } + +// ==================== EMBEDDING TYPES ==================== + +// BedrockTitanEmbeddingRequest represents a Bedrock Titan embedding request +type BedrockTitanEmbeddingRequest struct { + InputText string `json:"inputText"` // Required: Text to embed + // Note: Titan models have fixed dimensions and don't support the dimensions parameter + // ExtraParams can be used for any additional model-specific parameters +} + +// BedrockTitanEmbeddingResponse represents a Bedrock Titan embedding response +type BedrockTitanEmbeddingResponse struct { + Embedding []float32 `json:"embedding"` // The embedding vector + InputTextTokenCount int `json:"inputTextTokenCount"` // Number of tokens in input +} diff --git a/core/schemas/providers/bedrock/utils.go b/core/schemas/providers/bedrock/utils.go index c78751da17..1a44557b9c 100644 --- a/core/schemas/providers/bedrock/utils.go +++ b/core/schemas/providers/bedrock/utils.go @@ -10,7 +10,7 @@ import ( ) // convertParameters handles parameter conversion -func convertParameters(bifrostReq *schemas.BifrostRequest, bedrockReq *BedrockConverseRequest) { +func convertChatParameters(bifrostReq *schemas.BifrostChatRequest, bedrockReq *BedrockConverseRequest) { if bifrostReq.Params == nil { return } @@ -101,13 +101,13 @@ func convertParameters(bifrostReq *schemas.BifrostRequest, bedrockReq *BedrockCo } } -// ensureToolConfigForConversation ensures toolConfig is present when tool content exists -func ensureToolConfigForConversation(bifrostReq *schemas.BifrostRequest, bedrockReq *BedrockConverseRequest) { +// ensureChatToolConfigForConversation ensures toolConfig is present when tool content exists +func ensureChatToolConfigForConversation(bifrostReq *schemas.BifrostChatRequest, bedrockReq *BedrockConverseRequest) { if bedrockReq.ToolConfig != nil { return // Already has tool config } - hasToolContent, tools := extractToolsFromConversationHistory(*bifrostReq.Input.ChatCompletionInput) + hasToolContent, tools := extractToolsFromConversationHistory(bifrostReq.Input) if hasToolContent && len(tools) > 0 { bedrockReq.ToolConfig = &BedrockToolConfig{Tools: &tools} } @@ -115,13 +115,13 @@ func ensureToolConfigForConversation(bifrostReq *schemas.BifrostRequest, bedrock // convertMessages converts Bifrost messages to Bedrock format // Returns regular messages and system messages separately -func convertMessages(bifrostMessages []schemas.BifrostMessage) ([]BedrockMessage, []BedrockSystemMessage, error) { +func convertMessages(bifrostMessages []schemas.ChatMessage) ([]BedrockMessage, []BedrockSystemMessage, error) { var messages []BedrockMessage var systemMessages []BedrockSystemMessage for _, msg := range bifrostMessages { switch msg.Role { - case schemas.ModelChatMessageRoleSystem: + case schemas.ChatMessageRoleSystem: // Convert system message systemMsg, err := convertSystemMessage(msg) if err != nil { @@ -129,7 +129,7 @@ func convertMessages(bifrostMessages []schemas.BifrostMessage) ([]BedrockMessage } systemMessages = append(systemMessages, systemMsg) - case schemas.ModelChatMessageRoleUser, schemas.ModelChatMessageRoleAssistant: + case schemas.ChatMessageRoleUser, schemas.ChatMessageRoleAssistant: // Convert regular message bedrockMsg, err := convertMessage(msg) if err != nil { @@ -137,7 +137,7 @@ func convertMessages(bifrostMessages []schemas.BifrostMessage) ([]BedrockMessage } messages = append(messages, bedrockMsg) - case schemas.ModelChatMessageRoleTool: + case schemas.ChatMessageRoleTool: // Convert tool message - this should be part of the conversation bedrockMsg, err := convertToolMessage(msg) if err != nil { @@ -154,7 +154,7 @@ func convertMessages(bifrostMessages []schemas.BifrostMessage) ([]BedrockMessage } // convertSystemMessage converts a Bifrost system message to Bedrock format -func convertSystemMessage(msg schemas.BifrostMessage) (BedrockSystemMessage, error) { +func convertSystemMessage(msg schemas.ChatMessage) (BedrockSystemMessage, error) { systemMsg := BedrockSystemMessage{} // Convert content @@ -165,7 +165,7 @@ func convertSystemMessage(msg schemas.BifrostMessage) (BedrockSystemMessage, err // Combine all text blocks into a single string var textParts []string for _, block := range *msg.Content.ContentBlocks { - if block.Type == schemas.ContentBlockTypeText && block.Text != nil { + if block.Type == schemas.ChatContentBlockTypeText && block.Text != nil { textParts = append(textParts, *block.Text) } } @@ -179,9 +179,9 @@ func convertSystemMessage(msg schemas.BifrostMessage) (BedrockSystemMessage, err } // convertMessage converts a Bifrost message to Bedrock format -func convertMessage(msg schemas.BifrostMessage) (BedrockMessage, error) { +func convertMessage(msg schemas.ChatMessage) (BedrockMessage, error) { bedrockMsg := BedrockMessage{ - Role: string(msg.Role), + Role: BedrockMessageRole(msg.Role), } // Convert content @@ -191,8 +191,8 @@ func convertMessage(msg schemas.BifrostMessage) (BedrockMessage, error) { } // Add tool calls if present (for assistant messages) - if msg.AssistantMessage != nil && msg.AssistantMessage.ToolCalls != nil { - for _, toolCall := range *msg.AssistantMessage.ToolCalls { + if msg.ChatAssistantMessage != nil && msg.ChatAssistantMessage.ToolCalls != nil { + for _, toolCall := range *msg.ChatAssistantMessage.ToolCalls { toolUseBlock := convertToolCallToContentBlock(toolCall) contentBlocks = append(contentBlocks, toolUseBlock) } @@ -203,35 +203,49 @@ func convertMessage(msg schemas.BifrostMessage) (BedrockMessage, error) { } // convertToolMessage converts a Bifrost tool message to Bedrock format -func convertToolMessage(msg schemas.BifrostMessage) (BedrockMessage, error) { +func convertToolMessage(msg schemas.ChatMessage) (BedrockMessage, error) { bedrockMsg := BedrockMessage{ Role: "user", // Tool messages are typically treated as user messages in Bedrock } // Tool messages should have a tool_call_id - if msg.ToolMessage == nil || msg.ToolMessage.ToolCallID == nil { + if msg.ChatToolMessage == nil || msg.ChatToolMessage.ToolCallID == nil { return BedrockMessage{}, fmt.Errorf("tool message missing tool_call_id") } // Convert content to tool result - var toolResultContent []BedrockToolResultContent + var toolResultContent []BedrockContentBlock if msg.Content.ContentStr != nil { - toolResultContent = append(toolResultContent, BedrockToolResultContent{ - Text: msg.Content.ContentStr, - }) + // Bedrock expects JSON to be a parsed object, not a string + // Try to unmarshal the string content as JSON + var parsedOutput interface{} + if err := json.Unmarshal([]byte(*msg.Content.ContentStr), &parsedOutput); err != nil { + // If it's not valid JSON, wrap it as a text block instead + toolResultContent = append(toolResultContent, BedrockContentBlock{ + Text: msg.Content.ContentStr, + }) + } else { + // Use the parsed JSON object + toolResultContent = append(toolResultContent, BedrockContentBlock{ + JSON: parsedOutput, + }) + } } else if msg.Content.ContentBlocks != nil { for _, block := range *msg.Content.ContentBlocks { switch block.Type { - case schemas.ContentBlockTypeText: + case schemas.ChatContentBlockTypeText: if block.Text != nil { - toolResultContent = append(toolResultContent, BedrockToolResultContent{ + toolResultContent = append(toolResultContent, BedrockContentBlock{ Text: block.Text, }) } - case schemas.ContentBlockTypeImage: - if block.ImageURL != nil { - imageSource := convertImageToBedrockSource(*block.ImageURL) - toolResultContent = append(toolResultContent, BedrockToolResultContent{ + case schemas.ChatContentBlockTypeImage: + if block.ImageURLStruct != nil { + imageSource, err := convertImageToBedrockSource(block.ImageURLStruct.URL) + if err != nil { + return BedrockMessage{}, fmt.Errorf("failed to convert image in tool result: %w", err) + } + toolResultContent = append(toolResultContent, BedrockContentBlock{ Image: imageSource, }) } @@ -242,7 +256,7 @@ func convertToolMessage(msg schemas.BifrostMessage) (BedrockMessage, error) { // Create tool result content block toolResultBlock := BedrockContentBlock{ ToolResult: &BedrockToolResult{ - ToolUseID: *msg.ToolMessage.ToolCallID, + ToolUseID: *msg.ChatToolMessage.ToolCallID, Content: toolResultContent, Status: schemas.Ptr("success"), // Default to success }, @@ -253,7 +267,7 @@ func convertToolMessage(msg schemas.BifrostMessage) (BedrockMessage, error) { } // convertContent converts Bifrost message content to Bedrock content blocks -func convertContent(content schemas.MessageContent) ([]BedrockContentBlock, error) { +func convertContent(content schemas.ChatMessageContent) ([]BedrockContentBlock, error) { var contentBlocks []BedrockContentBlock if content.ContentStr != nil { @@ -276,24 +290,27 @@ func convertContent(content schemas.MessageContent) ([]BedrockContentBlock, erro } // convertContentBlock converts a Bifrost content block to Bedrock format -func convertContentBlock(block schemas.ContentBlock) (BedrockContentBlock, error) { +func convertContentBlock(block schemas.ChatContentBlock) (BedrockContentBlock, error) { switch block.Type { - case schemas.ContentBlockTypeText: + case schemas.ChatContentBlockTypeText: return BedrockContentBlock{ Text: block.Text, }, nil - case schemas.ContentBlockTypeImage: - if block.ImageURL == nil { + case schemas.ChatContentBlockTypeImage: + if block.ImageURLStruct == nil { return BedrockContentBlock{}, fmt.Errorf("image_url block missing image_url field") } - imageSource := convertImageToBedrockSource(*block.ImageURL) + imageSource, err := convertImageToBedrockSource(block.ImageURLStruct.URL) + if err != nil { + return BedrockContentBlock{}, fmt.Errorf("failed to convert image: %w", err) + } return BedrockContentBlock{ Image: imageSource, }, nil - case schemas.ContentBlockTypeInputAudio: + case schemas.ChatContentBlockTypeInputAudio: // Bedrock doesn't support audio input in Converse API return BedrockContentBlock{}, fmt.Errorf("audio input not supported in Bedrock Converse API") @@ -304,11 +321,20 @@ func convertContentBlock(block schemas.ContentBlock) (BedrockContentBlock, error // convertImageToBedrockSource converts a Bifrost image URL to Bedrock image source // Uses centralized utility functions like Anthropic converter -func convertImageToBedrockSource(imageURL schemas.ImageURLStruct) *BedrockImageSource { +// Returns an error for URL-based images (non-base64) since Bedrock requires base64 data +func convertImageToBedrockSource(imageURL string) (*BedrockImageSource, error) { // Use centralized utility functions from schemas package - sanitizedURL, _ := schemas.SanitizeImageURL(imageURL.URL) + sanitizedURL, err := schemas.SanitizeImageURL(imageURL) + if err != nil { + return nil, fmt.Errorf("failed to sanitize image URL: %w", err) + } urlTypeInfo := schemas.ExtractURLTypeInfo(sanitizedURL) + // Check if this is a URL-based image (not base64/data URI) + if urlTypeInfo.Type != schemas.ImageContentTypeBase64 || urlTypeInfo.DataURLWithoutPrefix == nil { + return nil, fmt.Errorf("only base64-encoded images (data URI format) are supported; remote image URLs are not allowed") + } + // Determine format from media type or default to jpeg format := "jpeg" if urlTypeInfo.MediaType != nil { @@ -326,31 +352,19 @@ func convertImageToBedrockSource(imageURL schemas.ImageURLStruct) *BedrockImageS imageSource := &BedrockImageSource{ Format: format, - } - - // Set source data based on type - if urlTypeInfo.Type == schemas.ImageContentTypeBase64 && urlTypeInfo.DataURLWithoutPrefix != nil { - // Base64 data - imageSource.Source = BedrockImageSourceData{ + Source: BedrockImageSourceData{ Bytes: urlTypeInfo.DataURLWithoutPrefix, - } - } else { - // For URLs, Bedrock requires base64 - this would need additional handling - // For now, we'll use empty bytes (this may cause errors but is consistent with old behavior) - emptyBytes := "" - imageSource.Source = BedrockImageSourceData{ - Bytes: &emptyBytes, - } + }, } - return imageSource + return imageSource, nil } // convertInferenceConfig converts Bifrost parameters to Bedrock inference config -func convertInferenceConfig(params *schemas.ModelParameters) *BedrockInferenceConfig { +func convertInferenceConfig(params *schemas.ChatParameters) *BedrockInferenceConfig { var config BedrockInferenceConfig - if params.MaxTokens != nil { - config.MaxTokens = params.MaxTokens + if params.MaxCompletionTokens != nil { + config.MaxTokens = params.MaxCompletionTokens } if params.Temperature != nil { @@ -361,21 +375,21 @@ func convertInferenceConfig(params *schemas.ModelParameters) *BedrockInferenceCo config.TopP = params.TopP } - if params.StopSequences != nil { - config.StopSequences = params.StopSequences + if params.Stop != nil { + config.StopSequences = params.Stop } return &config } // convertToolConfig converts Bifrost tools to Bedrock tool config -func convertToolConfig(params *schemas.ModelParameters) *BedrockToolConfig { - if params.Tools == nil || len(*params.Tools) == 0 { +func convertToolConfig(params *schemas.ChatParameters) *BedrockToolConfig { + if len(params.Tools) == 0 { return nil } var bedrockTools []BedrockTool - for _, tool := range *params.Tools { + for _, tool := range params.Tools { // Create the complete schema object that Bedrock expects var schemaObject interface{} if tool.Function.Parameters != nil { @@ -405,9 +419,9 @@ func convertToolConfig(params *schemas.ModelParameters) *BedrockToolConfig { bedrockTool := BedrockTool{ ToolSpec: &BedrockToolSpec{ Name: tool.Function.Name, - Description: &tool.Function.Description, + Description: schemas.Ptr(description), InputSchema: BedrockToolInputSchema{ - JSON: convertFunctionParameters(tool.Function.Parameters), + JSON: schemaObject, }, }, } @@ -429,64 +443,20 @@ func convertToolConfig(params *schemas.ModelParameters) *BedrockToolConfig { return toolConfig } -// convertFunctionParameters converts Bifrost function parameters to Bedrock input schema -func convertFunctionParameters(params schemas.FunctionParameters) map[string]interface{} { - schema := map[string]interface{}{ - "type": params.Type, - } - - if params.Description != nil { - schema["description"] = *params.Description - } - - if params.Properties != nil { - schema["properties"] = params.Properties - } - - if len(params.Required) > 0 { - schema["required"] = params.Required - } - - return schema -} - // convertToolChoice converts Bifrost tool choice to Bedrock format -func convertToolChoice(toolChoice schemas.ToolChoice) *BedrockToolChoice { +func convertToolChoice(toolChoice schemas.ChatToolChoice) *BedrockToolChoice { // Check if it's a string choice - if toolChoice.ToolChoiceStr != nil { - switch schemas.ToolChoiceType(*toolChoice.ToolChoiceStr) { - case schemas.ToolChoiceTypeAuto: + if toolChoice.ChatToolChoiceStr != nil { + switch schemas.ChatToolChoiceType(*toolChoice.ChatToolChoiceStr) { + case schemas.ChatToolChoiceTypeFunction: return &BedrockToolChoice{ Auto: &BedrockToolChoiceAuto{}, } - case schemas.ToolChoiceTypeAny, schemas.ToolChoiceTypeRequired: + case schemas.ChatToolChoiceTypeAny, schemas.ChatToolChoiceTypeRequired: return &BedrockToolChoice{ Any: &BedrockToolChoiceAny{}, } - case schemas.ToolChoiceTypeNone: - // Bedrock doesn't have explicit "none" - just don't include tools - return nil - } - } - - // Check if it's a struct choice - if toolChoice.ToolChoiceStruct != nil { - switch toolChoice.ToolChoiceStruct.Type { - case schemas.ToolChoiceTypeAuto: - return &BedrockToolChoice{ - Auto: &BedrockToolChoiceAuto{}, - } - case schemas.ToolChoiceTypeAny, schemas.ToolChoiceTypeRequired: - return &BedrockToolChoice{ - Any: &BedrockToolChoiceAny{}, - } - case schemas.ToolChoiceTypeFunction: - return &BedrockToolChoice{ - Tool: &BedrockToolChoiceTool{ - Name: toolChoice.ToolChoiceStruct.Function.Name, - }, - } - case schemas.ToolChoiceTypeNone: + case schemas.ChatToolChoiceTypeNone: // Bedrock doesn't have explicit "none" - just don't include tools return nil } @@ -496,7 +466,7 @@ func convertToolChoice(toolChoice schemas.ToolChoice) *BedrockToolChoice { } // extractToolsFromConversationHistory analyzes conversation history for tool content -func extractToolsFromConversationHistory(messages []schemas.BifrostMessage) (bool, []BedrockTool) { +func extractToolsFromConversationHistory(messages []schemas.ChatMessage) (bool, []BedrockTool) { hasToolContent := false toolsMap := make(map[string]BedrockTool) @@ -513,13 +483,13 @@ func extractToolsFromConversationHistory(messages []schemas.BifrostMessage) (boo } // checkMessageForToolContent checks a single message for tool content and updates the tools map -func checkMessageForToolContent(msg schemas.BifrostMessage, toolsMap map[string]BedrockTool) bool { +func checkMessageForToolContent(msg schemas.ChatMessage, toolsMap map[string]BedrockTool) bool { hasContent := false // Check assistant tool calls - if msg.AssistantMessage != nil && msg.AssistantMessage.ToolCalls != nil { + if msg.ChatAssistantMessage != nil && msg.ChatAssistantMessage.ToolCalls != nil { hasContent = true - for _, toolCall := range *msg.AssistantMessage.ToolCalls { + for _, toolCall := range *msg.ChatAssistantMessage.ToolCalls { if toolCall.Function.Name != nil { if _, exists := toolsMap[*toolCall.Function.Name]; !exists { // Create a complete schema object for extracted tools @@ -529,10 +499,10 @@ func checkMessageForToolContent(msg schemas.BifrostMessage, toolsMap map[string] } toolsMap[*toolCall.Function.Name] = BedrockTool{ - ToolSpec: &BedrockToolSpec{ - Name: *toolCall.Function.Name, - Description: schemas.Ptr("Tool extracted from conversation history"), - InputSchema: BedrockToolInputSchema{ + ToolSpec: &BedrockToolSpec{ + Name: *toolCall.Function.Name, + Description: schemas.Ptr("Tool extracted from conversation history"), + InputSchema: BedrockToolInputSchema{ JSON: schemaObject, }, }, @@ -543,7 +513,7 @@ func checkMessageForToolContent(msg schemas.BifrostMessage, toolsMap map[string] } // Check tool messages - if msg.ToolMessage != nil && msg.ToolMessage.ToolCallID != nil { + if msg.ChatToolMessage != nil && msg.ChatToolMessage.ToolCallID != nil { hasContent = true } @@ -560,7 +530,7 @@ func checkMessageForToolContent(msg schemas.BifrostMessage, toolsMap map[string] } // convertToolCallToContentBlock converts a Bifrost tool call to a Bedrock content block -func convertToolCallToContentBlock(toolCall schemas.ToolCall) BedrockContentBlock { +func convertToolCallToContentBlock(toolCall schemas.ChatAssistantMessageToolCall) BedrockContentBlock { toolUseID := "" if toolCall.ID != nil { toolUseID = *toolCall.ID diff --git a/core/schemas/providers/cohere/chat.go b/core/schemas/providers/cohere/chat.go index 022d686860..88b6e691c3 100644 --- a/core/schemas/providers/cohere/chat.go +++ b/core/schemas/providers/cohere/chat.go @@ -1,14 +1,16 @@ package cohere -import "github.com/maximhq/bifrost/core/schemas" +import ( + "github.com/maximhq/bifrost/core/schemas" +) // ConvertChatRequestToCohere converts a Bifrost request to Cohere v2 format -func ToCohereChatCompletionRequest(bifrostReq *schemas.BifrostRequest) *CohereChatRequest { - if bifrostReq == nil || bifrostReq.Input.ChatCompletionInput == nil { +func ToCohereChatCompletionRequest(bifrostReq *schemas.BifrostChatRequest) *CohereChatRequest { + if bifrostReq == nil || bifrostReq.Input == nil { return nil } - messages := *bifrostReq.Input.ChatCompletionInput + messages := bifrostReq.Input cohereReq := &CohereChatRequest{ Model: bifrostReq.Model, } @@ -28,14 +30,14 @@ func ToCohereChatCompletionRequest(bifrostReq *schemas.BifrostRequest) *CohereCh for _, block := range *msg.Content.ContentBlocks { if block.Text != nil { contentBlocks = append(contentBlocks, CohereContentBlock{ - Type: "text", + Type: CohereContentBlockTypeText, Text: block.Text, }) - } else if block.ImageURL != nil { + } else if block.ImageURLStruct != nil { contentBlocks = append(contentBlocks, CohereContentBlock{ - Type: "image_url", + Type: CohereContentBlockTypeImage, ImageURL: &CohereImageURL{ - URL: block.ImageURL.URL, + URL: block.ImageURLStruct.URL, }, }) } @@ -46,15 +48,29 @@ func ToCohereChatCompletionRequest(bifrostReq *schemas.BifrostRequest) *CohereCh } // Convert tool calls for assistant messages - if msg.AssistantMessage != nil && msg.AssistantMessage.ToolCalls != nil { + if msg.ChatAssistantMessage != nil && msg.ChatAssistantMessage.ToolCalls != nil { var toolCalls []CohereToolCall - for _, toolCall := range *msg.AssistantMessage.ToolCalls { + for _, toolCall := range *msg.ChatAssistantMessage.ToolCalls { + // Safely extract function name and arguments + var functionName *string + var functionArguments string + + if toolCall.Function.Name != nil { + functionName = toolCall.Function.Name + } else { + // Use empty string if Name is nil + functionName = schemas.Ptr("") + } + + // Arguments is a string, not a pointer, so it's safe to access directly + functionArguments = toolCall.Function.Arguments + cohereToolCall := CohereToolCall{ ID: toolCall.ID, Type: "function", Function: &CohereFunction{ - Name: toolCall.Function.Name, - Arguments: toolCall.Function.Arguments, + Name: functionName, + Arguments: functionArguments, }, } toolCalls = append(toolCalls, cohereToolCall) @@ -63,8 +79,8 @@ func ToCohereChatCompletionRequest(bifrostReq *schemas.BifrostRequest) *CohereCh } // Convert tool messages - if msg.ToolMessage != nil && msg.ToolMessage.ToolCallID != nil { - cohereMsg.ToolCallID = msg.ToolMessage.ToolCallID + if msg.ChatToolMessage != nil && msg.ChatToolMessage.ToolCallID != nil { + cohereMsg.ToolCallID = msg.ChatToolMessage.ToolCallID } cohereMessages = append(cohereMessages, cohereMsg) @@ -74,30 +90,26 @@ func ToCohereChatCompletionRequest(bifrostReq *schemas.BifrostRequest) *CohereCh // Convert parameters if bifrostReq.Params != nil { - cohereReq.MaxTokens = bifrostReq.Params.MaxTokens + cohereReq.MaxTokens = bifrostReq.Params.MaxCompletionTokens cohereReq.Temperature = bifrostReq.Params.Temperature cohereReq.P = bifrostReq.Params.TopP - cohereReq.K = bifrostReq.Params.TopK - cohereReq.StopSequences = bifrostReq.Params.StopSequences + cohereReq.StopSequences = bifrostReq.Params.Stop cohereReq.FrequencyPenalty = bifrostReq.Params.FrequencyPenalty cohereReq.PresencePenalty = bifrostReq.Params.PresencePenalty // Convert extra params if bifrostReq.Params.ExtraParams != nil { // Handle thinking parameter - if thinkingParam, ok := bifrostReq.Params.ExtraParams["thinking"]; ok { + if thinkingParam, ok := schemas.SafeExtractFromMap(bifrostReq.Params.ExtraParams, "thinking"); ok { if thinkingMap, ok := thinkingParam.(map[string]interface{}); ok { thinking := &CohereThinking{} - if typeStr, ok := thinkingMap["type"].(string); ok { + if typeStr, ok := schemas.SafeExtractString(thinkingMap["type"]); ok { thinking.Type = CohereThinkingType(typeStr) } - if tokenBudget, ok := thinkingMap["token_budget"].(int); ok { - thinking.TokenBudget = &tokenBudget - } else if tokenBudgetFloat, ok := thinkingMap["token_budget"].(float64); ok { - tokenBudgetInt := int(tokenBudgetFloat) - thinking.TokenBudget = &tokenBudgetInt + if tokenBudget, ok := schemas.SafeExtractIntPointer(thinkingMap["token_budget"]); ok { + thinking.TokenBudget = tokenBudget } cohereReq.Thinking = thinking @@ -105,42 +117,59 @@ func ToCohereChatCompletionRequest(bifrostReq *schemas.BifrostRequest) *CohereCh } // Handle other Cohere-specific extra params - if safetyMode, ok := bifrostReq.Params.ExtraParams["safety_mode"].(string); ok { - cohereReq.SafetyMode = &safetyMode + if safetyMode, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["safety_mode"]); ok { + cohereReq.SafetyMode = safetyMode } - if logProbs, ok := bifrostReq.Params.ExtraParams["log_probs"].(bool); ok { - cohereReq.LogProbs = &logProbs + if logProbs, ok := schemas.SafeExtractBoolPointer(bifrostReq.Params.ExtraParams["log_probs"]); ok { + cohereReq.LogProbs = logProbs } - if strictToolChoice, ok := bifrostReq.Params.ExtraParams["strict_tool_choice"].(bool); ok { - cohereReq.StrictToolChoice = &strictToolChoice + if strictToolChoice, ok := schemas.SafeExtractBoolPointer(bifrostReq.Params.ExtraParams["strict_tool_choice"]); ok { + cohereReq.StrictToolChoice = strictToolChoice } } - // Convert tools - direct assignment since formats are identical + // Convert tools to Cohere-specific format (without "strict" field) if bifrostReq.Params.Tools != nil { - cohereReq.Tools = bifrostReq.Params.Tools + cohereTools := make([]CohereChatRequestTool, len(bifrostReq.Params.Tools)) + for i, tool := range bifrostReq.Params.Tools { + cohereTools[i] = CohereChatRequestTool{ + Type: string(tool.Type), + } + if tool.Function != nil { + cohereTools[i].Function = CohereChatRequestFunction{ + Name: tool.Function.Name, + Description: tool.Function.Description, + Parameters: tool.Function.Parameters, // Convert to map + // Note: No "strict" field - Cohere doesn't support it + } + } + } + cohereReq.Tools = &cohereTools } // Convert tool choice if bifrostReq.Params.ToolChoice != nil { - if bifrostReq.Params.ToolChoice.ToolChoiceStr != nil { - toolChoice := CohereToolChoice(*bifrostReq.Params.ToolChoice.ToolChoiceStr) - cohereReq.ToolChoice = &toolChoice - } else if bifrostReq.Params.ToolChoice.ToolChoiceStruct != nil { - switch bifrostReq.Params.ToolChoice.ToolChoiceStruct.Type { - case schemas.ToolChoiceTypeFunction: - toolChoice := CohereToolChoice("REQUIRED") - cohereReq.ToolChoice = &toolChoice - case schemas.ToolChoiceTypeNone: - toolChoice := CohereToolChoice("NONE") + toolChoice := bifrostReq.Params.ToolChoice + + if toolChoice.ChatToolChoiceStr != nil { + switch schemas.ChatToolChoiceType(*toolChoice.ChatToolChoiceStr) { + case schemas.ChatToolChoiceTypeNone: + toolChoice := ToolChoiceNone cohereReq.ToolChoice = &toolChoice default: - toolChoiceStr := string(bifrostReq.Params.ToolChoice.ToolChoiceStruct.Type) - toolChoice := CohereToolChoice(toolChoiceStr) + toolChoice := ToolChoiceRequired cohereReq.ToolChoice = &toolChoice } + } else if toolChoice.ChatToolChoiceStruct != nil { + switch toolChoice.ChatToolChoiceStruct.Type { + case schemas.ChatToolChoiceTypeFunction: + toolChoice := ToolChoiceRequired + cohereReq.ToolChoice = &toolChoice + default: + cohereReq.ToolChoice = nil + } } } } @@ -157,18 +186,19 @@ func (cohereResp *CohereChatResponse) ToBifrostResponse() *schemas.BifrostRespon bifrostResponse := &schemas.BifrostResponse{ ID: cohereResp.ID, Object: "chat.completion", - Choices: []schemas.BifrostResponseChoice{ + Choices: []schemas.BifrostChatResponseChoice{ { Index: 0, BifrostNonStreamResponseChoice: &schemas.BifrostNonStreamResponseChoice{ - Message: schemas.BifrostMessage{ - Role: schemas.ModelChatMessageRoleAssistant, + Message: schemas.ChatMessage{ + Role: schemas.ChatMessageRoleAssistant, }, }, }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: schemas.Cohere, + RequestType: schemas.ChatCompletionRequest, + Provider: schemas.Cohere, }, } @@ -177,30 +207,30 @@ func (cohereResp *CohereChatResponse) ToBifrostResponse() *schemas.BifrostRespon if cohereResp.Message.Content != nil { if cohereResp.Message.Content.IsString() { content := cohereResp.Message.Content.GetString() - bifrostResponse.Choices[0].BifrostNonStreamResponseChoice.Message.Content = schemas.MessageContent{ + bifrostResponse.Choices[0].BifrostNonStreamResponseChoice.Message.Content = schemas.ChatMessageContent{ ContentStr: content, } } else if cohereResp.Message.Content.IsBlocks() { blocks := cohereResp.Message.Content.GetBlocks() if blocks != nil { - var contentBlocks []schemas.ContentBlock + var contentBlocks []schemas.ChatContentBlock for _, block := range *blocks { - if block.Type == "text" && block.Text != nil { - contentBlocks = append(contentBlocks, schemas.ContentBlock{ - Type: "text", + if block.Type == CohereContentBlockTypeText && block.Text != nil { + contentBlocks = append(contentBlocks, schemas.ChatContentBlock{ + Type: schemas.ChatContentBlockTypeText, Text: block.Text, }) - } else if block.Type == "image_url" && block.ImageURL != nil { - contentBlocks = append(contentBlocks, schemas.ContentBlock{ - Type: "image_url", - ImageURL: &schemas.ImageURLStruct{ + } else if block.Type == CohereContentBlockTypeImage && block.ImageURL != nil { + contentBlocks = append(contentBlocks, schemas.ChatContentBlock{ + Type: schemas.ChatContentBlockTypeImage, + ImageURLStruct: &schemas.ChatInputImage{ URL: block.ImageURL.URL, }, }) } } if len(contentBlocks) > 0 { - bifrostResponse.Choices[0].BifrostNonStreamResponseChoice.Message.Content = schemas.MessageContent{ + bifrostResponse.Choices[0].BifrostNonStreamResponseChoice.Message.Content = schemas.ChatMessageContent{ ContentBlocks: &contentBlocks, } } @@ -210,18 +240,38 @@ func (cohereResp *CohereChatResponse) ToBifrostResponse() *schemas.BifrostRespon // Convert tool calls if cohereResp.Message.ToolCalls != nil { - var toolCalls []schemas.ToolCall + var toolCalls []schemas.ChatAssistantMessageToolCall for _, toolCall := range *cohereResp.Message.ToolCalls { - bifrostToolCall := schemas.ToolCall{ + // Check if Function is nil to avoid nil pointer dereference + if toolCall.Function == nil { + // Skip this tool call if Function is nil + continue + } + + // Safely extract function name and arguments + var functionName *string + var functionArguments string + + if toolCall.Function.Name != nil { + functionName = toolCall.Function.Name + } else { + // Use empty string if Name is nil + functionName = schemas.Ptr("") + } + + // Arguments is a string, not a pointer, so it's safe to access directly + functionArguments = toolCall.Function.Arguments + + bifrostToolCall := schemas.ChatAssistantMessageToolCall{ ID: toolCall.ID, - Function: schemas.FunctionCall{ - Name: toolCall.Function.Name, - Arguments: toolCall.Function.Arguments, + Function: schemas.ChatAssistantMessageToolCallFunction{ + Name: functionName, + Arguments: functionArguments, }, } toolCalls = append(toolCalls, bifrostToolCall) } - bifrostResponse.Choices[0].BifrostNonStreamResponseChoice.Message.AssistantMessage = &schemas.AssistantMessage{ + bifrostResponse.Choices[0].BifrostNonStreamResponseChoice.Message.ChatAssistantMessage = &schemas.ChatAssistantMessage{ ToolCalls: &toolCalls, } } diff --git a/core/schemas/providers/cohere/embedding.go b/core/schemas/providers/cohere/embedding.go index 0796e3fe94..a2d0c5b1d0 100644 --- a/core/schemas/providers/cohere/embedding.go +++ b/core/schemas/providers/cohere/embedding.go @@ -3,19 +3,26 @@ package cohere import "github.com/maximhq/bifrost/core/schemas" // ToCohereEmbeddingRequest converts a Bifrost embedding request to Cohere format -func ToCohereEmbeddingRequest(bifrostReq *schemas.BifrostRequest) *CohereEmbeddingRequest { - if bifrostReq == nil || bifrostReq.Input.EmbeddingInput == nil { +func ToCohereEmbeddingRequest(bifrostReq *schemas.BifrostEmbeddingRequest) *CohereEmbeddingRequest { + if bifrostReq == nil || (bifrostReq.Input.Text == nil && bifrostReq.Input.Texts == nil) { return nil } - embeddingInput := bifrostReq.Input.EmbeddingInput + embeddingInput := bifrostReq.Input cohereReq := &CohereEmbeddingRequest{ Model: bifrostReq.Model, } + texts := []string{} + if embeddingInput.Text != nil { + texts = append(texts, *embeddingInput.Text) + } else { + texts = embeddingInput.Texts + } + // Convert texts from Bifrost format - if len(embeddingInput.Texts) > 0 { - cohereReq.Texts = &embeddingInput.Texts + if len(texts) > 0 { + cohereReq.Texts = &texts } // Set default input type if not specified in extra params @@ -23,32 +30,31 @@ func ToCohereEmbeddingRequest(bifrostReq *schemas.BifrostRequest) *CohereEmbeddi if bifrostReq.Params != nil { cohereReq.OutputDimension = bifrostReq.Params.Dimensions - cohereReq.MaxTokens = bifrostReq.Params.MaxTokens + + if bifrostReq.Params.ExtraParams != nil { + if maxTokens, ok := schemas.SafeExtractIntPointer(bifrostReq.Params.ExtraParams["max_tokens"]); ok { + cohereReq.MaxTokens = maxTokens + } + } } // Handle extra params if bifrostReq.Params != nil && bifrostReq.Params.ExtraParams != nil { // Input type - if inputType, ok := bifrostReq.Params.ExtraParams["input_type"].(string); ok { + if inputType, ok := schemas.SafeExtractString(bifrostReq.Params.ExtraParams["input_type"]); ok { cohereReq.InputType = inputType } // Embedding types - if embeddingTypes, ok := bifrostReq.Params.ExtraParams["embedding_types"].([]interface{}); ok { - var types []string - for _, t := range embeddingTypes { - if typeStr, ok := t.(string); ok { - types = append(types, typeStr) - } - } - if len(types) > 0 { - cohereReq.EmbeddingTypes = &types + if embeddingTypes, ok := schemas.SafeExtractStringSlice(bifrostReq.Params.ExtraParams["embedding_types"]); ok { + if len(embeddingTypes) > 0 { + cohereReq.EmbeddingTypes = &embeddingTypes } } // Truncate - if truncate, ok := bifrostReq.Params.ExtraParams["truncate"].(string); ok { - cohereReq.Truncate = &truncate + if truncate, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["truncate"]); ok { + cohereReq.Truncate = truncate } } diff --git a/core/schemas/providers/cohere/responses.go b/core/schemas/providers/cohere/responses.go new file mode 100644 index 0000000000..dc0d91c1db --- /dev/null +++ b/core/schemas/providers/cohere/responses.go @@ -0,0 +1,424 @@ +package cohere + +import ( + "strings" + "time" + + "github.com/maximhq/bifrost/core/schemas" +) + +// ToCohereResponsesRequest converts a BifrostRequest (Responses structure) to CohereChatRequest +func ToCohereResponsesRequest(bifrostReq *schemas.BifrostResponsesRequest) *CohereChatRequest { + if bifrostReq == nil { + return nil + } + + cohereReq := &CohereChatRequest{ + Model: bifrostReq.Model, + } + + // Map basic parameters + if bifrostReq.Params != nil { + if bifrostReq.Params.MaxOutputTokens != nil { + cohereReq.MaxTokens = bifrostReq.Params.MaxOutputTokens + } + if bifrostReq.Params.Temperature != nil { + cohereReq.Temperature = bifrostReq.Params.Temperature + } + if bifrostReq.Params.TopP != nil { + cohereReq.P = bifrostReq.Params.TopP + } + if bifrostReq.Params.ExtraParams != nil { + if topK, ok := schemas.SafeExtractIntPointer(bifrostReq.Params.ExtraParams["top_k"]); ok { + cohereReq.K = topK + } + if stop, ok := schemas.SafeExtractStringSlicePointer(bifrostReq.Params.ExtraParams["stop"]); ok { + cohereReq.StopSequences = stop + } + if frequencyPenalty, ok := schemas.SafeExtractFloat64Pointer(bifrostReq.Params.ExtraParams["frequency_penalty"]); ok { + cohereReq.FrequencyPenalty = frequencyPenalty + } + if presencePenalty, ok := schemas.SafeExtractFloat64Pointer(bifrostReq.Params.ExtraParams["presence_penalty"]); ok { + cohereReq.PresencePenalty = presencePenalty + } + } + } + + // Convert tools + if bifrostReq.Params.Tools != nil { + var cohereTools []CohereChatRequestTool + for _, tool := range bifrostReq.Params.Tools { + if tool.ResponsesToolFunction != nil && tool.Name != nil { + cohereTool := CohereChatRequestTool{ + Type: "function", + Function: CohereChatRequestFunction{ + Name: *tool.Name, + Description: tool.Description, + Parameters: tool.ResponsesToolFunction.Parameters, + }, + } + cohereTools = append(cohereTools, cohereTool) + } + } + + if len(cohereTools) > 0 { + cohereReq.Tools = &cohereTools + } + } + + // Convert tool choice + if bifrostReq.Params.ToolChoice != nil { + cohereReq.ToolChoice = convertBifrostToolChoiceToCohereToolChoice(*bifrostReq.Params.ToolChoice) + } + + // Process ResponsesInput (which contains the Responses items) + if bifrostReq.Input != nil { + cohereReq.Messages = convertResponsesMessagesToCohereMessages(bifrostReq.Input) + } + + return cohereReq +} + +// ToResponsesBifrostResponse converts CohereChatResponse to BifrostResponse (Responses structure) +func (cohereResp *CohereChatResponse) ToResponsesBifrostResponse() *schemas.BifrostResponse { + if cohereResp == nil { + return nil + } + + bifrostResp := &schemas.BifrostResponse{ + ID: cohereResp.ID, + Object: "response", + ResponsesResponse: &schemas.ResponsesResponse{ + CreatedAt: int(time.Now().Unix()), // Set current timestamp + }, + } + + // Convert usage information + if cohereResp.Usage != nil { + usage := &schemas.LLMUsage{ + ResponsesExtendedResponseUsage: &schemas.ResponsesExtendedResponseUsage{}, + } + + if cohereResp.Usage.Tokens != nil { + if cohereResp.Usage.Tokens.InputTokens != nil { + usage.PromptTokens = int(*cohereResp.Usage.Tokens.InputTokens) + } + if cohereResp.Usage.Tokens.OutputTokens != nil { + usage.CompletionTokens = int(*cohereResp.Usage.Tokens.OutputTokens) + } + usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + } + + bifrostResp.Usage = usage + } + + // Convert output message to Responses format + if cohereResp.Message != nil { + outputMessages := convertCohereMessageToResponsesOutput(*cohereResp.Message) + bifrostResp.ResponsesResponse.Output = outputMessages + } + + return bifrostResp +} + +// Helper functions + +// convertBifrostToolChoiceToCohere converts schemas.ToolChoice to CohereToolChoice +func convertBifrostToolChoiceToCohereToolChoice(toolChoice schemas.ResponsesToolChoice) *CohereToolChoice { + toolChoiceString := toolChoice.ResponsesToolChoiceStr + + if toolChoiceString != nil { + switch *toolChoiceString { + case "none": + choice := ToolChoiceNone + return &choice + case "required", "auto", "function": + choice := ToolChoiceRequired + return &choice + default: + choice := ToolChoiceRequired + return &choice + } + } + + return nil +} + +// convertResponsesMessagesToCohereMessages converts Responses items to Cohere messages +func convertResponsesMessagesToCohereMessages(messages []schemas.ResponsesMessage) []CohereMessage { + var cohereMessages []CohereMessage + var systemContent []string + + for _, msg := range messages { + // Handle nil Type with default + msgType := schemas.ResponsesMessageTypeMessage + if msg.Type != nil { + msgType = *msg.Type + } + + switch msgType { + case schemas.ResponsesMessageTypeMessage: + // Handle nil Role with default + role := "user" + if msg.Role != nil { + role = string(*msg.Role) + } + + if role == "system" { + // Collect system messages separately for Cohere + if msg.Content != nil { + if msg.Content.ContentStr != nil { + systemContent = append(systemContent, *msg.Content.ContentStr) + } else if msg.Content.ContentBlocks != nil { + for _, block := range *msg.Content.ContentBlocks { + if block.Text != nil { + systemContent = append(systemContent, *block.Text) + } + } + } + } + } else { + cohereMsg := CohereMessage{ + Role: role, + } + + // Convert content - only if Content is not nil + if msg.Content != nil { + if msg.Content.ContentStr != nil { + cohereMsg.Content = NewStringContent(*msg.Content.ContentStr) + } else if msg.Content.ContentBlocks != nil { + contentBlocks := convertResponsesMessageContentBlocksToCohere(*msg.Content.ContentBlocks) + cohereMsg.Content = NewBlocksContent(contentBlocks) + } + } + + cohereMessages = append(cohereMessages, cohereMsg) + } + + case "function_call": + // Handle function calls from Responses + assistantMsg := CohereMessage{ + Role: "assistant", + } + + // Extract function call details + var cohereToolCalls []CohereToolCall + toolCall := CohereToolCall{ + Type: "function", + Function: &CohereFunction{}, + } + + if msg.ID != nil { + toolCall.ID = msg.ID + } + + // Get function details from AssistantMessage + if msg.ResponsesToolMessage != nil && msg.ResponsesToolMessage.Arguments != nil { + toolCall.Function.Arguments = *msg.ResponsesToolMessage.Arguments + } + + // Get name from ToolMessage if available + if msg.ResponsesToolMessage != nil && msg.ResponsesToolMessage.Name != nil { + toolCall.Function.Name = msg.ResponsesToolMessage.Name + } + + cohereToolCalls = append(cohereToolCalls, toolCall) + + if len(cohereToolCalls) > 0 { + assistantMsg.ToolCalls = &cohereToolCalls + } + + cohereMessages = append(cohereMessages, assistantMsg) + + case "function_call_output": + // Handle function call outputs + if msg.ResponsesToolMessage != nil && msg.ResponsesToolMessage.CallID != nil { + toolMsg := CohereMessage{ + Role: "tool", + } + + // Extract content from ResponsesFunctionToolCallOutput if Content is not set + // This is needed for OpenAI Responses API which uses an "output" field + content := msg.Content + if content == nil && msg.ResponsesToolMessage.ResponsesFunctionToolCallOutput != nil { + content = &schemas.ResponsesMessageContent{} + if msg.ResponsesToolMessage.ResponsesFunctionToolCallOutput.ResponsesFunctionToolCallOutputStr != nil { + content.ContentStr = msg.ResponsesToolMessage.ResponsesFunctionToolCallOutput.ResponsesFunctionToolCallOutputStr + } else if msg.ResponsesToolMessage.ResponsesFunctionToolCallOutput.ResponsesFunctionToolCallOutputBlocks != nil { + content.ContentBlocks = msg.ResponsesToolMessage.ResponsesFunctionToolCallOutput.ResponsesFunctionToolCallOutputBlocks + } + } + + // Convert content - only if Content is not nil + if content != nil { + if content.ContentStr != nil { + toolMsg.Content = NewStringContent(*content.ContentStr) + } else if content.ContentBlocks != nil { + contentBlocks := convertResponsesMessageContentBlocksToCohere(*content.ContentBlocks) + toolMsg.Content = NewBlocksContent(contentBlocks) + } + } + + toolMsg.ToolCallID = msg.ResponsesToolMessage.CallID + + cohereMessages = append(cohereMessages, toolMsg) + } + } + } + + // Prepend system messages if any + if len(systemContent) > 0 { + systemMsg := CohereMessage{ + Role: "system", + Content: NewStringContent(strings.Join(systemContent, "\n")), + } + cohereMessages = append([]CohereMessage{systemMsg}, cohereMessages...) + } + + return cohereMessages +} + +// convertBifrostContentBlocksToCohere converts Bifrost content blocks to Cohere format +func convertResponsesMessageContentBlocksToCohere(blocks []schemas.ResponsesMessageContentBlock) []CohereContentBlock { + var cohereBlocks []CohereContentBlock + + for _, block := range blocks { + switch block.Type { + case schemas.ResponsesInputMessageContentBlockTypeText: + if block.Text != nil { + cohereBlocks = append(cohereBlocks, CohereContentBlock{ + Type: CohereContentBlockTypeText, + Text: block.Text, + }) + } + case schemas.ResponsesInputMessageContentBlockTypeImage: + if block.ResponsesInputMessageContentBlockImage != nil && block.ResponsesInputMessageContentBlockImage.ImageURL != nil && *block.ResponsesInputMessageContentBlockImage.ImageURL != "" { + cohereBlocks = append(cohereBlocks, CohereContentBlock{ + Type: CohereContentBlockTypeImage, + ImageURL: &CohereImageURL{ + URL: *block.ResponsesInputMessageContentBlockImage.ImageURL, + }, + }) + } + case schemas.ResponsesOutputMessageContentTypeReasoning: + if block.Text != nil { + cohereBlocks = append(cohereBlocks, CohereContentBlock{ + Type: CohereContentBlockTypeThinking, + Thinking: block.Text, + }) + } + } + } + + return cohereBlocks +} + +// convertCohereMessageToResponsesOutput converts Cohere message to Responses output format +func convertCohereMessageToResponsesOutput(cohereMsg CohereMessage) []schemas.ResponsesMessage { + var outputMessages []schemas.ResponsesMessage + + // Handle text content first + if cohereMsg.Content != nil { + var content schemas.ResponsesMessageContent + + var contentBlocks []schemas.ResponsesMessageContentBlock + + if cohereMsg.Content.StringContent != nil { + contentBlocks = append(contentBlocks, schemas.ResponsesMessageContentBlock{ + Type: schemas.ResponsesInputMessageContentBlockTypeText, + Text: cohereMsg.Content.StringContent, + }) + } else if cohereMsg.Content.BlocksContent != nil { + // Convert content blocks + for _, block := range *cohereMsg.Content.BlocksContent { + contentBlocks = append(contentBlocks, convertCohereContentBlockToBifrost(block)) + } + } + content.ContentBlocks = &contentBlocks + + // Create message output + if content.ContentBlocks != nil { + outputMsg := schemas.ResponsesMessage{ + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant), + Content: &content, + Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), + } + + outputMessages = append(outputMessages, outputMsg) + } + } + + // Handle tool calls + if cohereMsg.ToolCalls != nil { + for _, toolCall := range *cohereMsg.ToolCalls { + // Check if Function is nil to avoid nil pointer dereference + if toolCall.Function == nil { + // Skip this tool call if Function is nil + continue + } + + // Safely extract function name and arguments + var functionName *string + var functionArguments *string + + if toolCall.Function.Name != nil { + functionName = toolCall.Function.Name + } else { + // Use empty string if Name is nil + functionName = schemas.Ptr("") + } + + // Arguments is a string, not a pointer, so it's safe to access directly + functionArguments = schemas.Ptr(toolCall.Function.Arguments) + + toolCallMsg := schemas.ResponsesMessage{ + ID: toolCall.ID, + Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall), + Status: schemas.Ptr("completed"), + ResponsesToolMessage: &schemas.ResponsesToolMessage{ + Name: functionName, + CallID: toolCall.ID, + Arguments: functionArguments, + }, + } + + outputMessages = append(outputMessages, toolCallMsg) + } + } + + return outputMessages +} + +// convertCohereContentBlockToBifrost converts CohereContentBlock to schemas.ContentBlock for Responses +func convertCohereContentBlockToBifrost(cohereBlock CohereContentBlock) schemas.ResponsesMessageContentBlock { + switch cohereBlock.Type { + case CohereContentBlockTypeText: + return schemas.ResponsesMessageContentBlock{ + Type: schemas.ResponsesInputMessageContentBlockTypeText, + Text: cohereBlock.Text, + } + case CohereContentBlockTypeImage: + // For images, create a text block describing the image + if cohereBlock.ImageURL == nil { + // Skip invalid image blocks without ImageURL + return schemas.ResponsesMessageContentBlock{} + } + return schemas.ResponsesMessageContentBlock{ + Type: schemas.ResponsesInputMessageContentBlockTypeImage, + ResponsesInputMessageContentBlockImage: &schemas.ResponsesInputMessageContentBlockImage{ + ImageURL: &cohereBlock.ImageURL.URL, + }, + } + case CohereContentBlockTypeThinking: + return schemas.ResponsesMessageContentBlock{ + Type: schemas.ResponsesOutputMessageContentTypeReasoning, + Text: cohereBlock.Thinking, + } + default: + // Fallback to text block + return schemas.ResponsesMessageContentBlock{ + Type: schemas.ResponsesInputMessageContentBlockTypeText, + Text: schemas.Ptr(string(cohereBlock.Type)), + } + } +} diff --git a/core/schemas/providers/cohere/types.go b/core/schemas/providers/cohere/types.go index 5921ac2191..d0b301c3b9 100644 --- a/core/schemas/providers/cohere/types.go +++ b/core/schemas/providers/cohere/types.go @@ -3,30 +3,39 @@ package cohere import ( "encoding/json" "fmt" - - "github.com/maximhq/bifrost/core/schemas" ) // ==================== REQUEST TYPES ==================== // CohereChatRequest represents a Cohere chat completion request type CohereChatRequest struct { - Model string `json:"model"` // Required: Model to use for chat completion - Messages []CohereMessage `json:"messages"` // Required: Array of message objects - Tools *[]schemas.Tool `json:"tools,omitempty"` // Optional: Tools available for the model - ToolChoice *CohereToolChoice `json:"tool_choice,omitempty"` // Optional: Tool choice configuration - Temperature *float64 `json:"temperature,omitempty"` // Optional: Sampling temperature - P *float64 `json:"p,omitempty"` // Optional: Top-p sampling - K *int `json:"k,omitempty"` // Optional: Top-k sampling - MaxTokens *int `json:"max_tokens,omitempty"` // Optional: Maximum tokens to generate - StopSequences *[]string `json:"stop_sequences,omitempty"` // Optional: Stop sequences - FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` // Optional: Frequency penalty - PresencePenalty *float64 `json:"presence_penalty,omitempty"` // Optional: Presence penalty - Stream *bool `json:"stream,omitempty"` // Optional: Enable streaming - SafetyMode *string `json:"safety_mode,omitempty"` // Optional: Safety mode - LogProbs *bool `json:"log_probs,omitempty"` // Optional: Log probabilities - StrictToolChoice *bool `json:"strict_tool_choice,omitempty"` // Optional: Strict tool choice - Thinking *CohereThinking `json:"thinking,omitempty"` // Optional: Reasoning configuration + Model string `json:"model"` // Required: Model to use for chat completion + Messages []CohereMessage `json:"messages"` // Required: Array of message objects + Tools *[]CohereChatRequestTool `json:"tools,omitempty"` // Optional: Tools available for the model + ToolChoice *CohereToolChoice `json:"tool_choice,omitempty"` // Optional: Tool choice configuration + Temperature *float64 `json:"temperature,omitempty"` // Optional: Sampling temperature + P *float64 `json:"p,omitempty"` // Optional: Top-p sampling + K *int `json:"k,omitempty"` // Optional: Top-k sampling + MaxTokens *int `json:"max_tokens,omitempty"` // Optional: Maximum tokens to generate + StopSequences *[]string `json:"stop_sequences,omitempty"` // Optional: Stop sequences + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` // Optional: Frequency penalty + PresencePenalty *float64 `json:"presence_penalty,omitempty"` // Optional: Presence penalty + Stream *bool `json:"stream,omitempty"` // Optional: Enable streaming + SafetyMode *string `json:"safety_mode,omitempty"` // Optional: Safety mode + LogProbs *bool `json:"log_probs,omitempty"` // Optional: Log probabilities + StrictToolChoice *bool `json:"strict_tool_choice,omitempty"` // Optional: Strict tool choice + Thinking *CohereThinking `json:"thinking,omitempty"` // Optional: Reasoning configuration +} + +type CohereChatRequestTool struct { + Type string `json:"type"` // always "function" + Function CohereChatRequestFunction `json:"function"` +} + +type CohereChatRequestFunction struct { + Name string `json:"name"` // Function name + Parameters interface{} `json:"parameters,omitempty"` // Function parameters (JSON string) + Description *string `json:"description,omitempty"` // Optional: Function description } // CohereMessage represents a message in Cohere format @@ -111,10 +120,19 @@ func (c *CohereMessageContent) GetBlocks() *[]CohereContentBlock { return c.BlocksContent } +type CohereContentBlockType string + +const ( + CohereContentBlockTypeText CohereContentBlockType = "text" + CohereContentBlockTypeImage CohereContentBlockType = "image_url" + CohereContentBlockTypeThinking CohereContentBlockType = "thinking" + CohereContentBlockTypeDocument CohereContentBlockType = "document" +) + // CohereContentBlock represents a content block in Cohere format // This is a union type that can be text, image_url, thinking, or document type CohereContentBlock struct { - Type string `json:"type"` // Required: Content block type + Type CohereContentBlockType `json:"type"` // Required: Content block type // Text content block Text *string `json:"text,omitempty"` diff --git a/core/schemas/providers/gemini/chat.go b/core/schemas/providers/gemini/chat.go index 45f86b6062..ade4e982a2 100644 --- a/core/schemas/providers/gemini/chat.go +++ b/core/schemas/providers/gemini/chat.go @@ -10,7 +10,7 @@ import ( "github.com/maximhq/bifrost/core/schemas" ) -func (r *GeminiGenerationRequest) ToBifrostRequest() *schemas.BifrostRequest { +func (r *GeminiGenerationRequest) ToBifrostRequest() *schemas.BifrostChatRequest { provider, model := schemas.ParseModelString(r.Model, schemas.Gemini) if provider == schemas.Vertex && !r.IsEmbedding { @@ -20,67 +20,14 @@ func (r *GeminiGenerationRequest) ToBifrostRequest() *schemas.BifrostRequest { } } - // Handle embedding requests - if r.IsEmbedding { - // Extract texts from content (embedding requests) or contents (chat completion requests) - var texts []string - - // Check for batch embedding requests first - if len(r.Requests) > 0 { - for _, req := range r.Requests { - if req.Content != nil { - for _, part := range req.Content.Parts { - if part.Text != "" { - texts = append(texts, part.Text) - } - } - } - } - } - - // Fallback to contents (plural) for backward compatibility - if len(texts) == 0 { - for _, content := range r.Contents { - for _, part := range content.Parts { - if part.Text != "" { - texts = append(texts, part.Text) - } - } - } - } - - // Create embedding input - embeddingInput := &schemas.EmbeddingInput{ - Texts: texts, - } - - bifrostReq := &schemas.BifrostRequest{ - Provider: provider, - Model: model, - Input: schemas.RequestInput{ - EmbeddingInput: embeddingInput, - }, - } - - // Convert embedding parameters - params := r.convertEmbeddingParameters() - if params != nil { - bifrostReq.Params = params - } - - return bifrostReq - } - // Handle chat completion requests - bifrostReq := &schemas.BifrostRequest{ + bifrostReq := &schemas.BifrostChatRequest{ Provider: provider, Model: model, - Input: schemas.RequestInput{ - ChatCompletionInput: &[]schemas.BifrostMessage{}, - }, + Input: []schemas.ChatMessage{}, } - messages := []schemas.BifrostMessage{} + messages := []schemas.ChatMessage{} allGenAiMessages := []Content{} if r.SystemInstruction != nil { @@ -96,8 +43,8 @@ func (r *GeminiGenerationRequest) ToBifrostRequest() *schemas.BifrostRequest { } // Handle multiple parts - collect all content and tool calls - var toolCalls []schemas.ToolCall - var contentBlocks []schemas.ContentBlock + var toolCalls []schemas.ChatAssistantMessageToolCall + var contentBlocks []schemas.ChatContentBlock var thoughtStr string // Track thought content for assistant/model for _, part := range content.Parts { @@ -105,28 +52,33 @@ func (r *GeminiGenerationRequest) ToBifrostRequest() *schemas.BifrostRequest { case part.Text != "": // Handle thought content specially for assistant messages if part.Thought && - (content.Role == string(schemas.ModelChatMessageRoleAssistant) || content.Role == string(RoleModel)) { + (content.Role == string(schemas.ChatMessageRoleAssistant) || content.Role == string(RoleModel)) { thoughtStr = thoughtStr + part.Text + "\n" } else { - contentBlocks = append(contentBlocks, schemas.ContentBlock{ - Type: schemas.ContentBlockTypeText, + contentBlocks = append(contentBlocks, schemas.ChatContentBlock{ + Type: schemas.ChatContentBlockTypeText, Text: &part.Text, }) } case part.FunctionCall != nil: // Only add function calls for assistant messages - if content.Role == string(schemas.ModelChatMessageRoleAssistant) || content.Role == string(RoleModel) { + if content.Role == string(schemas.ChatMessageRoleAssistant) || content.Role == string(RoleModel) { jsonArgs, err := json.Marshal(part.FunctionCall.Args) if err != nil { jsonArgs = []byte(fmt.Sprintf("%v", part.FunctionCall.Args)) } - id := part.FunctionCall.ID // create local copy name := part.FunctionCall.Name // create local copy - toolCall := schemas.ToolCall{ - ID: schemas.Ptr(id), - Type: schemas.Ptr(string(schemas.ToolChoiceTypeFunction)), - Function: schemas.FunctionCall{ + // Gemini primarily works with function names for correlation + // Use ID if provided, otherwise fallback to name for stable correlation + callID := name + if strings.TrimSpace(part.FunctionCall.ID) != "" { + callID = part.FunctionCall.ID + } + toolCall := schemas.ChatAssistantMessageToolCall{ + ID: schemas.Ptr(callID), + Type: schemas.Ptr(string(schemas.ChatToolChoiceTypeFunction)), + Function: schemas.ChatAssistantMessageToolCallFunction{ Name: &name, Arguments: string(jsonArgs), }, @@ -141,13 +93,28 @@ func (r *GeminiGenerationRequest) ToBifrostRequest() *schemas.BifrostRequest { responseContent = []byte(fmt.Sprintf("%v", part.FunctionResponse.Response)) } - toolResponseMsg := schemas.BifrostMessage{ - Role: schemas.ModelChatMessageRoleTool, - Content: schemas.MessageContent{ + // Correlate with the function call: prefer ID if available, otherwise use name + callID := part.FunctionResponse.Name + if strings.TrimSpace(part.FunctionResponse.ID) != "" { + callID = part.FunctionResponse.ID + } else { + // Fallback: correlate with the prior function call by name to reuse its ID + for _, tc := range toolCalls { + if tc.Function.Name != nil && *tc.Function.Name == part.FunctionResponse.Name && + tc.ID != nil && *tc.ID != "" { + callID = *tc.ID + break + } + } + } + + toolResponseMsg := schemas.ChatMessage{ + Role: schemas.ChatMessageRoleTool, + Content: schemas.ChatMessageContent{ ContentStr: schemas.Ptr(string(responseContent)), }, - ToolMessage: &schemas.ToolMessage{ - ToolCallID: &part.FunctionResponse.Name, + ChatToolMessage: &schemas.ChatToolMessage{ + ToolCallID: &callID, }, } @@ -156,9 +123,9 @@ func (r *GeminiGenerationRequest) ToBifrostRequest() *schemas.BifrostRequest { case part.InlineData != nil: // Handle inline images/media - only append if it's actually an image if isImageMimeType(part.InlineData.MIMEType) { - contentBlocks = append(contentBlocks, schemas.ContentBlock{ - Type: schemas.ContentBlockTypeImage, - ImageURL: &schemas.ImageURLStruct{ + contentBlocks = append(contentBlocks, schemas.ChatContentBlock{ + Type: schemas.ChatContentBlockTypeImage, + ImageURLStruct: &schemas.ChatInputImage{ URL: fmt.Sprintf("data:%s;base64,%s", part.InlineData.MIMEType, base64.StdEncoding.EncodeToString(part.InlineData.Data)), }, }) @@ -167,9 +134,9 @@ func (r *GeminiGenerationRequest) ToBifrostRequest() *schemas.BifrostRequest { case part.FileData != nil: // Handle file data - only append if it's actually an image if isImageMimeType(part.FileData.MIMEType) { - contentBlocks = append(contentBlocks, schemas.ContentBlock{ - Type: schemas.ContentBlockTypeImage, - ImageURL: &schemas.ImageURLStruct{ + contentBlocks = append(contentBlocks, schemas.ChatContentBlock{ + Type: schemas.ChatContentBlockTypeImage, + ImageURLStruct: &schemas.ChatInputImage{ URL: part.FileData.FileURI, }, }) @@ -178,16 +145,16 @@ func (r *GeminiGenerationRequest) ToBifrostRequest() *schemas.BifrostRequest { case part.ExecutableCode != nil: // Handle executable code as text content codeText := fmt.Sprintf("```%s\n%s\n```", part.ExecutableCode.Language, part.ExecutableCode.Code) - contentBlocks = append(contentBlocks, schemas.ContentBlock{ - Type: schemas.ContentBlockTypeText, + contentBlocks = append(contentBlocks, schemas.ChatContentBlock{ + Type: schemas.ChatContentBlockTypeText, Text: &codeText, }) case part.CodeExecutionResult != nil: // Handle code execution results as text content resultText := fmt.Sprintf("Code execution result (%s):\n%s", part.CodeExecutionResult.Outcome, part.CodeExecutionResult.Output) - contentBlocks = append(contentBlocks, schemas.ContentBlock{ - Type: schemas.ContentBlockTypeText, + contentBlocks = append(contentBlocks, schemas.ChatContentBlock{ + Type: schemas.ChatContentBlockTypeText, Text: &resultText, }) } @@ -196,31 +163,28 @@ func (r *GeminiGenerationRequest) ToBifrostRequest() *schemas.BifrostRequest { // Only create message if there's actual content, tool calls, or thought content if len(contentBlocks) > 0 || len(toolCalls) > 0 || thoughtStr != "" { // Create main message with content blocks - bifrostMsg := schemas.BifrostMessage{ - Role: func(r string) schemas.ModelChatMessageRole { + bifrostMsg := schemas.ChatMessage{ + Role: func(r string) schemas.ChatMessageRole { if r == string(RoleModel) { // GenAI's internal alias - return schemas.ModelChatMessageRoleAssistant + return schemas.ChatMessageRoleAssistant } - return schemas.ModelChatMessageRole(r) + return schemas.ChatMessageRole(r) }(content.Role), } // Set content only if there are content blocks if len(contentBlocks) > 0 { - bifrostMsg.Content = schemas.MessageContent{ + bifrostMsg.Content = schemas.ChatMessageContent{ ContentBlocks: &contentBlocks, } } // Set assistant-specific fields for assistant/model messages - if content.Role == string(schemas.ModelChatMessageRoleAssistant) || content.Role == string(RoleModel) { + if content.Role == string(schemas.ChatMessageRoleAssistant) || content.Role == string(RoleModel) { if len(toolCalls) > 0 || thoughtStr != "" { - bifrostMsg.AssistantMessage = &schemas.AssistantMessage{} + bifrostMsg.ChatAssistantMessage = &schemas.ChatAssistantMessage{} if len(toolCalls) > 0 { - bifrostMsg.AssistantMessage.ToolCalls = &toolCalls - } - if thoughtStr != "" { - bifrostMsg.AssistantMessage.Thought = &thoughtStr + bifrostMsg.ChatAssistantMessage.ToolCalls = &toolCalls } } } @@ -229,10 +193,10 @@ func (r *GeminiGenerationRequest) ToBifrostRequest() *schemas.BifrostRequest { } } - bifrostReq.Input.ChatCompletionInput = &messages + bifrostReq.Input = messages // Convert generation config to parameters - if params := r.convertGenerationConfigToParams(); params != nil { + if params := r.convertGenerationConfigToChatParameters(); params != nil { bifrostReq.Params = params } @@ -264,20 +228,21 @@ func (r *GeminiGenerationRequest) ToBifrostRequest() *schemas.BifrostRequest { if len(r.Tools) > 0 { ensureExtraParams(bifrostReq) - tools := make([]schemas.Tool, 0, len(r.Tools)) + tools := make([]schemas.ChatTool, 0, len(r.Tools)) for _, tool := range r.Tools { if len(tool.FunctionDeclarations) > 0 { for _, fn := range tool.FunctionDeclarations { - bifrostTool := schemas.Tool{ - Type: "function", - Function: schemas.Function{ + bifrostTool := schemas.ChatTool{ + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ Name: fn.Name, - Description: fn.Description, + Description: schemas.Ptr(fn.Description), }, } // Convert parameters schema if present if fn.Parameters != nil { - bifrostTool.Function.Parameters = r.convertSchemaToFunctionParameters(fn.Parameters) + params := r.convertSchemaToFunctionParameters(fn.Parameters) + bifrostTool.Function.Parameters = ¶ms } tools = append(tools, bifrostTool) } @@ -295,7 +260,7 @@ func (r *GeminiGenerationRequest) ToBifrostRequest() *schemas.BifrostRequest { } if len(tools) > 0 { - bifrostReq.Params.Tools = &tools + bifrostReq.Params.Tools = tools } } @@ -308,8 +273,8 @@ func (r *GeminiGenerationRequest) ToBifrostRequest() *schemas.BifrostRequest { return bifrostReq } -// ToGeminiGenerationRequest converts a BifrostRequest to Gemini's generation request format -func ToGeminiGenerationRequest(bifrostReq *schemas.BifrostRequest, responseModalities []string) *GeminiGenerationRequest { +// ToGeminiChatGenerationRequest converts a BifrostChatRequest to Gemini's generation request format for chat completion +func ToGeminiChatGenerationRequest(bifrostReq *schemas.BifrostChatRequest, responseModalities []string) *GeminiGenerationRequest { if bifrostReq == nil { return nil } @@ -324,8 +289,8 @@ func ToGeminiGenerationRequest(bifrostReq *schemas.BifrostRequest, responseModal geminiReq.GenerationConfig = convertParamsToGenerationConfig(bifrostReq.Params, responseModalities) // Handle tool-related parameters - if bifrostReq.Params.Tools != nil && len(*bifrostReq.Params.Tools) > 0 { - geminiReq.Tools = convertBifrostToolsToGemini(*bifrostReq.Params.Tools) + if len(bifrostReq.Params.Tools) > 0 { + geminiReq.Tools = convertBifrostToolsToGemini(bifrostReq.Params.Tools) // Convert tool choice to tool config if bifrostReq.Params.ToolChoice != nil { @@ -336,26 +301,24 @@ func ToGeminiGenerationRequest(bifrostReq *schemas.BifrostRequest, responseModal // Handle extra parameters if bifrostReq.Params.ExtraParams != nil { // Safety settings - if safetySettings, ok := bifrostReq.Params.ExtraParams["safety_settings"]; ok { + if safetySettings, ok := schemas.SafeExtractFromMap(bifrostReq.Params.ExtraParams, "safety_settings"); ok { if settings, ok := safetySettings.([]SafetySetting); ok { geminiReq.SafetySettings = settings } } // Cached content - if cachedContent, ok := bifrostReq.Params.ExtraParams["cached_content"].(string); ok { + if cachedContent, ok := schemas.SafeExtractString(bifrostReq.Params.ExtraParams["cached_content"]); ok { geminiReq.CachedContent = cachedContent } // Response modalities - if modalities, ok := bifrostReq.Params.ExtraParams["response_modalities"]; ok { - if modalitySlice, ok := modalities.([]string); ok { - geminiReq.ResponseModalities = modalitySlice - } + if modalities, ok := schemas.SafeExtractStringSlice(bifrostReq.Params.ExtraParams["response_modalities"]); ok { + geminiReq.ResponseModalities = modalities } // Labels - if labels, ok := bifrostReq.Params.ExtraParams["labels"]; ok { + if labels, ok := schemas.SafeExtractFromMap(bifrostReq.Params.ExtraParams, "labels"); ok { if labelMap, ok := labels.(map[string]string); ok { geminiReq.Labels = labelMap } @@ -363,56 +326,8 @@ func ToGeminiGenerationRequest(bifrostReq *schemas.BifrostRequest, responseModal } } - // Convert input based on type - if bifrostReq.Input.SpeechInput != nil { - // Speech/TTS request - geminiReq.Contents = []CustomContent{ - { - Parts: []*CustomPart{ - { - Text: bifrostReq.Input.SpeechInput.Input, - }, - }, - }, - } - - // Add speech config to generation config - addSpeechConfigToGenerationConfig(&geminiReq.GenerationConfig, bifrostReq.Input.SpeechInput.VoiceConfig) - - } else if bifrostReq.Input.TranscriptionInput != nil { - var prompt string - if bifrostReq.Input.TranscriptionInput.Prompt != nil { - prompt = *bifrostReq.Input.TranscriptionInput.Prompt - } else { - prompt = "Generate a transcript of the speech." - } - // Transcription request - parts := []*CustomPart{ - { - Text: prompt, - }, - } - - // Add audio file if present - if len(bifrostReq.Input.TranscriptionInput.File) > 0 { - parts = append(parts, &CustomPart{ - InlineData: &CustomBlob{ - MIMEType: detectAudioMimeType(bifrostReq.Input.TranscriptionInput.File), - Data: bifrostReq.Input.TranscriptionInput.File, - }, - }) - } - - geminiReq.Contents = []CustomContent{ - { - Parts: parts, - }, - } - - } else if bifrostReq.Input.ChatCompletionInput != nil { - // Chat completion request - convert messages to Gemini format - geminiReq.Contents = convertBifrostMessagesToGemini(*bifrostReq.Input.ChatCompletionInput) - } + // Convert chat completion messages to Gemini format + geminiReq.Contents = convertBifrostMessagesToGemini(bifrostReq.Input) return geminiReq } @@ -599,8 +514,8 @@ func ToGeminiGenerationResponse(bifrostResp *schemas.BifrostResponse) interface{ } // Handle tool calls - if choice.Message.AssistantMessage != nil && choice.Message.AssistantMessage.ToolCalls != nil { - for _, toolCall := range *choice.Message.AssistantMessage.ToolCalls { + if choice.Message.ChatAssistantMessage != nil && choice.Message.ChatAssistantMessage.ToolCalls != nil { + for _, toolCall := range *choice.Message.ChatAssistantMessage.ToolCalls { argsMap := make(map[string]interface{}) if toolCall.Function.Arguments != "" { json.Unmarshal([]byte(toolCall.Function.Arguments), &argsMap) @@ -618,14 +533,6 @@ func ToGeminiGenerationResponse(bifrostResp *schemas.BifrostResponse) interface{ } } - // Handle thinking content if present - if choice.Message.AssistantMessage != nil && choice.Message.AssistantMessage.Thought != nil && *choice.Message.AssistantMessage.Thought != "" { - parts = append(parts, &Part{ - Text: *choice.Message.AssistantMessage.Thought, - Thought: true, - }) - } - if len(parts) > 0 { candidate.Content = &Content{ Parts: parts, diff --git a/core/schemas/providers/gemini/embedding.go b/core/schemas/providers/gemini/embedding.go index d0194cc8d6..6b5a6e1a1e 100644 --- a/core/schemas/providers/gemini/embedding.go +++ b/core/schemas/providers/gemini/embedding.go @@ -4,13 +4,13 @@ import ( "github.com/maximhq/bifrost/core/schemas" ) -// FromBifrostEmbeddingRequest converts a BifrostRequest with embedding input to Gemini's embedding request format -func ToGeminiEmbeddingRequest(bifrostReq *schemas.BifrostRequest) *GeminiEmbeddingRequest { - if bifrostReq == nil || bifrostReq.Input.EmbeddingInput == nil { +// ToGeminiEmbeddingRequest converts a BifrostRequest with embedding input to Gemini's embedding request format +func ToGeminiEmbeddingRequest(bifrostReq *schemas.BifrostEmbeddingRequest) *GeminiEmbeddingRequest { + if bifrostReq == nil || (bifrostReq.Input.Text == nil && bifrostReq.Input.Texts == nil) { return nil } - embeddingInput := bifrostReq.Input.EmbeddingInput + embeddingInput := bifrostReq.Input // Get the text to embed var text string @@ -45,11 +45,11 @@ func ToGeminiEmbeddingRequest(bifrostReq *schemas.BifrostRequest) *GeminiEmbeddi // Handle extra parameters if bifrostReq.Params.ExtraParams != nil { - if taskType, ok := bifrostReq.Params.ExtraParams["taskType"].(string); ok { - request.TaskType = &taskType + if taskType, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["taskType"]); ok { + request.TaskType = taskType } - if title, ok := bifrostReq.Params.ExtraParams["title"].(string); ok { - request.Title = &title + if title, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["title"]); ok { + request.Title = title } } } diff --git a/core/schemas/providers/gemini/responses.go b/core/schemas/providers/gemini/responses.go new file mode 100644 index 0000000000..30ce270ee9 --- /dev/null +++ b/core/schemas/providers/gemini/responses.go @@ -0,0 +1,721 @@ +package gemini + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "strings" + + "github.com/bytedance/sonic" + "github.com/maximhq/bifrost/core/schemas" +) + +func ToGeminiResponsesRequest(bifrostReq *schemas.BifrostResponsesRequest) (*GeminiGenerationRequest, error) { + if bifrostReq == nil { + return nil, nil + } + + // Create the base Gemini generation request + geminiReq := &GeminiGenerationRequest{ + Model: bifrostReq.Model, + } + + // Convert parameters to generation config + if bifrostReq.Params != nil { + geminiReq.GenerationConfig = convertParamsToGenerationConfigResponses(bifrostReq.Params) + + // Handle tool-related parameters + if len(bifrostReq.Params.Tools) > 0 { + geminiReq.Tools = convertResponsesToolsToGemini(bifrostReq.Params.Tools) + + // Convert tool choice if present + if bifrostReq.Params.ToolChoice != nil { + geminiReq.ToolConfig = convertResponsesToolChoiceToGemini(bifrostReq.Params.ToolChoice) + } + } + } + + // Convert ResponsesInput messages to Gemini contents + if bifrostReq.Input != nil { + contents, systemInstruction, err := convertResponsesMessagesToGeminiContents(bifrostReq.Input) + if err != nil { + return nil, fmt.Errorf("failed to convert messages: %w", err) + } + geminiReq.Contents = contents + + if systemInstruction != nil { + geminiReq.SystemInstruction = systemInstruction + } + } + + return geminiReq, nil +} + +func (response *GenerateContentResponse) ToResponsesBifrostResponse() *schemas.BifrostResponse { + if response == nil { + return nil + } + + // Parse model string to get provider and model + + // Create the BifrostResponse with Responses structure + bifrostResp := &schemas.BifrostResponse{ + ID: response.ResponseID, + Object: "response", + Model: response.ModelVersion, + } + + // Convert usage information + if response.UsageMetadata != nil { + bifrostResp.Usage = &schemas.LLMUsage{ + TotalTokens: int(response.UsageMetadata.TotalTokenCount), + ResponsesExtendedResponseUsage: &schemas.ResponsesExtendedResponseUsage{ + InputTokens: int(response.UsageMetadata.PromptTokenCount), + OutputTokens: int(response.UsageMetadata.CandidatesTokenCount), + }, + } + + // Handle cached tokens if present + if response.UsageMetadata.CachedContentTokenCount > 0 { + if bifrostResp.Usage.ResponsesExtendedResponseUsage.InputTokensDetails == nil { + bifrostResp.Usage.ResponsesExtendedResponseUsage.InputTokensDetails = &schemas.ResponsesResponseInputTokens{} + } + bifrostResp.Usage.ResponsesExtendedResponseUsage.InputTokensDetails.CachedTokens = int(response.UsageMetadata.CachedContentTokenCount) + } + } + + // Convert candidates to Responses output messages + if len(response.Candidates) > 0 { + outputMessages := convertGeminiCandidatesToResponsesOutput(response.Candidates) + if len(outputMessages) > 0 { + // Initialize ResponsesResponse if not already allocated + if bifrostResp.ResponsesResponse == nil { + bifrostResp.ResponsesResponse = &schemas.ResponsesResponse{} + } + bifrostResp.ResponsesResponse.Output = outputMessages + } + } + + return bifrostResp +} + +// Helper functions for Responses conversion +// convertGeminiCandidatesToResponsesOutput converts Gemini candidates to Responses output messages +func convertGeminiCandidatesToResponsesOutput(candidates []*Candidate) []schemas.ResponsesMessage { + var messages []schemas.ResponsesMessage + + for _, candidate := range candidates { + if candidate.Content == nil || len(candidate.Content.Parts) == 0 { + continue + } + + for _, part := range candidate.Content.Parts { + // Handle different types of parts + switch { + case part.Text != "": + // Regular text message + msg := schemas.ResponsesMessage{ + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant), + Content: &schemas.ResponsesMessageContent{ + ContentStr: &part.Text, + }, + Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), + } + messages = append(messages, msg) + + case part.Thought: + // Thinking/reasoning message + if part.Text != "" { + msg := schemas.ResponsesMessage{ + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant), + Content: &schemas.ResponsesMessageContent{ + ContentStr: &part.Text, + }, + Type: schemas.Ptr(schemas.ResponsesMessageTypeReasoning), + } + messages = append(messages, msg) + } + + case part.FunctionCall != nil: + // Function call message + // Convert Args to JSON string if it's not already a string + argumentsStr := "" + if part.FunctionCall.Args != nil { + if argsBytes, err := json.Marshal(part.FunctionCall.Args); err == nil { + argumentsStr = string(argsBytes) + } + } + + // Create copies of the values to avoid range loop variable capture + functionCallID := part.FunctionCall.ID + functionCallName := part.FunctionCall.Name + + msg := schemas.ResponsesMessage{ + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant), + Content: &schemas.ResponsesMessageContent{}, + Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCall), + ResponsesToolMessage: &schemas.ResponsesToolMessage{ + CallID: &functionCallID, + Name: &functionCallName, + Arguments: &argumentsStr, + }, + } + messages = append(messages, msg) + + case part.FunctionResponse != nil: + // Function response message + output := "" + if part.FunctionResponse.Response != nil { + if outputVal, ok := part.FunctionResponse.Response["output"]; ok { + if outputStr, ok := outputVal.(string); ok { + output = outputStr + } + } + } + + msg := schemas.ResponsesMessage{ + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant), + Content: &schemas.ResponsesMessageContent{ + ContentStr: &output, + }, + Type: schemas.Ptr(schemas.ResponsesMessageTypeFunctionCallOutput), + ResponsesToolMessage: &schemas.ResponsesToolMessage{ + CallID: schemas.Ptr(part.FunctionResponse.ID), + }, + } + + // Also set the tool name if present (Gemini associates on name) + if name := strings.TrimSpace(part.FunctionResponse.Name); name != "" { + msg.ResponsesToolMessage.Name = schemas.Ptr(name) + } + + messages = append(messages, msg) + + case part.InlineData != nil: + // Handle inline data (images, audio, etc.) + contentBlocks := []schemas.ResponsesMessageContentBlock{ + { + Type: func() schemas.ResponsesMessageContentBlockType { + if strings.HasPrefix(part.InlineData.MIMEType, "image/") { + return schemas.ResponsesInputMessageContentBlockTypeImage + } else if strings.HasPrefix(part.InlineData.MIMEType, "audio/") { + return schemas.ResponsesInputMessageContentBlockTypeAudio + } + return schemas.ResponsesInputMessageContentBlockTypeText + }(), + ResponsesInputMessageContentBlockImage: func() *schemas.ResponsesInputMessageContentBlockImage { + if strings.HasPrefix(part.InlineData.MIMEType, "image/") { + return &schemas.ResponsesInputMessageContentBlockImage{ + ImageURL: schemas.Ptr("data:" + part.InlineData.MIMEType + ";base64," + base64.StdEncoding.EncodeToString(part.InlineData.Data)), + } + } + return nil + }(), + Audio: func() *schemas.ResponsesInputMessageContentBlockAudio { + if strings.HasPrefix(part.InlineData.MIMEType, "audio/") { + // Extract format from MIME type (e.g., "audio/wav" -> "wav") + format := strings.TrimPrefix(part.InlineData.MIMEType, "audio/") + return &schemas.ResponsesInputMessageContentBlockAudio{ + Format: format, + Data: base64.StdEncoding.EncodeToString(part.InlineData.Data), + } + } + return nil + }(), + }, + } + + msg := schemas.ResponsesMessage{ + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant), + Content: &schemas.ResponsesMessageContent{ + ContentBlocks: &contentBlocks, + }, + Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), + } + messages = append(messages, msg) + + case part.FileData != nil: + // Handle file data + block := schemas.ResponsesMessageContentBlock{ + Type: schemas.ResponsesInputMessageContentBlockTypeFile, + ResponsesInputMessageContentBlockFile: &schemas.ResponsesInputMessageContentBlockFile{ + FileURL: schemas.Ptr(part.FileData.FileURI), + }, + } + if strings.HasPrefix(part.FileData.MIMEType, "image/") { + block.Type = schemas.ResponsesInputMessageContentBlockTypeImage + block.ResponsesInputMessageContentBlockImage = &schemas.ResponsesInputMessageContentBlockImage{ + ImageURL: schemas.Ptr(part.FileData.FileURI), + } + } + contentBlocks := []schemas.ResponsesMessageContentBlock{block} + + msg := schemas.ResponsesMessage{ + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant), + Content: &schemas.ResponsesMessageContent{ + ContentBlocks: &contentBlocks, + }, + Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), + } + messages = append(messages, msg) + + case part.CodeExecutionResult != nil: + // Handle code execution results + output := part.CodeExecutionResult.Output + if part.CodeExecutionResult.Outcome != OutcomeOK { + output = "Error: " + output + } + + msg := schemas.ResponsesMessage{ + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant), + Content: &schemas.ResponsesMessageContent{ + ContentStr: &output, + }, + Type: schemas.Ptr(schemas.ResponsesMessageTypeCodeInterpreterCall), + } + messages = append(messages, msg) + + case part.ExecutableCode != nil: + // Handle executable code + codeContent := "```" + part.ExecutableCode.Language + "\n" + part.ExecutableCode.Code + "\n```" + + msg := schemas.ResponsesMessage{ + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleAssistant), + Content: &schemas.ResponsesMessageContent{ + ContentStr: &codeContent, + }, + Type: schemas.Ptr(schemas.ResponsesMessageTypeMessage), + } + messages = append(messages, msg) + } + } + } + + return messages +} + +// convertParamsToGenerationConfigResponses converts ChatParameters to GenerationConfig for Responses +func convertParamsToGenerationConfigResponses(params *schemas.ResponsesParameters) GenerationConfig { + config := GenerationConfig{} + + if params.Temperature != nil { + config.Temperature = schemas.Ptr(float64(*params.Temperature)) + } + if params.TopP != nil { + config.TopP = schemas.Ptr(float64(*params.TopP)) + } + if params.MaxOutputTokens != nil { + config.MaxOutputTokens = int32(*params.MaxOutputTokens) + } + + if params.ExtraParams != nil { + if topK, ok := params.ExtraParams["top_k"]; ok { + if val, success := schemas.SafeExtractInt(topK); success { + config.TopK = schemas.Ptr(val) + } + } + if frequencyPenalty, ok := params.ExtraParams["frequency_penalty"]; ok { + if val, success := schemas.SafeExtractFloat64(frequencyPenalty); success { + config.FrequencyPenalty = schemas.Ptr(val) + } + } + if presencePenalty, ok := params.ExtraParams["presence_penalty"]; ok { + if val, success := schemas.SafeExtractFloat64(presencePenalty); success { + config.PresencePenalty = schemas.Ptr(val) + } + } + if stopSequences, ok := params.ExtraParams["stop_sequences"]; ok { + if val, success := schemas.SafeExtractStringSlice(stopSequences); success { + config.StopSequences = val + } + } + } + + return config +} + +// convertResponsesToolsToGemini converts Responses tools to Gemini tools +func convertResponsesToolsToGemini(tools []schemas.ResponsesTool) []Tool { + var geminiTools []Tool + + for _, tool := range tools { + if tool.Type == "function" { + geminiTool := Tool{} + + // Extract function information from ResponsesExtendedTool + if tool.ResponsesToolFunction != nil { + if tool.Name != nil && tool.ResponsesToolFunction != nil { + funcDecl := &FunctionDeclaration{ + Name: *tool.Name, + Description: func() string { + if tool.Description != nil { + return *tool.Description + } + return "" + }(), + Parameters: func() *Schema { + if tool.ResponsesToolFunction.Parameters != nil { + return convertFunctionParametersToGeminiSchema(*tool.ResponsesToolFunction.Parameters) + } + return nil + }(), + } + geminiTool.FunctionDeclarations = []*FunctionDeclaration{funcDecl} + } + } + + if len(geminiTool.FunctionDeclarations) > 0 { + geminiTools = append(geminiTools, geminiTool) + } + } + } + + return geminiTools +} + +// convertResponsesToolChoiceToGemini converts Responses tool choice to Gemini tool config +func convertResponsesToolChoiceToGemini(toolChoice *schemas.ResponsesToolChoice) ToolConfig { + config := ToolConfig{} + + if toolChoice.ResponsesToolChoiceStruct != nil { + funcConfig := &FunctionCallingConfig{} + ext := toolChoice.ResponsesToolChoiceStruct + + if ext.Mode != nil { + switch *ext.Mode { + case "auto": + funcConfig.Mode = FunctionCallingConfigModeAuto + case "required": + funcConfig.Mode = FunctionCallingConfigModeAny + case "none": + funcConfig.Mode = FunctionCallingConfigModeNone + } + } + + if ext.Name != nil { + funcConfig.Mode = FunctionCallingConfigModeAny + funcConfig.AllowedFunctionNames = []string{*ext.Name} + } + + config.FunctionCallingConfig = funcConfig + return config + } + + // Handle string-based tool choice modes + if toolChoice.ResponsesToolChoiceStr != nil { + funcConfig := &FunctionCallingConfig{} + switch *toolChoice.ResponsesToolChoiceStr { + case "none": + funcConfig.Mode = FunctionCallingConfigModeNone + case "required", "any": + funcConfig.Mode = FunctionCallingConfigModeAny + default: // "auto" or any other value + funcConfig.Mode = FunctionCallingConfigModeAuto + } + config.FunctionCallingConfig = funcConfig + } + + return config +} + +// convertFunctionParametersToGeminiSchema converts function parameters to Gemini Schema +func convertFunctionParametersToGeminiSchema(params schemas.ToolFunctionParameters) *Schema { + schema := &Schema{ + Type: Type(params.Type), + } + + if params.Description != nil { + schema.Description = *params.Description + } + + if params.Properties != nil { + schema.Properties = make(map[string]*Schema) + for key, prop := range params.Properties { + propSchema := convertPropertyToGeminiSchema(prop) + schema.Properties[key] = propSchema + } + } + + if params.Required != nil { + schema.Required = params.Required + } + + return schema +} + +// convertPropertyToGeminiSchema converts a property to Gemini Schema +func convertPropertyToGeminiSchema(prop interface{}) *Schema { + schema := &Schema{} + + // Handle property as map[string]interface{} + if propMap, ok := prop.(map[string]interface{}); ok { + if propType, exists := propMap["type"]; exists { + if typeStr, ok := propType.(string); ok { + schema.Type = Type(typeStr) + } + } + + if desc, exists := propMap["description"]; exists { + if descStr, ok := desc.(string); ok { + schema.Description = descStr + } + } + + if enum, exists := propMap["enum"]; exists { + if enumSlice, ok := enum.([]interface{}); ok { + var enumStrs []string + for _, item := range enumSlice { + if str, ok := item.(string); ok { + enumStrs = append(enumStrs, str) + } + } + schema.Enum = enumStrs + } + } + + // Handle nested properties for object types + if props, exists := propMap["properties"]; exists { + if propsMap, ok := props.(map[string]interface{}); ok { + schema.Properties = make(map[string]*Schema) + for key, nestedProp := range propsMap { + schema.Properties[key] = convertPropertyToGeminiSchema(nestedProp) + } + } + } + + // Handle array items + if items, exists := propMap["items"]; exists { + schema.Items = convertPropertyToGeminiSchema(items) + } + } + + return schema +} + +// convertResponsesMessagesToGeminiContents converts Responses messages to Gemini contents +func convertResponsesMessagesToGeminiContents(messages []schemas.ResponsesMessage) ([]CustomContent, *CustomContent, error) { + var contents []CustomContent + var systemInstruction *CustomContent + + for _, msg := range messages { + // Handle system messages separately + if msg.Role != nil && *msg.Role == schemas.ResponsesInputMessageRoleSystem { + if systemInstruction == nil { + systemInstruction = &CustomContent{} + } + + // Convert system message content + if msg.Content != nil { + if msg.Content.ContentStr != nil { + systemInstruction.Parts = append(systemInstruction.Parts, &CustomPart{ + Text: *msg.Content.ContentStr, + }) + } + if msg.Content.ContentBlocks != nil { + for _, block := range *msg.Content.ContentBlocks { + part, err := convertContentBlockToGeminiPart(block) + if err != nil { + return nil, nil, fmt.Errorf("failed to convert system message content block: %w", err) + } + if part != nil { + systemInstruction.Parts = append(systemInstruction.Parts, part) + } + } + } + } + + continue + } + + // Handle regular messages + content := CustomContent{} + + if msg.Role != nil { + content.Role = string(*msg.Role) + } else { + content.Role = "user" // Default role if msg.Role is nil + } + + // Convert message content + if msg.Content != nil { + if msg.Content.ContentStr != nil { + content.Parts = append(content.Parts, &CustomPart{ + Text: *msg.Content.ContentStr, + }) + } + + if msg.Content.ContentBlocks != nil { + for _, block := range *msg.Content.ContentBlocks { + part, err := convertContentBlockToGeminiPart(block) + if err != nil { + return nil, nil, fmt.Errorf("failed to convert message content block: %w", err) + } + if part != nil { + content.Parts = append(content.Parts, part) + } + } + } + } + + // Handle tool calls from assistant messages + if msg.ResponsesToolMessage != nil && msg.Type != nil { + switch *msg.Type { + case schemas.ResponsesMessageTypeFunctionCall: + // Convert function call to Gemini FunctionCall + if msg.ResponsesToolMessage.Name != nil { + argsMap := make(map[string]any) + if msg.ResponsesToolMessage.Arguments != nil { + // Parse JSON arguments + if err := sonic.Unmarshal([]byte(*msg.ResponsesToolMessage.Arguments), &argsMap); err == nil { + part := &CustomPart{ + FunctionCall: &FunctionCall{ + Name: *msg.ResponsesToolMessage.Name, + Args: argsMap, + }, + } + if msg.ResponsesToolMessage.CallID != nil { + part.FunctionCall.ID = *msg.ResponsesToolMessage.CallID + } + content.Parts = append(content.Parts, part) + } + } + } + + case schemas.ResponsesMessageTypeFunctionCallOutput: + // Convert function response to Gemini FunctionResponse + if msg.ResponsesToolMessage.CallID != nil { + responseMap := make(map[string]any) + if msg.Content != nil && msg.Content.ContentStr != nil { + responseMap["output"] = *msg.Content.ContentStr + } + + // Prefer the declared tool name; fallback to CallID if the name is absent + funcName := "" + if msg.ResponsesToolMessage.Name != nil && strings.TrimSpace(*msg.ResponsesToolMessage.Name) != "" { + funcName = *msg.ResponsesToolMessage.Name + } else { + funcName = *msg.ResponsesToolMessage.CallID + } + + part := &CustomPart{ + FunctionResponse: &FunctionResponse{ + Name: funcName, + Response: responseMap, + }, + } + // Keep ID = CallID + part.FunctionResponse.ID = *msg.ResponsesToolMessage.CallID + content.Parts = append(content.Parts, part) + } + } + } + + if len(content.Parts) > 0 { + contents = append(contents, content) + } + } + + return contents, systemInstruction, nil +} + +// convertContentBlockToGeminiPart converts a content block to Gemini part +func convertContentBlockToGeminiPart(block schemas.ResponsesMessageContentBlock) (*CustomPart, error) { + switch block.Type { + case schemas.ResponsesInputMessageContentBlockTypeText: + if block.Text != nil { + return &CustomPart{ + Text: *block.Text, + }, nil + } + + case schemas.ResponsesInputMessageContentBlockTypeImage: + if block.ResponsesInputMessageContentBlockImage != nil && block.ResponsesInputMessageContentBlockImage.ImageURL != nil { + imageURL := *block.ResponsesInputMessageContentBlockImage.ImageURL + + // Use existing utility functions to handle URL parsing + sanitizedURL, err := schemas.SanitizeImageURL(imageURL) + if err != nil { + return nil, fmt.Errorf("failed to sanitize image URL: %w", err) + } + + urlInfo := schemas.ExtractURLTypeInfo(sanitizedURL) + mimeType := "image/jpeg" // default + if urlInfo.MediaType != nil { + mimeType = *urlInfo.MediaType + } + + if urlInfo.Type == schemas.ImageContentTypeBase64 { + data := "" + if urlInfo.DataURLWithoutPrefix != nil { + data = *urlInfo.DataURLWithoutPrefix + } + + // Decode base64 data + decodedData, err := base64.StdEncoding.DecodeString(data) + if err != nil { + return nil, fmt.Errorf("failed to decode base64 image data: %w", err) + } + + return &CustomPart{ + InlineData: &CustomBlob{ + MIMEType: mimeType, + Data: decodedData, + }, + }, nil + } else { + return &CustomPart{ + FileData: &FileData{ + MIMEType: mimeType, + FileURI: sanitizedURL, + }, + }, nil + } + } + + case schemas.ResponsesInputMessageContentBlockTypeAudio: + if block.Audio != nil { + // Decode base64 audio data + decodedData, err := base64.StdEncoding.DecodeString(block.Audio.Data) + if err != nil { + return nil, fmt.Errorf("failed to decode base64 audio data: %w", err) + } + + return &CustomPart{ + InlineData: &CustomBlob{ + MIMEType: func() string { + f := strings.ToLower(strings.TrimSpace(block.Audio.Format)) + if f == "" { + return "audio/mpeg" + } + if strings.HasPrefix(f, "audio/") { + return f + } + return "audio/" + f + }(), + Data: decodedData, + }, + }, nil + } + + case schemas.ResponsesInputMessageContentBlockTypeFile: + if block.ResponsesInputMessageContentBlockFile != nil { + if block.ResponsesInputMessageContentBlockFile.FileURL != nil { + return &CustomPart{ + FileData: &FileData{ + MIMEType: "application/octet-stream", // default + FileURI: *block.ResponsesInputMessageContentBlockFile.FileURL, + }, + }, nil + } else if block.ResponsesInputMessageContentBlockFile.FileData != nil { + return &CustomPart{ + InlineData: &CustomBlob{ + MIMEType: "application/octet-stream", // default + Data: []byte(*block.ResponsesInputMessageContentBlockFile.FileData), + }, + }, nil + } + } + } + + return nil, nil +} diff --git a/core/schemas/providers/gemini/speech.go b/core/schemas/providers/gemini/speech.go new file mode 100644 index 0000000000..ffc9b0a92f --- /dev/null +++ b/core/schemas/providers/gemini/speech.go @@ -0,0 +1,48 @@ +package gemini + +import "github.com/maximhq/bifrost/core/schemas" + +func ToGeminiSpeechRequest(bifrostReq *schemas.BifrostSpeechRequest, responseModalities []string) *GeminiGenerationRequest { + if bifrostReq == nil { + return nil + } + + // Create the base Gemini generation request + geminiReq := &GeminiGenerationRequest{ + Model: bifrostReq.Model, + } + + // Set response modalities for speech generation + if len(responseModalities) > 0 { + geminiReq.ResponseModalities = responseModalities + } + + // Convert parameters to generation config + if len(responseModalities) > 0 { + var modalities []Modality + for _, mod := range responseModalities { + modalities = append(modalities, Modality(mod)) + } + geminiReq.GenerationConfig.ResponseModalities = modalities + } + + // Convert speech input to Gemini format + if bifrostReq.Input.Input != "" { + geminiReq.Contents = []CustomContent{ + { + Parts: []*CustomPart{ + { + Text: bifrostReq.Input.Input, + }, + }, + }, + } + + // Add speech config to generation config if voice config is provided + if bifrostReq.Params != nil && bifrostReq.Params.VoiceConfig.Voice != nil { + addSpeechConfigToGenerationConfig(&geminiReq.GenerationConfig, bifrostReq.Params.VoiceConfig) + } + } + + return geminiReq +} diff --git a/core/schemas/providers/gemini/transcription.go b/core/schemas/providers/gemini/transcription.go new file mode 100644 index 0000000000..38f2c17ae4 --- /dev/null +++ b/core/schemas/providers/gemini/transcription.go @@ -0,0 +1,73 @@ +package gemini + +import "github.com/maximhq/bifrost/core/schemas" + +func ToGeminiTranscriptionRequest(bifrostReq *schemas.BifrostTranscriptionRequest) *GeminiGenerationRequest { + if bifrostReq == nil { + return nil + } + + // Create the base Gemini generation request + geminiReq := &GeminiGenerationRequest{ + Model: bifrostReq.Model, + } + + // Convert parameters to generation config + if bifrostReq.Params != nil { + + // Handle extra parameters + if bifrostReq.Params.ExtraParams != nil { + // Safety settings + if safetySettings, ok := schemas.SafeExtractFromMap(bifrostReq.Params.ExtraParams, "safety_settings"); ok { + if settings, ok := safetySettings.([]SafetySetting); ok { + geminiReq.SafetySettings = settings + } + } + + // Cached content + if cachedContent, ok := schemas.SafeExtractString(bifrostReq.Params.ExtraParams["cached_content"]); ok { + geminiReq.CachedContent = cachedContent + } + + // Labels + if labels, ok := schemas.SafeExtractFromMap(bifrostReq.Params.ExtraParams, "labels"); ok { + if labelMap, ok := labels.(map[string]string); ok { + geminiReq.Labels = labelMap + } + } + } + } + + // Determine the prompt text + var prompt string + if bifrostReq.Params != nil && bifrostReq.Params.Prompt != nil { + prompt = *bifrostReq.Params.Prompt + } else { + prompt = "Generate a transcript of the speech." + } + + // Create parts for the transcription request + parts := []*CustomPart{ + { + Text: prompt, + }, + } + + // Add audio file if present + if len(bifrostReq.Input.File) > 0 { + parts = append(parts, &CustomPart{ + InlineData: &CustomBlob{ + MIMEType: detectAudioMimeType(bifrostReq.Input.File), + Data: bifrostReq.Input.File, + }, + }) + } + + geminiReq.Contents = []CustomContent{ + { + Parts: parts, + }, + } + + return geminiReq +} diff --git a/core/schemas/providers/gemini/types.go b/core/schemas/providers/gemini/types.go index 1cbb605cd1..a77632c563 100644 --- a/core/schemas/providers/gemini/types.go +++ b/core/schemas/providers/gemini/types.go @@ -52,19 +52,19 @@ const ( ) type GeminiGenerationRequest struct { - Model string `json:"model,omitempty"` // Model field for explicit model specification - Contents []CustomContent `json:"contents,omitempty"` // For chat completion requests - Requests []GeminiEmbeddingRequest `json:"requests,omitempty"` // For batch embedding requests - SystemInstruction *CustomContent `json:"systemInstruction,omitempty"` - GenerationConfig GenerationConfig `json:"generationConfig,omitempty"` - SafetySettings []SafetySetting `json:"safetySettings,omitempty"` - Tools []Tool `json:"tools,omitempty"` - ToolConfig ToolConfig `json:"toolConfig,omitempty"` - Labels map[string]string `json:"labels,omitempty"` - CachedContent string `json:"cachedContent,omitempty"` - ResponseModalities []string `json:"responseModalities,omitempty"` - Stream bool `json:"-"` // Internal field to track streaming requests - IsEmbedding bool `json:"-"` // Internal field to track if this is an embedding request + Model string `json:"model,omitempty"` // Model field for explicit model specification + Contents []CustomContent `json:"contents,omitempty"` // For chat completion requests + Requests []GeminiEmbeddingRequest `json:"requests,omitempty"` // For batch embedding requests + SystemInstruction *CustomContent `json:"systemInstruction,omitempty"` + GenerationConfig GenerationConfig `json:"generationConfig,omitempty"` + SafetySettings []SafetySetting `json:"safetySettings,omitempty"` + Tools []Tool `json:"tools,omitempty"` + ToolConfig ToolConfig `json:"toolConfig,omitempty"` + Labels map[string]string `json:"labels,omitempty"` + CachedContent string `json:"cachedContent,omitempty"` + ResponseModalities []string `json:"responseModalities,omitempty"` + Stream bool `json:"-"` // Internal field to track streaming requests + IsEmbedding bool `json:"-"` // Internal field to track if this is an embedding request // Embedding-specific parameters TaskType *string `json:"taskType,omitempty"` @@ -72,6 +72,11 @@ type GeminiGenerationRequest struct { OutputDimensionality *int `json:"outputDimensionality,omitempty"` } +// IsStreamingRequested implements the StreamingRequest interface +func (r *GeminiGenerationRequest) IsStreamingRequested() bool { + return r.Stream +} + // Safety settings. type SafetySetting struct { // Optional. Determines if the harm block method uses probability or probability @@ -116,7 +121,6 @@ const ( FunctionCallingConfigModeValidated FunctionCallingConfigMode = "VALIDATED" ) - // An object that represents a latitude/longitude pair. // This is expressed as a pair of doubles to represent degrees latitude and // degrees longitude. Unless specified otherwise, this object must conform to the @@ -200,7 +204,6 @@ const ( BehaviorNonBlocking Behavior = "NON_BLOCKING" ) - // Represents a time interval, encoded as a start time (inclusive) and an end time (exclusive). // The start time must be less than or equal to the end time. // When the start equals the end time, the interval is an empty interval. @@ -464,7 +467,6 @@ const ( APISpecElasticSearch APISpec = "ELASTIC_SEARCH" ) - // Define data stores within engine to filter on in a search call and configurations // for those data stores. For more information, see https://cloud.google.com/generative-ai-app-builder/docs/reference/rpc/google.cloud.discoveryengine.v1#datastorespec type VertexAISearchDataStoreSpec struct { @@ -518,7 +520,7 @@ type RAGRetrievalConfigHybridSearch struct { // results. The range is [0, 1], while 0 means sparse vector search only and 1 means // dense vector search only. The default value is 0.5 which balances sparse and dense // vector search equally. - Alpha *float32 `json:"alpha,omitempty"` + Alpha *float64 `json:"alpha,omitempty"` } // Config for LlmRanker. @@ -624,7 +626,6 @@ type Tool struct { CodeExecution *ToolCodeExecution `json:"codeExecution,omitempty"` } - // Generation config. You can find API default values and more details at https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference#generationconfig // and https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/content-generation-parameters. type GenerationConfig struct { @@ -638,7 +639,7 @@ type GenerationConfig struct { // Optional. If enabled, the model will detect emotions and adapt its responses accordingly. EnableAffectiveDialog *bool `json:"enableAffectiveDialog,omitempty"` // Optional. Frequency penalties. - FrequencyPenalty *float32 `json:"frequencyPenalty,omitempty"` + FrequencyPenalty *float64 `json:"frequencyPenalty,omitempty"` // Optional. Logit probabilities. Logprobs *int32 `json:"logprobs,omitempty"` // Optional. The maximum number of output tokens to generate per message. If empty, @@ -647,7 +648,7 @@ type GenerationConfig struct { // Optional. If specified, the media resolution specified will be used. MediaResolution string `json:"mediaResolution,omitempty"` // Optional. Positive penalties. - PresencePenalty *float32 `json:"presencePenalty,omitempty"` + PresencePenalty *float64 `json:"presencePenalty,omitempty"` // Optional. Output schema of the generated response. This is an alternative to `response_schema` // that accepts [JSON Schema](https://json-schema.org/). If set, `response_schema` must // be omitted, but `response_mime_type` is required. While the full JSON Schema may @@ -685,14 +686,14 @@ type GenerationConfig struct { // Optional. Stop sequences. StopSequences []string `json:"stopSequences,omitempty"` // Optional. Controls the randomness of predictions. - Temperature *float32 `json:"temperature,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` // Optional. Config for thinking features. An error will be returned if this field is // set for models that don't support thinking. ThinkingConfig *GenerationConfigThinkingConfig `json:"thinkingConfig,omitempty"` // Optional. If specified, top-k sampling will be used. - TopK *float32 `json:"topK,omitempty"` + TopK *int `json:"topK,omitempty"` // Optional. If specified, nucleus sampling will be used. - TopP *float32 `json:"topP,omitempty"` + TopP *float64 `json:"topP,omitempty"` } // Config for model selection. @@ -917,14 +918,14 @@ func (b *CustomBlob) UnmarshalJSON(data []byte) error { // CustomPart handles Google GenAI Part with custom Blob unmarshalling type CustomPart struct { VideoMetadata *VideoMetadata `json:"videoMetadata,omitempty"` - Thought bool `json:"thought,omitempty"` + Thought bool `json:"thought,omitempty"` CodeExecutionResult *CodeExecutionResult `json:"codeExecutionResult,omitempty"` ExecutableCode *ExecutableCode `json:"executableCode,omitempty"` FileData *FileData `json:"fileData,omitempty"` FunctionCall *FunctionCall `json:"functionCall,omitempty"` FunctionResponse *FunctionResponse `json:"functionResponse,omitempty"` - InlineData *CustomBlob `json:"inlineData,omitempty"` - Text string `json:"text,omitempty"` + InlineData *CustomBlob `json:"inlineData,omitempty"` + Text string `json:"text,omitempty"` } // ToGenAIPart converts CustomPart to Part @@ -950,7 +951,6 @@ func (p *CustomPart) ToGenAIPart() *Part { return part } - // Contains the multi-part content of a message. type Content struct { // Optional. List of parts that constitute a single message. Each part may have @@ -960,7 +960,7 @@ type Content struct { // 'model'. Useful to set for multi-turn conversations, otherwise can be // empty. If role is not specified, SDK will determine the role. Role string `json:"role,omitempty"` -} +} // CustomContent handles Google GenAI Content with custom Part unmarshalling type CustomContent struct { @@ -1118,7 +1118,6 @@ type FunctionResponse struct { Response map[string]any `json:"response,omitempty"` } - // ==================== RESPONSE TYPES ==================== // GeminiEmbeddingResponse represents a Google GenAI embedding response type GeminiEmbeddingResponse struct { @@ -1187,7 +1186,6 @@ type SafetyRating struct { SeverityScore float32 `json:"severityScore,omitempty"` } - // Context for a single URL retrieval. type URLMetadata struct { // Optional. The URL retrieved by the tool. @@ -1202,7 +1200,6 @@ type URLContextMetadata struct { URLMetadata []*URLMetadata `json:"urlMetadata,omitempty"` } - // A response candidate generated from the model. type Candidate struct { // Optional. Contains the multi-part content of the response. diff --git a/core/schemas/providers/gemini/utils.go b/core/schemas/providers/gemini/utils.go index ab0bef5973..cb5c896af3 100644 --- a/core/schemas/providers/gemini/utils.go +++ b/core/schemas/providers/gemini/utils.go @@ -2,52 +2,15 @@ package gemini import ( "bytes" - "encoding/json" - "strconv" "strings" "github.com/bytedance/sonic" "github.com/maximhq/bifrost/core/schemas" ) -// convertEmbeddingParameters converts Gemini embedding request parameters to ModelParameters -func (r *GeminiGenerationRequest) convertEmbeddingParameters() *schemas.ModelParameters { - params := &schemas.ModelParameters{ - ExtraParams: make(map[string]interface{}), - } - - // Check for parameters from batch embedding requests first - if len(r.Requests) > 0 { - // Use parameters from the first request in the batch - firstReq := r.Requests[0] - if firstReq.TaskType != nil { - params.ExtraParams["taskType"] = *firstReq.TaskType - } - if firstReq.Title != nil { - params.ExtraParams["title"] = *firstReq.Title - } - if firstReq.OutputDimensionality != nil { - params.Dimensions = firstReq.OutputDimensionality - } - } else { - // Fallback to top-level embedding parameters for single requests - if r.TaskType != nil { - params.ExtraParams["taskType"] = *r.TaskType - } - if r.Title != nil { - params.ExtraParams["title"] = *r.Title - } - if r.OutputDimensionality != nil { - params.Dimensions = r.OutputDimensionality - } - } - - return params -} - -// convertGenerationConfigToParams converts Gemini GenerationConfig to ModelParameters -func (r *GeminiGenerationRequest) convertGenerationConfigToParams() *schemas.ModelParameters { - params := &schemas.ModelParameters{ +// convertGenerationConfigToChatParameters converts Gemini GenerationConfig to ModelParameters +func (r *GeminiGenerationRequest) convertGenerationConfigToChatParameters() *schemas.ChatParameters { + params := &schemas.ChatParameters{ ExtraParams: make(map[string]interface{}), } @@ -55,33 +18,31 @@ func (r *GeminiGenerationRequest) convertGenerationConfigToParams() *schemas.Mod // Map generation config fields to parameters if config.Temperature != nil { - temp := float64(*config.Temperature) - params.Temperature = &temp + params.Temperature = config.Temperature } if config.TopP != nil { - params.TopP = schemas.Ptr(float64(*config.TopP)) + params.TopP = config.TopP } if config.TopK != nil { - params.TopK = schemas.Ptr(int(*config.TopK)) + params.ExtraParams["top_k"] = *config.TopK } if config.MaxOutputTokens > 0 { - maxTokens := int(config.MaxOutputTokens) - params.MaxTokens = &maxTokens + params.MaxCompletionTokens = schemas.Ptr(int(config.MaxOutputTokens)) } if config.CandidateCount > 0 { params.ExtraParams["candidate_count"] = config.CandidateCount } if len(config.StopSequences) > 0 { - params.StopSequences = &config.StopSequences + params.Stop = &config.StopSequences } if config.PresencePenalty != nil { - params.PresencePenalty = schemas.Ptr(float64(*config.PresencePenalty)) + params.PresencePenalty = config.PresencePenalty } if config.FrequencyPenalty != nil { - params.FrequencyPenalty = schemas.Ptr(float64(*config.FrequencyPenalty)) + params.FrequencyPenalty = config.FrequencyPenalty } if config.Seed != nil { - params.ExtraParams["seed"] = *config.Seed + params.Seed = schemas.Ptr(int(*config.Seed)) } if config.ResponseMIMEType != "" { params.ExtraParams["response_mime_type"] = config.ResponseMIMEType @@ -97,8 +58,8 @@ func (r *GeminiGenerationRequest) convertGenerationConfigToParams() *schemas.Mod } // convertSchemaToFunctionParameters converts genai.Schema to schemas.FunctionParameters -func (r *GeminiGenerationRequest) convertSchemaToFunctionParameters(schema *Schema) schemas.FunctionParameters { - params := schemas.FunctionParameters{ +func (r *GeminiGenerationRequest) convertSchemaToFunctionParameters(schema *Schema) schemas.ToolFunctionParameters { + params := schemas.ToolFunctionParameters{ Type: string(schema.Type), } @@ -168,9 +129,9 @@ func isImageMimeType(mimeType string) bool { } // ensureExtraParams ensures that bifrostReq.Params and bifrostReq.Params.ExtraParams are initialized -func ensureExtraParams(bifrostReq *schemas.BifrostRequest) { +func ensureExtraParams(bifrostReq *schemas.BifrostChatRequest) { if bifrostReq.Params == nil { - bifrostReq.Params = &schemas.ModelParameters{ + bifrostReq.Params = &schemas.ChatParameters{ ExtraParams: make(map[string]interface{}), } } @@ -191,7 +152,7 @@ func (r *GenerateContentResponse) extractUsageMetadata() (int, int, int) { } // convertParamsToGenerationConfig converts Bifrost parameters to Gemini GenerationConfig -func convertParamsToGenerationConfig(params *schemas.ModelParameters, responseModalities []string) GenerationConfig { +func convertParamsToGenerationConfig(params *schemas.ChatParameters, responseModalities []string) GenerationConfig { config := GenerationConfig{} // Add response modalities if specified @@ -204,50 +165,60 @@ func convertParamsToGenerationConfig(params *schemas.ModelParameters, responseMo } // Map standard parameters - if params.StopSequences != nil { - config.StopSequences = *params.StopSequences + if params.Stop != nil { + config.StopSequences = *params.Stop } - if params.MaxTokens != nil { - config.MaxOutputTokens = int32(*params.MaxTokens) + if params.MaxCompletionTokens != nil { + config.MaxOutputTokens = int32(*params.MaxCompletionTokens) } if params.Temperature != nil { - temp := float32(*params.Temperature) + temp := float64(*params.Temperature) config.Temperature = &temp } if params.TopP != nil { - topP := float32(*params.TopP) + topP := float64(*params.TopP) config.TopP = &topP } - if params.TopK != nil { - topK := float32(*params.TopK) - config.TopK = &topK - } if params.PresencePenalty != nil { - penalty := float32(*params.PresencePenalty) + penalty := float64(*params.PresencePenalty) config.PresencePenalty = &penalty } if params.FrequencyPenalty != nil { - penalty := float32(*params.FrequencyPenalty) + penalty := float64(*params.FrequencyPenalty) config.FrequencyPenalty = &penalty } + if params.ExtraParams != nil { + if topK, ok := params.ExtraParams["top_k"]; ok { + if val, success := schemas.SafeExtractInt(topK); success { + config.TopK = schemas.Ptr(val) + } + } + } + return config } // convertBifrostToolsToGemini converts Bifrost tools to Gemini format -func convertBifrostToolsToGemini(bifrostTools []schemas.Tool) []Tool { +func convertBifrostToolsToGemini(bifrostTools []schemas.ChatTool) []Tool { var geminiTools []Tool for _, tool := range bifrostTools { - if tool.Type == "function" { + if tool.Type == "" { + continue + } + if tool.Type == "function" && tool.Function != nil { + fd := &FunctionDeclaration{ + Name: tool.Function.Name, + } + if tool.Function.Parameters != nil { + fd.Parameters = convertFunctionParametersToSchema(*tool.Function.Parameters) + } + if tool.Function.Description != nil { + fd.Description = *tool.Function.Description + } geminiTool := Tool{ - FunctionDeclarations: []*FunctionDeclaration{ - { - Name: tool.Function.Name, - Description: tool.Function.Description, - Parameters: convertFunctionParametersToSchema(tool.Function.Parameters), - }, - }, + FunctionDeclarations: []*FunctionDeclaration{fd}, } geminiTools = append(geminiTools, geminiTool) } @@ -257,7 +228,7 @@ func convertBifrostToolsToGemini(bifrostTools []schemas.Tool) []Tool { } // convertFunctionParametersToSchema converts Bifrost function parameters to Gemini Schema -func convertFunctionParametersToSchema(params schemas.FunctionParameters) *Schema { +func convertFunctionParametersToSchema(params schemas.ToolFunctionParameters) *Schema { schema := &Schema{ Type: Type(params.Type), } @@ -292,15 +263,14 @@ func convertFunctionParametersToSchema(params schemas.FunctionParameters) *Schem return schema } - // convertToolChoiceToToolConfig converts Bifrost tool choice to Gemini tool config -func convertToolChoiceToToolConfig(toolChoice *schemas.ToolChoice) ToolConfig { +func convertToolChoiceToToolConfig(toolChoice *schemas.ChatToolChoice) ToolConfig { config := ToolConfig{} functionCallingConfig := FunctionCallingConfig{} - if toolChoice.ToolChoiceStr != nil { + if toolChoice.ChatToolChoiceStr != nil { // Map string values to Gemini's enum values - switch *toolChoice.ToolChoiceStr { + switch *toolChoice.ChatToolChoiceStr { case "none": functionCallingConfig.Mode = FunctionCallingConfigModeNone case "auto": @@ -310,21 +280,21 @@ func convertToolChoiceToToolConfig(toolChoice *schemas.ToolChoice) ToolConfig { default: functionCallingConfig.Mode = FunctionCallingConfigModeAuto } - } else if toolChoice.ToolChoiceStruct != nil { - switch toolChoice.ToolChoiceStruct.Type { - case schemas.ToolChoiceTypeNone: + } else if toolChoice.ChatToolChoiceStruct != nil { + switch toolChoice.ChatToolChoiceStruct.Type { + case schemas.ChatToolChoiceTypeNone: functionCallingConfig.Mode = FunctionCallingConfigModeNone - case schemas.ToolChoiceTypeAuto: - functionCallingConfig.Mode = FunctionCallingConfigModeAuto - case schemas.ToolChoiceTypeRequired, schemas.ToolChoiceTypeFunction: + case schemas.ChatToolChoiceTypeFunction: + functionCallingConfig.Mode = FunctionCallingConfigModeAny + case schemas.ChatToolChoiceTypeRequired: functionCallingConfig.Mode = FunctionCallingConfigModeAny default: functionCallingConfig.Mode = FunctionCallingConfigModeAuto } // Handle specific function selection - if toolChoice.ToolChoiceStruct.Function.Name != "" { - functionCallingConfig.AllowedFunctionNames = []string{toolChoice.ToolChoiceStruct.Function.Name} + if toolChoice.ChatToolChoiceStruct.Function.Name != "" { + functionCallingConfig.AllowedFunctionNames = []string{toolChoice.ChatToolChoiceStruct.Function.Name} } } @@ -368,7 +338,7 @@ func addSpeechConfigToGenerationConfig(config *GenerationConfig, voiceConfig sch } // convertBifrostMessagesToGemini converts Bifrost messages to Gemini format -func convertBifrostMessagesToGemini(messages []schemas.BifrostMessage) []CustomContent { +func convertBifrostMessagesToGemini(messages []schemas.ChatMessage) []CustomContent { var contents []CustomContent for _, message := range messages { @@ -391,8 +361,8 @@ func convertBifrostMessagesToGemini(messages []schemas.BifrostMessage) []CustomC } // Handle tool calls for assistant messages - if message.AssistantMessage != nil && message.AssistantMessage.ToolCalls != nil { - for _, toolCall := range *message.AssistantMessage.ToolCalls { + if message.ChatAssistantMessage != nil && message.ChatAssistantMessage.ToolCalls != nil { + for _, toolCall := range *message.ChatAssistantMessage.ToolCalls { // Convert tool call to function call part if toolCall.Function.Name != nil { // Create function call part - simplified implementation @@ -400,9 +370,14 @@ func convertBifrostMessagesToGemini(messages []schemas.BifrostMessage) []CustomC if toolCall.Function.Arguments != "" { sonic.Unmarshal([]byte(toolCall.Function.Arguments), &argsMap) } + // Handle ID: use it if available, otherwise fallback to function name + callID := *toolCall.Function.Name + if toolCall.ID != nil && strings.TrimSpace(*toolCall.ID) != "" { + callID = *toolCall.ID + } parts = append(parts, &CustomPart{ FunctionCall: &FunctionCall{ - ID: *toolCall.ID, + ID: callID, Name: *toolCall.Function.Name, Args: argsMap, }, @@ -411,11 +386,54 @@ func convertBifrostMessagesToGemini(messages []schemas.BifrostMessage) []CustomC } } - // Handle thinking content - if message.AssistantMessage != nil && message.AssistantMessage.Thought != nil && *message.AssistantMessage.Thought != "" { + // Handle tool response messages + if message.Role == schemas.ChatMessageRoleTool && message.ChatToolMessage != nil { + // Parse the response content + var responseData map[string]any + var contentStr string + + // Extract content string from ContentStr or ContentBlocks + if message.Content.ContentStr != nil && *message.Content.ContentStr != "" { + contentStr = *message.Content.ContentStr + } else if message.Content.ContentBlocks != nil { + // Fallback: try to extract text from content blocks + var textParts []string + for _, block := range *message.Content.ContentBlocks { + if block.Text != nil && *block.Text != "" { + textParts = append(textParts, *block.Text) + } + } + if len(textParts) > 0 { + contentStr = strings.Join(textParts, "\n") + } + } + + // Try to unmarshal as JSON + if contentStr != "" { + err := sonic.Unmarshal([]byte(contentStr), &responseData) + if err != nil { + // If unmarshaling fails, wrap the original string to preserve it + responseData = map[string]any{ + "content": contentStr, + } + } + } else { + // If no content at all, use empty map to avoid nil + responseData = map[string]any{} + } + + // Use ToolCallID if available, ensuring it's not nil + callID := "" + if message.ChatToolMessage.ToolCallID != nil { + callID = *message.ChatToolMessage.ToolCallID + } + parts = append(parts, &CustomPart{ - Text: *message.AssistantMessage.Thought, - Thought: true, + FunctionResponse: &FunctionResponse{ + ID: callID, + Name: callID, // Gemini uses name for correlation + Response: responseData, + }, }) } @@ -487,59 +505,6 @@ func detectAudioMimeType(audioData []byte) string { // Fallback within supported set return "audio/mp3" } -// safeFloat32Conversion safely converts various numeric types to float32 -func safeFloat32Conversion(value interface{}) (float32, bool) { - if value == nil { - return 0, false - } - - switch v := value.(type) { - case int: - return float32(v), true - case int64: - return float32(v), true - case float64: - return float32(v), true - case float32: - return v, true - case json.Number: - if val, err := v.Float64(); err == nil { - return float32(val), true - } - return 0, false - case string: - if val, err := strconv.ParseFloat(v, 32); err == nil { - return float32(val), true - } - return 0, false - default: - return 0, false - } -} - -// safeStringSliceConversion safely converts various types to []string -func safeStringSliceConversion(value interface{}) ([]string, bool) { - if value == nil { - return nil, false - } - - switch v := value.(type) { - case []string: - return v, true - case []interface{}: - var result []string - for _, item := range v { - if str, ok := item.(string); ok { - result = append(result, str) - } else { - return nil, false // If any item is not a string, fail - } - } - return result, true - default: - return nil, false - } -} // normalizeAudioMIMEType converts audio format tokens to proper MIME types func normalizeAudioMIMEType(format string) string { diff --git a/core/schemas/providers/mistral/embedding.go b/core/schemas/providers/mistral/embedding.go index d0c89d81f9..4e41f372a5 100644 --- a/core/schemas/providers/mistral/embedding.go +++ b/core/schemas/providers/mistral/embedding.go @@ -4,14 +4,34 @@ import ( "github.com/maximhq/bifrost/core/schemas" ) -func ToMistralEmbeddingRequest(bifrostReq *schemas.BifrostRequest) *MistralEmbeddingRequest { - if bifrostReq == nil || bifrostReq.Input.EmbeddingInput == nil { +func ToMistralEmbeddingRequest(bifrostReq *schemas.BifrostEmbeddingRequest) *MistralEmbeddingRequest { + if bifrostReq == nil { return nil } - texts := bifrostReq.Input.EmbeddingInput.Texts - if len(texts) == 0 && bifrostReq.Input.EmbeddingInput.Text != nil { - texts = []string{*bifrostReq.Input.EmbeddingInput.Text} + var texts []string + + // Handle single Text input + if bifrostReq.Input.Text != nil { + // Treat empty string as nil/absent + if *bifrostReq.Input.Text != "" { + texts = []string{*bifrostReq.Input.Text} + } + } + + // Handle multiple Texts input (only if single Text wasn't valid) + if len(texts) == 0 && bifrostReq.Input.Texts != nil { + // Filter out empty strings from the slice + for _, text := range bifrostReq.Input.Texts { + if text != "" { + texts = append(texts, text) + } + } + } + + // Return nil immediately when no valid texts remain + if len(texts) == 0 { + return nil } mistralReq := &MistralEmbeddingRequest{ @@ -23,7 +43,11 @@ func ToMistralEmbeddingRequest(bifrostReq *schemas.BifrostRequest) *MistralEmbed if bifrostReq.Params != nil { mistralReq.OutputDtype = bifrostReq.Params.EncodingFormat mistralReq.OutputDimension = bifrostReq.Params.Dimensions - mistralReq.User = bifrostReq.Params.User + if bifrostReq.Params.ExtraParams != nil { + if user, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["user"]); ok { + mistralReq.User = user + } + } } return mistralReq diff --git a/core/schemas/providers/openai/chat.go b/core/schemas/providers/openai/chat.go index 64b4fd0e70..97e9129b7b 100644 --- a/core/schemas/providers/openai/chat.go +++ b/core/schemas/providers/openai/chat.go @@ -3,56 +3,32 @@ package openai import "github.com/maximhq/bifrost/core/schemas" // ToBifrostRequest converts an OpenAI chat request to Bifrost format -func (r *OpenAIChatRequest) ToBifrostRequest() *schemas.BifrostRequest { +func (r *OpenAIChatRequest) ToBifrostRequest() *schemas.BifrostChatRequest { provider, model := schemas.ParseModelString(r.Model, schemas.OpenAI) - params := r.convertParameters() - - bifrostReq := &schemas.BifrostRequest{ + bifrostReq := &schemas.BifrostChatRequest{ Provider: provider, Model: model, - Input: schemas.RequestInput{ - ChatCompletionInput: &r.Messages, - }, - Params: filterParams(provider, params), + Input: r.Messages, + Params: &r.ChatParameters, } return bifrostReq } -// ToOpenAIChatCompletionResponse converts a Bifrost response to OpenAI format -func ToOpenAIChatCompletionResponse(bifrostResp *schemas.BifrostResponse) *OpenAIChatResponse { - if bifrostResp == nil { +// ToOpenAIChatRequest converts a Bifrost chat completion request to OpenAI format +func ToOpenAIChatRequest(bifrostReq *schemas.BifrostChatRequest) *OpenAIChatRequest { + if bifrostReq == nil || bifrostReq.Input == nil { return nil } - openaiResp := &OpenAIChatResponse{ - ID: bifrostResp.ID, - Object: bifrostResp.Object, - Created: bifrostResp.Created, - Model: bifrostResp.Model, - Choices: bifrostResp.Choices, - Usage: bifrostResp.Usage, - ServiceTier: bifrostResp.ServiceTier, - SystemFingerprint: bifrostResp.SystemFingerprint, - } - - return openaiResp -} - -// ToOpenAIChatCompletionRequest converts a Bifrost chat completion request to OpenAI format -func ToOpenAIChatCompletionRequest(bifrostReq *schemas.BifrostRequest) *OpenAIChatRequest { - if bifrostReq == nil || bifrostReq.Input.ChatCompletionInput == nil { - return nil + openaiReq := &OpenAIChatRequest{ + Model: bifrostReq.Model, + Messages: bifrostReq.Input, } - messages := *bifrostReq.Input.ChatCompletionInput - params := bifrostReq.Params - - openaiReq := &OpenAIChatRequest{ - Model: bifrostReq.Model, - Messages: messages, - ModelParameters: params, // Directly embed the parameters + if bifrostReq.Params != nil { + openaiReq.ChatParameters = *bifrostReq.Params } return openaiReq diff --git a/core/schemas/providers/openai/embedding.go b/core/schemas/providers/openai/embedding.go index 25986788e9..2f4261b644 100644 --- a/core/schemas/providers/openai/embedding.go +++ b/core/schemas/providers/openai/embedding.go @@ -1,80 +1,26 @@ package openai import ( - "github.com/bytedance/sonic" "github.com/maximhq/bifrost/core/schemas" ) // ToBifrostRequest converts an OpenAI embedding request to Bifrost format -func (r *OpenAIEmbeddingRequest) ToBifrostRequest() *schemas.BifrostRequest { +func (r *OpenAIEmbeddingRequest) ToBifrostRequest() *schemas.BifrostEmbeddingRequest { provider, model := schemas.ParseModelString(r.Model, schemas.OpenAI) - // Create embedding input - embeddingInput := &schemas.EmbeddingInput{} - - // Cleaner coercion: marshal input and try to unmarshal into supported shapes - if raw, err := sonic.Marshal(r.Input); err == nil { - // 1) string - var s string - if err := sonic.Unmarshal(raw, &s); err == nil { - embeddingInput.Text = &s - } else { - // 2) []string - var ss []string - if err := sonic.Unmarshal(raw, &ss); err == nil { - embeddingInput.Texts = ss - } else { - // 3) []int - var i []int - if err := sonic.Unmarshal(raw, &i); err == nil { - embeddingInput.Embedding = i - } else { - // 4) [][]int - var ii [][]int - if err := sonic.Unmarshal(raw, &ii); err == nil { - embeddingInput.Embeddings = ii - } - } - } - } - } - - bifrostReq := &schemas.BifrostRequest{ + bifrostReq := &schemas.BifrostEmbeddingRequest{ Provider: provider, Model: model, - Input: schemas.RequestInput{ - EmbeddingInput: embeddingInput, - }, + Input: r.Input, + Params: &r.EmbeddingParameters, } - // Convert parameters first - params := r.convertEmbeddingParameters() - - // Map parameters - bifrostReq.Params = filterParams(provider, params) - return bifrostReq } -// ToOpenAIEmbeddingResponse converts a Bifrost embedding response to OpenAI format -func ToOpenAIEmbeddingResponse(bifrostResp *schemas.BifrostResponse) *OpenAIEmbeddingResponse { - if bifrostResp == nil || bifrostResp.Data == nil { - return nil - } - - return &OpenAIEmbeddingResponse{ - Object: "list", - Data: bifrostResp.Data, - Model: bifrostResp.Model, - Usage: bifrostResp.Usage, - ServiceTier: bifrostResp.ServiceTier, - SystemFingerprint: bifrostResp.SystemFingerprint, - } -} - // ToOpenAIEmbeddingRequest converts a Bifrost embedding request to OpenAI format -func ToOpenAIEmbeddingRequest(bifrostReq *schemas.BifrostRequest) *OpenAIEmbeddingRequest { - if bifrostReq == nil || bifrostReq.Input.EmbeddingInput == nil { +func ToOpenAIEmbeddingRequest(bifrostReq *schemas.BifrostEmbeddingRequest) *OpenAIEmbeddingRequest { + if bifrostReq == nil { return nil } @@ -82,14 +28,12 @@ func ToOpenAIEmbeddingRequest(bifrostReq *schemas.BifrostRequest) *OpenAIEmbeddi openaiReq := &OpenAIEmbeddingRequest{ Model: bifrostReq.Model, - Input: *bifrostReq.Input.EmbeddingInput, + Input: bifrostReq.Input, } // Map parameters if params != nil { - openaiReq.EncodingFormat = params.EncodingFormat - openaiReq.Dimensions = params.Dimensions - openaiReq.User = params.User + openaiReq.EmbeddingParameters = *params } return openaiReq diff --git a/core/schemas/providers/openai/error.go b/core/schemas/providers/openai/error.go deleted file mode 100644 index d173330259..0000000000 --- a/core/schemas/providers/openai/error.go +++ /dev/null @@ -1,50 +0,0 @@ -package openai - -import ( - "github.com/maximhq/bifrost/core/schemas" -) - -// ToOpenAIError converts a BifrostError to OpenAIChatError -func ToOpenAIError(bifrostErr *schemas.BifrostError) *OpenAIChatError { - if bifrostErr == nil { - return nil - } - - // Provide blank strings for nil pointer fields - eventID := "" - if bifrostErr.EventID != nil { - eventID = *bifrostErr.EventID - } - - errorType := "" - if bifrostErr.Type != nil { - errorType = *bifrostErr.Type - } - - // Handle nested error fields with nil checks - errorStruct := OpenAIChatErrorStruct{ - Type: "", - Code: "", - Message: bifrostErr.Error.Message, - Param: bifrostErr.Error.Param, - EventID: eventID, - } - - if bifrostErr.Error.Type != nil { - errorStruct.Type = *bifrostErr.Error.Type - } - - if bifrostErr.Error.Code != nil { - errorStruct.Code = *bifrostErr.Error.Code - } - - if bifrostErr.Error.EventID != nil { - errorStruct.EventID = *bifrostErr.Error.EventID - } - - return &OpenAIChatError{ - EventID: eventID, - Type: errorType, - Error: errorStruct, - } -} diff --git a/core/schemas/providers/openai/responses.go b/core/schemas/providers/openai/responses.go new file mode 100644 index 0000000000..0163618d56 --- /dev/null +++ b/core/schemas/providers/openai/responses.go @@ -0,0 +1,36 @@ +package openai + +import "github.com/maximhq/bifrost/core/schemas" + +func (r *OpenAIResponsesRequest) ToBifrostRequest() *schemas.BifrostResponsesRequest { + if r == nil { + return nil + } + + return &schemas.BifrostResponsesRequest{ + Provider: schemas.OpenAI, + Model: r.Model, + Input: r.Input, + Params: &r.ResponsesParameters, + } +} + +func ToOpenAIResponsesRequest(bifrostReq *schemas.BifrostResponsesRequest) *OpenAIResponsesRequest { + if bifrostReq == nil || bifrostReq.Input == nil { + return nil + } + + params := bifrostReq.Params + + // Create the responses request with properly mapped parameters + req := &OpenAIResponsesRequest{ + Model: bifrostReq.Model, + Input: bifrostReq.Input, + } + + if params != nil { + req.ResponsesParameters = *params + } + + return req +} diff --git a/core/schemas/providers/openai/speech.go b/core/schemas/providers/openai/speech.go index 8b471a2b16..9fe1cd5133 100644 --- a/core/schemas/providers/openai/speech.go +++ b/core/schemas/providers/openai/speech.go @@ -3,41 +3,16 @@ package openai import "github.com/maximhq/bifrost/core/schemas" // ToBifrostRequest converts an OpenAI speech request to Bifrost format -func (r *OpenAISpeechRequest) ToBifrostRequest() *schemas.BifrostRequest { +func (r *OpenAISpeechRequest) ToBifrostRequest() *schemas.BifrostSpeechRequest { provider, model := schemas.ParseModelString(r.Model, schemas.OpenAI) - // Create speech input - speechInput := &schemas.SpeechInput{ - Input: r.Input, - VoiceConfig: schemas.SpeechVoiceInput{ - Voice: &r.Voice, - }, - } - - // Set response format if provided - if r.ResponseFormat != nil { - speechInput.ResponseFormat = *r.ResponseFormat - } - - // Set instructions if provided - if r.Instructions != nil { - speechInput.Instructions = *r.Instructions - } - - bifrostReq := &schemas.BifrostRequest{ + bifrostReq := &schemas.BifrostSpeechRequest{ Provider: provider, Model: model, - Input: schemas.RequestInput{ - SpeechInput: speechInput, - }, + Input: schemas.SpeechInput{Input: r.Input}, + Params: &r.SpeechParameters, } - // Convert parameters first - params := r.convertSpeechParameters() - - // Map parameters - bifrostReq.Params = filterParams(provider, params) - return bifrostReq } @@ -51,12 +26,12 @@ func ToOpenAISpeechResponse(bifrostResp *schemas.BifrostResponse) *schemas.Bifro } // ToOpenAISpeechRequest converts a Bifrost speech request to OpenAI format -func ToOpenAISpeechRequest(bifrostReq *schemas.BifrostRequest) *OpenAISpeechRequest { - if bifrostReq == nil || bifrostReq.Input.SpeechInput == nil { +func ToOpenAISpeechRequest(bifrostReq *schemas.BifrostSpeechRequest) *OpenAISpeechRequest { + if bifrostReq == nil || bifrostReq.Input.Input == "" { return nil } - speechInput := bifrostReq.Input.SpeechInput + speechInput := bifrostReq.Input params := bifrostReq.Params openaiReq := &OpenAISpeechRequest{ @@ -64,27 +39,8 @@ func ToOpenAISpeechRequest(bifrostReq *schemas.BifrostRequest) *OpenAISpeechRequ Input: speechInput.Input, } - // Set voice - if speechInput.VoiceConfig.Voice != nil { - openaiReq.Voice = *speechInput.VoiceConfig.Voice - } - - // Set optional fields - if speechInput.ResponseFormat != "" { - openaiReq.ResponseFormat = &speechInput.ResponseFormat - } - if speechInput.Instructions != "" { - openaiReq.Instructions = &speechInput.Instructions - } - - // Map parameters - if params != nil && params.ExtraParams != nil { - if speed, ok := params.ExtraParams["speed"].(float64); ok { - openaiReq.Speed = &speed - } - if streamFormat, ok := params.ExtraParams["stream_format"].(string); ok { - openaiReq.StreamFormat = &streamFormat - } + if params != nil { + openaiReq.SpeechParameters = *params } return openaiReq diff --git a/core/schemas/providers/openai/stream.go b/core/schemas/providers/openai/stream.go deleted file mode 100644 index 8c30c68a99..0000000000 --- a/core/schemas/providers/openai/stream.go +++ /dev/null @@ -1,85 +0,0 @@ -package openai - -import "github.com/maximhq/bifrost/core/schemas" - -// ToOpenAIStreamResponse converts a Bifrost response to OpenAI streaming format -func ToOpenAIChatCompletionStreamResponse(bifrostResp *schemas.BifrostResponse) *OpenAIChatCompletionStreamResponse { - if bifrostResp == nil { - return nil - } - - streamResp := &OpenAIChatCompletionStreamResponse{ - ID: bifrostResp.ID, - Object: "chat.completion.chunk", - Created: bifrostResp.Created, - Model: bifrostResp.Model, - SystemFingerprint: bifrostResp.SystemFingerprint, - Usage: bifrostResp.Usage, - } - - // Convert choices to streaming format - for _, choice := range bifrostResp.Choices { - streamChoice := OpenAIChatCompletionStreamChoice{ - Index: choice.Index, - FinishReason: choice.FinishReason, - } - - var delta *OpenAIChatCompletionStreamDelta - - // Handle streaming vs non-streaming choices - if choice.BifrostStreamResponseChoice != nil { - // This is a streaming response - use the delta directly - delta = &OpenAIChatCompletionStreamDelta{} - - // Only set fields that are not nil - if choice.BifrostStreamResponseChoice.Delta.Role != nil { - delta.Role = choice.BifrostStreamResponseChoice.Delta.Role - } - if choice.BifrostStreamResponseChoice.Delta.Content != nil { - delta.Content = choice.BifrostStreamResponseChoice.Delta.Content - } - if len(choice.BifrostStreamResponseChoice.Delta.ToolCalls) > 0 { - delta.ToolCalls = &choice.BifrostStreamResponseChoice.Delta.ToolCalls - } - } else if choice.BifrostNonStreamResponseChoice != nil { - // This is a non-streaming response - convert message to delta format - delta = &OpenAIChatCompletionStreamDelta{} - - // Convert role - role := string(choice.BifrostNonStreamResponseChoice.Message.Role) - delta.Role = &role - - // Convert content - if choice.BifrostNonStreamResponseChoice.Message.Content.ContentStr != nil { - delta.Content = choice.BifrostNonStreamResponseChoice.Message.Content.ContentStr - } - - // Convert tool calls if present (from AssistantMessage) - if choice.BifrostNonStreamResponseChoice.Message.AssistantMessage != nil && - choice.BifrostNonStreamResponseChoice.Message.AssistantMessage.ToolCalls != nil { - delta.ToolCalls = choice.BifrostNonStreamResponseChoice.Message.AssistantMessage.ToolCalls - } - - // Set LogProbs from non-streaming choice - if choice.LogProbs != nil { - streamChoice.LogProbs = choice.LogProbs - } - } - - // Ensure we have a valid delta with at least one field set - // If all fields are nil, we should skip this chunk or set an empty content - if delta != nil { - hasValidField := (delta.Role != nil) || (delta.Content != nil) || (delta.ToolCalls != nil) - if !hasValidField { - // Set empty content to ensure we have at least one field - emptyContent := "" - delta.Content = &emptyContent - } - streamChoice.Delta = delta - } - - streamResp.Choices = append(streamResp.Choices, streamChoice) - } - - return streamResp -} diff --git a/core/schemas/providers/openai/text.go b/core/schemas/providers/openai/text.go index 4755cc9b8a..35cbc9ac0f 100644 --- a/core/schemas/providers/openai/text.go +++ b/core/schemas/providers/openai/text.go @@ -5,89 +5,34 @@ import ( ) // ToOpenAITextCompletionRequest converts a Bifrost text completion request to OpenAI format -func ToOpenAITextCompletionRequest(bifrostReq *schemas.BifrostRequest) *OpenAITextCompletionRequest { - if bifrostReq == nil || bifrostReq.Input.TextCompletionInput == nil { +func ToOpenAITextCompletionRequest(bifrostReq *schemas.BifrostTextCompletionRequest) *OpenAITextCompletionRequest { + if bifrostReq == nil { return nil } + params := bifrostReq.Params + openaiReq := &OpenAITextCompletionRequest{ - Model: bifrostReq.Model, - Prompt: *bifrostReq.Input.TextCompletionInput, - ModelParameters: bifrostReq.Params, // Directly embed the parameters + Model: bifrostReq.Model, + Prompt: bifrostReq.Input, } - // Handle OpenAI-specific parameters from ExtraParams - if bifrostReq.Params != nil && bifrostReq.Params.ExtraParams != nil { - // Echo prompt - if echo, ok := bifrostReq.Params.ExtraParams["echo"].(bool); ok { - openaiReq.Echo = &echo - } - - // Best of - if bestOf, ok := bifrostReq.Params.ExtraParams["best_of"].(int); ok { - openaiReq.BestOf = &bestOf - } - - // Suffix - if suffix, ok := bifrostReq.Params.ExtraParams["suffix"].(string); ok { - openaiReq.Suffix = &suffix - } + if params != nil { + openaiReq.TextCompletionParameters = *params } return openaiReq } -// ToOpenAITextCompletionResponse converts an OpenAI text completion response to Bifrost format -func (response *OpenAITextCompletionResponse) ToBifrostResponse() *schemas.BifrostResponse { - if response == nil { +func (r *OpenAITextCompletionRequest) ToBifrostRequest() *schemas.BifrostTextCompletionRequest { + if r == nil { return nil } - // Convert choices - choices := make([]schemas.BifrostResponseChoice, 0, len(response.Choices)) - for i, choice := range response.Choices { - // Create a copy of the text to avoid pointer issues - textCopy := choice.Text - - bifrostChoice := schemas.BifrostResponseChoice{ - Index: i, - BifrostTextCompletionResponseChoice: &schemas.BifrostTextCompletionResponseChoice{ - Text: &textCopy, - }, - FinishReason: choice.FinishReason, - } - - // Add log probabilities if available - if choice.Logprobs != nil { - bifrostChoice.LogProbs = &schemas.LogProbs{ - TextCompletionLogProb: choice.Logprobs, - } - } - - choices = append(choices, bifrostChoice) + return &schemas.BifrostTextCompletionRequest{ + Provider: schemas.OpenAI, + Model: r.Model, + Input: r.Prompt, + Params: &r.TextCompletionParameters, } - - // Create the Bifrost response - bifrostResponse := &schemas.BifrostResponse{ - ID: response.ID, - Object: "list", // Standard Bifrost object type for completions - Choices: choices, - Model: response.Model, - Created: response.Created, - // Set provider outside of this function - } - - // Set system fingerprint - if response.SystemFingerprint != nil { - bifrostResponse.SystemFingerprint = response.SystemFingerprint - } - - // Set usage information - if response.Usage != nil { - // Create a copy to avoid pointer issues - usageCopy := *response.Usage - bifrostResponse.Usage = &usageCopy - } - - return bifrostResponse } diff --git a/core/schemas/providers/openai/transcription.go b/core/schemas/providers/openai/transcription.go index e242e90b19..f2019984f2 100644 --- a/core/schemas/providers/openai/transcription.go +++ b/core/schemas/providers/openai/transcription.go @@ -3,49 +3,28 @@ package openai import "github.com/maximhq/bifrost/core/schemas" // ToBifrostRequest converts an OpenAI transcription request to Bifrost format -func (r *OpenAITranscriptionRequest) ToBifrostRequest() *schemas.BifrostRequest { +func (r *OpenAITranscriptionRequest) ToBifrostRequest() *schemas.BifrostTranscriptionRequest { provider, model := schemas.ParseModelString(r.Model, schemas.OpenAI) - // Create transcription input - transcriptionInput := &schemas.TranscriptionInput{ - File: r.File, - } - - // Set optional fields - if r.Language != nil { - transcriptionInput.Language = r.Language - } - if r.Prompt != nil { - transcriptionInput.Prompt = r.Prompt - } - if r.ResponseFormat != nil { - transcriptionInput.ResponseFormat = r.ResponseFormat - } - - bifrostReq := &schemas.BifrostRequest{ + bifrostReq := &schemas.BifrostTranscriptionRequest{ Provider: provider, Model: model, - Input: schemas.RequestInput{ - TranscriptionInput: transcriptionInput, + Input: schemas.TranscriptionInput{ + File: r.File, }, + Params: &r.TranscriptionParameters, } - // Convert parameters first - params := r.convertTranscriptionParameters() - - // Map parameters - bifrostReq.Params = filterParams(provider, params) - return bifrostReq } // ToOpenAITranscriptionRequest converts a Bifrost transcription request to OpenAI format -func ToOpenAITranscriptionRequest(bifrostReq *schemas.BifrostRequest) *OpenAITranscriptionRequest { - if bifrostReq == nil || bifrostReq.Input.TranscriptionInput == nil { +func ToOpenAITranscriptionRequest(bifrostReq *schemas.BifrostTranscriptionRequest) *OpenAITranscriptionRequest { + if bifrostReq == nil || bifrostReq.Input.File == nil { return nil } - transcriptionInput := bifrostReq.Input.TranscriptionInput + transcriptionInput := bifrostReq.Input params := bifrostReq.Params openaiReq := &OpenAITranscriptionRequest{ @@ -53,42 +32,9 @@ func ToOpenAITranscriptionRequest(bifrostReq *schemas.BifrostRequest) *OpenAITra File: transcriptionInput.File, } - // Set optional fields - openaiReq.Language = transcriptionInput.Language - openaiReq.Prompt = transcriptionInput.Prompt - openaiReq.ResponseFormat = transcriptionInput.ResponseFormat - - // Map parameters - if params != nil && params.ExtraParams != nil { - if temperature, ok := params.ExtraParams["temperature"].(float64); ok { - openaiReq.Temperature = &temperature - } - if include, ok := params.ExtraParams["include"].([]string); ok { - openaiReq.Include = include - } - if timestampGranularities, ok := params.ExtraParams["timestamp_granularities"].([]string); ok { - openaiReq.TimestampGranularities = timestampGranularities - } - if stream, ok := params.ExtraParams["stream"].(bool); ok { - openaiReq.Stream = &stream - } + if params != nil { + openaiReq.TranscriptionParameters = *params } return openaiReq } - -func ToOpenAITranscriptionResponse(bifrostResp *schemas.BifrostResponse) *OpenAITranscriptionResponse { - if bifrostResp == nil { - return nil - } - - return &OpenAITranscriptionResponse{ - ID: bifrostResp.ID, - Object: bifrostResp.Object, - Created: bifrostResp.Created, - Model: bifrostResp.Model, - Transcribe: bifrostResp.Transcribe, - Usage: bifrostResp.Usage, - SystemFingerprint: bifrostResp.SystemFingerprint, - } -} diff --git a/core/schemas/providers/openai/types.go b/core/schemas/providers/openai/types.go index 66b97c760f..d3ae1c9fdc 100644 --- a/core/schemas/providers/openai/types.go +++ b/core/schemas/providers/openai/types.go @@ -4,177 +4,74 @@ import "github.com/maximhq/bifrost/core/schemas" // REQUEST TYPES -// OpenAIChatRequest represents an OpenAI chat completion request -type OpenAIChatRequest struct { - Model string `json:"model"` - Messages []schemas.BifrostMessage `json:"messages"` - - // Embed ModelParameters to avoid duplication - *schemas.ModelParameters -} - -// OpenAISpeechRequest represents an OpenAI speech synthesis request -type OpenAISpeechRequest struct { - Model string `json:"model"` - Input string `json:"input"` - Voice string `json:"voice"` - ResponseFormat *string `json:"response_format,omitempty"` - Speed *float64 `json:"speed,omitempty"` - Instructions *string `json:"instructions,omitempty"` - StreamFormat *string `json:"stream_format,omitempty"` -} - -// OpenAITranscriptionRequest represents an OpenAI transcription request -// Note: This is used for JSON body parsing, actual form parsing is handled in the router -type OpenAITranscriptionRequest struct { - Model string `json:"model"` - File []byte `json:"file"` // Binary audio data - Language *string `json:"language,omitempty"` - Prompt *string `json:"prompt,omitempty"` - ResponseFormat *string `json:"response_format,omitempty"` - Temperature *float64 `json:"temperature,omitempty"` - Include []string `json:"include,omitempty"` - TimestampGranularities []string `json:"timestamp_granularities,omitempty"` - Stream *bool `json:"stream,omitempty"` -} +// OpenAITextCompletionRequest represents an OpenAI text completion request +type OpenAITextCompletionRequest struct { + Model string `json:"model"` // Required: Model to use + Prompt schemas.TextCompletionInput `json:"prompt"` // Required: String or array of strings -// OpenAITranscriptionResponse represents an OpenAI transcription response -type OpenAITranscriptionResponse struct { - ID string `json:"id"` - Object string `json:"object"` - Created int `json:"created"` - Model string `json:"model"` - Transcribe *schemas.BifrostTranscribe `json:"transcribe"` - Usage *schemas.LLMUsage `json:"usage,omitempty"` - SystemFingerprint *string `json:"system_fingerprint,omitempty"` + schemas.TextCompletionParameters } // OpenAIEmbeddingRequest represents an OpenAI embedding request type OpenAIEmbeddingRequest struct { - Model string `json:"model"` - Input schemas.EmbeddingInput `json:"input"` // Can be string or []string - EncodingFormat *string `json:"encoding_format,omitempty"` - Dimensions *int `json:"dimensions,omitempty"` - User *string `json:"user,omitempty"` -} + Model string `json:"model"` + Input schemas.EmbeddingInput `json:"input"` // Can be string or []string -// OpenAITextCompletionRequest represents an OpenAI text completion request -type OpenAITextCompletionRequest struct { - Model string `json:"model"` // Required: Model to use - Prompt interface{} `json:"prompt"` // Required: String or array of strings - - // Embed ModelParameters to avoid duplication - *schemas.ModelParameters - - // OpenAI-specific text completion parameters not in core ModelParameters - Echo *bool `json:"echo,omitempty"` // Echo back the prompt - BestOf *int `json:"best_of,omitempty"` // Generate best_of completions server-side - Suffix *string `json:"suffix,omitempty"` // Suffix for completion + schemas.EmbeddingParameters } -// IsStreamingRequested implements the StreamingRequest interface for speech -func (r *OpenAISpeechRequest) IsStreamingRequested() bool { - return r.StreamFormat != nil && *r.StreamFormat == "sse" -} +// OpenAIChatRequest represents an OpenAI chat completion request +type OpenAIChatRequest struct { + Model string `json:"model"` + Messages []schemas.ChatMessage `json:"messages"` -// IsStreamingRequested implements the StreamingRequest interface for transcription -func (r *OpenAITranscriptionRequest) IsStreamingRequested() bool { - return r.Stream != nil && *r.Stream + schemas.ChatParameters + Stream *bool `json:"stream,omitempty"` } -// IsStreamingRequested implements the StreamingRequest interface for embeddings -// Note: Embeddings don't support streaming in OpenAI API -func (r *OpenAIEmbeddingRequest) IsStreamingRequested() bool { - return false +// IsStreamingRequested implements the StreamingRequest interface +func (r *OpenAIChatRequest) IsStreamingRequested() bool { + return r.Stream != nil && *r.Stream } -// RESPONSE TYPES -// OpenAIChatResponse represents an OpenAI chat completion response -type OpenAIChatResponse struct { - ID string `json:"id"` - Object string `json:"object"` - Created int `json:"created"` - Model string `json:"model"` - Choices []schemas.BifrostResponseChoice `json:"choices"` - Usage *schemas.LLMUsage `json:"usage,omitempty"` // Reuse schema type - ServiceTier *string `json:"service_tier,omitempty"` - SystemFingerprint *string `json:"system_fingerprint,omitempty"` -} +type OpenAIResponsesRequest struct { + Model string `json:"model"` + Input []schemas.ResponsesMessage `json:"input"` -// OpenAIStreamResponse represents a single chunk in the OpenAI streaming response -type OpenAIChatCompletionStreamResponse struct { - ID string `json:"id"` - Object string `json:"object"` - Created int `json:"created"` - Model string `json:"model"` - SystemFingerprint *string `json:"system_fingerprint,omitempty"` - Choices []OpenAIChatCompletionStreamChoice `json:"choices"` - Usage *schemas.LLMUsage `json:"usage,omitempty"` + schemas.ResponsesParameters + Stream *bool `json:"stream,omitempty"` } -// OpenAIStreamChoice represents a choice in a streaming response chunk -type OpenAIChatCompletionStreamChoice struct { - Index int `json:"index"` - Delta *OpenAIChatCompletionStreamDelta `json:"delta,omitempty"` - FinishReason *string `json:"finish_reason,omitempty"` - LogProbs *schemas.LogProbs `json:"logprobs,omitempty"` +// IsStreamingRequested implements the StreamingRequest interface +func (r *OpenAIResponsesRequest) IsStreamingRequested() bool { + return r.Stream != nil && *r.Stream } -// OpenAIStreamDelta represents the incremental content in a streaming chunk -type OpenAIChatCompletionStreamDelta struct { - Role *string `json:"role,omitempty"` - Content *string `json:"content,omitempty"` - ToolCalls *[]schemas.ToolCall `json:"tool_calls,omitempty"` -} +// OpenAISpeechRequest represents an OpenAI speech synthesis request +type OpenAISpeechRequest struct { + Model string `json:"model"` + Input string `json:"input"` -// OpenAIEmbeddingResponse represents an OpenAI embedding response -type OpenAIEmbeddingResponse struct { - Object string `json:"object"` - Data []schemas.BifrostEmbedding `json:"data"` - Model string `json:"model"` - Usage *schemas.LLMUsage `json:"usage,omitempty"` - ServiceTier *string `json:"service_tier,omitempty"` - SystemFingerprint *string `json:"system_fingerprint,omitempty"` + schemas.SpeechParameters + StreamFormat *string `json:"stream_format,omitempty"` } -// OpenAITextCompletionChoice represents a completion choice in the text completion response -type OpenAITextCompletionChoice struct { - Text string `json:"text"` // Generated completion text - Index int `json:"index"` // Index of this choice - FinishReason *string `json:"finish_reason,omitempty"` // Reason completion finished - Logprobs *schemas.TextCompletionLogProb `json:"logprobs,omitempty"` // Log probability information -} +// OpenAITranscriptionRequest represents an OpenAI transcription request +// Note: This is used for JSON body parsing, actual form parsing is handled in the router +type OpenAITranscriptionRequest struct { + Model string `json:"model"` + File []byte `json:"file"` // Binary audio data -// OpenAITextCompletionResponse represents an OpenAI text completion response -type OpenAITextCompletionResponse struct { - ID string `json:"id"` // Unique identifier - Object string `json:"object"` // Always "text_completion" - Created int `json:"created"` // Unix timestamp - Model string `json:"model"` // Model used - Choices []OpenAITextCompletionChoice `json:"choices"` // Completion choices - Usage *schemas.LLMUsage `json:"usage,omitempty"` // Token usage - SystemFingerprint *string `json:"system_fingerprint,omitempty"` // System fingerprint + schemas.TranscriptionParameters + Stream *bool `json:"stream,omitempty"` } -// ERROR TYPES -// OpenAIChatError represents an OpenAI chat completion error response -type OpenAIChatError struct { - EventID string `json:"event_id"` // Unique identifier for the error event - Type string `json:"type"` // Type of error - Error struct { - Type string `json:"type"` // Error type - Code string `json:"code"` // Error code - Message string `json:"message"` // Error message - Param interface{} `json:"param"` // Parameter that caused the error - EventID string `json:"event_id"` // Event ID for tracking - } `json:"error"` +// IsStreamingRequested implements the StreamingRequest interface for speech +func (r *OpenAISpeechRequest) IsStreamingRequested() bool { + return r.StreamFormat != nil && *r.StreamFormat == "sse" } -// OpenAIChatErrorStruct represents the error structure of an OpenAI chat completion error response -type OpenAIChatErrorStruct struct { - Type string `json:"type"` // Error type - Code string `json:"code"` // Error code - Message string `json:"message"` // Error message - Param interface{} `json:"param"` // Parameter that caused the error - EventID string `json:"event_id"` // Event ID for tracking +// IsStreamingRequested implements the StreamingRequest interface for transcription +func (r *OpenAITranscriptionRequest) IsStreamingRequested() bool { + return r.Stream != nil && *r.Stream } diff --git a/core/schemas/providers/openai/utils.go b/core/schemas/providers/openai/utils.go deleted file mode 100644 index 385b6e3611..0000000000 --- a/core/schemas/providers/openai/utils.go +++ /dev/null @@ -1,130 +0,0 @@ -package openai - -import "github.com/maximhq/bifrost/core/schemas" - -func filterParams(provider schemas.ModelProvider, p *schemas.ModelParameters) *schemas.ModelParameters { - if p == nil { - return nil - } - return schemas.ValidateAndFilterParamsForProvider(provider, p) -} - -// convertParameters converts OpenAI request parameters to Bifrost ModelParameters -// using direct field access for better performance and type safety. -func (r *OpenAIChatRequest) convertParameters() *schemas.ModelParameters { - params := &schemas.ModelParameters{ - ExtraParams: make(map[string]interface{}), - } - - params.Tools = r.Tools - params.ToolChoice = r.ToolChoice - - // Direct field mapping - if r.MaxTokens != nil { - params.MaxTokens = r.MaxTokens - } - if r.Temperature != nil { - params.Temperature = r.Temperature - } - if r.TopP != nil { - params.TopP = r.TopP - } - if r.PresencePenalty != nil { - params.PresencePenalty = r.PresencePenalty - } - if r.FrequencyPenalty != nil { - params.FrequencyPenalty = r.FrequencyPenalty - } - if r.N != nil { - params.N = r.N - } - if r.LogProbs != nil { - params.LogProbs = r.LogProbs - } - if r.TopLogProbs != nil { - params.TopLogProbs = r.TopLogProbs - } - if r.Stop != nil { - params.Stop = r.Stop - } - if r.LogitBias != nil { - params.LogitBias = r.LogitBias - } - if r.User != nil { - params.User = r.User - } - if r.Stream != nil { - params.Stream = r.Stream - } - if r.Seed != nil { - params.Seed = r.Seed - } - if r.StreamOptions != nil { - params.StreamOptions = r.StreamOptions - } - if r.ResponseFormat != nil { - params.ResponseFormat = r.ResponseFormat - } - if r.MaxCompletionTokens != nil { - params.MaxCompletionTokens = r.MaxCompletionTokens - } - if r.ReasoningEffort != nil { - params.ReasoningEffort = r.ReasoningEffort - } - - return params -} - -// convertSpeechParameters converts OpenAI speech request parameters to Bifrost ModelParameters -func (r *OpenAISpeechRequest) convertSpeechParameters() *schemas.ModelParameters { - params := &schemas.ModelParameters{ - ExtraParams: make(map[string]interface{}), - } - - // Add speech-specific parameters - if r.Speed != nil { - params.ExtraParams["speed"] = *r.Speed - } - - return params -} - -// convertTranscriptionParameters converts OpenAI transcription request parameters to Bifrost ModelParameters -func (r *OpenAITranscriptionRequest) convertTranscriptionParameters() *schemas.ModelParameters { - params := &schemas.ModelParameters{ - ExtraParams: make(map[string]interface{}), - } - - // Add transcription-specific parameters - if r.Temperature != nil { - params.ExtraParams["temperature"] = *r.Temperature - } - if len(r.TimestampGranularities) > 0 { - params.ExtraParams["timestamp_granularities"] = r.TimestampGranularities - } - if len(r.Include) > 0 { - params.ExtraParams["include"] = r.Include - } - - return params -} - -// convertEmbeddingParameters converts OpenAI embedding request parameters to Bifrost ModelParameters -func (r *OpenAIEmbeddingRequest) convertEmbeddingParameters() *schemas.ModelParameters { - params := &schemas.ModelParameters{ - ExtraParams: make(map[string]interface{}), - } - - // Add embedding-specific parameters - if r.EncodingFormat != nil { - params.EncodingFormat = r.EncodingFormat - } - if r.Dimensions != nil { - params.Dimensions = r.Dimensions - } - if r.User != nil { - params.User = r.User - } - - return params -} diff --git a/core/schemas/providers/vertex/embedding.go b/core/schemas/providers/vertex/embedding.go index 28b7ffaddc..9372d8c9c7 100644 --- a/core/schemas/providers/vertex/embedding.go +++ b/core/schemas/providers/vertex/embedding.go @@ -5,21 +5,16 @@ import ( ) // ToVertexEmbeddingRequest converts a Bifrost embedding request to Vertex AI format -func ToVertexEmbeddingRequest(bifrostReq *schemas.BifrostRequest) *VertexEmbeddingRequest { - if bifrostReq == nil || bifrostReq.Input.EmbeddingInput == nil { +func ToVertexEmbeddingRequest(bifrostReq *schemas.BifrostEmbeddingRequest) *VertexEmbeddingRequest { + if bifrostReq == nil || (bifrostReq.Input.Text == nil && bifrostReq.Input.Texts == nil) { return nil } - embeddingInput := bifrostReq.Input.EmbeddingInput - texts := embeddingInput.Texts - - // Handle single text input - if len(texts) == 0 && embeddingInput.Text != nil { - texts = []string{*embeddingInput.Text} - } - - if len(texts) == 0 { - return nil + var texts []string + if bifrostReq.Input.Text != nil { + texts = []string{*bifrostReq.Input.Text} + } else { + texts = bifrostReq.Input.Texts } // Create instances for each text @@ -30,12 +25,12 @@ func ToVertexEmbeddingRequest(bifrostReq *schemas.BifrostRequest) *VertexEmbeddi } // Add optional task_type and title from params - if bifrostReq.Params != nil && bifrostReq.Params.ExtraParams != nil { - if taskTypeStr, exists := bifrostReq.Params.ExtraParams["task_type"].(string); exists { - instance.TaskType = &taskTypeStr + if bifrostReq.Params != nil { + if taskTypeStr, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["task_type"]); ok { + instance.TaskType = taskTypeStr } - if title, exists := bifrostReq.Params.ExtraParams["title"].(string); exists { - instance.Title = &title + if title, ok := schemas.SafeExtractStringPointer(bifrostReq.Params.ExtraParams["title"]); ok { + instance.Title = title } } @@ -54,7 +49,7 @@ func ToVertexEmbeddingRequest(bifrostReq *schemas.BifrostRequest) *VertexEmbeddi // Set autoTruncate (defaults to true) autoTruncate := true if bifrostReq.Params.ExtraParams != nil { - if autoTruncateVal, exists := bifrostReq.Params.ExtraParams["autoTruncate"].(bool); exists { + if autoTruncateVal, ok := schemas.SafeExtractBool(bifrostReq.Params.ExtraParams["autoTruncate"]); ok { autoTruncate = autoTruncateVal } } @@ -119,7 +114,8 @@ func (vertexResp *VertexEmbeddingResponse) ToBifrostResponse() *schemas.BifrostR Data: embeddings, Usage: usage, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: schemas.Vertex, + RequestType: schemas.EmbeddingRequest, + Provider: schemas.Vertex, }, } diff --git a/core/schemas/responses.go b/core/schemas/responses.go new file mode 100644 index 0000000000..6872ce4d75 --- /dev/null +++ b/core/schemas/responses.go @@ -0,0 +1,1488 @@ +package schemas + +import ( + "fmt" + "maps" + + "github.com/bytedance/sonic" +) + +// ============================================================================= +// OPENAI RESPONSES API SCHEMAS +// ============================================================================= +// +// This file contains all the schema definitions for the OpenAI Responses API. +// +// Structure: +// 1. Core API Request/Response Structures +// 2. Input Message Structures +// 3. Output Message Structures +// 4. Tool Call Structures (organized by tool type) +// 5. Tool Configuration Structures +// 6. Tool Choice Configuration +// +// Union Types: +// - Many structs use "union types" where only one field should be set +// - These are implemented with pointer fields and custom JSON marshaling +// ============================================================================= + +// ============================================================================= +// 1. CORE API REQUEST/RESPONSE STRUCTURES +// ============================================================================= + +type ResponsesParameters struct { + Background *bool `json:"background,omitempty"` + Conversation *string `json:"conversation,omitempty"` + Include *[]string `json:"include,omitempty"` // Supported values: "web_search_call.action.sources", "code_interpreter_call.outputs", "computer_call_output.output.image_url", "file_search_call.results", "message.input_image.image_url", "message.output_text.logprobs", "reasoning.encrypted_content" + Instructions *string `json:"instructions,omitempty"` + MaxOutputTokens *int `json:"max_output_tokens,omitempty"` + MaxToolCalls *int `json:"max_tool_calls,omitempty"` + Metadata *map[string]any `json:"metadata,omitempty"` + ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` + PreviousResponseID *string `json:"previous_response_id,omitempty"` + PromptCacheKey *string `json:"prompt_cache_key,omitempty"` // Prompt cache key + Reasoning *ResponsesParametersReasoning `json:"reasoning,omitempty"` // Configuration options for reasoning models + SafetyIdentifier *string `json:"safety_identifier,omitempty"` // Safety identifier + ServiceTier *string `json:"service_tier,omitempty"` + StreamOptions *ResponsesStreamOptions `json:"stream_options,omitempty"` + Store *bool `json:"store,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + Text *ResponsesTextConfig `json:"text,omitempty"` + TopLogProbs *int `json:"top_logprobs,omitempty"` + TopP *float64 `json:"top_p,omitempty"` // Controls diversity via nucleus sampling + ToolChoice *ResponsesToolChoice `json:"tool_choice,omitempty"` // Whether to call a tool + Tools []ResponsesTool `json:"tools,omitempty"` // Tools to use + Truncation *string `json:"truncation,omitempty"` + + // Dynamic parameters that can be provider-specific, they are directly + // added to the request as is. + ExtraParams map[string]interface{} `json:"-"` +} + +type ResponsesStreamOptions struct { + IncludeObfuscation *bool `json:"include_obfuscation,omitempty"` +} + +type ResponsesTextConfig struct { + Format *ResponsesTextConfigFormat `json:"format,omitempty"` // An object specifying the format that the model must output + Verbosity *string `json:"verbosity,omitempty"` // "low" | "medium" | "high" or null +} + +type ResponsesTextConfigFormat struct { + Type string `json:"type"` // "text" | "json_schema" | "json_object" + JSONSchema *ResponsesTextConfigFormatJSONSchema `json:"json_schema,omitempty"` // when type == "json_schema" +} + +// ResponsesTextConfigFormatJSONSchema represents a JSON schema specification +type ResponsesTextConfigFormatJSONSchema struct { + Name string `json:"name"` + Schema map[string]any `json:"schema"` // JSON Schema (subset) + Type string `json:"type"` // always "json_schema" + Description *string `json:"description,omitempty"` + Strict *bool `json:"strict,omitempty"` +} + +type ResponsesResponse struct { + Tools *[]ResponsesTool `json:"tools,omitempty"` + ToolChoice *ResponsesToolChoice `json:"tool_choice,omitempty"` + + ResponsesParameters + + CreatedAt int `json:"created_at"` // Unix timestamp when Response was created + Conversation *ResponsesResponseConversation `json:"conversation,omitempty"` // The conversation that this response belongs to + IncompleteDetails *ResponsesResponseIncompleteDetails `json:"incomplete_details,omitempty"` // Details about why the response is incomplete + Instructions *[]ResponsesMessage `json:"instructions,omitempty"` + Output []ResponsesMessage `json:"output,omitempty"` + Prompt *ResponsesPrompt `json:"prompt,omitempty"` // Reference to a prompt template and variables + Reasoning *ResponsesParametersReasoning `json:"reasoning,omitempty"` // Configuration options for reasoning models +} + +type ResponsesPrompt struct { + ID string `json:"id"` + Variables map[string]any `json:"variables"` + Version *string `json:"version,omitempty"` +} + +type ResponsesParametersReasoning struct { + Effort *string `json:"effort,omitempty"` // "minimal" | "low" | "medium" | "high" + GenerateSummary *string `json:"generate_summary,omitempty"` // Deprecated: use summary instead + Summary *string `json:"summary,omitempty"` // "auto" | "concise" | "detailed" +} + +type ResponsesResponseConversation struct { + ID string `json:"id"` // The unique ID of the conversation +} + +type ResponsesResponseError struct { + Code string `json:"code"` // The error code for the response + Message string `json:"message"` // A human-readable description of the error +} + +type ResponsesResponseIncompleteDetails struct { + Reason string `json:"reason"` // The reason why the response is incomplete +} + +type ResponsesExtendedResponseUsage struct { + InputTokens int `json:"input_tokens"` // Number of input tokens + InputTokensDetails *ResponsesResponseInputTokens `json:"input_tokens_details"` // Detailed breakdown of input tokens + OutputTokens int `json:"output_tokens"` // Number of output tokens + OutputTokensDetails *ResponsesResponseOutputTokens `json:"output_tokens_details"` // Detailed breakdown of output tokens +} + +type ResponsesResponseUsage struct { + *ResponsesExtendedResponseUsage + TotalTokens int `json:"total_tokens"` // Total number of tokens used +} + +type ResponsesResponseInputTokens struct { + CachedTokens int `json:"cached_tokens"` // Tokens retrieved from cache +} + +type ResponsesResponseOutputTokens struct { + ReasoningTokens int `json:"reasoning_tokens"` // Number of reasoning tokens +} + +// ============================================================================= +// 2. INPUT MESSAGE STRUCTURES +// ============================================================================= + +type ResponsesMessageType string + +const ( + ResponsesMessageTypeMessage ResponsesMessageType = "message" + ResponsesMessageTypeFileSearchCall ResponsesMessageType = "file_search_call" + ResponsesMessageTypeComputerCall ResponsesMessageType = "computer_call" + ResponsesMessageTypeComputerCallOutput ResponsesMessageType = "computer_call_output" + ResponsesMessageTypeWebSearchCall ResponsesMessageType = "web_search_call" + ResponsesMessageTypeFunctionCall ResponsesMessageType = "function_call" + ResponsesMessageTypeFunctionCallOutput ResponsesMessageType = "function_call_output" + ResponsesMessageTypeCodeInterpreterCall ResponsesMessageType = "code_interpreter_call" + ResponsesMessageTypeLocalShellCall ResponsesMessageType = "local_shell_call" + ResponsesMessageTypeLocalShellCallOutput ResponsesMessageType = "local_shell_call_output" + ResponsesMessageTypeMCPCall ResponsesMessageType = "mcp_call" + ResponsesMessageTypeCustomToolCall ResponsesMessageType = "custom_tool_call" + ResponsesMessageTypeCustomToolCallOutput ResponsesMessageType = "custom_tool_call_output" + ResponsesMessageTypeImageGenerationCall ResponsesMessageType = "image_generation_call" + ResponsesMessageTypeMCPListTools ResponsesMessageType = "mcp_list_tools" + ResponsesMessageTypeMCPApprovalRequest ResponsesMessageType = "mcp_approval_request" + ResponsesMessageTypeMCPApprovalResponses ResponsesMessageType = "mcp_approval_responses" + ResponsesMessageTypeReasoning ResponsesMessageType = "reasoning" + ResponsesMessageTypeItemReference ResponsesMessageType = "item_reference" + ResponsesMessageTypeRefusal ResponsesMessageType = "refusal" +) + +// ResponsesMessage is a union type that can contain different types of input items +// Only one of the fields should be set at a time +type ResponsesMessage struct { + ID *string `json:"id,omitempty"` // Common ID field for most item types + Type *ResponsesMessageType `json:"type,omitempty"` + Status *string `json:"status,omitempty"` // "in_progress" | "completed" | "incomplete" | "interpreting" | "failed" + + Role *ResponsesMessageRoleType `json:"role,omitempty"` + Content *ResponsesMessageContent `json:"content,omitempty"` + + *ResponsesToolMessage // For Tool calls and outputs + + // Reasoning + *ResponsesReasoning +} + +// UnmarshalJSON implements custom JSON unmarshalling for ResponsesMessage. +// It handles the embedded pointer fields by initializing them based on the message type. +func (rm *ResponsesMessage) UnmarshalJSON(data []byte) error { + // First unmarshal into a temporary struct to avoid recursion and get the type + type tempResponsesMessage struct { + ID *string `json:"id,omitempty"` + Type *ResponsesMessageType `json:"type,omitempty"` + Status *string `json:"status,omitempty"` + Role *ResponsesMessageRoleType `json:"role,omitempty"` + Content *ResponsesMessageContent `json:"content,omitempty"` + } + + var temp tempResponsesMessage + if err := sonic.Unmarshal(data, &temp); err != nil { + return err + } + + // Assign the basic fields + rm.ID = temp.ID + rm.Type = temp.Type + rm.Status = temp.Status + rm.Role = temp.Role + rm.Content = temp.Content + + // Based on the message type, initialize the appropriate embedded struct + if temp.Type != nil { + switch *temp.Type { + case ResponsesMessageTypeFileSearchCall: + rm.ResponsesToolMessage = &ResponsesToolMessage{ + ResponsesFileSearchToolCall: &ResponsesFileSearchToolCall{}, + } + case ResponsesMessageTypeComputerCall: + rm.ResponsesToolMessage = &ResponsesToolMessage{ + ResponsesComputerToolCall: &ResponsesComputerToolCall{}, + } + case ResponsesMessageTypeComputerCallOutput: + rm.ResponsesToolMessage = &ResponsesToolMessage{ + ResponsesComputerToolCallOutput: &ResponsesComputerToolCallOutput{}, + } + case ResponsesMessageTypeWebSearchCall: + rm.ResponsesToolMessage = &ResponsesToolMessage{ + ResponsesWebSearchToolCall: &ResponsesWebSearchToolCall{}, + } + case ResponsesMessageTypeFunctionCall: + rm.ResponsesToolMessage = &ResponsesToolMessage{} + case ResponsesMessageTypeFunctionCallOutput: + rm.ResponsesToolMessage = &ResponsesToolMessage{ + ResponsesFunctionToolCallOutput: &ResponsesFunctionToolCallOutput{}, + } + case ResponsesMessageTypeCodeInterpreterCall: + rm.ResponsesToolMessage = &ResponsesToolMessage{ + ResponsesCodeInterpreterToolCall: &ResponsesCodeInterpreterToolCall{}, + } + case ResponsesMessageTypeLocalShellCall: + rm.ResponsesToolMessage = &ResponsesToolMessage{ + ResponsesLocalShellCall: &ResponsesLocalShellCall{}, + } + case ResponsesMessageTypeLocalShellCallOutput: + rm.ResponsesToolMessage = &ResponsesToolMessage{ + ResponsesLocalShellCallOutput: &ResponsesLocalShellCallOutput{}, + } + case ResponsesMessageTypeMCPCall: + rm.ResponsesToolMessage = &ResponsesToolMessage{ + ResponsesMCPToolCall: &ResponsesMCPToolCall{}, + } + case ResponsesMessageTypeCustomToolCall: + rm.ResponsesToolMessage = &ResponsesToolMessage{ + ResponsesCustomToolCall: &ResponsesCustomToolCall{}, + } + case ResponsesMessageTypeCustomToolCallOutput: + rm.ResponsesToolMessage = &ResponsesToolMessage{ + ResponsesCustomToolCallOutput: &ResponsesCustomToolCallOutput{}, + } + case ResponsesMessageTypeImageGenerationCall: + rm.ResponsesToolMessage = &ResponsesToolMessage{ + ResponsesImageGenerationCall: &ResponsesImageGenerationCall{}, + } + case ResponsesMessageTypeMCPListTools: + rm.ResponsesToolMessage = &ResponsesToolMessage{ + ResponsesMCPListTools: &ResponsesMCPListTools{}, + } + case ResponsesMessageTypeMCPApprovalRequest: + rm.ResponsesToolMessage = &ResponsesToolMessage{ + ResponsesMCPApprovalRequest: &ResponsesMCPApprovalRequest{}, + } + case ResponsesMessageTypeMCPApprovalResponses: + rm.ResponsesToolMessage = &ResponsesToolMessage{ + ResponsesMCPApprovalResponse: &ResponsesMCPApprovalResponse{}, + } + case ResponsesMessageTypeReasoning: + rm.ResponsesReasoning = &ResponsesReasoning{} + case ResponsesMessageTypeMessage, ResponsesMessageTypeItemReference, ResponsesMessageTypeRefusal: + // Regular message types, no embedded structs needed + return nil + default: + // Unknown type, try to unmarshal basic tool message fields if present + rm.ResponsesToolMessage = &ResponsesToolMessage{} + } + + // Now unmarshal the tool message fields + if rm.ResponsesToolMessage != nil { + // First unmarshal basic tool message fields (call_id, name, arguments) + if err := sonic.Unmarshal(data, rm.ResponsesToolMessage); err != nil { + return fmt.Errorf("failed to unmarshal tool message: %v", err) + } + + // Then unmarshal into specific embedded structs based on message type + switch *temp.Type { + case ResponsesMessageTypeFileSearchCall: + if rm.ResponsesToolMessage.ResponsesFileSearchToolCall != nil { + if err := sonic.Unmarshal(data, rm.ResponsesToolMessage.ResponsesFileSearchToolCall); err != nil { + return fmt.Errorf("failed to unmarshal file search tool call: %v", err) + } + } + case ResponsesMessageTypeComputerCall: + if rm.ResponsesToolMessage.ResponsesComputerToolCall != nil { + if err := sonic.Unmarshal(data, rm.ResponsesToolMessage.ResponsesComputerToolCall); err != nil { + return fmt.Errorf("failed to unmarshal computer tool call: %v", err) + } + } + case ResponsesMessageTypeComputerCallOutput: + if rm.ResponsesToolMessage.ResponsesComputerToolCallOutput != nil { + if err := sonic.Unmarshal(data, rm.ResponsesToolMessage.ResponsesComputerToolCallOutput); err != nil { + return fmt.Errorf("failed to unmarshal computer tool call output: %v", err) + } + } + case ResponsesMessageTypeWebSearchCall: + if rm.ResponsesToolMessage.ResponsesWebSearchToolCall != nil { + if err := sonic.Unmarshal(data, rm.ResponsesToolMessage.ResponsesWebSearchToolCall); err != nil { + return fmt.Errorf("failed to unmarshal web search tool call: %v", err) + } + } + case ResponsesMessageTypeFunctionCallOutput: + if rm.ResponsesToolMessage.ResponsesFunctionToolCallOutput != nil { + if err := sonic.Unmarshal(data, rm.ResponsesToolMessage.ResponsesFunctionToolCallOutput); err != nil { + return fmt.Errorf("failed to unmarshal function tool call output: %v", err) + } + } + case ResponsesMessageTypeCodeInterpreterCall: + if rm.ResponsesToolMessage.ResponsesCodeInterpreterToolCall != nil { + if err := sonic.Unmarshal(data, rm.ResponsesToolMessage.ResponsesCodeInterpreterToolCall); err != nil { + return fmt.Errorf("failed to unmarshal code interpreter tool call: %v", err) + } + } + case ResponsesMessageTypeLocalShellCall: + if rm.ResponsesToolMessage.ResponsesLocalShellCall != nil { + if err := sonic.Unmarshal(data, rm.ResponsesToolMessage.ResponsesLocalShellCall); err != nil { + return fmt.Errorf("failed to unmarshal local shell call: %v", err) + } + } + case ResponsesMessageTypeLocalShellCallOutput: + if rm.ResponsesToolMessage.ResponsesLocalShellCallOutput != nil { + if err := sonic.Unmarshal(data, rm.ResponsesToolMessage.ResponsesLocalShellCallOutput); err != nil { + return fmt.Errorf("failed to unmarshal local shell call output: %v", err) + } + } + case ResponsesMessageTypeMCPCall: + if rm.ResponsesToolMessage.ResponsesMCPToolCall != nil { + if err := sonic.Unmarshal(data, rm.ResponsesToolMessage.ResponsesMCPToolCall); err != nil { + return fmt.Errorf("failed to unmarshal MCP tool call: %v", err) + } + } + case ResponsesMessageTypeCustomToolCall: + if rm.ResponsesToolMessage.ResponsesCustomToolCall != nil { + if err := sonic.Unmarshal(data, rm.ResponsesToolMessage.ResponsesCustomToolCall); err != nil { + return fmt.Errorf("failed to unmarshal custom tool call: %v", err) + } + } + case ResponsesMessageTypeCustomToolCallOutput: + if rm.ResponsesToolMessage.ResponsesCustomToolCallOutput != nil { + if err := sonic.Unmarshal(data, rm.ResponsesToolMessage.ResponsesCustomToolCallOutput); err != nil { + return fmt.Errorf("failed to unmarshal custom tool call output: %v", err) + } + } + case ResponsesMessageTypeImageGenerationCall: + if rm.ResponsesToolMessage.ResponsesImageGenerationCall != nil { + if err := sonic.Unmarshal(data, rm.ResponsesToolMessage.ResponsesImageGenerationCall); err != nil { + return fmt.Errorf("failed to unmarshal image generation call: %v", err) + } + } + case ResponsesMessageTypeMCPListTools: + if rm.ResponsesToolMessage.ResponsesMCPListTools != nil { + if err := sonic.Unmarshal(data, rm.ResponsesToolMessage.ResponsesMCPListTools); err != nil { + return fmt.Errorf("failed to unmarshal MCP list tools: %v", err) + } + } + case ResponsesMessageTypeMCPApprovalRequest: + if rm.ResponsesToolMessage.ResponsesMCPApprovalRequest != nil { + if err := sonic.Unmarshal(data, rm.ResponsesToolMessage.ResponsesMCPApprovalRequest); err != nil { + return fmt.Errorf("failed to unmarshal MCP approval request: %v", err) + } + } + case ResponsesMessageTypeMCPApprovalResponses: + if rm.ResponsesToolMessage.ResponsesMCPApprovalResponse != nil { + if err := sonic.Unmarshal(data, rm.ResponsesToolMessage.ResponsesMCPApprovalResponse); err != nil { + return fmt.Errorf("failed to unmarshal MCP approval response: %v", err) + } + } + // Note: ResponsesMessageTypeFunctionCall only needs basic fields (handled above) + } + } + + if rm.ResponsesReasoning != nil { + if err := sonic.Unmarshal(data, rm.ResponsesReasoning); err != nil { + return fmt.Errorf("failed to unmarshal reasoning: %v", err) + } + } + } + + return nil +} + +// MarshalJSON implements custom JSON marshalling for ResponsesMessage. +// It handles the embedded pointer fields by only marshaling non-nil fields. +func (rm ResponsesMessage) MarshalJSON() ([]byte, error) { + // Start with the base fields + result := make(map[string]interface{}) + + if rm.ID != nil { + result["id"] = *rm.ID + } + if rm.Type != nil { + result["type"] = *rm.Type + } + if rm.Status != nil { + result["status"] = *rm.Status + } + if rm.Role != nil { + result["role"] = *rm.Role + } + if rm.Content != nil { + result["content"] = rm.Content + } + + // Add tool message fields if present + if rm.ResponsesToolMessage != nil { + toolData, err := sonic.Marshal(rm.ResponsesToolMessage) + if err != nil { + return nil, fmt.Errorf("failed to marshal tool message: %v", err) + } + + var toolFields map[string]interface{} + if err := sonic.Unmarshal(toolData, &toolFields); err != nil { + return nil, fmt.Errorf("failed to unmarshal tool data: %v", err) + } + + // Merge tool fields into result + maps.Copy(result, toolFields) + } + + // Add reasoning fields if present + if rm.ResponsesReasoning != nil { + reasoningData, err := sonic.Marshal(rm.ResponsesReasoning) + if err != nil { + return nil, fmt.Errorf("failed to marshal reasoning: %v", err) + } + + var reasoningFields map[string]interface{} + if err := sonic.Unmarshal(reasoningData, &reasoningFields); err != nil { + return nil, fmt.Errorf("failed to unmarshal reasoning data: %v", err) + } + + // Merge reasoning fields into result + maps.Copy(result, reasoningFields) + } + + return sonic.Marshal(result) +} + +type ResponsesMessageRoleType string + +const ( + ResponsesInputMessageRoleAssistant ResponsesMessageRoleType = "assistant" + ResponsesInputMessageRoleUser ResponsesMessageRoleType = "user" + ResponsesInputMessageRoleSystem ResponsesMessageRoleType = "system" + ResponsesInputMessageRoleDeveloper ResponsesMessageRoleType = "developer" +) + +// ResponsesInputMessageContent is a union type that can be either a string or array of content blocks +type ResponsesMessageContent struct { + ContentStr *string // Simple text content + ContentBlocks *[]ResponsesMessageContentBlock // Rich content with multiple media types +} + +// MarshalJSON implements custom JSON marshalling for ResponsesMessageContent. +// It marshals either ContentStr or ContentBlocks directly without wrapping. +func (rc ResponsesMessageContent) MarshalJSON() ([]byte, error) { + // Validation: ensure only one field is set at a time + if rc.ContentStr != nil && rc.ContentBlocks != nil { + return nil, fmt.Errorf("both ResponsesMessageContentStr and ResponsesMessageContentBlocks are set; only one should be non-nil") + } + + if rc.ContentStr != nil { + return sonic.Marshal(*rc.ContentStr) + } + if rc.ContentBlocks != nil { + return sonic.Marshal(*rc.ContentBlocks) + } + // If both are nil, return null + return sonic.Marshal(nil) +} + +// UnmarshalJSON implements custom JSON unmarshalling for ResponsesMessageContent. +// It determines whether "content" is a string or array and assigns to the appropriate field. +// It also handles direct string/array content without a wrapper object. +func (rc *ResponsesMessageContent) UnmarshalJSON(data []byte) error { + // First, try to unmarshal as a direct string + var stringContent string + if err := sonic.Unmarshal(data, &stringContent); err == nil { + rc.ContentStr = &stringContent + return nil + } + + // Try to unmarshal as a direct array of ContentBlock + var arrayContent []ResponsesMessageContentBlock + if err := sonic.Unmarshal(data, &arrayContent); err == nil { + rc.ContentBlocks = &arrayContent + return nil + } + + return fmt.Errorf("content field is neither a string nor an array of Content blocks") +} + +type ResponsesMessageContentBlockType string + +const ( + ResponsesInputMessageContentBlockTypeText ResponsesMessageContentBlockType = "input_text" + ResponsesInputMessageContentBlockTypeImage ResponsesMessageContentBlockType = "input_image" + ResponsesInputMessageContentBlockTypeFile ResponsesMessageContentBlockType = "input_file" + ResponsesInputMessageContentBlockTypeAudio ResponsesMessageContentBlockType = "input_audio" + ResponsesOutputMessageContentTypeText ResponsesMessageContentBlockType = "output_text" + ResponsesOutputMessageContentTypeRefusal ResponsesMessageContentBlockType = "refusal" + ResponsesOutputMessageContentTypeReasoning ResponsesMessageContentBlockType = "reasoning_text" +) + +// ResponsesMessageContentBlock represents different types of content (text, image, file, audio) +// Only one of the content type fields should be set +type ResponsesMessageContentBlock struct { + Type ResponsesMessageContentBlockType `json:"type"` + FileID *string `json:"file_id,omitempty"` // Reference to uploaded file + Text *string `json:"text,omitempty"` + + *ResponsesInputMessageContentBlockImage + *ResponsesInputMessageContentBlockFile + Audio *ResponsesInputMessageContentBlockAudio `json:"input_audio,omitempty"` + + *ResponsesOutputMessageContentText // Normal text output from the model + *ResponsesOutputMessageContentRefusal // Model refusal to answer +} + +type ResponsesInputMessageContentBlockImage struct { + ImageURL *string `json:"image_url,omitempty"` + Detail *string `json:"detail,omitempty"` // "low" | "high" | "auto" +} + +type ResponsesInputMessageContentBlockFile struct { + FileData *string `json:"file_data,omitempty"` // Base64 encoded file data + FileURL *string `json:"file_url,omitempty"` // Direct URL to file + Filename *string `json:"filename,omitempty"` // Name of the file +} + +type ResponsesInputMessageContentBlockAudio struct { + Format string `json:"format"` // "mp3" or "wav" + Data string `json:"data"` // base64 encoded audio data +} + +// ============================================================================= +// 3. OUTPUT MESSAGE STRUCTURES +// ============================================================================= + +type ResponsesOutputMessageContentText struct { + Annotations *[]ResponsesOutputMessageContentTextAnnotation `json:"annotations,omitempty"` // Citations and references + LogProbs *[]ResponsesOutputMessageContentTextLogProb `json:"logprobs,omitempty"` // Token log probabilities +} + +type ResponsesOutputMessageContentTextAnnotation struct { + Type string `json:"type"` // "file_citation" | "url_citation" | "container_file_citation" | "file_path" + Index *int `json:"index,omitempty"` // Common index field (FileCitation, FilePath) + FileID *string `json:"file_id,omitempty"` // Common file ID field (FileCitation, ContainerFileCitation, FilePath) + StartIndex *int `json:"start_index,omitempty"` // Common start index field (URLCitation, ContainerFileCitation) + EndIndex *int `json:"end_index,omitempty"` // Common end index field (URLCitation, ContainerFileCitation) + Filename *string `json:"filename,omitempty"` + Title *string `json:"title,omitempty"` + URL *string `json:"url,omitempty"` + ContainerID *string `json:"container_id,omitempty"` +} + +// ResponsesOutputMessageContentTextLogProb represents log probability information for content. +type ResponsesOutputMessageContentTextLogProb struct { + Bytes []int `json:"bytes"` + LogProb float64 `json:"logprob"` + Token string `json:"token"` + TopLogProbs []LogProb `json:"top_logprobs"` +} +type ResponsesOutputMessageContentRefusal struct { + Refusal string `json:"refusal"` +} + +type ResponsesToolMessage struct { + CallID *string `json:"call_id,omitempty"` // Common call ID for tool calls and outputs + Name *string `json:"name,omitempty"` // Common name field for tool calls + Arguments *string `json:"arguments,omitempty"` + + // Tool calls and outputs + *ResponsesFileSearchToolCall + *ResponsesComputerToolCall + *ResponsesComputerToolCallOutput + *ResponsesWebSearchToolCall + *ResponsesFunctionToolCallOutput + *ResponsesCodeInterpreterToolCall + *ResponsesLocalShellCall + *ResponsesLocalShellCallOutput + *ResponsesMCPToolCall + *ResponsesCustomToolCall + *ResponsesCustomToolCallOutput + *ResponsesImageGenerationCall + + // MCP-specific + *ResponsesMCPListTools + *ResponsesMCPApprovalRequest + *ResponsesMCPApprovalResponse +} + +// UnmarshalJSON implements custom JSON unmarshalling for ResponsesToolMessage. +// This prevents embedded pointer fields from interfering with basic field unmarshaling. +func (rtm *ResponsesToolMessage) UnmarshalJSON(data []byte) error { + // Use a simple struct to unmarshal basic fields without embedded interference + type basicToolMessage struct { + CallID *string `json:"call_id,omitempty"` + Name *string `json:"name,omitempty"` + Arguments *string `json:"arguments,omitempty"` + } + + var basic basicToolMessage + if err := sonic.Unmarshal(data, &basic); err != nil { + return err + } + + // Assign the basic fields + rtm.CallID = basic.CallID + rtm.Name = basic.Name + rtm.Arguments = basic.Arguments + + // Embedded field unmarshaling is handled by the parent ResponsesMessage.UnmarshalJSON + // based on the message type - no need to duplicate logic here + + return nil +} + +// MarshalJSON implements custom JSON marshalling for ResponsesToolMessage. +// It only marshals the basic fields and skips nil embedded pointers to prevent auto-generated +// marshalling from dereferencing them. The parent ResponsesMessage.MarshalJSON already handles +// merging embedded struct fields using the same pattern. +func (rtm ResponsesToolMessage) MarshalJSON() ([]byte, error) { + result := make(map[string]interface{}) + + // Only marshal basic fields + if rtm.CallID != nil { + result["call_id"] = *rtm.CallID + } + if rtm.Name != nil { + result["name"] = *rtm.Name + } + if rtm.Arguments != nil { + result["arguments"] = *rtm.Arguments + } + + // Helper to marshal and merge embedded struct + mergeEmbedded := func(v interface{}) error { + data, err := sonic.Marshal(v) + if err != nil { + return err + } + var fields map[string]interface{} + if err := sonic.Unmarshal(data, &fields); err != nil { + return err + } + maps.Copy(result, fields) + return nil + } + + // Marshal each embedded pointer field only if non-nil + // Note: We check each field explicitly because nil pointers in interface{} don't compare to nil + if rtm.ResponsesFileSearchToolCall != nil { + if err := mergeEmbedded(rtm.ResponsesFileSearchToolCall); err != nil { + return nil, err + } + } + if rtm.ResponsesComputerToolCall != nil { + if err := mergeEmbedded(rtm.ResponsesComputerToolCall); err != nil { + return nil, err + } + } + if rtm.ResponsesComputerToolCallOutput != nil { + if err := mergeEmbedded(rtm.ResponsesComputerToolCallOutput); err != nil { + return nil, err + } + } + if rtm.ResponsesWebSearchToolCall != nil { + if err := mergeEmbedded(rtm.ResponsesWebSearchToolCall); err != nil { + return nil, err + } + } + if rtm.ResponsesFunctionToolCallOutput != nil { + // Special case: ResponsesFunctionToolCallOutput marshals to a raw value (string or array), + // not an object, so we need to add it as an "output" field directly + outputData, err := sonic.Marshal(rtm.ResponsesFunctionToolCallOutput) + if err != nil { + return nil, err + } + var output interface{} + if err := sonic.Unmarshal(outputData, &output); err != nil { + return nil, err + } + result["output"] = output + } + if rtm.ResponsesCodeInterpreterToolCall != nil { + if err := mergeEmbedded(rtm.ResponsesCodeInterpreterToolCall); err != nil { + return nil, err + } + } + if rtm.ResponsesLocalShellCall != nil { + if err := mergeEmbedded(rtm.ResponsesLocalShellCall); err != nil { + return nil, err + } + } + if rtm.ResponsesLocalShellCallOutput != nil { + if err := mergeEmbedded(rtm.ResponsesLocalShellCallOutput); err != nil { + return nil, err + } + } + if rtm.ResponsesMCPToolCall != nil { + if err := mergeEmbedded(rtm.ResponsesMCPToolCall); err != nil { + return nil, err + } + } + if rtm.ResponsesCustomToolCall != nil { + if err := mergeEmbedded(rtm.ResponsesCustomToolCall); err != nil { + return nil, err + } + } + if rtm.ResponsesCustomToolCallOutput != nil { + if err := mergeEmbedded(rtm.ResponsesCustomToolCallOutput); err != nil { + return nil, err + } + } + if rtm.ResponsesImageGenerationCall != nil { + if err := mergeEmbedded(rtm.ResponsesImageGenerationCall); err != nil { + return nil, err + } + } + if rtm.ResponsesMCPListTools != nil { + if err := mergeEmbedded(rtm.ResponsesMCPListTools); err != nil { + return nil, err + } + } + if rtm.ResponsesMCPApprovalRequest != nil { + if err := mergeEmbedded(rtm.ResponsesMCPApprovalRequest); err != nil { + return nil, err + } + } + if rtm.ResponsesMCPApprovalResponse != nil { + if err := mergeEmbedded(rtm.ResponsesMCPApprovalResponse); err != nil { + return nil, err + } + } + + return sonic.Marshal(result) +} + +// ============================================================================= +// 4. TOOL CALL STRUCTURES (organized by tool type) +// ============================================================================= + +// ----------------------------------------------------------------------------- +// File Search Tool +// ----------------------------------------------------------------------------- + +type ResponsesFileSearchToolCall struct { + Queries []string `json:"queries"` + Results []ResponsesFileSearchToolCallResult `json:"results,omitempty"` +} + +type ResponsesFileSearchToolCallResult struct { + Attributes *map[string]any `json:"attributes,omitempty"` + FileID *string `json:"file_id,omitempty"` + Filename *string `json:"filename,omitempty"` + Score *float64 `json:"score,omitempty"` + Text *string `json:"text,omitempty"` +} + +// ----------------------------------------------------------------------------- +// Computer Tool +// ----------------------------------------------------------------------------- +type ResponsesComputerToolCall struct { + Action ResponsesComputerToolCallAction `json:"action"` + PendingSafetyChecks []ResponsesComputerToolCallPendingSafetyCheck `json:"pending_safety_checks"` +} + +type ResponsesComputerToolCallPendingSafetyCheck struct { + ID string `json:"id"` + Context string `json:"context"` + Message string `json:"message"` +} + +// ComputerAction represents the different types of computer actions +type ResponsesComputerToolCallAction struct { + Type string `json:"type"` // "click" | "double_click" | "drag" | "keypress" | "move" | "screenshot" | "scroll" | "type" | "wait" + X *int `json:"x,omitempty"` // Common X coordinate field (Click, DoubleClick, Move, Scroll) + Y *int `json:"y,omitempty"` // Common Y coordinate field (Click, DoubleClick, Move, Scroll) + Button *string `json:"button,omitempty"` // "left" | "right" | "wheel" | "back" | "forward" + Path []ResponsesComputerToolCallActionPath `json:"path,omitempty"` + Keys []string `json:"keys,omitempty"` + ScrollX *int `json:"scroll_x,omitempty"` + ScrollY *int `json:"scroll_y,omitempty"` + Text *string `json:"text,omitempty"` +} + +type ResponsesComputerToolCallActionPath struct { + X int `json:"x"` + Y int `json:"y"` +} + +// Computer Tool Call Output - contains the results from executing a computer tool call +type ResponsesComputerToolCallOutput struct { + Output ResponsesComputerToolCallOutputData `json:"output"` + AcknowledgedSafetyChecks []ResponsesComputerToolCallAcknowledgedSafetyCheck `json:"acknowledged_safety_checks,omitempty"` +} + +// ComputerToolCallOutputData - A computer screenshot image used with the computer use tool +type ResponsesComputerToolCallOutputData struct { + Type string `json:"type"` // always "computer_screenshot" + FileID *string `json:"file_id,omitempty"` + ImageURL *string `json:"image_url,omitempty"` +} + +// AcknowledgedSafetyCheck - The safety checks reported by the API that have been acknowledged by the developer +type ResponsesComputerToolCallAcknowledgedSafetyCheck struct { + ID string `json:"id"` + Code *string `json:"code,omitempty"` + Message *string `json:"message,omitempty"` +} + +// ----------------------------------------------------------------------------- +// Web Search Tool +// ----------------------------------------------------------------------------- +type ResponsesWebSearchToolCall struct { + Action ResponsesWebSearchAction `json:"action"` +} + +// WebSearchAction represents the different types of web search actions +type ResponsesWebSearchAction struct { + Type string `json:"type"` // "search" | "open_page" | "find" + URL *string `json:"url,omitempty"` // Common URL field (OpenPage, Find) + Query *string `json:"query,omitempty"` + Sources []ResponsesWebSearchActionSearchSource `json:"sources,omitempty"` + Pattern *string `json:"pattern,omitempty"` +} + +// WebSearchSource - The sources used in the search +type ResponsesWebSearchActionSearchSource struct { + Type string `json:"type"` // always "url" + URL string `json:"url"` +} + +// ----------------------------------------------------------------------------- +// Function Tool +// ----------------------------------------------------------------------------- + +// Function Tool Call Output - contains the results from executing a function tool call +type ResponsesFunctionToolCallOutput struct { + ResponsesFunctionToolCallOutputStr *string //A JSON string of the output of the function tool call. + ResponsesFunctionToolCallOutputBlocks *[]ResponsesMessageContentBlock +} + +// MarshalJSON implements custom JSON marshalling for ResponsesFunctionToolCallOutput. +// It marshals either ContentStr or ContentBlocks directly without wrapping. +func (rf ResponsesFunctionToolCallOutput) MarshalJSON() ([]byte, error) { + // Validation: ensure only one field is set at a time + if rf.ResponsesFunctionToolCallOutputStr != nil && rf.ResponsesFunctionToolCallOutputBlocks != nil { + return nil, fmt.Errorf("both ResponsesFunctionToolCallOutputStr and ResponsesFunctionToolCallOutputBlocks are set; only one should be non-nil") + } + + if rf.ResponsesFunctionToolCallOutputStr != nil { + return sonic.Marshal(*rf.ResponsesFunctionToolCallOutputStr) + } + if rf.ResponsesFunctionToolCallOutputBlocks != nil { + return sonic.Marshal(*rf.ResponsesFunctionToolCallOutputBlocks) + } + // If both are nil, return null + return sonic.Marshal(nil) +} + +// UnmarshalJSON implements custom JSON unmarshalling for ResponsesFunctionToolCallOutput. +// It determines whether "content" is a string or array and assigns to the appropriate field. +// It also handles direct string/array content without a wrapper object. +func (rf *ResponsesFunctionToolCallOutput) UnmarshalJSON(data []byte) error { + // Parse as generic object to check if it contains content-like fields + var genericObj map[string]interface{} + if err := sonic.Unmarshal(data, &genericObj); err != nil { + return err + } + + // If the object doesn't contain typical content fields, it's probably not meant for this struct + // (e.g., it's a tool call, not a tool call output) + hasContentFields := false + for key := range genericObj { + if key == "content" || key == "output" || key == "result" { + hasContentFields = true + break + } + } + + if !hasContentFields { + return nil // Skip unmarshaling if no relevant content fields + } + + // First, try to unmarshal as a direct string + var stringContent string + if err := sonic.Unmarshal(data, &stringContent); err == nil { + rf.ResponsesFunctionToolCallOutputStr = &stringContent + return nil + } + + // Try to unmarshal as a direct array of ContentBlock + var arrayContent []ResponsesMessageContentBlock + if err := sonic.Unmarshal(data, &arrayContent); err == nil { + rf.ResponsesFunctionToolCallOutputBlocks = &arrayContent + return nil + } + + return fmt.Errorf("content field is neither a string nor an array of Content blocks") +} + +// ----------------------------------------------------------------------------- +// Reasoning +// ----------------------------------------------------------------------------- + +type ResponsesReasoning struct { + Summary []ResponsesReasoningContent `json:"summary"` + EncryptedContent *string `json:"encrypted_content,omitempty"` +} + +type ResponsesReasoningContentBlockType string + +const ( + ResponsesReasoningContentBlockTypeSummaryText ResponsesReasoningContentBlockType = "summary_text" +) + +type ResponsesReasoningContent struct { + Type ResponsesReasoningContentBlockType `json:"type"` + Text string `json:"text"` +} + +// ----------------------------------------------------------------------------- +// Image Generation Tool +// ----------------------------------------------------------------------------- +type ResponsesImageGenerationCall struct { + Result string `json:"result"` +} + +// ----------------------------------------------------------------------------- +// Code Interpreter Tool +// ----------------------------------------------------------------------------- +type ResponsesCodeInterpreterToolCall struct { + Code *string `json:"code"` // The code to run, or null if not available + ContainerID string `json:"container_id"` // The ID of the container used to run the code + Outputs []ResponsesCodeInterpreterOutput `json:"outputs"` // The outputs generated by the code interpreter, can be null +} + +// CodeInterpreterOutput represents the different types of code interpreter outputs +type ResponsesCodeInterpreterOutput struct { + *ResponsesCodeInterpreterOutputLogs + *ResponsesCodeInterpreterOutputImage +} + +// MarshalJSON implements custom JSON marshaling for ResponsesCodeInterpreterOutput +func (o ResponsesCodeInterpreterOutput) MarshalJSON() ([]byte, error) { + // Error if both variants are set + if o.ResponsesCodeInterpreterOutputLogs != nil && o.ResponsesCodeInterpreterOutputImage != nil { + return nil, fmt.Errorf("ResponsesCodeInterpreterOutput cannot have both Logs and Image set") + } + + // Marshal whichever one is present + if o.ResponsesCodeInterpreterOutputLogs != nil { + return sonic.Marshal(o.ResponsesCodeInterpreterOutputLogs) + } + if o.ResponsesCodeInterpreterOutputImage != nil { + return sonic.Marshal(o.ResponsesCodeInterpreterOutputImage) + } + + // Return null if neither is set + return []byte("null"), nil +} + +// UnmarshalJSON implements custom JSON unmarshaling for ResponsesCodeInterpreterOutput +func (o *ResponsesCodeInterpreterOutput) UnmarshalJSON(data []byte) error { + // Handle null case + if string(data) == "null" { + return nil + } + + // First, peek at the type field to determine which variant to unmarshal + var typeStruct struct { + Type string `json:"type"` + } + if err := sonic.Unmarshal(data, &typeStruct); err != nil { + return fmt.Errorf("failed to read type field: %w", err) + } + + // Unmarshal into the appropriate concrete type based on the type field + switch typeStruct.Type { + case "logs": + var logs ResponsesCodeInterpreterOutputLogs + if err := sonic.Unmarshal(data, &logs); err != nil { + return fmt.Errorf("failed to unmarshal logs output: %w", err) + } + o.ResponsesCodeInterpreterOutputLogs = &logs + o.ResponsesCodeInterpreterOutputImage = nil + return nil + + case "image": + var image ResponsesCodeInterpreterOutputImage + if err := sonic.Unmarshal(data, &image); err != nil { + return fmt.Errorf("failed to unmarshal image output: %w", err) + } + o.ResponsesCodeInterpreterOutputImage = &image + o.ResponsesCodeInterpreterOutputLogs = nil + return nil + + default: + return fmt.Errorf("unknown ResponsesCodeInterpreterOutput type: %s", typeStruct.Type) + } +} + +// CodeInterpreterOutputLogs - The logs output from the code interpreter +type ResponsesCodeInterpreterOutputLogs struct { + Logs string `json:"logs"` + Type string `json:"type"` // always "logs" +} + +// CodeInterpreterOutputImage - The image output from the code interpreter +type ResponsesCodeInterpreterOutputImage struct { + Type string `json:"type"` // always "image" + URL string `json:"url"` +} + +// ----------------------------------------------------------------------------- +// Local Shell Tool +// ----------------------------------------------------------------------------- +type ResponsesLocalShellCall struct { + Action ResponsesLocalShellCallAction `json:"action"` +} + +type ResponsesLocalShellCallAction struct { + Command []string `json:"command"` + Env []string `json:"env"` + Type string `json:"type"` // always "exec" + TimeoutMS *int `json:"timeout_ms,omitempty"` + User *string `json:"user,omitempty"` + WorkingDirectory *string `json:"working_directory,omitempty"` +} + +type ResponsesLocalShellCallOutput struct { + Output string `json:"output"` +} + +// ----------------------------------------------------------------------------- +// MCP (Model Context Protocol) Tools +// ----------------------------------------------------------------------------- +type ResponsesMCPListTools struct { + ServerLabel string `json:"server_label"` + Tools []ResponsesMCPTool `json:"tools"` + Error *string `json:"error,omitempty"` +} + +type ResponsesMCPTool struct { + Name string `json:"name"` + InputSchema map[string]any `json:"input_schema"` + Description *string `json:"description,omitempty"` + Annotations *map[string]any `json:"annotations,omitempty"` +} + +// MCP Approval Request - requests approval for a specific action within MCP +type ResponsesMCPApprovalRequest struct { + Action ResponsesMCPApprovalRequestAction `json:"action"` +} + +type ResponsesMCPApprovalRequestAction struct { + ID string `json:"id"` + Type string `json:"type"` // always "mcp_approval_request" + Name string `json:"name"` + ServerLabel string `json:"server_label"` + Arguments string `json:"arguments"` +} + +// MCP Approval Response - contains the response to an approval request within MCP +type ResponsesMCPApprovalResponse struct { + ApprovalResponseID string `json:"approval_response_id"` + Approve bool `json:"approve"` + Reason *string `json:"reason,omitempty"` +} + +// MCP Tool Call - an invocation of a tool on an MCP server +type ResponsesMCPToolCall struct { + ServerLabel string `json:"server_label"` // The label of the MCP server running the tool + Error *string `json:"error,omitempty"` // The error from the tool call, if any + Output *string `json:"output,omitempty"` // The output from the tool call +} + +// ----------------------------------------------------------------------------- +// Custom Tools +// ----------------------------------------------------------------------------- +type ResponsesCustomToolCallOutput struct { + Output string `json:"output"` // The output from the custom tool call generated by your code +} + +// Custom Tool Call - a call to a custom tool created by the model +type ResponsesCustomToolCall struct { + Input string `json:"input"` // The input for the custom tool call generated by the model +} + +// ============================================================================= +// 5. TOOL CHOICE CONFIGURATION +// ============================================================================= + +// Combined tool choices for all providers, make sure to check the provider's +// documentation to see which tool choices are supported. +type ResponsesToolChoiceType string + +const ( + // ResponsesToolChoiceTypeNone means no tool should be called + ResponsesToolChoiceTypeNone ResponsesToolChoiceType = "none" + // ResponsesToolChoiceTypeAuto means an automatic tool should be called + ResponsesToolChoiceTypeAuto ResponsesToolChoiceType = "auto" + // ResponsesToolChoiceTypeAny means any tool can be called + ResponsesToolChoiceTypeAny ResponsesToolChoiceType = "any" + // ResponsesToolChoiceTypeRequired means a specific tool must be called + ResponsesToolChoiceTypeRequired ResponsesToolChoiceType = "required" + // ResponsesToolChoiceTypeFunction means a specific tool must be called + ResponsesToolChoiceTypeFunction ResponsesToolChoiceType = "function" + // ResponsesToolChoiceTypeAllowedTools means a specific tool must be called + ResponsesToolChoiceTypeAllowedTools ResponsesToolChoiceType = "allowed_tools" + // ResponsesToolChoiceTypeFileSearch means a file search tool must be called + ResponsesToolChoiceTypeFileSearch ResponsesToolChoiceType = "file_search" + // ResponsesToolChoiceTypeWebSearchPreview means a web search preview tool must be called + ResponsesToolChoiceTypeWebSearchPreview ResponsesToolChoiceType = "web_search_preview" + // ResponsesToolChoiceTypeComputerUsePreview means a computer use preview tool must be called + ResponsesToolChoiceTypeComputerUsePreview ResponsesToolChoiceType = "computer_use_preview" + // ResponsesToolChoiceTypeCodeInterpreter means a code interpreter tool must be called + ResponsesToolChoiceTypeCodeInterpreter ResponsesToolChoiceType = "code_interpreter" + // ResponsesToolChoiceTypeImageGeneration means an image generation tool must be called + ResponsesToolChoiceTypeImageGeneration ResponsesToolChoiceType = "image_generation" + // ResponsesToolChoiceTypeMCP means an MCP tool must be called + ResponsesToolChoiceTypeMCP ResponsesToolChoiceType = "mcp" + // ResponsesToolChoiceTypeCustom means a custom tool must be called + ResponsesToolChoiceTypeCustom ResponsesToolChoiceType = "custom" +) + +// ResponsesToolChoice represents how the model should select tools - can be string or object +type ResponsesToolChoiceStruct struct { + Type ResponsesToolChoiceType `json:"type"` // Type of tool choice + Mode *string `json:"mode,omitempty"` //"none" | "auto" | "required" + Name *string `json:"name,omitempty"` // Common name field for function/MCP/custom tools + ServerLabel *string `json:"server_label,omitempty"` // Common server label field for MCP tools + Tools []ResponsesToolChoiceAllowedToolDef `json:"tools,omitempty"` +} + +type ResponsesToolChoice struct { + ResponsesToolChoiceStr *string + ResponsesToolChoiceStruct *ResponsesToolChoiceStruct +} + +// MarshalJSON implements custom JSON marshalling for ChatMessageContent. +// It marshals either ContentStr or ContentBlocks directly without wrapping. +func (bc ResponsesToolChoice) MarshalJSON() ([]byte, error) { + // Validation: ensure only one field is set at a time + if bc.ResponsesToolChoiceStr != nil && bc.ResponsesToolChoiceStruct != nil { + return nil, fmt.Errorf("both ResponsesToolChoiceStr, ResponsesToolChoiceStruct are set; only one should be non-nil") + } + + if bc.ResponsesToolChoiceStr != nil { + return sonic.Marshal(bc.ResponsesToolChoiceStr) + } + if bc.ResponsesToolChoiceStruct != nil { + return sonic.Marshal(bc.ResponsesToolChoiceStruct) + } + // If both are nil, return null + return sonic.Marshal(nil) +} + +// UnmarshalJSON implements custom JSON unmarshalling for ChatMessageContent. +// It determines whether "content" is a string or array and assigns to the appropriate field. +// It also handles direct string/array content without a wrapper object. +func (bc *ResponsesToolChoice) UnmarshalJSON(data []byte) error { + // First, try to unmarshal as a direct string + var toolChoiceStr string + if err := sonic.Unmarshal(data, &toolChoiceStr); err == nil { + bc.ResponsesToolChoiceStr = &toolChoiceStr + return nil + } + + // Try to unmarshal as a direct array of ContentBlock + var responsesToolChoiceStruct ResponsesToolChoiceStruct + if err := sonic.Unmarshal(data, &responsesToolChoiceStruct); err == nil { + bc.ResponsesToolChoiceStruct = &responsesToolChoiceStruct + return nil + } + + return fmt.Errorf("tool_choice field is neither a string nor a ResponsesToolChoiceStruct object") +} + +// ToolChoiceAllowedToolDef - Definition of an allowed tool +type ResponsesToolChoiceAllowedToolDef struct { + Type string `json:"type"` // "function" | "mcp" | "image_generation" + Name *string `json:"name,omitempty"` // for function tools + ServerLabel *string `json:"server_label,omitempty"` // for MCP tools +} + +// ============================================================================= +// 7. TOOL CONFIGURATION STRUCTURES +// ============================================================================= + +// Tool represents different types of tools the model can use +type ResponsesTool struct { + Type string `json:"type"` // "function" | "file_search" | "computer_use_preview" | "web_search" | "web_search_2025_08_26" | "mcp" | "code_interpreter" | "image_generation" | "local_shell" | "custom" | "web_search_preview" | "web_search_preview_2025_03_11" + Name *string `json:"name,omitempty"` // Common name field (Function, Custom tools) + Description *string `json:"description,omitempty"` // Common description field (Function, Custom tools) + + *ResponsesToolFunction + *ResponsesToolFileSearch + *ResponsesToolComputerUsePreview + *ResponsesToolWebSearch + *ResponsesToolMCP + *ResponsesToolCodeInterpreter + *ResponsesToolImageGeneration + *ResponsesToolLocalShell + *ResponsesToolCustom + *ResponsesToolWebSearchPreview +} + +type ResponsesToolFunction struct { + Parameters *ToolFunctionParameters `json:"parameters,omitempty"` // A JSON schema object describing the parameters + Strict *bool `json:"strict,omitempty"` // Whether to enforce strict parameter validation +} + +// ToolFileSearch - A tool that searches for relevant content from uploaded files +type ResponsesToolFileSearch struct { + VectorStoreIDs []string `json:"vector_store_ids"` // The IDs of the vector stores to search + Filters *ResponsesToolFileSearchFilter `json:"filters,omitempty"` // A filter to apply + MaxNumResults *int `json:"max_num_results,omitempty"` // Maximum results (1-50) + RankingOptions *ResponsesToolFileSearchRankingOptions `json:"ranking_options,omitempty"` // Ranking options for search +} + +// FileSearchFilter - A filter to apply to file search +type ResponsesToolFileSearchFilter struct { + Type string `json:"type"` // "eq" | "ne" | "gt" | "gte" | "lt" | "lte" | "and" | "or" + + // Filter types - only one should be set + *ResponsesToolFileSearchComparisonFilter + *ResponsesToolFileSearchCompoundFilter +} + +// MarshalJSON implements custom JSON marshaling for ResponsesToolFileSearchFilter +func (f *ResponsesToolFileSearchFilter) MarshalJSON() ([]byte, error) { + // Validate that exactly one filter type is set + if f.ResponsesToolFileSearchComparisonFilter != nil && f.ResponsesToolFileSearchCompoundFilter != nil { + return nil, fmt.Errorf("both comparison and compound filters are set; only one should be non-nil") + } + if f.ResponsesToolFileSearchComparisonFilter == nil && f.ResponsesToolFileSearchCompoundFilter == nil { + return nil, fmt.Errorf("neither comparison nor compound filter is set; exactly one must be non-nil") + } + + // Create a map to hold the JSON data + result := make(map[string]interface{}) + result["type"] = f.Type + + // Marshal the appropriate embedded struct based on type + switch f.Type { + case "eq", "ne", "gt", "gte", "lt", "lte": + if f.ResponsesToolFileSearchComparisonFilter == nil { + return nil, fmt.Errorf("comparison filter is nil but type is %s", f.Type) + } + // Copy fields from the embedded struct + result["key"] = f.ResponsesToolFileSearchComparisonFilter.Key + result["value"] = f.ResponsesToolFileSearchComparisonFilter.Value + case "and", "or": + if f.ResponsesToolFileSearchCompoundFilter == nil { + return nil, fmt.Errorf("compound filter is nil but type is %s", f.Type) + } + // Copy fields from the embedded struct + result["filters"] = f.ResponsesToolFileSearchCompoundFilter.Filters + default: + return nil, fmt.Errorf("unknown filter type: %s", f.Type) + } + + return sonic.Marshal(result) +} + +// UnmarshalJSON implements custom JSON unmarshaling for ResponsesToolFileSearchFilter +func (f *ResponsesToolFileSearchFilter) UnmarshalJSON(data []byte) error { + // First, unmarshal into a map to inspect the type field + var raw map[string]interface{} + if err := sonic.Unmarshal(data, &raw); err != nil { + return fmt.Errorf("failed to unmarshal filter JSON: %w", err) + } + + // Extract the type field + typeValue, ok := raw["type"] + if !ok { + return fmt.Errorf("missing required 'type' field in filter") + } + + typeStr, ok := typeValue.(string) + if !ok { + return fmt.Errorf("'type' field must be a string, got %T", typeValue) + } + + f.Type = typeStr + + // Initialize the appropriate embedded struct based on type + switch typeStr { + case "eq", "ne", "gt", "gte", "lt", "lte": + // This is a comparison filter + f.ResponsesToolFileSearchComparisonFilter = &ResponsesToolFileSearchComparisonFilter{} + f.ResponsesToolFileSearchCompoundFilter = nil + + // Unmarshal into the comparison filter + if err := sonic.Unmarshal(data, f.ResponsesToolFileSearchComparisonFilter); err != nil { + return fmt.Errorf("failed to unmarshal comparison filter: %w", err) + } + + // Validate required fields + if f.ResponsesToolFileSearchComparisonFilter.Key == "" { + return fmt.Errorf("comparison filter missing required 'key' field") + } + if f.ResponsesToolFileSearchComparisonFilter.Value == nil { + return fmt.Errorf("comparison filter missing required 'value' field") + } + + case "and", "or": + // This is a compound filter + f.ResponsesToolFileSearchCompoundFilter = &ResponsesToolFileSearchCompoundFilter{} + f.ResponsesToolFileSearchComparisonFilter = nil + + // Unmarshal into the compound filter + if err := sonic.Unmarshal(data, f.ResponsesToolFileSearchCompoundFilter); err != nil { + return fmt.Errorf("failed to unmarshal compound filter: %w", err) + } + + // Validate required fields + if f.ResponsesToolFileSearchCompoundFilter.Filters == nil { + return fmt.Errorf("compound filter missing required 'filters' field") + } + if len(f.ResponsesToolFileSearchCompoundFilter.Filters) == 0 { + return fmt.Errorf("compound filter 'filters' array cannot be empty") + } + + default: + return fmt.Errorf("unknown filter type: %s (supported types: eq, ne, gt, gte, lt, lte, and, or)", typeStr) + } + + return nil +} + +// FileSearchComparisonFilter - Compare a specified attribute key to a value +type ResponsesToolFileSearchComparisonFilter struct { + Key string `json:"key"` // The key to compare against the value + Type string `json:"type"` // + Value interface{} `json:"value"` // The value to compare (string, number, or boolean) +} + +// FileSearchCompoundFilter - Combine multiple filters using and or or +type ResponsesToolFileSearchCompoundFilter struct { + Filters []ResponsesToolFileSearchFilter `json:"filters"` // Array of filters to combine +} + +// FileSearchRankingOptions - Ranking options for search +type ResponsesToolFileSearchRankingOptions struct { + Ranker *string `json:"ranker,omitempty"` // The ranker to use + ScoreThreshold *float64 `json:"score_threshold,omitempty"` // Score threshold (0-1) +} + +// ToolComputerUsePreview - A tool that controls a virtual computer +type ResponsesToolComputerUsePreview struct { + DisplayHeight int `json:"display_height"` // The height of the computer display + DisplayWidth int `json:"display_width"` // The width of the computer display + Environment string `json:"environment"` // The type of computer environment to control +} + +// ToolWebSearch - Search the Internet for sources related to the prompt +type ResponsesToolWebSearch struct { + Filters *ResponsesToolWebSearchFilters `json:"filters,omitempty"` // Filters for the search + SearchContextSize *string `json:"search_context_size,omitempty"` // "low" | "medium" | "high" + UserLocation *ResponsesToolWebSearchUserLocation `json:"user_location,omitempty"` // The approximate location of the user +} + +// ResponsesToolWebSearchFilters - Filters for web search +type ResponsesToolWebSearchFilters struct { + AllowedDomains []string `json:"allowed_domains"` // Allowed domains for the search +} + +// ResponsesToolWebSearchUserLocation - The approximate location of the user +type ResponsesToolWebSearchUserLocation struct { + City *string `json:"city,omitempty"` // Free text input for the city + Country *string `json:"country,omitempty"` // Two-letter ISO country code + Region *string `json:"region,omitempty"` // Free text input for the region + Timezone *string `json:"timezone,omitempty"` // IANA timezone + Type *string `json:"type,omitempty"` // always "approximate" +} + +// ResponsesToolMCP - Give the model access to additional tools via remote MCP servers +type ResponsesToolMCP struct { + ServerLabel string `json:"server_label"` // A label for this MCP server + AllowedTools *ResponsesToolMCPAllowedTools `json:"allowed_tools,omitempty"` // List of allowed tool names or filter + Authorization *string `json:"authorization,omitempty"` // OAuth access token + ConnectorID *string `json:"connector_id,omitempty"` // Service connector ID + Headers *map[string]string `json:"headers,omitempty"` // Optional HTTP headers + RequireApproval *ResponsesToolMCPAllowedToolsApprovalSetting `json:"require_approval,omitempty"` // Tool approval settings + ServerDescription *string `json:"server_description,omitempty"` // Optional server description + ServerURL *string `json:"server_url,omitempty"` // The URL for the MCP server +} + +// ResponsesToolMCPAllowedTools - List of allowed tool names or a filter object +type ResponsesToolMCPAllowedTools struct { + // Either a simple array of tool names or a filter object + ToolNames *[]string `json:",omitempty"` + Filter *ResponsesToolMCPAllowedToolsFilter `json:",omitempty"` +} + +// ResponsesToolMCPAllowedToolsFilter - A filter object to specify which tools are allowed +type ResponsesToolMCPAllowedToolsFilter struct { + ReadOnly *bool `json:"read_only,omitempty"` // Whether tool is read-only + ToolNames *[]string `json:"tool_names,omitempty"` // List of allowed tool names +} + +// ResponsesToolMCPAllowedToolsApprovalSetting - Specify which tools require approval +type ResponsesToolMCPAllowedToolsApprovalSetting struct { + // Either a string setting or filter objects + Setting *string `json:",omitempty"` // "always" | "never" + Always *ResponsesToolMCPAllowedToolsApprovalFilter `json:"always,omitempty"` + Never *ResponsesToolMCPAllowedToolsApprovalFilter `json:"never,omitempty"` +} + +// ResponsesToolMCPAllowedToolsApprovalFilter - Filter for approval settings +type ResponsesToolMCPAllowedToolsApprovalFilter struct { + ReadOnly *bool `json:"read_only,omitempty"` // Whether tool is read-only + ToolNames *[]string `json:"tool_names,omitempty"` // List of tool names +} + +// ToolCodeInterpreter - A tool that runs Python code +type ResponsesToolCodeInterpreter struct { + Container interface{} `json:"container"` // Container ID or object with file IDs +} + +// ToolImageGeneration - A tool that generates images +type ResponsesToolImageGeneration struct { + Background *string `json:"background,omitempty"` // "transparent" | "opaque" | "auto" + InputFidelity *string `json:"input_fidelity,omitempty"` // "high" | "low" + InputImageMask *ResponsesToolImageGenerationInputImageMask `json:"input_image_mask,omitempty"` // Optional mask for inpainting + Model *string `json:"model,omitempty"` // Image generation model + Moderation *string `json:"moderation,omitempty"` // Moderation level + OutputCompression *int `json:"output_compression,omitempty"` // Compression level (0-100) + OutputFormat *string `json:"output_format,omitempty"` // "png" | "webp" | "jpeg" + PartialImages *int `json:"partial_images,omitempty"` // Number of partial images (0-3) + Quality *string `json:"quality,omitempty"` // "low" | "medium" | "high" | "auto" + Size *string `json:"size,omitempty"` // Image size +} + +// ImageGenerationInputMask - Optional mask for inpainting +type ResponsesToolImageGenerationInputImageMask struct { + FileID *string `json:"file_id,omitempty"` // File ID for the mask image + ImageURL *string `json:"image_url,omitempty"` // Base64-encoded mask image +} + +// ToolLocalShell - A tool that allows executing shell commands locally +type ResponsesToolLocalShell struct { + // No unique fields needed since Type is now in the top-level struct +} + +// ToolCustom - A custom tool that processes input using a specified format +type ResponsesToolCustom struct { + Format *ResponsesToolCustomFormat `json:"format,omitempty"` // The input format +} + +// CustomToolFormat - The input format for the custom tool +type ResponsesToolCustomFormat struct { + Type string `json:"type"` // always "text" + + // For Grammar + Definition *string `json:"definition,omitempty"` // The grammar definition + Syntax *string `json:"syntax,omitempty"` // "lark" | "regex" +} + +// ToolWebSearchPreview - Web search tool preview variant +type ResponsesToolWebSearchPreview struct { + SearchContextSize *string `json:"search_context_size,omitempty"` // "low" | "medium" | "high" + UserLocation *ResponsesToolWebSearchUserLocation `json:"user_location,omitempty"` // The user's location +} diff --git a/core/schemas/speech.go b/core/schemas/speech.go new file mode 100644 index 0000000000..f301ada5dd --- /dev/null +++ b/core/schemas/speech.go @@ -0,0 +1,84 @@ +package schemas + +import ( + "fmt" + + "github.com/bytedance/sonic" +) + +// SpeechInput represents the input for a speech request. +type SpeechInput struct { + Input string `json:"input"` +} + +type SpeechParameters struct { + VoiceConfig SpeechVoiceInput `json:"voice"` + Instructions string `json:"instructions,omitempty"` + ResponseFormat string `json:"response_format,omitempty"` // Default is "mp3" + Speed *float64 `json:"speed,omitempty"` + + // Dynamic parameters that can be provider-specific, they are directly + // added to the request as is. + ExtraParams map[string]interface{} `json:"-"` +} + +type SpeechVoiceInput struct { + Voice *string + MultiVoiceConfig []VoiceConfig +} + +type VoiceConfig struct { + Speaker string `json:"speaker"` + Voice string `json:"voice"` +} + +// MarshalJSON implements custom JSON marshalling for SpeechVoiceInput. +// It marshals either Voice or MultiVoiceConfig directly without wrapping. +func (vi *SpeechVoiceInput) MarshalJSON() ([]byte, error) { + // Validation: ensure only one field is set at a time + if vi.Voice != nil && len(vi.MultiVoiceConfig) > 0 { + return nil, fmt.Errorf("both Voice and MultiVoiceConfig are set; only one should be non-nil") + } + + if vi.Voice != nil { + return sonic.Marshal(*vi.Voice) + } + if len(vi.MultiVoiceConfig) > 0 { + return sonic.Marshal(vi.MultiVoiceConfig) + } + // If both are nil, return null + return sonic.Marshal(nil) +} + +// UnmarshalJSON implements custom JSON unmarshalling for SpeechVoiceInput. +// It determines whether "voice" is a string or a VoiceConfig object/array and assigns to the appropriate field. +// It also handles direct string/array content without a wrapper object. +func (vi *SpeechVoiceInput) UnmarshalJSON(data []byte) error { + // Reset receiver state before attempting any decode to avoid stale data + vi.Voice = nil + vi.MultiVoiceConfig = nil + + // First, try to unmarshal as a direct string + var stringContent string + if err := sonic.Unmarshal(data, &stringContent); err == nil { + vi.Voice = &stringContent + return nil + } + + // Try to unmarshal as an array of VoiceConfig objects + var voiceConfigs []VoiceConfig + if err := sonic.Unmarshal(data, &voiceConfigs); err == nil { + // Validate each VoiceConfig and build a new slice deterministically + validConfigs := make([]VoiceConfig, 0, len(voiceConfigs)) + for _, config := range voiceConfigs { + if config.Voice == "" { + return fmt.Errorf("voice config has empty voice field") + } + validConfigs = append(validConfigs, config) + } + vi.MultiVoiceConfig = validConfigs + return nil + } + + return fmt.Errorf("voice field is neither a string, nor an array of VoiceConfig objects") +} diff --git a/core/schemas/textcompletions.go b/core/schemas/textcompletions.go new file mode 100644 index 0000000000..e70f0d1bbc --- /dev/null +++ b/core/schemas/textcompletions.go @@ -0,0 +1,69 @@ +package schemas + +import ( + "fmt" + + "github.com/bytedance/sonic" +) + +type TextCompletionInput struct { + PromptStr *string + PromptArray []string +} + +func (t *TextCompletionInput) MarshalJSON() ([]byte, error) { + set := 0 + if t.PromptStr != nil { + set++ + } + if t.PromptArray != nil { + set++ + } + if set == 0 { + return nil, fmt.Errorf("text completion input is empty") + } + if set > 1 { + return nil, fmt.Errorf("text completion input must set exactly one of: prompt_str or prompt_array") + } + if t.PromptStr != nil { + return sonic.Marshal(*t.PromptStr) + } + return sonic.Marshal(t.PromptArray) +} + +func (t *TextCompletionInput) UnmarshalJSON(data []byte) error { + var prompt string + if err := sonic.Unmarshal(data, &prompt); err == nil { + t.PromptStr = &prompt + t.PromptArray = nil + return nil + } + var promptArray []string + if err := sonic.Unmarshal(data, &promptArray); err == nil { + t.PromptStr = nil + t.PromptArray = promptArray + return nil + } + return fmt.Errorf("invalid text completion input") +} + +type TextCompletionParameters struct { + BestOf *int `json:"best_of,omitempty"` + Echo *bool `json:"echo,omitempty"` + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` + LogitBias *map[string]float64 `json:"logit_bias,omitempty"` + LogProbs *int `json:"logprobs,omitempty"` + MaxTokens *int `json:"max_tokens,omitempty"` + N *int `json:"n,omitempty"` + PresencePenalty *float64 `json:"presence_penalty,omitempty"` + Seed *int `json:"seed,omitempty"` + Stop *[]string `json:"stop,omitempty"` + Suffix *string `json:"suffix,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + User *string `json:"user,omitempty"` + + // Dynamic parameters that can be provider-specific, they are directly + // added to the request as is. + ExtraParams map[string]interface{} `json:"-"` +} diff --git a/core/schemas/transcriptions.go b/core/schemas/transcriptions.go new file mode 100644 index 0000000000..451655c880 --- /dev/null +++ b/core/schemas/transcriptions.go @@ -0,0 +1,54 @@ +package schemas + +type TranscriptionInput struct { + File []byte `json:"file"` +} + +type TranscriptionParameters struct { + Language *string `json:"language,omitempty"` + Prompt *string `json:"prompt,omitempty"` + ResponseFormat *string `json:"response_format,omitempty"` // Default is "json" + Format *string `json:"file_format,omitempty"` // Type of file, not required in openai, but required in gemini + + // Dynamic parameters that can be provider-specific, they are directly + // added to the request as is. + ExtraParams map[string]interface{} `json:"-"` +} + +// TranscriptionLogProb represents log probability information for transcription +type TranscriptionLogProb struct { + Token string `json:"token"` + LogProb float64 `json:"logprob"` + Bytes []int `json:"bytes"` +} + +// TranscriptionWord represents word-level timing information +type TranscriptionWord struct { + Word string `json:"word"` + Start float64 `json:"start"` + End float64 `json:"end"` +} + +// TranscriptionSegment represents segment-level transcription information +type TranscriptionSegment struct { + ID int `json:"id"` + Seek int `json:"seek"` + Start float64 `json:"start"` + End float64 `json:"end"` + Text string `json:"text"` + Tokens []int `json:"tokens"` + Temperature float64 `json:"temperature"` + AvgLogProb float64 `json:"avg_logprob"` + CompressionRatio float64 `json:"compression_ratio"` + NoSpeechProb float64 `json:"no_speech_prob"` +} + +// TranscriptionUsage represents usage information for transcription +type TranscriptionUsage struct { + Type string `json:"type"` // "tokens" or "duration" + InputTokens *int `json:"input_tokens,omitempty"` + InputTokenDetails *AudioTokenDetails `json:"input_token_details,omitempty"` + OutputTokens *int `json:"output_tokens,omitempty"` + TotalTokens *int `json:"total_tokens,omitempty"` + Seconds *int `json:"seconds,omitempty"` // For duration-based usage +} diff --git a/core/schemas/utils.go b/core/schemas/utils.go index 9a34fe2d31..427b971854 100644 --- a/core/schemas/utils.go +++ b/core/schemas/utils.go @@ -1,9 +1,11 @@ package schemas import ( + "encoding/json" "fmt" "net/url" "regexp" + "strconv" "strings" ) @@ -60,371 +62,6 @@ func mapFinishReasonToAnthropic(finishReason string) string { } } -// ParameterSet represents a set of valid parameters using a map for O(1) lookup -type ParameterSet map[string]bool - -// Marker to allowe all params -const AllowAllParams = "*" - -// Pre-defined parameter groups (initialized once at startup) -var ( - - allowAllParams = ParameterSet{ - AllowAllParams: true, - } - // Core parameters supported by most providers - coreParams = ParameterSet{ - "max_tokens": true, - "temperature": true, - "top_p": true, - "stream": true, - "tools": true, - "tool_choice": true, - } - - // Extended parameter groups - openAIParams = ParameterSet{ - "frequency_penalty": true, - "presence_penalty": true, - "n": true, - "stop": true, - "logprobs": true, - "top_logprobs": true, - "logit_bias": true, - "seed": true, - "user": true, - "response_format": true, - "parallel_tool_calls": true, - "max_completion_tokens": true, - "metadata": true, - "modalities": true, - "prediction": true, - "reasoning_effort": true, - "service_tier": true, - "store": true, - "speed": true, - "language": true, - "prompt": true, - "include": true, - "timestamp_granularities": true, - "encoding_format": true, - "dimensions": true, - "stream_options": true, - } - - anthropicParams = ParameterSet{ - "stop_sequences": true, - "system": true, - "metadata": true, - "mcp_servers": true, - "service_tier": true, - "thinking": true, - "top_k": true, - } - - cohereParams = ParameterSet{ - "frequency_penalty": true, - "presence_penalty": true, - "k": true, - "p": true, - "truncate": true, - "return_likelihoods": true, - "logit_bias": true, - "stop_sequences": true, - } - - mistralParams = ParameterSet{ - "frequency_penalty": true, - "presence_penalty": true, - "safe_mode": true, - "n": true, - "parallel_tool_calls": true, - "prediction": true, - "prompt_mode": true, - "random_seed": true, - "response_format": true, - "safe_prompt": true, - "top_k": true, - } - - groqParams = ParameterSet{ - "n": true, - "reasoning_effort": true, - "reasoning_format": true, - "service_tier": true, - "stop": true, - } - - ollamaParams = ParameterSet{ - "num_ctx": true, - "num_gpu": true, - "num_thread": true, - "repeat_penalty": true, - "repeat_last_n": true, - "seed": true, - "tfs_z": true, - "mirostat": true, - "mirostat_tau": true, - "mirostat_eta": true, - "format": true, - "keep_alive": true, - "low_vram": true, - "main_gpu": true, - "min_p": true, - "num_batch": true, - "num_keep": true, - "num_predict": true, - "numa": true, - "penalize_newline": true, - "raw": true, - "typical_p": true, - "use_mlock": true, - "use_mmap": true, - "vocab_only": true, - } - - openRouterParams = ParameterSet{ - "transforms": true, - "models": true, - "route": true, - "provider": true, - "prediction": true, - "top_a": true, - "min_p": true, - } - - vertexParams = ParameterSet{ - "task_type": true, - "title": true, - "autoTruncate": true, - "outputDimensionality": true, - } - - bedrockParams = ParameterSet{ - "max_tokens_to_sample": true, - "toolConfig": true, - "input_type": true, - } -) - -// ParameterValidator provides fast parameter validation and filtering -type ParameterValidator struct { - providerSchemas map[ModelProvider]ParameterSet -} - -// NewParameterValidator creates a validator with pre-computed provider schemas -func NewParameterValidator() *ParameterValidator { - return &ParameterValidator{ - providerSchemas: buildProviderSchemas(), - } -} - -// FilterParameters filters parameters for a provider using manual field checks (no reflection) -func (v *ParameterValidator) FilterParameters(provider ModelProvider, params *ModelParameters) *ModelParameters { - if params == nil { - return nil - } - - schema, exists := v.providerSchemas[provider] - if !exists { - return params // Unknown provider, return all params - } - - filtered := &ModelParameters{ - ExtraParams: make(map[string]interface{}), - } - - // Return all params if the provider allows all params - if schema[AllowAllParams] { - return params - } - - // Manual field filtering - fast and memory efficient - v.filterStandardFields(schema, params, filtered) - v.filterExtraParams(schema, params, filtered) - - // Return nil if no valid parameters - if v.isEmpty(filtered) { - return nil - } - - return filtered -} - -// filterStandardFields manually filters each field - faster than reflection -func (v *ParameterValidator) filterStandardFields(schema ParameterSet, source, target *ModelParameters) { - if source.MaxTokens != nil && schema["max_tokens"] { - target.MaxTokens = source.MaxTokens - } - if source.Temperature != nil && schema["temperature"] { - target.Temperature = source.Temperature - } - if source.TopP != nil && schema["top_p"] { - target.TopP = source.TopP - } - if source.TopK != nil && schema["top_k"] { - target.TopK = source.TopK - } - if source.PresencePenalty != nil && schema["presence_penalty"] { - target.PresencePenalty = source.PresencePenalty - } - if source.FrequencyPenalty != nil && schema["frequency_penalty"] { - target.FrequencyPenalty = source.FrequencyPenalty - } - if source.StopSequences != nil && schema["stop_sequences"] { - target.StopSequences = source.StopSequences - } - if source.Tools != nil && schema["tools"] { - target.Tools = source.Tools - } - if source.ToolChoice != nil && schema["tool_choice"] { - target.ToolChoice = source.ToolChoice - } - if source.User != nil && schema["user"] { - target.User = source.User - } - if source.EncodingFormat != nil && schema["encoding_format"] { - target.EncodingFormat = source.EncodingFormat - } - if source.Dimensions != nil && schema["dimensions"] { - target.Dimensions = source.Dimensions - } - if source.ParallelToolCalls != nil && schema["parallel_tool_calls"] { - target.ParallelToolCalls = source.ParallelToolCalls - } - if source.N != nil && schema["n"] { - target.N = source.N - } - if source.Stop != nil && schema["stop"] { - target.Stop = source.Stop - } - if source.MaxCompletionTokens != nil && schema["max_completion_tokens"] { - target.MaxCompletionTokens = source.MaxCompletionTokens - } - if source.ReasoningEffort != nil && schema["reasoning_effort"] { - target.ReasoningEffort = source.ReasoningEffort - } - if source.StreamOptions != nil && schema["stream_options"] { - target.StreamOptions = source.StreamOptions - } - if source.Stream != nil && schema["stream"] { - target.Stream = source.Stream - } - if source.LogProbs != nil && schema["logprobs"] { - target.LogProbs = source.LogProbs - } - if source.TopLogProbs != nil && schema["top_logprobs"] { - target.TopLogProbs = source.TopLogProbs - } - if source.ResponseFormat != nil && schema["response_format"] { - target.ResponseFormat = source.ResponseFormat - } - if source.Seed != nil && schema["seed"] { - target.Seed = source.Seed - } - if source.LogitBias != nil && schema["logit_bias"] { - target.LogitBias = source.LogitBias - } -} - -// filterExtraParams filters the ExtraParams map -func (v *ParameterValidator) filterExtraParams(schema ParameterSet, source, target *ModelParameters) { - if source.ExtraParams == nil { - return - } - - for key, value := range source.ExtraParams { - if schema[key] { - target.ExtraParams[key] = value - } - } -} - -// isEmpty checks if all fields are nil/empty - manual check is faster -func (v *ParameterValidator) isEmpty(params *ModelParameters) bool { - return params.MaxTokens == nil && - params.Temperature == nil && - params.TopP == nil && - params.TopK == nil && - params.PresencePenalty == nil && - params.FrequencyPenalty == nil && - params.StopSequences == nil && - params.Tools == nil && - params.ToolChoice == nil && - params.User == nil && - params.EncodingFormat == nil && - params.Dimensions == nil && - params.ParallelToolCalls == nil && - len(params.ExtraParams) == 0 -} - -// IsValidParameter checks if a parameter is valid for a provider (O(1) lookup) -func (v *ParameterValidator) IsValidParameter(provider ModelProvider, paramName string) bool { - schema, exists := v.providerSchemas[provider] - if !exists { - return false - } - return schema[paramName] -} - -// GetSupportedParameters returns all supported parameters for a provider -func (v *ParameterValidator) GetSupportedParameters(provider ModelProvider) []string { - schema, exists := v.providerSchemas[provider] - if !exists { - return nil - } - - params := make([]string, 0, len(schema)) - for param := range schema { - params = append(params, param) - } - return params -} - -// Helper function to merge parameter sets -func mergeParameterSets(sets ...ParameterSet) ParameterSet { - totalSize := 0 - for _, set := range sets { - totalSize += len(set) - } - - result := make(ParameterSet, totalSize) - for _, set := range sets { - for param := range set { - result[param] = true - } - } - return result -} - -// buildProviderSchemas creates provider-specific parameter schemas -func buildProviderSchemas() map[ModelProvider]ParameterSet { - return map[ModelProvider]ParameterSet{ - OpenAI: mergeParameterSets(coreParams, openAIParams), - Azure: mergeParameterSets(coreParams, openAIParams), - Anthropic: mergeParameterSets(coreParams, anthropicParams), - Cohere: mergeParameterSets(coreParams, cohereParams), - Mistral: mergeParameterSets(coreParams, mistralParams), - Groq: mergeParameterSets(coreParams, groqParams), - Bedrock: mergeParameterSets(coreParams, anthropicParams, mistralParams, bedrockParams), - Vertex: mergeParameterSets(coreParams, openAIParams, anthropicParams, vertexParams), - Ollama: mergeParameterSets(coreParams, ollamaParams), - Cerebras: mergeParameterSets(coreParams, openAIParams), - SGL: mergeParameterSets(coreParams, openAIParams), - Parasail: mergeParameterSets(coreParams, openAIParams), - Gemini: mergeParameterSets(coreParams, openAIParams, ParameterSet{"top_k": true, "stop_sequences": true}), - OpenRouter: allowAllParams, - } -} - -// Global validator instance -var globalValidator = NewParameterValidator() - -// Public API functions using the global validator -func ValidateAndFilterParamsForProvider(provider ModelProvider, params *ModelParameters) *ModelParameters { - return globalValidator.FilterParameters(provider, params) -} - //* IMAGE UTILS *// // dataURIRegex is a precompiled regex for matching data URI format patterns. @@ -447,6 +84,21 @@ var fileExtensionToMediaType = map[string]string{ ".bmp": "image/bmp", } +// ImageContentType represents the type of image content +type ImageContentType string + +const ( + ImageContentTypeBase64 ImageContentType = "base64" + ImageContentTypeURL ImageContentType = "url" +) + +// URLTypeInfo contains extracted information about a URL +type URLTypeInfo struct { + Type ImageContentType + MediaType *string + DataURLWithoutPrefix *string // URL without the prefix (eg data:image/png;base64,iVBORw0KGgo...) +} + // SanitizeImageURL sanitizes and validates an image URL. // It handles both data URLs and regular HTTP/HTTPS URLs. // It also detects raw base64 image data and adds proper data URL headers. @@ -604,3 +256,276 @@ func isLikelyBase64(s string) bool { // Check if it contains only base64 characters using pre-compiled regex return base64Regex.MatchString(cleanData) } + +// Helper function to convert interface{} to JSON string +func JsonifyInput(input interface{}) string { + if input == nil { + return "{}" + } + jsonBytes, err := json.Marshal(input) + if err != nil { + return "{}" + } + return string(jsonBytes) +} + +//* SAFE EXTRACTION UTILITIES *// + +// SafeExtractString safely extracts a string value from an interface{} with type checking +func SafeExtractString(value interface{}) (string, bool) { + if value == nil { + return "", false + } + switch v := value.(type) { + case string: + return v, true + case *string: + if v != nil { + return *v, true + } + return "", false + case json.Number: + return string(v), true + default: + return "", false + } +} + +// SafeExtractInt safely extracts an int value from an interface{} with type checking +func SafeExtractInt(value interface{}) (int, bool) { + if value == nil { + return 0, false + } + switch v := value.(type) { + case int: + return v, true + case int8: + return int(v), true + case int16: + return int(v), true + case int32: + return int(v), true + case int64: + return int(v), true + case uint: + return int(v), true + case uint8: + return int(v), true + case uint16: + return int(v), true + case uint32: + return int(v), true + case uint64: + return int(v), true + case float32: + return int(v), true + case float64: + return int(v), true + case json.Number: + if intVal, err := v.Int64(); err == nil { + return int(intVal), true + } + return 0, false + case string: + if intVal, err := strconv.Atoi(v); err == nil { + return intVal, true + } + return 0, false + default: + return 0, false + } +} + +// SafeExtractFloat64 safely extracts a float64 value from an interface{} with type checking +func SafeExtractFloat64(value interface{}) (float64, bool) { + if value == nil { + return 0, false + } + switch v := value.(type) { + case float64: + return v, true + case float32: + return float64(v), true + case int: + return float64(v), true + case int8: + return float64(v), true + case int16: + return float64(v), true + case int32: + return float64(v), true + case int64: + return float64(v), true + case uint: + return float64(v), true + case uint8: + return float64(v), true + case uint16: + return float64(v), true + case uint32: + return float64(v), true + case uint64: + return float64(v), true + case json.Number: + if floatVal, err := v.Float64(); err == nil { + return floatVal, true + } + return 0, false + case string: + if floatVal, err := strconv.ParseFloat(v, 64); err == nil { + return floatVal, true + } + return 0, false + default: + return 0, false + } +} + +// SafeExtractBool safely extracts a bool value from an interface{} with type checking +func SafeExtractBool(value interface{}) (bool, bool) { + if value == nil { + return false, false + } + switch v := value.(type) { + case bool: + return v, true + case *bool: + if v != nil { + return *v, true + } + return false, false + case string: + if boolVal, err := strconv.ParseBool(v); err == nil { + return boolVal, true + } + return false, false + case int: + return v != 0, true + case int8: + return v != 0, true + case int16: + return v != 0, true + case int32: + return v != 0, true + case int64: + return v != 0, true + case uint: + return v != 0, true + case uint8: + return v != 0, true + case uint16: + return v != 0, true + case uint32: + return v != 0, true + case uint64: + return v != 0, true + case float32: + return v != 0, true + case float64: + return v != 0, true + default: + return false, false + } +} + +// SafeExtractStringSlice safely extracts a []string value from an interface{} with type checking +func SafeExtractStringSlice(value interface{}) ([]string, bool) { + if value == nil { + return nil, false + } + switch v := value.(type) { + case []string: + return v, true + case []interface{}: + var result []string + for _, item := range v { + if str, ok := SafeExtractString(item); ok { + result = append(result, str) + } else { + return nil, false // If any item is not a string, fail + } + } + return result, true + case []*string: + var result []string + for _, item := range v { + if item != nil { + result = append(result, *item) + } + } + return result, true + default: + return nil, false + } +} + +// SafeExtractStringPointer safely extracts a *string value from an interface{} with type checking +func SafeExtractStringPointer(value interface{}) (*string, bool) { + if value == nil { + return nil, false + } + switch v := value.(type) { + case *string: + return v, true + case string: + return &v, true + case json.Number: + str := string(v) + return &str, true + default: + return nil, false + } +} + +// SafeExtractIntPointer safely extracts an *int value from an interface{} with type checking +func SafeExtractIntPointer(value interface{}) (*int, bool) { + if value == nil { + return nil, false + } + if intVal, ok := SafeExtractInt(value); ok { + return &intVal, true + } + return nil, false +} + +// SafeExtractFloat64Pointer safely extracts a *float64 value from an interface{} with type checking +func SafeExtractFloat64Pointer(value interface{}) (*float64, bool) { + if value == nil { + return nil, false + } + if floatVal, ok := SafeExtractFloat64(value); ok { + return &floatVal, true + } + return nil, false +} + +// SafeExtractBoolPointer safely extracts a *bool value from an interface{} with type checking +func SafeExtractBoolPointer(value interface{}) (*bool, bool) { + if value == nil { + return nil, false + } + if boolVal, ok := SafeExtractBool(value); ok { + return &boolVal, true + } + return nil, false +} + +// SafeExtractStringSlicePointer safely extracts a *[]string value from an interface{} with type checking +func SafeExtractStringSlicePointer(value interface{}) (*[]string, bool) { + if value == nil { + return nil, false + } + if sliceVal, ok := SafeExtractStringSlice(value); ok { + return &sliceVal, true + } + return nil, false +} + +// SafeExtractFromMap safely extracts a value from a map[string]interface{} with type checking +func SafeExtractFromMap(m map[string]interface{}, key string) (interface{}, bool) { + if m == nil { + return nil, false + } + value, exists := m[key] + return value, exists +} diff --git a/core/utils.go b/core/utils.go index e0bddf0aa6..6adc11e0ee 100644 --- a/core/utils.go +++ b/core/utils.go @@ -13,14 +13,6 @@ func Ptr[T any](v T) *T { return &v } -func attachContextKeys(ctx context.Context, req *schemas.BifrostRequest, requestType schemas.RequestType) context.Context { - ctx = context.WithValue(ctx, schemas.BifrostContextKeyRequestType, requestType) - ctx = context.WithValue(ctx, schemas.BifrostContextKeyRequestProvider, req.Provider) - ctx = context.WithValue(ctx, schemas.BifrostContextKeyRequestModel, req.Model) - - return ctx -} - // providerRequiresKey returns true if the given provider requires an API key for authentication. // Some providers like Ollama and SGL are keyless and don't require API keys. func providerRequiresKey(providerKey schemas.ModelProvider) bool { @@ -129,7 +121,7 @@ func IsStandardProvider(providerKey schemas.ModelProvider) bool { // IsStreamRequestType returns true if the given request type is a stream request. func IsStreamRequestType(reqType schemas.RequestType) bool { - return reqType == schemas.ChatCompletionStreamRequest || reqType == schemas.SpeechStreamRequest || reqType == schemas.TranscriptionStreamRequest + return reqType == schemas.ChatCompletionStreamRequest || reqType == schemas.ResponsesStreamRequest || reqType == schemas.SpeechStreamRequest || reqType == schemas.TranscriptionStreamRequest } func IsFinalChunk(ctx *context.Context) bool { @@ -148,3 +140,11 @@ func IsFinalChunk(ctx *context.Context) bool { return false } + +func GetRequestFields(result *schemas.BifrostResponse, err *schemas.BifrostError) (schemas.RequestType, schemas.ModelProvider, string) { + if result != nil { + return result.ExtraFields.RequestType, result.ExtraFields.Provider, result.ExtraFields.ModelRequested + } + + return err.ExtraFields.RequestType, err.ExtraFields.Provider, err.ExtraFields.ModelRequested +} diff --git a/docs/apis/openapi.json b/docs/apis/openapi.json index 350bf06869..af20666922 100644 --- a/docs/apis/openapi.json +++ b/docs/apis/openapi.json @@ -227,7 +227,7 @@ } } }, - "/v1/text/completions": { + "/v1/completions": { "post": { "summary": "Create Text Completion", "description": "Creates a text completion from a prompt. Useful for text generation, summarization, and other non-conversational tasks.", @@ -372,7 +372,7 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/BifrostMessage" + "$ref": "#/components/schemas/ChatMessage" }, "examples": { "search_result": { @@ -1803,7 +1803,7 @@ "messages": { "type": "array", "items": { - "$ref": "#/components/schemas/BifrostMessage" + "$ref": "#/components/schemas/ChatMessage" }, "description": "Array of chat messages", "minItems": 1 @@ -1892,7 +1892,23 @@ "description": "AI model provider", "example": "openai" }, - "BifrostMessage": { + "RequestType": { + "type": "string", + "enum": [ + "text_completion", + "chat_completion", + "chat_completion_stream", + "responses", + "responses_stream", + "embedding", + "speech", + "speech_stream", + "transcription", + "transcription_stream" + ], + "description": "Request type" + }, + "ChatMessage": { "type": "object", "required": ["role"], "properties": { @@ -2343,7 +2359,7 @@ "example": 0 }, "message": { - "$ref": "#/components/schemas/BifrostMessage" + "$ref": "#/components/schemas/ChatMessage" }, "finish_reason": { "type": "string", @@ -2416,6 +2432,13 @@ "provider": { "$ref": "#/components/schemas/ModelProvider" }, + "request_type": { + "$ref": "#/components/schemas/RequestType" + }, + "model_requested": { + "type": "string", + "description": "Model requested" + }, "model_params": { "$ref": "#/components/schemas/ModelParameters" }, @@ -2424,13 +2447,6 @@ "description": "Request latency in seconds", "example": 1.234 }, - "chat_history": { - "type": "array", - "items": { - "$ref": "#/components/schemas/BifrostMessage" - }, - "description": "Full conversation history" - }, "billed_usage": { "$ref": "#/components/schemas/BilledLLMUsage" }, diff --git a/docs/features/fallbacks.mdx b/docs/features/fallbacks.mdx index 315319c077..caa2001d3f 100644 --- a/docs/features/fallbacks.mdx +++ b/docs/features/fallbacks.mdx @@ -96,10 +96,10 @@ func chatWithFallbacks(client *bifrost.Bifrost) { response, err := client.ChatCompletion(ctx, &schemas.BifrostRequest{ Provider: schemas.OpenAI, Model: "gpt-4o-mini", - Messages: []schemas.BifrostMessage{ + Messages: []schemas.ChatMessage{ { - Role: schemas.ModelChatMessageRoleUser, - Content: schemas.MessageContent{ + Role: schemas.ChatMessageRoleUser, + Content: schemas.ChatMessageContent{ ContentStr: bifrost.Ptr("Explain quantum computing in simple terms"), }, }, diff --git a/docs/features/mcp.mdx b/docs/features/mcp.mdx index fc2c744e31..b8e16f7886 100644 --- a/docs/features/mcp.mdx +++ b/docs/features/mcp.mdx @@ -298,9 +298,9 @@ func main() { }, }) - firstMessage := schemas.BifrostMessage{ - Role: schemas.ModelChatMessageRoleUser, - Content: schemas.MessageContent{ + firstMessage := schemas.ChatMessage{ + Role: schemas.ChatMessageRoleUser, + Content: schemas.ChatMessageContent{ ContentStr: bifrost.Ptr("Read the contents of config.json file"), }, } @@ -310,7 +310,7 @@ func main() { Provider: schemas.OpenAI, Model: "gpt-4o-mini", Input: schemas.RequestInput{ - ChatCompletionInput: &[]schemas.BifrostMessage{ + ChatCompletionInput: &[]schemas.ChatMessage{ firstMessage, }, }, @@ -326,7 +326,7 @@ func main() { } // Build conversation history for final response - conversationHistory := []schemas.BifrostMessage{ + conversationHistory := []schemas.ChatMessage{ firstMessage, } @@ -592,10 +592,10 @@ func main() { Provider: schemas.OpenAI, Model: "gpt-4o-mini", Input: schemas.RequestInput{ - ChatCompletionInput: &[]schemas.BifrostMessage{ + ChatCompletionInput: &[]schemas.ChatMessage{ { - Role: schemas.ModelChatMessageRoleUser, - Content: schemas.MessageContent{ + Role: schemas.ChatMessageRoleUser, + Content: schemas.ChatMessageContent{ ContentStr: bifrost.Ptr("Calculate 15.5 + 24.3"), }, }, diff --git a/docs/features/plugins/jsonparser.mdx b/docs/features/plugins/jsonparser.mdx index 379f09f9ad..3bad94ce89 100644 --- a/docs/features/plugins/jsonparser.mdx +++ b/docs/features/plugins/jsonparser.mdx @@ -208,16 +208,16 @@ func main() { if err != nil { panic(err) } - defer client.Cleanup() + defer client.Shutdown() // Request structured JSON response request := &schemas.BifrostRequest{ Provider: schemas.OpenAI, Model: "gpt-4o-mini", - Messages: []schemas.BifrostMessage{ + Messages: []schemas.ChatMessage{ { - Role: schemas.ModelChatMessageRoleUser, - Content: schemas.MessageContent{ + Role: schemas.ChatMessageRoleUser, + Content: schemas.ChatMessageContent{ ContentStr: bifrost.Ptr("Return user profile as JSON: {\"name\": \"John Doe\", \"email\": \"john@example.com\"}"), }, }, diff --git a/docs/features/plugins/mocker.mdx b/docs/features/plugins/mocker.mdx index 7749b15d84..2262a70215 100644 --- a/docs/features/plugins/mocker.mdx +++ b/docs/features/plugins/mocker.mdx @@ -37,17 +37,17 @@ func main() { if err != nil { panic(err) } - defer client.Cleanup() + defer client.Shutdown() // All requests will now return: "This is a mock response from the Mocker plugin" response, _ := client.ChatCompletionRequest(context.Background(), &schemas.BifrostRequest{ Provider: schemas.OpenAI, Model: "gpt-4", Input: schemas.RequestInput{ - ChatCompletionInput: &[]schemas.BifrostMessage{ + ChatCompletionInput: &[]schemas.ChatMessage{ { - Role: schemas.ModelChatMessageRoleUser, - Content: schemas.MessageContent{ + Role: schemas.ChatMessageRoleUser, + Content: schemas.ChatMessageContent{ ContentStr: bifrost.Ptr("Hello!"), }, }, diff --git a/docs/quickstart/go-sdk/multimodal.mdx b/docs/quickstart/go-sdk/multimodal.mdx index 3c3d8cf6e5..e8c444c45a 100644 --- a/docs/quickstart/go-sdk/multimodal.mdx +++ b/docs/quickstart/go-sdk/multimodal.mdx @@ -13,10 +13,10 @@ response, err := client.ChatCompletionRequest(context.Background(), &schemas.Bif Provider: schemas.OpenAI, Model: "gpt-4o", // Using vision-capable model Input: schemas.RequestInput{ - ChatCompletionInput: &[]schemas.BifrostMessage{ + ChatCompletionInput: &[]schemas.ChatMessage{ { - Role: schemas.ModelChatMessageRoleUser, - Content: schemas.MessageContent{ + Role: schemas.ChatMessageRoleUser, + Content: schemas.ChatMessageContent{ ContentBlocks: &[]schemas.ContentBlock{ { Type: schemas.ContentBlockTypeText, @@ -54,10 +54,10 @@ response, err := client.ChatCompletionRequest(context.Background(), &schemas.Bif Provider: schemas.OpenAI, Model: "gpt-4o-audio-preview", Input: schemas.RequestInput{ - ChatCompletionInput: &[]schemas.BifrostMessage{ + ChatCompletionInput: &[]schemas.ChatMessage{ { - Role: schemas.ModelChatMessageRoleUser, - Content: schemas.MessageContent{ + Role: schemas.ChatMessageRoleUser, + Content: schemas.ChatMessageContent{ ContentBlocks: &[]schemas.ContentBlock{ { Type: schemas.ContentBlockTypeText, @@ -157,10 +157,10 @@ response, err := client.ChatCompletionRequest(context.Background(), &schemas.Bif Provider: schemas.OpenAI, Model: "gpt-4o", Input: schemas.RequestInput{ - ChatCompletionInput: &[]schemas.BifrostMessage{ + ChatCompletionInput: &[]schemas.ChatMessage{ { - Role: schemas.ModelChatMessageRoleUser, - Content: schemas.MessageContent{ + Role: schemas.ChatMessageRoleUser, + Content: schemas.ChatMessageContent{ ContentBlocks: &[]schemas.ContentBlock{ { Type: schemas.ContentBlockTypeText, @@ -203,10 +203,10 @@ response, err := client.ChatCompletionRequest(context.Background(), &schemas.Bif Provider: schemas.OpenAI, Model: "gpt-4o", Input: schemas.RequestInput{ - ChatCompletionInput: &[]schemas.BifrostMessage{ + ChatCompletionInput: &[]schemas.ChatMessage{ { - Role: schemas.ModelChatMessageRoleUser, - Content: schemas.MessageContent{ + Role: schemas.ChatMessageRoleUser, + Content: schemas.ChatMessageContent{ ContentBlocks: &[]schemas.ContentBlock{ { Type: schemas.ContentBlockTypeText, diff --git a/docs/quickstart/go-sdk/provider-configuration.mdx b/docs/quickstart/go-sdk/provider-configuration.mdx index 142ad8ec52..1d7a8f5037 100644 --- a/docs/quickstart/go-sdk/provider-configuration.mdx +++ b/docs/quickstart/go-sdk/provider-configuration.mdx @@ -254,8 +254,14 @@ type BifrostResponse struct { } type BifrostResponseExtraFields struct { - Provider ModelProvider `json:"provider"` - RawResponse interface{} `json:"raw_response,omitempty"` // Original provider response + RequestType RequestType `json:"request_type"` + Provider ModelProvider `json:"provider"` + ModelRequested string `json:"model_requested"` + Latency *float64 `json:"latency,omitempty"` + BilledUsage *BilledLLMUsage `json:"billed_usage,omitempty"` + ChunkIndex int `json:"chunk_index"` // used for streaming responses to identify the chunk index, will be 0 for non-streaming responses + RawResponse interface{} `json:"raw_response,omitempty"` + CacheDebug *BifrostCacheDebug `json:"cache_debug,omitempty"` } ``` diff --git a/docs/quickstart/go-sdk/setting-up.mdx b/docs/quickstart/go-sdk/setting-up.mdx index be6d9165cf..aac02cbbd2 100644 --- a/docs/quickstart/go-sdk/setting-up.mdx +++ b/docs/quickstart/go-sdk/setting-up.mdx @@ -78,12 +78,12 @@ func main() { if initErr != nil { panic(initErr) } - defer client.Cleanup() + defer client.Shutdown() - messages := []schemas.BifrostMessage{ + messages := []schemas.ChatMessage{ { - Role: schemas.ModelChatMessageRoleUser, - Content: schemas.MessageContent{ + Role: schemas.ChatMessageRoleUser, + Content: schemas.ChatMessageContent{ ContentStr: bifrost.Ptr("Hello, Bifrost!"), }, }, diff --git a/docs/quickstart/go-sdk/tool-calling.mdx b/docs/quickstart/go-sdk/tool-calling.mdx index f58549981a..9b592a4903 100644 --- a/docs/quickstart/go-sdk/tool-calling.mdx +++ b/docs/quickstart/go-sdk/tool-calling.mdx @@ -41,10 +41,10 @@ response, err := client.ChatCompletionRequest(context.Background(), &schemas.Bif Provider: schemas.OpenAI, Model: "gpt-4o-mini", Input: schemas.RequestInput{ - ChatCompletionInput: &[]schemas.BifrostMessage{ + ChatCompletionInput: &[]schemas.ChatMessage{ { - Role: schemas.ModelChatMessageRoleUser, - Content: schemas.MessageContent{ + Role: schemas.ChatMessageRoleUser, + Content: schemas.ChatMessageContent{ ContentStr: bifrost.Ptr("What is 2+2? Use the calculator tool."), }, }, @@ -89,16 +89,16 @@ client, initErr := bifrost.Init(context.Background(), schemas.BifrostConfig{ if initErr != nil { panic(initErr) } -defer client.Cleanup() +defer client.Shutdown() response, err := client.ChatCompletionRequest(context.Background(), &schemas.BifrostRequest{ Provider: schemas.OpenAI, Model: "gpt-4o-mini", Input: schemas.RequestInput{ - ChatCompletionInput: &[]schemas.BifrostMessage{ + ChatCompletionInput: &[]schemas.ChatMessage{ { - Role: schemas.ModelChatMessageRoleUser, - Content: schemas.MessageContent{ + Role: schemas.ChatMessageRoleUser, + Content: schemas.ChatMessageContent{ ContentStr: bifrost.Ptr("What do you see when you search for 'bifrost' on youtube?"), }, }, @@ -214,10 +214,10 @@ response, err := client.ChatCompletionRequest(context.Background(), &schemas.Bif Provider: schemas.OpenAI, Model: "gpt-4o-mini", Input: schemas.RequestInput{ - ChatCompletionInput: &[]schemas.BifrostMessage{ + ChatCompletionInput: &[]schemas.ChatMessage{ { - Role: schemas.ModelChatMessageRoleUser, - Content: schemas.MessageContent{ + Role: schemas.ChatMessageRoleUser, + Content: schemas.ChatMessageContent{ ContentStr: bifrost.Ptr("What's the weather in New York and calculate 15% tip for a $50 bill?"), }, }, diff --git a/framework/configstore/sqlite.go b/framework/configstore/sqlite.go index 76514bbba3..7d964c8d15 100644 --- a/framework/configstore/sqlite.go +++ b/framework/configstore/sqlite.go @@ -124,11 +124,17 @@ func (s *SQLiteConfigStore) UpdateProvidersConfig(providers map[schemas.ModelPro Value: key.Value, Models: key.Models, Weight: key.Weight, + OpenAIKeyConfig: key.OpenAIKeyConfig, AzureKeyConfig: key.AzureKeyConfig, VertexKeyConfig: key.VertexKeyConfig, BedrockKeyConfig: key.BedrockKeyConfig, } + // Handle OpenAI config + if key.OpenAIKeyConfig != nil { + dbKey.OpenAIUseResponsesAPI = &key.OpenAIKeyConfig.UseResponsesAPI + } + // Handle Azure config if key.AzureKeyConfig != nil { dbKey.AzureEndpoint = &key.AzureKeyConfig.Endpoint @@ -234,11 +240,17 @@ func (s *SQLiteConfigStore) UpdateProvider(provider schemas.ModelProvider, confi Value: key.Value, Models: key.Models, Weight: key.Weight, + OpenAIKeyConfig: key.OpenAIKeyConfig, AzureKeyConfig: key.AzureKeyConfig, VertexKeyConfig: key.VertexKeyConfig, BedrockKeyConfig: key.BedrockKeyConfig, } + // Handle OpenAI config + if key.OpenAIKeyConfig != nil { + dbKey.OpenAIUseResponsesAPI = &key.OpenAIKeyConfig.UseResponsesAPI + } + // Handle Azure config if key.AzureKeyConfig != nil { dbKey.AzureEndpoint = &key.AzureKeyConfig.Endpoint @@ -332,11 +344,17 @@ func (s *SQLiteConfigStore) AddProvider(provider schemas.ModelProvider, config P Value: key.Value, Models: key.Models, Weight: key.Weight, + OpenAIKeyConfig: key.OpenAIKeyConfig, AzureKeyConfig: key.AzureKeyConfig, VertexKeyConfig: key.VertexKeyConfig, BedrockKeyConfig: key.BedrockKeyConfig, } + // Handle OpenAI config + if key.OpenAIKeyConfig != nil { + dbKey.OpenAIUseResponsesAPI = &key.OpenAIKeyConfig.UseResponsesAPI + } + // Handle Azure config if key.AzureKeyConfig != nil { dbKey.AzureEndpoint = &key.AzureKeyConfig.Endpoint @@ -480,6 +498,7 @@ func (s *SQLiteConfigStore) GetProvidersConfig() (map[schemas.ModelProvider]Prov Value: processedValue, Models: dbKey.Models, Weight: dbKey.Weight, + OpenAIKeyConfig: dbKey.OpenAIKeyConfig, AzureKeyConfig: azureConfig, VertexKeyConfig: vertexConfig, BedrockKeyConfig: bedrockConfig, diff --git a/framework/configstore/tables.go b/framework/configstore/tables.go index 06d5b99fbe..c92d0720c0 100644 --- a/framework/configstore/tables.go +++ b/framework/configstore/tables.go @@ -66,6 +66,9 @@ type TableKey struct { CreatedAt time.Time `gorm:"index;not null" json:"created_at"` UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"` + // OpenAI config fields (embedded) + OpenAIUseResponsesAPI *bool `gorm:"type:boolean" json:"openai_use_responses_api,omitempty"` + // Azure config fields (embedded instead of separate table for simplicity) AzureEndpoint *string `gorm:"type:text" json:"azure_endpoint,omitempty"` AzureAPIVersion *string `gorm:"type:varchar(50)" json:"azure_api_version,omitempty"` @@ -86,6 +89,7 @@ type TableKey struct { // Virtual fields for runtime use (not stored in DB) Models []string `gorm:"-" json:"models"` + OpenAIKeyConfig *schemas.OpenAIKeyConfig `gorm:"-" json:"openai_key_config,omitempty"` AzureKeyConfig *schemas.AzureKeyConfig `gorm:"-" json:"azure_key_config,omitempty"` VertexKeyConfig *schemas.VertexKeyConfig `gorm:"-" json:"vertex_key_config,omitempty"` BedrockKeyConfig *schemas.BedrockKeyConfig `gorm:"-" json:"bedrock_key_config,omitempty"` @@ -244,6 +248,12 @@ func (k *TableKey) BeforeSave(tx *gorm.DB) error { k.ModelsJSON = "[]" } + if k.OpenAIKeyConfig != nil { + k.OpenAIUseResponsesAPI = &k.OpenAIKeyConfig.UseResponsesAPI + } else { + k.OpenAIUseResponsesAPI = nil + } + if k.AzureKeyConfig != nil { if k.AzureKeyConfig.Endpoint != "" { k.AzureEndpoint = &k.AzureKeyConfig.Endpoint @@ -425,6 +435,13 @@ func (k *TableKey) AfterFind(tx *gorm.DB) error { } } + // Reconstruct OpenAI config if fields are present + if k.OpenAIUseResponsesAPI != nil { + k.OpenAIKeyConfig = &schemas.OpenAIKeyConfig{ + UseResponsesAPI: *k.OpenAIUseResponsesAPI, + } + } + // Reconstruct Azure config if fields are present if k.AzureEndpoint != nil { azureConfig := &schemas.AzureKeyConfig{ diff --git a/framework/logstore/tables.go b/framework/logstore/tables.go index 6e6286963c..079d3f65fc 100644 --- a/framework/logstore/tables.go +++ b/framework/logstore/tables.go @@ -73,8 +73,8 @@ type Log struct { Object string `gorm:"type:varchar(255);index;not null;column:object_type" json:"object"` // text.completion, chat.completion, or embedding Provider string `gorm:"type:varchar(255);index;not null" json:"provider"` Model string `gorm:"type:varchar(255);index;not null" json:"model"` - InputHistory string `gorm:"type:text" json:"-"` // JSON serialized []schemas.BifrostMessage - OutputMessage string `gorm:"type:text" json:"-"` // JSON serialized *schemas.BifrostMessage + InputHistory string `gorm:"type:text" json:"-"` // JSON serialized []schemas.ChatMessage + OutputMessage string `gorm:"type:text" json:"-"` // JSON serialized *schemas.ChatMessage EmbeddingOutput string `gorm:"type:text" json:"-"` // JSON serialized *[][]float32 Params string `gorm:"type:text" json:"-"` // JSON serialized *schemas.ModelParameters Tools string `gorm:"type:text" json:"-"` // JSON serialized *[]schemas.Tool @@ -100,19 +100,19 @@ type Log struct { CreatedAt time.Time `gorm:"index;not null" json:"created_at"` // Virtual fields for JSON output - these will be populated when needed - InputHistoryParsed []schemas.BifrostMessage `gorm:"-" json:"input_history,omitempty"` - OutputMessageParsed *schemas.BifrostMessage `gorm:"-" json:"output_message,omitempty"` - EmbeddingOutputParsed *[]schemas.BifrostEmbedding `gorm:"-" json:"embedding_output,omitempty"` - ParamsParsed *schemas.ModelParameters `gorm:"-" json:"params,omitempty"` - ToolsParsed *[]schemas.Tool `gorm:"-" json:"tools,omitempty"` - ToolCallsParsed *[]schemas.ToolCall `gorm:"-" json:"tool_calls,omitempty"` - TokenUsageParsed *schemas.LLMUsage `gorm:"-" json:"token_usage,omitempty"` - ErrorDetailsParsed *schemas.BifrostError `gorm:"-" json:"error_details,omitempty"` - SpeechInputParsed *schemas.SpeechInput `gorm:"-" json:"speech_input,omitempty"` - TranscriptionInputParsed *schemas.TranscriptionInput `gorm:"-" json:"transcription_input,omitempty"` - SpeechOutputParsed *schemas.BifrostSpeech `gorm:"-" json:"speech_output,omitempty"` - TranscriptionOutputParsed *schemas.BifrostTranscribe `gorm:"-" json:"transcription_output,omitempty"` - CacheDebugParsed *schemas.BifrostCacheDebug `gorm:"-" json:"cache_debug,omitempty"` + InputHistoryParsed []schemas.ChatMessage `gorm:"-" json:"input_history,omitempty"` + OutputMessageParsed *schemas.ChatMessage `gorm:"-" json:"output_message,omitempty"` + EmbeddingOutputParsed *[]schemas.BifrostEmbedding `gorm:"-" json:"embedding_output,omitempty"` + ParamsParsed interface{} `gorm:"-" json:"params,omitempty"` + ToolsParsed *[]schemas.ChatTool `gorm:"-" json:"tools,omitempty"` + ToolCallsParsed *[]schemas.ChatAssistantMessageToolCall `gorm:"-" json:"tool_calls,omitempty"` + TokenUsageParsed *schemas.LLMUsage `gorm:"-" json:"token_usage,omitempty"` + ErrorDetailsParsed *schemas.BifrostError `gorm:"-" json:"error_details,omitempty"` + SpeechInputParsed *schemas.SpeechInput `gorm:"-" json:"speech_input,omitempty"` + TranscriptionInputParsed *schemas.TranscriptionInput `gorm:"-" json:"transcription_input,omitempty"` + SpeechOutputParsed *schemas.BifrostSpeech `gorm:"-" json:"speech_output,omitempty"` + TranscriptionOutputParsed *schemas.BifrostTranscribe `gorm:"-" json:"transcription_output,omitempty"` + CacheDebugParsed *schemas.BifrostCacheDebug `gorm:"-" json:"cache_debug,omitempty"` } // TableName sets the table name for GORM @@ -260,7 +260,7 @@ func (l *Log) DeserializeFields() error { if l.InputHistory != "" { if err := json.Unmarshal([]byte(l.InputHistory), &l.InputHistoryParsed); err != nil { // Log error but don't fail the operation - initialize as empty slice - l.InputHistoryParsed = []schemas.BifrostMessage{} + l.InputHistoryParsed = []schemas.ChatMessage{} } } diff --git a/framework/pricing/main.go b/framework/pricing/main.go index 45880cbd6b..475c4fafb2 100644 --- a/framework/pricing/main.go +++ b/framework/pricing/main.go @@ -84,7 +84,7 @@ func Init(configStore configstore.ConfigStore, logger schemas.Logger) (*PricingM if err := pm.syncPricing(); err != nil { return nil, fmt.Errorf("failed to sync pricing data: %w", err) } - + } else { // Load pricing data from config memory if err := pm.loadPricingIntoMemory(); err != nil { @@ -100,8 +100,8 @@ func Init(configStore configstore.ConfigStore, logger schemas.Logger) (*PricingM return pm, nil } -func (pm *PricingManager) CalculateCost(result *schemas.BifrostResponse, provider schemas.ModelProvider, model string, requestType schemas.RequestType) float64 { - if result == nil || provider == "" || model == "" || requestType == "" { +func (pm *PricingManager) CalculateCost(result *schemas.BifrostResponse) float64 { + if result == nil { return 0.0 } @@ -157,14 +157,14 @@ func (pm *PricingManager) CalculateCost(result *schemas.BifrostResponse, provide cost := 0.0 if usage != nil || audioSeconds != nil || audioTokenDetails != nil { - cost = pm.CalculateCostFromUsage(string(provider), model, usage, requestType, isCacheRead, isBatch, audioSeconds, audioTokenDetails) + cost = pm.CalculateCostFromUsage(string(result.ExtraFields.Provider), result.ExtraFields.ModelRequested, usage, result.ExtraFields.RequestType, isCacheRead, isBatch, audioSeconds, audioTokenDetails) } return cost } -func (pm *PricingManager) CalculateCostWithCacheDebug(result *schemas.BifrostResponse, provider schemas.ModelProvider, model string, requestType schemas.RequestType) float64 { - if result == nil || provider == "" || model == "" || requestType == "" { +func (pm *PricingManager) CalculateCostWithCacheDebug(result *schemas.BifrostResponse) float64 { + if result == nil { return 0.0 } cacheDebug := result.ExtraFields.CacheDebug @@ -183,7 +183,7 @@ func (pm *PricingManager) CalculateCostWithCacheDebug(result *schemas.BifrostRes // Don't over-bill cache hits if fields are missing. return 0 } else { - baseCost := pm.CalculateCost(result, provider, model, requestType) + baseCost := pm.CalculateCost(result) var semanticCacheCost float64 if cacheDebug.ProviderUsed != nil && cacheDebug.ModelUsed != nil && cacheDebug.InputTokens != nil { semanticCacheCost = pm.CalculateCostFromUsage(*cacheDebug.ProviderUsed, *cacheDebug.ModelUsed, &schemas.LLMUsage{ @@ -197,7 +197,7 @@ func (pm *PricingManager) CalculateCostWithCacheDebug(result *schemas.BifrostRes } } - return pm.CalculateCost(result, provider, model, requestType) + return pm.CalculateCost(result) } func (pm *PricingManager) Cleanup() error { @@ -332,12 +332,22 @@ func (pm *PricingManager) getPricing(model, provider string, requestType schemas pricing, ok := pm.pricingData[makeKey(model, provider, normalizeRequestType(requestType))] if !ok { + // Lookup in vertex if gemini not found if provider == string(schemas.Gemini) { pricing, ok = pm.pricingData[makeKey(model, "vertex", normalizeRequestType(requestType))] if ok { return &pricing, true } } + + // Lookup in chat if responses not found + if requestType == schemas.ResponsesRequest || requestType == schemas.ResponsesStreamRequest { + pricing, ok = pm.pricingData[makeKey(model, provider, normalizeRequestType(schemas.ChatCompletionRequest))] + if ok { + return &pricing, true + } + } + return nil, false } return &pricing, true diff --git a/framework/pricing/utils.go b/framework/pricing/utils.go index b841e3b202..28bbcb4806 100644 --- a/framework/pricing/utils.go +++ b/framework/pricing/utils.go @@ -30,7 +30,7 @@ func isCacheReadRequest(req *schemas.BifrostRequest, headers map[string]string) return true } - // TODO: Add message-level cache control detection when BifrostMessage schema supports it + // TODO: Add message-level cache control detection when ChatMessage schema supports it // For now, cache detection relies on headers only return false @@ -54,6 +54,8 @@ func normalizeRequestType(reqType schemas.RequestType) string { baseType = "completion" case schemas.ChatCompletionRequest, schemas.ChatCompletionStreamRequest: baseType = "chat" + case schemas.ResponsesRequest, schemas.ResponsesStreamRequest: + baseType = "responses" case schemas.EmbeddingRequest: baseType = "embedding" case schemas.SpeechRequest, schemas.SpeechStreamRequest: @@ -82,7 +84,7 @@ func convertPricingDataToTableModelPricing(modelKey string, entry PricingEntry) modelName = strings.Join(parts[1:], "/") } } - + pricing := configstore.TableModelPricing{ Model: modelName, Provider: provider, diff --git a/plugins/governance/main.go b/plugins/governance/main.go index d3b51ea464..510df270d6 100644 --- a/plugins/governance/main.go +++ b/plugins/governance/main.go @@ -117,10 +117,6 @@ func (p *GovernancePlugin) PreHook(ctx *context.Context, req *schemas.BifrostReq provider := req.Provider model := req.Model - // Store original request provider/model and operation flags in context for PostHook - *ctx = context.WithValue(*ctx, schemas.BifrostContextKeyRequestProvider, provider) - *ctx = context.WithValue(*ctx, schemas.BifrostContextKeyRequestModel, model) - // Create request context for evaluation evaluationRequest := &EvaluationRequest{ VirtualKey: virtualKey, @@ -208,32 +204,8 @@ func (p *GovernancePlugin) PostHook(ctx *context.Context, result *schemas.Bifros return result, err, nil } - // Extract provider and model from stored context values (set in PreHook) - var provider schemas.ModelProvider - var model string - var requestType schemas.RequestType - - if providerValue := (*ctx).Value(schemas.BifrostContextKeyRequestProvider); providerValue != nil { - if p, ok := providerValue.(schemas.ModelProvider); ok { - provider = p - } - } - if modelValue := (*ctx).Value(schemas.BifrostContextKeyRequestModel); modelValue != nil { - if m, ok := modelValue.(string); ok { - model = m - } - } - if requestTypeValue := (*ctx).Value(schemas.BifrostContextKeyRequestType); requestTypeValue != nil { - if r, ok := requestTypeValue.(schemas.RequestType); ok { - requestType = r - } - } - - // If we couldn't get provider/model from context, skip usage tracking - if provider == "" || model == "" { - p.logger.Debug("Could not extract provider/model from context, skipping usage tracking") - return result, err, nil - } + // Extract request type, provider, and model + requestType, provider, model := bifrost.GetRequestFields(result, err) // Extract cache and batch flags from context isCacheRead := false @@ -295,8 +267,8 @@ func (p *GovernancePlugin) postHookWorker(result *schemas.BifrostResponse, provi cost := 0.0 if !isStreaming || (isStreaming && isFinalChunk) { - if p.pricingManager != nil { - cost = p.pricingManager.CalculateCost(result, provider, model, requestType) + if p.pricingManager != nil && result != nil { + cost = p.pricingManager.CalculateCost(result) } } diff --git a/plugins/jsonparser/main.go b/plugins/jsonparser/main.go index 10289a6a0c..bd5ec3af9f 100644 --- a/plugins/jsonparser/main.go +++ b/plugins/jsonparser/main.go @@ -12,8 +12,7 @@ import ( ) const ( - PluginName = "streaming-json-parser" - EnableStreamingJSONParser = "enable-streaming-json-parser" + PluginName = "streaming-json-parser" ) type Usage string @@ -50,6 +49,12 @@ type PluginConfig struct { MaxAge time.Duration } +type ContextKey string + +const ( + EnableStreamingJSONParser ContextKey = "enable-streaming-json-parser" +) + // Init creates a new JSON parser plugin instance with custom configuration func Init(config PluginConfig) (*JsonParserPlugin, error) { // Set defaults if not provided @@ -89,13 +94,13 @@ func (p *JsonParserPlugin) PreHook(ctx *context.Context, req *schemas.BifrostReq // PostHook processes streaming responses by accumulating chunks and making accumulated content valid JSON func (p *JsonParserPlugin) PostHook(ctx *context.Context, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) { - // Check if plugin should run based on usage type - if !p.shouldRun(ctx) { + // If there's an error, don't process + if err != nil { return result, err, nil } - // If there's an error, don't process - if err != nil { + // Check if plugin should run based on usage type + if !p.shouldRun(ctx, result.ExtraFields.RequestType) { return result, err, nil } @@ -201,10 +206,9 @@ func (p *JsonParserPlugin) accumulateContent(requestID, newContent string) strin } // shouldRun determines if the plugin should process the request based on usage type -func (p *JsonParserPlugin) shouldRun(ctx *context.Context) bool { +func (p *JsonParserPlugin) shouldRun(ctx *context.Context, requestType schemas.RequestType) bool { // Run only for chat completion stream requests - requestType, ok := (*ctx).Value(schemas.BifrostContextKeyRequestType).(schemas.RequestType) - if !ok || requestType != schemas.ChatCompletionStreamRequest { + if requestType != schemas.ChatCompletionStreamRequest { return false } diff --git a/plugins/jsonparser/plugin_test.go b/plugins/jsonparser/plugin_test.go index a4b3c9a498..1565deeb7f 100644 --- a/plugins/jsonparser/plugin_test.go +++ b/plugins/jsonparser/plugin_test.go @@ -81,23 +81,20 @@ func TestJsonParserPluginEndToEnd(t *testing.T) { } defer client.Shutdown() - // Make a test chat completion request with streaming enabled + // Make a test responses request with streaming enabled // Request JSON output to test the parser - request := &schemas.BifrostRequest{ + request := &schemas.BifrostChatRequest{ Provider: schemas.OpenAI, Model: "gpt-4o-mini", - Input: schemas.RequestInput{ - ChatCompletionInput: &[]schemas.BifrostMessage{ - { - Role: "user", - Content: schemas.MessageContent{ - ContentStr: bifrost.Ptr("Return a JSON object with name, age, and city fields. Example: {\"name\": \"John\", \"age\": 30, \"city\": \"New York\"}"), - }, + Input: []schemas.ChatMessage{ + { + Role: schemas.ChatMessageRoleUser, + Content: schemas.ChatMessageContent{ + ContentStr: bifrost.Ptr("Return a JSON object with name, age, and city fields. Example: {\"name\": \"John\", \"age\": 30, \"city\": \"New York\"}"), }, }, }, - Params: &schemas.ModelParameters{ - + Params: &schemas.ChatParameters{ ExtraParams: map[string]any{ "stream": true, "response_format": map[string]any{ @@ -128,11 +125,13 @@ func TestJsonParserPluginEndToEnd(t *testing.T) { } if streamResponse.BifrostResponse != nil { - for _, choice := range streamResponse.BifrostResponse.Choices { - if choice.BifrostStreamResponseChoice != nil && choice.BifrostStreamResponseChoice.Delta.Content != nil { - content := *choice.BifrostStreamResponseChoice.Delta.Content - if content != "" { - t.Logf("Chunk %d: %s", responseCount, content) + if streamResponse.BifrostResponse.ResponsesResponse != nil { + for _, outputMsg := range streamResponse.BifrostResponse.ResponsesResponse.Output { + if outputMsg.Content != nil && outputMsg.Content.ContentStr != nil { + content := *outputMsg.Content.ContentStr + if content != "" { + t.Logf("Chunk %d: %s", responseCount, content) + } } } } @@ -183,20 +182,18 @@ func TestJsonParserPluginPerRequest(t *testing.T) { defer client.Shutdown() // Test request with plugin enabled via context - request := &schemas.BifrostRequest{ + request := &schemas.BifrostChatRequest{ Provider: schemas.OpenAI, Model: "gpt-4o-mini", - Input: schemas.RequestInput{ - ChatCompletionInput: &[]schemas.BifrostMessage{ - { - Role: "user", - Content: schemas.MessageContent{ - ContentStr: bifrost.Ptr("Return a JSON object with name and age fields."), - }, + Input: []schemas.ChatMessage{ + { + Role: schemas.ChatMessageRoleUser, + Content: schemas.ChatMessageContent{ + ContentStr: bifrost.Ptr("Return a JSON object with name and age fields."), }, }, }, - Params: &schemas.ModelParameters{ + Params: &schemas.ChatParameters{ ExtraParams: map[string]any{ "stream": true, "response_format": map[string]any{ diff --git a/plugins/logging/main.go b/plugins/logging/main.go index 57f6e5330d..36de1da9d8 100644 --- a/plugins/logging/main.go +++ b/plugins/logging/main.go @@ -7,6 +7,7 @@ import ( "context" "errors" "fmt" + "strings" "sync" "sync/atomic" "time" @@ -44,9 +45,9 @@ type UpdateLogData struct { Status string TokenUsage *schemas.LLMUsage Cost *float64 // Cost in dollars from pricing plugin - OutputMessage *schemas.BifrostMessage + OutputMessage *schemas.ChatMessage EmbeddingOutput *[]schemas.BifrostEmbedding - ToolCalls *[]schemas.ToolCall + ToolCalls *[]schemas.ChatAssistantMessageToolCall ErrorDetails *schemas.BifrostError Model string // May be different from request Object string // May be different from request @@ -82,11 +83,11 @@ type InitialLogData struct { Provider string Model string Object string - InputHistory []schemas.BifrostMessage - Params *schemas.ModelParameters + InputHistory []schemas.ChatMessage + Params interface{} SpeechInput *schemas.SpeechInput TranscriptionInput *schemas.TranscriptionInput - Tools *[]schemas.Tool + Tools *[]schemas.ChatTool } // LogCallback is a function that gets called when a new log entry is created @@ -273,28 +274,43 @@ func (p *LoggerPlugin) PreHook(ctx *context.Context, req *schemas.BifrostRequest return req, nil, nil } - requestType, ok := (*ctx).Value(schemas.BifrostContextKeyRequestType).(schemas.RequestType) - if !ok { - p.logger.Error("request type not found in context") - return req, nil, nil - } - // Prepare initial log data - objectType := p.determineObjectType(requestType) - inputHistory := p.extractInputHistory(req.Input) + objectType := p.determineObjectType(req.RequestType) + inputHistory := p.extractInputHistory(req) initialData := &InitialLogData{ - Provider: string(req.Provider), - Model: req.Model, - Object: objectType, - InputHistory: inputHistory, - Params: req.Params, - SpeechInput: req.Input.SpeechInput, - TranscriptionInput: req.Input.TranscriptionInput, + Provider: string(req.Provider), + Model: req.Model, + Object: objectType, + InputHistory: inputHistory, } - if req.Params != nil && req.Params.Tools != nil { - initialData.Tools = req.Params.Tools + switch req.RequestType { + case schemas.TextCompletionRequest: + initialData.Params = req.TextCompletionRequest.Params + case schemas.ChatCompletionRequest, schemas.ChatCompletionStreamRequest: + initialData.Params = req.ChatRequest.Params + if req.ChatRequest.Params != nil && req.ChatRequest.Params.Tools != nil { + initialData.Tools = &req.ChatRequest.Params.Tools + } + case schemas.ResponsesRequest, schemas.ResponsesStreamRequest: + initialData.Params = req.ResponsesRequest.Params + + if req.ResponsesRequest.Params != nil && req.ResponsesRequest.Params.Tools != nil { + var tools []schemas.ChatTool + for _, tool := range req.ResponsesRequest.Params.Tools { + tools = append(tools, *tool.ToChatTool()) + } + initialData.Tools = &tools + } + case schemas.EmbeddingRequest: + initialData.Params = req.EmbeddingRequest.Params + case schemas.SpeechRequest, schemas.SpeechStreamRequest: + initialData.Params = req.SpeechRequest.Params + initialData.SpeechInput = &req.SpeechRequest.Input + case schemas.TranscriptionRequest, schemas.TranscriptionStreamRequest: + initialData.Params = req.TranscriptionRequest.Params + initialData.TranscriptionInput = &req.TranscriptionRequest.Input } // Store created timestamp in context for latency calculation optimization @@ -361,23 +377,14 @@ func (p *LoggerPlugin) PostHook(ctx *context.Context, result *schemas.BifrostRes return result, err, nil } - provider, ok := (*ctx).Value(schemas.BifrostContextKeyRequestProvider).(schemas.ModelProvider) - if !ok { - p.logger.Error("provider not found in context") - return result, err, nil - } + var requestType schemas.RequestType - model, ok := (*ctx).Value(schemas.BifrostContextKeyRequestModel).(string) - if !ok { - p.logger.Error("model not found in context") - return result, err, nil - } - // Check if this is a streaming response - requestType, ok := (*ctx).Value(schemas.BifrostContextKeyRequestType).(schemas.RequestType) - if !ok { - p.logger.Error("request type missing/invalid in PostHook for request %s", requestID) - return result, err, nil + if result != nil { + requestType = result.ExtraFields.RequestType + } else { + requestType = err.ExtraFields.RequestType } + isAudioStreaming := requestType == schemas.SpeechStreamRequest || requestType == schemas.TranscriptionStreamRequest isChatStreaming := requestType == schemas.ChatCompletionStreamRequest @@ -403,7 +410,7 @@ func (p *LoggerPlugin) PostHook(ctx *context.Context, result *schemas.BifrostRes streamUpdateData.ErrorDetails = err } else if result != nil { if result.Model != "" { - streamUpdateData.Model = model + streamUpdateData.Model = result.Model } // Update object type if available @@ -460,7 +467,7 @@ func (p *LoggerPlugin) PostHook(ctx *context.Context, result *schemas.BifrostRes updateData.Status = "success" if result.Model != "" { - updateData.Model = model + updateData.Model = result.Model } // Update object type if available @@ -468,23 +475,55 @@ func (p *LoggerPlugin) PostHook(ctx *context.Context, result *schemas.BifrostRes updateData.Object = result.Object } - // Token usage - if result.Usage != nil && result.Usage.TotalTokens > 0 { - updateData.TokenUsage = result.Usage + // Token usage - handle both regular usage and responses API usage + if result.Usage != nil { + // For responses API, TotalTokens might not be set, but we can calculate it + if result.Usage.TotalTokens > 0 { + updateData.TokenUsage = result.Usage + } else if result.Usage.ResponsesExtendedResponseUsage != nil { + // For responses API, calculate total from input + output tokens + totalTokens := result.Usage.ResponsesExtendedResponseUsage.InputTokens + + result.Usage.ResponsesExtendedResponseUsage.OutputTokens + + if totalTokens > 0 { + // Create a copy of usage with calculated total + usageCopy := *result.Usage + usageCopy.TotalTokens = totalTokens + usageCopy.PromptTokens = result.Usage.ResponsesExtendedResponseUsage.InputTokens + usageCopy.CompletionTokens = result.Usage.ResponsesExtendedResponseUsage.OutputTokens + updateData.TokenUsage = &usageCopy + } + } } - // Output message and tool calls - if len(result.Choices) > 0 { - choice := result.Choices[0] + // Output message and tool calls - handle both chat completions and responses API + // Check if this is a chat completions response (has ChatCompletionsExtendedResponse) + if result != nil && (len(result.Choices) > 0 || result.ResponsesResponse != nil) { + var choice schemas.BifrostChatResponseChoice + + if result.ResponsesResponse != nil { + if len(result.ResponsesResponse.Output) > 0 { + messages := schemas.ToChatMessages(result.ResponsesResponse.Output) + if len(messages) > 0 { + choice = schemas.BifrostChatResponseChoice{ + BifrostNonStreamResponseChoice: &schemas.BifrostNonStreamResponseChoice{ + Message: messages[0], + }, + } + } + } + } else { + choice = result.Choices[0] + } - // Check if this is a non-stream response choice + // Check if this is a non-stream response choice (chat completions) if choice.BifrostNonStreamResponseChoice != nil { updateData.OutputMessage = &choice.BifrostNonStreamResponseChoice.Message // Extract tool calls if present - if choice.BifrostNonStreamResponseChoice.Message.AssistantMessage != nil && - choice.BifrostNonStreamResponseChoice.Message.AssistantMessage.ToolCalls != nil { - updateData.ToolCalls = choice.BifrostNonStreamResponseChoice.Message.AssistantMessage.ToolCalls + if choice.BifrostNonStreamResponseChoice.Message.ChatAssistantMessage != nil && + choice.BifrostNonStreamResponseChoice.Message.ChatAssistantMessage.ToolCalls != nil { + updateData.ToolCalls = choice.BifrostNonStreamResponseChoice.Message.ChatAssistantMessage.ToolCalls } } } @@ -547,11 +586,11 @@ func (p *LoggerPlugin) PostHook(ctx *context.Context, result *schemas.BifrostRes } if logMsg.UpdateData != nil && p.pricingManager != nil { - cost := p.pricingManager.CalculateCostWithCacheDebug(result, provider, model, requestType) + cost := p.pricingManager.CalculateCostWithCacheDebug(result) logMsg.UpdateData.Cost = &cost } if logMsg.StreamUpdateData != nil && isFinalChunk && p.pricingManager != nil { - cost := p.pricingManager.CalculateCostWithCacheDebug(result, provider, model, requestType) + cost := p.pricingManager.CalculateCostWithCacheDebug(result) logMsg.StreamUpdateData.Cost = &cost } @@ -636,43 +675,60 @@ func (p *LoggerPlugin) determineObjectType(requestType schemas.RequestType) stri } // extractInputHistory extracts input history from request input -func (p *LoggerPlugin) extractInputHistory(input schemas.RequestInput) []schemas.BifrostMessage { - if input.ChatCompletionInput != nil { - return *input.ChatCompletionInput +func (p *LoggerPlugin) extractInputHistory(request *schemas.BifrostRequest) []schemas.ChatMessage { + if request.ChatRequest != nil { + return request.ChatRequest.Input + } + if request.ResponsesRequest != nil { + messages := schemas.ToChatMessages(request.ResponsesRequest.Input) + if len(messages) > 0 { + return messages + } } - if input.TextCompletionInput != nil { - // Convert text completion to message format - return []schemas.BifrostMessage{ + if request.TextCompletionRequest != nil { + var text string + if request.TextCompletionRequest.Input.PromptStr != nil { + text = *request.TextCompletionRequest.Input.PromptStr + } else { + var stringBuilder strings.Builder + for _, prompt := range request.TextCompletionRequest.Input.PromptArray { + stringBuilder.WriteString(prompt) + } + text = stringBuilder.String() + } + return []schemas.ChatMessage{ { - Role: schemas.ModelChatMessageRoleUser, - Content: schemas.MessageContent{ - ContentStr: input.TextCompletionInput, + Role: schemas.ChatMessageRoleUser, + Content: schemas.ChatMessageContent{ + ContentStr: &text, }, }, } } - if input.EmbeddingInput != nil { - texts := input.EmbeddingInput.Texts + if request.EmbeddingRequest != nil { + texts := request.EmbeddingRequest.Input.Texts - if len(texts) == 0 && input.EmbeddingInput.Text != nil { - texts = []string{*input.EmbeddingInput.Text} + if len(texts) == 0 && request.EmbeddingRequest.Input.Text != nil { + texts = []string{*request.EmbeddingRequest.Input.Text} } - contentBlocks := make([]schemas.ContentBlock, len(texts)) + contentBlocks := make([]schemas.ChatContentBlock, len(texts)) for i, text := range texts { - contentBlocks[i] = schemas.ContentBlock{ - Type: schemas.ContentBlockTypeText, - Text: &text, + // Create a per-iteration copy to avoid reusing the same memory address + t := text + contentBlocks[i] = schemas.ChatContentBlock{ + Type: schemas.ChatContentBlockTypeText, + Text: &t, } } - return []schemas.BifrostMessage{ + return []schemas.ChatMessage{ { - Role: schemas.ModelChatMessageRoleUser, - Content: schemas.MessageContent{ + Role: schemas.ChatMessageRoleUser, + Content: schemas.ChatMessageContent{ ContentBlocks: &contentBlocks, }, }, } } - return []schemas.BifrostMessage{} + return []schemas.ChatMessage{} } diff --git a/plugins/logging/operations.go b/plugins/logging/operations.go index 7d9173f330..159ebf314b 100644 --- a/plugins/logging/operations.go +++ b/plugins/logging/operations.go @@ -299,30 +299,30 @@ func (p *LoggerPlugin) prepareDeltaUpdates(requestID string, delta *schemas.Bifr } // Parse existing message or create new one - var outputMessage *schemas.BifrostMessage + var outputMessage *schemas.ChatMessage if currentEntry.OutputMessage != "" { - outputMessage = &schemas.BifrostMessage{} + outputMessage = &schemas.ChatMessage{} // Attempt to deserialize; use parsed message only if successful if err := currentEntry.DeserializeFields(); err == nil && currentEntry.OutputMessageParsed != nil { outputMessage = currentEntry.OutputMessageParsed } else { // Create new message if parsing fails - outputMessage = &schemas.BifrostMessage{ - Role: schemas.ModelChatMessageRoleAssistant, - Content: schemas.MessageContent{}, + outputMessage = &schemas.ChatMessage{ + Role: schemas.ChatMessageRoleAssistant, + Content: schemas.ChatMessageContent{}, } } } else { // Create new message - outputMessage = &schemas.BifrostMessage{ - Role: schemas.ModelChatMessageRoleAssistant, - Content: schemas.MessageContent{}, + outputMessage = &schemas.ChatMessage{ + Role: schemas.ChatMessageRoleAssistant, + Content: schemas.ChatMessageContent{}, } } // Handle role (usually in first chunk) if delta.Role != nil { - outputMessage.Role = schemas.ModelChatMessageRole(*delta.Role) + outputMessage.Role = schemas.ChatMessageRole(*delta.Role) } // Append content @@ -332,13 +332,13 @@ func (p *LoggerPlugin) prepareDeltaUpdates(requestID string, delta *schemas.Bifr // Handle refusal if delta.Refusal != nil && *delta.Refusal != "" { - if outputMessage.AssistantMessage == nil { - outputMessage.AssistantMessage = &schemas.AssistantMessage{} + if outputMessage.ChatAssistantMessage == nil { + outputMessage.ChatAssistantMessage = &schemas.ChatAssistantMessage{} } - if outputMessage.AssistantMessage.Refusal == nil { - outputMessage.AssistantMessage.Refusal = delta.Refusal + if outputMessage.ChatAssistantMessage.Refusal == nil { + outputMessage.ChatAssistantMessage.Refusal = delta.Refusal } else { - *outputMessage.AssistantMessage.Refusal += *delta.Refusal + *outputMessage.ChatAssistantMessage.Refusal += *delta.Refusal } } @@ -351,8 +351,8 @@ func (p *LoggerPlugin) prepareDeltaUpdates(requestID string, delta *schemas.Bifr tempEntry := &logstore.Log{ OutputMessageParsed: outputMessage, } - if outputMessage.AssistantMessage != nil && outputMessage.AssistantMessage.ToolCalls != nil { - tempEntry.ToolCallsParsed = outputMessage.AssistantMessage.ToolCalls + if outputMessage.ChatAssistantMessage != nil && outputMessage.ChatAssistantMessage.ToolCalls != nil { + tempEntry.ToolCallsParsed = outputMessage.ChatAssistantMessage.ToolCalls } if err := tempEntry.SerializeFields(); err != nil { diff --git a/plugins/logging/streaming.go b/plugins/logging/streaming.go index 4cdaac72a3..2bb59e0841 100644 --- a/plugins/logging/streaming.go +++ b/plugins/logging/streaming.go @@ -13,7 +13,7 @@ import ( ) // appendContentToMessage efficiently appends content to a message -func (p *LoggerPlugin) appendContentToMessage(message *schemas.BifrostMessage, newContent string) { +func (p *LoggerPlugin) appendContentToMessage(message *schemas.ChatMessage, newContent string) { if message == nil { return } @@ -23,13 +23,13 @@ func (p *LoggerPlugin) appendContentToMessage(message *schemas.BifrostMessage, n } else if message.Content.ContentBlocks != nil { // Find the last text block and append, or create new one blocks := *message.Content.ContentBlocks - if len(blocks) > 0 && blocks[len(blocks)-1].Type == schemas.ContentBlockTypeText && blocks[len(blocks)-1].Text != nil { + if len(blocks) > 0 && blocks[len(blocks)-1].Type == schemas.ChatContentBlockTypeText && blocks[len(blocks)-1].Text != nil { // Append to last text block *blocks[len(blocks)-1].Text += newContent } else { // Create new text block - blocks = append(blocks, schemas.ContentBlock{ - Type: schemas.ContentBlockTypeText, + blocks = append(blocks, schemas.ChatContentBlock{ + Type: schemas.ChatContentBlockTypeText, Text: &newContent, }) message.Content.ContentBlocks = &blocks @@ -41,19 +41,19 @@ func (p *LoggerPlugin) appendContentToMessage(message *schemas.BifrostMessage, n } // accumulateToolCallsInMessage efficiently accumulates tool calls in a message -func (p *LoggerPlugin) accumulateToolCallsInMessage(message *schemas.BifrostMessage, deltaToolCalls []schemas.ToolCall) { +func (p *LoggerPlugin) accumulateToolCallsInMessage(message *schemas.ChatMessage, deltaToolCalls []schemas.ChatAssistantMessageToolCall) { if message == nil { return } - if message.AssistantMessage == nil { - message.AssistantMessage = &schemas.AssistantMessage{} + if message.ChatAssistantMessage == nil { + message.ChatAssistantMessage = &schemas.ChatAssistantMessage{} } - if message.AssistantMessage.ToolCalls == nil { - message.AssistantMessage.ToolCalls = &[]schemas.ToolCall{} + if message.ChatAssistantMessage.ToolCalls == nil { + message.ChatAssistantMessage.ToolCalls = &[]schemas.ChatAssistantMessageToolCall{} } - existingToolCalls := *message.AssistantMessage.ToolCalls + existingToolCalls := *message.ChatAssistantMessage.ToolCalls for _, deltaToolCall := range deltaToolCalls { // Find existing tool call with same ID or create new one @@ -72,7 +72,7 @@ func (p *LoggerPlugin) accumulateToolCallsInMessage(message *schemas.BifrostMess existingToolCalls = append(existingToolCalls, deltaToolCall) } } - message.AssistantMessage.ToolCalls = &existingToolCalls + message.ChatAssistantMessage.ToolCalls = &existingToolCalls } // Stream accumulator helper methods @@ -159,8 +159,8 @@ func (p *LoggerPlugin) processAccumulatedChunks(ctx context.Context, requestID s tempEntry := &logstore.Log{ OutputMessageParsed: completeMessage, } - if completeMessage.AssistantMessage != nil && completeMessage.AssistantMessage.ToolCalls != nil { - tempEntry.ToolCallsParsed = completeMessage.AssistantMessage.ToolCalls + if completeMessage.ChatAssistantMessage != nil && completeMessage.ChatAssistantMessage.ToolCalls != nil { + tempEntry.ToolCallsParsed = completeMessage.ChatAssistantMessage.ToolCalls } if err := tempEntry.SerializeFields(); err != nil { return fmt.Errorf("failed to serialize complete message: %w", err) @@ -236,10 +236,10 @@ func (p *LoggerPlugin) processAccumulatedChunks(ctx context.Context, requestID s } // buildCompleteMessageFromChunks builds a complete message from ordered chunks -func (p *LoggerPlugin) buildCompleteMessageFromChunks(chunks []*StreamChunk) *schemas.BifrostMessage { - completeMessage := &schemas.BifrostMessage{ - Role: schemas.ModelChatMessageRoleAssistant, - Content: schemas.MessageContent{}, +func (p *LoggerPlugin) buildCompleteMessageFromChunks(chunks []*StreamChunk) *schemas.ChatMessage { + completeMessage := &schemas.ChatMessage{ + Role: schemas.ChatMessageRoleAssistant, + Content: schemas.ChatMessageContent{}, } for _, chunk := range chunks { @@ -249,7 +249,7 @@ func (p *LoggerPlugin) buildCompleteMessageFromChunks(chunks []*StreamChunk) *sc // Handle role (usually in first chunk) if chunk.Delta.Role != nil { - completeMessage.Role = schemas.ModelChatMessageRole(*chunk.Delta.Role) + completeMessage.Role = schemas.ChatMessageRole(*chunk.Delta.Role) } // Append content @@ -259,13 +259,13 @@ func (p *LoggerPlugin) buildCompleteMessageFromChunks(chunks []*StreamChunk) *sc // Handle refusal if chunk.Delta.Refusal != nil && *chunk.Delta.Refusal != "" { - if completeMessage.AssistantMessage == nil { - completeMessage.AssistantMessage = &schemas.AssistantMessage{} + if completeMessage.ChatAssistantMessage == nil { + completeMessage.ChatAssistantMessage = &schemas.ChatAssistantMessage{} } - if completeMessage.AssistantMessage.Refusal == nil { - completeMessage.AssistantMessage.Refusal = chunk.Delta.Refusal + if completeMessage.ChatAssistantMessage.Refusal == nil { + completeMessage.ChatAssistantMessage.Refusal = chunk.Delta.Refusal } else { - *completeMessage.AssistantMessage.Refusal += *chunk.Delta.Refusal + *completeMessage.ChatAssistantMessage.Refusal += *chunk.Delta.Refusal } } @@ -331,24 +331,6 @@ func (p *LoggerPlugin) handleStreamingResponse(ctx *context.Context, result *sch return result, err, nil } - provider, ok := (*ctx).Value(schemas.BifrostContextKeyRequestProvider).(schemas.ModelProvider) - if !ok { - p.logger.Error("provider not found in context") - return result, err, nil - } - - model, ok := (*ctx).Value(schemas.BifrostContextKeyRequestModel).(string) - if !ok { - p.logger.Error("model not found in context") - return result, err, nil - } - - requestType, ok := (*ctx).Value(schemas.BifrostContextKeyRequestType).(schemas.RequestType) - if !ok { - p.logger.Error("request type not found in context") - return result, err, nil - } - // Create chunk from current response using pool chunk := p.getStreamChunk() chunk.Timestamp = time.Now() @@ -383,7 +365,7 @@ func (p *LoggerPlugin) handleStreamingResponse(ctx *context.Context, result *sch if result != nil { if isFinalChunk { if p.pricingManager != nil { - cost := p.pricingManager.CalculateCostWithCacheDebug(result, provider, model, requestType) + cost := p.pricingManager.CalculateCostWithCacheDebug(result) chunk.Cost = bifrost.Ptr(cost) } chunk.SemanticCacheDebug = result.ExtraFields.CacheDebug diff --git a/plugins/maxim/main.go b/plugins/maxim/main.go index 8988ba7dca..2f615207d8 100644 --- a/plugins/maxim/main.go +++ b/plugins/maxim/main.go @@ -6,6 +6,7 @@ import ( "context" "encoding/json" "fmt" + "strings" "sync" "github.com/google/uuid" @@ -236,7 +237,6 @@ func (plugin *Plugin) PreHook(ctx *context.Context, req *schemas.BifrostRequest) } // Determine request type and set appropriate tags - var requestType string var messages []logging.CompletionRequest var latestMessage string @@ -248,23 +248,47 @@ func (plugin *Plugin) PreHook(ctx *context.Context, req *schemas.BifrostRequest) // Add model to tags tags["model"] = req.Model - if req.Input.ChatCompletionInput != nil { - requestType = "chat_completion" - for _, message := range *req.Input.ChatCompletionInput { + modelParams := make(map[string]interface{}) + + switch req.RequestType { + case schemas.TextCompletionRequest: + messages = append(messages, logging.CompletionRequest{ + Role: string(schemas.ChatMessageRoleUser), + Content: req.TextCompletionRequest.Input, + }) + if req.TextCompletionRequest.Input.PromptStr != nil { + latestMessage = *req.TextCompletionRequest.Input.PromptStr + } else { + var stringBuilder strings.Builder + for _, prompt := range req.TextCompletionRequest.Input.PromptArray { + stringBuilder.WriteString(prompt) + } + latestMessage = stringBuilder.String() + } + + if req.TextCompletionRequest.Params != nil { + // Convert the struct to a map using reflection or JSON marshaling + jsonData, err := json.Marshal(req.TextCompletionRequest.Params) + if err == nil { + json.Unmarshal(jsonData, &modelParams) + } + } + case schemas.ChatCompletionRequest: + for _, message := range req.ChatRequest.Input { messages = append(messages, logging.CompletionRequest{ Role: string(message.Role), Content: message.Content, }) } - if len(*req.Input.ChatCompletionInput) > 0 { - lastMsg := (*req.Input.ChatCompletionInput)[len(*req.Input.ChatCompletionInput)-1] + if len(req.ChatRequest.Input) > 0 { + lastMsg := req.ChatRequest.Input[len(req.ChatRequest.Input)-1] if lastMsg.Content.ContentStr != nil { latestMessage = *lastMsg.Content.ContentStr } else if lastMsg.Content.ContentBlocks != nil { // Find the last text content block for i := len(*lastMsg.Content.ContentBlocks) - 1; i >= 0; i-- { block := (*lastMsg.Content.ContentBlocks)[i] - if block.Type == "text" && block.Text != nil { + if block.Type == schemas.ChatContentBlockTypeText && block.Text != nil { latestMessage = *block.Text break } @@ -275,21 +299,65 @@ func (plugin *Plugin) PreHook(ctx *context.Context, req *schemas.BifrostRequest) } } } - } else if req.Input.TextCompletionInput != nil { - requestType = "text_completion" - messages = append(messages, logging.CompletionRequest{ - Role: string(schemas.ModelChatMessageRoleUser), - Content: req.Input.TextCompletionInput, - }) - latestMessage = *req.Input.TextCompletionInput + + if req.ChatRequest.Params != nil { + // Convert the struct to a map using reflection or JSON marshaling + jsonData, err := json.Marshal(req.ChatRequest.Params) + if err == nil { + json.Unmarshal(jsonData, &modelParams) + } + } + case schemas.ResponsesRequest: + for _, message := range req.ResponsesRequest.Input { + if message.Content != nil { + role := schemas.ChatMessageRoleUser + if message.Role != nil { + role = schemas.ChatMessageRole(*message.Role) + } + messages = append(messages, logging.CompletionRequest{ + Role: string(role), + Content: message.Content, + }) + } + } + if len(req.ResponsesRequest.Input) > 0 { + lastMsg := req.ResponsesRequest.Input[len(req.ResponsesRequest.Input)-1] + // Initialize to placeholder in case content is missing or empty + latestMessage = "-" + + // Check if Content is nil before accessing its fields + if lastMsg.Content != nil { + if lastMsg.Content.ContentStr != nil { + latestMessage = *lastMsg.Content.ContentStr + } else if lastMsg.Content.ContentBlocks != nil { + // Find the last text content block + for i := len(*lastMsg.Content.ContentBlocks) - 1; i >= 0; i-- { + block := (*lastMsg.Content.ContentBlocks)[i] + if block.Text != nil { + latestMessage = *block.Text + break + } + } + // If no text block found, keep the placeholder + } + } + } + + if req.ResponsesRequest.Params != nil { + // Convert the struct to a map using reflection or JSON marshaling + jsonData, err := json.Marshal(req.ResponsesRequest.Params) + if err == nil { + json.Unmarshal(jsonData, &modelParams) + } + } } - tags["action"] = requestType + tags["action"] = string(req.RequestType) if traceID == "" { // If traceID is not set, create a new trace traceID = uuid.New().String() - name := fmt.Sprintf("bifrost_%s", requestType) + name := fmt.Sprintf("bifrost_%s", string(req.RequestType)) if traceName != "" { name = traceName } @@ -312,16 +380,6 @@ func (plugin *Plugin) PreHook(ctx *context.Context, req *schemas.BifrostRequest) } } - // Convert ModelParameters to map[string]interface{} - modelParams := make(map[string]interface{}) - if req.Params != nil { - // Convert the struct to a map using reflection or JSON marshaling - jsonData, err := json.Marshal(req.Params) - if err == nil { - json.Unmarshal(jsonData, &modelParams) - } - } - generationID := uuid.New().String() generationConfig := logging.GenerationConfig{ diff --git a/plugins/maxim/plugin_test.go b/plugins/maxim/plugin_test.go index 9d69fe404b..4f708bb40e 100644 --- a/plugins/maxim/plugin_test.go +++ b/plugins/maxim/plugin_test.go @@ -103,16 +103,14 @@ func TestMaximLoggerPlugin(t *testing.T) { } // Make a test chat completion request - _, bifrostErr := client.ChatCompletionRequest(context.Background(), &schemas.BifrostRequest{ + _, bifrostErr := client.ChatCompletionRequest(context.Background(), &schemas.BifrostChatRequest{ Provider: schemas.OpenAI, Model: "gpt-4o-mini", - Input: schemas.RequestInput{ - ChatCompletionInput: &[]schemas.BifrostMessage{ - { - Role: "user", - Content: schemas.MessageContent{ - ContentStr: bifrost.Ptr("Hello, how are you?"), - }, + Input: []schemas.ChatMessage{ + { + Role: "user", + Content: schemas.ChatMessageContent{ + ContentStr: bifrost.Ptr("Hello, how are you?"), }, }, }, diff --git a/plugins/mocker/benchmark_test.go b/plugins/mocker/benchmark_test.go index 5e26311c75..1d1fea9e53 100644 --- a/plugins/mocker/benchmark_test.go +++ b/plugins/mocker/benchmark_test.go @@ -37,16 +37,14 @@ func BenchmarkMockerPlugin_PreHook_SimpleRule(b *testing.B) { b.Fatal(err) } - req := &schemas.BifrostRequest{ + req := &schemas.BifrostChatRequest{ Provider: schemas.OpenAI, Model: "gpt-4", - Input: schemas.RequestInput{ - ChatCompletionInput: &[]schemas.BifrostMessage{ - { - Role: schemas.ModelChatMessageRoleUser, - Content: schemas.MessageContent{ - ContentStr: bifrost.Ptr("Hello, benchmark test"), - }, + Input: []schemas.ChatMessage{ + { + Role: schemas.ChatMessageRoleUser, + Content: schemas.ChatMessageContent{ + ContentStr: bifrost.Ptr("Hello, benchmark test"), }, }, }, @@ -57,8 +55,16 @@ func BenchmarkMockerPlugin_PreHook_SimpleRule(b *testing.B) { b.ResetTimer() b.ReportAllocs() + // Convert to BifrostRequest for PreHook compatibility + bifrostReq := &schemas.BifrostRequest{ + Provider: req.Provider, + Model: req.Model, + RequestType: schemas.ChatCompletionRequest, + ChatRequest: req, + } + for i := 0; i < b.N; i++ { - _, _, _ = plugin.PreHook(&ctx, req) + _, _, _ = plugin.PreHook(&ctx, bifrostReq) } } @@ -90,16 +96,14 @@ func BenchmarkMockerPlugin_PreHook_RegexRule(b *testing.B) { b.Fatal(err) } - req := &schemas.BifrostRequest{ + req := &schemas.BifrostChatRequest{ Provider: schemas.OpenAI, Model: "gpt-4", - Input: schemas.RequestInput{ - ChatCompletionInput: &[]schemas.BifrostMessage{ - { - Role: schemas.ModelChatMessageRoleUser, - Content: schemas.MessageContent{ - ContentStr: bifrost.Ptr("Hello, this should match the regex pattern"), - }, + Input: []schemas.ChatMessage{ + { + Role: schemas.ChatMessageRoleUser, + Content: schemas.ChatMessageContent{ + ContentStr: bifrost.Ptr("Hello, this should match the regex pattern"), }, }, }, @@ -110,8 +114,16 @@ func BenchmarkMockerPlugin_PreHook_RegexRule(b *testing.B) { b.ResetTimer() b.ReportAllocs() + // Convert to BifrostRequest for PreHook compatibility + bifrostReq := &schemas.BifrostRequest{ + Provider: req.Provider, + Model: req.Model, + RequestType: schemas.ChatCompletionRequest, + ChatRequest: req, + } + for i := 0; i < b.N; i++ { - _, _, _ = plugin.PreHook(&ctx, req) + _, _, _ = plugin.PreHook(&ctx, bifrostReq) } } @@ -165,16 +177,14 @@ func BenchmarkMockerPlugin_PreHook_MultipleRules(b *testing.B) { b.Fatal(err) } - req := &schemas.BifrostRequest{ + req := &schemas.BifrostChatRequest{ Provider: schemas.OpenAI, Model: "gpt-4", - Input: schemas.RequestInput{ - ChatCompletionInput: &[]schemas.BifrostMessage{ - { - Role: schemas.ModelChatMessageRoleUser, - Content: schemas.MessageContent{ - ContentStr: bifrost.Ptr("Test message"), - }, + Input: []schemas.ChatMessage{ + { + Role: schemas.ChatMessageRoleUser, + Content: schemas.ChatMessageContent{ + ContentStr: bifrost.Ptr("Test message"), }, }, }, @@ -185,8 +195,16 @@ func BenchmarkMockerPlugin_PreHook_MultipleRules(b *testing.B) { b.ResetTimer() b.ReportAllocs() + // Convert to BifrostRequest for PreHook compatibility + bifrostReq := &schemas.BifrostRequest{ + Provider: req.Provider, + Model: req.Model, + RequestType: schemas.ChatCompletionRequest, + ChatRequest: req, + } + for i := 0; i < b.N; i++ { - _, _, _ = plugin.PreHook(&ctx, req) + _, _, _ = plugin.PreHook(&ctx, bifrostReq) } } @@ -219,16 +237,14 @@ func BenchmarkMockerPlugin_PreHook_NoMatch(b *testing.B) { b.Fatal(err) } - req := &schemas.BifrostRequest{ + req := &schemas.BifrostChatRequest{ Provider: schemas.OpenAI, // Different from rule condition Model: "gpt-4", - Input: schemas.RequestInput{ - ChatCompletionInput: &[]schemas.BifrostMessage{ - { - Role: schemas.ModelChatMessageRoleUser, - Content: schemas.MessageContent{ - ContentStr: bifrost.Ptr("Test message"), - }, + Input: []schemas.ChatMessage{ + { + Role: schemas.ChatMessageRoleUser, + Content: schemas.ChatMessageContent{ + ContentStr: bifrost.Ptr("Test message"), }, }, }, @@ -239,8 +255,16 @@ func BenchmarkMockerPlugin_PreHook_NoMatch(b *testing.B) { b.ResetTimer() b.ReportAllocs() + // Convert to BifrostRequest for PreHook compatibility + bifrostReq := &schemas.BifrostRequest{ + Provider: req.Provider, + Model: req.Model, + RequestType: schemas.ChatCompletionRequest, + ChatRequest: req, + } + for i := 0; i < b.N; i++ { - _, _, _ = plugin.PreHook(&ctx, req) + _, _, _ = plugin.PreHook(&ctx, bifrostReq) } } @@ -270,16 +294,14 @@ func BenchmarkMockerPlugin_PreHook_Template(b *testing.B) { b.Fatal(err) } - req := &schemas.BifrostRequest{ + req := &schemas.BifrostChatRequest{ Provider: schemas.OpenAI, Model: "gpt-4", - Input: schemas.RequestInput{ - ChatCompletionInput: &[]schemas.BifrostMessage{ - { - Role: schemas.ModelChatMessageRoleUser, - Content: schemas.MessageContent{ - ContentStr: bifrost.Ptr("Test message"), - }, + Input: []schemas.ChatMessage{ + { + Role: schemas.ChatMessageRoleUser, + Content: schemas.ChatMessageContent{ + ContentStr: bifrost.Ptr("Test message"), }, }, }, @@ -290,7 +312,15 @@ func BenchmarkMockerPlugin_PreHook_Template(b *testing.B) { b.ResetTimer() b.ReportAllocs() + // Convert to BifrostRequest for PreHook compatibility + bifrostReq := &schemas.BifrostRequest{ + Provider: req.Provider, + Model: req.Model, + RequestType: schemas.ChatCompletionRequest, + ChatRequest: req, + } + for i := 0; i < b.N; i++ { - _, _, _ = plugin.PreHook(&ctx, req) + _, _, _ = plugin.PreHook(&ctx, bifrostReq) } } diff --git a/plugins/mocker/main.go b/plugins/mocker/main.go index cb850af7be..5b8627b4aa 100644 --- a/plugins/mocker/main.go +++ b/plugins/mocker/main.go @@ -633,37 +633,77 @@ func (p *MockerPlugin) matchesConditionsFast(req *schemas.BifrostRequest, condit // extractMessageContentFast extracts message content with optimized performance func (p *MockerPlugin) extractMessageContentFast(req *schemas.BifrostRequest) string { - // Handle text completion input - if req.Input.TextCompletionInput != nil { - return *req.Input.TextCompletionInput - } - - // Handle chat completion input - optimized for common cases - if req.Input.ChatCompletionInput != nil { - messages := *req.Input.ChatCompletionInput - if len(messages) == 0 { - return "" + switch req.RequestType { + case schemas.TextCompletionRequest: + // Handle text completion input + if req.TextCompletionRequest.Input.PromptStr != nil { + return *req.TextCompletionRequest.Input.PromptStr + } else { + var stringBuilder strings.Builder + for _, prompt := range req.TextCompletionRequest.Input.PromptArray { + stringBuilder.WriteString(prompt) + } + return stringBuilder.String() } + case schemas.ChatCompletionRequest, schemas.ChatCompletionStreamRequest: + // Handle chat completion input - optimized for common cases + if req.ChatRequest.Input != nil { + messages := req.ChatRequest.Input + if len(messages) == 0 { + return "" + } - // Fast path for single message - if len(messages) == 1 { - if messages[0].Content.ContentStr != nil { - return *messages[0].Content.ContentStr + // Fast path for single message + if len(messages) == 1 { + if messages[0].Content.ContentStr != nil { + return *messages[0].Content.ContentStr + } + return "" } - return "" + + // Multiple messages - use string builder for efficiency + var builder strings.Builder + for i, message := range messages { + if message.Content.ContentStr != nil { + if i > 0 { + builder.WriteByte(' ') + } + builder.WriteString(*message.Content.ContentStr) + } + } + return builder.String() } + case schemas.ResponsesRequest, schemas.ResponsesStreamRequest: + // Handle responses input - optimized for common cases + if req.ResponsesRequest.Input != nil { + messages := req.ResponsesRequest.Input + if len(messages) == 0 { + return "" + } - // Multiple messages - use string builder for efficiency - var builder strings.Builder - for i, message := range messages { - if message.Content.ContentStr != nil { + // Fast path for single message + if len(messages) == 1 { + if messages[0].Content != nil && messages[0].Content.ContentStr != nil { + return *messages[0].Content.ContentStr + } + return "" + } + + // Multiple messages - use string builder for efficiency + var builder strings.Builder + for i, message := range messages { + if message.Content == nil || message.Content.ContentStr == nil { + continue + } if i > 0 { builder.WriteByte(' ') } builder.WriteString(*message.Content.ContentStr) } + return builder.String() } - return builder.String() + default: + return "" } return "" @@ -675,12 +715,18 @@ func (p *MockerPlugin) calculateRequestSizeFast(req *schemas.BifrostRequest) int size := len(req.Model) + len(string(req.Provider)) // Add input size - if req.Input.TextCompletionInput != nil { - size += len(*req.Input.TextCompletionInput) + if req.TextCompletionRequest != nil { + if req.TextCompletionRequest.Input.PromptStr != nil { + size += len(*req.TextCompletionRequest.Input.PromptStr) + } else { + for _, prompt := range req.TextCompletionRequest.Input.PromptArray { + size += len(prompt) + } + } } - if req.Input.ChatCompletionInput != nil { - for _, message := range *req.Input.ChatCompletionInput { + if req.ChatRequest.Input != nil { + for _, message := range req.ChatRequest.Input { if message.Content.ContentStr != nil { size += len(*message.Content.ContentStr) } @@ -688,6 +734,15 @@ func (p *MockerPlugin) calculateRequestSizeFast(req *schemas.BifrostRequest) int } } + if req.ResponsesRequest.Input != nil { + for _, message := range req.ResponsesRequest.Input { + if message.Content != nil && message.Content.ContentStr != nil { + size += len(*message.Content.ContentStr) + } + size += 50 // Approximate overhead for message structure + } + } + return size } @@ -736,13 +791,13 @@ func (p *MockerPlugin) generateSuccessShortCircuit(req *schemas.BifrostRequest, mockResponse := &schemas.BifrostResponse{ Model: req.Model, Usage: &usage, - Choices: []schemas.BifrostResponseChoice{ + Choices: []schemas.BifrostChatResponseChoice{ { Index: 0, BifrostNonStreamResponseChoice: &schemas.BifrostNonStreamResponseChoice{ - Message: schemas.BifrostMessage{ - Role: schemas.ModelChatMessageRoleAssistant, - Content: schemas.MessageContent{ + Message: schemas.ChatMessage{ + Role: schemas.ChatMessageRoleAssistant, + Content: schemas.ChatMessageContent{ ContentStr: &message, }, }, @@ -751,7 +806,9 @@ func (p *MockerPlugin) generateSuccessShortCircuit(req *schemas.BifrostRequest, }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: req.Provider, + RequestType: schemas.ChatCompletionRequest, + Provider: req.Provider, + ModelRequested: req.Model, }, } @@ -899,13 +956,13 @@ func (p *MockerPlugin) handleDefaultBehavior(req *schemas.BifrostRequest) (*sche CompletionTokens: 10, TotalTokens: 15, }, - Choices: []schemas.BifrostResponseChoice{ + Choices: []schemas.BifrostChatResponseChoice{ { Index: 0, BifrostNonStreamResponseChoice: &schemas.BifrostNonStreamResponseChoice{ - Message: schemas.BifrostMessage{ - Role: schemas.ModelChatMessageRoleAssistant, - Content: schemas.MessageContent{ + Message: schemas.ChatMessage{ + Role: schemas.ChatMessageRoleAssistant, + Content: schemas.ChatMessageContent{ ContentStr: bifrost.Ptr("Mock plugin default response"), }, }, @@ -914,7 +971,9 @@ func (p *MockerPlugin) handleDefaultBehavior(req *schemas.BifrostRequest) (*sche }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: req.Provider, + RequestType: schemas.ChatCompletionRequest, + Provider: req.Provider, + ModelRequested: req.Model, }, }, }, nil diff --git a/plugins/mocker/plugin_test.go b/plugins/mocker/plugin_test.go index 820019ae6b..bca2414a42 100644 --- a/plugins/mocker/plugin_test.go +++ b/plugins/mocker/plugin_test.go @@ -73,16 +73,14 @@ func TestMockerPlugin_Disabled(t *testing.T) { defer client.Shutdown() // This should pass through to the real provider (but will fail due to dummy key) - _, bifrostErr := client.ChatCompletionRequest(ctx, &schemas.BifrostRequest{ + _, bifrostErr := client.ChatCompletionRequest(ctx, &schemas.BifrostChatRequest{ Provider: schemas.OpenAI, Model: "gpt-4", - Input: schemas.RequestInput{ - ChatCompletionInput: &[]schemas.BifrostMessage{ - { - Role: schemas.ModelChatMessageRoleUser, - Content: schemas.MessageContent{ - ContentStr: bifrost.Ptr("Hello, test message"), - }, + Input: []schemas.ChatMessage{ + { + Role: schemas.ChatMessageRoleUser, + Content: schemas.ChatMessageContent{ + ContentStr: bifrost.Ptr("Hello, test message"), }, }, }, @@ -117,16 +115,14 @@ func TestMockerPlugin_DefaultMockRule(t *testing.T) { } defer client.Shutdown() - response, bifrostErr := client.ChatCompletionRequest(ctx, &schemas.BifrostRequest{ + response, bifrostErr := client.ChatCompletionRequest(ctx, &schemas.BifrostChatRequest{ Provider: schemas.OpenAI, Model: "gpt-4", - Input: schemas.RequestInput{ - ChatCompletionInput: &[]schemas.BifrostMessage{ - { - Role: schemas.ModelChatMessageRoleUser, - Content: schemas.MessageContent{ - ContentStr: bifrost.Ptr("Hello, test message"), - }, + Input: []schemas.ChatMessage{ + { + Role: schemas.ChatMessageRoleUser, + Content: schemas.ChatMessageContent{ + ContentStr: bifrost.Ptr("Hello, test message"), }, }, }, @@ -141,11 +137,11 @@ func TestMockerPlugin_DefaultMockRule(t *testing.T) { if len(response.Choices) == 0 { t.Fatal("Expected at least one choice") } - if response.Choices[0].Message.Content.ContentStr == nil { + if response.Choices[0].BifrostNonStreamResponseChoice.Message.Content.ContentStr == nil { t.Fatal("Expected content string") } - if *response.Choices[0].Message.Content.ContentStr != "This is a mock response from the Mocker plugin" { - t.Errorf("Expected default mock message, got: %s", *response.Choices[0].Message.Content.ContentStr) + if *response.Choices[0].BifrostNonStreamResponseChoice.Message.Content.ContentStr != "This is a mock response from the Mocker plugin" { + t.Errorf("Expected default mock message, got: %s", *response.Choices[0].BifrostNonStreamResponseChoice.Message.Content.ContentStr) } } @@ -195,16 +191,14 @@ func TestMockerPlugin_CustomSuccessRule(t *testing.T) { } defer client.Shutdown() - response, bifrostErr := client.ChatCompletionRequest(ctx, &schemas.BifrostRequest{ + response, bifrostErr := client.ChatCompletionRequest(ctx, &schemas.BifrostChatRequest{ Provider: schemas.OpenAI, Model: "gpt-4", - Input: schemas.RequestInput{ - ChatCompletionInput: &[]schemas.BifrostMessage{ - { - Role: schemas.ModelChatMessageRoleUser, - Content: schemas.MessageContent{ - ContentStr: bifrost.Ptr("Hello, test message"), - }, + Input: []schemas.ChatMessage{ + { + Role: schemas.ChatMessageRoleUser, + Content: schemas.ChatMessageContent{ + ContentStr: bifrost.Ptr("Hello, test message"), }, }, }, @@ -219,11 +213,11 @@ func TestMockerPlugin_CustomSuccessRule(t *testing.T) { if len(response.Choices) == 0 { t.Fatal("Expected at least one choice") } - if response.Choices[0].Message.Content.ContentStr == nil { + if response.Choices[0].BifrostNonStreamResponseChoice.Message.Content.ContentStr == nil { t.Fatal("Expected content string") } - if *response.Choices[0].Message.Content.ContentStr != "Custom OpenAI mock response" { - t.Errorf("Expected custom message, got: %s", *response.Choices[0].Message.Content.ContentStr) + if *response.Choices[0].BifrostNonStreamResponseChoice.Message.Content.ContentStr != "Custom OpenAI mock response" { + t.Errorf("Expected custom message, got: %s", *response.Choices[0].BifrostNonStreamResponseChoice.Message.Content.ContentStr) } if response.Usage.TotalTokens != 40 { t.Errorf("Expected 40 total tokens, got %d", response.Usage.TotalTokens) @@ -276,16 +270,14 @@ func TestMockerPlugin_ErrorResponse(t *testing.T) { } defer client.Shutdown() - _, bifrostErr := client.ChatCompletionRequest(ctx, &schemas.BifrostRequest{ + _, bifrostErr := client.ChatCompletionRequest(ctx, &schemas.BifrostChatRequest{ Provider: schemas.OpenAI, Model: "gpt-4", - Input: schemas.RequestInput{ - ChatCompletionInput: &[]schemas.BifrostMessage{ - { - Role: schemas.ModelChatMessageRoleUser, - Content: schemas.MessageContent{ - ContentStr: bifrost.Ptr("Hello, test message"), - }, + Input: []schemas.ChatMessage{ + { + Role: schemas.ChatMessageRoleUser, + Content: schemas.ChatMessageContent{ + ContentStr: bifrost.Ptr("Hello, test message"), }, }, }, @@ -341,16 +333,14 @@ func TestMockerPlugin_MessageTemplate(t *testing.T) { } defer client.Shutdown() - response, bifrostErr := client.ChatCompletionRequest(ctx, &schemas.BifrostRequest{ + response, bifrostErr := client.ChatCompletionRequest(ctx, &schemas.BifrostChatRequest{ Provider: schemas.Anthropic, Model: "claude-3", - Input: schemas.RequestInput{ - ChatCompletionInput: &[]schemas.BifrostMessage{ - { - Role: schemas.ModelChatMessageRoleUser, - Content: schemas.MessageContent{ - ContentStr: bifrost.Ptr("Hello, test message"), - }, + Input: []schemas.ChatMessage{ + { + Role: schemas.ChatMessageRoleUser, + Content: schemas.ChatMessageContent{ + ContentStr: bifrost.Ptr("Hello, test message"), }, }, }, @@ -365,12 +355,12 @@ func TestMockerPlugin_MessageTemplate(t *testing.T) { if len(response.Choices) == 0 { t.Fatal("Expected at least one choice") } - if response.Choices[0].Message.Content.ContentStr == nil { + if response.Choices[0].BifrostNonStreamResponseChoice.Message.Content.ContentStr == nil { t.Fatal("Expected content string") } expectedMessage := "Hello from anthropic using model claude-3" - if *response.Choices[0].Message.Content.ContentStr != expectedMessage { - t.Errorf("Expected '%s', got: %s", expectedMessage, *response.Choices[0].Message.Content.ContentStr) + if *response.Choices[0].BifrostNonStreamResponseChoice.Message.Content.ContentStr != expectedMessage { + t.Errorf("Expected '%s', got: %s", expectedMessage, *response.Choices[0].BifrostNonStreamResponseChoice.Message.Content.ContentStr) } } @@ -415,16 +405,14 @@ func TestMockerPlugin_Statistics(t *testing.T) { // Make multiple requests for i := 0; i < 3; i++ { - _, _ = client.ChatCompletionRequest(ctx, &schemas.BifrostRequest{ + _, _ = client.ChatCompletionRequest(ctx, &schemas.BifrostChatRequest{ Provider: schemas.OpenAI, Model: "gpt-4", - Input: schemas.RequestInput{ - ChatCompletionInput: &[]schemas.BifrostMessage{ - { - Role: schemas.ModelChatMessageRoleUser, - Content: schemas.MessageContent{ - ContentStr: bifrost.Ptr("Hello, test message"), - }, + Input: []schemas.ChatMessage{ + { + Role: schemas.ChatMessageRoleUser, + Content: schemas.ChatMessageContent{ + ContentStr: bifrost.Ptr("Hello, test message"), }, }, }, diff --git a/plugins/semanticcache/main.go b/plugins/semanticcache/main.go index 8e03fa986d..1c1ccee0b5 100644 --- a/plugins/semanticcache/main.go +++ b/plugins/semanticcache/main.go @@ -366,11 +366,6 @@ func (plugin *Plugin) PreHook(ctx *context.Context, req *schemas.BifrostRequest) *ctx = context.WithValue(*ctx, requestModelKey, req.Model) *ctx = context.WithValue(*ctx, requestProviderKey, req.Provider) - requestType, ok := (*ctx).Value(schemas.BifrostContextKeyRequestType).(schemas.RequestType) - if !ok { - return req, nil, nil - } - performDirectSearch, performSemanticSearch := true, true if (*ctx).Value(CacheTypeKey) != nil { cacheTypeVal, ok := (*ctx).Value(CacheTypeKey).(CacheType) @@ -383,7 +378,7 @@ func (plugin *Plugin) PreHook(ctx *context.Context, req *schemas.BifrostRequest) } if performDirectSearch { - shortCircuit, err := plugin.performDirectSearch(ctx, req, requestType, cacheKey) + shortCircuit, err := plugin.performDirectSearch(ctx, req, cacheKey) if err != nil { plugin.logger.Warn(PluginLoggerPrefix + " Direct search failed: " + err.Error()) // Don't return - continue to semantic search fallback @@ -396,13 +391,13 @@ func (plugin *Plugin) PreHook(ctx *context.Context, req *schemas.BifrostRequest) } if performSemanticSearch && plugin.client != nil { - if req.Input.EmbeddingInput != nil || req.Input.TranscriptionInput != nil { + if req.EmbeddingRequest != nil || req.TranscriptionRequest != nil { plugin.logger.Debug(PluginLoggerPrefix + " Skipping semantic search for embedding/transcription input") return req, nil, nil } // Try semantic search as fallback - shortCircuit, err := plugin.performSemanticSearch(ctx, req, requestType, cacheKey) + shortCircuit, err := plugin.performSemanticSearch(ctx, req, cacheKey) if err != nil { return req, nil, nil } @@ -461,12 +456,6 @@ func (plugin *Plugin) PostHook(ctx *context.Context, res *schemas.BifrostRespons } } - // Get the request type from context - requestType, ok := (*ctx).Value(schemas.BifrostContextKeyRequestType).(schemas.RequestType) - if !ok { - return res, nil, nil - } - // Get the cache key from context cacheKey, ok := (*ctx).Value(CacheKey).(string) if !ok { @@ -499,6 +488,8 @@ func (plugin *Plugin) PostHook(ctx *context.Context, res *schemas.BifrostRespons } } + requestType := res.ExtraFields.RequestType + // Get embedding from context if available and needed if shouldStoreEmbeddings && requestType != schemas.EmbeddingRequest && requestType != schemas.TranscriptionRequest { embeddingValue := (*ctx).Value(requestEmbeddingKey) @@ -577,7 +568,7 @@ func (plugin *Plugin) PostHook(ctx *context.Context, res *schemas.BifrostRespons embeddingToStore = nil } - if plugin.isStreamingRequest(requestType) { + if bifrost.IsStreamRequestType(requestType) { if err := plugin.addStreamingResponse(cacheCtx, requestID, res, bifrostErr, embeddingToStore, unifiedMetadata, cacheTTL, isFinalChunk); err != nil { plugin.logger.Warn(fmt.Sprintf("%s Failed to cache streaming response: %v", PluginLoggerPrefix, err)) } diff --git a/plugins/semanticcache/plugin_conversation_config_test.go b/plugins/semanticcache/plugin_conversation_config_test.go index 28bffe3d8b..c842d67240 100644 --- a/plugins/semanticcache/plugin_conversation_config_test.go +++ b/plugins/semanticcache/plugin_conversation_config_test.go @@ -167,15 +167,15 @@ func TestConversationHistoryThresholdDifferentValues(t *testing.T) { ctx := CreateContextWithCacheKey("test-threshold-" + tc.name) // Build conversation with specified number of messages - var conversation []schemas.BifrostMessage + var conversation []schemas.ChatMessage for i := 0; i < tc.messages; i++ { - role := schemas.ModelChatMessageRoleUser + role := schemas.ChatMessageRoleUser if i%2 == 1 { - role = schemas.ModelChatMessageRoleAssistant + role = schemas.ChatMessageRoleAssistant } - message := schemas.BifrostMessage{ + message := schemas.ChatMessage{ Role: role, - Content: schemas.MessageContent{ + Content: schemas.ChatMessageContent{ ContentStr: bifrost.Ptr("Message " + strconv.Itoa(i+1)), }, } @@ -328,41 +328,41 @@ func TestExcludeSystemPromptWithMultipleSystemMessages(t *testing.T) { ctx := CreateContextWithCacheKey("test-multiple-system-messages") // Manually create conversation with multiple system messages - conversation1 := []schemas.BifrostMessage{ + conversation1 := []schemas.ChatMessage{ { - Role: schemas.ModelChatMessageRoleSystem, - Content: schemas.MessageContent{ContentStr: bifrost.Ptr("You are helpful")}, + Role: schemas.ChatMessageRoleSystem, + Content: schemas.ChatMessageContent{ContentStr: bifrost.Ptr("You are helpful")}, }, { - Role: schemas.ModelChatMessageRoleSystem, - Content: schemas.MessageContent{ContentStr: bifrost.Ptr("Be concise")}, + Role: schemas.ChatMessageRoleSystem, + Content: schemas.ChatMessageContent{ContentStr: bifrost.Ptr("Be concise")}, }, { - Role: schemas.ModelChatMessageRoleUser, - Content: schemas.MessageContent{ContentStr: bifrost.Ptr("Hello")}, + Role: schemas.ChatMessageRoleUser, + Content: schemas.ChatMessageContent{ContentStr: bifrost.Ptr("Hello")}, }, { - Role: schemas.ModelChatMessageRoleAssistant, - Content: schemas.MessageContent{ContentStr: bifrost.Ptr("Hi!")}, + Role: schemas.ChatMessageRoleAssistant, + Content: schemas.ChatMessageContent{ContentStr: bifrost.Ptr("Hi!")}, }, } - conversation2 := []schemas.BifrostMessage{ + conversation2 := []schemas.ChatMessage{ { - Role: schemas.ModelChatMessageRoleSystem, - Content: schemas.MessageContent{ContentStr: bifrost.Ptr("You are an expert")}, + Role: schemas.ChatMessageRoleSystem, + Content: schemas.ChatMessageContent{ContentStr: bifrost.Ptr("You are an expert")}, }, { - Role: schemas.ModelChatMessageRoleSystem, - Content: schemas.MessageContent{ContentStr: bifrost.Ptr("Be detailed")}, + Role: schemas.ChatMessageRoleSystem, + Content: schemas.ChatMessageContent{ContentStr: bifrost.Ptr("Be detailed")}, }, { - Role: schemas.ModelChatMessageRoleUser, - Content: schemas.MessageContent{ContentStr: bifrost.Ptr("Hello")}, + Role: schemas.ChatMessageRoleUser, + Content: schemas.ChatMessageContent{ContentStr: bifrost.Ptr("Hello")}, }, { - Role: schemas.ModelChatMessageRoleAssistant, - Content: schemas.MessageContent{ContentStr: bifrost.Ptr("Hi!")}, + Role: schemas.ChatMessageRoleAssistant, + Content: schemas.ChatMessageContent{ContentStr: bifrost.Ptr("Hi!")}, }, } @@ -397,14 +397,14 @@ func TestExcludeSystemPromptWithNoSystemMessages(t *testing.T) { ctx := CreateContextWithCacheKey("test-no-system-messages") // Conversation with no system messages - conversation := []schemas.BifrostMessage{ + conversation := []schemas.ChatMessage{ { - Role: schemas.ModelChatMessageRoleUser, - Content: schemas.MessageContent{ContentStr: bifrost.Ptr("Hello")}, + Role: schemas.ChatMessageRoleUser, + Content: schemas.ChatMessageContent{ContentStr: bifrost.Ptr("Hello")}, }, { - Role: schemas.ModelChatMessageRoleAssistant, - Content: schemas.MessageContent{ContentStr: bifrost.Ptr("Hi there!")}, + Role: schemas.ChatMessageRoleAssistant, + Content: schemas.ChatMessageContent{ContentStr: bifrost.Ptr("Hi there!")}, }, } diff --git a/plugins/semanticcache/plugin_edge_cases_test.go b/plugins/semanticcache/plugin_edge_cases_test.go index 7136e6ec3c..173571f132 100644 --- a/plugins/semanticcache/plugin_edge_cases_test.go +++ b/plugins/semanticcache/plugin_edge_cases_test.go @@ -19,8 +19,8 @@ func TestParameterVariations(t *testing.T) { tests := []struct { name string - request1 *schemas.BifrostRequest - request2 *schemas.BifrostRequest + request1 *schemas.BifrostChatRequest + request2 *schemas.BifrostChatRequest shouldCache bool }{ { @@ -80,49 +80,45 @@ func TestToolVariations(t *testing.T) { ctx := CreateContextWithCacheKey("tool-variations-test") // Base request without tools - baseRequest := &schemas.BifrostRequest{ + baseRequest := &schemas.BifrostChatRequest{ Provider: schemas.OpenAI, Model: "gpt-4o-mini", - Input: schemas.RequestInput{ - ChatCompletionInput: &[]schemas.BifrostMessage{ - { - Role: "user", - Content: schemas.MessageContent{ - ContentStr: bifrost.Ptr("What's the weather like today?"), - }, + Input: []schemas.ChatMessage{ + { + Role: schemas.ChatMessageRoleUser, + Content: schemas.ChatMessageContent{ + ContentStr: bifrost.Ptr("What's the weather like today?"), }, }, }, - Params: &schemas.ModelParameters{ - Temperature: bifrost.Ptr(0.5), - MaxTokens: bifrost.Ptr(100), + Params: &schemas.ChatParameters{ + MaxCompletionTokens: bifrost.Ptr(100), + Temperature: bifrost.Ptr(0.5), }, } // Request with tools - requestWithTools := &schemas.BifrostRequest{ + requestWithTools := &schemas.BifrostChatRequest{ Provider: schemas.OpenAI, Model: "gpt-4o-mini", - Input: schemas.RequestInput{ - ChatCompletionInput: &[]schemas.BifrostMessage{ - { - Role: "user", - Content: schemas.MessageContent{ - ContentStr: bifrost.Ptr("What's the weather like today?"), - }, + Input: []schemas.ChatMessage{ + { + Role: schemas.ChatMessageRoleUser, + Content: schemas.ChatMessageContent{ + ContentStr: bifrost.Ptr("What's the weather like today?"), }, }, }, - Params: &schemas.ModelParameters{ - Temperature: bifrost.Ptr(0.5), - MaxTokens: bifrost.Ptr(100), - Tools: &[]schemas.Tool{ + Params: &schemas.ChatParameters{ + MaxCompletionTokens: bifrost.Ptr(100), + Temperature: bifrost.Ptr(0.5), + Tools: []schemas.ChatTool{ { - Type: "function", - Function: schemas.Function{ + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ Name: "get_weather", - Description: "Get the current weather", - Parameters: schemas.FunctionParameters{ + Description: bifrost.Ptr("Get the current weather"), + Parameters: &schemas.ToolFunctionParameters{ Type: "object", Properties: map[string]interface{}{ "location": map[string]interface{}{ @@ -131,6 +127,7 @@ func TestToolVariations(t *testing.T) { }, }, }, + Strict: bifrost.Ptr(false), }, }, }, @@ -138,29 +135,27 @@ func TestToolVariations(t *testing.T) { } // Request with different tools - requestWithDifferentTools := &schemas.BifrostRequest{ + requestWithDifferentTools := &schemas.BifrostChatRequest{ Provider: schemas.OpenAI, Model: "gpt-4o-mini", - Input: schemas.RequestInput{ - ChatCompletionInput: &[]schemas.BifrostMessage{ - { - Role: "user", - Content: schemas.MessageContent{ - ContentStr: bifrost.Ptr("What's the weather like today?"), - }, + Input: []schemas.ChatMessage{ + { + Role: schemas.ChatMessageRoleUser, + Content: schemas.ChatMessageContent{ + ContentStr: bifrost.Ptr("What's the weather like today?"), }, }, }, - Params: &schemas.ModelParameters{ - Temperature: bifrost.Ptr(0.5), - MaxTokens: bifrost.Ptr(100), - Tools: &[]schemas.Tool{ + Params: &schemas.ChatParameters{ + MaxCompletionTokens: bifrost.Ptr(100), + Temperature: bifrost.Ptr(0.5), + Tools: []schemas.ChatTool{ { - Type: "function", - Function: schemas.Function{ - Name: "get_current_weather", // Different name - Description: "Get current weather information", - Parameters: schemas.FunctionParameters{ + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ + Name: "get_current_weather", + Description: bifrost.Ptr("Get current weather information"), + Parameters: &schemas.ToolFunctionParameters{ Type: "object", Properties: map[string]interface{}{ "city": map[string]interface{}{ // Different parameter name @@ -169,6 +164,7 @@ func TestToolVariations(t *testing.T) { }, }, }, + Strict: bifrost.Ptr(false), }, }, }, @@ -225,150 +221,140 @@ func TestContentVariations(t *testing.T) { tests := []struct { name string - request *schemas.BifrostRequest + request *schemas.BifrostChatRequest }{ { name: "Unicode Content", - request: &schemas.BifrostRequest{ + request: &schemas.BifrostChatRequest{ Provider: schemas.OpenAI, Model: "gpt-4o-mini", - Input: schemas.RequestInput{ - ChatCompletionInput: &[]schemas.BifrostMessage{ - { - Role: "user", - Content: schemas.MessageContent{ - ContentStr: bifrost.Ptr("🌟 Unicode test: Hello, 世界! مرحبا 🌍"), - }, + Input: []schemas.ChatMessage{ + { + Role: schemas.ChatMessageRoleUser, + Content: schemas.ChatMessageContent{ + ContentStr: bifrost.Ptr("🌟 Unicode test: Hello, 世界! مرحبا 🌍"), }, }, }, - Params: &schemas.ModelParameters{ - Temperature: bifrost.Ptr(0.1), - MaxTokens: bifrost.Ptr(50), + Params: &schemas.ChatParameters{ + MaxCompletionTokens: bifrost.Ptr(50), + Temperature: bifrost.Ptr(0.1), }, }, }, { name: "Image URL Content", - request: &schemas.BifrostRequest{ + request: &schemas.BifrostChatRequest{ Provider: schemas.OpenAI, Model: "gpt-4o-mini", - Input: schemas.RequestInput{ - ChatCompletionInput: &[]schemas.BifrostMessage{ - { - Role: "user", - Content: schemas.MessageContent{ - ContentBlocks: &[]schemas.ContentBlock{ - { - Type: "text", - Text: bifrost.Ptr("Analyze this image"), - }, - { - Type: "image_url", - ImageURL: &schemas.ImageURLStruct{ - URL: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", - }, + Input: []schemas.ChatMessage{ + { + Role: schemas.ChatMessageRoleUser, + Content: schemas.ChatMessageContent{ + ContentBlocks: &[]schemas.ChatContentBlock{ + { + Type: schemas.ChatContentBlockTypeText, + Text: bifrost.Ptr("Analyze this image"), + }, + { + Type: schemas.ChatContentBlockTypeImage, + ImageURLStruct: &schemas.ChatInputImage{ + URL: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", }, }, }, }, }, }, - Params: &schemas.ModelParameters{ - Temperature: bifrost.Ptr(0.3), - MaxTokens: bifrost.Ptr(200), + Params: &schemas.ChatParameters{ + MaxCompletionTokens: bifrost.Ptr(200), + Temperature: bifrost.Ptr(0.3), }, }, }, { name: "Multiple Images", - request: &schemas.BifrostRequest{ + request: &schemas.BifrostChatRequest{ Provider: schemas.OpenAI, Model: "gpt-4o-mini", - Input: schemas.RequestInput{ - ChatCompletionInput: &[]schemas.BifrostMessage{ - { - Role: "user", - Content: schemas.MessageContent{ - ContentBlocks: &[]schemas.ContentBlock{ - { - Type: "text", - Text: bifrost.Ptr("Compare these images"), - }, - { - Type: "image_url", - ImageURL: &schemas.ImageURLStruct{ - URL: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", - }, + Input: []schemas.ChatMessage{ + { + Role: schemas.ChatMessageRoleUser, + Content: schemas.ChatMessageContent{ + ContentBlocks: &[]schemas.ChatContentBlock{ + { + Type: schemas.ChatContentBlockTypeText, + Text: bifrost.Ptr("Compare these images"), + }, + { + Type: schemas.ChatContentBlockTypeImage, + ImageURLStruct: &schemas.ChatInputImage{ + URL: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", }, - { - Type: "image_url", - ImageURL: &schemas.ImageURLStruct{ - URL: "https://upload.wikimedia.org/wikipedia/commons/b/b5/Scenery_.jpg", - }, + }, + { + Type: schemas.ChatContentBlockTypeImage, + ImageURLStruct: &schemas.ChatInputImage{ + URL: "https://upload.wikimedia.org/wikipedia/commons/b/b5/Scenery_.jpg", }, }, }, }, }, }, - Params: &schemas.ModelParameters{ - Temperature: bifrost.Ptr(0.3), - MaxTokens: bifrost.Ptr(200), + Params: &schemas.ChatParameters{ + MaxCompletionTokens: bifrost.Ptr(200), + Temperature: bifrost.Ptr(0.3), }, }, }, { name: "Very Long Content", - request: &schemas.BifrostRequest{ + request: &schemas.BifrostChatRequest{ Provider: schemas.OpenAI, Model: "gpt-4o-mini", - Input: schemas.RequestInput{ - ChatCompletionInput: &[]schemas.BifrostMessage{ - { - Role: "user", - Content: schemas.MessageContent{ - ContentStr: bifrost.Ptr(strings.Repeat("This is a very long prompt. ", 100)), - }, + Input: []schemas.ChatMessage{ + { + Role: schemas.ChatMessageRoleUser, + Content: schemas.ChatMessageContent{ + ContentStr: bifrost.Ptr(strings.Repeat("This is a very long prompt. ", 100)), }, }, }, - Params: &schemas.ModelParameters{ - Temperature: bifrost.Ptr(0.2), - MaxTokens: bifrost.Ptr(50), + Params: &schemas.ChatParameters{ + MaxCompletionTokens: bifrost.Ptr(50), + Temperature: bifrost.Ptr(0.2), }, }, }, { name: "Multi-turn Conversation", - request: &schemas.BifrostRequest{ + request: &schemas.BifrostChatRequest{ Provider: schemas.OpenAI, Model: "gpt-4o-mini", - Input: schemas.RequestInput{ - ChatCompletionInput: &[]schemas.BifrostMessage{ - { - Role: "user", - Content: schemas.MessageContent{ - ContentStr: bifrost.Ptr("What is AI?"), - }, + Input: []schemas.ChatMessage{ + { + Role: schemas.ChatMessageRoleUser, + Content: schemas.ChatMessageContent{ + ContentStr: bifrost.Ptr("What is AI?"), }, - { - Role: "assistant", - Content: schemas.MessageContent{ - ContentStr: bifrost.Ptr("AI stands for Artificial Intelligence..."), - }, + }, + { + Role: schemas.ChatMessageRoleAssistant, + Content: schemas.ChatMessageContent{ + ContentStr: bifrost.Ptr("AI stands for Artificial Intelligence..."), }, - { - Role: "user", - Content: schemas.MessageContent{ - ContentStr: bifrost.Ptr("Can you give me examples?"), - }, + }, + { + Role: schemas.ChatMessageRoleUser, + Content: schemas.ChatMessageContent{ + ContentStr: bifrost.Ptr("Can you give me examples?"), }, }, }, - Params: &schemas.ModelParameters{ - Temperature: bifrost.Ptr(0.5), - MaxTokens: bifrost.Ptr(150), + Params: &schemas.ChatParameters{ + MaxCompletionTokens: bifrost.Ptr(150), + Temperature: bifrost.Ptr(0.5), }, }, }, @@ -409,76 +395,70 @@ func TestBoundaryParameterValues(t *testing.T) { tests := []struct { name string - request *schemas.BifrostRequest + request *schemas.BifrostChatRequest }{ { name: "Maximum Parameter Values", - request: &schemas.BifrostRequest{ + request: &schemas.BifrostChatRequest{ Provider: schemas.OpenAI, Model: "gpt-4o-mini", - Input: schemas.RequestInput{ - ChatCompletionInput: &[]schemas.BifrostMessage{ - { - Role: "user", - Content: schemas.MessageContent{ - ContentStr: bifrost.Ptr("Test max parameters"), - }, + Input: []schemas.ChatMessage{ + { + Role: schemas.ChatMessageRoleUser, + Content: schemas.ChatMessageContent{ + ContentStr: bifrost.Ptr("Test max parameters"), }, }, }, - Params: &schemas.ModelParameters{ - Temperature: bifrost.Ptr(2.0), - MaxTokens: bifrost.Ptr(4096), - TopP: bifrost.Ptr(1.0), - PresencePenalty: bifrost.Ptr(2.0), - FrequencyPenalty: bifrost.Ptr(2.0), + Params: &schemas.ChatParameters{ + MaxCompletionTokens: bifrost.Ptr(4096), + PresencePenalty: bifrost.Ptr(2.0), + FrequencyPenalty: bifrost.Ptr(2.0), + Temperature: bifrost.Ptr(2.0), + TopP: bifrost.Ptr(1.0), }, }, }, { name: "Minimum Parameter Values", - request: &schemas.BifrostRequest{ + request: &schemas.BifrostChatRequest{ Provider: schemas.OpenAI, Model: "gpt-4o-mini", - Input: schemas.RequestInput{ - ChatCompletionInput: &[]schemas.BifrostMessage{ - { - Role: "user", - Content: schemas.MessageContent{ - ContentStr: bifrost.Ptr("Test min parameters"), - }, + Input: []schemas.ChatMessage{ + { + Role: schemas.ChatMessageRoleUser, + Content: schemas.ChatMessageContent{ + ContentStr: bifrost.Ptr("Test min parameters"), }, }, }, - Params: &schemas.ModelParameters{ - Temperature: bifrost.Ptr(0.0), - MaxTokens: bifrost.Ptr(1), - TopP: bifrost.Ptr(0.01), - PresencePenalty: bifrost.Ptr(-2.0), - FrequencyPenalty: bifrost.Ptr(-2.0), + Params: &schemas.ChatParameters{ + MaxCompletionTokens: bifrost.Ptr(1), + PresencePenalty: bifrost.Ptr(-2.0), + FrequencyPenalty: bifrost.Ptr(-2.0), + Temperature: bifrost.Ptr(0.0), + TopP: bifrost.Ptr(0.01), }, }, }, { name: "Edge Case Parameters", - request: &schemas.BifrostRequest{ + request: &schemas.BifrostChatRequest{ Provider: schemas.OpenAI, Model: "gpt-4o-mini", - Input: schemas.RequestInput{ - ChatCompletionInput: &[]schemas.BifrostMessage{ - { - Role: "user", - Content: schemas.MessageContent{ - ContentStr: bifrost.Ptr("Test edge case parameters"), - }, + Input: []schemas.ChatMessage{ + { + Role: schemas.ChatMessageRoleUser, + Content: schemas.ChatMessageContent{ + ContentStr: bifrost.Ptr("Test edge case parameters"), }, }, }, - Params: &schemas.ModelParameters{ - Temperature: bifrost.Ptr(0.0), - MaxTokens: bifrost.Ptr(1), - TopP: bifrost.Ptr(0.1), - User: bifrost.Ptr("test-user-id-12345"), + Params: &schemas.ChatParameters{ + MaxCompletionTokens: bifrost.Ptr(1), + User: bifrost.Ptr("test-user-id-12345"), + Temperature: bifrost.Ptr(0.0), + TopP: bifrost.Ptr(0.1), }, }, }, diff --git a/plugins/semanticcache/plugin_integration_test.go b/plugins/semanticcache/plugin_integration_test.go index 16c80e342c..f8d73660be 100644 --- a/plugins/semanticcache/plugin_integration_test.go +++ b/plugins/semanticcache/plugin_integration_test.go @@ -20,25 +20,25 @@ func TestSemanticCacheBasicFlow(t *testing.T) { // Add cache key to context ctx = context.WithValue(ctx, CacheKey, "test-cache-enabled") - ctx = context.WithValue(ctx, schemas.BifrostContextKeyRequestType, schemas.ChatCompletionRequest) // Test request request := &schemas.BifrostRequest{ - Provider: schemas.OpenAI, - Model: "gpt-4o-mini", - Input: schemas.RequestInput{ - ChatCompletionInput: &[]schemas.BifrostMessage{ + Provider: schemas.OpenAI, + Model: "gpt-4o-mini", + RequestType: schemas.ChatCompletionRequest, + ChatRequest: &schemas.BifrostChatRequest{ + Input: []schemas.ChatMessage{ { - Role: "user", - Content: schemas.MessageContent{ + Role: schemas.ChatMessageRoleUser, + Content: schemas.ChatMessageContent{ ContentStr: bifrost.Ptr("Hello, world!"), }, }, }, - }, - Params: &schemas.ModelParameters{ - Temperature: bifrost.Ptr(0.7), - MaxTokens: bifrost.Ptr(100), + Params: &schemas.ChatParameters{ + Temperature: bifrost.Ptr(0.7), + MaxCompletionTokens: bifrost.Ptr(100), + }, }, } @@ -63,12 +63,12 @@ func TestSemanticCacheBasicFlow(t *testing.T) { // Simulate a response response := &schemas.BifrostResponse{ ID: uuid.New().String(), - Choices: []schemas.BifrostResponseChoice{ + Choices: []schemas.BifrostChatResponseChoice{ { BifrostNonStreamResponseChoice: &schemas.BifrostNonStreamResponseChoice{ - Message: schemas.BifrostMessage{ + Message: schemas.ChatMessage{ Role: "assistant", - Content: schemas.MessageContent{ + Content: schemas.ChatMessageContent{ ContentStr: bifrost.Ptr("Hello! How can I help you today?"), }, }, @@ -76,7 +76,9 @@ func TestSemanticCacheBasicFlow(t *testing.T) { }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: schemas.OpenAI, + Provider: schemas.OpenAI, + ModelRequested: "gpt-4o-mini", + RequestType: schemas.ChatCompletionRequest, }, } @@ -107,7 +109,6 @@ func TestSemanticCacheBasicFlow(t *testing.T) { // Reset context for second request ctx2 := context.Background() ctx2 = context.WithValue(ctx2, CacheKey, "test-cache-enabled") - ctx2 = context.WithValue(ctx2, schemas.BifrostContextKeyRequestType, schemas.ChatCompletionRequest) modifiedReq2, shortCircuit2, err := setup.Plugin.PreHook(&ctx2, request) if err != nil { @@ -162,25 +163,25 @@ func TestSemanticCacheStrictFiltering(t *testing.T) { ctx := context.Background() ctx = context.WithValue(ctx, CacheKey, "test-cache-enabled") - ctx = context.WithValue(ctx, schemas.BifrostContextKeyRequestType, schemas.ChatCompletionRequest) // Base request baseRequest := &schemas.BifrostRequest{ - Provider: schemas.OpenAI, - Model: "gpt-4o-mini", - Input: schemas.RequestInput{ - ChatCompletionInput: &[]schemas.BifrostMessage{ + Provider: schemas.OpenAI, + Model: "gpt-4o-mini", + RequestType: schemas.ChatCompletionRequest, + ChatRequest: &schemas.BifrostChatRequest{ + Input: []schemas.ChatMessage{ { - Role: "user", - Content: schemas.MessageContent{ + Role: schemas.ChatMessageRoleUser, + Content: schemas.ChatMessageContent{ ContentStr: bifrost.Ptr("What is the weather like?"), }, }, }, - }, - Params: &schemas.ModelParameters{ - Temperature: bifrost.Ptr(0.7), - MaxTokens: bifrost.Ptr(100), + Params: &schemas.ChatParameters{ + Temperature: bifrost.Ptr(0.7), + MaxCompletionTokens: bifrost.Ptr(100), + }, }, } @@ -199,12 +200,12 @@ func TestSemanticCacheStrictFiltering(t *testing.T) { // Cache a response response := &schemas.BifrostResponse{ ID: uuid.New().String(), - Choices: []schemas.BifrostResponseChoice{ + Choices: []schemas.BifrostChatResponseChoice{ { BifrostNonStreamResponseChoice: &schemas.BifrostNonStreamResponseChoice{ - Message: schemas.BifrostMessage{ + Message: schemas.ChatMessage{ Role: "assistant", - Content: schemas.MessageContent{ + Content: schemas.ChatMessageContent{ ContentStr: bifrost.Ptr("It's sunny today!"), }, }, @@ -212,7 +213,9 @@ func TestSemanticCacheStrictFiltering(t *testing.T) { }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: schemas.OpenAI, + Provider: schemas.OpenAI, + ModelRequested: "gpt-4o-mini", + RequestType: schemas.ChatCompletionRequest, }, } @@ -229,15 +232,28 @@ func TestSemanticCacheStrictFiltering(t *testing.T) { ctx2 := context.Background() ctx2 = context.WithValue(ctx2, CacheKey, "test-cache-enabled") - ctx2 = context.WithValue(ctx2, schemas.BifrostContextKeyRequestType, schemas.ChatCompletionRequest) - modifiedRequest := *baseRequest - modifiedRequest.Params = &schemas.ModelParameters{ - Temperature: bifrost.Ptr(0.5), // Different temperature - MaxTokens: bifrost.Ptr(100), + modifiedRequest := &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4o-mini", + RequestType: schemas.ChatCompletionRequest, + ChatRequest: &schemas.BifrostChatRequest{ + Input: []schemas.ChatMessage{ + { + Role: schemas.ChatMessageRoleUser, + Content: schemas.ChatMessageContent{ + ContentStr: bifrost.Ptr("What is the weather like?"), + }, + }, + }, + Params: &schemas.ChatParameters{ + Temperature: bifrost.Ptr(0.5), // Different temperature + MaxCompletionTokens: bifrost.Ptr(100), + }, + }, } - _, shortCircuit2, err := setup.Plugin.PreHook(&ctx2, &modifiedRequest) + _, shortCircuit2, err := setup.Plugin.PreHook(&ctx2, modifiedRequest) if err != nil { t.Fatalf("Second PreHook failed: %v", err) } @@ -253,12 +269,28 @@ func TestSemanticCacheStrictFiltering(t *testing.T) { ctx3 := context.Background() ctx3 = context.WithValue(ctx3, CacheKey, "test-cache-enabled") - ctx3 = context.WithValue(ctx3, schemas.BifrostContextKeyRequestType, schemas.ChatCompletionRequest) - modifiedRequest2 := *baseRequest - modifiedRequest2.Model = "gpt-3.5-turbo" // Different model + modifiedRequest2 := &schemas.BifrostRequest{ + Provider: schemas.OpenAI, + Model: "gpt-3.5-turbo", // Different model + RequestType: schemas.ChatCompletionRequest, + ChatRequest: &schemas.BifrostChatRequest{ + Input: []schemas.ChatMessage{ + { + Role: schemas.ChatMessageRoleUser, + Content: schemas.ChatMessageContent{ + ContentStr: bifrost.Ptr("What is the weather like?"), + }, + }, + }, + Params: &schemas.ChatParameters{ + Temperature: bifrost.Ptr(0.7), + MaxCompletionTokens: bifrost.Ptr(100), + }, + }, + } - _, shortCircuit3, err := setup.Plugin.PreHook(&ctx3, &modifiedRequest2) + _, shortCircuit3, err := setup.Plugin.PreHook(&ctx3, modifiedRequest2) if err != nil { t.Fatalf("Third PreHook failed: %v", err) } @@ -278,23 +310,23 @@ func TestSemanticCacheStreamingFlow(t *testing.T) { ctx := context.Background() ctx = context.WithValue(ctx, CacheKey, "test-cache-enabled") - ctx = context.WithValue(ctx, schemas.BifrostContextKeyRequestType, schemas.ChatCompletionStreamRequest) request := &schemas.BifrostRequest{ - Provider: schemas.OpenAI, - Model: "gpt-4o-mini", - Input: schemas.RequestInput{ - ChatCompletionInput: &[]schemas.BifrostMessage{ + Provider: schemas.OpenAI, + Model: "gpt-4o-mini", + RequestType: schemas.ChatCompletionStreamRequest, + ChatRequest: &schemas.BifrostChatRequest{ + Input: []schemas.ChatMessage{ { - Role: "user", - Content: schemas.MessageContent{ + Role: schemas.ChatMessageRoleUser, + Content: schemas.ChatMessageContent{ ContentStr: bifrost.Ptr("Tell me a short story"), }, }, }, - }, - Params: &schemas.ModelParameters{ - Temperature: bifrost.Ptr(0.8), + Params: &schemas.ChatParameters{ + Temperature: bifrost.Ptr(0.8), + }, }, } @@ -329,7 +361,7 @@ func TestSemanticCacheStreamingFlow(t *testing.T) { chunkResponse := &schemas.BifrostResponse{ ID: uuid.New().String(), - Choices: []schemas.BifrostResponseChoice{ + Choices: []schemas.BifrostChatResponseChoice{ { Index: i, FinishReason: finishReason, @@ -341,8 +373,10 @@ func TestSemanticCacheStreamingFlow(t *testing.T) { }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: schemas.OpenAI, - ChunkIndex: i, + Provider: schemas.OpenAI, + ModelRequested: "gpt-4o-mini", + RequestType: schemas.ChatCompletionStreamRequest, + ChunkIndex: i, }, } @@ -360,7 +394,6 @@ func TestSemanticCacheStreamingFlow(t *testing.T) { ctx2 := context.Background() ctx2 = context.WithValue(ctx2, CacheKey, "test-cache-enabled") - ctx2 = context.WithValue(ctx2, schemas.BifrostContextKeyRequestType, schemas.ChatCompletionStreamRequest) _, shortCircuit2, err := setup.Plugin.PreHook(&ctx2, request) if err != nil { @@ -405,16 +438,16 @@ func TestSemanticCache_NoCacheWhenKeyMissing(t *testing.T) { ctx := context.Background() // Don't set the cache key - cache should be disabled - ctx = context.WithValue(ctx, schemas.BifrostContextKeyRequestType, schemas.ChatCompletionRequest) request := &schemas.BifrostRequest{ - Provider: schemas.OpenAI, - Model: "gpt-4o-mini", - Input: schemas.RequestInput{ - ChatCompletionInput: &[]schemas.BifrostMessage{ + Provider: schemas.OpenAI, + Model: "gpt-4o-mini", + RequestType: schemas.ChatCompletionRequest, + ChatRequest: &schemas.BifrostChatRequest{ + Input: []schemas.ChatMessage{ { - Role: "user", - Content: schemas.MessageContent{ + Role: schemas.ChatMessageRoleUser, + Content: schemas.ChatMessageContent{ ContentStr: bifrost.Ptr("Test message"), }, }, @@ -444,16 +477,16 @@ func TestSemanticCache_CustomTTLHandling(t *testing.T) { ctx := context.Background() ctx = context.WithValue(ctx, CacheKey, "test-cache-enabled") ctx = context.WithValue(ctx, CacheTTLKey, 1*time.Minute) // Custom TTL - ctx = context.WithValue(ctx, schemas.BifrostContextKeyRequestType, schemas.ChatCompletionRequest) request := &schemas.BifrostRequest{ - Provider: schemas.OpenAI, - Model: "gpt-4o-mini", - Input: schemas.RequestInput{ - ChatCompletionInput: &[]schemas.BifrostMessage{ + Provider: schemas.OpenAI, + Model: "gpt-4o-mini", + RequestType: schemas.ChatCompletionRequest, + ChatRequest: &schemas.BifrostChatRequest{ + Input: []schemas.ChatMessage{ { - Role: "user", - Content: schemas.MessageContent{ + Role: schemas.ChatMessageRoleUser, + Content: schemas.ChatMessageContent{ ContentStr: bifrost.Ptr("TTL test message"), }, }, @@ -474,12 +507,12 @@ func TestSemanticCache_CustomTTLHandling(t *testing.T) { // Simulate response and cache it response := &schemas.BifrostResponse{ ID: "ttl-test-response", - Choices: []schemas.BifrostResponseChoice{ + Choices: []schemas.BifrostChatResponseChoice{ { BifrostNonStreamResponseChoice: &schemas.BifrostNonStreamResponseChoice{ - Message: schemas.BifrostMessage{ + Message: schemas.ChatMessage{ Role: "assistant", - Content: schemas.MessageContent{ + Content: schemas.ChatMessageContent{ ContentStr: bifrost.Ptr("TTL test response"), }, }, @@ -487,7 +520,9 @@ func TestSemanticCache_CustomTTLHandling(t *testing.T) { }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: schemas.OpenAI, + Provider: schemas.OpenAI, + ModelRequested: "gpt-4o-mini", + RequestType: schemas.ChatCompletionRequest, }, } @@ -510,16 +545,16 @@ func TestSemanticCache_CustomThresholdHandling(t *testing.T) { ctx := context.Background() ctx = context.WithValue(ctx, CacheKey, "test-cache-enabled") ctx = context.WithValue(ctx, CacheThresholdKey, 0.95) // Very high threshold - ctx = context.WithValue(ctx, schemas.BifrostContextKeyRequestType, schemas.ChatCompletionRequest) request := &schemas.BifrostRequest{ - Provider: schemas.OpenAI, - Model: "gpt-4o-mini", - Input: schemas.RequestInput{ - ChatCompletionInput: &[]schemas.BifrostMessage{ + Provider: schemas.OpenAI, + Model: "gpt-4o-mini", + RequestType: schemas.ChatCompletionRequest, + ChatRequest: &schemas.BifrostChatRequest{ + Input: []schemas.ChatMessage{ { - Role: "user", - Content: schemas.MessageContent{ + Role: schemas.ChatMessageRoleUser, + Content: schemas.ChatMessageContent{ ContentStr: bifrost.Ptr("Threshold test message"), }, }, @@ -551,16 +586,16 @@ func TestSemanticCache_ProviderModelCachingFlags(t *testing.T) { ctx := context.Background() ctx = context.WithValue(ctx, CacheKey, "test-cache-enabled") - ctx = context.WithValue(ctx, schemas.BifrostContextKeyRequestType, schemas.ChatCompletionRequest) request1 := &schemas.BifrostRequest{ - Provider: schemas.OpenAI, - Model: "gpt-4o-mini", - Input: schemas.RequestInput{ - ChatCompletionInput: &[]schemas.BifrostMessage{ + Provider: schemas.OpenAI, + Model: "gpt-4o-mini", + RequestType: schemas.ChatCompletionRequest, + ChatRequest: &schemas.BifrostChatRequest{ + Input: []schemas.ChatMessage{ { - Role: "user", - Content: schemas.MessageContent{ + Role: schemas.ChatMessageRoleUser, + Content: schemas.ChatMessageContent{ ContentStr: bifrost.Ptr("Provider model flags test"), }, }, @@ -581,12 +616,12 @@ func TestSemanticCache_ProviderModelCachingFlags(t *testing.T) { // Cache the response response := &schemas.BifrostResponse{ ID: "provider-model-test", - Choices: []schemas.BifrostResponseChoice{ + Choices: []schemas.BifrostChatResponseChoice{ { BifrostNonStreamResponseChoice: &schemas.BifrostNonStreamResponseChoice{ - Message: schemas.BifrostMessage{ + Message: schemas.ChatMessage{ Role: "assistant", - Content: schemas.MessageContent{ + Content: schemas.ChatMessageContent{ ContentStr: bifrost.Ptr("Provider model test response"), }, }, @@ -594,7 +629,9 @@ func TestSemanticCache_ProviderModelCachingFlags(t *testing.T) { }, }, ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: schemas.OpenAI, + Provider: schemas.OpenAI, + ModelRequested: "gpt-4o-mini", + RequestType: schemas.ChatCompletionRequest, }, } @@ -607,13 +644,14 @@ func TestSemanticCache_ProviderModelCachingFlags(t *testing.T) { // Second request with different provider - should potentially hit cache since provider is not considered request2 := &schemas.BifrostRequest{ - Provider: schemas.Anthropic, // Different provider - Model: "claude-3-haiku", // Different model - Input: schemas.RequestInput{ - ChatCompletionInput: &[]schemas.BifrostMessage{ + Provider: schemas.Anthropic, // Different provider + Model: "claude-3-haiku", // Different model + RequestType: schemas.ChatCompletionRequest, + ChatRequest: &schemas.BifrostChatRequest{ + Input: []schemas.ChatMessage{ { - Role: "user", - Content: schemas.MessageContent{ + Role: schemas.ChatMessageRoleUser, + Content: schemas.ChatMessageContent{ ContentStr: bifrost.Ptr("Provider model flags test"), // Same content }, }, @@ -623,7 +661,6 @@ func TestSemanticCache_ProviderModelCachingFlags(t *testing.T) { ctx2 := context.Background() ctx2 = context.WithValue(ctx2, CacheKey, "test-cache-enabled") - ctx2 = context.WithValue(ctx2, schemas.BifrostContextKeyRequestType, schemas.ChatCompletionRequest) _, shortCircuit2, err := setup.Plugin.PreHook(&ctx2, request2) if err != nil { @@ -646,16 +683,16 @@ func TestSemanticCache_ConfigurationEdgeCases(t *testing.T) { ctx := context.Background() ctx = context.WithValue(ctx, CacheKey, "test-cache-enabled") ctx = context.WithValue(ctx, CacheTTLKey, "not-a-duration") // Invalid TTL type - ctx = context.WithValue(ctx, schemas.BifrostContextKeyRequestType, schemas.ChatCompletionRequest) request := &schemas.BifrostRequest{ - Provider: schemas.OpenAI, - Model: "gpt-4o-mini", - Input: schemas.RequestInput{ - ChatCompletionInput: &[]schemas.BifrostMessage{ + Provider: schemas.OpenAI, + Model: "gpt-4o-mini", + RequestType: schemas.ChatCompletionRequest, + ChatRequest: &schemas.BifrostChatRequest{ + Input: []schemas.ChatMessage{ { - Role: "user", - Content: schemas.MessageContent{ + Role: schemas.ChatMessageRoleUser, + Content: schemas.ChatMessageContent{ ContentStr: bifrost.Ptr("Edge case test"), }, }, @@ -677,7 +714,6 @@ func TestSemanticCache_ConfigurationEdgeCases(t *testing.T) { ctx2 := context.Background() ctx2 = context.WithValue(ctx2, CacheKey, "test-cache-enabled") ctx2 = context.WithValue(ctx2, CacheThresholdKey, "not-a-float") // Invalid threshold type - ctx2 = context.WithValue(ctx2, schemas.BifrostContextKeyRequestType, schemas.ChatCompletionRequest) // Should handle invalid threshold gracefully _, shortCircuit2, err := setup.Plugin.PreHook(&ctx2, request) diff --git a/plugins/semanticcache/plugin_normalization_test.go b/plugins/semanticcache/plugin_normalization_test.go index 23bb7f30ed..80bb0d8465 100644 --- a/plugins/semanticcache/plugin_normalization_test.go +++ b/plugins/semanticcache/plugin_normalization_test.go @@ -63,30 +63,28 @@ func testChatCompletionNormalization(t *testing.T, setup *TestSetup) { } // Create chat completion requests for all test cases - requests := make([]*schemas.BifrostRequest, len(testCases)) + requests := make([]*schemas.BifrostChatRequest, len(testCases)) for i, tc := range testCases { - requests[i] = &schemas.BifrostRequest{ + requests[i] = &schemas.BifrostChatRequest{ Provider: schemas.OpenAI, Model: "gpt-4o-mini", - Input: schemas.RequestInput{ - ChatCompletionInput: &[]schemas.BifrostMessage{ - { - Role: schemas.ModelChatMessageRoleSystem, - Content: schemas.MessageContent{ - ContentStr: &tc.systemMsg, - }, + Input: []schemas.ChatMessage{ + { + Role: schemas.ChatMessageRoleSystem, + Content: schemas.ChatMessageContent{ + ContentStr: &tc.systemMsg, }, - { - Role: schemas.ModelChatMessageRoleUser, - Content: schemas.MessageContent{ - ContentStr: &tc.userMsg, - }, + }, + { + Role: schemas.ChatMessageRoleUser, + Content: schemas.ChatMessageContent{ + ContentStr: &tc.userMsg, }, }, }, - Params: &schemas.ModelParameters{ - Temperature: PtrFloat64(0.5), - MaxTokens: PtrInt(50), + Params: &schemas.ChatParameters{ + Temperature: PtrFloat64(0.5), + MaxCompletionTokens: PtrInt(50), }, } } @@ -144,7 +142,7 @@ func testSpeechNormalization(t *testing.T, setup *TestSetup) { } // Create speech requests for all test cases - requests := make([]*schemas.BifrostRequest, len(testCases)) + requests := make([]*schemas.BifrostSpeechRequest, len(testCases)) for i, tc := range testCases { requests[i] = CreateSpeechRequest(tc.input, "alloy") } @@ -214,33 +212,31 @@ func TestChatCompletionContentBlocksNormalization(t *testing.T) { } // Create chat completion requests with content blocks - requests := make([]*schemas.BifrostRequest, len(testCases)) + requests := make([]*schemas.BifrostChatRequest, len(testCases)) for i, tc := range testCases { // Create content blocks - contentBlocks := make([]schemas.ContentBlock, len(tc.textBlocks)) + contentBlocks := make([]schemas.ChatContentBlock, len(tc.textBlocks)) for j, text := range tc.textBlocks { - contentBlocks[j] = schemas.ContentBlock{ - Type: schemas.ContentBlockTypeText, + contentBlocks[j] = schemas.ChatContentBlock{ + Type: schemas.ChatContentBlockTypeText, Text: &text, } } - requests[i] = &schemas.BifrostRequest{ + requests[i] = &schemas.BifrostChatRequest{ Provider: schemas.OpenAI, Model: "gpt-4o-mini", - Input: schemas.RequestInput{ - ChatCompletionInput: &[]schemas.BifrostMessage{ - { - Role: schemas.ModelChatMessageRoleUser, - Content: schemas.MessageContent{ - ContentBlocks: &contentBlocks, - }, + Input: []schemas.ChatMessage{ + { + Role: schemas.ChatMessageRoleUser, + Content: schemas.ChatMessageContent{ + ContentBlocks: &contentBlocks, }, }, }, - Params: &schemas.ModelParameters{ - Temperature: PtrFloat64(0.5), - MaxTokens: PtrInt(50), + Params: &schemas.ChatParameters{ + Temperature: PtrFloat64(0.5), + MaxCompletionTokens: PtrInt(50), }, } } diff --git a/plugins/semanticcache/plugin_responses_test.go b/plugins/semanticcache/plugin_responses_test.go new file mode 100644 index 0000000000..9d983da768 --- /dev/null +++ b/plugins/semanticcache/plugin_responses_test.go @@ -0,0 +1,415 @@ +package semanticcache + +import ( + "testing" + "time" + + "github.com/maximhq/bifrost/core/schemas" +) + +// TestResponsesAPIBasicFunctionality tests the core caching functionality with Responses API +func TestResponsesAPIBasicFunctionality(t *testing.T) { + setup := NewTestSetup(t) + defer setup.Cleanup() + + ctx := CreateContextWithCacheKey("test-responses-basic") + + // Create test request + testRequest := CreateBasicResponsesRequest( + "What is Bifrost? Answer in one short sentence.", + 0.7, + 50, + ) + + t.Log("Making first Responses API request (should go to OpenAI and be cached)...") + + // Make first request (will go to OpenAI and be cached) + start1 := time.Now() + response1, err1 := setup.Client.ResponsesRequest(ctx, testRequest) + duration1 := time.Since(start1) + + if err1 != nil { + t.Fatalf("First Responses request failed: %v", err1) + } + + if response1 == nil || len(response1.Output) == 0 { + t.Fatal("First Responses response is invalid") + } + + t.Logf("First request completed in %v", duration1) + t.Logf("Response contains %d output messages", len(response1.Output)) + + // Wait for cache to be written + WaitForCache() + + t.Log("Making second identical Responses API request (should be served from cache)...") + + // Make second identical request (should be cached) + start2 := time.Now() + response2, err2 := setup.Client.ResponsesRequest(ctx, testRequest) + duration2 := time.Since(start2) + + if err2 != nil { + t.Fatalf("Second Responses request failed: %v", err2) + } + + if response2 == nil || len(response2.Output) == 0 { + t.Fatal("Second Responses response is invalid") + } + + t.Logf("Second request completed in %v", duration2) + + // Verify cache hit + AssertCacheHit(t, response2, string(CacheTypeDirect)) + + // Performance comparison + t.Logf("Performance Summary:") + t.Logf("First request (OpenAI): %v", duration1) + t.Logf("Second request (Cache): %v", duration2) + + if duration2 >= duration1 { + t.Log("⚠️ Cache doesn't seem faster, but this could be due to test environment") + } + + // Verify provider information is maintained in cached response + if response2.ExtraFields.Provider != testRequest.Provider { + t.Errorf("Provider mismatch in cached response: expected %s, got %s", + testRequest.Provider, response2.ExtraFields.Provider) + } + + t.Log("✅ Basic Responses API semantic caching test completed successfully!") +} + +// TestResponsesAPIDifferentParameters tests that different parameters produce different cache entries +func TestResponsesAPIDifferentParameters(t *testing.T) { + setup := NewTestSetup(t) + defer setup.Cleanup() + + ctx := CreateContextWithCacheKey("test-responses-params") + basePrompt := "Explain quantum computing" + + tests := []struct { + name string + request1 *schemas.BifrostResponsesRequest + request2 *schemas.BifrostResponsesRequest + shouldCache bool + }{ + { + name: "Identical Requests", + request1: CreateBasicResponsesRequest(basePrompt, 0.5, 50), + request2: CreateBasicResponsesRequest(basePrompt, 0.5, 50), + shouldCache: true, + }, + { + name: "Different Temperature", + request1: CreateBasicResponsesRequest(basePrompt, 0.1, 50), + request2: CreateBasicResponsesRequest(basePrompt, 0.9, 50), + shouldCache: false, + }, + { + name: "Different MaxOutputTokens", + request1: CreateBasicResponsesRequest(basePrompt, 0.5, 50), + request2: CreateBasicResponsesRequest(basePrompt, 0.5, 200), + shouldCache: false, + }, + { + name: "Different Instructions", + request1: CreateResponsesRequestWithInstructions(basePrompt, "Be concise", 0.5, 50), + request2: CreateResponsesRequestWithInstructions(basePrompt, "Be detailed", 0.5, 50), + shouldCache: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Clear cache for this subtest + clearTestKeysWithStore(t, setup.Store) + + // Make first request + _, err1 := setup.Client.ResponsesRequest(ctx, tt.request1) + if err1 != nil { + t.Fatalf("First request failed: %v", err1) + } + + WaitForCache() + + // Make second request + response2, err2 := setup.Client.ResponsesRequest(ctx, tt.request2) + if err2 != nil { + t.Fatalf("Second request failed: %v", err2) + } + + if tt.shouldCache { + AssertCacheHit(t, response2, "direct") + t.Log("✓ Parameters match: cache hit as expected") + } else { + AssertNoCacheHit(t, response2) + t.Log("✓ Parameters differ: no cache hit as expected") + } + }) + } +} + +// TestResponsesAPISemanticMatching tests semantic similarity matching with Responses API +func TestResponsesAPISemanticMatching(t *testing.T) { + setup := NewTestSetup(t) + defer setup.Cleanup() + + ctx := CreateContextWithCacheKeyAndType("test-responses-semantic", CacheTypeSemantic) + + // First request + originalRequest := CreateBasicResponsesRequest("What is machine learning?", 0.5, 50) + t.Log("Making first Responses request with original text...") + response1, err1 := setup.Client.ResponsesRequest(ctx, originalRequest) + if err1 != nil { + t.Fatalf("First request failed: %v", err1) + } + + AssertNoCacheHit(t, response1) + WaitForCache() + + // Test semantic match with similar but different text + semanticRequest := CreateBasicResponsesRequest("Can you explain machine learning concepts?", 0.5, 50) + t.Log("Making semantically similar Responses request...") + response2, err2 := setup.Client.ResponsesRequest(ctx, semanticRequest) + if err2 != nil { + t.Fatalf("Second request failed: %v", err2) + } + + // This should be a semantic cache hit + AssertCacheHit(t, response2, "semantic") + t.Log("✓ Semantic cache hit with similar content") +} + +// TestResponsesAPIWithInstructions tests caching with system instructions +func TestResponsesAPIWithInstructions(t *testing.T) { + setup := NewTestSetup(t) + defer setup.Cleanup() + + ctx := CreateContextWithCacheKey("test-responses-instructions") + + // Create request with instructions + request1 := CreateResponsesRequestWithInstructions( + "Explain artificial intelligence", + "You are a helpful assistant. Be concise and accurate.", + 0.7, + 100, + ) + + t.Log("Making first Responses request with instructions...") + response1, err1 := setup.Client.ResponsesRequest(ctx, request1) + if err1 != nil { + t.Fatalf("First request failed: %v", err1) + } + + AssertNoCacheHit(t, response1) + WaitForCache() + + // Make identical request + request2 := CreateResponsesRequestWithInstructions( + "Explain artificial intelligence", + "You are a helpful assistant. Be concise and accurate.", + 0.7, + 100, + ) + + t.Log("Making second identical Responses request with instructions...") + response2, err2 := setup.Client.ResponsesRequest(ctx, request2) + if err2 != nil { + t.Fatalf("Second request failed: %v", err2) + } + + // Should be a cache hit + AssertCacheHit(t, response2, "direct") + t.Log("✓ Responses API with instructions cached correctly") +} + +// TestResponsesAPICacheExpiration tests TTL functionality for Responses API requests +func TestResponsesAPICacheExpiration(t *testing.T) { + setup := NewTestSetup(t) + defer setup.Cleanup() + + // Set very short TTL for testing + shortTTL := 2 * time.Second + ctx := CreateContextWithCacheKeyAndTTL("test-responses-ttl", shortTTL) + + responsesRequest := CreateBasicResponsesRequest("TTL test for Responses API", 0.5, 50) + + t.Log("Making first Responses request with short TTL...") + response1, err1 := setup.Client.ResponsesRequest(ctx, responsesRequest) + if err1 != nil { + t.Fatalf("First request failed: %v", err1) + } + AssertNoCacheHit(t, response1) + + WaitForCache() + + t.Log("Making second Responses request before TTL expiration...") + response2, err2 := setup.Client.ResponsesRequest(ctx, responsesRequest) + if err2 != nil { + t.Fatalf("Second request failed: %v", err2) + } + AssertCacheHit(t, response2, "direct") + + t.Logf("Waiting for TTL expiration (%v)...", shortTTL) + time.Sleep(shortTTL + 1*time.Second) // Wait for TTL to expire + + t.Log("Making third Responses request after TTL expiration...") + response3, err3 := setup.Client.ResponsesRequest(ctx, responsesRequest) + if err3 != nil { + t.Fatalf("Third request failed: %v", err3) + } + // Should not be a cache hit since TTL expired + AssertNoCacheHit(t, response3) + + t.Log("✅ Responses API requests properly handle TTL expiration") +} + +// TestResponsesAPIWithoutCacheKey tests that Responses requests without cache key are not cached +func TestResponsesAPIWithoutCacheKey(t *testing.T) { + setup := NewTestSetup(t) + defer setup.Cleanup() + + // Don't set cache key in context + ctx := CreateContextWithCacheKey("") + + responsesRequest := CreateBasicResponsesRequest("Test Responses without cache key", 0.5, 50) + + t.Log("Making Responses request without cache key...") + + response, err := setup.Client.ResponsesRequest(ctx, responsesRequest) + if err != nil { + t.Fatalf("Responses request failed: %v", err) + } + + // Should not be cached + AssertNoCacheHit(t, response) + + t.Log("✅ Responses requests without cache key are properly not cached") +} + +// TestResponsesAPINoStoreFlag tests that Responses requests with no-store flag are not cached +func TestResponsesAPINoStoreFlag(t *testing.T) { + setup := NewTestSetup(t) + defer setup.Cleanup() + + responsesRequest := CreateBasicResponsesRequest("Test no-store with Responses API", 0.7, 50) + ctx := CreateContextWithCacheKeyAndNoStore("test-no-store-responses", true) + + t.Log("Testing no-store with Responses API...") + response1, err1 := setup.Client.ResponsesRequest(ctx, responsesRequest) + if err1 != nil { + t.Fatalf("Responses request failed: %v", err1) + } + AssertNoCacheHit(t, response1) + + WaitForCache() + + // Verify not cached + response2, err2 := setup.Client.ResponsesRequest(ctx, responsesRequest) + if err2 != nil { + t.Fatalf("Second Responses request failed: %v", err2) + } + AssertNoCacheHit(t, response2) // Should not be cached + + t.Log("✅ Responses API no-store flag working correctly") +} + +// TestResponsesAPIStreaming tests streaming Responses API requests +func TestResponsesAPIStreaming(t *testing.T) { + setup := NewTestSetup(t) + defer setup.Cleanup() + + ctx := CreateContextWithCacheKey("test-responses-streaming") + prompt := "Explain the basics of quantum computing in simple terms" + + // Make non-streaming request first + t.Log("Making non-streaming Responses request...") + nonStreamRequest := CreateBasicResponsesRequest(prompt, 0.5, 50) + _, err1 := setup.Client.ResponsesRequest(ctx, nonStreamRequest) + if err1 != nil { + t.Fatalf("Non-streaming Responses request failed: %v", err1) + } + + WaitForCache() + + // Make streaming request with same prompt and parameters + t.Log("Making streaming Responses request with same prompt...") + streamRequest := CreateStreamingResponsesRequest(prompt, 0.5, 50) + stream, err2 := setup.Client.ResponsesStreamRequest(ctx, streamRequest) + if err2 != nil { + t.Fatalf("Streaming Responses request failed: %v", err2) + } + + var streamResponses []schemas.BifrostResponse + for streamMsg := range stream { + if streamMsg.BifrostError != nil { + t.Fatalf("Error in Responses stream: %v", streamMsg.BifrostError) + } + streamResponses = append(streamResponses, *streamMsg.BifrostResponse) + } + + if len(streamResponses) == 0 { + t.Fatal("No streaming responses received") + } + + // Check if any of the streaming responses was served from cache + cacheHitFound := false + for _, resp := range streamResponses { + if resp.ExtraFields.CacheDebug != nil && resp.ExtraFields.CacheDebug.CacheHit { + cacheHitFound = true + break + } + } + + if !cacheHitFound { + t.Log("⚠️ No cache hit detected in streaming responses - this could be expected behavior") + } else { + t.Log("✓ Cache hit detected in streaming Responses API") + } + + t.Log("✅ Streaming Responses API test completed") +} + +// TestResponsesAPIComplexParameters tests complex parameter handling +func TestResponsesAPIComplexParameters(t *testing.T) { + setup := NewTestSetup(t) + defer setup.Cleanup() + + ctx := CreateContextWithCacheKey("test-responses-complex-params") + + // Create request with various complex parameters + request := CreateBasicResponsesRequest("Test complex parameters", 0.8, 150) + request.Params.TopP = PtrFloat64(0.9) + request.Params.Background = &[]bool{true}[0] + request.Params.ParallelToolCalls = &[]bool{false}[0] + request.Params.ServiceTier = &[]string{"default"}[0] + request.Params.Store = &[]bool{true}[0] + + t.Log("Making first Responses request with complex parameters...") + response1, err1 := setup.Client.ResponsesRequest(ctx, request) + if err1 != nil { + t.Fatalf("First request failed: %v", err1) + } + + AssertNoCacheHit(t, response1) + WaitForCache() + + // Create identical request + request2 := CreateBasicResponsesRequest("Test complex parameters", 0.8, 150) + request2.Params.TopP = PtrFloat64(0.9) + request2.Params.Background = &[]bool{true}[0] + request2.Params.ParallelToolCalls = &[]bool{false}[0] + request2.Params.ServiceTier = &[]string{"default"}[0] + request2.Params.Store = &[]bool{true}[0] + + t.Log("Making second identical Responses request with complex parameters...") + response2, err2 := setup.Client.ResponsesRequest(ctx, request2) + if err2 != nil { + t.Fatalf("Second request failed: %v", err2) + } + + // Should be a cache hit + AssertCacheHit(t, response2, "direct") + t.Log("✓ Responses API with complex parameters cached correctly") +} diff --git a/plugins/semanticcache/search.go b/plugins/semanticcache/search.go index 30dffeec90..90a43be37a 100644 --- a/plugins/semanticcache/search.go +++ b/plugins/semanticcache/search.go @@ -13,9 +13,9 @@ import ( "github.com/maximhq/bifrost/framework/vectorstore" ) -func (plugin *Plugin) performDirectSearch(ctx *context.Context, req *schemas.BifrostRequest, requestType schemas.RequestType, cacheKey string) (*schemas.PluginShortCircuit, error) { +func (plugin *Plugin) performDirectSearch(ctx *context.Context, req *schemas.BifrostRequest, cacheKey string) (*schemas.PluginShortCircuit, error) { // Generate hash for the request - hash, err := plugin.generateRequestHash(req, requestType) + hash, err := plugin.generateRequestHash(req) if err != nil { return nil, fmt.Errorf("failed to generate request hash: %w", err) } @@ -23,7 +23,7 @@ func (plugin *Plugin) performDirectSearch(ctx *context.Context, req *schemas.Bif plugin.logger.Debug(PluginLoggerPrefix + " Generated Hash for Request: " + hash) // Extract metadata for strict filtering - _, paramsHash, err := plugin.extractTextForEmbedding(req, requestType) + _, paramsHash, err := plugin.extractTextForEmbedding(req) if err != nil { return nil, fmt.Errorf("failed to extract metadata for filtering: %w", err) } @@ -51,7 +51,7 @@ func (plugin *Plugin) performDirectSearch(ctx *context.Context, req *schemas.Bif // Make a full copy so we don't mutate the original backing array selectFields := append([]string(nil), SelectFields...) - if plugin.isStreamingRequest(requestType) { + if bifrost.IsStreamRequestType(req.RequestType) { selectFields = removeField(selectFields, "response") } else { selectFields = removeField(selectFields, "stream_chunks") @@ -81,9 +81,9 @@ func (plugin *Plugin) performDirectSearch(ctx *context.Context, req *schemas.Bif } // performSemanticSearch performs semantic similarity search and returns matching response if found. -func (plugin *Plugin) performSemanticSearch(ctx *context.Context, req *schemas.BifrostRequest, requestType schemas.RequestType, cacheKey string) (*schemas.PluginShortCircuit, error) { +func (plugin *Plugin) performSemanticSearch(ctx *context.Context, req *schemas.BifrostRequest, cacheKey string) (*schemas.PluginShortCircuit, error) { // Extract text and metadata for embedding - text, paramsHash, err := plugin.extractTextForEmbedding(req, requestType) + text, paramsHash, err := plugin.extractTextForEmbedding(req) if err != nil { return nil, fmt.Errorf("failed to extract text for embedding: %w", err) } @@ -129,7 +129,7 @@ func (plugin *Plugin) performSemanticSearch(ctx *context.Context, req *schemas.B // Make a full copy so we don't mutate the original backing array selectFields := append([]string(nil), SelectFields...) - if plugin.isStreamingRequest(requestType) { + if bifrost.IsStreamRequestType(req.RequestType) { selectFields = removeField(selectFields, "response") } else { selectFields = removeField(selectFields, "stream_chunks") diff --git a/plugins/semanticcache/test_utils.go b/plugins/semanticcache/test_utils.go index 22c4b834aa..74e141953c 100644 --- a/plugins/semanticcache/test_utils.go +++ b/plugins/semanticcache/test_utils.go @@ -102,7 +102,10 @@ func (baseAccount *BaseAccount) GetConfigForProvider(providerKey schemas.ModelPr RetryBackoffInitial: 100 * time.Millisecond, RetryBackoffMax: 10 * time.Second, }, - ConcurrencyAndBufferSize: schemas.DefaultConcurrencyAndBufferSize, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 5, + BufferSize: 10, + }, }, nil } @@ -122,9 +125,9 @@ func NewTestSetup(t *testing.T) *TestSetup { } return NewTestSetupWithConfig(t, Config{ - Provider: schemas.OpenAI, - EmbeddingModel: "text-embedding-3-small", - Threshold: 0.8, + Provider: schemas.OpenAI, + EmbeddingModel: "text-embedding-3-small", + Threshold: 0.8, CleanUpOnShutdown: true, Keys: []schemas.Key{ { @@ -193,44 +196,43 @@ func clearTestKeysWithStore(t *testing.T, store vectorstore.VectorStore) { } // CreateBasicChatRequest creates a basic chat completion request for testing -func CreateBasicChatRequest(content string, temperature float64, maxTokens int) *schemas.BifrostRequest { - return &schemas.BifrostRequest{ +func CreateBasicChatRequest(content string, temperature float64, maxTokens int) *schemas.BifrostChatRequest { + return &schemas.BifrostChatRequest{ Provider: schemas.OpenAI, Model: "gpt-4o-mini", - Input: schemas.RequestInput{ - ChatCompletionInput: &[]schemas.BifrostMessage{ - { - Role: "user", - Content: schemas.MessageContent{ - ContentStr: &content, - }, + Input: []schemas.ChatMessage{ + { + Role: "user", + Content: schemas.ChatMessageContent{ + ContentStr: &content, }, }, }, - Params: &schemas.ModelParameters{ - Temperature: &temperature, - MaxTokens: &maxTokens, + Params: &schemas.ChatParameters{ + Temperature: &temperature, + MaxCompletionTokens: &maxTokens, }, } } // CreateStreamingChatRequest creates a streaming chat completion request for testing -func CreateStreamingChatRequest(content string, temperature float64, maxTokens int) *schemas.BifrostRequest { +func CreateStreamingChatRequest(content string, temperature float64, maxTokens int) *schemas.BifrostChatRequest { return CreateBasicChatRequest(content, temperature, maxTokens) } // CreateSpeechRequest creates a speech synthesis request for testing -func CreateSpeechRequest(input string, voice string) *schemas.BifrostRequest { - return &schemas.BifrostRequest{ +func CreateSpeechRequest(input string, voice string) *schemas.BifrostSpeechRequest { + return &schemas.BifrostSpeechRequest{ Provider: schemas.OpenAI, Model: "tts-1", - Input: schemas.RequestInput{ - SpeechInput: &schemas.SpeechInput{ - Input: input, - VoiceConfig: schemas.SpeechVoiceInput{ - Voice: &voice, - }, + Input: schemas.SpeechInput{ + Input: input, + }, + Params: &schemas.SpeechParameters{ + VoiceConfig: schemas.SpeechVoiceInput{ + Voice: &voice, }, + ResponseFormat: "mp3", }, } } @@ -283,18 +285,56 @@ func WaitForCache() { } // CreateEmbeddingRequest creates an embedding request for testing -func CreateEmbeddingRequest(texts []string) *schemas.BifrostRequest { - return &schemas.BifrostRequest{ +func CreateEmbeddingRequest(texts []string) *schemas.BifrostEmbeddingRequest { + return &schemas.BifrostEmbeddingRequest{ Provider: schemas.OpenAI, Model: "text-embedding-3-small", - Input: schemas.RequestInput{ - EmbeddingInput: &schemas.EmbeddingInput{ - Texts: texts, + Input: schemas.EmbeddingInput{ + Texts: texts, + }, + } +} + +// CreateBasicResponsesRequest creates a basic Responses API request for testing +func CreateBasicResponsesRequest(content string, temperature float64, maxTokens int) *schemas.BifrostResponsesRequest { + userRole := schemas.ResponsesInputMessageRoleUser + return &schemas.BifrostResponsesRequest{ + Provider: schemas.OpenAI, + Model: "gpt-4o-mini", + Input: []schemas.ResponsesMessage{ + { + Role: &userRole, + Content: &schemas.ResponsesMessageContent{ + ContentStr: &content, + }, }, }, + Params: &schemas.ResponsesParameters{ + Temperature: &temperature, + MaxOutputTokens: &maxTokens, + }, } } +// CreateResponsesRequestWithTools creates a Responses API request with tools for testing +func CreateResponsesRequestWithTools(content string, temperature float64, maxTokens int, tools []schemas.ResponsesTool) *schemas.BifrostResponsesRequest { + req := CreateBasicResponsesRequest(content, temperature, maxTokens) + req.Params.Tools = tools + return req +} + +// CreateResponsesRequestWithInstructions creates a Responses API request with system instructions +func CreateResponsesRequestWithInstructions(content string, instructions string, temperature float64, maxTokens int) *schemas.BifrostResponsesRequest { + req := CreateBasicResponsesRequest(content, temperature, maxTokens) + req.Params.Instructions = &instructions + return req +} + +// CreateStreamingResponsesRequest creates a streaming Responses API request for testing +func CreateStreamingResponsesRequest(content string, temperature float64, maxTokens int) *schemas.BifrostResponsesRequest { + return CreateBasicResponsesRequest(content, temperature, maxTokens) +} + // CreateContextWithCacheKey creates a context with the test cache key func CreateContextWithCacheKey(value string) context.Context { return context.WithValue(context.Background(), CacheKey, value) @@ -398,29 +438,27 @@ func CreateTestSetupWithThresholdAndExcludeSystem(t *testing.T, threshold int, e } // CreateConversationRequest creates a chat request with conversation history -func CreateConversationRequest(messages []schemas.BifrostMessage, temperature float64, maxTokens int) *schemas.BifrostRequest { - return &schemas.BifrostRequest{ +func CreateConversationRequest(messages []schemas.ChatMessage, temperature float64, maxTokens int) *schemas.BifrostChatRequest { + return &schemas.BifrostChatRequest{ Provider: schemas.OpenAI, Model: "gpt-4o-mini", - Input: schemas.RequestInput{ - ChatCompletionInput: &messages, - }, - Params: &schemas.ModelParameters{ - Temperature: &temperature, - MaxTokens: &maxTokens, + Input: messages, + Params: &schemas.ChatParameters{ + Temperature: &temperature, + MaxCompletionTokens: &maxTokens, }, } } // BuildConversationHistory creates a conversation history from pairs of user/assistant messages -func BuildConversationHistory(systemPrompt string, userAssistantPairs ...[]string) []schemas.BifrostMessage { - messages := []schemas.BifrostMessage{} +func BuildConversationHistory(systemPrompt string, userAssistantPairs ...[]string) []schemas.ChatMessage { + messages := []schemas.ChatMessage{} // Add system prompt if provided if systemPrompt != "" { - messages = append(messages, schemas.BifrostMessage{ - Role: schemas.ModelChatMessageRoleSystem, - Content: schemas.MessageContent{ + messages = append(messages, schemas.ChatMessage{ + Role: schemas.ChatMessageRoleSystem, + Content: schemas.ChatMessageContent{ ContentStr: &systemPrompt, }, }) @@ -430,18 +468,18 @@ func BuildConversationHistory(systemPrompt string, userAssistantPairs ...[]strin for _, pair := range userAssistantPairs { if len(pair) >= 1 && pair[0] != "" { userMsg := pair[0] - messages = append(messages, schemas.BifrostMessage{ - Role: schemas.ModelChatMessageRoleUser, - Content: schemas.MessageContent{ + messages = append(messages, schemas.ChatMessage{ + Role: schemas.ChatMessageRoleUser, + Content: schemas.ChatMessageContent{ ContentStr: &userMsg, }, }) } if len(pair) >= 2 && pair[1] != "" { assistantMsg := pair[1] - messages = append(messages, schemas.BifrostMessage{ - Role: schemas.ModelChatMessageRoleAssistant, - Content: schemas.MessageContent{ + messages = append(messages, schemas.ChatMessage{ + Role: schemas.ChatMessageRoleAssistant, + Content: schemas.ChatMessageContent{ ContentStr: &assistantMsg, }, }) @@ -452,10 +490,10 @@ func BuildConversationHistory(systemPrompt string, userAssistantPairs ...[]strin } // AddUserMessage adds a user message to existing conversation -func AddUserMessage(messages []schemas.BifrostMessage, userMessage string) []schemas.BifrostMessage { - newMessage := schemas.BifrostMessage{ - Role: schemas.ModelChatMessageRoleUser, - Content: schemas.MessageContent{ +func AddUserMessage(messages []schemas.ChatMessage, userMessage string) []schemas.ChatMessage { + newMessage := schemas.ChatMessage{ + Role: schemas.ChatMessageRoleUser, + Content: schemas.ChatMessageContent{ ContentStr: &userMessage, }, } diff --git a/plugins/semanticcache/utils.go b/plugins/semanticcache/utils.go index bcdaf99c9e..6715166658 100644 --- a/plugins/semanticcache/utils.go +++ b/plugins/semanticcache/utils.go @@ -22,13 +22,11 @@ func normalizeText(text string) string { // generateEmbedding generates an embedding for the given text using the configured provider. func (plugin *Plugin) generateEmbedding(ctx context.Context, text string) ([]float32, int, error) { // Create embedding request - embeddingReq := &schemas.BifrostRequest{ + embeddingReq := &schemas.BifrostEmbeddingRequest{ Provider: plugin.config.Provider, Model: plugin.config.EmbeddingModel, - Input: schemas.RequestInput{ - EmbeddingInput: &schemas.EmbeddingInput{ - Texts: []string{text}, - }, + Input: schemas.EmbeddingInput{ + Text: &text, }, } @@ -86,16 +84,32 @@ func (plugin *Plugin) generateEmbedding(ctx context.Context, text string) ([]flo // Returns: // - string: Hexadecimal representation of the xxhash // - error: Any error that occurred during request normalization or hashing -func (plugin *Plugin) generateRequestHash(req *schemas.BifrostRequest, requestType schemas.RequestType) (string, error) { +func (plugin *Plugin) generateRequestHash(req *schemas.BifrostRequest) (string, error) { // Create a hash input structure that includes both input and parameters hashInput := struct { - Input schemas.RequestInput `json:"input"` - Params *schemas.ModelParameters `json:"params,omitempty"` - Stream bool `json:"stream,omitempty"` + Input interface{} `json:"input"` + Params interface{} `json:"params,omitempty"` + Stream bool `json:"stream,omitempty"` }{ - Input: *plugin.getInputForCaching(req), - Params: req.Params, - Stream: plugin.isStreamingRequest(requestType), + Input: plugin.getInputForCaching(req), + Stream: bifrost.IsStreamRequestType(req.RequestType), + } + + switch req.RequestType { + case schemas.TextCompletionRequest: + hashInput.Params = req.TextCompletionRequest.Params + case schemas.ChatCompletionRequest, schemas.ChatCompletionStreamRequest: + hashInput.Params = req.ChatRequest.Params + case schemas.ResponsesRequest, schemas.ResponsesStreamRequest: + hashInput.Params = req.ResponsesRequest.Params + case schemas.SpeechRequest, schemas.SpeechStreamRequest: + if req.SpeechRequest != nil { + hashInput.Params = req.SpeechRequest.Params + } + case schemas.EmbeddingRequest: + hashInput.Params = req.EmbeddingRequest.Params + case schemas.TranscriptionRequest, schemas.TranscriptionStreamRequest: + hashInput.Params = req.TranscriptionRequest.Params } // Marshal to JSON for consistent hashing @@ -111,99 +125,141 @@ func (plugin *Plugin) generateRequestHash(req *schemas.BifrostRequest, requestTy // extractTextForEmbedding extracts meaningful text from different input types for embedding generation. // Returns the text to embed and metadata for storage. -func (plugin *Plugin) extractTextForEmbedding(req *schemas.BifrostRequest, requestType schemas.RequestType) (string, string, error) { +func (plugin *Plugin) extractTextForEmbedding(req *schemas.BifrostRequest) (string, string, error) { metadata := map[string]interface{}{} attachments := []string{} - // Add parameters as metadata if present - if req.Params != nil { - if req.Params.ToolChoice != nil { - if req.Params.ToolChoice.ToolChoiceStr != nil { - metadata["tool_choice"] = *req.Params.ToolChoice.ToolChoiceStr - } else if req.Params.ToolChoice.ToolChoiceStruct != nil { - metadata["tool_choice"] = (*req.Params.ToolChoice.ToolChoiceStruct).Function.Name - } - } - if req.Params.Temperature != nil { - metadata["temperature"] = *req.Params.Temperature + // Add parameters as metadata if present - handle segregated parameters + metadata["stream"] = bifrost.IsStreamRequestType(req.RequestType) + + // Extract parameters based on request type + switch req.RequestType { + case schemas.TextCompletionRequest: + if req.TextCompletionRequest != nil && req.TextCompletionRequest.Params != nil { + plugin.extractTextCompletionParametersToMetadata(req.TextCompletionRequest.Params, metadata) } - if req.Params.TopP != nil { - metadata["top_p"] = *req.Params.TopP + case schemas.ChatCompletionRequest, schemas.ChatCompletionStreamRequest: + if req.ChatRequest != nil && req.ChatRequest.Params != nil { + plugin.extractChatParametersToMetadata(req.ChatRequest.Params, metadata) } - if req.Params.TopK != nil { - metadata["top_k"] = *req.Params.TopK + case schemas.ResponsesRequest, schemas.ResponsesStreamRequest: + if req.ResponsesRequest != nil && req.ResponsesRequest.Params != nil { + plugin.extractResponsesParametersToMetadata(req.ResponsesRequest.Params, metadata) } - if req.Params.MaxTokens != nil { - metadata["max_tokens"] = *req.Params.MaxTokens + case schemas.SpeechRequest, schemas.SpeechStreamRequest: + if req.SpeechRequest != nil && req.SpeechRequest.Params != nil { + plugin.extractSpeechParametersToMetadata(req.SpeechRequest.Params, metadata) } - if req.Params.StopSequences != nil { - metadata["stop_sequences"] = *req.Params.StopSequences + case schemas.EmbeddingRequest: + if req.EmbeddingRequest != nil && req.EmbeddingRequest.Params != nil { + plugin.extractEmbeddingParametersToMetadata(req.EmbeddingRequest.Params, metadata) } - if req.Params.PresencePenalty != nil { - metadata["presence_penalty"] = *req.Params.PresencePenalty + case schemas.TranscriptionRequest, schemas.TranscriptionStreamRequest: + if req.TranscriptionRequest != nil && req.TranscriptionRequest.Params != nil { + plugin.extractTranscriptionParametersToMetadata(req.TranscriptionRequest.Params, metadata) } - if req.Params.FrequencyPenalty != nil { - metadata["frequency_penalty"] = *req.Params.FrequencyPenalty + } + + switch { + case req.TextCompletionRequest != nil: + metadataHash, err := getMetadataHash(metadata) + if err != nil { + return "", "", fmt.Errorf("failed to marshal metadata for metadata hash: %w", err) } - if req.Params.ParallelToolCalls != nil { - metadata["parallel_tool_calls"] = *req.Params.ParallelToolCalls + + var textContent string + if req.TextCompletionRequest.Input.PromptStr != nil { + textContent = normalizeText(*req.TextCompletionRequest.Input.PromptStr) + } else if len(req.TextCompletionRequest.Input.PromptArray) > 0 { + textContent = normalizeText(strings.Join(req.TextCompletionRequest.Input.PromptArray, " ")) } - if req.Params.User != nil { - metadata["user"] = *req.Params.User + return textContent, metadataHash, nil + + case req.ChatRequest != nil: + reqInput, ok := plugin.getInputForCaching(req).([]schemas.ChatMessage) + if !ok { + return "", "", fmt.Errorf("failed to cast request input to chat messages") } - if len(req.Params.ExtraParams) > 0 { - maps.Copy(metadata, req.Params.ExtraParams) + // Serialize chat messages for embedding + var textParts []string + for _, msg := range reqInput { + // Extract content as string + var content string + if msg.Content.ContentStr != nil { + content = *msg.Content.ContentStr + } else if msg.Content.ContentBlocks != nil { + // For content blocks, extract text parts + var blockTexts []string + for _, block := range *msg.Content.ContentBlocks { + if block.Text != nil { + blockTexts = append(blockTexts, normalizeText(*block.Text)) + } + if block.ImageURLStruct != nil && block.ImageURLStruct.URL != "" { + attachments = append(attachments, block.ImageURLStruct.URL) + } + } + content = strings.Join(blockTexts, " ") + } + + if content != "" { + textParts = append(textParts, fmt.Sprintf("%s: %s", msg.Role, content)) + } } - } - metadata["stream"] = plugin.isStreamingRequest(requestType) + if len(textParts) == 0 { + return "", "", fmt.Errorf("no text content found in chat messages") + } - if req.Params != nil && req.Params.Tools != nil { - if toolsJSON, err := json.Marshal(*req.Params.Tools); err != nil { - plugin.logger.Warn(fmt.Sprintf("%s Failed to marshal tools for metadata: %v", PluginLoggerPrefix, err)) - } else { - toolHash := xxhash.Sum64(toolsJSON) - metadata["tools_hash"] = fmt.Sprintf("%x", toolHash) + if len(attachments) > 0 { + metadata["attachments"] = attachments } - } - switch { - case req.Input.TextCompletionInput != nil: metadataHash, err := getMetadataHash(metadata) if err != nil { return "", "", fmt.Errorf("failed to marshal metadata for metadata hash: %w", err) } - return *req.Input.TextCompletionInput, metadataHash, nil + return strings.Join(textParts, "\n"), metadataHash, nil - case req.Input.ChatCompletionInput != nil: - reqInput := plugin.getInputForCaching(req) + case req.ResponsesRequest != nil: + reqInput, ok := plugin.getInputForCaching(req).([]schemas.ResponsesMessage) + if !ok { + return "", "", fmt.Errorf("failed to cast request input to responses messages") + } // Serialize chat messages for embedding var textParts []string - for _, msg := range *reqInput.ChatCompletionInput { + for _, msg := range reqInput { // Extract content as string var content string if msg.Content.ContentStr != nil { - content = *msg.Content.ContentStr + content = normalizeText(*msg.Content.ContentStr) } else if msg.Content.ContentBlocks != nil { // For content blocks, extract text parts var blockTexts []string for _, block := range *msg.Content.ContentBlocks { if block.Text != nil { - blockTexts = append(blockTexts, *block.Text) + blockTexts = append(blockTexts, normalizeText(*block.Text)) + } + if block.ResponsesInputMessageContentBlockImage != nil && block.ResponsesInputMessageContentBlockImage.ImageURL != nil { + attachments = append(attachments, *block.ResponsesInputMessageContentBlockImage.ImageURL) } - if block.ImageURL != nil && block.ImageURL.URL != "" { - attachments = append(attachments, block.ImageURL.URL) + if block.ResponsesInputMessageContentBlockFile != nil && block.ResponsesInputMessageContentBlockFile.FileURL != nil { + attachments = append(attachments, *block.ResponsesInputMessageContentBlockFile.FileURL) } } content = strings.Join(blockTexts, " ") } + role := "" + if msg.Role != nil { + role = string(*msg.Role) + } + if content != "" { - textParts = append(textParts, fmt.Sprintf("%s: %s", msg.Role, content)) + textParts = append(textParts, fmt.Sprintf("%s: %s", role, content)) } } @@ -222,31 +278,27 @@ func (plugin *Plugin) extractTextForEmbedding(req *schemas.BifrostRequest, reque return strings.Join(textParts, "\n"), metadataHash, nil - case req.Input.SpeechInput != nil: - if req.Input.SpeechInput.Input != "" { - if req.Input.SpeechInput.VoiceConfig.Voice != nil { - metadata["voice"] = *req.Input.SpeechInput.VoiceConfig.Voice - } - + case req.SpeechRequest != nil: + if req.SpeechRequest.Input.Input != "" { metadataHash, err := getMetadataHash(metadata) if err != nil { return "", "", fmt.Errorf("failed to marshal metadata for metadata hash: %w", err) } - return req.Input.SpeechInput.Input, metadataHash, nil + return req.SpeechRequest.Input.Input, metadataHash, nil } return "", "", fmt.Errorf("no input text found in speech request") - case req.Input.EmbeddingInput != nil: + case req.EmbeddingRequest != nil: metadataHash, err := getMetadataHash(metadata) if err != nil { return "", "", fmt.Errorf("failed to marshal metadata for metadata hash: %w", err) } - texts := req.Input.EmbeddingInput.Texts + texts := req.EmbeddingRequest.Input.Texts - if len(texts) == 0 && req.Input.EmbeddingInput.Text != nil { - texts = []string{*req.Input.EmbeddingInput.Text} + if len(texts) == 0 && req.EmbeddingRequest.Input.Text != nil { + texts = []string{*req.EmbeddingRequest.Input.Text} } var text string @@ -256,7 +308,7 @@ func (plugin *Plugin) extractTextForEmbedding(req *schemas.BifrostRequest, reque return strings.TrimSpace(text), metadataHash, nil - case req.Input.TranscriptionInput != nil: + case req.TranscriptionRequest != nil: // Skip semantic caching for transcription requests return "", "", fmt.Errorf("transcription requests are not supported for semantic caching") @@ -273,13 +325,6 @@ func getMetadataHash(metadata map[string]interface{}) (string, error) { return fmt.Sprintf("%x", xxhash.Sum64(metadataJSON)), nil } -// isStreamingRequest checks if the request is a streaming request -func (plugin *Plugin) isStreamingRequest(requestType schemas.RequestType) bool { - return requestType == schemas.ChatCompletionStreamRequest || - requestType == schemas.SpeechStreamRequest || - requestType == schemas.TranscriptionStreamRequest -} - // buildUnifiedMetadata constructs the unified metadata structure for VectorEntry func (plugin *Plugin) buildUnifiedMetadata(provider schemas.ModelProvider, model string, paramsHash string, requestHash string, cacheKey string, ttl time.Duration) map[string]interface{} { unifiedMetadata := make(map[string]interface{}) @@ -381,23 +426,32 @@ func (plugin *Plugin) addStreamingResponse(ctx context.Context, responseID strin // getInputForCaching returns a normalized and sanitized copy of req.Input for hashing/embedding. // It applies text normalization (lowercase + trim) and optionally removes system messages. -func (plugin *Plugin) getInputForCaching(req *schemas.BifrostRequest) *schemas.RequestInput { - reqInput := req.Input - - // Handle text completion normalization - if reqInput.TextCompletionInput != nil { - normalizedText := normalizeText(*reqInput.TextCompletionInput) - reqInput.TextCompletionInput = &normalizedText - } - - // Handle chat completion normalization - if reqInput.ChatCompletionInput != nil { - originalMessages := *reqInput.ChatCompletionInput - normalizedMessages := make([]schemas.BifrostMessage, 0, len(originalMessages)) +func (plugin *Plugin) getInputForCaching(req *schemas.BifrostRequest) interface{} { + switch req.RequestType { + case schemas.TextCompletionRequest: + // Create a shallow copy of the input to avoid mutating the original request + copiedInput := req.TextCompletionRequest.Input + + if copiedInput.PromptStr != nil { + normalizedText := normalizeText(*copiedInput.PromptStr) + copiedInput.PromptStr = &normalizedText + } else if len(copiedInput.PromptArray) > 0 { + // Create a copy of the PromptArray and normalize each element + normalizedPromptArray := make([]string, len(copiedInput.PromptArray)) + copy(normalizedPromptArray, copiedInput.PromptArray) + for i, prompt := range normalizedPromptArray { + normalizedPromptArray[i] = normalizeText(prompt) + } + copiedInput.PromptArray = normalizedPromptArray + } + return copiedInput + case schemas.ChatCompletionRequest, schemas.ChatCompletionStreamRequest: + originalMessages := req.ChatRequest.Input + normalizedMessages := make([]schemas.ChatMessage, 0, len(originalMessages)) for _, msg := range originalMessages { // Skip system messages if configured to exclude them - if plugin.config.ExcludeSystemPrompt != nil && *plugin.config.ExcludeSystemPrompt && msg.Role == schemas.ModelChatMessageRoleSystem { + if plugin.config.ExcludeSystemPrompt != nil && *plugin.config.ExcludeSystemPrompt && msg.Role == schemas.ChatMessageRoleSystem { continue } @@ -410,7 +464,7 @@ func (plugin *Plugin) getInputForCaching(req *schemas.BifrostRequest) *schemas.R normalizedMsg.Content.ContentStr = &normalizedContent } else if msg.Content.ContentBlocks != nil { // Create a copy of content blocks with normalized text - normalizedBlocks := make([]schemas.ContentBlock, len(*msg.Content.ContentBlocks)) + normalizedBlocks := make([]schemas.ChatContentBlock, len(*msg.Content.ContentBlocks)) for i, block := range *msg.Content.ContentBlocks { normalizedBlocks[i] = block if block.Text != nil { @@ -423,16 +477,64 @@ func (plugin *Plugin) getInputForCaching(req *schemas.BifrostRequest) *schemas.R normalizedMessages = append(normalizedMessages, normalizedMsg) } + return normalizedMessages + case schemas.ResponsesRequest, schemas.ResponsesStreamRequest: + originalMessages := req.ResponsesRequest.Input + normalizedMessages := make([]schemas.ResponsesMessage, 0, len(originalMessages)) - reqInput.ChatCompletionInput = &normalizedMessages - } + for _, msg := range originalMessages { + // Skip system messages if configured to exclude them + if plugin.config.ExcludeSystemPrompt != nil && *plugin.config.ExcludeSystemPrompt && msg.Role != nil && *msg.Role == schemas.ResponsesInputMessageRoleSystem { + continue + } - if reqInput.SpeechInput != nil { - normalizedInput := normalizeText(reqInput.SpeechInput.Input) - reqInput.SpeechInput.Input = normalizedInput - } + // Create a deep copy of the message with normalized content + normalizedMsg := msg - return &reqInput + // Create a deep copy of the Content to avoid modifying the original + if msg.Content != nil { + normalizedContent := &schemas.ResponsesMessageContent{} + if msg.Content.ContentStr != nil { + normalizedText := normalizeText(*msg.Content.ContentStr) + normalizedContent.ContentStr = &normalizedText + } else if msg.Content.ContentBlocks != nil { + // Create a copy of content blocks with normalized text + normalizedBlocks := make([]schemas.ResponsesMessageContentBlock, len(*msg.Content.ContentBlocks)) + for i, block := range *msg.Content.ContentBlocks { + normalizedBlocks[i] = block + if block.Text != nil { + normalizedText := normalizeText(*block.Text) + normalizedBlocks[i].Text = &normalizedText + } + } + normalizedContent.ContentBlocks = &normalizedBlocks + } + normalizedMsg.Content = normalizedContent + } + + normalizedMessages = append(normalizedMessages, normalizedMsg) + } + return normalizedMessages + case schemas.SpeechRequest, schemas.SpeechStreamRequest: + return normalizeText(req.SpeechRequest.Input.Input) + case schemas.EmbeddingRequest: + input := req.EmbeddingRequest.Input + if input.Text != nil { + normalizedText := normalizeText(*input.Text) + return schemas.EmbeddingInput{Text: &normalizedText} + } else if len(input.Texts) > 0 { + normalizedTexts := make([]string, len(input.Texts)) + for i, text := range input.Texts { + normalizedTexts[i] = normalizeText(text) + } + return schemas.EmbeddingInput{Texts: normalizedTexts} + } + return input + case schemas.TranscriptionRequest, schemas.TranscriptionStreamRequest: + return req.TranscriptionRequest.Input + default: + return nil + } } // removeField removes the first occurrence of target from the slice. @@ -446,12 +548,303 @@ func removeField(arr []string, target string) []string { return arr // unchanged if target not found } -// isConversationHistoryThresholdExceeded checks if the conversation history threshold is exceeded +// extractChatParametersToMetadata extracts Chat API parameters into metadata map +func (plugin *Plugin) extractChatParametersToMetadata(params *schemas.ChatParameters, metadata map[string]interface{}) { + if params.ToolChoice != nil { + if params.ToolChoice.ChatToolChoiceStr != nil { + metadata["tool_choice"] = *params.ToolChoice.ChatToolChoiceStr + } else if params.ToolChoice.ChatToolChoiceStruct != nil && params.ToolChoice.ChatToolChoiceStruct.Function.Name != "" { + metadata["tool_choice"] = params.ToolChoice.ChatToolChoiceStruct.Function.Name + } + } + if params.Temperature != nil { + metadata["temperature"] = *params.Temperature + } + if params.TopP != nil { + metadata["top_p"] = *params.TopP + } + if params.MaxCompletionTokens != nil { + metadata["max_tokens"] = *params.MaxCompletionTokens + } + if params.Stop != nil { + metadata["stop_sequences"] = *params.Stop + } + if params.PresencePenalty != nil { + metadata["presence_penalty"] = *params.PresencePenalty + } + if params.FrequencyPenalty != nil { + metadata["frequency_penalty"] = *params.FrequencyPenalty + } + if params.ParallelToolCalls != nil { + metadata["parallel_tool_calls"] = *params.ParallelToolCalls + } + if params.User != nil { + metadata["user"] = *params.User + } + if params.LogitBias != nil { + metadata["logit_bias"] = *params.LogitBias + } + if params.LogProbs != nil { + metadata["logprobs"] = *params.LogProbs + } + if params.Modalities != nil { + metadata["modalities"] = *params.Modalities + } + if params.PromptCacheKey != nil { + metadata["prompt_cache_key"] = *params.PromptCacheKey + } + if params.ReasoningEffort != nil { + metadata["reasoning_effort"] = *params.ReasoningEffort + } + if params.ResponseFormat != nil { + metadata["response_format"] = params.ResponseFormat + } + if params.SafetyIdentifier != nil { + metadata["safety_identifier"] = *params.SafetyIdentifier + } + if params.Seed != nil { + metadata["seed"] = *params.Seed + } + if params.ServiceTier != nil { + metadata["service_tier"] = *params.ServiceTier + } + if params.Store != nil { + metadata["store"] = *params.Store + } + if params.TopLogProbs != nil { + metadata["top_logprobs"] = *params.TopLogProbs + } + if params.Verbosity != nil { + metadata["verbosity"] = *params.Verbosity + } + if len(params.ExtraParams) > 0 { + maps.Copy(metadata, params.ExtraParams) + } + if len(params.Tools) > 0 { + if toolsJSON, err := json.Marshal(params.Tools); err != nil { + plugin.logger.Warn(fmt.Sprintf("%s Failed to marshal tools for metadata: %v", PluginLoggerPrefix, err)) + } else { + toolHash := xxhash.Sum64(toolsJSON) + metadata["tools_hash"] = fmt.Sprintf("%x", toolHash) + } + } +} + +// extractResponsesParametersToMetadata extracts Responses API parameters into metadata map +func (plugin *Plugin) extractResponsesParametersToMetadata(params *schemas.ResponsesParameters, metadata map[string]interface{}) { + if params.ToolChoice != nil { + if params.ToolChoice.ResponsesToolChoiceStr != nil { + metadata["tool_choice"] = *params.ToolChoice.ResponsesToolChoiceStr + } else if params.ToolChoice.ResponsesToolChoiceStruct != nil && params.ToolChoice.ResponsesToolChoiceStruct.Name != nil { + metadata["tool_choice"] = *params.ToolChoice.ResponsesToolChoiceStruct.Name + } + } + if params.Temperature != nil { + metadata["temperature"] = *params.Temperature + } + if params.TopP != nil { + metadata["top_p"] = *params.TopP + } + if params.MaxOutputTokens != nil { + metadata["max_tokens"] = *params.MaxOutputTokens + } + if params.ParallelToolCalls != nil { + metadata["parallel_tool_calls"] = *params.ParallelToolCalls + } + if params.Background != nil { + metadata["background"] = *params.Background + } + if params.Conversation != nil { + metadata["conversation"] = *params.Conversation + } + if params.Include != nil { + metadata["include"] = *params.Include + } + if params.Instructions != nil { + metadata["instructions"] = *params.Instructions + } + if params.MaxToolCalls != nil { + metadata["max_tool_calls"] = *params.MaxToolCalls + } + if params.PreviousResponseID != nil { + metadata["previous_response_id"] = *params.PreviousResponseID + } + if params.PromptCacheKey != nil { + metadata["prompt_cache_key"] = *params.PromptCacheKey + } + if params.Reasoning != nil { + if params.Reasoning.Effort != nil { + metadata["reasoning_effort"] = *params.Reasoning.Effort + } + if params.Reasoning.Summary != nil { + metadata["reasoning_summary"] = *params.Reasoning.Summary + } + } + if params.SafetyIdentifier != nil { + metadata["safety_identifier"] = *params.SafetyIdentifier + } + if params.ServiceTier != nil { + metadata["service_tier"] = *params.ServiceTier + } + if params.Store != nil { + metadata["store"] = *params.Store + } + if params.Text != nil { + if params.Text.Verbosity != nil { + metadata["text_verbosity"] = *params.Text.Verbosity + } + if params.Text.Format != nil { + metadata["text_format_type"] = params.Text.Format.Type + } + } + if params.TopLogProbs != nil { + metadata["top_logprobs"] = *params.TopLogProbs + } + if params.Truncation != nil { + metadata["truncation"] = *params.Truncation + } + if len(params.ExtraParams) > 0 { + maps.Copy(metadata, params.ExtraParams) + } + if len(params.Tools) > 0 { + if toolsJSON, err := json.Marshal(params.Tools); err != nil { + plugin.logger.Warn(fmt.Sprintf("%s Failed to marshal tools for metadata: %v", PluginLoggerPrefix, err)) + } else { + toolHash := xxhash.Sum64(toolsJSON) + metadata["tools_hash"] = fmt.Sprintf("%x", toolHash) + } + } +} + +// extractTextCompletionParametersToMetadata extracts Text Completion parameters into metadata map +func (plugin *Plugin) extractTextCompletionParametersToMetadata(params *schemas.TextCompletionParameters, metadata map[string]interface{}) { + if params.Temperature != nil { + metadata["temperature"] = *params.Temperature + } + if params.TopP != nil { + metadata["top_p"] = *params.TopP + } + if params.MaxTokens != nil { + metadata["max_tokens"] = *params.MaxTokens + } + if params.Stop != nil { + metadata["stop_sequences"] = *params.Stop + } + if params.PresencePenalty != nil { + metadata["presence_penalty"] = *params.PresencePenalty + } + if params.FrequencyPenalty != nil { + metadata["frequency_penalty"] = *params.FrequencyPenalty + } + if params.User != nil { + metadata["user"] = *params.User + } + if params.BestOf != nil { + metadata["best_of"] = *params.BestOf + } + if params.Echo != nil { + metadata["echo"] = *params.Echo + } + if params.LogitBias != nil { + metadata["logit_bias"] = *params.LogitBias + } + if params.LogProbs != nil { + metadata["logprobs"] = *params.LogProbs + } + if params.N != nil { + metadata["n"] = *params.N + } + if params.Seed != nil { + metadata["seed"] = *params.Seed + } + if params.Suffix != nil { + metadata["suffix"] = *params.Suffix + } + if len(params.ExtraParams) > 0 { + maps.Copy(metadata, params.ExtraParams) + } +} + +// extractSpeechParametersToMetadata extracts Speech parameters into metadata map +func (plugin *Plugin) extractSpeechParametersToMetadata(params *schemas.SpeechParameters, metadata map[string]interface{}) { + if params == nil { + return + } + + if params.Speed != nil { + metadata["speed"] = *params.Speed + } + if params.ResponseFormat != "" { + metadata["response_format"] = params.ResponseFormat + } + if params.Instructions != "" { + metadata["instructions"] = params.Instructions + } + // Check if VoiceConfig.Voice is non-nil before accessing it + if params.VoiceConfig.Voice != nil { + metadata["voice"] = *params.VoiceConfig.Voice + } + if len(params.VoiceConfig.MultiVoiceConfig) > 0 { + flattenedVC := make([]string, len(params.VoiceConfig.MultiVoiceConfig)) + for i, vc := range params.VoiceConfig.MultiVoiceConfig { + flattenedVC[i] = fmt.Sprintf("%s:%s", vc.Speaker, vc.Voice) + } + metadata["multi_voice_count"] = flattenedVC + } + if len(params.ExtraParams) > 0 { + maps.Copy(metadata, params.ExtraParams) + } +} + +// extractEmbeddingParametersToMetadata extracts Embedding parameters into metadata map +func (plugin *Plugin) extractEmbeddingParametersToMetadata(params *schemas.EmbeddingParameters, metadata map[string]interface{}) { + if params.EncodingFormat != nil { + metadata["encoding_format"] = *params.EncodingFormat + } + if params.Dimensions != nil { + metadata["dimensions"] = *params.Dimensions + } + if len(params.ExtraParams) > 0 { + maps.Copy(metadata, params.ExtraParams) + } +} + +// extractTranscriptionParametersToMetadata extracts Transcription parameters into metadata map +func (plugin *Plugin) extractTranscriptionParametersToMetadata(params *schemas.TranscriptionParameters, metadata map[string]interface{}) { + if params.Language != nil { + metadata["language"] = *params.Language + } + if params.ResponseFormat != nil { + metadata["response_format"] = *params.ResponseFormat + } + if params.Prompt != nil { + metadata["prompt"] = *params.Prompt + } + if params.Format != nil { + metadata["file_format"] = *params.Format + } + if len(params.ExtraParams) > 0 { + maps.Copy(metadata, params.ExtraParams) + } +} + func (plugin *Plugin) isConversationHistoryThresholdExceeded(req *schemas.BifrostRequest) bool { switch { - case req.Input.ChatCompletionInput != nil: - input := plugin.getInputForCaching(req) - if len(*input.ChatCompletionInput) > plugin.config.ConversationHistoryThreshold { + case req.ChatRequest != nil: + input, ok := plugin.getInputForCaching(req).([]schemas.ChatMessage) + if !ok { + return false + } + if len(input) > plugin.config.ConversationHistoryThreshold { + return true + } + return false + case req.ResponsesRequest != nil: + input, ok := plugin.getInputForCaching(req).([]schemas.ResponsesMessage) + if !ok { + return false + } + if len(input) > plugin.config.ConversationHistoryThreshold { return true } return false diff --git a/plugins/telemetry/main.go b/plugins/telemetry/main.go index ce99ceb8f8..494ef61ffb 100644 --- a/plugins/telemetry/main.go +++ b/plugins/telemetry/main.go @@ -81,13 +81,11 @@ func (p *PrometheusPlugin) PreHook(ctx *context.Context, req *schemas.BifrostReq // - Request latency // - Total request count func (p *PrometheusPlugin) PostHook(ctx *context.Context, result *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) { - if result == nil { - return result, bifrostErr, nil - } + requestType, provider, model := bifrost.GetRequestFields(result, bifrostErr) - requestType, ok := (*ctx).Value(schemas.BifrostContextKeyRequestType).(schemas.RequestType) + startTime, ok := (*ctx).Value(startTimeKey).(time.Time) if !ok { - log.Println("Warning: request type not found in context for Prometheus PostHook") + log.Println("Warning: startTime not found in context for Prometheus PostHook") return result, bifrostErr, nil } @@ -107,41 +105,17 @@ func (p *PrometheusPlugin) PostHook(ctx *context.Context, result *schemas.Bifros // This is the final chunk - continue with metrics recording } - startTime, ok := (*ctx).Value(startTimeKey).(time.Time) - if !ok { - log.Println("Warning: startTime not found in context for Prometheus PostHook") - return result, bifrostErr, nil - } - - provider, ok := (*ctx).Value(schemas.BifrostContextKeyRequestProvider).(schemas.ModelProvider) - if !ok { - log.Println("Warning: provider not found in context for Prometheus PostHook") - return result, bifrostErr, nil - } - - model, ok := (*ctx).Value(schemas.BifrostContextKeyRequestModel).(string) - if !ok { - log.Println("Warning: model not found in context for Prometheus PostHook") - return result, bifrostErr, nil - } - - method, ok := (*ctx).Value(schemas.BifrostContextKeyRequestType).(schemas.RequestType) - if !ok { - log.Println("Warning: method not found in context for Prometheus PostHook") - return result, bifrostErr, nil - } - // Calculate cost and record metrics in a separate goroutine to avoid blocking the main thread go func() { cost := 0.0 - if p.pricingManager != nil { - cost = p.pricingManager.CalculateCostWithCacheDebug(result, provider, model, requestType) + if p.pricingManager != nil && result != nil { + cost = p.pricingManager.CalculateCostWithCacheDebug(result) } labelValues := map[string]string{ "provider": string(provider), "model": model, - "method": string(method), + "method": string(requestType), } // Get all prometheus labels from context @@ -167,29 +141,34 @@ func (p *PrometheusPlugin) PostHook(ctx *context.Context, result *schemas.Bifros // Record error and success counts if bifrostErr != nil { - p.ErrorRequestsTotal.WithLabelValues(promLabelValues...).Inc() + // Add reason to label values + errorLabelValues := append(promLabelValues[:3], bifrostErr.Error.Message) // provider, model, method, reason + errorLabelValues = append(errorLabelValues, promLabelValues[3:]...) // then custom labels + p.ErrorRequestsTotal.WithLabelValues(errorLabelValues...).Inc() } else { p.SuccessRequestsTotal.WithLabelValues(promLabelValues...).Inc() } - // Record input and output tokens - if result.Usage != nil { - p.InputTokensTotal.WithLabelValues(promLabelValues...).Add(float64(result.Usage.PromptTokens)) - p.OutputTokensTotal.WithLabelValues(promLabelValues...).Add(float64(result.Usage.CompletionTokens)) - } - - // Record cache hits with cache type - if result.ExtraFields.CacheDebug != nil && result.ExtraFields.CacheDebug.CacheHit { - cacheType := "unknown" - if result.ExtraFields.CacheDebug.HitType != nil { - cacheType = *result.ExtraFields.CacheDebug.HitType + if result != nil { + // Record input and output tokens + if result.Usage != nil { + p.InputTokensTotal.WithLabelValues(promLabelValues...).Add(float64(result.Usage.PromptTokens)) + p.OutputTokensTotal.WithLabelValues(promLabelValues...).Add(float64(result.Usage.CompletionTokens)) } - // Add cache_type to label values - cacheHitLabelValues := append(promLabelValues[:3], cacheType) // provider, model, method, cache_type - cacheHitLabelValues = append(cacheHitLabelValues, promLabelValues[3:]...) // then custom labels + // Record cache hits with cache type + if result.ExtraFields.CacheDebug != nil && result.ExtraFields.CacheDebug.CacheHit { + cacheType := "unknown" + if result.ExtraFields.CacheDebug.HitType != nil { + cacheType = *result.ExtraFields.CacheDebug.HitType + } + + // Add cache_type to label values + cacheHitLabelValues := append(promLabelValues[:3], cacheType) // provider, model, method, cache_type + cacheHitLabelValues = append(cacheHitLabelValues, promLabelValues[3:]...) // then custom labels - p.CacheHitsTotal.WithLabelValues(cacheHitLabelValues...).Inc() + p.CacheHitsTotal.WithLabelValues(cacheHitLabelValues...).Inc() + } } }() diff --git a/plugins/telemetry/setup.go b/plugins/telemetry/setup.go index 70ae1ed4e2..ed7576d24b 100644 --- a/plugins/telemetry/setup.go +++ b/plugins/telemetry/setup.go @@ -136,7 +136,7 @@ func InitPrometheusMetrics(labels []string) { Name: "bifrost_error_requests_total", Help: "Total number of error requests forwarded to upstream providers by Bifrost.", }, - append(bifrostDefaultLabels, labels...), + append(append(bifrostDefaultLabels, "reason"), labels...), ) bifrostInputTokensTotal = promauto.NewCounterVec( diff --git a/tests/core-chatbot/main.go b/tests/core-chatbot/main.go index c98d410f56..b212383827 100644 --- a/tests/core-chatbot/main.go +++ b/tests/core-chatbot/main.go @@ -30,7 +30,7 @@ type ChatbotConfig struct { // ChatSession manages the conversation state type ChatSession struct { - history []schemas.BifrostMessage + history []schemas.ChatMessage client *bifrost.Bifrost config ChatbotConfig systemPrompt string @@ -280,7 +280,7 @@ func NewChatSession(config ChatbotConfig) (*ChatSession, error) { } session := &ChatSession{ - history: make([]schemas.BifrostMessage, 0), + history: make([]schemas.ChatMessage, 0), client: client, config: config, account: account, @@ -290,9 +290,9 @@ func NewChatSession(config ChatbotConfig) (*ChatSession, error) { // Add system message to history if session.systemPrompt != "" { - session.history = append(session.history, schemas.BifrostMessage{ + session.history = append(session.history, schemas.ChatMessage{ Role: schemas.ModelChatMessageRoleSystem, - Content: schemas.MessageContent{ + Content: schemas.ChatMessageContent{ ContentStr: &session.systemPrompt, }, }) @@ -457,9 +457,9 @@ func (s *ChatSession) showCurrentConfig() { // AddUserMessage adds a user message to the conversation history func (s *ChatSession) AddUserMessage(message string) { - userMessage := schemas.BifrostMessage{ - Role: schemas.ModelChatMessageRoleUser, - Content: schemas.MessageContent{ + userMessage := schemas.ChatMessage{ + Role: schemas.ChatMessageRoleUser, + Content: schemas.ChatMessageContent{ ContentStr: &message, }, } @@ -540,7 +540,7 @@ func (s *ChatSession) SendMessage(message string) (string, error) { } // handleToolCalls handles tool execution using the new Bifrost MCP integration -func (s *ChatSession) handleToolCalls(assistantMessage schemas.BifrostMessage) (string, error) { +func (s *ChatSession) handleToolCalls(assistantMessage schemas.ChatMessage) (string, error) { toolCalls := *assistantMessage.ToolCalls // Display tools to user for approval @@ -568,7 +568,7 @@ func (s *ChatSession) handleToolCalls(assistantMessage schemas.BifrostMessage) ( fmt.Println("✅ Executing tools...") // Execute each tool using Bifrost's ExecuteMCPTool method - toolResults := make([]schemas.BifrostMessage, 0) + toolResults := make([]schemas.ChatMessage, 0) for _, toolCall := range toolCalls { // Start loading animation for this tool stopChan, wg := startLoader() @@ -582,9 +582,9 @@ func (s *ChatSession) handleToolCalls(assistantMessage schemas.BifrostMessage) ( if err != nil { fmt.Printf("❌ Error executing tool %s: %v\n", *toolCall.Function.Name, err) // Create error message for this tool - errorResult := schemas.BifrostMessage{ + errorResult := schemas.ChatMessage{ Role: schemas.ModelChatMessageRoleTool, - Content: schemas.MessageContent{ + Content: schemas.ChatMessageContent{ ContentStr: stringPtr(fmt.Sprintf("Error executing tool: %v", err)), }, ToolMessage: &schemas.ToolMessage{ @@ -622,9 +622,9 @@ func (s *ChatSession) handleToolCalls(assistantMessage schemas.BifrostMessage) ( // synthesizeToolResults sends the conversation with tool results back to LLM for synthesis func (s *ChatSession) synthesizeToolResults() (string, error) { // Add synthesis prompt - synthesisPrompt := schemas.BifrostMessage{ - Role: schemas.ModelChatMessageRoleUser, - Content: schemas.MessageContent{ + synthesisPrompt := schemas.ChatMessage{ + Role: schemas.ChatMessageRoleUser, + Content: schemas.ChatMessageContent{ ContentStr: stringPtr("Please provide a comprehensive response based on the tool results above."), }, } @@ -907,7 +907,7 @@ func main() { case "/clear": // Keep system prompt but clear conversation history systemPrompt := session.history[0] // Assuming first message is system - session.history = []schemas.BifrostMessage{systemPrompt} + session.history = []schemas.ChatMessage{systemPrompt} fmt.Println("🧹 Conversation history cleared!") continue case "/config": diff --git a/tests/core-providers/custom_test.go b/tests/core-providers/custom_test.go index 2a650cc06b..0cb14c7dc5 100644 --- a/tests/core-providers/custom_test.go +++ b/tests/core-providers/custom_test.go @@ -21,9 +21,9 @@ func TestCustomProvider(t *testing.T) { defer client.Shutdown() testConfig := config.ComprehensiveTestConfig{ - Provider: config.ProviderOpenAICustom, - ChatModel: "llama-3.3-70b-versatile", - TextModel: "", // OpenAI doesn't support text completion in newer models + Provider: config.ProviderOpenAICustom, + ChatModel: "llama-3.3-70b-versatile", + TextModel: "", // OpenAI doesn't support text completion in newer models EmbeddingModel: "", // groq custom base: embeddings not supported Scenarios: config.TestScenarios{ TextCompletion: false, // Not supported @@ -62,12 +62,11 @@ func TestCustomProvider_DisallowedOperation(t *testing.T) { defer cancel() defer client.Shutdown() - // Create a speech request to the custom provider prompt := "The future of artificial intelligence is" request := &schemas.BifrostRequest{ Provider: config.ProviderOpenAICustom, // Use the custom provider - Model: "llama-3.3-70b-versatile", // Use a model that exists for this provider + Model: "llama-3.3-70b-versatile", // Use a model that exists for this provider Input: schemas.RequestInput{ SpeechInput: &schemas.SpeechInput{ Input: prompt, @@ -107,10 +106,10 @@ func TestCustomProvider_MismatchedIdentity(t *testing.T) { Provider: wrongProvider, Model: "llama-3.3-70b-versatile", Input: schemas.RequestInput{ - ChatCompletionInput: &[]schemas.BifrostMessage{ + ChatCompletionInput: &[]schemas.ChatMessage{ { - Role: schemas.ModelChatMessageRoleUser, - Content: schemas.MessageContent{ + Role: schemas.ChatMessageRoleUser, + Content: schemas.ChatMessageContent{ ContentStr: bifrost.Ptr("Hello! What's the capital of France?"), }, }, diff --git a/tests/core-providers/openai_test.go b/tests/core-providers/openai_test.go index 69142fe41f..4321f18255 100644 --- a/tests/core-providers/openai_test.go +++ b/tests/core-providers/openai_test.go @@ -22,7 +22,8 @@ func TestOpenAI(t *testing.T) { TextModel: "", // OpenAI doesn't support text completion in newer models EmbeddingModel: "text-embedding-3-small", TranscriptionModel: "whisper-1", - SpeechSynthesisModel: "tts-1", + SpeechSynthesisModel: "gpt-4o-mini-tts", + ReasoningModel: "gpt-5", Scenarios: config.TestScenarios{ TextCompletion: false, // Not supported SimpleChat: true, diff --git a/tests/core-providers/scenarios/automatic_function_calling.go b/tests/core-providers/scenarios/automatic_function_calling.go index 08d9b41916..8d331bda57 100644 --- a/tests/core-providers/scenarios/automatic_function_calling.go +++ b/tests/core-providers/scenarios/automatic_function_calling.go @@ -19,7 +19,7 @@ func RunAutomaticFunctionCallingTest(t *testing.T, client *bifrost.Bifrost, ctx } t.Run("AutomaticFunctionCalling", func(t *testing.T) { - messages := []schemas.BifrostMessage{ + messages := []schemas.ChatMessage{ CreateBasicChatMessage("Get the current time in UTC timezone"), } diff --git a/tests/core-providers/scenarios/chat_completion_stream.go b/tests/core-providers/scenarios/chat_completion_stream.go index 672a5fd85f..e77b9805fb 100644 --- a/tests/core-providers/scenarios/chat_completion_stream.go +++ b/tests/core-providers/scenarios/chat_completion_stream.go @@ -22,7 +22,7 @@ func RunChatCompletionStreamTest(t *testing.T, client *bifrost.Bifrost, ctx cont } t.Run("ChatCompletionStream", func(t *testing.T) { - messages := []schemas.BifrostMessage{ + messages := []schemas.ChatMessage{ CreateBasicChatMessage("Tell me a short story about a robot learning to paint. Keep it under 200 words."), } @@ -191,7 +191,7 @@ func RunChatCompletionStreamTest(t *testing.T, client *bifrost.Bifrost, ctx cont // Test streaming with tool calls if supported if testConfig.Scenarios.ToolCalls { t.Run("ChatCompletionStreamWithTools", func(t *testing.T) { - messages := []schemas.BifrostMessage{ + messages := []schemas.ChatMessage{ CreateBasicChatMessage("What's the weather like in San Francisco? Please use the get_weather function."), } diff --git a/tests/core-providers/scenarios/complete_end_to_end.go b/tests/core-providers/scenarios/complete_end_to_end.go index 880485a066..0d10a785ad 100644 --- a/tests/core-providers/scenarios/complete_end_to_end.go +++ b/tests/core-providers/scenarios/complete_end_to_end.go @@ -27,7 +27,7 @@ func RunCompleteEnd2EndTest(t *testing.T, client *bifrost.Bifrost, ctx context.C Provider: testConfig.Provider, Model: testConfig.ChatModel, Input: schemas.RequestInput{ - ChatCompletionInput: &[]schemas.BifrostMessage{userMessage1}, + ChatCompletionInput: &[]schemas.ChatMessage{userMessage1}, }, Params: MergeModelParameters(&schemas.ModelParameters{ Tools: &[]schemas.Tool{WeatherToolDefinition}, @@ -44,7 +44,7 @@ func RunCompleteEnd2EndTest(t *testing.T, client *bifrost.Bifrost, ctx context.C t.Logf("✅ First response: %s", GetResultContent(response1)) // If tool was called, simulate result and continue conversation - var conversationHistory []schemas.BifrostMessage + var conversationHistory []schemas.ChatMessage conversationHistory = append(conversationHistory, userMessage1) // Add all choice messages to conversation history diff --git a/tests/core-providers/scenarios/end_to_end_tool_calling.go b/tests/core-providers/scenarios/end_to_end_tool_calling.go index 9995b61cc7..d93844f7bd 100644 --- a/tests/core-providers/scenarios/end_to_end_tool_calling.go +++ b/tests/core-providers/scenarios/end_to_end_tool_calling.go @@ -33,7 +33,7 @@ func RunEnd2EndToolCallingTest(t *testing.T, client *bifrost.Bifrost, ctx contex Provider: testConfig.Provider, Model: testConfig.ChatModel, Input: schemas.RequestInput{ - ChatCompletionInput: &[]schemas.BifrostMessage{userMessage}, + ChatCompletionInput: &[]schemas.ChatMessage{userMessage}, }, Params: params, Fallbacks: testConfig.Fallbacks, @@ -78,7 +78,7 @@ func RunEnd2EndToolCallingTest(t *testing.T, client *bifrost.Bifrost, ctx contex require.NotEmpty(t, toolCallID, "toolCallID must not be empty") // Build conversation history with all choice messages from first response - conversationMessages := []schemas.BifrostMessage{ + conversationMessages := []schemas.ChatMessage{ userMessage, } diff --git a/tests/core-providers/scenarios/image_base64.go b/tests/core-providers/scenarios/image_base64.go index b10655d6f0..f7d97a2210 100644 --- a/tests/core-providers/scenarios/image_base64.go +++ b/tests/core-providers/scenarios/image_base64.go @@ -20,7 +20,7 @@ func RunImageBase64Test(t *testing.T, client *bifrost.Bifrost, ctx context.Conte } t.Run("ImageBase64", func(t *testing.T) { - messages := []schemas.BifrostMessage{ + messages := []schemas.ChatMessage{ CreateImageMessage("Describe this image briefly", TestImageBase64), } diff --git a/tests/core-providers/scenarios/image_url.go b/tests/core-providers/scenarios/image_url.go index 9f11d63289..a0ae9c841a 100644 --- a/tests/core-providers/scenarios/image_url.go +++ b/tests/core-providers/scenarios/image_url.go @@ -21,7 +21,7 @@ func RunImageURLTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, } t.Run("ImageURL", func(t *testing.T) { - messages := []schemas.BifrostMessage{ + messages := []schemas.ChatMessage{ CreateImageMessage("What do you see in this image?", TestImageURL), } diff --git a/tests/core-providers/scenarios/multi_turn_conversation.go b/tests/core-providers/scenarios/multi_turn_conversation.go index 10512578bc..e246097911 100644 --- a/tests/core-providers/scenarios/multi_turn_conversation.go +++ b/tests/core-providers/scenarios/multi_turn_conversation.go @@ -23,7 +23,7 @@ func RunMultiTurnConversationTest(t *testing.T, client *bifrost.Bifrost, ctx con t.Run("MultiTurnConversation", func(t *testing.T) { // First message userMessage1 := CreateBasicChatMessage("My name is Alice. Remember this.") - messages1 := []schemas.BifrostMessage{ + messages1 := []schemas.ChatMessage{ userMessage1, } @@ -46,7 +46,7 @@ func RunMultiTurnConversationTest(t *testing.T, client *bifrost.Bifrost, ctx con // Second message with conversation history // Build conversation history with all choice messages - messages2 := []schemas.BifrostMessage{ + messages2 := []schemas.ChatMessage{ userMessage1, } diff --git a/tests/core-providers/scenarios/multiple_images.go b/tests/core-providers/scenarios/multiple_images.go index ba8d70a2b4..77e5300109 100644 --- a/tests/core-providers/scenarios/multiple_images.go +++ b/tests/core-providers/scenarios/multiple_images.go @@ -20,10 +20,10 @@ func RunMultipleImagesTest(t *testing.T, client *bifrost.Bifrost, ctx context.Co } t.Run("MultipleImages", func(t *testing.T) { - messages := []schemas.BifrostMessage{ + messages := []schemas.ChatMessage{ { - Role: schemas.ModelChatMessageRoleUser, - Content: schemas.MessageContent{ + Role: schemas.ChatMessageRoleUser, + Content: schemas.ChatMessageContent{ ContentBlocks: &[]schemas.ContentBlock{ { Type: schemas.ContentBlockTypeText, diff --git a/tests/core-providers/scenarios/multiple_tool_calls.go b/tests/core-providers/scenarios/multiple_tool_calls.go index dd65e663c3..75dea144a7 100644 --- a/tests/core-providers/scenarios/multiple_tool_calls.go +++ b/tests/core-providers/scenarios/multiple_tool_calls.go @@ -30,7 +30,7 @@ func RunMultipleToolCallsTest(t *testing.T, client *bifrost.Bifrost, ctx context } t.Run("MultipleToolCalls", func(t *testing.T) { - messages := []schemas.BifrostMessage{ + messages := []schemas.ChatMessage{ CreateBasicChatMessage("I need to know the weather in London and also calculate 15 * 23. Can you help with both?"), } diff --git a/tests/core-providers/scenarios/provider_specific.go b/tests/core-providers/scenarios/provider_specific.go index 0c4b3bff0c..189563b92f 100644 --- a/tests/core-providers/scenarios/provider_specific.go +++ b/tests/core-providers/scenarios/provider_specific.go @@ -22,7 +22,7 @@ func RunProviderSpecificTest(t *testing.T, client *bifrost.Bifrost, ctx context. t.Run("ProviderSpecific", func(t *testing.T) { // This would contain provider-specific tests // For now, we'll do a basic functionality test - messages := []schemas.BifrostMessage{ + messages := []schemas.ChatMessage{ CreateBasicChatMessage("Test provider-specific functionality. What makes you unique?"), } diff --git a/tests/core-providers/scenarios/simple_chat.go b/tests/core-providers/scenarios/simple_chat.go index 5665e4fc63..c2940316e8 100644 --- a/tests/core-providers/scenarios/simple_chat.go +++ b/tests/core-providers/scenarios/simple_chat.go @@ -20,7 +20,7 @@ func RunSimpleChatTest(t *testing.T, client *bifrost.Bifrost, ctx context.Contex } t.Run("SimpleChat", func(t *testing.T) { - messages := []schemas.BifrostMessage{ + messages := []schemas.ChatMessage{ CreateBasicChatMessage("Hello! What's the capital of France?"), } diff --git a/tests/core-providers/scenarios/tool_calls.go b/tests/core-providers/scenarios/tool_calls.go index 486b1e094a..fddd955c6f 100644 --- a/tests/core-providers/scenarios/tool_calls.go +++ b/tests/core-providers/scenarios/tool_calls.go @@ -20,7 +20,7 @@ func RunToolCallsTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context } t.Run("ToolCalls", func(t *testing.T) { - messages := []schemas.BifrostMessage{ + messages := []schemas.ChatMessage{ CreateBasicChatMessage("What's the weather like in New York? answer in celsius"), } diff --git a/tests/core-providers/scenarios/utils.go b/tests/core-providers/scenarios/utils.go index 27937762f9..098797437e 100644 --- a/tests/core-providers/scenarios/utils.go +++ b/tests/core-providers/scenarios/utils.go @@ -146,19 +146,19 @@ func CreateTranscriptionInput(audioData []byte, language, responseFormat *string } // Helper functions for creating requests -func CreateBasicChatMessage(content string) schemas.BifrostMessage { - return schemas.BifrostMessage{ - Role: schemas.ModelChatMessageRoleUser, - Content: schemas.MessageContent{ +func CreateBasicChatMessage(content string) schemas.ChatMessage { + return schemas.ChatMessage{ + Role: schemas.ChatMessageRoleUser, + Content: schemas.ChatMessageContent{ ContentStr: bifrost.Ptr(content), }, } } -func CreateImageMessage(text, imageURL string) schemas.BifrostMessage { - return schemas.BifrostMessage{ - Role: schemas.ModelChatMessageRoleUser, - Content: schemas.MessageContent{ +func CreateImageMessage(text, imageURL string) schemas.ChatMessage { + return schemas.ChatMessage{ + Role: schemas.ChatMessageRoleUser, + Content: schemas.ChatMessageContent{ ContentBlocks: &[]schemas.ContentBlock{ { Type: schemas.ContentBlockTypeText, @@ -175,10 +175,10 @@ func CreateImageMessage(text, imageURL string) schemas.BifrostMessage { } } -func CreateToolMessage(content string, toolCallID string) schemas.BifrostMessage { - return schemas.BifrostMessage{ +func CreateToolMessage(content string, toolCallID string) schemas.ChatMessage { + return schemas.ChatMessage{ Role: schemas.ModelChatMessageRoleTool, - Content: schemas.MessageContent{ + Content: schemas.ChatMessageContent{ ContentStr: bifrost.Ptr(content), }, ToolMessage: &schemas.ToolMessage{ diff --git a/transports/bifrost-http/handlers/completions.go b/transports/bifrost-http/handlers/completions.go index 7333c90742..d12d9a35a5 100644 --- a/transports/bifrost-http/handlers/completions.go +++ b/transports/bifrost-http/handlers/completions.go @@ -39,138 +39,195 @@ func NewCompletionHandler(client *bifrost.Bifrost, handlerStore lib.HandlerStore } // Known fields for CompletionRequest -var completionRequestKnownFields = map[string]bool{ - "model": true, - "messages": true, - "text": true, - "fallbacks": true, - "stream": true, - "input": true, - "voice": true, - "instructions": true, - "response_format": true, - "stream_format": true, - "tool_choice": true, - "tools": true, - "temperature": true, - "top_p": true, - "top_k": true, - "max_tokens": true, - "stop_sequences": true, - "presence_penalty": true, - "frequency_penalty": true, - "parallel_tool_calls": true, - "encoding_format": true, - "dimensions": true, - "user": true, +var textParamsKnownFields = map[string]bool{ + "model": true, + "text": true, + "fallbacks": true, + "best_of": true, + "echo": true, + "frequency_penalty": true, + "logit_bias": true, + "logprobs": true, + "max_tokens": true, + "n": true, + "presence_penalty": true, + "seed": true, + "stop": true, + "suffix": true, + "temperature": true, + "top_p": true, + "user": true, } -// CompletionRequest represents a request for either text or chat completion -type CompletionRequest struct { - Model string `json:"model"` // Model to use in "provider/model" format - Messages []schemas.BifrostMessage `json:"messages"` // Chat messages (for chat completion) - Text string `json:"text"` // Text input (for text completion) - Fallbacks []string `json:"fallbacks"` // Fallback providers and models in "provider/model" format - Stream *bool `json:"stream"` // Whether to stream the response - - // Speech inputs - Input schemas.EmbeddingInput `json:"input"` // string can be used for voice input as well - Voice schemas.SpeechVoiceInput `json:"voice"` - Instructions string `json:"instructions"` - ResponseFormat string `json:"response_format"` - StreamFormat *string `json:"stream_format,omitempty"` - - ToolChoice *schemas.ToolChoice `json:"tool_choice,omitempty"` // Whether to call a tool - Tools *[]schemas.Tool `json:"tools,omitempty"` // Tools to use - Temperature *float64 `json:"temperature,omitempty"` // Controls randomness in the output - TopP *float64 `json:"top_p,omitempty"` // Controls diversity via nucleus sampling - TopK *int `json:"top_k,omitempty"` // Controls diversity via top-k sampling - MaxTokens *int `json:"max_tokens,omitempty"` // Maximum number of tokens to generate - StopSequences *[]string `json:"stop_sequences,omitempty"` // Sequences that stop generation - PresencePenalty *float64 `json:"presence_penalty,omitempty"` // Penalizes repeated tokens - FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` // Penalizes frequent tokens - ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` // Enables parallel tool calls - EncodingFormat *string `json:"encoding_format,omitempty"` // Format for embedding output (e.g., "float", "base64") - Dimensions *int `json:"dimensions,omitempty"` // Number of dimensions for embedding output - User *string `json:"user,omitempty"` // User identifier for tracking - - // Dynamic parameters that can be provider-specific, they are directly - // added to the request as is. - ExtraParams map[string]interface{} `json:"-"` +// Known fields for CompletionRequest +var chatParamsKnownFields = map[string]bool{ + "model": true, + "messages": true, + "fallbacks": true, + "stream": true, + "frequency_penalty": true, + "logit_bias": true, + "logprobs": true, + "max_completion_tokens": true, + "metadata": true, + "modalities": true, + "parallel_tool_calls": true, + "presence_penalty": true, + "prompt_cache_key": true, + "reasoning_effort": true, + "response_format": true, + "safety_identifier": true, + "service_tier": true, + "stream_options": true, + "store": true, + "temperature": true, + "tool_choice": true, + "tools": true, + "truncation": true, + "user": true, + "verbosity": true, } -func (cr *CompletionRequest) UnmarshalJSON(data []byte) error { - // Use type alias to avoid infinite recursion - type Alias CompletionRequest - aux := (*Alias)(cr) +var responsesParamsKnownFields = map[string]bool{ + "model": true, + "input": true, + "fallbacks": true, + "stream": true, + "background": true, + "conversation": true, + "include": true, + "instructions": true, + "max_output_tokens": true, + "max_tool_calls": true, + "metadata": true, + "parallel_tool_calls": true, + "previous_response_id": true, + "prompt_cache_key": true, + "reasoning": true, + "safety_identifier": true, + "service_tier": true, + "stream_options": true, + "store": true, + "temperature": true, + "text": true, + "top_logprobs": true, + "top_p": true, + "tool_choice": true, + "tools": true, + "truncation": true, +} - // First unmarshal known fields - if err := sonic.Unmarshal(data, aux); err != nil { - return err - } +var embeddingParamsKnownFields = map[string]bool{ + "model": true, + "input": true, + "fallbacks": true, + "encoding_format": true, + "dimensions": true, +} - // Then unmarshal to map for unknown fields - var rawData map[string]json.RawMessage - if err := sonic.Unmarshal(data, &rawData); err != nil { - return err +var speechParamsKnownFields = map[string]bool{ + "model": true, + "input": true, + "fallbacks": true, + "stream_format": true, + "voice": true, + "instructions": true, + "response_format": true, + "speed": true, +} + +var transcriptionParamsKnownFields = map[string]bool{ + "model": true, + "file": true, + "fallbacks": true, + "stream": true, + "language": true, + "prompt": true, + "response_format": true, + "file_format": true, +} + +type BifrostParams struct { + Model string `json:"model"` // Model to use in "provider/model" format + Fallbacks []string `json:"fallbacks"` // Fallback providers and models in "provider/model" format + Stream *bool `json:"stream"` // Whether to stream the response + StreamFormat *string `json:"stream_format,omitempty"` // For speech +} + +type TextRequest struct { + Prompt schemas.TextCompletionInput `json:"prompt"` + BifrostParams + *schemas.TextCompletionParameters +} + +type ChatRequest struct { + Messages []schemas.ChatMessage `json:"messages"` + BifrostParams + *schemas.ChatParameters +} + +type ResponsesRequest struct { + Input []schemas.ResponsesMessage `json:"input"` + BifrostParams + *schemas.ResponsesParameters +} + +type EmbeddingRequest struct { + Input schemas.EmbeddingInput `json:"input"` + BifrostParams + *schemas.EmbeddingParameters +} + +type SpeechRequest struct { + *schemas.SpeechInput + BifrostParams + *schemas.SpeechParameters +} + +type TranscriptionRequest struct { + *schemas.TranscriptionInput + BifrostParams + *schemas.TranscriptionParameters +} + +// Helper functions + +// parseFallbacks extracts fallbacks from string array and converts to Fallback structs +func parseFallbacks(fallbackStrings []string) ([]schemas.Fallback, error) { + fallbacks := make([]schemas.Fallback, len(fallbackStrings)) + for i, fallback := range fallbackStrings { + fallbackProvider, fallbackModelName := schemas.ParseModelString(fallback, "") + fallbacks[i] = schemas.Fallback{ + Provider: fallbackProvider, + Model: fallbackModelName, + } } + return fallbacks, nil +} - // Initialize ExtraParams - if cr.ExtraParams == nil { - cr.ExtraParams = make(map[string]interface{}) +// extractExtraParams processes unknown fields from JSON data into ExtraParams +func extractExtraParams(data []byte, knownFields map[string]bool) (map[string]interface{}, error) { + // Parse JSON to extract unknown fields + var rawData map[string]json.RawMessage + if err := json.Unmarshal(data, &rawData); err != nil { + return nil, err } // Extract unknown fields + extraParams := make(map[string]interface{}) for key, value := range rawData { - if !completionRequestKnownFields[key] { + if !knownFields[key] { var v interface{} - if err := sonic.Unmarshal(value, &v); err != nil { + if err := json.Unmarshal(value, &v); err != nil { continue // Skip fields that can't be unmarshaled } - cr.ExtraParams[key] = v + extraParams[key] = v } } - return nil + return extraParams, nil } -func (cr *CompletionRequest) GetModelParameters() *schemas.ModelParameters { - params := &schemas.ModelParameters{ - ExtraParams: make(map[string]interface{}), - ToolChoice: cr.ToolChoice, - Tools: cr.Tools, - Temperature: cr.Temperature, - TopP: cr.TopP, - TopK: cr.TopK, - MaxTokens: cr.MaxTokens, - StopSequences: cr.StopSequences, - PresencePenalty: cr.PresencePenalty, - FrequencyPenalty: cr.FrequencyPenalty, - ParallelToolCalls: cr.ParallelToolCalls, - EncodingFormat: cr.EncodingFormat, - Dimensions: cr.Dimensions, - User: cr.User, - } - - if cr.ExtraParams != nil { - for k, v := range cr.ExtraParams { - params.ExtraParams[k] = v - } - } - - return params -} - -type CompletionType string - -const ( - CompletionTypeText CompletionType = "text" - CompletionTypeChat CompletionType = "chat" - CompletionTypeEmbeddings CompletionType = "embeddings" - CompletionTypeSpeech CompletionType = "speech" - CompletionTypeTranscription CompletionType = "transcription" -) - const ( // Maximum file size (25MB) MaxFileSize = 25 * 1024 * 1024 @@ -186,115 +243,356 @@ const ( AudioMimeFLAC2 = "audio/x-flac" // Alternative FLAC ) -// validateAudioFile checks if the file size and format are valid -func (h *CompletionHandler) validateAudioFile(fileHeader *multipart.FileHeader) error { - // Check file size - if fileHeader.Size > MaxFileSize { - return fmt.Errorf("file size exceeds maximum limit of %d MB", MaxFileSize/1024/1024) +// RegisterRoutes registers all completion-related routes +func (h *CompletionHandler) RegisterRoutes(r *router.Router) { + // Completion endpoints + r.POST("/v1/completions", h.textCompletion) + r.POST("/v1/chat/completions", h.chatCompletion) + r.POST("/v1/responses", h.responses) + r.POST("/v1/embeddings", h.embeddings) + r.POST("/v1/audio/speech", h.speech) + r.POST("/v1/audio/transcriptions", h.transcription) +} + +// textCompletion handles POST /v1/completions - Process text completion requests +func (h *CompletionHandler) textCompletion(ctx *fasthttp.RequestCtx) { + var req TextRequest + if err := sonic.Unmarshal(ctx.PostBody(), &req); err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid request format: %v", err), h.logger) + return } - // Get file extension - ext := strings.ToLower(filepath.Ext(fileHeader.Filename)) + // Create BifrostTextCompletionRequest directly using segregated structure + provider, modelName := schemas.ParseModelString(req.Model, "") - // Check file extension - validExtensions := map[string]bool{ - ".flac": true, - ".mp3": true, - ".mp4": true, - ".mpeg": true, - ".mpga": true, - ".m4a": true, - ".ogg": true, - ".wav": true, - ".webm": true, + // Parse fallbacks using helper function + fallbacks, err := parseFallbacks(req.Fallbacks) + if err != nil { + SendError(ctx, fasthttp.StatusBadRequest, err.Error(), h.logger) + return } - if !validExtensions[ext] { - return fmt.Errorf("unsupported file format: %s. Supported formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, webm", ext) + if req.Prompt.PromptStr == nil && req.Prompt.PromptArray == nil { + SendError(ctx, fasthttp.StatusBadRequest, "Text is required for text completion", h.logger) + return } - // Open file to check MIME type - file, err := fileHeader.Open() + // Extract extra params + if req.TextCompletionParameters == nil { + req.TextCompletionParameters = &schemas.TextCompletionParameters{} + } + + extraParams, err := extractExtraParams(ctx.PostBody(), textParamsKnownFields) if err != nil { - return fmt.Errorf("failed to open file: %v", err) + h.logger.Warn(fmt.Sprintf("Failed to extract extra params: %v", err)) + } else { + req.TextCompletionParameters.ExtraParams = extraParams } - defer file.Close() - // Read first 512 bytes for MIME type detection - buffer := make([]byte, 512) - _, err = file.Read(buffer) - if err != nil && err != io.EOF { - return fmt.Errorf("failed to read file header: %v", err) + // Create segregated BifrostTextCompletionRequest + bifrostTextReq := &schemas.BifrostTextCompletionRequest{ + Provider: schemas.ModelProvider(provider), + Model: modelName, + Input: req.Prompt, + Params: req.TextCompletionParameters, + Fallbacks: fallbacks, } - // Check MIME type - mimeType := http.DetectContentType(buffer) - validMimeTypes := map[string]bool{ - // Primary MIME types - AudioMimeMP3: true, // Covers MP3, MPEG, MPGA - AudioMimeMP4: true, - AudioMimeM4A: true, - AudioMimeOGG: true, - AudioMimeWAV: true, - AudioMimeWEBM: true, - AudioMimeFLAC: true, - AudioMimeFLAC2: true, + // Convert context + bifrostCtx := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys()) + if bifrostCtx == nil { + SendError(ctx, fasthttp.StatusInternalServerError, "Failed to convert context", h.logger) + return + } - // Alternative MIME types - "audio/mpeg3": true, - "audio/x-wav": true, - "audio/vnd.wave": true, - "audio/x-mpeg": true, - "audio/x-mpeg3": true, - "audio/x-mpg": true, - "audio/x-mpegaudio": true, + resp, bifrostErr := h.client.TextCompletionRequest(*bifrostCtx, bifrostTextReq) + if bifrostErr != nil { + SendBifrostError(ctx, bifrostErr, h.logger) + return } - if !validMimeTypes[mimeType] { - return fmt.Errorf("invalid file type: %s. Supported audio formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, webm", mimeType) + // Send successful response + SendJSON(ctx, resp, h.logger) +} + +// chatCompletion handles POST /v1/chat/completions - Process chat completion requests +func (h *CompletionHandler) chatCompletion(ctx *fasthttp.RequestCtx) { + var req ChatRequest + if err := sonic.Unmarshal(ctx.PostBody(), &req); err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid request format: %v", err), h.logger) + return } - // Reset file pointer for subsequent reads - _, err = file.Seek(0, 0) + // Create BifrostChatRequest directly using segregated structure + provider, modelName := schemas.ParseModelString(req.Model, "") + + // Parse fallbacks using helper function + fallbacks, err := parseFallbacks(req.Fallbacks) if err != nil { - return fmt.Errorf("failed to reset file pointer: %v", err) + SendError(ctx, fasthttp.StatusBadRequest, err.Error(), h.logger) + return } - return nil -} + if len(req.Messages) == 0 { + SendError(ctx, fasthttp.StatusBadRequest, "Messages is required for chat completion", h.logger) + return + } -// RegisterRoutes registers all completion-related routes -func (h *CompletionHandler) RegisterRoutes(r *router.Router) { - // Completion endpoints - r.POST("/v1/text/completions", h.textCompletion) - r.POST("/v1/chat/completions", h.chatCompletion) - r.POST("/v1/embeddings", h.embeddings) - r.POST("/v1/audio/speech", h.speechCompletion) - r.POST("/v1/audio/transcriptions", h.transcriptionCompletion) -} + // Extract extra params + if req.ChatParameters == nil { + req.ChatParameters = &schemas.ChatParameters{} + } -// textCompletion handles POST /v1/text/completions - Process text completion requests -func (h *CompletionHandler) textCompletion(ctx *fasthttp.RequestCtx) { - h.handleRequest(ctx, CompletionTypeText) + extraParams, err := extractExtraParams(ctx.PostBody(), chatParamsKnownFields) + if err != nil { + h.logger.Warn(fmt.Sprintf("Failed to extract extra params: %v", err)) + } else { + req.ChatParameters.ExtraParams = extraParams + } + + // Create segregated BifrostChatRequest + bifrostChatReq := &schemas.BifrostChatRequest{ + Provider: schemas.ModelProvider(provider), + Model: modelName, + Input: req.Messages, + Params: req.ChatParameters, + Fallbacks: fallbacks, + } + + // Convert context + bifrostCtx := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys()) + if bifrostCtx == nil { + SendError(ctx, fasthttp.StatusInternalServerError, "Failed to convert context", h.logger) + return + } + + if req.Stream != nil && *req.Stream { + h.handleStreamingChatCompletion(ctx, bifrostChatReq, bifrostCtx) + return + } + + resp, bifrostErr := h.client.ChatCompletionRequest(*bifrostCtx, bifrostChatReq) + if bifrostErr != nil { + SendBifrostError(ctx, bifrostErr, h.logger) + return + } + + // Send successful response + SendJSON(ctx, resp, h.logger) } -// chatCompletion handles POST /v1/chat/completions - Process chat completion requests -func (h *CompletionHandler) chatCompletion(ctx *fasthttp.RequestCtx) { - h.handleRequest(ctx, CompletionTypeChat) +// responses handles POST /v1/responses - Process responses requests +func (h *CompletionHandler) responses(ctx *fasthttp.RequestCtx) { + var req ResponsesRequest + if err := sonic.Unmarshal(ctx.PostBody(), &req); err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid request format: %v", err), h.logger) + return + } + + // Create BifrostResponsesRequest directly using segregated structure + provider, modelName := schemas.ParseModelString(req.Model, "") + + // Parse fallbacks using helper function + fallbacks, err := parseFallbacks(req.Fallbacks) + if err != nil { + SendError(ctx, fasthttp.StatusBadRequest, err.Error(), h.logger) + return + } + + if len(req.Input) == 0 { + SendError(ctx, fasthttp.StatusBadRequest, "Input is required for responses", h.logger) + return + } + + // Extract extra params + if req.ResponsesParameters == nil { + req.ResponsesParameters = &schemas.ResponsesParameters{} + } + + extraParams, err := extractExtraParams(ctx.PostBody(), responsesParamsKnownFields) + if err != nil { + h.logger.Warn(fmt.Sprintf("Failed to extract extra params: %v", err)) + } else { + req.ResponsesParameters.ExtraParams = extraParams + } + + // Create segregated BifrostResponsesRequest + bifrostResponsesReq := &schemas.BifrostResponsesRequest{ + Provider: schemas.ModelProvider(provider), + Model: modelName, + Input: req.Input, + Params: req.ResponsesParameters, + Fallbacks: fallbacks, + } + + // Convert context + bifrostCtx := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys()) + if bifrostCtx == nil { + SendError(ctx, fasthttp.StatusInternalServerError, "Failed to convert context", h.logger) + return + } + + if req.Stream != nil && *req.Stream { + h.handleStreamingResponses(ctx, bifrostResponsesReq, bifrostCtx) + return + } + + resp, bifrostErr := h.client.ResponsesRequest(*bifrostCtx, bifrostResponsesReq) + if bifrostErr != nil { + SendBifrostError(ctx, bifrostErr, h.logger) + return + } + + // Send successful response + SendJSON(ctx, resp, h.logger) } // embeddings handles POST /v1/embeddings - Process embeddings requests func (h *CompletionHandler) embeddings(ctx *fasthttp.RequestCtx) { - h.handleRequest(ctx, CompletionTypeEmbeddings) + var req EmbeddingRequest + if err := sonic.Unmarshal(ctx.PostBody(), &req); err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid request format: %v", err), h.logger) + return + } + + // Create BifrostEmbeddingRequest directly using segregated structure + provider, modelName := schemas.ParseModelString(req.Model, "") + + // Parse fallbacks using helper function + fallbacks, err := parseFallbacks(req.Fallbacks) + if err != nil { + SendError(ctx, fasthttp.StatusBadRequest, err.Error(), h.logger) + return + } + + if req.Input.Text == nil && req.Input.Texts == nil && req.Input.Embedding == nil && req.Input.Embeddings == nil { + SendError(ctx, fasthttp.StatusBadRequest, "Input is required for embeddings", h.logger) + return + } + + // Extract extra params + if req.EmbeddingParameters == nil { + req.EmbeddingParameters = &schemas.EmbeddingParameters{} + } + + extraParams, err := extractExtraParams(ctx.PostBody(), embeddingParamsKnownFields) + if err != nil { + h.logger.Warn(fmt.Sprintf("Failed to extract extra params: %v", err)) + } else { + req.EmbeddingParameters.ExtraParams = extraParams + } + + // Create segregated BifrostEmbeddingRequest + bifrostEmbeddingReq := &schemas.BifrostEmbeddingRequest{ + Provider: schemas.ModelProvider(provider), + Model: modelName, + Input: req.Input, + Params: req.EmbeddingParameters, + Fallbacks: fallbacks, + } + + // Convert context + bifrostCtx := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys()) + if bifrostCtx == nil { + SendError(ctx, fasthttp.StatusInternalServerError, "Failed to convert context", h.logger) + return + } + + resp, bifrostErr := h.client.EmbeddingRequest(*bifrostCtx, bifrostEmbeddingReq) + if bifrostErr != nil { + SendBifrostError(ctx, bifrostErr, h.logger) + return + } + + // Send successful response + SendJSON(ctx, resp, h.logger) } -// speechCompletion handles POST /v1/audio/speech - Process speech completion requests -func (h *CompletionHandler) speechCompletion(ctx *fasthttp.RequestCtx) { - h.handleRequest(ctx, CompletionTypeSpeech) +// speech handles POST /v1/audio/speech - Process speech completion requests +func (h *CompletionHandler) speech(ctx *fasthttp.RequestCtx) { + var req SpeechRequest + if err := sonic.Unmarshal(ctx.PostBody(), &req); err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid request format: %v", err), h.logger) + return + } + + // Create BifrostSpeechRequest directly using segregated structure + provider, modelName := schemas.ParseModelString(req.Model, "") + + // Parse fallbacks using helper function + fallbacks, err := parseFallbacks(req.Fallbacks) + if err != nil { + SendError(ctx, fasthttp.StatusBadRequest, err.Error(), h.logger) + return + } + + if req.Input == "" { + SendError(ctx, fasthttp.StatusBadRequest, "Input is required for speech completion", h.logger) + return + } + if req.VoiceConfig.Voice == nil && len(req.VoiceConfig.MultiVoiceConfig) == 0 { + SendError(ctx, fasthttp.StatusBadRequest, "Voice is required for speech completion", h.logger) + return + } + + // Extract extra params + if req.SpeechParameters == nil { + req.SpeechParameters = &schemas.SpeechParameters{} + } + + // Extract extra params + if req.SpeechParameters == nil { + req.SpeechParameters = &schemas.SpeechParameters{} + } + + extraParams, err := extractExtraParams(ctx.PostBody(), speechParamsKnownFields) + if err != nil { + h.logger.Warn(fmt.Sprintf("Failed to extract extra params: %v", err)) + } else { + req.SpeechParameters.ExtraParams = extraParams + } + + // Create segregated BifrostSpeechRequest + bifrostSpeechReq := &schemas.BifrostSpeechRequest{ + Provider: schemas.ModelProvider(provider), + Model: modelName, + Input: *req.SpeechInput, + Params: req.SpeechParameters, + Fallbacks: fallbacks, + } + + // Convert context + bifrostCtx := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys()) + if bifrostCtx == nil { + SendError(ctx, fasthttp.StatusInternalServerError, "Failed to convert context", h.logger) + return + } + + if req.StreamFormat != nil && *req.StreamFormat == "sse" { + h.handleStreamingSpeech(ctx, bifrostSpeechReq, bifrostCtx) + return + } + + resp, bifrostErr := h.client.SpeechRequest(*bifrostCtx, bifrostSpeechReq) + if bifrostErr != nil { + SendBifrostError(ctx, bifrostErr, h.logger) + return + } + + // Send successful response + if resp.Speech.Audio == nil { + SendError(ctx, fasthttp.StatusInternalServerError, "Speech response is missing audio data", h.logger) + return + } + + ctx.Response.Header.Set("Content-Type", "audio/mpeg") + ctx.Response.Header.Set("Content-Disposition", "attachment; filename=speech.mp3") + ctx.Response.Header.Set("Content-Length", strconv.Itoa(len(resp.Speech.Audio))) + ctx.Response.SetBody(resp.Speech.Audio) } -// transcriptionCompletion handles POST /v1/audio/transcriptions - Process transcription requests -func (h *CompletionHandler) transcriptionCompletion(ctx *fasthttp.RequestCtx) { +// transcription handles POST /v1/audio/transcriptions - Process transcription requests +func (h *CompletionHandler) transcription(ctx *fasthttp.RequestCtx) { // Parse multipart form form, err := ctx.MultipartForm() if err != nil { @@ -309,11 +607,7 @@ func (h *CompletionHandler) transcriptionCompletion(ctx *fasthttp.RequestCtx) { return } - provider, modelName, err := ParseModel(modelValues[0]) - if err != nil { - SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Model must be in the format of 'provider/model': %v", err), h.logger) - return - } + provider, modelName := schemas.ParseModelString(modelValues[0], "") // Extract file (required) fileHeaders := form.File["file"] @@ -349,26 +643,38 @@ func (h *CompletionHandler) transcriptionCompletion(ctx *fasthttp.RequestCtx) { File: fileData, } + // Create transcription parameters + transcriptionParams := &schemas.TranscriptionParameters{} + // Extract optional parameters if languageValues := form.Value["language"]; len(languageValues) > 0 && languageValues[0] != "" { - transcriptionInput.Language = &languageValues[0] + transcriptionParams.Language = &languageValues[0] } if promptValues := form.Value["prompt"]; len(promptValues) > 0 && promptValues[0] != "" { - transcriptionInput.Prompt = &promptValues[0] + transcriptionParams.Prompt = &promptValues[0] } if responseFormatValues := form.Value["response_format"]; len(responseFormatValues) > 0 && responseFormatValues[0] != "" { - transcriptionInput.ResponseFormat = &responseFormatValues[0] + transcriptionParams.ResponseFormat = &responseFormatValues[0] } - // Create BifrostRequest - bifrostReq := &schemas.BifrostRequest{ + if transcriptionParams.ExtraParams == nil { + transcriptionParams.ExtraParams = make(map[string]interface{}) + } + + for key, value := range form.Value { + if len(value) > 0 && value[0] != "" && !transcriptionParamsKnownFields[key] { + transcriptionParams.ExtraParams[key] = value[0] + } + } + + // Create BifrostTranscriptionRequest + bifrostTranscriptionReq := &schemas.BifrostTranscriptionRequest{ Model: modelName, Provider: schemas.ModelProvider(provider), - Input: schemas.RequestInput{ - TranscriptionInput: transcriptionInput, - }, + Input: *transcriptionInput, + Params: transcriptionParams, } // Convert context @@ -381,13 +687,13 @@ func (h *CompletionHandler) transcriptionCompletion(ctx *fasthttp.RequestCtx) { if streamValues := form.Value["stream"]; len(streamValues) > 0 && streamValues[0] != "" { stream := streamValues[0] if stream == "true" { - h.handleStreamingTranscriptionRequest(ctx, bifrostReq, bifrostCtx) + h.handleStreamingTranscriptionRequest(ctx, bifrostTranscriptionReq, bifrostCtx) return } } // Make transcription request - resp, bifrostErr := h.client.TranscriptionRequest(*bifrostCtx, bifrostReq) + resp, bifrostErr := h.client.TranscriptionRequest(*bifrostCtx, bifrostTranscriptionReq) // Handle response if bifrostErr != nil { @@ -399,150 +705,62 @@ func (h *CompletionHandler) transcriptionCompletion(ctx *fasthttp.RequestCtx) { SendJSON(ctx, resp, h.logger) } -// handleCompletion processes both text and chat completion requests -// It handles request parsing, validation, and response formatting -func (h *CompletionHandler) handleRequest(ctx *fasthttp.RequestCtx, completionType CompletionType) { - var req CompletionRequest - if err := sonic.Unmarshal(ctx.PostBody(), &req); err != nil { - SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid request format: %v", err), h.logger) - return +// handleStreamingChatCompletion handles streaming chat completion requests using Server-Sent Events (SSE) +func (h *CompletionHandler) handleStreamingChatCompletion(ctx *fasthttp.RequestCtx, req *schemas.BifrostChatRequest, bifrostCtx *context.Context) { + getStream := func() (chan *schemas.BifrostStream, *schemas.BifrostError) { + return h.client.ChatCompletionStreamRequest(*bifrostCtx, req) } - if req.Model == "" { - SendError(ctx, fasthttp.StatusBadRequest, "Model is required", h.logger) - return + extractResponse := func(response *schemas.BifrostStream) (interface{}, bool) { + return response, true } - provider, modelName, err := ParseModel(req.Model) - if err != nil { - SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Model must be in the format of 'provider/model': %v", err), h.logger) - return - } + h.handleStreamingResponse(ctx, getStream, extractResponse) +} - fallbacks := make([]schemas.Fallback, len(req.Fallbacks)) - for i, fallback := range req.Fallbacks { - fallbackProvider, fallbackModelName, err := ParseModel(fallback) - if err != nil { - SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Fallback must be in the format of 'provider/model': %v", err), h.logger) - return - } - if fallbackProvider == "" || fallbackModelName == "" { - SendError(ctx, fasthttp.StatusBadRequest, "Fallback must be in the format of 'provider/model'", h.logger) - return - } - fallbacks[i] = schemas.Fallback{ - Provider: schemas.ModelProvider(fallbackProvider), - Model: fallbackModelName, - } +// handleStreamingResponses handles streaming responses requests using Server-Sent Events (SSE) +func (h *CompletionHandler) handleStreamingResponses(ctx *fasthttp.RequestCtx, req *schemas.BifrostResponsesRequest, bifrostCtx *context.Context) { + getStream := func() (chan *schemas.BifrostStream, *schemas.BifrostError) { + return h.client.ResponsesStreamRequest(*bifrostCtx, req) } - // Create BifrostRequest - bifrostReq := &schemas.BifrostRequest{ - Model: modelName, - Provider: schemas.ModelProvider(provider), - Params: req.GetModelParameters(), - Fallbacks: fallbacks, + extractResponse := func(response *schemas.BifrostStream) (interface{}, bool) { + return response, true } - // Validate and set input based on completion type - switch completionType { - case CompletionTypeText: - if req.Text == "" { - SendError(ctx, fasthttp.StatusBadRequest, "Text is required for text completion", h.logger) - return - } - bifrostReq.Input = schemas.RequestInput{ - TextCompletionInput: &req.Text, - } - case CompletionTypeChat: - if len(req.Messages) == 0 { - SendError(ctx, fasthttp.StatusBadRequest, "Messages array is required for chat completion", h.logger) - return - } - bifrostReq.Input = schemas.RequestInput{ - ChatCompletionInput: &req.Messages, - } - case CompletionTypeEmbeddings: - bifrostReq.Input = schemas.RequestInput{ - EmbeddingInput: &req.Input, - } - case CompletionTypeSpeech: - if req.Input.Text == nil { - SendError(ctx, fasthttp.StatusBadRequest, "Input is required for speech completion", h.logger) - return - } - if req.Voice.Voice == nil && len(req.Voice.MultiVoiceConfig) == 0 { - SendError(ctx, fasthttp.StatusBadRequest, "Voice is required for speech completion", h.logger) - return - } - bifrostReq.Input = schemas.RequestInput{ - SpeechInput: &schemas.SpeechInput{ - Input: *req.Input.Text, - VoiceConfig: req.Voice, - Instructions: req.Instructions, - ResponseFormat: req.ResponseFormat, - }, - } - } + h.handleStreamingResponse(ctx, getStream, extractResponse) +} - // Convert context - bifrostCtx := lib.ConvertToBifrostContext(ctx, h.handlerStore.ShouldAllowDirectKeys()) - if bifrostCtx == nil { - SendError(ctx, fasthttp.StatusInternalServerError, "Failed to convert context", h.logger) - return +// handleStreamingSpeech handles streaming speech requests using Server-Sent Events (SSE) +func (h *CompletionHandler) handleStreamingSpeech(ctx *fasthttp.RequestCtx, req *schemas.BifrostSpeechRequest, bifrostCtx *context.Context) { + getStream := func() (chan *schemas.BifrostStream, *schemas.BifrostError) { + return h.client.SpeechStreamRequest(*bifrostCtx, req) } - // Check if streaming is requested - isStreaming := req.Stream != nil && *req.Stream || req.StreamFormat != nil && *req.StreamFormat == "sse" - - // Handle streaming for chat completions only - if isStreaming { - switch completionType { - case CompletionTypeChat: - h.handleStreamingChatCompletion(ctx, bifrostReq, bifrostCtx) - return - case CompletionTypeSpeech: - h.handleStreamingSpeech(ctx, bifrostReq, bifrostCtx) - return + extractResponse := func(response *schemas.BifrostStream) (interface{}, bool) { + if response.Speech == nil || response.Speech.BifrostSpeechStreamResponse == nil { + return nil, false } + return response.Speech, true } - // Handle non-streaming requests - var resp *schemas.BifrostResponse - var bifrostErr *schemas.BifrostError - - switch completionType { - case CompletionTypeText: - resp, bifrostErr = h.client.TextCompletionRequest(*bifrostCtx, bifrostReq) - case CompletionTypeChat: - resp, bifrostErr = h.client.ChatCompletionRequest(*bifrostCtx, bifrostReq) - case CompletionTypeEmbeddings: - resp, bifrostErr = h.client.EmbeddingRequest(*bifrostCtx, bifrostReq) - case CompletionTypeSpeech: - resp, bifrostErr = h.client.SpeechRequest(*bifrostCtx, bifrostReq) - } + h.handleStreamingResponse(ctx, getStream, extractResponse) +} - // Handle response - if bifrostErr != nil { - SendBifrostError(ctx, bifrostErr, h.logger) - return +// handleStreamingTranscriptionRequest handles streaming transcription requests using Server-Sent Events (SSE) +func (h *CompletionHandler) handleStreamingTranscriptionRequest(ctx *fasthttp.RequestCtx, req *schemas.BifrostTranscriptionRequest, bifrostCtx *context.Context) { + getStream := func() (chan *schemas.BifrostStream, *schemas.BifrostError) { + return h.client.TranscriptionStreamRequest(*bifrostCtx, req) } - if completionType == CompletionTypeSpeech { - if resp.Speech.Audio == nil { - SendError(ctx, fasthttp.StatusInternalServerError, "Speech response is missing audio data", h.logger) - return + extractResponse := func(response *schemas.BifrostStream) (interface{}, bool) { + if response.Transcribe == nil || response.Transcribe.BifrostTranscribeStreamResponse == nil { + return nil, false } - - ctx.Response.Header.Set("Content-Type", "audio/mpeg") - ctx.Response.Header.Set("Content-Disposition", "attachment; filename=speech.mp3") - ctx.Response.Header.Set("Content-Length", strconv.Itoa(len(resp.Speech.Audio))) - ctx.Response.SetBody(resp.Speech.Audio) - return + return response.Transcribe, true } - // Send successful response - SendJSON(ctx, resp, h.logger) + h.handleStreamingResponse(ctx, getStream, extractResponse) } // handleStreamingResponse is a generic function to handle streaming responses using Server-Sent Events (SSE) @@ -604,47 +822,79 @@ func (h *CompletionHandler) handleStreamingResponse(ctx *fasthttp.RequestCtx, ge }) } -// handleStreamingChatCompletion handles streaming chat completion requests using Server-Sent Events (SSE) -func (h *CompletionHandler) handleStreamingChatCompletion(ctx *fasthttp.RequestCtx, req *schemas.BifrostRequest, bifrostCtx *context.Context) { - getStream := func() (chan *schemas.BifrostStream, *schemas.BifrostError) { - return h.client.ChatCompletionStreamRequest(*bifrostCtx, req) +// validateAudioFile checks if the file size and format are valid +func (h *CompletionHandler) validateAudioFile(fileHeader *multipart.FileHeader) error { + // Check file size + if fileHeader.Size > MaxFileSize { + return fmt.Errorf("file size exceeds maximum limit of %d MB", MaxFileSize/1024/1024) } - extractResponse := func(response *schemas.BifrostStream) (interface{}, bool) { - return response, true + // Get file extension + ext := strings.ToLower(filepath.Ext(fileHeader.Filename)) + + // Check file extension + validExtensions := map[string]bool{ + ".flac": true, + ".mp3": true, + ".mp4": true, + ".mpeg": true, + ".mpga": true, + ".m4a": true, + ".ogg": true, + ".wav": true, + ".webm": true, } - h.handleStreamingResponse(ctx, getStream, extractResponse) -} + if !validExtensions[ext] { + return fmt.Errorf("unsupported file format: %s. Supported formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, webm", ext) + } -// handleStreamingSpeech handles streaming speech requests using Server-Sent Events (SSE) -func (h *CompletionHandler) handleStreamingSpeech(ctx *fasthttp.RequestCtx, req *schemas.BifrostRequest, bifrostCtx *context.Context) { - getStream := func() (chan *schemas.BifrostStream, *schemas.BifrostError) { - return h.client.SpeechStreamRequest(*bifrostCtx, req) + // Open file to check MIME type + file, err := fileHeader.Open() + if err != nil { + return fmt.Errorf("failed to open file: %v", err) } + defer file.Close() - extractResponse := func(response *schemas.BifrostStream) (interface{}, bool) { - if response.Speech == nil || response.Speech.BifrostSpeechStreamResponse == nil { - return nil, false - } - return response.Speech, true + // Read first 512 bytes for MIME type detection + buffer := make([]byte, 512) + _, err = file.Read(buffer) + if err != nil && err != io.EOF { + return fmt.Errorf("failed to read file header: %v", err) } - h.handleStreamingResponse(ctx, getStream, extractResponse) -} + // Check MIME type + mimeType := http.DetectContentType(buffer) + validMimeTypes := map[string]bool{ + // Primary MIME types + AudioMimeMP3: true, // Covers MP3, MPEG, MPGA + AudioMimeMP4: true, + AudioMimeM4A: true, + AudioMimeOGG: true, + AudioMimeWAV: true, + AudioMimeWEBM: true, + AudioMimeFLAC: true, + AudioMimeFLAC2: true, -// handleStreamingTranscriptionRequest handles streaming transcription requests using Server-Sent Events (SSE) -func (h *CompletionHandler) handleStreamingTranscriptionRequest(ctx *fasthttp.RequestCtx, req *schemas.BifrostRequest, bifrostCtx *context.Context) { - getStream := func() (chan *schemas.BifrostStream, *schemas.BifrostError) { - return h.client.TranscriptionStreamRequest(*bifrostCtx, req) + // Alternative MIME types + "audio/mpeg3": true, + "audio/x-wav": true, + "audio/vnd.wave": true, + "audio/x-mpeg": true, + "audio/x-mpeg3": true, + "audio/x-mpg": true, + "audio/x-mpegaudio": true, } - extractResponse := func(response *schemas.BifrostStream) (interface{}, bool) { - if response.Transcribe == nil || response.Transcribe.BifrostTranscribeStreamResponse == nil { - return nil, false - } - return response.Transcribe, true + if !validMimeTypes[mimeType] { + return fmt.Errorf("invalid file type: %s. Supported audio formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, webm", mimeType) } - h.handleStreamingResponse(ctx, getStream, extractResponse) + // Reset file pointer for subsequent reads + _, err = file.Seek(0, 0) + if err != nil { + return fmt.Errorf("failed to reset file pointer: %v", err) + } + + return nil } diff --git a/transports/bifrost-http/handlers/mcp.go b/transports/bifrost-http/handlers/mcp.go index 37fedf8fce..7ccbf90f63 100644 --- a/transports/bifrost-http/handlers/mcp.go +++ b/transports/bifrost-http/handlers/mcp.go @@ -42,7 +42,7 @@ func (h *MCPHandler) RegisterRoutes(r *router.Router) { // executeTool handles POST /v1/mcp/tool/execute - Execute MCP tool func (h *MCPHandler) executeTool(ctx *fasthttp.RequestCtx) { - var req schemas.ToolCall + var req schemas.ChatAssistantMessageToolCall if err := json.Unmarshal(ctx.PostBody(), &req); err != nil { SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid request format: %v", err), h.logger) return diff --git a/transports/bifrost-http/integrations/anthropic.go b/transports/bifrost-http/integrations/anthropic.go index a413a2b662..f18d06bac9 100644 --- a/transports/bifrost-http/integrations/anthropic.go +++ b/transports/bifrost-http/integrations/anthropic.go @@ -25,7 +25,9 @@ func CreateAnthropicRouteConfigs(pathPrefix string) []RouteConfig { }, RequestConverter: func(req interface{}) (*schemas.BifrostRequest, error) { if anthropicReq, ok := req.(*anthropic.AnthropicMessageRequest); ok { - return anthropicReq.ToBifrostRequest(), nil + return &schemas.BifrostRequest{ + ChatRequest: anthropicReq.ToBifrostRequest(), + }, nil } return nil, errors.New("invalid request type") }, diff --git a/transports/bifrost-http/integrations/genai.go b/transports/bifrost-http/integrations/genai.go index 9371a51126..5ffbdee7d3 100644 --- a/transports/bifrost-http/integrations/genai.go +++ b/transports/bifrost-http/integrations/genai.go @@ -30,7 +30,9 @@ func CreateGenAIRouteConfigs(pathPrefix string) []RouteConfig { }, RequestConverter: func(req interface{}) (*schemas.BifrostRequest, error) { if geminiReq, ok := req.(*gemini.GeminiGenerationRequest); ok { - return geminiReq.ToBifrostRequest(), nil + return &schemas.BifrostRequest{ + ChatRequest: geminiReq.ToBifrostRequest(), + }, nil } return nil, errors.New("invalid request type") }, diff --git a/transports/bifrost-http/integrations/openai.go b/transports/bifrost-http/integrations/openai.go index 22247547fc..6c218e92d5 100644 --- a/transports/bifrost-http/integrations/openai.go +++ b/transports/bifrost-http/integrations/openai.go @@ -92,6 +92,36 @@ func AzureEndpointPreHook(handlerStore lib.HandlerStore) func(ctx *fasthttp.Requ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) []RouteConfig { var routes []RouteConfig + // Text completions endpoint + for _, path := range []string{ + "/v1/completions", + "/completions", + "/openai/deployments/{deployment-id}/completions", + } { + routes = append(routes, RouteConfig{ + Path: pathPrefix + path, + Method: "POST", + GetRequestTypeInstance: func() interface{} { + return &openai.OpenAITextCompletionRequest{} + }, + RequestConverter: func(req interface{}) (*schemas.BifrostRequest, error) { + if openaiReq, ok := req.(*openai.OpenAITextCompletionRequest); ok { + return &schemas.BifrostRequest{ + TextCompletionRequest: openaiReq.ToBifrostRequest(), + }, nil + } + return nil, errors.New("invalid request type") + }, + ResponseConverter: func(resp *schemas.BifrostResponse) (interface{}, error) { + return resp, nil + }, + ErrorConverter: func(err *schemas.BifrostError) interface{} { + return err + }, + PreCallback: AzureEndpointPreHook(handlerStore), + }) + } + // Chat completions endpoint for _, path := range []string{ "/v1/chat/completions", @@ -106,22 +136,24 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) }, RequestConverter: func(req interface{}) (*schemas.BifrostRequest, error) { if openaiReq, ok := req.(*openai.OpenAIChatRequest); ok { - return openaiReq.ToBifrostRequest(), nil + return &schemas.BifrostRequest{ + ChatRequest: openaiReq.ToBifrostRequest(), + }, nil } return nil, errors.New("invalid request type") }, ResponseConverter: func(resp *schemas.BifrostResponse) (interface{}, error) { - return openai.ToOpenAIChatCompletionResponse(resp), nil + return resp, nil }, ErrorConverter: func(err *schemas.BifrostError) interface{} { - return openai.ToOpenAIError(err) + return err }, StreamConfig: &StreamConfig{ ResponseConverter: func(resp *schemas.BifrostResponse) (interface{}, error) { - return openai.ToOpenAIChatCompletionStreamResponse(resp), nil + return resp, nil }, ErrorConverter: func(err *schemas.BifrostError) interface{} { - return openai.ToOpenAIError(err) + return err }, }, PreCallback: AzureEndpointPreHook(handlerStore), @@ -142,15 +174,17 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) }, RequestConverter: func(req interface{}) (*schemas.BifrostRequest, error) { if embeddingReq, ok := req.(*openai.OpenAIEmbeddingRequest); ok { - return embeddingReq.ToBifrostRequest(), nil + return &schemas.BifrostRequest{ + EmbeddingRequest: embeddingReq.ToBifrostRequest(), + }, nil } return nil, errors.New("invalid embedding request type") }, ResponseConverter: func(resp *schemas.BifrostResponse) (interface{}, error) { - return openai.ToOpenAIEmbeddingResponse(resp), nil + return resp, nil }, ErrorConverter: func(err *schemas.BifrostError) interface{} { - return openai.ToOpenAIError(err) + return err }, PreCallback: AzureEndpointPreHook(handlerStore), }) @@ -170,7 +204,9 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) }, RequestConverter: func(req interface{}) (*schemas.BifrostRequest, error) { if speechReq, ok := req.(*openai.OpenAISpeechRequest); ok { - return speechReq.ToBifrostRequest(), nil + return &schemas.BifrostRequest{ + SpeechRequest: speechReq.ToBifrostRequest(), + }, nil } return nil, errors.New("invalid speech request type") }, @@ -179,14 +215,14 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) return speechResp.Audio, nil }, ErrorConverter: func(err *schemas.BifrostError) interface{} { - return openai.ToOpenAIError(err) + return err }, StreamConfig: &StreamConfig{ ResponseConverter: func(resp *schemas.BifrostResponse) (interface{}, error) { return openai.ToOpenAISpeechResponse(resp), nil }, ErrorConverter: func(err *schemas.BifrostError) interface{} { - return openai.ToOpenAIError(err) + return err }, }, PreCallback: AzureEndpointPreHook(handlerStore), @@ -208,22 +244,24 @@ func CreateOpenAIRouteConfigs(pathPrefix string, handlerStore lib.HandlerStore) RequestParser: parseTranscriptionMultipartRequest, // Handle multipart form parsing RequestConverter: func(req interface{}) (*schemas.BifrostRequest, error) { if transcriptionReq, ok := req.(*openai.OpenAITranscriptionRequest); ok { - return transcriptionReq.ToBifrostRequest(), nil + return &schemas.BifrostRequest{ + TranscriptionRequest: transcriptionReq.ToBifrostRequest(), + }, nil } return nil, errors.New("invalid transcription request type") }, ResponseConverter: func(resp *schemas.BifrostResponse) (interface{}, error) { - return openai.ToOpenAITranscriptionResponse(resp), nil + return resp, nil }, ErrorConverter: func(err *schemas.BifrostError) interface{} { - return openai.ToOpenAIError(err) + return err }, StreamConfig: &StreamConfig{ ResponseConverter: func(resp *schemas.BifrostResponse) (interface{}, error) { - return openai.ToOpenAITranscriptionResponse(resp), nil + return resp, nil }, ErrorConverter: func(err *schemas.BifrostError) interface{} { - return openai.ToOpenAIError(err) + return err }, }, PreCallback: AzureEndpointPreHook(handlerStore), @@ -283,51 +321,17 @@ func parseTranscriptionMultipartRequest(ctx *fasthttp.RequestCtx, req interface{ // Extract optional parameters if languageValues := form.Value["language"]; len(languageValues) > 0 && languageValues[0] != "" { language := languageValues[0] - transcriptionReq.Language = &language + transcriptionReq.TranscriptionParameters.Language = &language } if promptValues := form.Value["prompt"]; len(promptValues) > 0 && promptValues[0] != "" { prompt := promptValues[0] - transcriptionReq.Prompt = &prompt + transcriptionReq.TranscriptionParameters.Prompt = &prompt } if responseFormatValues := form.Value["response_format"]; len(responseFormatValues) > 0 && responseFormatValues[0] != "" { responseFormat := responseFormatValues[0] - transcriptionReq.ResponseFormat = &responseFormat - } - - if temperatureValues := form.Value["temperature"]; len(temperatureValues) > 0 && temperatureValues[0] != "" { - temp, err := strconv.ParseFloat(temperatureValues[0], 64) - if err != nil { - return errors.New("invalid temperature value") - } - transcriptionReq.Temperature = &temp - } - - // Handle include[] array format used by OpenAI - if includeValues := form.Value["include[]"]; len(includeValues) > 0 { - transcriptionReq.Include = includeValues - } else if includeValues := form.Value["include"]; len(includeValues) > 0 && includeValues[0] != "" { - // Fallback: Handle comma-separated values for backwards compatibility - includes := strings.Split(includeValues[0], ",") - // Trim whitespace from each value - for i, v := range includes { - includes[i] = strings.TrimSpace(v) - } - transcriptionReq.Include = includes - } - - // Handle timestamp_granularities[] array format used by OpenAI - if timestampValues := form.Value["timestamp_granularities[]"]; len(timestampValues) > 0 { - transcriptionReq.TimestampGranularities = timestampValues - } else if timestampValues := form.Value["timestamp_granularities"]; len(timestampValues) > 0 && timestampValues[0] != "" { - // Fallback: Handle comma-separated values for backwards compatibility - granularities := strings.Split(timestampValues[0], ",") - // Trim whitespace from each value - for i, v := range granularities { - granularities[i] = strings.TrimSpace(v) - } - transcriptionReq.TimestampGranularities = granularities + transcriptionReq.TranscriptionParameters.ResponseFormat = &responseFormat } if streamValues := form.Value["stream"]; len(streamValues) > 0 && streamValues[0] != "" { diff --git a/transports/bifrost-http/integrations/utils.go b/transports/bifrost-http/integrations/utils.go index 8621b807a3..71dfa55c3a 100644 --- a/transports/bifrost-http/integrations/utils.go +++ b/transports/bifrost-http/integrations/utils.go @@ -52,7 +52,7 @@ import ( "encoding/json" "fmt" "log" - "regexp" + "reflect" "strconv" "strings" @@ -271,8 +271,10 @@ func (g *GenericRouter) createHandler(config RouteConfig) fasthttp.RequestHandle g.sendError(ctx, config.ErrorConverter, newBifrostError(nil, "Invalid request")) return } - if bifrostReq.Model == "" { - g.sendError(ctx, config.ErrorConverter, newBifrostError(nil, "Model parameter is required")) + + // Extract and parse fallbacks from the request if present + if err := g.extractAndParseFallbacks(req, bifrostReq); err != nil { + g.sendError(ctx, config.ErrorConverter, newBifrostError(err, "failed to parse fallbacks: "+err.Error())) return } @@ -306,16 +308,16 @@ func (g *GenericRouter) handleNonStreamingRequest(ctx *fasthttp.RequestCtx, conf var bifrostErr *schemas.BifrostError // Handle different request types - if bifrostReq.Input.TextCompletionInput != nil { - result, bifrostErr = g.client.TextCompletionRequest(*bifrostCtx, bifrostReq) - } else if bifrostReq.Input.ChatCompletionInput != nil { - result, bifrostErr = g.client.ChatCompletionRequest(*bifrostCtx, bifrostReq) - } else if bifrostReq.Input.EmbeddingInput != nil { - result, bifrostErr = g.client.EmbeddingRequest(*bifrostCtx, bifrostReq) - } else if bifrostReq.Input.SpeechInput != nil { - result, bifrostErr = g.client.SpeechRequest(*bifrostCtx, bifrostReq) - } else if bifrostReq.Input.TranscriptionInput != nil { - result, bifrostErr = g.client.TranscriptionRequest(*bifrostCtx, bifrostReq) + if bifrostReq.TextCompletionRequest != nil { + result, bifrostErr = g.client.TextCompletionRequest(*bifrostCtx, bifrostReq.TextCompletionRequest) + } else if bifrostReq.ChatRequest != nil { + result, bifrostErr = g.client.ChatCompletionRequest(*bifrostCtx, bifrostReq.ChatRequest) + } else if bifrostReq.EmbeddingRequest != nil { + result, bifrostErr = g.client.EmbeddingRequest(*bifrostCtx, bifrostReq.EmbeddingRequest) + } else if bifrostReq.SpeechRequest != nil { + result, bifrostErr = g.client.SpeechRequest(*bifrostCtx, bifrostReq.SpeechRequest) + } else if bifrostReq.TranscriptionRequest != nil { + result, bifrostErr = g.client.TranscriptionRequest(*bifrostCtx, bifrostReq.TranscriptionRequest) } // Handle errors @@ -360,7 +362,7 @@ func (g *GenericRouter) handleNonStreamingRequest(ctx *fasthttp.RequestCtx, conf } // handleStreamingRequest handles streaming requests using Server-Sent Events (SSE) -func (g *GenericRouter) handleStreamingRequest(ctx *fasthttp.RequestCtx, config RouteConfig, bifrostReq *schemas.BifrostRequest, bifrostCtx *context.Context) { +func (g *GenericRouter) handleStreamingRequest(ctx *fasthttp.RequestCtx, config RouteConfig, bifrostReq *schemas.BifrostRequest, bifrostCtx *context.Context) { // Set common SSE headers ctx.SetContentType("text/event-stream") ctx.Response.Header.Set("Cache-Control", "no-cache") @@ -371,12 +373,12 @@ func (g *GenericRouter) handleStreamingRequest(ctx *fasthttp.RequestCtx, config var bifrostErr *schemas.BifrostError // Handle different request types - if bifrostReq.Input.ChatCompletionInput != nil { - stream, bifrostErr = g.client.ChatCompletionStreamRequest(*bifrostCtx, bifrostReq) - } else if bifrostReq.Input.SpeechInput != nil { - stream, bifrostErr = g.client.SpeechStreamRequest(*bifrostCtx, bifrostReq) - } else if bifrostReq.Input.TranscriptionInput != nil { - stream, bifrostErr = g.client.TranscriptionStreamRequest(*bifrostCtx, bifrostReq) + if bifrostReq.ChatRequest != nil { + stream, bifrostErr = g.client.ChatCompletionStreamRequest(*bifrostCtx, bifrostReq.ChatRequest) + } else if bifrostReq.SpeechRequest != nil { + stream, bifrostErr = g.client.SpeechStreamRequest(*bifrostCtx, bifrostReq.SpeechRequest) + } else if bifrostReq.TranscriptionRequest != nil { + stream, bifrostErr = g.client.TranscriptionStreamRequest(*bifrostCtx, bifrostReq.TranscriptionRequest) } // Get the streaming channel from Bifrost @@ -630,234 +632,131 @@ func (g *GenericRouter) sendSuccess(ctx *fasthttp.RequestCtx, errorConverter Err ctx.SetBody(responseBody) } -// ValidProviders is a pre-computed map for efficient O(1) provider validation. -var ValidProviders = map[schemas.ModelProvider]bool{ - schemas.OpenAI: true, - schemas.Azure: true, - schemas.Anthropic: true, - schemas.Bedrock: true, - schemas.Cohere: true, - schemas.Vertex: true, - schemas.Mistral: true, - schemas.Ollama: true, - schemas.Groq: true, - schemas.SGL: true, - schemas.Parasail: true, - schemas.Cerebras: true, - schemas.Gemini: true, - schemas.OpenRouter: true, -} - -// ParseModelString extracts provider and model from a model string. -// For model strings like "anthropic/claude", it returns ("anthropic", "claude"). -// For model strings like "claude", it returns ("", "claude"). -func ParseModelString(model string, defaultProvider schemas.ModelProvider, checkProviderFromModel bool) (schemas.ModelProvider, string) { - // Check if model contains a provider prefix (only split on first "/" to preserve model names with "/") - if strings.Contains(model, "/") { - parts := strings.SplitN(model, "/", 2) - if len(parts) == 2 { - extractedProvider := parts[0] - extractedModel := parts[1] - - return schemas.ModelProvider(extractedProvider), extractedModel - } - } - - //TODO add model wise check for provider - - // No provider prefix found, return empty provider and the original model - return defaultProvider, model -} - -// GetProviderFromModel determines the appropriate provider based on model name patterns -// This function uses comprehensive pattern matching to identify the correct provider -// for various model naming conventions used across different AI providers. -func GetProviderFromModel(model string) schemas.ModelProvider { - // Check if model contains a provider prefix (only split on first "/" to preserve model names with "/") - if strings.Contains(model, "/") { - parts := strings.SplitN(model, "/", 2) - if len(parts) > 1 { - extractedProvider := parts[0] - - if ValidProviders[schemas.ModelProvider(extractedProvider)] { - return schemas.ModelProvider(extractedProvider) - } +// newBifrostError wraps a standard error into a BifrostError with IsBifrostError set to false. +// This helper function reduces code duplication when handling non-Bifrost errors. +func newBifrostError(err error, message string) *schemas.BifrostError { + if err == nil { + return &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Message: message, + }, } } - // Normalize model name for case-insensitive matching - modelLower := strings.ToLower(strings.TrimSpace(model)) - - // Azure OpenAI Models - check first to prevent false positives from OpenAI "gpt" patterns - if isAzureModel(modelLower) { - return schemas.Azure - } - - // OpenAI Models - comprehensive pattern matching - if isOpenAIModel(modelLower) { - return schemas.OpenAI - } - - // Anthropic Models - Claude family - if isAnthropicModel(modelLower) { - return schemas.Anthropic - } - - // Google Vertex AI Models - Gemini and Palm family - if isVertexModel(modelLower) { - return schemas.Vertex - } - - // AWS Bedrock Models - various model providers through Bedrock - if isBedrockModel(modelLower) { - return schemas.Bedrock - } - - // Cohere Models - Command and Embed family - if isCohereModel(modelLower) { - return schemas.Cohere - } - - // Google GenAI Models - Gemini and Palm family - if isGeminiModel(modelLower) { - return schemas.Gemini + return &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Message: message, + Error: err, + }, } - - // Default to OpenAI for unknown models (most LiteLLM compatible) - return schemas.OpenAI } -// isOpenAIModel checks for OpenAI model patterns -func isOpenAIModel(model string) bool { - // Exclude Azure models to prevent overlap - if strings.Contains(model, "azure/") { - return false +// extractAndParseFallbacks extracts fallbacks from the integration request and adds them to the BifrostRequest +func (g *GenericRouter) extractAndParseFallbacks(req interface{}, bifrostReq *schemas.BifrostRequest) error { + // Check if the request has a fallbacks field ([]string) + fallbacks, err := g.extractFallbacksFromRequest(req) + if err != nil { + return fmt.Errorf("failed to extract fallbacks: %w", err) } - openaiPatterns := []string{ - "gpt", "davinci", "curie", "babbage", "ada", "o1", "o3", "o4", - "text-embedding", "dall-e", "whisper", "tts", "chatgpt", + if len(fallbacks) == 0 { + return nil // No fallbacks to process } - return matchesAnyPattern(model, openaiPatterns) -} - -// isAzureModel checks for Azure OpenAI specific patterns -func isAzureModel(model string) bool { - azurePatterns := []string{ - "azure", "model-router", "computer-use-preview", - } + // Parse fallbacks from strings to Fallback structs + parsedFallbacks := make([]schemas.Fallback, 0, len(fallbacks)) + for _, fallbackStr := range fallbacks { + if fallbackStr == "" { + continue // Skip empty strings + } - return matchesAnyPattern(model, azurePatterns) -} + // Use ParseModelString to extract provider and model + provider, model := schemas.ParseModelString(fallbackStr, bifrostReq.Provider) -// isAnthropicModel checks for Anthropic Claude model patterns -func isAnthropicModel(model string) bool { - anthropicPatterns := []string{ - "claude", "anthropic/", + parsedFallback := schemas.Fallback{ + Provider: provider, + Model: model, + } + parsedFallbacks = append(parsedFallbacks, parsedFallback) } - return matchesAnyPattern(model, anthropicPatterns) -} - -var geminiRegexp = regexp.MustCompile(`\b(gemini|gemini-embedding|palm|bison|gecko)\b`) - -// isGeminiModel checks for Google Gemini model patterns using strict regex matching -func isGeminiModel(model string) bool { - return geminiRegexp.MatchString(model) -} - -// isVertexModel checks for Google Vertex AI model patterns -func isVertexModel(model string) bool { - vertexPatterns := []string{ - "gemini", "palm", "bison", "gecko", "vertex/", "google/", + if len(parsedFallbacks) == 0 { + return nil // No valid fallbacks found } - return matchesAnyPattern(model, vertexPatterns) -} + // Add fallbacks to the main BifrostRequest + bifrostReq.Fallbacks = parsedFallbacks -// isBedrockModel checks for AWS Bedrock model patterns -func isBedrockModel(model string) bool { - bedrockPatterns := []string{ - "bedrock", "bedrock.amazonaws.com/", "bedrock/", - "amazon.titan", "amazon.nova", "aws/amazon.", - "ai21.jamba", "ai21.j2", "aws/ai21.", - "meta.llama", "aws/meta.", - "stability.stable-diffusion", "stability.sd3", "aws/stability.", - "anthropic.claude", "aws/anthropic.", - "cohere.command", "cohere.embed", "aws/cohere.", - "mistral.mistral", "mistral.mixtral", "aws/mistral.", - "titan-text", "titan-embed", "nova-micro", "nova-lite", "nova-pro", - "jamba-instruct", "j2-ultra", "j2-mid", - "llama-2", "llama-3", "llama-3.1", "llama-3.2", - "stable-diffusion-xl", "sd3-large", + // Also add fallbacks to the specific request type if it exists + switch bifrostReq.RequestType { + case schemas.TextCompletionRequest: + if bifrostReq.TextCompletionRequest != nil { + bifrostReq.TextCompletionRequest.Fallbacks = parsedFallbacks + } + case schemas.ChatCompletionRequest, schemas.ChatCompletionStreamRequest: + if bifrostReq.ChatRequest != nil { + bifrostReq.ChatRequest.Fallbacks = parsedFallbacks + } + case schemas.ResponsesRequest, schemas.ResponsesStreamRequest: + if bifrostReq.ResponsesRequest != nil { + bifrostReq.ResponsesRequest.Fallbacks = parsedFallbacks + } + case schemas.EmbeddingRequest: + if bifrostReq.EmbeddingRequest != nil { + bifrostReq.EmbeddingRequest.Fallbacks = parsedFallbacks + } + case schemas.SpeechRequest, schemas.SpeechStreamRequest: + if bifrostReq.SpeechRequest != nil { + bifrostReq.SpeechRequest.Fallbacks = parsedFallbacks + } + case schemas.TranscriptionRequest, schemas.TranscriptionStreamRequest: + if bifrostReq.TranscriptionRequest != nil { + bifrostReq.TranscriptionRequest.Fallbacks = parsedFallbacks + } } - return matchesAnyPattern(model, bedrockPatterns) + return nil } -// isCohereModel checks for Cohere model patterns -func isCohereModel(model string) bool { - coherePatterns := []string{ - "command-", "embed-", "cohere", +// extractFallbacksFromRequest uses reflection to extract fallbacks field from any request type +func (g *GenericRouter) extractFallbacksFromRequest(req interface{}) ([]string, error) { + if req == nil { + return nil, nil } - return matchesAnyPattern(model, coherePatterns) -} - -// matchesAnyPattern checks if the model matches any of the given patterns -func matchesAnyPattern(model string, patterns []string) bool { - for _, pattern := range patterns { - if strings.Contains(model, pattern) { - return true - } + // Try to use reflection to find a "fallbacks" field + reqValue := reflect.ValueOf(req) + if reqValue.Kind() == reflect.Ptr { + reqValue = reqValue.Elem() } - return false -} -// newBifrostError wraps a standard error into a BifrostError with IsBifrostError set to false. -// This helper function reduces code duplication when handling non-Bifrost errors. -func newBifrostError(err error, message string) *schemas.BifrostError { - if err == nil { - return &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: message, - }, - } + if reqValue.Kind() != reflect.Struct { + return nil, nil // Not a struct, no fallbacks } - return &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: message, - Error: err, - }, + // Look for the "fallbacks" field + fallbacksField := reqValue.FieldByName("fallbacks") + if !fallbacksField.IsValid() { + return nil, nil // No fallbacks field found } -} -// MapFinishReasonToProvider maps OpenAI-compatible finish reasons to provider-specific format -func MapFinishReasonToProvider(finishReason string, targetProvider schemas.ModelProvider) string { - switch targetProvider { - case schemas.Anthropic: - return mapFinishReasonToAnthropic(finishReason) - default: - // For OpenAI, Azure, and other providers, pass through as-is - return finishReason + // Handle different types of fallbacks field + switch fallbacksField.Kind() { + case reflect.Slice: + if fallbacksField.Type().Elem().Kind() == reflect.String { + // []string case + fallbacks := make([]string, fallbacksField.Len()) + for i := 0; i < fallbacksField.Len(); i++ { + fallbacks[i] = fallbacksField.Index(i).String() + } + return fallbacks, nil + } + case reflect.String: + // Single string case - treat as one fallback + return []string{fallbacksField.String()}, nil } -} -// mapFinishReasonToAnthropic maps OpenAI finish reasons to Anthropic format -func mapFinishReasonToAnthropic(finishReason string) string { - switch finishReason { - case "stop": - return "end_turn" - case "length": - return "max_tokens" - case "tool_calls": - return "tool_use" - default: - // Pass through other reasons like "pause_turn", "refusal", "stop_sequence", etc. - return finishReason - } + return nil, nil } diff --git a/transports/bifrost-http/lib/config.go b/transports/bifrost-http/lib/config.go index c68fa83653..3c4edca4be 100644 --- a/transports/bifrost-http/lib/config.go +++ b/transports/bifrost-http/lib/config.go @@ -263,6 +263,7 @@ func LoadConfig(ctx context.Context, configDirPath string) (*Config, error) { Value: dbKey.Value, Models: dbKey.Models, Weight: dbKey.Weight, + OpenAIKeyConfig: dbKey.OpenAIKeyConfig, AzureKeyConfig: dbKey.AzureKeyConfig, VertexKeyConfig: dbKey.VertexKeyConfig, BedrockKeyConfig: dbKey.BedrockKeyConfig, @@ -500,7 +501,7 @@ func LoadConfig(ctx context.Context, configDirPath string) (*Config, error) { // Process Azure key config if present if key.AzureKeyConfig != nil { - if err := config.processAzureKeyConfigEnvVars(&cfg.Keys[i], provider, i, newEnvKeys); err != nil { + if err := config.processAzureKeyConfigEnvVars(&cfg.Keys[i], provider, newEnvKeys); err != nil { config.cleanupEnvKeys(provider, "", newEnvKeys) logger.Warn("failed to process Azure key config env vars for %s: %v", provider, err) continue @@ -509,7 +510,7 @@ func LoadConfig(ctx context.Context, configDirPath string) (*Config, error) { // Process Vertex key config if present if key.VertexKeyConfig != nil { - if err := config.processVertexKeyConfigEnvVars(&cfg.Keys[i], provider, i, newEnvKeys); err != nil { + if err := config.processVertexKeyConfigEnvVars(&cfg.Keys[i], provider, newEnvKeys); err != nil { config.cleanupEnvKeys(provider, "", newEnvKeys) logger.Warn("failed to process Vertex key config env vars for %s: %v", provider, err) continue @@ -518,7 +519,7 @@ func LoadConfig(ctx context.Context, configDirPath string) (*Config, error) { // Process Bedrock key config if present if key.BedrockKeyConfig != nil { - if err := config.processBedrockKeyConfigEnvVars(&cfg.Keys[i], provider, i, newEnvKeys); err != nil { + if err := config.processBedrockKeyConfigEnvVars(&cfg.Keys[i], provider, newEnvKeys); err != nil { config.cleanupEnvKeys(provider, "", newEnvKeys) logger.Warn("failed to process Bedrock key config env vars for %s: %v", provider, err) continue @@ -880,9 +881,10 @@ func (s *Config) GetProviderConfigRedacted(provider schemas.ModelProvider) (*con redactedConfig.Keys = make([]schemas.Key, len(config.Keys)) for i, key := range config.Keys { redactedConfig.Keys[i] = schemas.Key{ - ID: key.ID, - Models: key.Models, // Copy slice reference - read-only so safe - Weight: key.Weight, + ID: key.ID, + Models: key.Models, // Copy slice reference - read-only so safe + Weight: key.Weight, + OpenAIKeyConfig: key.OpenAIKeyConfig, } // Redact API key value @@ -1069,7 +1071,7 @@ func (s *Config) AddProvider(provider schemas.ModelProvider, config configstore. // Process Azure key config if present if key.AzureKeyConfig != nil { - if err := s.processAzureKeyConfigEnvVars(&config.Keys[i], provider, i, newEnvKeys); err != nil { + if err := s.processAzureKeyConfigEnvVars(&config.Keys[i], provider, newEnvKeys); err != nil { s.cleanupEnvKeys(provider, "", newEnvKeys) return fmt.Errorf("failed to process Azure key config env vars: %w", err) } @@ -1077,7 +1079,7 @@ func (s *Config) AddProvider(provider schemas.ModelProvider, config configstore. // Process Vertex key config if present if key.VertexKeyConfig != nil { - if err := s.processVertexKeyConfigEnvVars(&config.Keys[i], provider, i, newEnvKeys); err != nil { + if err := s.processVertexKeyConfigEnvVars(&config.Keys[i], provider, newEnvKeys); err != nil { s.cleanupEnvKeys(provider, "", newEnvKeys) return fmt.Errorf("failed to process Vertex key config env vars: %w", err) } @@ -1085,7 +1087,7 @@ func (s *Config) AddProvider(provider schemas.ModelProvider, config configstore. // Process Bedrock key config if present if key.BedrockKeyConfig != nil { - if err := s.processBedrockKeyConfigEnvVars(&config.Keys[i], provider, i, newEnvKeys); err != nil { + if err := s.processBedrockKeyConfigEnvVars(&config.Keys[i], provider, newEnvKeys); err != nil { s.cleanupEnvKeys(provider, "", newEnvKeys) return fmt.Errorf("failed to process Bedrock key config env vars: %w", err) } @@ -1174,7 +1176,7 @@ func (s *Config) UpdateProviderConfig(provider schemas.ModelProvider, config con // Process Azure key config if present if key.AzureKeyConfig != nil { - if err := s.processAzureKeyConfigEnvVars(&config.Keys[i], provider, i, newEnvKeys); err != nil { + if err := s.processAzureKeyConfigEnvVars(&config.Keys[i], provider, newEnvKeys); err != nil { s.cleanupEnvKeys(provider, "", newEnvKeys) return fmt.Errorf("failed to process Azure key config env vars: %w", err) } @@ -1182,7 +1184,7 @@ func (s *Config) UpdateProviderConfig(provider schemas.ModelProvider, config con // Process Vertex key config if present if key.VertexKeyConfig != nil { - if err := s.processVertexKeyConfigEnvVars(&config.Keys[i], provider, i, newEnvKeys); err != nil { + if err := s.processVertexKeyConfigEnvVars(&config.Keys[i], provider, newEnvKeys); err != nil { s.cleanupEnvKeys(provider, "", newEnvKeys) return fmt.Errorf("failed to process Vertex key config env vars: %w", err) } @@ -1190,7 +1192,7 @@ func (s *Config) UpdateProviderConfig(provider schemas.ModelProvider, config con // Process Bedrock key config if present if key.BedrockKeyConfig != nil { - if err := s.processBedrockKeyConfigEnvVars(&config.Keys[i], provider, i, newEnvKeys); err != nil { + if err := s.processBedrockKeyConfigEnvVars(&config.Keys[i], provider, newEnvKeys); err != nil { s.cleanupEnvKeys(provider, "", newEnvKeys) return fmt.Errorf("failed to process Bedrock key config env vars: %w", err) } @@ -1837,7 +1839,7 @@ func (s *Config) autoDetectProviders() { } // processAzureKeyConfigEnvVars processes environment variables in Azure key configuration -func (s *Config) processAzureKeyConfigEnvVars(key *schemas.Key, provider schemas.ModelProvider, keyIndex int, newEnvKeys map[string]struct{}) error { +func (s *Config) processAzureKeyConfigEnvVars(key *schemas.Key, provider schemas.ModelProvider, newEnvKeys map[string]struct{}) error { azureConfig := key.AzureKeyConfig // Process Endpoint @@ -1880,7 +1882,7 @@ func (s *Config) processAzureKeyConfigEnvVars(key *schemas.Key, provider schemas } // processVertexKeyConfigEnvVars processes environment variables in Vertex key configuration -func (s *Config) processVertexKeyConfigEnvVars(key *schemas.Key, provider schemas.ModelProvider, keyIndex int, newEnvKeys map[string]struct{}) error { +func (s *Config) processVertexKeyConfigEnvVars(key *schemas.Key, provider schemas.ModelProvider, newEnvKeys map[string]struct{}) error { vertexConfig := key.VertexKeyConfig // Process ProjectID @@ -1938,7 +1940,7 @@ func (s *Config) processVertexKeyConfigEnvVars(key *schemas.Key, provider schema } // processBedrockKeyConfigEnvVars processes environment variables in Bedrock key configuration -func (s *Config) processBedrockKeyConfigEnvVars(key *schemas.Key, provider schemas.ModelProvider, keyIndex int, newEnvKeys map[string]struct{}) error { +func (s *Config) processBedrockKeyConfigEnvVars(key *schemas.Key, provider schemas.ModelProvider, newEnvKeys map[string]struct{}) error { bedrockConfig := key.BedrockKeyConfig // Process AccessKey diff --git a/transports/bifrost-http/main.go b/transports/bifrost-http/main.go index a7947640e6..d8bfa6470e 100644 --- a/transports/bifrost-http/main.go +++ b/transports/bifrost-http/main.go @@ -2,7 +2,7 @@ // for text and chat completions using various AI model providers (OpenAI, Anthropic, Bedrock, Mistral, Ollama, etc.). // // The HTTP service provides the following main endpoints: -// - /v1/text/completions: For text completion requests +// - /v1/completions: For text completion requests // - /v1/chat/completions: For chat completion requests // - /v1/mcp/tool/execute: For MCP tool execution requests // - /providers/*: For provider configuration management @@ -330,7 +330,7 @@ func getDefaultConfigDir(appDir string) string { // 5. Starts the HTTP server on the specified host and port // // The server exposes the following endpoints: -// - POST /v1/text/completions: For text completion requests +// - POST /v1/completions: For text completion requests // - POST /v1/chat/completions: For chat completion requests // - GET /metrics: For Prometheus metrics func main() { @@ -411,7 +411,7 @@ func main() { // Eventually same flow will be used for third party plugins for _, plugin := range config.Plugins { if !plugin.Enabled { - logger.Debug("plugin %s is disabled, skipping initialization", plugin.Name) + logger.Debug("plugin %s is disabled, skipping initialization", plugin.Name) continue } switch strings.ToLower(plugin.Name) { diff --git a/transports/go.mod b/transports/go.mod index d9b324f232..6633073f15 100644 --- a/transports/go.mod +++ b/transports/go.mod @@ -18,13 +18,10 @@ require ( github.com/maximhq/bifrost/plugins/telemetry v1.2.16 github.com/prometheus/client_golang v1.23.0 github.com/valyala/fasthttp v1.65.0 - google.golang.org/genai v1.22.0 gorm.io/gorm v1.30.1 ) require ( - cloud.google.com/go v0.121.6 // indirect - cloud.google.com/go/auth v0.16.5 // indirect cloud.google.com/go/compute/metadata v0.8.0 // indirect github.com/andybalholm/brotli v1.2.0 // indirect github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 // indirect @@ -48,9 +45,6 @@ require ( github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cloudwego/base64x v0.1.6 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect - github.com/felixge/httpsnoop v1.0.4 // indirect - github.com/go-logr/logr v1.4.3 // indirect - github.com/go-logr/stdr v1.2.2 // indirect github.com/go-openapi/analysis v0.23.0 // indirect github.com/go-openapi/errors v0.22.0 // indirect github.com/go-openapi/jsonpointer v0.21.0 // indirect @@ -61,11 +55,6 @@ require ( github.com/go-openapi/strfmt v0.23.0 // indirect github.com/go-openapi/swag v0.23.0 // indirect github.com/go-openapi/validate v0.24.0 // indirect - github.com/google/go-cmp v0.7.0 // indirect - github.com/google/s2a-go v0.1.9 // indirect - github.com/googleapis/enterprise-certificate-proxy v0.3.6 // indirect - github.com/googleapis/gax-go/v2 v2.15.0 // indirect - github.com/gorilla/websocket v1.5.3 // indirect github.com/invopop/jsonschema v0.13.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect @@ -97,13 +86,7 @@ require ( github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect go.mongodb.org/mongo-driver v1.14.0 // indirect - go.opentelemetry.io/auto/sdk v1.1.0 // indirect - go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.62.0 // indirect - go.opentelemetry.io/otel v1.37.0 // indirect - go.opentelemetry.io/otel/metric v1.37.0 // indirect - go.opentelemetry.io/otel/trace v1.37.0 // indirect golang.org/x/arch v0.20.0 // indirect - golang.org/x/crypto v0.41.0 // indirect golang.org/x/net v0.43.0 // indirect golang.org/x/oauth2 v0.30.0 // indirect golang.org/x/sys v0.35.0 // indirect @@ -116,4 +99,4 @@ require ( gorm.io/driver/sqlite v1.6.0 // indirect ) -replace github.com/maximhq/bifrost/core => ../core \ No newline at end of file +replace github.com/maximhq/bifrost/core => ../core diff --git a/transports/go.sum b/transports/go.sum index d3063c18a8..a4c2c57dce 100644 --- a/transports/go.sum +++ b/transports/go.sum @@ -1,7 +1,3 @@ -cloud.google.com/go v0.121.6 h1:waZiuajrI28iAf40cWgycWNgaXPO06dupuS+sgibK6c= -cloud.google.com/go v0.121.6/go.mod h1:coChdst4Ea5vUpiALcYKXEpR1S9ZgXbhEzzMcMR66vI= -cloud.google.com/go/auth v0.16.5 h1:mFWNQ2FEVWAliEQWpAdH80omXFokmrnbDhUS9cBywsI= -cloud.google.com/go/auth v0.16.5/go.mod h1:utzRfHMP+Vv0mpOkTRQoWD2q3BatTOoWbA7gCc2dUhQ= cloud.google.com/go/compute/metadata v0.8.0 h1:HxMRIbao8w17ZX6wBnjhcDkW6lTFpgcaobyVfZWqRLA= cloud.google.com/go/compute/metadata v0.8.0/go.mod h1:sYOGTp851OV9bOFJ9CH7elVvyzopvWQFNNghtDQ/Biw= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= @@ -70,11 +66,8 @@ github.com/fasthttp/router v1.5.4 h1:oxdThbBwQgsDIYZ3wR1IavsNl6ZS9WdjKukeMikOnC8 github.com/fasthttp/router v1.5.4/go.mod h1:3/hysWq6cky7dTfzaaEPZGdptwjwx0qzTgFCKEWRjgc= github.com/fasthttp/websocket v1.5.12 h1:e4RGPpWW2HTbL3zV0Y/t7g0ub294LkiuXXUuTOUInlE= github.com/fasthttp/websocket v1.5.12/go.mod h1:I+liyL7/4moHojiOgUOIKEWm9EIxHqxZChS+aMFltyg= -github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= -github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= -github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= @@ -148,17 +141,9 @@ github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEW github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= -github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0= -github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM= github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/googleapis/enterprise-certificate-proxy v0.3.6 h1:GW/XbdyBFQ8Qe+YAmFU9uHLo7OnF5tL52HFAgMmyrf4= -github.com/googleapis/enterprise-certificate-proxy v0.3.6/go.mod h1:MkHOF77EYAE7qfSuSS9PU6g4Nt4e11cnsDUowfwewLA= -github.com/googleapis/gax-go/v2 v2.15.0 h1:SyjDc1mGgZU5LncH8gimWo9lW1DtIfPibOG81vgd/bo= -github.com/googleapis/gax-go/v2 v2.15.0/go.mod h1:zVVkkxAQHa1RQpg9z2AUCMnKhi0Qld9rcmyfL1OZhoc= -github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= -github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E= github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= @@ -206,8 +191,6 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs= github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= -github.com/maximhq/bifrost/core v1.1.38 h1:d5B7n5oibBO9f5wMBxyymTewK017nzS15ZzJILRAE6k= -github.com/maximhq/bifrost/core v1.1.38/go.mod h1:tf2pFTpoM53UGXXMFYxsaUjMqnCqYDOd9glFgMJvA0c= github.com/maximhq/bifrost/framework v1.0.24 h1:pxunQTl70q1GwOmxcTIwsEoeRTJfAcJZosa8C6KMPjI= github.com/maximhq/bifrost/framework v1.0.24/go.mod h1:94045IOmEISTQsaIFuNIn9ZfJ0lJS8uP4+YnzkpCnFo= github.com/maximhq/bifrost/plugins/governance v1.2.17 h1:i+9ZDYhuJBOS5hSkWf3v9dnrqUOtKLxrzw1H52YxDT8= @@ -313,8 +296,6 @@ go.mongodb.org/mongo-driver v1.14.0 h1:P98w8egYRjYe3XDjxhYJagTokP/H6HzlsnojRgZRd go.mongodb.org/mongo-driver v1.14.0/go.mod h1:Vzb0Mk/pa7e6cWw85R4F/endUC3u0U9jGcNU603k65c= go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= -go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.62.0 h1:Hf9xI/XLML9ElpiHVDNwvqI0hIFlzV8dgIr35kV1kRU= -go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.62.0/go.mod h1:NfchwuyNoMcZ5MLHwPrODwUF1HWCXWrL31s8gSAdIKY= go.opentelemetry.io/otel v1.37.0 h1:9zhNfelUvx0KBfu/gb+ZgeAfAgtWrfHJZcAqFC228wQ= go.opentelemetry.io/otel v1.37.0/go.mod h1:ehE/umFRLnuLa/vSccNq9oS1ErUlkkK71gMcN34UG8I= go.opentelemetry.io/otel/metric v1.37.0 h1:mvwbQS5m0tbmqML4NqK+e3aDiO02vsf/WgbsdpcPoZE= @@ -334,8 +315,6 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk golang.org/x/crypto v0.0.0-20190422162423-af44ce270edf/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= golang.org/x/crypto v0.0.0-20200302210943-78000ba7a073/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20201216223049-8b5274cf687f/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= -golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4= -golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20210421230115-4e50805a0758/go.mod h1:72T/g9IO56b78aLF+1Kcs5dz7/ng1VjMUvfKvpfy+jM= @@ -348,8 +327,6 @@ golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20190412183630-56d357773e84/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= -golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -382,8 +359,6 @@ golang.org/x/tools v0.0.0-20190416151739-9c9e1878f421/go.mod h1:LCzVGOaR6xXOjkQ3 golang.org/x/tools v0.0.0-20190420181800-aa740d480789/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190531172133-b3315ee88b7d/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/genai v1.22.0 h1:5hrEhXXWJQZa3tdPocl4vQ/0w6myEAxdNns2Kmx0f4Y= -google.golang.org/genai v1.22.0/go.mod h1:QPj5NGJw+3wEOHg+PrsWwJKvG6UC84ex5FR7qAYsN/M= google.golang.org/genproto/googleapis/rpc v0.0.0-20250811230008-5f3141c8851a h1:tPE/Kp+x9dMSwUm/uM0JKK0IfdiJkwAbSMSeZBXXJXc= google.golang.org/genproto/googleapis/rpc v0.0.0-20250811230008-5f3141c8851a/go.mod h1:gw1tLEfykwDz2ET4a12jcXt4couGAm7IwsVaTy0Sflo= google.golang.org/grpc v1.74.2 h1:WoosgB65DlWVC9FqI82dGsZhWFNBSLjQ84bjROOpMu4= diff --git a/ui/app/logs/page.tsx b/ui/app/logs/page.tsx index 653e5eba86..0e2e334355 100644 --- a/ui/app/logs/page.tsx +++ b/ui/app/logs/page.tsx @@ -9,7 +9,7 @@ import { Alert, AlertDescription } from "@/components/ui/alert"; import { Card, CardContent } from "@/components/ui/card"; import { useWebSocket } from "@/hooks/useWebSocket"; import { getErrorMessage, useLazyGetLogsQuery } from "@/lib/store"; -import type { BifrostMessage, ContentBlock, LogEntry, LogFilters, LogStats, MessageContent, Pagination } from "@/lib/types/logs"; +import type { ChatMessage, ContentBlock, LogEntry, LogFilters, LogStats, ChatMessageContent, Pagination } from "@/lib/types/logs"; import { AlertCircle, BarChart, CheckCircle, Clock, DollarSign, Hash } from "lucide-react"; import { useCallback, useEffect, useMemo, useRef, useState } from "react"; @@ -244,7 +244,7 @@ export default function LogsPage() { setInitialLoading(false); }, []); - const getMessageText = (content: MessageContent): string => { + const getMessageText = (content: ChatMessageContent): string => { if (typeof content === "string") { return content; } @@ -291,7 +291,7 @@ export default function LogsPage() { if (filters.content_search) { const search = filters.content_search.toLowerCase(); const content = [ - ...(log.input_history || []).map((msg: BifrostMessage) => getMessageText(msg.content)), + ...(log.input_history || []).map((msg: ChatMessage) => getMessageText(msg.content)), log.output_message ? getMessageText(log.output_message.content) : "", ] .join(" ") diff --git a/ui/app/logs/views/logMessageView.tsx b/ui/app/logs/views/logMessageView.tsx index 0b5cbcece9..32e02e2f85 100644 --- a/ui/app/logs/views/logMessageView.tsx +++ b/ui/app/logs/views/logMessageView.tsx @@ -1,8 +1,8 @@ -import { BifrostMessage } from "@/lib/types/logs"; +import { ChatMessage } from "@/lib/types/logs"; import { CodeEditor } from "./codeEditor"; interface LogMessageViewProps { - message: BifrostMessage; + message: ChatMessage; } const isJson = (text: string) => { diff --git a/ui/app/providers/fragments/apiKeysFormFragment.tsx b/ui/app/providers/fragments/apiKeysFormFragment.tsx index 1ebda00584..b0a20fd4e2 100644 --- a/ui/app/providers/fragments/apiKeysFormFragment.tsx +++ b/ui/app/providers/fragments/apiKeysFormFragment.tsx @@ -3,6 +3,7 @@ import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert"; import { FormControl, FormDescription, FormField, FormItem, FormLabel, FormMessage } from "@/components/ui/form"; import { Input } from "@/components/ui/input"; +import { Switch } from "@/components/ui/switch"; import { Separator } from "@/components/ui/separator"; import { TagInput } from "@/components/ui/tagInput"; import { Textarea } from "@/components/ui/textarea"; @@ -27,6 +28,7 @@ const MODEL_PLACEHOLDERS = { }; export function ApiKeyFormFragment({ control, providerName, form }: Props) { + const isOpenAI = providerName === "openai"; const isBedrock = providerName === "bedrock"; const isVertex = providerName === "vertex"; const isAzure = providerName === "azure"; @@ -150,6 +152,30 @@ export function ApiKeyFormFragment({ control, providerName, form }: Props) { )} /> + {isOpenAI && ( +
+ ( + + +
+
+ +

Use the Responses API instead of the Chat Completion API.

+
+ +
+
+ +
+ )} + /> +
+ )} {isAzure && (
return KnownProvidersNames.includes(provider.toLowerCase() as KnownProvider); }; +export interface OpenAIKeyConfig { + use_responses_api: boolean; +} + +export const DefaultOpenAIKeyConfig: OpenAIKeyConfig = { + use_responses_api: false, +} as const satisfies Required; + // AzureKeyConfig matching Go's schemas.AzureKeyConfig export interface AzureKeyConfig { endpoint: string; @@ -68,6 +76,7 @@ export interface ModelProviderKey { value?: string; models?: string[]; weight: number; + openai_key_config?: OpenAIKeyConfig; azure_key_config?: AzureKeyConfig; vertex_key_config?: VertexKeyConfig; bedrock_key_config?: BedrockKeyConfig; diff --git a/ui/lib/types/logs.ts b/ui/lib/types/logs.ts index a3105f3b46..dc31d5926b 100644 --- a/ui/lib/types/logs.ts +++ b/ui/lib/types/logs.ts @@ -100,11 +100,11 @@ export interface ContentBlock { }; } -export type MessageContent = string | ContentBlock[]; +export type ChatMessageContent = string | ContentBlock[]; -export interface BifrostMessage { +export interface ChatMessage { role: "assistant" | "user" | "system" | "chatbot" | "tool"; - content: MessageContent; + content: ChatMessageContent; tool_call_id?: string; refusal?: string; annotations?: Annotation[]; @@ -237,8 +237,8 @@ export interface LogEntry { timestamp: string; // ISO string format from Go time.Time provider: string; model: string; - input_history: BifrostMessage[]; - output_message?: BifrostMessage; + input_history: ChatMessage[]; + output_message?: ChatMessage; embedding_output?: BifrostEmbedding[]; params?: ModelParameters; speech_input?: SpeechInput; diff --git a/ui/lib/types/schemas.ts b/ui/lib/types/schemas.ts index cea2666ef2..3b1c379806 100644 --- a/ui/lib/types/schemas.ts +++ b/ui/lib/types/schemas.ts @@ -12,6 +12,11 @@ export const customProviderNameSchema = z.string().min(1, "Custom provider name // Model provider name schema (union of known and custom providers) export const modelProviderNameSchema = z.union([knownProviderSchema, customProviderNameSchema]); +// OpenAI key config schema +export const openaiKeyConfigSchema = z.object({ + use_responses_api: z.boolean(), +}); + // Azure key config schema export const azureKeyConfigSchema = z.object({ endpoint: z.url("Must be a valid URL"), @@ -62,6 +67,7 @@ export const modelProviderKeySchema = z }) .pipe(z.number().min(0.1, "Weight must be greater than 0.1").max(1, "Weight must be less than 1")), ]), + openai_key_config: openaiKeyConfigSchema.optional(), azure_key_config: azureKeyConfigSchema.optional(), vertex_key_config: vertexKeyConfigSchema.optional(), bedrock_key_config: bedrockKeyConfigSchema.optional(),