diff --git a/core/mcp.go b/core/mcp.go index 43b98f17e7..21ee1661fc 100644 --- a/core/mcp.go +++ b/core/mcp.go @@ -899,9 +899,9 @@ func (m *MCPManager) addMCPToolsToBifrostRequest(ctx context.Context, req *schem req.Params = &schemas.ModelParameters{} } if req.Params.Tools == nil { - req.Params.Tools = &[]schemas.Tool{} + req.Params.Tools = []schemas.Tool{} } - tools := *req.Params.Tools + tools := req.Params.Tools // Create a map of existing tool names for O(1) lookup existingToolsMap := make(map[string]bool) @@ -917,7 +917,7 @@ func (m *MCPManager) addMCPToolsToBifrostRequest(ctx context.Context, req *schem existingToolsMap[mcpTool.Function.Name] = true } } - req.Params.Tools = &tools + req.Params.Tools = tools } return req diff --git a/core/providers/anthropic.go b/core/providers/anthropic.go index e8b01b9a74..fc69765b31 100644 --- a/core/providers/anthropic.go +++ b/core/providers/anthropic.go @@ -14,134 +14,13 @@ import ( "github.com/bytedance/sonic" schemas "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/core/schemas/api" "github.com/valyala/fasthttp" ) -// AnthropicToolChoice represents the tool choice configuration for Anthropic's API. -// It specifies how tools should be used in the completion request. -type AnthropicToolChoice struct { - Type schemas.ToolChoiceType `json:"type"` // Type of tool choice - Name *string `json:"name"` // Name of the tool to use - DisableParallelToolUse *bool `json:"disable_parallel_tool_use"` // Whether to disable parallel tool use -} - -// AnthropicTextResponse represents the response structure from Anthropic's text completion API. -// It includes the completion text, model information, and token usage statistics. -type AnthropicTextResponse struct { - ID string `json:"id"` // Unique identifier for the completion - Type string `json:"type"` // Type of completion - Completion string `json:"completion"` // Generated completion text - Model string `json:"model"` // Model used for the completion - Usage struct { - InputTokens int `json:"input_tokens"` // Number of input tokens used - OutputTokens int `json:"output_tokens"` // Number of output tokens generated - } `json:"usage"` // Token usage statistics -} - -// AnthropicChatResponse represents the response structure from Anthropic's chat completion API. -// It includes message content, model information, and token usage statistics. -type AnthropicChatResponse struct { - ID string `json:"id"` // Unique identifier for the completion - Type string `json:"type"` // Type of completion - Role string `json:"role"` // Role of the message sender - Content []struct { - Type string `json:"type"` // Type of content - Text string `json:"text,omitempty"` // Text content - Thinking string `json:"thinking,omitempty"` // Thinking process - ID string `json:"id"` // Content identifier - Name string `json:"name"` // Name of the content - Input map[string]interface{} `json:"input"` // Input parameters - } `json:"content"` // Array of content items - Model string `json:"model"` // Model used for the completion - StopReason string `json:"stop_reason,omitempty"` // Reason for completion termination - StopSequence *string `json:"stop_sequence,omitempty"` // Sequence that caused completion to stop - Usage struct { - InputTokens int `json:"input_tokens"` // Number of input tokens used - OutputTokens int `json:"output_tokens"` // Number of output tokens generated - } `json:"usage"` // Token usage statistics -} - -// AnthropicStreamEvent represents a single event in the Anthropic streaming response. -// It corresponds to the various event types defined in Anthropic's Messages API streaming documentation. -type AnthropicStreamEvent struct { - Type string `json:"type"` - Message *AnthropicStreamMessage `json:"message,omitempty"` - Index *int `json:"index,omitempty"` - ContentBlock *AnthropicContentBlock `json:"content_block,omitempty"` - Delta *AnthropicDelta `json:"delta,omitempty"` - Usage *schemas.LLMUsage `json:"usage,omitempty"` - Error *AnthropicStreamError `json:"error,omitempty"` -} - -// AnthropicStreamMessage represents the message structure in streaming events. -// This appears in message_start events and contains the initial message structure. -type AnthropicStreamMessage struct { - ID string `json:"id"` - Type string `json:"type"` - Role string `json:"role"` - Content []AnthropicContentBlock `json:"content"` - Model string `json:"model"` - StopReason *string `json:"stop_reason"` - StopSequence *string `json:"stop_sequence"` - Usage *schemas.LLMUsage `json:"usage"` -} - -// AnthropicContentBlock represents a content block in Anthropic responses. -// This includes text, tool_use, thinking, and web_search_tool_result blocks. -type AnthropicContentBlock struct { - Type string `json:"type"` - Text string `json:"text,omitempty"` - ID string `json:"id,omitempty"` - Name string `json:"name,omitempty"` - Input map[string]interface{} `json:"input,omitempty"` - Thinking string `json:"thinking,omitempty"` - // Web search tool result specific fields - ToolUseID string `json:"tool_use_id,omitempty"` - Content []AnthropicToolContent `json:"content,omitempty"` -} - -// AnthropicToolContent represents content within tool result blocks -type AnthropicToolContent struct { - Type string `json:"type"` - Title string `json:"title,omitempty"` - URL string `json:"url,omitempty"` - EncryptedContent string `json:"encrypted_content,omitempty"` - PageAge *string `json:"page_age,omitempty"` -} - -// AnthropicDelta represents incremental updates to content blocks during streaming. -// This includes all delta types: text_delta, input_json_delta, thinking_delta, and signature_delta. -type AnthropicDelta struct { - Type string `json:"type"` - Text string `json:"text,omitempty"` - PartialJSON string `json:"partial_json,omitempty"` - Thinking string `json:"thinking,omitempty"` - Signature string `json:"signature,omitempty"` - StopReason *string `json:"stop_reason,omitempty"` - StopSequence *string `json:"stop_sequence,omitempty"` -} - -// AnthropicStreamError represents error events in the streaming response. -type AnthropicStreamError struct { - Type string `json:"type"` - Message string `json:"message"` -} - -// AnthropicError represents the error response structure from Anthropic's API. -// It includes error type and message information. -type AnthropicError struct { - Type string `json:"type"` // always "error" - Error struct { - Type string `json:"type"` // Error type - Message string `json:"message"` // Error message - } `json:"error"` // Error details -} - -type AnthropicImageContent struct { - Type ImageContentType `json:"type"` - URL string `json:"url"` - MediaType string `json:"media_type,omitempty"` -} +const ( + DEFAULT_MAX_TOKENS = 4096 +) // AnthropicProvider implements the Provider interface for Anthropic's Claude API. type AnthropicProvider struct { @@ -156,40 +35,40 @@ type AnthropicProvider struct { // anthropicChatResponsePool provides a pool for Anthropic chat response objects. var anthropicChatResponsePool = sync.Pool{ New: func() interface{} { - return &AnthropicChatResponse{} + return &api.AnthropicChatResponse{} }, } // anthropicTextResponsePool provides a pool for Anthropic text response objects. var anthropicTextResponsePool = sync.Pool{ New: func() interface{} { - return &AnthropicTextResponse{} + return &api.AnthropicTextResponse{} }, } // acquireAnthropicChatResponse gets an Anthropic chat response from the pool and resets it. -func acquireAnthropicChatResponse() *AnthropicChatResponse { - resp := anthropicChatResponsePool.Get().(*AnthropicChatResponse) - *resp = AnthropicChatResponse{} // Reset the struct +func acquireAnthropicChatResponse() *api.AnthropicChatResponse { + resp := anthropicChatResponsePool.Get().(*api.AnthropicChatResponse) + *resp = api.AnthropicChatResponse{} // Reset the struct return resp } // releaseAnthropicChatResponse returns an Anthropic chat response to the pool. -func releaseAnthropicChatResponse(resp *AnthropicChatResponse) { +func releaseAnthropicChatResponse(resp *api.AnthropicChatResponse) { if resp != nil { anthropicChatResponsePool.Put(resp) } } // acquireAnthropicTextResponse gets an Anthropic text response from the pool and resets it. -func acquireAnthropicTextResponse() *AnthropicTextResponse { - resp := anthropicTextResponsePool.Get().(*AnthropicTextResponse) - *resp = AnthropicTextResponse{} // Reset the struct +func acquireAnthropicTextResponse() *api.AnthropicTextResponse { + resp := anthropicTextResponsePool.Get().(*api.AnthropicTextResponse) + *resp = api.AnthropicTextResponse{} // Reset the struct return resp } // releaseAnthropicTextResponse returns an Anthropic text response to the pool. -func releaseAnthropicTextResponse(resp *AnthropicTextResponse) { +func releaseAnthropicTextResponse(resp *api.AnthropicTextResponse) { if resp != nil { anthropicTextResponsePool.Put(resp) } @@ -214,8 +93,8 @@ func NewAnthropicProvider(config *schemas.ProviderConfig, logger schemas.Logger) // Pre-warm response pools for range config.ConcurrencyAndBufferSize.Concurrency { - anthropicTextResponsePool.Put(&AnthropicTextResponse{}) - anthropicChatResponsePool.Put(&AnthropicChatResponse{}) + anthropicTextResponsePool.Put(&api.AnthropicTextResponse{}) + anthropicChatResponsePool.Put(&api.AnthropicChatResponse{}) } @@ -243,28 +122,208 @@ func (provider *AnthropicProvider) GetProviderKey() schemas.ModelProvider { return schemas.Anthropic } -// prepareTextCompletionParams prepares text completion parameters for Anthropic's API. -// It handles parameter mapping and conversion to the format expected by Anthropic. -// Returns the modified parameters map. -func (provider *AnthropicProvider) prepareTextCompletionParams(params map[string]interface{}) map[string]interface{} { - // Check if there is a key entry for max_tokens - if maxTokens, exists := params["max_tokens"]; exists { - // Check if max_tokens_to_sample is already present - if _, exists := params["max_tokens_to_sample"]; !exists { - // If max_tokens_to_sample is not present, rename max_tokens to max_tokens_to_sample - params["max_tokens_to_sample"] = maxTokens +// buildAnthropicTextRequest creates a type-safe Anthropic text completion request +// from Bifrost text input and parameters. +func buildAnthropicTextRequest(model string, text string, params *schemas.ModelParameters) *api.AnthropicTextRequest { + // Format the prompt with Anthropic's expected format + prompt := fmt.Sprintf("\n\nHuman: %s\n\nAssistant:", text) + + // Build the request + request := &api.AnthropicTextRequest{ + Model: model, + Prompt: prompt, + } + + // Add parameters if provided + if params != nil { + request.MaxTokensToSample = DEFAULT_MAX_TOKENS + if params.MaxTokens != nil { + request.MaxTokensToSample = *params.MaxTokens + } + request.Temperature = params.Temperature + request.TopP = params.TopP + request.TopK = params.TopK + request.StopSequences = params.StopSequences + + if params.ExtraParams != nil { + request.ExtraParams = params.ExtraParams } - delete(params, "max_tokens") } - return params + + return request +} + +// buildAnthropicChatRequest creates a type-safe Anthropic chat completion request +// from Bifrost messages and parameters. +func buildAnthropicChatRequest(model string, messages []schemas.BifrostMessage, params *schemas.ModelParameters) *api.AnthropicMessageRequest { + // Convert Bifrost messages to Anthropic format + var anthropicMessages []api.AnthropicMessage + var systemContent *api.AnthropicContent + + for _, msg := range messages { + if msg.Role == schemas.ModelChatMessageRoleSystem { + // Handle system messages separately + if msg.Content.ContentStr != nil { + systemContent = &api.AnthropicContent{ + ContentStr: msg.Content.ContentStr, + } + } else if msg.Content.ContentBlocks != nil { + // Convert content blocks to Anthropic format + contentBlocks := make([]api.AnthropicContentBlock, 0, len(*msg.Content.ContentBlocks)) + for _, block := range *msg.Content.ContentBlocks { + if block.Text != nil { + contentBlocks = append(contentBlocks, api.AnthropicContentBlock{ + Type: "text", + Text: block.Text, + }) + } + } + if len(contentBlocks) > 0 { + systemContent = &api.AnthropicContent{ + ContentBlocks: &contentBlocks, + } + } + } + } else { + // Convert regular messages + anthropicMsg := api.AnthropicMessage{ + Role: string(msg.Role), + } + + if msg.Content.ContentStr != nil { + anthropicMsg.Content = api.AnthropicContent{ + ContentStr: msg.Content.ContentStr, + } + } else if msg.Content.ContentBlocks != nil { + // Convert content blocks to Anthropic format + contentBlocks := make([]api.AnthropicContentBlock, 0, len(*msg.Content.ContentBlocks)) + for _, block := range *msg.Content.ContentBlocks { + if block.Text != nil { + contentBlocks = append(contentBlocks, api.AnthropicContentBlock{ + Type: "text", + Text: block.Text, + }) + } + if block.ImageURL != nil { + // Handle image content + imageSource := buildAnthropicImageSourceMap(block.ImageURL) + contentBlocks = append(contentBlocks, api.AnthropicContentBlock{ + Type: "image", + Source: imageSource, + }) + } + } + if len(contentBlocks) > 0 { + anthropicMsg.Content = api.AnthropicContent{ + ContentBlocks: &contentBlocks, + } + } + } + + anthropicMessages = append(anthropicMessages, anthropicMsg) + } + } + + // Build the request + request := &api.AnthropicMessageRequest{ + Model: model, + Messages: anthropicMessages, + } + + // Add system content if present + if systemContent != nil { + request.System = systemContent + } + + // Add parameters if provided + if params != nil { + request.MaxTokens = DEFAULT_MAX_TOKENS // Default value, will be set below if provided + if params.MaxTokens != nil { + request.MaxTokens = *params.MaxTokens + } + request.Temperature = params.Temperature + request.TopP = params.TopP + request.TopK = params.TopK + request.StopSequences = params.StopSequences + + // Handle tools if present + if params.Tools != nil { + tools := make([]api.AnthropicTool, 0, len(params.Tools)) + for _, tool := range params.Tools { + anthropicTool := api.AnthropicTool{ + Name: tool.Function.Name, + Description: tool.Function.Description, + } + + // Convert function parameters to input schema + if tool.Function.Parameters.Type != "" { + anthropicTool.InputSchema = &struct { + Type string `json:"type"` + Properties map[string]interface{} `json:"properties"` + Required []string `json:"required"` + }{ + Type: tool.Function.Parameters.Type, + Properties: tool.Function.Parameters.Properties, + Required: tool.Function.Parameters.Required, + } + } + + tools = append(tools, anthropicTool) + } + request.Tools = tools + } + + // Handle tool choice if present + if params.ToolChoice != nil { + if params.ToolChoice.ToolChoiceStr != nil { + request.ToolChoice = &api.AnthropicToolChoice{ + Type: schemas.ToolChoiceType(*params.ToolChoice.ToolChoiceStr), + } + } else if params.ToolChoice.ToolChoiceStruct != nil { + toolChoice := &api.AnthropicToolChoice{ + Type: params.ToolChoice.ToolChoiceStruct.Type, + } + if params.ToolChoice.ToolChoiceStruct.Function != nil && params.ToolChoice.ToolChoiceStruct.Function.Name != "" { + toolChoice.Name = ¶ms.ToolChoice.ToolChoiceStruct.Function.Name + } + request.ToolChoice = toolChoice + } + } + + // Handle extra parameters by mapping them to specific fields + if params.ExtraParams != nil { + // Extract known fields and set them on the typed struct + if anthropicVersion, ok := params.ExtraParams["anthropic_version"].(string); ok { + request.AnthropicVersion = &anthropicVersion + delete(params.ExtraParams, "anthropic_version") + } + if region, ok := params.ExtraParams["region"].(string); ok { + request.Region = ®ion + delete(params.ExtraParams, "region") + } + + // Pass all ExtraParams to the request - MarshalJSON will handle conflict detection + // This ensures unknown fields can still be passed through to the API + request.ExtraParams = params.ExtraParams + } + } + + return request } // completeRequest sends a request to Anthropic's API and handles the response. // It constructs the API URL, sets up authentication, and processes the response. // Returns the response body or an error if the request fails. -func (provider *AnthropicProvider) completeRequest(ctx context.Context, requestBody map[string]interface{}, url string, key string) ([]byte, *schemas.BifrostError) { - // Marshal the request body - jsonData, err := sonic.Marshal(requestBody) +func (provider *AnthropicProvider) completeRequest(ctx context.Context, requestBody api.AnthropicRequestConfig, key string) ([]byte, *schemas.BifrostError) { + + var jsonData []byte + var err error + if requestBody.AnthropicTextRequest != nil { + jsonData, err = sonic.Marshal(requestBody.AnthropicTextRequest) + } else if requestBody.AnthropicMessageRequest != nil { + jsonData, err = sonic.Marshal(requestBody.AnthropicMessageRequest) + } + if err != nil { return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, schemas.Anthropic) } @@ -278,7 +337,7 @@ func (provider *AnthropicProvider) completeRequest(ctx context.Context, requestB // Set any extra headers from network config setExtraHeaders(req, provider.networkConfig.ExtraHeaders, nil) - req.SetRequestURI(url) + req.SetRequestURI(requestBody.URL) req.Header.SetMethod("POST") req.Header.SetContentType("application/json") req.Header.Set("x-api-key", key) @@ -296,7 +355,7 @@ func (provider *AnthropicProvider) completeRequest(ctx context.Context, requestB if resp.StatusCode() != fasthttp.StatusOK { provider.logger.Debug(fmt.Sprintf("error from anthropic provider: %s", string(resp.Body()))) - var errorResp AnthropicError + var errorResp api.AnthropicError bifrostErr := handleProviderAPIError(resp, &errorResp) bifrostErr.Error.Type = &errorResp.Error.Type @@ -315,15 +374,14 @@ func (provider *AnthropicProvider) completeRequest(ctx context.Context, requestB // It formats the request, sends it to Anthropic, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. func (provider *AnthropicProvider) TextCompletion(ctx context.Context, model string, key schemas.Key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { - preparedParams := provider.prepareTextCompletionParams(prepareParams(params)) + request := buildAnthropicTextRequest(model, text, params) - // Merge additional parameters - requestBody := mergeConfig(map[string]interface{}{ - "model": model, - "prompt": fmt.Sprintf("\n\nHuman: %s\n\nAssistant:", text), - }, preparedParams) + requestBody := api.AnthropicRequestConfig{ + URL: provider.networkConfig.BaseURL + "/v1/complete", + AnthropicTextRequest: request, + } - responseBody, err := provider.completeRequest(ctx, requestBody, provider.networkConfig.BaseURL+"/v1/complete", key.Value) + responseBody, err := provider.completeRequest(ctx, requestBody, key.Value) if err != nil { return nil, err } @@ -380,15 +438,15 @@ func (provider *AnthropicProvider) TextCompletion(ctx context.Context, model str // It formats the request, sends it to Anthropic, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. func (provider *AnthropicProvider) ChatCompletion(ctx context.Context, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { - formattedMessages, preparedParams := prepareAnthropicChatRequest(messages, params) + // Build type-safe request + request := buildAnthropicChatRequest(model, messages, params) - // Merge additional parameters - requestBody := mergeConfig(map[string]interface{}{ - "model": model, - "messages": formattedMessages, - }, preparedParams) + requestBody := api.AnthropicRequestConfig{ + URL: provider.networkConfig.BaseURL + "/v1/messages", + AnthropicMessageRequest: request, + } - responseBody, err := provider.completeRequest(ctx, requestBody, provider.networkConfig.BaseURL+"/v1/messages", key.Value) + responseBody, err := provider.completeRequest(ctx, requestBody, key.Value) if err != nil { return nil, err } @@ -426,7 +484,7 @@ func (provider *AnthropicProvider) ChatCompletion(ctx context.Context, model str } // buildAnthropicImageSourceMap creates the "source" map for an Anthropic image content part. -func buildAnthropicImageSourceMap(imgContent *schemas.ImageURLStruct) map[string]interface{} { +func buildAnthropicImageSourceMap(imgContent *schemas.ImageURLStruct) *api.AnthropicImageSource { if imgContent == nil { return nil } @@ -434,273 +492,24 @@ func buildAnthropicImageSourceMap(imgContent *schemas.ImageURLStruct) map[string sanitizedURL, _ := SanitizeImageURL(imgContent.URL) urlTypeInfo := ExtractURLTypeInfo(sanitizedURL) - formattedImgContent := AnthropicImageContent{ - Type: urlTypeInfo.Type, + imageSource := &api.AnthropicImageSource{ + Type: string(urlTypeInfo.Type), } if urlTypeInfo.MediaType != nil { - formattedImgContent.MediaType = *urlTypeInfo.MediaType + imageSource.MediaType = urlTypeInfo.MediaType } if urlTypeInfo.DataURLWithoutPrefix != nil { - formattedImgContent.URL = *urlTypeInfo.DataURLWithoutPrefix + imageSource.Data = urlTypeInfo.DataURLWithoutPrefix } else { - formattedImgContent.URL = sanitizedURL - } - - sourceMap := map[string]interface{}{ - "type": string(formattedImgContent.Type), // "base64" or "url" + imageSource.URL = &sanitizedURL } - if formattedImgContent.Type == ImageContentTypeURL { - sourceMap["url"] = formattedImgContent.URL - } else { - if formattedImgContent.MediaType != "" { - sourceMap["media_type"] = formattedImgContent.MediaType - } - sourceMap["data"] = formattedImgContent.URL // URL field contains base64 data string - } - return sourceMap + return imageSource } -func prepareAnthropicChatRequest(messages []schemas.BifrostMessage, params *schemas.ModelParameters) ([]map[string]interface{}, map[string]interface{}) { - // Add system messages if present - var systemMessages []BedrockAnthropicSystemMessage - for _, msg := range messages { - if msg.Role == schemas.ModelChatMessageRoleSystem { - if msg.Content.ContentStr != nil { - systemMessages = append(systemMessages, BedrockAnthropicSystemMessage{ - Text: *msg.Content.ContentStr, - }) - } else if msg.Content.ContentBlocks != nil { - for _, block := range *msg.Content.ContentBlocks { - if block.Text != nil { - systemMessages = append(systemMessages, BedrockAnthropicSystemMessage{ - Text: *block.Text, - }) - } - } - } - } - } - - // Format messages for Anthropic API - var formattedMessages []map[string]interface{} - for _, msg := range messages { - var content []interface{} - - if msg.Role != schemas.ModelChatMessageRoleSystem { - if msg.Role == schemas.ModelChatMessageRoleTool && msg.ToolMessage != nil && msg.ToolMessage.ToolCallID != nil { - toolCallResult := map[string]interface{}{ - "type": "tool_result", - "tool_use_id": *msg.ToolMessage.ToolCallID, - } - - var toolCallResultContent []map[string]interface{} - - if msg.Content.ContentStr != nil { - toolCallResultContent = append(toolCallResultContent, map[string]interface{}{ - "type": "text", - "text": *msg.Content.ContentStr, - }) - } else if msg.Content.ContentBlocks != nil { - for _, block := range *msg.Content.ContentBlocks { - if block.Text != nil { - toolCallResultContent = append(toolCallResultContent, map[string]interface{}{ - "type": "text", - "text": *block.Text, - }) - } - } - } - - toolCallResult["content"] = toolCallResultContent - content = append(content, toolCallResult) - } else { - // Add text content if present - if msg.Content.ContentStr != nil && *msg.Content.ContentStr != "" { - content = append(content, map[string]interface{}{ - "type": "text", - "text": *msg.Content.ContentStr, - }) - } else if msg.Content.ContentBlocks != nil { - for _, block := range *msg.Content.ContentBlocks { - if block.Text != nil && *block.Text != "" { - content = append(content, map[string]interface{}{ - "type": "text", - "text": *block.Text, - }) - } - if block.ImageURL != nil { - imageSource := buildAnthropicImageSourceMap(block.ImageURL) - if imageSource != nil { - content = append(content, map[string]interface{}{ - "type": "image", - "source": imageSource, - }) - } - } - } - } - - // Add thinking content if present in AssistantMessage - if msg.AssistantMessage != nil && msg.AssistantMessage.Thought != nil { - content = append(content, map[string]interface{}{ - "type": "thinking", - "thinking": *msg.AssistantMessage.Thought, - }) - } - - // Add tool calls as content if present - if msg.AssistantMessage != nil && msg.AssistantMessage.ToolCalls != nil { - for _, toolCall := range *msg.AssistantMessage.ToolCalls { - if toolCall.Function.Name != nil { - var input map[string]interface{} - if toolCall.Function.Arguments != "" { - if err := sonic.Unmarshal([]byte(toolCall.Function.Arguments), &input); err != nil { - // If unmarshaling fails, use a simple string representation - input = map[string]interface{}{"arguments": toolCall.Function.Arguments} - } - } - - toolUseContent := map[string]interface{}{ - "type": "tool_use", - "name": *toolCall.Function.Name, - "input": input, - } - - if toolCall.ID != nil { - toolUseContent["id"] = *toolCall.ID - } - - content = append(content, toolUseContent) - } - } - } - } - - if len(content) > 0 { - formattedMessages = append(formattedMessages, map[string]interface{}{ - "role": msg.Role, - "content": content, - }) - } - } - } - - preparedParams := prepareParams(params) - - // Transform tools if present - if params != nil && params.Tools != nil && len(*params.Tools) > 0 { - var tools []map[string]interface{} - for _, tool := range *params.Tools { - tools = append(tools, map[string]interface{}{ - "name": tool.Function.Name, - "description": tool.Function.Description, - "input_schema": tool.Function.Parameters, - }) - } - - preparedParams["tools"] = tools - } - - // Transform tool choice if present - if params != nil && params.ToolChoice != nil { - if params.ToolChoice.ToolChoiceStr != nil { - preparedParams["tool_choice"] = map[string]interface{}{ - "type": *params.ToolChoice.ToolChoiceStr, - } - } else if params.ToolChoice.ToolChoiceStruct != nil { - switch toolChoice := params.ToolChoice.ToolChoiceStruct.Type; toolChoice { - case schemas.ToolChoiceTypeFunction: - fallthrough - case "tool": - preparedParams["tool_choice"] = map[string]interface{}{ - "type": "tool", - "name": params.ToolChoice.ToolChoiceStruct.Function.Name, - } - default: - preparedParams["tool_choice"] = map[string]interface{}{ - "type": toolChoice, - } - } - } - } - - if len(systemMessages) > 0 { - var messages []string - for _, message := range systemMessages { - messages = append(messages, message.Text) - } - - preparedParams["system"] = strings.Join(messages, " ") - } - - // Post-process formattedMessages for tool call results - processedFormattedMessages := []map[string]interface{}{} // Use a new slice - i := 0 - for i < len(formattedMessages) { - currentMsg := formattedMessages[i] - currentRole, roleOk := getRoleFromMessage(currentMsg) - - if !roleOk || currentRole == "" { - // If role is of an unexpected type, missing, or empty, treat as non-tool message - processedFormattedMessages = append(processedFormattedMessages, currentMsg) - i++ - continue - } - - if currentRole == schemas.ModelChatMessageRoleTool { - // Content of a tool message is the toolCallResult map - // Initialize accumulatedToolResults with the content of the current tool message. - var accumulatedToolResults []interface{} - - // Safely extract content from current message - if content, ok := currentMsg["content"].([]interface{}); ok { - accumulatedToolResults = content - } else { - // If content is not the expected type, skip this message - processedFormattedMessages = append(processedFormattedMessages, currentMsg) - i++ - continue - } - - // Look ahead for more sequential tool messages - j := i + 1 - for j < len(formattedMessages) { - nextMsg := formattedMessages[j] - nextRole, nextRoleOk := getRoleFromMessage(nextMsg) - - if !nextRoleOk || nextRole == "" || nextRole != schemas.ModelChatMessageRoleTool { - break // Not a sequential tool message or role is invalid/missing/empty - } - - // Safely extract content from next message - if nextContent, ok := nextMsg["content"].([]interface{}); ok { - accumulatedToolResults = append(accumulatedToolResults, nextContent...) - } - j++ - } - - // Create a new message with role User and accumulated content - mergedMsg := map[string]interface{}{ - "role": schemas.ModelChatMessageRoleUser, // Final role is User - "content": accumulatedToolResults, - } - processedFormattedMessages = append(processedFormattedMessages, mergedMsg) - i = j // Advance main loop index past all merged messages - } else { - // Not a tool message, add it as is - processedFormattedMessages = append(processedFormattedMessages, currentMsg) - i++ - } - } - formattedMessages = processedFormattedMessages // Update with processed messages - - return formattedMessages, preparedParams -} - -func parseAnthropicResponse(response *AnthropicChatResponse, bifrostResponse *schemas.BifrostResponse) (*schemas.BifrostResponse, *schemas.BifrostError) { +func parseAnthropicResponse(response *api.AnthropicChatResponse, bifrostResponse *schemas.BifrostResponse) (*schemas.BifrostResponse, *schemas.BifrostError) { // Collect all content and tool calls into a single message var toolCalls []schemas.ToolCall var thinking string @@ -729,7 +538,7 @@ func parseAnthropicResponse(response *AnthropicChatResponse, bifrostResponse *sc } toolCalls = append(toolCalls, schemas.ToolCall{ - Type: StrPtr("function"), + Type: ptr("function"), ID: &c.ID, Function: function, }) @@ -787,14 +596,12 @@ func (provider *AnthropicProvider) Embedding(ctx context.Context, model string, // It supports real-time streaming of responses using Server-Sent Events (SSE). // Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails. func (provider *AnthropicProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { - formattedMessages, preparedParams := prepareAnthropicChatRequest(messages, params) + // Build type-safe request + requestBody := buildAnthropicChatRequest(model, messages, params) - // Merge additional parameters and set stream to true - requestBody := mergeConfig(map[string]interface{}{ - "model": model, - "messages": formattedMessages, - "stream": true, - }, preparedParams) + // Set streaming flag + stream := true + requestBody.Stream = &stream // Prepare Anthropic headers headers := map[string]string{ @@ -826,7 +633,7 @@ func handleAnthropicStreaming( ctx context.Context, httpClient *http.Client, url string, - requestBody map[string]interface{}, + requestBody *api.AnthropicMessageRequest, headers map[string]string, extraHeaders map[string]string, providerType schemas.ModelProvider, @@ -911,7 +718,7 @@ func handleAnthropicStreaming( // Handle different event types switch eventType { case "message_start": - var event AnthropicStreamEvent + var event api.AnthropicStreamEvent if err := sonic.Unmarshal([]byte(eventData), &event); err != nil { logger.Warn(fmt.Sprintf("Failed to parse message_start event: %v", err)) continue @@ -922,7 +729,7 @@ func handleAnthropicStreaming( } case "content_block_start": - var event AnthropicStreamEvent + var event api.AnthropicStreamEvent if err := sonic.Unmarshal([]byte(eventData), &event); err != nil { logger.Warn(fmt.Sprintf("Failed to parse content_block_start event: %v", err)) continue @@ -933,7 +740,7 @@ func handleAnthropicStreaming( switch event.ContentBlock.Type { case "tool_use": // Tool use content block initialization - if event.ContentBlock.Name != "" && event.ContentBlock.ID != "" { + if event.ContentBlock.Name != nil && event.ContentBlock.ID != nil { // Create streaming response for tool start streamResponse := &schemas.BifrostResponse{ ID: messageID, @@ -947,9 +754,9 @@ func handleAnthropicStreaming( ToolCalls: []schemas.ToolCall{ { Type: func() *string { s := "function"; return &s }(), - ID: &event.ContentBlock.ID, + ID: event.ContentBlock.ID, Function: schemas.FunctionCall{ - Name: &event.ContentBlock.Name, + Name: event.ContentBlock.Name, }, }, }, @@ -970,13 +777,9 @@ func handleAnthropicStreaming( processAndSendResponse(ctx, postHookRunner, streamResponse, responseChan) } default: - thought := "" - if event.ContentBlock.Thinking != "" { - thought = event.ContentBlock.Thinking - } content := "" - if event.ContentBlock.Text != "" { - content = event.ContentBlock.Text + if event.ContentBlock.Text != nil { + content = *event.ContentBlock.Text } // Send empty message for other content block types @@ -989,7 +792,7 @@ func handleAnthropicStreaming( Index: *event.Index, BifrostStreamResponseChoice: &schemas.BifrostStreamResponseChoice{ Delta: schemas.BifrostStreamDelta{ - Thought: &thought, + Thought: &content, Content: &content, }, }, @@ -1010,7 +813,7 @@ func handleAnthropicStreaming( } case "content_block_delta": - var event AnthropicStreamEvent + var event api.AnthropicStreamEvent if err := sonic.Unmarshal([]byte(eventData), &event); err != nil { logger.Warn(fmt.Sprintf("Failed to parse content_block_delta event: %v", err)) continue @@ -1130,7 +933,7 @@ func handleAnthropicStreaming( continue case "message_delta": - var event AnthropicStreamEvent + var event api.AnthropicStreamEvent if err := sonic.Unmarshal([]byte(eventData), &event); err != nil { logger.Warn(fmt.Sprintf("Failed to parse message_delta event: %v", err)) continue @@ -1168,7 +971,7 @@ func handleAnthropicStreaming( } case "message_stop": - var event AnthropicStreamEvent + var event api.AnthropicStreamEvent if err := sonic.Unmarshal([]byte(eventData), &event); err != nil { logger.Warn(fmt.Sprintf("Failed to parse message_stop event: %v", err)) continue @@ -1211,7 +1014,7 @@ func handleAnthropicStreaming( continue case "error": - var event AnthropicStreamEvent + var event api.AnthropicStreamEvent if err := sonic.Unmarshal([]byte(eventData), &event); err != nil { logger.Warn(fmt.Sprintf("Failed to parse error event: %v", err)) continue diff --git a/core/providers/azure.go b/core/providers/azure.go index e41015d3e2..d94c852b2e 100644 --- a/core/providers/azure.go +++ b/core/providers/azure.go @@ -31,18 +31,6 @@ type AzureTextResponse struct { Usage schemas.LLMUsage `json:"usage"` // Token usage statistics } -// AzureChatResponse represents the response structure from Azure's chat completion API. -// It includes completion choices, model information, and usage statistics. -type AzureChatResponse struct { - ID string `json:"id"` // Unique identifier for the completion - Object string `json:"object"` // Type of completion (always "chat.completion") - Choices []schemas.BifrostResponseChoice `json:"choices"` // Array of completion choices - Model string `json:"model"` // Model used for the completion - Created int `json:"created"` // Unix timestamp of completion creation - SystemFingerprint *string `json:"system_fingerprint"` // System fingerprint for the request - Usage schemas.LLMUsage `json:"usage"` // Token usage statistics -} - // AzureEmbeddingResponse represents the response structure from Azure's embedding API. type AzureEmbeddingResponse struct { Object string `json:"object"` @@ -79,19 +67,19 @@ var azureTextCompletionResponsePool = sync.Pool{ // azureChatResponsePool provides a pool for Azure chat response objects. var azureChatResponsePool = sync.Pool{ New: func() interface{} { - return &AzureChatResponse{} + return &schemas.BifrostResponse{} }, } // acquireAzureChatResponse gets an Azure chat response from the pool and resets it. -func acquireAzureChatResponse() *AzureChatResponse { - resp := azureChatResponsePool.Get().(*AzureChatResponse) - *resp = AzureChatResponse{} // Reset the struct +func acquireAzureChatResponse() *schemas.BifrostResponse { + resp := azureChatResponsePool.Get().(*schemas.BifrostResponse) + *resp = schemas.BifrostResponse{} // Reset the struct return resp } // releaseAzureChatResponse returns an Azure chat response to the pool. -func releaseAzureChatResponse(resp *AzureChatResponse) { +func releaseAzureChatResponse(resp *schemas.BifrostResponse) { if resp != nil { azureChatResponsePool.Put(resp) } @@ -139,7 +127,7 @@ func NewAzureProvider(config *schemas.ProviderConfig, logger schemas.Logger) (*A // Pre-warm response pools for range config.ConcurrencyAndBufferSize.Concurrency { - azureChatResponsePool.Put(&AzureChatResponse{}) + azureChatResponsePool.Put(&schemas.BifrostResponse{}) azureTextCompletionResponsePool.Put(&AzureTextResponse{}) } @@ -164,7 +152,7 @@ func (provider *AzureProvider) GetProviderKey() schemas.ModelProvider { // completeRequest sends a request to Azure's API and handles the response. // It constructs the API URL, sets up authentication, and processes the response. // Returns the response body or an error if the request fails. -func (provider *AzureProvider) completeRequest(ctx context.Context, requestBody map[string]interface{}, path string, key schemas.Key, model string) ([]byte, *schemas.BifrostError) { +func (provider *AzureProvider) completeRequest(ctx context.Context, requestBody interface{}, path string, key schemas.Key, model string) ([]byte, *schemas.BifrostError) { if key.AzureKeyConfig == nil { return nil, newConfigurationError("azure key config not set", schemas.Azure) } @@ -189,7 +177,7 @@ func (provider *AzureProvider) completeRequest(ctx context.Context, requestBody apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil { - apiVersion = StrPtr("2024-02-01") + apiVersion = ptr("2024-02-01") } url = fmt.Sprintf("%s/openai/deployments/%s/%s?api-version=%s", url, deployment, path, *apiVersion) @@ -320,13 +308,7 @@ func (provider *AzureProvider) TextCompletion(ctx context.Context, model string, // It formats the request, sends it to Azure, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. func (provider *AzureProvider) ChatCompletion(ctx context.Context, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { - formattedMessages, preparedParams := prepareOpenAIChatRequest(messages, params) - - // Merge additional parameters - requestBody := mergeConfig(map[string]interface{}{ - "model": model, - "messages": formattedMessages, - }, preparedParams) + requestBody := buildOpenAIChatCompletionRequest(model, messages, params) responseBody, err := provider.completeRequest(ctx, requestBody, "chat/completions", key, model) if err != nil { @@ -342,29 +324,18 @@ func (provider *AzureProvider) ChatCompletion(ctx context.Context, model string, return nil, bifrostErr } - // Create final response - bifrostResponse := &schemas.BifrostResponse{ - ID: response.ID, - Choices: response.Choices, - Model: response.Model, - Created: response.Created, - SystemFingerprint: response.SystemFingerprint, - Usage: &response.Usage, - ExtraFields: schemas.BifrostResponseExtraFields{ - Provider: schemas.Azure, - }, - } + response.ExtraFields.Provider = schemas.Azure // Set raw response if enabled if provider.sendBackRawResponse { - bifrostResponse.ExtraFields.RawResponse = rawResponse + response.ExtraFields.RawResponse = rawResponse } if params != nil { - bifrostResponse.ExtraFields.Params = *params + response.ExtraFields.Params = *params } - return bifrostResponse, nil + return response, nil } // Embedding generates embeddings for the given input text(s) using Azure OpenAI. @@ -461,18 +432,14 @@ func (provider *AzureProvider) Embedding(ctx context.Context, model string, key // Uses Azure-specific URL construction with deployments and supports both api-key and Bearer token authentication. // Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails. func (provider *AzureProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { - formattedMessages, preparedParams := prepareOpenAIChatRequest(messages, params) if key.AzureKeyConfig == nil { return nil, newConfigurationError("azure key config not set", schemas.Azure) } + requestBody := buildOpenAIChatCompletionRequest(model, messages, params) - // Merge additional parameters and set stream to true - requestBody := mergeConfig(map[string]interface{}{ - "model": model, - "messages": formattedMessages, - "stream": true, - }, preparedParams) + stream := true + requestBody.Stream = &stream // Construct Azure-specific URL with deployment if key.AzureKeyConfig.Endpoint == "" { @@ -490,7 +457,7 @@ func (provider *AzureProvider) ChatCompletionStream(ctx context.Context, postHoo apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil { - apiVersion = StrPtr("2024-02-01") + apiVersion = ptr("2024-02-01") } fullURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=%s", baseURL, deployment, *apiVersion) diff --git a/core/providers/bedrock.go b/core/providers/bedrock.go index 8e2f3a38e3..aac6a5a15a 100644 --- a/core/providers/bedrock.go +++ b/core/providers/bedrock.go @@ -25,6 +25,7 @@ import ( "github.com/aws/aws-sdk-go-v2/config" "github.com/bytedance/sonic" schemas "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/core/schemas/api" ) // BedrockAnthropicTextResponse represents the response structure from Bedrock's Anthropic text completion API. @@ -277,7 +278,7 @@ func (provider *BedrockProvider) completeRequest(ctx context.Context, requestBod return nil, &schemas.BifrostError{ IsBifrostError: false, Error: schemas.ErrorField{ - Type: StrPtr(schemas.RequestCancelled), + Type: ptr(schemas.RequestCancelled), Message: fmt.Sprintf("Request cancelled or timed out by context: %v", ctx.Err()), Error: err, }, @@ -583,7 +584,7 @@ func (provider *BedrockProvider) prepareChatCompletionMessages(messages []schema sanitizedURL, _ := SanitizeImageURL(block.ImageURL.URL) urlTypeInfo := ExtractURLTypeInfo(sanitizedURL) - formattedImgContent := AnthropicImageContent{ + formattedImgContent := api.AnthropicImageContent{ Type: urlTypeInfo.Type, } @@ -777,7 +778,7 @@ func (provider *BedrockProvider) getChatCompletionTools(params *schemas.ModelPar case "anthropic.claude-3-opus-20240229-v1:0": fallthrough case "anthropic.claude-3-7-sonnet-20250219-v1:0": - for _, tool := range *params.Tools { + for _, tool := range params.Tools { tools = append(tools, BedrockAnthropicToolCall{ ToolSpec: BedrockAnthropicToolSpec{ Name: tool.Function.Name, @@ -916,7 +917,7 @@ func (provider *BedrockProvider) ChatCompletion(ctx context.Context, model strin preparedParams := prepareParams(params) // Transform tools if present - if params != nil && params.Tools != nil && len(*params.Tools) > 0 { + if params != nil && params.Tools != nil && len(params.Tools) > 0 { preparedParams["toolConfig"] = map[string]interface{}{ "tools": provider.getChatCompletionTools(params, model), } @@ -987,7 +988,7 @@ func (provider *BedrockProvider) ChatCompletion(ctx context.Context, model strin } toolCalls = append(toolCalls, schemas.ToolCall{ - Type: StrPtr("function"), + Type: ptr("function"), ID: &choice.ToolUse.ToolUseID, Function: schemas.FunctionCall{ Name: &choice.ToolUse.Name, @@ -1257,7 +1258,7 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx context.Context, postH preparedParams := prepareParams(params) // Transform tools if present - if params != nil && params.Tools != nil && len(*params.Tools) > 0 { + if params != nil && params.Tools != nil && len(params.Tools) > 0 { preparedParams["toolConfig"] = map[string]interface{}{ "tools": provider.getChatCompletionTools(params, model), } diff --git a/core/providers/cohere.go b/core/providers/cohere.go index 0135cccc59..d9607b8ee4 100644 --- a/core/providers/cohere.go +++ b/core/providers/cohere.go @@ -325,8 +325,8 @@ func (provider *CohereProvider) ChatCompletion(ctx context.Context, model string ExtraFields: schemas.BifrostResponseExtraFields{ Provider: schemas.Cohere, BilledUsage: &schemas.BilledLLMUsage{ - PromptTokens: float64Ptr(response.Meta.BilledUnits.InputTokens), - CompletionTokens: float64Ptr(response.Meta.BilledUnits.OutputTokens), + PromptTokens: ptr(response.Meta.BilledUnits.InputTokens), + CompletionTokens: ptr(response.Meta.BilledUnits.OutputTokens), }, ChatHistory: convertChatHistory(response.ChatHistory), }, @@ -491,9 +491,9 @@ func prepareCohereChatRequest(messages []schemas.BifrostMessage, params *schemas } // Add tools if present - if params != nil && params.Tools != nil && len(*params.Tools) > 0 { + if params != nil && params.Tools != nil && len(params.Tools) > 0 { var tools []CohereTool - for _, tool := range *params.Tools { + for _, tool := range params.Tools { parameterDefinitions := make(map[string]CohereParameterDefinition) params := tool.Function.Parameters for name, prop := range tool.Function.Parameters.Properties { @@ -808,7 +808,7 @@ func (provider *CohereProvider) ChatCompletionStream(ctx context.Context, postHo BifrostStreamResponseChoice: &schemas.BifrostStreamResponseChoice{ Delta: schemas.BifrostStreamDelta{ - Role: StrPtr(string(schemas.ModelChatMessageRoleAssistant)), + Role: ptr(string(schemas.ModelChatMessageRoleAssistant)), }, }, }, @@ -939,7 +939,7 @@ func (provider *CohereProvider) ChatCompletionStream(ctx context.Context, postHo Index: 0, BifrostStreamResponseChoice: &schemas.BifrostStreamResponseChoice{ Delta: schemas.BifrostStreamDelta{ - Role: StrPtr(string(schemas.ModelChatMessageRoleAssistant)), + Role: ptr(string(schemas.ModelChatMessageRoleAssistant)), Content: &stopEvent.Response.Text, ToolCalls: toolCalls, }, diff --git a/core/providers/groq.go b/core/providers/groq.go index 1c03fe7462..c714907d93 100644 --- a/core/providers/groq.go +++ b/core/providers/groq.go @@ -97,12 +97,7 @@ func (provider *GroqProvider) TextCompletion(ctx context.Context, model string, // ChatCompletion performs a chat completion request to the Groq API. func (provider *GroqProvider) ChatCompletion(ctx context.Context, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { - formattedMessages, preparedParams := prepareOpenAIChatRequest(messages, params) - - requestBody := mergeConfig(map[string]interface{}{ - "model": model, - "messages": formattedMessages, - }, preparedParams) + requestBody := buildOpenAIChatCompletionRequest(model, messages, params) jsonBody, err := sonic.Marshal(requestBody) if err != nil { @@ -177,13 +172,10 @@ func (provider *GroqProvider) Embedding(ctx context.Context, model string, key s // Uses Groq's OpenAI-compatible streaming format. // Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails. func (provider *GroqProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { - formattedMessages, preparedParams := prepareOpenAIChatRequest(messages, params) + requestBody := buildOpenAIChatCompletionRequest(model, messages, params) - requestBody := mergeConfig(map[string]interface{}{ - "model": model, - "messages": formattedMessages, - "stream": true, - }, preparedParams) + stream := true + requestBody.Stream = &stream // Prepare Groq headers (Groq typically doesn't require authorization, but we include it if provided) headers := map[string]string{ diff --git a/core/providers/mistral.go b/core/providers/mistral.go index a32c30879b..55e86a8303 100644 --- a/core/providers/mistral.go +++ b/core/providers/mistral.go @@ -111,12 +111,7 @@ func (provider *MistralProvider) TextCompletion(ctx context.Context, model strin // ChatCompletion performs a chat completion request to the Mistral API. func (provider *MistralProvider) ChatCompletion(ctx context.Context, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { - formattedMessages, preparedParams := prepareOpenAIChatRequest(messages, params) - - requestBody := mergeConfig(map[string]interface{}{ - "model": model, - "messages": formattedMessages, - }, preparedParams) + requestBody := buildOpenAIChatCompletionRequest(model, messages, params) jsonBody, err := sonic.Marshal(requestBody) if err != nil { @@ -301,13 +296,10 @@ func (provider *MistralProvider) Embedding(ctx context.Context, model string, ke // Uses Mistral's OpenAI-compatible streaming format. // Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails. func (provider *MistralProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { - formattedMessages, preparedParams := prepareOpenAIChatRequest(messages, params) + requestBody := buildOpenAIChatCompletionRequest(model, messages, params) - requestBody := mergeConfig(map[string]interface{}{ - "model": model, - "messages": formattedMessages, - "stream": true, - }, preparedParams) + stream := true + requestBody.Stream = &stream // Prepare Mistral headers headers := map[string]string{ diff --git a/core/providers/ollama.go b/core/providers/ollama.go index fe994cecab..43dc5918c6 100644 --- a/core/providers/ollama.go +++ b/core/providers/ollama.go @@ -98,12 +98,7 @@ func (provider *OllamaProvider) TextCompletion(ctx context.Context, model string // ChatCompletion performs a chat completion request to the Ollama API. func (provider *OllamaProvider) ChatCompletion(ctx context.Context, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { - formattedMessages, preparedParams := prepareOpenAIChatRequest(messages, params) - - requestBody := mergeConfig(map[string]interface{}{ - "model": model, - "messages": formattedMessages, - }, preparedParams) + requestBody := buildOpenAIChatCompletionRequest(model, messages, params) jsonBody, err := sonic.Marshal(requestBody) if err != nil { @@ -179,13 +174,10 @@ func (provider *OllamaProvider) Embedding(ctx context.Context, model string, key // Uses Ollama's OpenAI-compatible streaming format. // Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails. func (provider *OllamaProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { - formattedMessages, preparedParams := prepareOpenAIChatRequest(messages, params) + requestBody := buildOpenAIChatCompletionRequest(model, messages, params) - requestBody := mergeConfig(map[string]interface{}{ - "model": model, - "messages": formattedMessages, - "stream": true, - }, preparedParams) + stream := true + requestBody.Stream = &stream // Prepare Ollama headers (Ollama typically doesn't require authorization, but we include it if provided) headers := map[string]string{ diff --git a/core/providers/openai.go b/core/providers/openai.go index dd5cef215d..31d7e323cd 100644 --- a/core/providers/openai.go +++ b/core/providers/openai.go @@ -19,27 +19,10 @@ import ( "github.com/bytedance/sonic" schemas "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/core/schemas/api" "github.com/valyala/fasthttp" ) -// OpenAIResponse represents the response structure from the OpenAI API. -// It includes completion choices, model information, and usage statistics. -type OpenAIResponse struct { - ID string `json:"id"` // Unique identifier for the completion - Object string `json:"object"` // Type of completion (text.completion, chat.completion, or embedding) - Choices []schemas.BifrostResponseChoice `json:"choices"` // Array of completion choices - Data []struct { // Embedding data - Object string `json:"object"` - Embedding any `json:"embedding"` - Index int `json:"index"` - } `json:"data,omitempty"` - Model string `json:"model"` // Model used for the completion - Created int `json:"created"` // Unix timestamp of completion creation - ServiceTier *string `json:"service_tier"` // Service tier used for the request - SystemFingerprint *string `json:"system_fingerprint"` // System fingerprint for the request - Usage schemas.LLMUsage `json:"usage"` // Token usage statistics -} - // openAIResponsePool provides a pool for OpenAI response objects. var openAIResponsePool = sync.Pool{ New: func() interface{} { @@ -121,16 +104,152 @@ func (provider *OpenAIProvider) TextCompletion(ctx context.Context, model string return nil, newUnsupportedOperationError("text completion", "openai") } +// buildChatCompletionRequest creates a type-safe OpenAI chat completion request +// from Bifrost messages and parameters. +func buildOpenAIChatCompletionRequest(model string, messages []schemas.BifrostMessage, params *schemas.ModelParameters) *api.OpenAIChatRequest { + // Process messages to sanitize image URLs + processedMessages := make([]schemas.BifrostMessage, len(messages)) + for i, msg := range messages { + processedMessages[i] = msg // Copy the message + + // Sanitize image URLs in content blocks + if msg.Content.ContentBlocks != nil { + contentBlocks := *msg.Content.ContentBlocks + for j := range contentBlocks { + if contentBlocks[j].Type == schemas.ContentBlockTypeImage && contentBlocks[j].ImageURL != nil { + sanitizedURL, _ := SanitizeImageURL(contentBlocks[j].ImageURL.URL) + contentBlocks[j].ImageURL.URL = sanitizedURL + } + } + processedMessages[i].Content.ContentBlocks = &contentBlocks + } + } + + // Build the request + request := &api.OpenAIChatRequest{ + Model: model, + Messages: processedMessages, + Stream: ptr(false), + } + + // Add parameters if provided + if params != nil { + request.MaxTokens = params.MaxTokens + request.Temperature = params.Temperature + request.TopP = params.TopP + request.PresencePenalty = params.PresencePenalty + request.FrequencyPenalty = params.FrequencyPenalty + request.Tools = params.Tools + request.ToolChoice = params.ToolChoice + request.LogProbs = params.Logprobs + request.User = params.User + request.N = params.N + + // Handle extra parameters + if params.ExtraParams != nil { + if stop, ok := params.ExtraParams["stop"]; ok { + request.Stop = stop + delete(params.ExtraParams, "stop") + } + if logitBias, ok := params.ExtraParams["logit_bias"].(map[string]float64); ok { + request.LogitBias = logitBias + delete(params.ExtraParams, "logit_bias") + } + if topLogProbs, ok := params.ExtraParams["top_logprobs"].(int); ok { + request.TopLogProbs = &topLogProbs + delete(params.ExtraParams, "top_logprobs") + } + if responseFormat, ok := params.ExtraParams["response_format"]; ok { + request.ResponseFormat = responseFormat + delete(params.ExtraParams, "response_format") + } + if seed, ok := params.ExtraParams["seed"].(int); ok { + request.Seed = &seed + delete(params.ExtraParams, "seed") + } + + request.ExtraParams = params.ExtraParams + } + } + + return request +} + +// buildOpenAISpeechRequest creates a type-safe OpenAI speech synthesis request +// from Bifrost speech input and parameters. +func buildOpenAISpeechRequest(model string, input *schemas.SpeechInput, params *schemas.ModelParameters) *api.OpenAISpeechRequest { + // Set default response format if not provided + responseFormat := input.ResponseFormat + if responseFormat == "" { + responseFormat = "mp3" + } + + // Validate voice is provided + if input.VoiceConfig.Voice == nil { + return nil + } + + // Build the request + request := &api.OpenAISpeechRequest{ + Model: model, + Input: input.Input, + Voice: *input.VoiceConfig.Voice, + ResponseFormat: &responseFormat, + } + + // Set instructions if provided + if input.Instructions != "" { + request.Instructions = &input.Instructions + } + + // Add parameters if provided + if params != nil && params.ExtraParams != nil { + if speed, ok := params.ExtraParams["speed"].(float64); ok { + request.Speed = &speed + delete(params.ExtraParams, "speed") + } + if streamFormat, ok := params.ExtraParams["stream_format"].(string); ok { + request.StreamFormat = &streamFormat + delete(params.ExtraParams, "stream_format") + } + + request.ExtraParams = params.ExtraParams + } + + return request +} + +func buildOpenAIEmbeddingRequest(model string, input *schemas.EmbeddingInput, params *schemas.ModelParameters) *api.OpenAIEmbeddingRequest { + + // Build the request + request := &api.OpenAIEmbeddingRequest{ + Model: model, + Input: input.Texts, + } + + // Add parameters if provided + if params != nil { + request.EncodingFormat = params.EncodingFormat + request.Dimensions = params.Dimensions + + // Handle extra parameters + if params.ExtraParams != nil { + if user, ok := params.ExtraParams["user"].(string); ok { + request.User = &user + delete(params.ExtraParams, "user") + } + request.ExtraParams = params.ExtraParams + } + } + + return request +} + // ChatCompletion performs a chat completion request to the OpenAI API. // It supports both text and image content in messages. // Returns a BifrostResponse containing the completion results or an error if the request fails. func (provider *OpenAIProvider) ChatCompletion(ctx context.Context, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { - formattedMessages, preparedParams := prepareOpenAIChatRequest(messages, params) - - requestBody := mergeConfig(map[string]interface{}{ - "model": model, - "messages": formattedMessages, - }, preparedParams) + requestBody := buildOpenAIChatCompletionRequest(model, messages, params) jsonBody, err := sonic.Marshal(requestBody) if err != nil { @@ -189,53 +308,6 @@ func (provider *OpenAIProvider) ChatCompletion(ctx context.Context, model string return response, nil } -// prepareOpenAIChatRequest formats messages for the OpenAI API. -// It handles both text and image content in messages. -// Returns a slice of formatted messages and any additional parameters. -func prepareOpenAIChatRequest(messages []schemas.BifrostMessage, params *schemas.ModelParameters) ([]map[string]interface{}, map[string]interface{}) { - // Format messages for OpenAI API - var formattedMessages []map[string]interface{} - for _, msg := range messages { - if msg.Role == schemas.ModelChatMessageRoleAssistant { - assistantMessage := map[string]interface{}{ - "role": msg.Role, - "content": msg.Content, - } - if msg.AssistantMessage != nil && msg.AssistantMessage.ToolCalls != nil { - assistantMessage["tool_calls"] = *msg.AssistantMessage.ToolCalls - } - formattedMessages = append(formattedMessages, assistantMessage) - } else { - message := map[string]interface{}{ - "role": msg.Role, - } - - if msg.Content.ContentStr != nil { - message["content"] = *msg.Content.ContentStr - } else if msg.Content.ContentBlocks != nil { - contentBlocks := *msg.Content.ContentBlocks - for i := range contentBlocks { - if contentBlocks[i].Type == schemas.ContentBlockTypeImage && contentBlocks[i].ImageURL != nil { - sanitizedURL, _ := SanitizeImageURL(contentBlocks[i].ImageURL.URL) - contentBlocks[i].ImageURL.URL = sanitizedURL - } - } - - message["content"] = contentBlocks - } - - if msg.ToolMessage != nil && msg.ToolMessage.ToolCallID != nil { - message["tool_call_id"] = *msg.ToolMessage.ToolCallID - } - - formattedMessages = append(formattedMessages, message) - } - } - - preparedParams := prepareParams(params) - - return formattedMessages, preparedParams -} // Embedding generates embeddings for the given input text(s). // The input can be either a single string or a slice of strings for batch embedding. @@ -246,28 +318,7 @@ func (provider *OpenAIProvider) Embedding(ctx context.Context, model string, key return nil, newBifrostOperationError("input texts cannot be empty", nil, schemas.OpenAI) } - // Prepare request body with base parameters - requestBody := map[string]interface{}{ - "model": model, - "input": input.Texts, - } - - // Merge any additional parameters - if params != nil { - // Map standard parameters - if params.EncodingFormat != nil { - requestBody["encoding_format"] = *params.EncodingFormat - } - if params.Dimensions != nil { - requestBody["dimensions"] = *params.Dimensions - } - if params.User != nil { - requestBody["user"] = *params.User - } - - // Merge any extra parameters - requestBody = mergeConfig(requestBody, params.ExtraParams) - } + requestBody := buildOpenAIEmbeddingRequest(model, input, params) jsonBody, err := sonic.Marshal(requestBody) if err != nil { @@ -303,7 +354,7 @@ func (provider *OpenAIProvider) Embedding(ctx context.Context, model string, key } // Parse response - var response OpenAIResponse + var response api.OpenAIResponse if err := sonic.Unmarshal(resp.Body(), &response); err != nil { return nil, newBifrostOperationError(schemas.ErrProviderResponseUnmarshal, err, schemas.OpenAI) } @@ -376,13 +427,11 @@ func (provider *OpenAIProvider) Embedding(ctx context.Context, model string, key // It formats messages, prepares request body, and uses shared streaming logic. // Returns a channel for streaming responses and any error that occurred. func (provider *OpenAIProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { - formattedMessages, preparedParams := prepareOpenAIChatRequest(messages, params) + requestBody := buildOpenAIChatCompletionRequest(model, messages, params) - requestBody := mergeConfig(map[string]interface{}{ - "model": model, - "messages": formattedMessages, - "stream": true, - }, preparedParams) + // Set streaming flag + stream := true + requestBody.Stream = &stream // Prepare OpenAI headers headers := map[string]string{ @@ -413,7 +462,7 @@ func handleOpenAIStreaming( ctx context.Context, httpClient *http.Client, url string, - requestBody map[string]interface{}, + requestBody *api.OpenAIChatRequest, headers map[string]string, extraHeaders map[string]string, providerType schemas.ModelProvider, @@ -576,21 +625,9 @@ func handleOpenAIStreaming( // It formats the request body, makes the API call, and returns the response. // Returns the response and any error that occurred. func (provider *OpenAIProvider) Speech(ctx context.Context, model string, key schemas.Key, input *schemas.SpeechInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { - responseFormat := input.ResponseFormat - if responseFormat == "" { - responseFormat = "mp3" - } - - requestBody := map[string]interface{}{ - "input": input.Input, - "model": model, - "voice": input.VoiceConfig.Voice, - "instructions": input.Instructions, - "response_format": responseFormat, - } - - if params != nil { - requestBody = mergeConfig(requestBody, params.ExtraParams) + requestBody := buildOpenAISpeechRequest(model, input, params) + if requestBody == nil { + return nil, newBifrostOperationError("invalid speech input: voice is required", nil, schemas.OpenAI) } jsonBody, err := sonic.Marshal(requestBody) @@ -654,23 +691,12 @@ func (provider *OpenAIProvider) Speech(ctx context.Context, model string, key sc // It formats the request body, creates HTTP request, and uses shared streaming logic. // Returns a channel for streaming responses and any error that occurred. func (provider *OpenAIProvider) SpeechStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model string, key schemas.Key, input *schemas.SpeechInput, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { - responseFormat := input.ResponseFormat - if responseFormat == "" { - responseFormat = "mp3" + requestBody := buildOpenAISpeechRequest(model, input, params) + if requestBody == nil { + return nil, newBifrostOperationError("invalid speech input: voice is required", nil, schemas.OpenAI) } - requestBody := map[string]interface{}{ - "input": input.Input, - "model": model, - "voice": input.VoiceConfig.Voice, - "instructions": input.Instructions, - "response_format": responseFormat, - "stream_format": "sse", - } - - if params != nil { - requestBody = mergeConfig(requestBody, params.ExtraParams) - } + requestBody.StreamFormat = ptr("sse") jsonBody, err := sonic.Marshal(requestBody) if err != nil { @@ -1053,8 +1079,6 @@ func parseTranscriptionFormDataBody(writer *multipart.Writer, input *schemas.Tra } } - // Note: Temperature and TimestampGranularities can be added via params.ExtraParams if needed - // Add extra params if provided if params != nil && params.ExtraParams != nil { for key, value := range params.ExtraParams { diff --git a/core/providers/sgl.go b/core/providers/sgl.go index ee092d7983..f4541dd5f4 100644 --- a/core/providers/sgl.go +++ b/core/providers/sgl.go @@ -98,12 +98,7 @@ func (provider *SGLProvider) TextCompletion(ctx context.Context, model string, k // ChatCompletion performs a chat completion request to the SGL API. func (provider *SGLProvider) ChatCompletion(ctx context.Context, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { - formattedMessages, preparedParams := prepareOpenAIChatRequest(messages, params) - - requestBody := mergeConfig(map[string]interface{}{ - "model": model, - "messages": formattedMessages, - }, preparedParams) + requestBody := buildOpenAIChatCompletionRequest(model, messages, params) jsonBody, err := sonic.Marshal(requestBody) if err != nil { @@ -185,13 +180,11 @@ func (provider *SGLProvider) Embedding(ctx context.Context, model string, key sc // Uses SGL's OpenAI-compatible streaming format. // Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails. func (provider *SGLProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { - formattedMessages, preparedParams := prepareOpenAIChatRequest(messages, params) + requestBody := buildOpenAIChatCompletionRequest(model, messages, params) - requestBody := mergeConfig(map[string]interface{}{ - "model": model, - "messages": formattedMessages, - "stream": true, - }, preparedParams) + // Set streaming flag + stream := true + requestBody.Stream = &stream // Prepare SGL headers (SGL typically doesn't require authorization, but we include it if provided) headers := map[string]string{ diff --git a/core/providers/utils.go b/core/providers/utils.go index 91750e201e..e0de10fed9 100644 --- a/core/providers/utils.go +++ b/core/providers/utils.go @@ -16,6 +16,7 @@ import ( "github.com/bytedance/sonic" schemas "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/core/schemas/api" "github.com/valyala/fasthttp" "github.com/valyala/fasthttp/fasthttpproxy" @@ -42,21 +43,6 @@ var fileExtensionToMediaType = map[string]string{ ".bmp": "image/bmp", } -// ImageContentType represents the type of image content -type ImageContentType string - -const ( - ImageContentTypeBase64 ImageContentType = "base64" - ImageContentTypeURL ImageContentType = "url" -) - -// URLTypeInfo contains extracted information about a URL -type URLTypeInfo struct { - Type ImageContentType - MediaType *string - DataURLWithoutPrefix *string // URL without the prefix (eg data:image/png;base64,iVBORw0KGgo...) -} - // ContextKey is a custom type for context keys to prevent key collisions in the context. // It provides type safety for context values and ensures that context keys are unique // across different packages. @@ -148,7 +134,7 @@ func makeRequestWithContext(ctx context.Context, client *fasthttp.Client, req *f return &schemas.BifrostError{ IsBifrostError: true, Error: schemas.ErrorField{ - Type: StrPtr(schemas.RequestCancelled), + Type: ptr(schemas.RequestCancelled), Message: fmt.Sprintf("Request cancelled or timed out by context: %v", ctx.Err()), Error: ctx.Err(), }, @@ -377,16 +363,10 @@ func getRoleFromMessage(msg map[string]interface{}) (schemas.ModelChatMessageRol return "", false // Role is of an unexpected or invalid type } -// float64Ptr creates a pointer to a float64 value. -// This is a helper function for creating pointers to float64 values. -func float64Ptr(f float64) *float64 { - return &f -} - -// StrPtr creates a pointer to a string value. -// This is a helper function for creating pointers to string values. -func StrPtr(s string) *string { - return &s +// ptr creates a pointer to a value. +// This is a helper function for creating pointers to values. +func ptr[T any](v T) *T { + return &v } //* IMAGE UTILS *// @@ -445,7 +425,7 @@ func SanitizeImageURL(rawURL string) (string, error) { // ExtractURLTypeInfo extracts type and media type information from a sanitized URL. // For data URLs, it parses the media type and encoding. // For regular URLs, it attempts to infer the media type from the file extension. -func ExtractURLTypeInfo(sanitizedURL string) URLTypeInfo { +func ExtractURLTypeInfo(sanitizedURL string) api.URLTypeInfo { if strings.HasPrefix(sanitizedURL, "data:") { return extractDataURLInfo(sanitizedURL) } @@ -453,12 +433,12 @@ func ExtractURLTypeInfo(sanitizedURL string) URLTypeInfo { } // extractDataURLInfo extracts information from a data URL -func extractDataURLInfo(dataURL string) URLTypeInfo { +func extractDataURLInfo(dataURL string) api.URLTypeInfo { // Parse data URL: data:[][;base64], matches := dataURIRegex.FindStringSubmatch(dataURL) if len(matches) != 4 { - return URLTypeInfo{Type: ImageContentTypeBase64} + return api.URLTypeInfo{Type: api.ImageContentTypeBase64} } mediaType := matches[1] @@ -469,24 +449,24 @@ func extractDataURLInfo(dataURL string) URLTypeInfo { dataURLWithoutPrefix = dataURL[len("data:")+len(mediaType)+len(";base64,"):] } - info := URLTypeInfo{ + info := api.URLTypeInfo{ MediaType: &mediaType, DataURLWithoutPrefix: &dataURLWithoutPrefix, } if isBase64 { - info.Type = ImageContentTypeBase64 + info.Type = api.ImageContentTypeBase64 } else { - info.Type = ImageContentTypeURL // Non-base64 data URL + info.Type = api.ImageContentTypeURL // Non-base64 data URL } return info } // extractRegularURLInfo extracts information from a regular HTTP/HTTPS URL -func extractRegularURLInfo(regularURL string) URLTypeInfo { - info := URLTypeInfo{ - Type: ImageContentTypeURL, +func extractRegularURLInfo(regularURL string) api.URLTypeInfo { + info := api.URLTypeInfo{ + Type: api.ImageContentTypeURL, } // Try to infer media type from file extension diff --git a/core/providers/vertex.go b/core/providers/vertex.go index 88bd5e65b7..afcd8cd376 100644 --- a/core/providers/vertex.go +++ b/core/providers/vertex.go @@ -18,6 +18,7 @@ import ( "github.com/bytedance/sonic" schemas "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/core/schemas/api" ) type VertexError struct { @@ -66,8 +67,8 @@ func NewVertexProvider(config *schemas.ProviderConfig, logger schemas.Logger) (* // Pre-warm response pools for range config.ConcurrencyAndBufferSize.Concurrency { - openAIResponsePool.Put(&OpenAIResponse{}) - anthropicChatResponsePool.Put(&AnthropicChatResponse{}) + openAIResponsePool.Put(&api.OpenAIResponse{}) + anthropicChatResponsePool.Put(&api.AnthropicChatResponse{}) } @@ -136,36 +137,6 @@ func (provider *VertexProvider) ChatCompletion(ctx context.Context, model string return nil, newConfigurationError("vertex key config is not set", schemas.Vertex) } - // Format messages for Vertex API - var formattedMessages []map[string]interface{} - var preparedParams map[string]interface{} - - if strings.Contains(model, "claude") { - formattedMessages, preparedParams = prepareAnthropicChatRequest(messages, params) - } else { - formattedMessages, preparedParams = prepareOpenAIChatRequest(messages, params) - } - - requestBody := mergeConfig(map[string]interface{}{ - "model": model, - "messages": formattedMessages, - }, preparedParams) - - if strings.Contains(model, "claude") { - if _, exists := requestBody["anthropic_version"]; !exists { - requestBody["anthropic_version"] = "vertex-2023-10-16" - } - - delete(requestBody, "model") - } - - delete(requestBody, "region") - - jsonBody, err := sonic.Marshal(requestBody) - if err != nil { - return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, schemas.Vertex) - } - projectID := key.VertexKeyConfig.ProjectID if projectID == "" { return nil, newConfigurationError("project ID is not set", schemas.Vertex) @@ -176,10 +147,32 @@ func (provider *VertexProvider) ChatCompletion(ctx context.Context, model string return nil, newConfigurationError("region is not set in meta config", schemas.Vertex) } - url := fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions", region, projectID, region) + // Determine request type based on model + var requestBody interface{} + var url string - if strings.Contains(model, "claude") { + if api.IsAnthropicModel(model) { + // Use Anthropic-style request for Claude models + anthropicRequest := buildAnthropicChatRequest(model, messages, params) + + // Set Vertex-specific fields + if anthropicRequest.AnthropicVersion == nil { + version := "vertex-2023-10-16" + anthropicRequest.AnthropicVersion = &version + } + + requestBody = anthropicRequest url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:streamRawPredict", region, projectID, region, model) + } else { + // Use OpenAI-style request for non-Claude models + openaiRequest := buildOpenAIChatCompletionRequest(model, messages, params) + requestBody = openaiRequest + url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions", region, projectID, region) + } + + jsonBody, err := sonic.Marshal(requestBody) + if err != nil { + return nil, newBifrostOperationError(schemas.ErrProviderJSONMarshaling, err, schemas.Vertex) } // Create request @@ -213,7 +206,7 @@ func (provider *VertexProvider) ChatCompletion(ctx context.Context, model string return nil, &schemas.BifrostError{ IsBifrostError: false, Error: schemas.ErrorField{ - Type: StrPtr(schemas.RequestCancelled), + Type: ptr(schemas.RequestCancelled), Message: fmt.Sprintf("Request cancelled or timed out by context: %v", ctx.Err()), Error: err, }, @@ -350,22 +343,19 @@ func (provider *VertexProvider) ChatCompletionStream(ctx context.Context, postHo return nil, newBifrostOperationError("error creating auth client", err, schemas.Vertex) } - if strings.Contains(model, "claude") { + if api.IsAnthropicModel(model) { // Use Anthropic-style streaming for Claude models - formattedMessages, preparedParams := prepareAnthropicChatRequest(messages, params) + anthropicRequest := buildAnthropicChatRequest(model, messages, params) - requestBody := mergeConfig(map[string]interface{}{ - "messages": formattedMessages, - "stream": true, - }, preparedParams) + // Set streaming and Vertex-specific fields + stream := true + anthropicRequest.Stream = &stream - if _, exists := requestBody["anthropic_version"]; !exists { - requestBody["anthropic_version"] = "vertex-2023-10-16" + if anthropicRequest.AnthropicVersion == nil { + version := "vertex-2023-10-16" + anthropicRequest.AnthropicVersion = &version } - delete(requestBody, "model") - delete(requestBody, "region") - url := fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:streamRawPredict", region, projectID, region, model) // Prepare headers for Vertex Anthropic @@ -380,7 +370,7 @@ func (provider *VertexProvider) ChatCompletionStream(ctx context.Context, postHo ctx, client, url, - requestBody, + anthropicRequest, headers, provider.networkConfig.ExtraHeaders, schemas.Vertex, @@ -390,15 +380,11 @@ func (provider *VertexProvider) ChatCompletionStream(ctx context.Context, postHo ) } else { // Use OpenAI-style streaming for non-Claude models - formattedMessages, preparedParams := prepareOpenAIChatRequest(messages, params) - - requestBody := mergeConfig(map[string]interface{}{ - "model": model, - "messages": formattedMessages, - "stream": true, - }, preparedParams) + openaiRequest := buildOpenAIChatCompletionRequest(model, messages, params) - delete(requestBody, "region") + // Set streaming + stream := true + openaiRequest.Stream = &stream url := fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions", region, projectID, region) @@ -414,7 +400,7 @@ func (provider *VertexProvider) ChatCompletionStream(ctx context.Context, postHo ctx, client, url, - requestBody, + openaiRequest, headers, provider.networkConfig.ExtraHeaders, schemas.Vertex, diff --git a/core/schemas/api/anthropic.go b/core/schemas/api/anthropic.go new file mode 100644 index 0000000000..a6977784e0 --- /dev/null +++ b/core/schemas/api/anthropic.go @@ -0,0 +1,434 @@ +package api + +import ( + "fmt" + + "github.com/bytedance/sonic" + "github.com/maximhq/bifrost/core/schemas" +) + +type AnthropicRequestConfig struct { + URL string `json:"url"` + AnthropicTextRequest *AnthropicTextRequest `json:"anthropic_text_request,omitempty"` + AnthropicMessageRequest *AnthropicMessageRequest `json:"anthropic_message_request,omitempty"` +} + +// AnthropicToolChoice represents the tool choice configuration for Anthropic's API. +// It specifies how tools should be used in the completion request. +type AnthropicToolChoice struct { + Type schemas.ToolChoiceType `json:"type"` // Type of tool choice + Name *string `json:"name"` // Name of the tool to use + DisableParallelToolUse *bool `json:"disable_parallel_tool_use"` // Whether to disable parallel tool use +} + +// AnthropicTextResponse represents the response structure from Anthropic's text completion API. +// It includes the completion text, model information, and token usage statistics. +type AnthropicTextResponse struct { + ID string `json:"id"` // Unique identifier for the completion + Type string `json:"type"` // Type of completion + Completion string `json:"completion"` // Generated completion text + Model string `json:"model"` // Model used for the completion + Usage *AnthropicUsage `json:"usage"` // Token usage statistics +} + +// AnthropicChatResponse represents the response structure from Anthropic's chat completion API. +// It includes message content, model information, and token usage statistics. +type AnthropicChatResponse struct { + ID string `json:"id"` // Unique identifier for the completion + Type string `json:"type"` // Type of completion + Role string `json:"role"` // Role of the message sender + Content []AnthropicResponseContent `json:"content"` // Array of content items + Model string `json:"model"` // Model used for the completion + StopReason string `json:"stop_reason,omitempty"` // Reason for completion termination + StopSequence *string `json:"stop_sequence,omitempty"` // Sequence that caused completion to stop + Usage *AnthropicUsage `json:"usage"` // Token usage statistics +} + +type AnthropicResponseContent struct { + Type string `json:"type"` // Type of content + Text string `json:"text,omitempty"` // Text content + Thinking string `json:"thinking,omitempty"` // Thinking process + ID string `json:"id"` // Content identifier + Name string `json:"name"` // Name of the content + Input map[string]interface{} `json:"input"` // Input parameters +} + +// AnthropicStreamEvent represents a single event in the Anthropic streaming response. +// It corresponds to the various event types defined in Anthropic's Messages API streaming documentation. +type AnthropicStreamEvent struct { + Type string `json:"type"` + Message *AnthropicStreamMessage `json:"message,omitempty"` + Index *int `json:"index,omitempty"` + ContentBlock *AnthropicContentBlock `json:"content_block,omitempty"` + Delta *AnthropicDelta `json:"delta,omitempty"` + Usage *schemas.LLMUsage `json:"usage,omitempty"` + Error *AnthropicStreamError `json:"error,omitempty"` +} + +// AnthropicStreamMessage represents the message structure in streaming events. +// This appears in message_start events and contains the initial message structure. +type AnthropicStreamMessage struct { + ID string `json:"id"` + Type string `json:"type"` + Role string `json:"role"` + Content []AnthropicContentBlock `json:"content"` + Model string `json:"model"` + StopReason *string `json:"stop_reason"` + StopSequence *string `json:"stop_sequence"` + Usage *schemas.LLMUsage `json:"usage"` +} + +// AnthropicContentBlock represents content in Anthropic message format +type AnthropicContentBlock struct { + Type string `json:"type"` // "text", "image", "tool_use", "tool_result" + Text *string `json:"text,omitempty"` // For text content + ToolUseID *string `json:"tool_use_id,omitempty"` // For tool_result content + ID *string `json:"id,omitempty"` // For tool_use content + Name *string `json:"name,omitempty"` // For tool_use content + Input interface{} `json:"input,omitempty"` // For tool_use content + Content AnthropicContent `json:"content,omitempty"` // For tool_result content + Source *AnthropicImageSource `json:"source,omitempty"` // For image content +} + +// AnthropicImageSource represents image source in Anthropic format +type AnthropicImageSource struct { + Type string `json:"type"` // "base64" or "url" + MediaType *string `json:"media_type,omitempty"` // "image/jpeg", "image/png", etc. + Data *string `json:"data,omitempty"` // Base64-encoded image data + URL *string `json:"url,omitempty"` // URL of the image +} + +// AnthropicToolContent represents content within tool result blocks +type AnthropicToolContent struct { + Type string `json:"type"` + Title string `json:"title,omitempty"` + URL string `json:"url,omitempty"` + EncryptedContent string `json:"encrypted_content,omitempty"` + PageAge *string `json:"page_age,omitempty"` +} + +// AnthropicDelta represents incremental updates to content blocks during streaming. +// This includes all delta types: text_delta, input_json_delta, thinking_delta, and signature_delta. +type AnthropicDelta struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + PartialJSON string `json:"partial_json,omitempty"` + Thinking string `json:"thinking,omitempty"` + Signature string `json:"signature,omitempty"` + StopReason *string `json:"stop_reason,omitempty"` + StopSequence *string `json:"stop_sequence,omitempty"` +} + +// AnthropicStreamError represents error events in the streaming response. +type AnthropicStreamError struct { + Type string `json:"type"` + Message string `json:"message"` +} + +// AnthropicError represents the error response structure from Anthropic's API. +// It includes error type and message information. +type AnthropicError struct { + Type string `json:"type"` // always "error" + Error struct { + Type string `json:"type"` // Error type + Message string `json:"message"` // Error message + } `json:"error"` // Error details +} + +// URLTypeInfo contains extracted information about a URL +type URLTypeInfo struct { + Type ImageContentType + MediaType *string + DataURLWithoutPrefix *string // URL without the prefix (eg data:image/png;base64,iVBORw0KGgo...) +} + +// ImageContentType represents the type of image content +type ImageContentType string + +const ( + ImageContentTypeBase64 ImageContentType = "base64" + ImageContentTypeURL ImageContentType = "url" +) + +type AnthropicImageContent struct { + Type ImageContentType `json:"type"` + URL string `json:"url"` + MediaType string `json:"media_type,omitempty"` +} + +// AnthropicMessage represents a message in Anthropic format +type AnthropicMessage struct { + Role string `json:"role"` // "user", "assistant" + Content AnthropicContent `json:"content"` // Array of content blocks +} + +type AnthropicContent struct { + ContentStr *string + ContentBlocks *[]AnthropicContentBlock +} + +// AnthropicTool represents a tool in Anthropic format +type AnthropicTool struct { + Name string `json:"name"` + Type *string `json:"type,omitempty"` + Description string `json:"description"` + InputSchema *struct { + Type string `json:"type"` // "object" + Properties map[string]interface{} `json:"properties"` + Required []string `json:"required"` + } `json:"input_schema,omitempty"` +} + +// AnthropicMessageResponse represents an Anthropic messages API response +type AnthropicMessageResponse struct { + ID string `json:"id"` + Type string `json:"type"` + Role string `json:"role"` + Content []AnthropicContentBlock `json:"content"` + Model string `json:"model"` + StopReason *string `json:"stop_reason,omitempty"` + StopSequence *string `json:"stop_sequence,omitempty"` + Usage *AnthropicUsage `json:"usage,omitempty"` +} + +// AnthropicUsage represents usage information in Anthropic format +type AnthropicUsage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` +} + +// AnthropicStreamResponse represents a single chunk in the Anthropic streaming response +// This matches the format expected by Anthropic's streaming API clients +type AnthropicStreamResponse struct { + Type string `json:"type"` + ID *string `json:"id,omitempty"` + Model *string `json:"model,omitempty"` + Index *int `json:"index,omitempty"` + Message *AnthropicStreamMessage `json:"message,omitempty"` + ContentBlock *AnthropicContentBlock `json:"content_block,omitempty"` + Delta *AnthropicStreamDelta `json:"delta,omitempty"` + Usage *AnthropicUsage `json:"usage,omitempty"` +} + +// AnthropicStreamDelta represents the incremental content in a streaming chunk +type AnthropicStreamDelta struct { + Type string `json:"type"` + Text *string `json:"text,omitempty"` + Thinking *string `json:"thinking,omitempty"` + PartialJSON *string `json:"partial_json,omitempty"` + StopReason *string `json:"stop_reason,omitempty"` + StopSequence *string `json:"stop_sequence,omitempty"` +} + +// AnthropicMessageError represents an Anthropic messages API error response +type AnthropicMessageError struct { + Type string `json:"type"` // always "error" + Error AnthropicMessageErrorStruct `json:"error"` // Error details +} + +// AnthropicMessageErrorStruct represents the error structure of an Anthropic messages API error response +type AnthropicMessageErrorStruct struct { + Type string `json:"type"` // Error type + Message string `json:"message"` // Error message +} + +// AnthropicMessageRequest represents an Anthropic messages API request +type AnthropicMessageRequest struct { + Model string `json:"model"` + MaxTokens int `json:"max_tokens"` + Messages []AnthropicMessage `json:"messages"` + System *AnthropicContent `json:"system,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + TopK *int `json:"top_k,omitempty"` + StopSequences []string `json:"stop_sequences,omitempty"` + Stream *bool `json:"stream,omitempty"` + Tools []AnthropicTool `json:"tools,omitempty"` + ToolChoice *AnthropicToolChoice `json:"tool_choice,omitempty"` + AnthropicVersion *string `json:"anthropic_version,omitempty"` + Region *string `json:"region,omitempty"` + ExtraParams map[string]interface{} `json:"-"` +} + +func (mr *AnthropicMessageRequest) MarshalJSON() ([]byte, error) { + // Use standard marshaling when no extra params - gives us type safety and performance + if len(mr.ExtraParams) == 0 { + type Alias AnthropicMessageRequest + return sonic.Marshal((*Alias)(mr)) + } + + // When ExtraParams exist, use dynamic approach with conflict detection + result := make(map[string]interface{}, 13+len(mr.ExtraParams)) + + // Add all fields directly - no reflection overhead + result["model"] = mr.Model + result["max_tokens"] = mr.MaxTokens + result["messages"] = mr.Messages + result["system"] = mr.System + + // Track which JSON field names are set to avoid conflicts + setFields := make(map[string]bool) + setFields["model"] = true + setFields["max_tokens"] = true + setFields["messages"] = true + setFields["system"] = true + + if mr.Temperature != nil { + result["temperature"] = *mr.Temperature + setFields["temperature"] = true + } + if mr.TopP != nil { + result["top_p"] = *mr.TopP + setFields["top_p"] = true + } + if mr.TopK != nil { + result["top_k"] = *mr.TopK + setFields["top_k"] = true + } + if mr.StopSequences != nil { + result["stop_sequences"] = mr.StopSequences + setFields["stop_sequences"] = true + } + if mr.Stream != nil { + result["stream"] = *mr.Stream + setFields["stream"] = true + } + if mr.Tools != nil { + result["tools"] = mr.Tools + setFields["tools"] = true + } + if mr.ToolChoice != nil { + result["tool_choice"] = mr.ToolChoice + setFields["tool_choice"] = true + } + if mr.AnthropicVersion != nil { + result["anthropic_version"] = *mr.AnthropicVersion + setFields["anthropic_version"] = true + } + if mr.Region != nil { + result["region"] = *mr.Region + setFields["region"] = true + } + + // Add ExtraParams only if they don't conflict with existing fields + for key, value := range mr.ExtraParams { + if !setFields[key] { + result[key] = value + } + // Silently skip conflicting fields - this prevents overwriting typed fields + // while still allowing unknown fields to pass through + } + + return sonic.Marshal(result) +} + +// AnthropicTextRequest represents an Anthropic text completion API request +type AnthropicTextRequest struct { + Model string `json:"model"` // Required: Model identifier + Prompt string `json:"prompt"` // Required: Text prompt for completion + MaxTokensToSample int `json:"max_tokens_to_sample,omitempty"` // Optional: Maximum tokens to generate + Temperature *float64 `json:"temperature,omitempty"` // Optional: Sampling temperature (0-1) + TopP *float64 `json:"top_p,omitempty"` // Optional: Nucleus sampling (0-1) + TopK *int `json:"top_k,omitempty"` // Optional: Top K sampling + StopSequences []string `json:"stop_sequences,omitempty"` // Optional: Sequences that stop generation + Stream *bool `json:"stream,omitempty"` // Optional: Enable streaming + ExtraParams map[string]interface{} `json:"-"` +} + + +func (r *AnthropicTextRequest) MarshalJSON() ([]byte, error) { + // Use standard marshaling when no extra params - gives us type safety and performance + if len(r.ExtraParams) == 0 { + type Alias AnthropicTextRequest + return sonic.Marshal((*Alias)(r)) + } + + // When ExtraParams exist, use dynamic approach with conflict detection + result := make(map[string]interface{}, 8+len(r.ExtraParams)) + + result["model"] = r.Model + result["prompt"] = r.Prompt + result["max_tokens_to_sample"] = r.MaxTokensToSample + + // Track which JSON field names are set to avoid conflicts + setFields := make(map[string]bool) + setFields["model"] = true + setFields["prompt"] = true + setFields["max_tokens_to_sample"] = true + + if r.Temperature != nil { + result["temperature"] = *r.Temperature + setFields["temperature"] = true + } + if r.TopP != nil { + result["top_p"] = *r.TopP + setFields["top_p"] = true + } + if r.TopK != nil { + result["top_k"] = *r.TopK + setFields["top_k"] = true + } + if r.StopSequences != nil { + result["stop_sequences"] = r.StopSequences + setFields["stop_sequences"] = true + } + if r.Stream != nil { + result["stream"] = *r.Stream + setFields["stream"] = true + } + + // Add ExtraParams only if they don't conflict with existing fields + for key, value := range r.ExtraParams { + if !setFields[key] { + result[key] = value + } + // Silently skip conflicting fields - this prevents overwriting typed fields + // while still allowing unknown fields to pass through + } + + return sonic.Marshal(result) +} + +// IsStreamingRequested implements the StreamingRequest interface +func (r *AnthropicMessageRequest) IsStreamingRequested() bool { + return r.Stream != nil && *r.Stream +} + +// MarshalJSON implements custom JSON marshalling for MessageContent. +// It marshals either ContentStr or ContentBlocks directly without wrapping. +func (mc AnthropicContent) MarshalJSON() ([]byte, error) { + // Validation: ensure only one field is set at a time + if mc.ContentStr != nil && mc.ContentBlocks != nil { + return nil, fmt.Errorf("both ContentStr and ContentBlocks are set; only one should be non-nil") + } + + if mc.ContentStr != nil { + return sonic.Marshal(*mc.ContentStr) + } + if mc.ContentBlocks != nil { + return sonic.Marshal(*mc.ContentBlocks) + } + // If both are nil, return null + return sonic.Marshal(nil) +} + +// UnmarshalJSON implements custom JSON unmarshalling for MessageContent. +// It determines whether "content" is a string or array and assigns to the appropriate field. +// It also handles direct string/array content without a wrapper object. +func (mc *AnthropicContent) UnmarshalJSON(data []byte) error { + // First, try to unmarshal as a direct string + var stringContent string + if err := sonic.Unmarshal(data, &stringContent); err == nil { + mc.ContentStr = &stringContent + return nil + } + + // Try to unmarshal as a direct array of ContentBlock + var arrayContent []AnthropicContentBlock + if err := sonic.Unmarshal(data, &arrayContent); err == nil { + mc.ContentBlocks = &arrayContent + return nil + } + + return fmt.Errorf("content field is neither a string nor an array of ContentBlock") +} diff --git a/core/schemas/api/bedrock.go b/core/schemas/api/bedrock.go new file mode 100644 index 0000000000..3c682de1c9 --- /dev/null +++ b/core/schemas/api/bedrock.go @@ -0,0 +1,424 @@ +package api + +import ( + "encoding/json" + + "github.com/bytedance/sonic" + "github.com/maximhq/bifrost/core/schemas" +) + +// BedrockTextRequest represents the unified request structure for Bedrock's text completion API. +// This typed struct optimizes JSON marshalling performance and supports both Anthropic and Mistral models. +type BedrockTextRequest struct { + Prompt string `json:"prompt"` // Required: The prompt to complete + MaxTokensToSample *int `json:"max_tokens_to_sample,omitempty"` // Anthropic: Maximum tokens to generate (0-4096, default 200) + MaxTokens *int `json:"max_tokens,omitempty"` // Mistral: Maximum tokens to generate + Temperature *float64 `json:"temperature,omitempty"` // Optional: Amount of randomness (0-1, default 1) + TopP *float64 `json:"top_p,omitempty"` // Optional: Nucleus sampling (0-1, default 1) + TopK *int `json:"top_k,omitempty"` // Optional: Top K sampling (0-500, default 250) + StopSequences []string `json:"stop_sequences,omitempty"` // Optional: Sequences that cause generation to stop + ExtraParams map[string]interface{} `json:"-"` +} + +func (r *BedrockTextRequest) MarshalJSON() ([]byte, error) { + // Use standard marshaling when no extra params - gives us type safety and performance + if len(r.ExtraParams) == 0 { + type Alias BedrockTextRequest + return sonic.Marshal((*Alias)(r)) + } + + // When ExtraParams exist, use dynamic approach with conflict detection + result := make(map[string]interface{}, 7+len(r.ExtraParams)) + + // Add all fields directly - no reflection overhead + result["prompt"] = r.Prompt + + // Track which JSON field names are set to avoid conflicts + setFields := make(map[string]bool) + setFields["prompt"] = true + + if r.MaxTokensToSample != nil { + result["max_tokens_to_sample"] = *r.MaxTokensToSample + setFields["max_tokens_to_sample"] = true + } + if r.MaxTokens != nil { + result["max_tokens"] = *r.MaxTokens + setFields["max_tokens"] = true + } + if r.Temperature != nil { + result["temperature"] = *r.Temperature + setFields["temperature"] = true + } + if r.TopP != nil { + result["top_p"] = *r.TopP + setFields["top_p"] = true + } + if r.TopK != nil { + result["top_k"] = *r.TopK + setFields["top_k"] = true + } + if len(r.StopSequences) > 0 { + result["stop_sequences"] = r.StopSequences + setFields["stop_sequences"] = true + } + + // Add ExtraParams only if they don't conflict with existing fields + for key, value := range r.ExtraParams { + if !setFields[key] { + result[key] = value + } + // Silently skip conflicting fields - this prevents overwriting typed fields + // while still allowing unknown fields to pass through + } + + return sonic.Marshal(result) +} + +// BedrockAnthropicTextResponse represents the response structure from Bedrock's Anthropic text completion API. +// It includes the completion text and stop reason information. +type BedrockAnthropicTextResponse struct { + Completion string `json:"completion"` // Generated completion text + StopReason string `json:"stop_reason"` // Reason for completion termination + Stop string `json:"stop"` // Stop sequence that caused completion to stop +} + +// BedrockMistralTextResponse represents the response structure from Bedrock's Mistral text completion API. +// It includes multiple output choices with their text and stop reasons. +type BedrockMistralTextResponse struct { + Outputs []struct { + Text string `json:"text"` // Generated text + StopReason string `json:"stop_reason"` // Reason for completion termination + } `json:"outputs"` // Array of output choices +} + +// BedrockChatResponse represents the response structure from Bedrock's chat completion API. +// It includes message content, metrics, and token usage statistics. +type BedrockChatResponse struct { + Metrics struct { + Latency int `json:"latencyMs"` // Response latency in milliseconds + } `json:"metrics"` // Performance metrics + Output struct { + Message struct { + Content []struct { + Text *string `json:"text"` // Message content + // Bedrock returns a union type where either Text or ToolUse is present (mutually exclusive) + BedrockAnthropicToolUseMessage + } `json:"content"` // Array of message content + Role string `json:"role"` // Role of the message sender + } `json:"message"` // Message structure + } `json:"output"` // Output structure + StopReason string `json:"stopReason"` // Reason for completion termination + Usage struct { + InputTokens int `json:"inputTokens"` // Number of input tokens used + OutputTokens int `json:"outputTokens"` // Number of output tokens generated + TotalTokens int `json:"totalTokens"` // Total number of tokens used + } `json:"usage"` // Token usage statistics +} + +type BedrockAnthropicToolUseMessage struct { + ToolUse *BedrockAnthropicToolUse `json:"toolUse"` +} + +type BedrockAnthropicToolUse struct { + ToolUseID string `json:"toolUseId"` + Name string `json:"name"` + Input map[string]interface{} `json:"input"` +} + +// BedrockError represents the error response structure from Bedrock's API. +type BedrockError struct { + Message string `json:"message"` // Error message +} + +// BedrockAnthropicSystemMessage represents a system message for Anthropic models. +type BedrockAnthropicSystemMessage struct { + Text string `json:"text"` // System message text +} + +// BedrockAnthropicTextMessage represents a text message for Anthropic models. +type BedrockAnthropicTextMessage struct { + Type string `json:"type"` // Type of message + Text string `json:"text"` // Message text +} + +// BedrockMistralContent represents content for Mistral models. +type BedrockMistralContent struct { + Text string `json:"text"` // Content text +} + +// BedrockMistralChatMessage represents a chat message for Mistral models. +type BedrockMistralChatMessage struct { + Role schemas.ModelChatMessageRole `json:"role"` // Role of the message sender + Content []BedrockMistralContent `json:"content"` // Array of message content + ToolCalls *[]BedrockMistralToolCall `json:"tool_calls,omitempty"` // Optional tool calls + ToolCallID *string `json:"tool_call_id,omitempty"` // Optional tool call ID +} + +// BedrockAnthropicImageMessage represents an image message for Anthropic models. +type BedrockAnthropicImageMessage struct { + Type string `json:"type"` // Type of message + Image BedrockAnthropicImage `json:"image"` // Image data +} + +// BedrockAnthropicImage represents image data for Anthropic models. +type BedrockAnthropicImage struct { + Format string `json:"format,omitempty"` // Image format + Source BedrockAnthropicImageSource `json:"source,omitempty"` // Image source +} + +// BedrockAnthropicImageSource represents the source of an image for Anthropic models. +type BedrockAnthropicImageSource struct { + Bytes string `json:"bytes"` // Base64 encoded image data +} + +// BedrockMistralToolCall represents a tool call for Mistral models. +type BedrockMistralToolCall struct { + ID string `json:"id"` // Tool call ID + Function schemas.FunctionCall `json:"function"` // Function to call +} + +// BedrockAnthropicToolCall represents a tool call for Anthropic models. +type BedrockAnthropicToolCall struct { + ToolSpec BedrockAnthropicToolSpec `json:"toolSpec"` // Tool specification +} + +// BedrockAnthropicToolSpec represents a tool specification for Anthropic models. +type BedrockAnthropicToolSpec struct { + Name string `json:"name"` // Tool name + Description string `json:"description"` // Tool description + InputSchema struct { + Json interface{} `json:"json"` // Input schema in JSON format + } `json:"inputSchema"` // Input schema structure +} + +// BedrockStreamMessageStartEvent is emitted when the assistant message starts. +type BedrockStreamMessageStartEvent struct { + MessageStart struct { + Role string `json:"role"` // e.g. "assistant" + } `json:"messageStart"` +} + +// BedrockStreamContentBlockDeltaEvent is sent for each content delta chunk (text, reasoning, tool use). +type BedrockStreamContentBlockDeltaEvent struct { + ContentBlockDelta struct { + Delta struct { + Text string `json:"text,omitempty"` + ReasoningContent json.RawMessage `json:"reasoningContent,omitempty"` + ToolUse json.RawMessage `json:"toolUse,omitempty"` + } `json:"delta"` + ContentBlockIndex int `json:"contentBlockIndex"` + } `json:"contentBlockDelta"` +} + +// BedrockStreamContentBlockStopEvent indicates the end of a content block. +type BedrockStreamContentBlockStopEvent struct { + ContentBlockStop struct { + ContentBlockIndex int `json:"contentBlockIndex"` + } `json:"contentBlockStop"` +} + +// BedrockStreamMessageStopEvent marks the end of the assistant message. +type BedrockStreamMessageStopEvent struct { + MessageStop struct { + StopReason string `json:"stopReason"` // e.g. "stop", "max_tokens", "tool_use" + } `json:"messageStop"` +} + +// BedrockStreamMetadataEvent contains metadata after streaming ends. +type BedrockStreamMetadataEvent struct { + Metadata struct { + Usage struct { + InputTokens int `json:"inputTokens"` + OutputTokens int `json:"outputTokens"` + TotalTokens int `json:"totalTokens"` + } `json:"usage"` + Metrics struct { + LatencyMs float64 `json:"latencyMs"` + } `json:"metrics"` + } `json:"metadata"` +} + +// BedrockChatRequest represents the unified request structure for Bedrock's chat completion API. +// This typed struct optimizes JSON marshalling performance and supports various models. +type BedrockChatRequest struct { + Messages []BedrockMistralChatMessage `json:"messages"` // Formatted messages + Tools []BedrockAnthropicToolCall `json:"tools,omitempty"` // Optional tool definitions + ToolChoice *string `json:"tool_choice,omitempty"` // Optional tool choice ("auto", "any", "none") + MaxTokens *int `json:"max_tokens,omitempty"` // Maximum tokens to generate + Temperature *float64 `json:"temperature,omitempty"` // Sampling temperature + TopP *float64 `json:"top_p,omitempty"` // Nucleus sampling + ExtraParams map[string]interface{} `json:"-"` +} + +func (r *BedrockChatRequest) MarshalJSON() ([]byte, error) { + // Use standard marshaling when no extra params - gives us type safety and performance + if len(r.ExtraParams) == 0 { + type Alias BedrockChatRequest + return sonic.Marshal((*Alias)(r)) + } + + // When ExtraParams exist, use dynamic approach with conflict detection + result := make(map[string]interface{}, 6+len(r.ExtraParams)) + + // Add all fields directly - no reflection overhead + result["messages"] = r.Messages + + // Track which JSON field names are set to avoid conflicts + setFields := make(map[string]bool) + setFields["messages"] = true + + if r.MaxTokens != nil { + result["max_tokens"] = *r.MaxTokens + setFields["max_tokens"] = true + } + if r.Temperature != nil { + result["temperature"] = *r.Temperature + setFields["temperature"] = true + } + if r.TopP != nil { + result["top_p"] = *r.TopP + setFields["top_p"] = true + } + if r.Tools != nil { + result["tools"] = r.Tools + setFields["tools"] = true + } + if r.ToolChoice != nil { + result["tool_choice"] = *r.ToolChoice + setFields["tool_choice"] = true + } + + // Add ExtraParams only if they don't conflict with existing fields + for key, value := range r.ExtraParams { + if !setFields[key] { + result[key] = value + } + // Silently skip conflicting fields - this prevents overwriting typed fields + // while still allowing unknown fields to pass through + } + + return sonic.Marshal(result) +} + +// BedrockTool represents a tool definition for Bedrock models. +type BedrockTool struct { + Type string `json:"type"` // Tool type (e.g., "function") + Function BedrockFunction `json:"function"` // Function definition +} + +// BedrockFunction represents a function definition for tools. +type BedrockFunction struct { + Name string `json:"name"` // Function name + Description string `json:"description"` // Function description + Parameters map[string]interface{} `json:"parameters"` // Function parameters schema +} + +// BedrockToolConfig represents tool configuration for Bedrock requests. +type BedrockToolConfig struct { + Tools []BedrockAnthropicToolCall `json:"tools"` // Array of tool specifications +} + +// BedrockTitanEmbeddingRequest represents the request structure for Titan embedding API. +type BedrockTitanEmbeddingRequest struct { + InputText string `json:"inputText"` // Text to embed + Dimensions *int `json:"dimensions,omitempty"` // Dimensions to embed + Normalize *bool `json:"normalize,omitempty"` // Normalize the embedding + EmbeddingTypes []interface{} `json:"embeddingTypes,omitempty"` // Embedding types to embed + ExtraParams map[string]interface{} `json:"-"` +} + +func (r *BedrockTitanEmbeddingRequest) MarshalJSON() ([]byte, error) { + // Use standard marshaling when no extra params - gives us type safety and performance + if len(r.ExtraParams) == 0 { + type Alias BedrockTitanEmbeddingRequest + return sonic.Marshal((*Alias)(r)) + } + + // When ExtraParams exist, use dynamic approach with conflict detection + result := make(map[string]interface{}, 4+len(r.ExtraParams)) + + // Add all fields directly - no reflection overhead + result["inputText"] = r.InputText + + // Track which JSON field names are set to avoid conflicts + setFields := make(map[string]bool) + setFields["inputText"] = true + + if r.Dimensions != nil { + result["dimensions"] = *r.Dimensions + setFields["dimensions"] = true + } + if r.Normalize != nil { + result["normalize"] = *r.Normalize + setFields["normalize"] = true + } + if len(r.EmbeddingTypes) > 0 { + result["embeddingTypes"] = r.EmbeddingTypes + setFields["embeddingTypes"] = true + } + + // Add ExtraParams only if they don't conflict with existing fields + for key, value := range r.ExtraParams { + if !setFields[key] { + result[key] = value + } + // Silently skip conflicting fields - this prevents overwriting typed fields + // while still allowing unknown fields to pass through + } + + return sonic.Marshal(result) +} + +// BedrockCohereEmbeddingRequest represents the request structure for Cohere embedding API. +type BedrockCohereEmbeddingRequest struct { + Texts []string `json:"texts"` // Texts to embed + InputType string `json:"input_type"` // Input type (e.g., "search_document") + Images []string `json:"images,omitempty"` // Images to embed + Truncate *string `json:"truncate,omitempty"` // Truncate the embedding + EmbeddingTypes []string `json:"embedding_types,omitempty"` // Embedding types to embed + ExtraParams map[string]interface{} `json:"-"` +} + +func (r *BedrockCohereEmbeddingRequest) MarshalJSON() ([]byte, error) { + // Use standard marshaling when no extra params - gives us type safety and performance + if len(r.ExtraParams) == 0 { + type Alias BedrockCohereEmbeddingRequest + return sonic.Marshal((*Alias)(r)) + } + + // When ExtraParams exist, use dynamic approach with conflict detection + result := make(map[string]interface{}, 5+len(r.ExtraParams)) + + // Add all fields directly - no reflection overhead + result["texts"] = r.Texts + result["input_type"] = r.InputType + + // Track which JSON field names are set to avoid conflicts + setFields := make(map[string]bool) + setFields["texts"] = true + setFields["input_type"] = true + + if r.Truncate != nil { + result["truncate"] = *r.Truncate + setFields["truncate"] = true + } + if r.Images != nil { + result["images"] = r.Images + setFields["images"] = true + } + if len(r.EmbeddingTypes) > 0 { + result["embedding_types"] = r.EmbeddingTypes + setFields["embedding_types"] = true + } + + // Add ExtraParams only if they don't conflict with existing fields + for key, value := range r.ExtraParams { + if !setFields[key] { + result[key] = value + } + // Silently skip conflicting fields - this prevents overwriting typed fields + // while still allowing unknown fields to pass through + } + + return sonic.Marshal(result) +} diff --git a/core/schemas/api/openai.go b/core/schemas/api/openai.go new file mode 100644 index 0000000000..c6cc55751b --- /dev/null +++ b/core/schemas/api/openai.go @@ -0,0 +1,350 @@ +package api + +import ( + "github.com/bytedance/sonic" + "github.com/maximhq/bifrost/core/schemas" +) + +// OpenAIResponse represents the response structure from the OpenAI API. +// It includes completion choices, model information, and usage statistics. +type OpenAIResponse struct { + ID string `json:"id"` // Unique identifier for the completion + Object string `json:"object"` // Type of completion (text.completion, chat.completion, or embedding) + Choices []schemas.BifrostResponseChoice `json:"choices"` // Array of completion choices + Data []struct { // Embedding data + Object string `json:"object"` + Embedding any `json:"embedding"` + Index int `json:"index"` + } `json:"data,omitempty"` + Model string `json:"model"` // Model used for the completion + Created int `json:"created"` // Unix timestamp of completion creation + ServiceTier *string `json:"service_tier"` // Service tier used for the request + SystemFingerprint *string `json:"system_fingerprint"` // System fingerprint for the request + Usage schemas.LLMUsage `json:"usage"` // Token usage statistics +} + +// OpenAIError represents the error response structure from the OpenAI API. +// It includes detailed error information and event tracking. +type OpenAIError struct { + EventID string `json:"event_id"` // Unique identifier for the error event + Type string `json:"type"` // Type of error + Error struct { + Type string `json:"type"` // Error type + Code string `json:"code"` // Error code + Message string `json:"message"` // Error message + Param interface{} `json:"param"` // Parameter that caused the error + EventID string `json:"event_id"` // Event ID for tracking + } `json:"error"` +} + +// OpenAIChatRequest represents an OpenAI chat completion request +type OpenAIChatRequest struct { + Model string `json:"model"` + Messages []schemas.BifrostMessage `json:"messages"` + MaxTokens *int `json:"max_tokens,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + N *int `json:"n,omitempty"` + Stop interface{} `json:"stop,omitempty"` + PresencePenalty *float64 `json:"presence_penalty,omitempty"` + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` + LogitBias map[string]float64 `json:"logit_bias,omitempty"` + User *string `json:"user,omitempty"` + Tools []schemas.Tool `json:"tools,omitempty"` // Reuse schema type + ToolChoice *schemas.ToolChoice `json:"tool_choice,omitempty"` + Stream *bool `json:"stream,omitempty"` + LogProbs *bool `json:"logprobs,omitempty"` + TopLogProbs *int `json:"top_logprobs,omitempty"` + ResponseFormat interface{} `json:"response_format,omitempty"` + Seed *int `json:"seed,omitempty"` + ExtraParams map[string]interface{} `json:"-"` +} + +func (r *OpenAIChatRequest) MarshalJSON() ([]byte, error) { + // Use standard marshaling when no extra params - gives us type safety and performance + if len(r.ExtraParams) == 0 { + type Alias OpenAIChatRequest + return sonic.Marshal((*Alias)(r)) + } + + // When ExtraParams exist, use dynamic approach with conflict detection + result := make(map[string]interface{}, 18+len(r.ExtraParams)) + + result["model"] = r.Model + result["messages"] = r.Messages + + // Track which JSON field names are set to avoid conflicts + setFields := make(map[string]bool) + setFields["model"] = true + setFields["messages"] = true + + if r.MaxTokens != nil { + result["max_tokens"] = *r.MaxTokens + setFields["max_tokens"] = true + } + if r.Temperature != nil { + result["temperature"] = *r.Temperature + setFields["temperature"] = true + } + if r.TopP != nil { + result["top_p"] = *r.TopP + setFields["top_p"] = true + } + if r.N != nil { + result["n"] = *r.N + setFields["n"] = true + } + if r.Stop != nil { + result["stop"] = r.Stop + setFields["stop"] = true + } + if r.PresencePenalty != nil { + result["presence_penalty"] = *r.PresencePenalty + setFields["presence_penalty"] = true + } + if r.FrequencyPenalty != nil { + result["frequency_penalty"] = *r.FrequencyPenalty + setFields["frequency_penalty"] = true + } + if r.LogitBias != nil { + result["logit_bias"] = r.LogitBias + setFields["logit_bias"] = true + } + if r.User != nil { + result["user"] = *r.User + setFields["user"] = true + } + if r.Tools != nil { + result["tools"] = r.Tools + setFields["tools"] = true + } + if r.ToolChoice != nil { + result["tool_choice"] = *r.ToolChoice + setFields["tool_choice"] = true + } + if r.LogProbs != nil { + result["logprobs"] = *r.LogProbs + setFields["logprobs"] = true + } + if r.TopLogProbs != nil { + result["top_logprobs"] = *r.TopLogProbs + setFields["top_logprobs"] = true + } + if r.ResponseFormat != nil { + result["response_format"] = r.ResponseFormat + setFields["response_format"] = true + } + if r.Seed != nil { + result["seed"] = *r.Seed + setFields["seed"] = true + } + if r.Stream != nil { + result["stream"] = *r.Stream + setFields["stream"] = true + } + + // Add ExtraParams only if they don't conflict with existing fields + for key, value := range r.ExtraParams { + if !setFields[key] { + result[key] = value + } + // Silently skip conflicting fields - this prevents overwriting typed fields + // while still allowing unknown fields to pass through + } + + return sonic.Marshal(result) +} + +type OpenAIEmbeddingRequest struct { + Model string `json:"model"` + Input []string `json:"input"` // Array of strings to embed + EncodingFormat *string `json:"encoding_format,omitempty"` + Dimensions *int `json:"dimensions,omitempty"` + User *string `json:"user,omitempty"` + ExtraParams map[string]interface{} `json:"-"` +} + +func (r *OpenAIEmbeddingRequest) MarshalJSON() ([]byte, error) { + // Use standard marshaling when no extra params - gives us type safety and performance + if len(r.ExtraParams) == 0 { + type Alias OpenAIEmbeddingRequest + return sonic.Marshal((*Alias)(r)) + } + + // When ExtraParams exist, use dynamic approach with conflict detection + result := make(map[string]interface{}, 5+len(r.ExtraParams)) + + result["model"] = r.Model + result["input"] = r.Input + + // Track which JSON field names are set to avoid conflicts + setFields := make(map[string]bool) + setFields["model"] = true + setFields["input"] = true + + if r.EncodingFormat != nil { + result["encoding_format"] = *r.EncodingFormat + setFields["encoding_format"] = true + } + if r.Dimensions != nil { + result["dimensions"] = *r.Dimensions + setFields["dimensions"] = true + } + if r.User != nil { + result["user"] = *r.User + setFields["user"] = true + } + + // Add ExtraParams only if they don't conflict with existing fields + for key, value := range r.ExtraParams { + if !setFields[key] { + result[key] = value + } + // Silently skip conflicting fields - this prevents overwriting typed fields + // while still allowing unknown fields to pass through + } + + return sonic.Marshal(result) +} + +// OpenAISpeechRequest represents an OpenAI speech synthesis request +type OpenAISpeechRequest struct { + Model string `json:"model"` + Input string `json:"input"` + Voice string `json:"voice"` + ResponseFormat *string `json:"response_format,omitempty"` + Speed *float64 `json:"speed,omitempty"` + Instructions *string `json:"instructions,omitempty"` + StreamFormat *string `json:"stream_format,omitempty"` + ExtraParams map[string]interface{} `json:"-"` +} + +func (r *OpenAISpeechRequest) MarshalJSON() ([]byte, error) { + // Use standard marshaling when no extra params - gives us type safety and performance + if len(r.ExtraParams) == 0 { + type Alias OpenAISpeechRequest + return sonic.Marshal((*Alias)(r)) + } + + // When ExtraParams exist, use dynamic approach with conflict detection + result := make(map[string]interface{}, 7+len(r.ExtraParams)) + + result["model"] = r.Model + result["input"] = r.Input + result["voice"] = r.Voice + + // Track which JSON field names are set to avoid conflicts + setFields := make(map[string]bool) + setFields["model"] = true + setFields["input"] = true + setFields["voice"] = true + + if r.ResponseFormat != nil { + result["response_format"] = *r.ResponseFormat + setFields["response_format"] = true + } + if r.Speed != nil { + result["speed"] = *r.Speed + setFields["speed"] = true + } + if r.Instructions != nil { + result["instructions"] = *r.Instructions + setFields["instructions"] = true + } + if r.StreamFormat != nil { + result["stream_format"] = *r.StreamFormat + setFields["stream_format"] = true + } + + // Add ExtraParams only if they don't conflict with existing fields + for key, value := range r.ExtraParams { + if !setFields[key] { + result[key] = value + } + // Silently skip conflicting fields - this prevents overwriting typed fields + // while still allowing unknown fields to pass through + } + + return sonic.Marshal(result) +} + +// OpenAITranscriptionRequest represents an OpenAI transcription request +// Note: This is used for JSON body parsing, actual form parsing is handled in the router +type OpenAITranscriptionRequest struct { + Model string `json:"model"` + File []byte `json:"file"` // Binary audio data + Language *string `json:"language,omitempty"` + Prompt *string `json:"prompt,omitempty"` + ResponseFormat *string `json:"response_format,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + Include []string `json:"include,omitempty"` + TimestampGranularities []string `json:"timestamp_granularities,omitempty"` + Stream *bool `json:"stream,omitempty"` +} + +//response types + +// OpenAIChatResponse represents an OpenAI chat completion response +type OpenAIChatResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Data []struct { // Embedding data + Object string `json:"object"` + Embedding any `json:"embedding"` + Index int `json:"index"` + } `json:"data,omitempty"` + Created int `json:"created"` + Model string `json:"model"` + Choices []schemas.BifrostResponseChoice `json:"choices"` + Usage *schemas.LLMUsage `json:"usage,omitempty"` // Reuse schema type + ServiceTier *string `json:"service_tier,omitempty"` + SystemFingerprint *string `json:"system_fingerprint,omitempty"` +} + +// OpenAIChatError represents an OpenAI chat completion error response +type OpenAIChatError struct { + EventID string `json:"event_id"` // Unique identifier for the error event + Type string `json:"type"` // Type of error + Error struct { + Type string `json:"type"` // Error type + Code string `json:"code"` // Error code + Message string `json:"message"` // Error message + Param interface{} `json:"param"` // Parameter that caused the error + EventID string `json:"event_id"` // Event ID for tracking + } `json:"error"` +} + +// OpenAIChatErrorStruct represents the error structure of an OpenAI chat completion error response +type OpenAIChatErrorStruct struct { + Type string `json:"type"` // Error type + Code string `json:"code"` // Error code + Message string `json:"message"` // Error message + Param interface{} `json:"param"` // Parameter that caused the error + EventID string `json:"event_id"` // Event ID for tracking +} + +// OpenAIStreamChoice represents a choice in a streaming response chunk +type OpenAIStreamChoice struct { + Index int `json:"index"` + Delta *OpenAIStreamDelta `json:"delta,omitempty"` + FinishReason *string `json:"finish_reason,omitempty"` + LogProbs *schemas.LogProbs `json:"logprobs,omitempty"` +} + +// OpenAIStreamDelta represents the incremental content in a streaming chunk +type OpenAIStreamDelta struct { + Role *string `json:"role,omitempty"` + Content *string `json:"content,omitempty"` + ToolCalls *[]schemas.ToolCall `json:"tool_calls,omitempty"` +} + +// OpenAIStreamResponse represents a single chunk in the OpenAI streaming response +type OpenAIStreamResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Created int `json:"created"` + Model string `json:"model"` + SystemFingerprint *string `json:"system_fingerprint,omitempty"` + Choices []OpenAIStreamChoice `json:"choices"` + Usage *schemas.LLMUsage `json:"usage,omitempty"` +} diff --git a/core/schemas/api/utils.go b/core/schemas/api/utils.go new file mode 100644 index 0000000000..a4b68d4b5a --- /dev/null +++ b/core/schemas/api/utils.go @@ -0,0 +1,87 @@ +package api + +import ( + "strings" +) + +// IsOpenAIModel checks for OpenAI model patterns +func IsOpenAIModel(model string) bool { + // Exclude Azure models to prevent overlap + if strings.Contains(model, "azure/") { + return false + } + + openaiPatterns := []string{ + "gpt", "davinci", "curie", "babbage", "ada", "o1", "o3", "o4", + "text-embedding", "dall-e", "whisper", "tts", "chatgpt", + } + + return matchesAnyPattern(model, openaiPatterns) +} + +// IsAzureModel checks for Azure OpenAI specific patterns +func IsAzureModel(model string) bool { + azurePatterns := []string{ + "azure", "model-router", "computer-use-preview", + } + + return matchesAnyPattern(model, azurePatterns) +} + +// IsAnthropicModel checks for Anthropic Claude model patterns +func IsAnthropicModel(model string) bool { + anthropicPatterns := []string{ + "claude", "anthropic/", + } + + return matchesAnyPattern(model, anthropicPatterns) +} + +// IsVertexModel checks for Google Vertex AI model patterns +func IsVertexModel(model string) bool { + vertexPatterns := []string{ + "gemini", "palm", "bison", "gecko", "vertex/", "google/", + } + + return matchesAnyPattern(model, vertexPatterns) +} + +// IsBedrockModel checks for AWS Bedrock model patterns +func IsBedrockModel(model string) bool { + bedrockPatterns := []string{ + "bedrock", "bedrock.amazonaws.com/", "bedrock/", + "amazon.titan", "amazon.nova", "aws/amazon.", + "ai21.jamba", "ai21.j2", "aws/ai21.", + "meta.llama", "aws/meta.", + "stability.stable-diffusion", "stability.sd3", "aws/stability.", + "anthropic.claude", "aws/anthropic.", + "cohere.command", "cohere.embed", "aws/cohere.", + "mistral.mistral", "mistral.mixtral", "aws/mistral.", + "titan-text", "titan-embed", "nova-micro", "nova-lite", "nova-pro", + "jamba-instruct", "j2-ultra", "j2-mid", + "llama-2", "llama-3", "llama-3.1", "llama-3.2", + "stable-diffusion-xl", "sd3-large", + } + + return matchesAnyPattern(model, bedrockPatterns) +} + +// IsCohereModel checks for Cohere model patterns +func IsCohereModel(model string) bool { + coherePatterns := []string{ + "command-", "embed-", "cohere", + } + + return matchesAnyPattern(model, coherePatterns) +} + +// matchesAnyPattern checks if the model matches any of the given patterns +func matchesAnyPattern(model string, patterns []string) bool { + model = strings.ToLower(model) // <- normalise once + for _, pattern := range patterns { + if strings.Contains(model, pattern) { + return true + } + } + return false +} \ No newline at end of file diff --git a/core/schemas/bifrost.go b/core/schemas/bifrost.go index 734796d195..72d6cde06c 100644 --- a/core/schemas/bifrost.go +++ b/core/schemas/bifrost.go @@ -163,18 +163,20 @@ type Fallback struct { // mapped to the provider's parameters. type ModelParameters struct { ToolChoice *ToolChoice `json:"tool_choice,omitempty"` // Whether to call a tool - Tools *[]Tool `json:"tools,omitempty"` // Tools to use + Tools []Tool `json:"tools,omitempty"` // Tools to use Temperature *float64 `json:"temperature,omitempty"` // Controls randomness in the output TopP *float64 `json:"top_p,omitempty"` // Controls diversity via nucleus sampling TopK *int `json:"top_k,omitempty"` // Controls diversity via top-k sampling MaxTokens *int `json:"max_tokens,omitempty"` // Maximum number of tokens to generate - StopSequences *[]string `json:"stop_sequences,omitempty"` // Sequences that stop generation + StopSequences []string `json:"stop_sequences,omitempty"` // Sequences that stop generation PresencePenalty *float64 `json:"presence_penalty,omitempty"` // Penalizes repeated tokens FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` // Penalizes frequent tokens ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` // Enables parallel tool calls EncodingFormat *string `json:"encoding_format,omitempty"` // Format for embedding output (e.g., "float", "base64") Dimensions *int `json:"dimensions,omitempty"` // Number of dimensions for embedding output - User *string `json:"user,omitempty"` // User identifier for tracking + User *string `json:"user,omitempty"` + N *int `json:"n,omitempty"` + Logprobs *bool `json:"logprobs,omitempty"` // Dynamic parameters that can be provider-specific, they are directly // added to the request as is. ExtraParams map[string]interface{} `json:"-"` @@ -228,7 +230,7 @@ type ToolChoiceFunction struct { // ToolChoiceStruct represents a specific tool choice. type ToolChoiceStruct struct { Type ToolChoiceType `json:"type"` // Type of tool choice - Function ToolChoiceFunction `json:"function,omitempty"` // Function to call if type is ToolChoiceTypeFunction + Function *ToolChoiceFunction `json:"function,omitempty"` // Function to call if type is ToolChoiceTypeFunction } // ToolChoice represents how a tool should be chosen for a request. (either a string or a struct) diff --git a/transports/bifrost-http/integrations/anthropic/router.go b/transports/bifrost-http/integrations/anthropic/router.go index e7d13ca8ed..2e389113f8 100644 --- a/transports/bifrost-http/integrations/anthropic/router.go +++ b/transports/bifrost-http/integrations/anthropic/router.go @@ -5,6 +5,7 @@ import ( bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/core/schemas/api" "github.com/maximhq/bifrost/transports/bifrost-http/integrations" ) @@ -20,11 +21,11 @@ func NewAnthropicRouter(client *bifrost.Bifrost) *AnthropicRouter { Path: "/anthropic/v1/messages", Method: "POST", GetRequestTypeInstance: func() interface{} { - return &AnthropicMessageRequest{} + return &api.AnthropicMessageRequest{} }, RequestConverter: func(req interface{}) (*schemas.BifrostRequest, error) { - if anthropicReq, ok := req.(*AnthropicMessageRequest); ok { - return anthropicReq.ConvertToBifrostRequest(), nil + if anthropicReq, ok := req.(*api.AnthropicMessageRequest); ok { + return ConvertToBifrostRequest(anthropicReq), nil } return nil, errors.New("invalid request type") }, diff --git a/transports/bifrost-http/integrations/anthropic/types.go b/transports/bifrost-http/integrations/anthropic/types.go index 0833a3824e..8df0280da9 100644 --- a/transports/bifrost-http/integrations/anthropic/types.go +++ b/transports/bifrost-http/integrations/anthropic/types.go @@ -3,189 +3,17 @@ package anthropic import ( "encoding/json" "fmt" + "log" + "github.com/bytedance/sonic" bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/core/schemas/api" "github.com/maximhq/bifrost/transports/bifrost-http/integrations" ) -var fnTypePtr = bifrost.Ptr(string(schemas.ToolChoiceTypeFunction)) - -// AnthropicContentBlock represents content in Anthropic message format -type AnthropicContentBlock struct { - Type string `json:"type"` // "text", "image", "tool_use", "tool_result" - Text *string `json:"text,omitempty"` // For text content - ToolUseID *string `json:"tool_use_id,omitempty"` // For tool_result content - ID *string `json:"id,omitempty"` // For tool_use content - Name *string `json:"name,omitempty"` // For tool_use content - Input interface{} `json:"input,omitempty"` // For tool_use content - Content AnthropicContent `json:"content,omitempty"` // For tool_result content - Source *AnthropicImageSource `json:"source,omitempty"` // For image content -} - -// AnthropicImageSource represents image source in Anthropic format -type AnthropicImageSource struct { - Type string `json:"type"` // "base64" or "url" - MediaType *string `json:"media_type,omitempty"` // "image/jpeg", "image/png", etc. - Data *string `json:"data,omitempty"` // Base64-encoded image data - URL *string `json:"url,omitempty"` // URL of the image -} - -// AnthropicMessage represents a message in Anthropic format -type AnthropicMessage struct { - Role string `json:"role"` // "user", "assistant" - Content AnthropicContent `json:"content"` // Array of content blocks -} - -type AnthropicContent struct { - ContentStr *string - ContentBlocks *[]AnthropicContentBlock -} - -// AnthropicTool represents a tool in Anthropic format -type AnthropicTool struct { - Name string `json:"name"` - Type *string `json:"type,omitempty"` - Description string `json:"description"` - InputSchema *struct { - Type string `json:"type"` // "object" - Properties map[string]interface{} `json:"properties"` - Required []string `json:"required"` - } `json:"input_schema,omitempty"` -} - -// AnthropicToolChoice represents tool choice in Anthropic format -type AnthropicToolChoice struct { - Type string `json:"type"` // "auto", "any", "tool" - Name string `json:"name,omitempty"` // For type "tool" -} - -// AnthropicMessageRequest represents an Anthropic messages API request -type AnthropicMessageRequest struct { - Model string `json:"model"` - MaxTokens int `json:"max_tokens"` - Messages []AnthropicMessage `json:"messages"` - System *AnthropicContent `json:"system,omitempty"` - Temperature *float64 `json:"temperature,omitempty"` - TopP *float64 `json:"top_p,omitempty"` - TopK *int `json:"top_k,omitempty"` - StopSequences *[]string `json:"stop_sequences,omitempty"` - Stream *bool `json:"stream,omitempty"` - Tools *[]AnthropicTool `json:"tools,omitempty"` - ToolChoice *AnthropicToolChoice `json:"tool_choice,omitempty"` -} - -// IsStreamingRequested implements the StreamingRequest interface -func (r *AnthropicMessageRequest) IsStreamingRequested() bool { - return r.Stream != nil && *r.Stream -} - -// AnthropicMessageResponse represents an Anthropic messages API response -type AnthropicMessageResponse struct { - ID string `json:"id"` - Type string `json:"type"` - Role string `json:"role"` - Content []AnthropicContentBlock `json:"content"` - Model string `json:"model"` - StopReason *string `json:"stop_reason,omitempty"` - StopSequence *string `json:"stop_sequence,omitempty"` - Usage *AnthropicUsage `json:"usage,omitempty"` -} - -// AnthropicUsage represents usage information in Anthropic format -type AnthropicUsage struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` -} - -// AnthropicMessageError represents an Anthropic messages API error response -type AnthropicMessageError struct { - Type string `json:"type"` // always "error" - Error AnthropicMessageErrorStruct `json:"error"` // Error details -} - -// AnthropicMessageErrorStruct represents the error structure of an Anthropic messages API error response -type AnthropicMessageErrorStruct struct { - Type string `json:"type"` // Error type - Message string `json:"message"` // Error message -} - -// AnthropicStreamResponse represents a single chunk in the Anthropic streaming response -// This matches the format expected by Anthropic's streaming API clients -type AnthropicStreamResponse struct { - Type string `json:"type"` - ID *string `json:"id,omitempty"` - Model *string `json:"model,omitempty"` - Index *int `json:"index,omitempty"` - Message *AnthropicStreamMessage `json:"message,omitempty"` - ContentBlock *AnthropicContentBlock `json:"content_block,omitempty"` - Delta *AnthropicStreamDelta `json:"delta,omitempty"` - Usage *AnthropicUsage `json:"usage,omitempty"` -} - -// AnthropicStreamMessage represents the message structure in streaming events -type AnthropicStreamMessage struct { - ID string `json:"id"` - Type string `json:"type"` - Role string `json:"role"` - Content []AnthropicContentBlock `json:"content"` - Model string `json:"model"` - StopReason *string `json:"stop_reason,omitempty"` - StopSequence *string `json:"stop_sequence,omitempty"` - Usage *AnthropicUsage `json:"usage,omitempty"` -} - -// AnthropicStreamDelta represents the incremental content in a streaming chunk -type AnthropicStreamDelta struct { - Type string `json:"type"` - Text *string `json:"text,omitempty"` - Thinking *string `json:"thinking,omitempty"` - PartialJSON *string `json:"partial_json,omitempty"` - StopReason *string `json:"stop_reason,omitempty"` - StopSequence *string `json:"stop_sequence,omitempty"` -} - -// MarshalJSON implements custom JSON marshalling for MessageContent. -// It marshals either ContentStr or ContentBlocks directly without wrapping. -func (mc AnthropicContent) MarshalJSON() ([]byte, error) { - // Validation: ensure only one field is set at a time - if mc.ContentStr != nil && mc.ContentBlocks != nil { - return nil, fmt.Errorf("both ContentStr and ContentBlocks are set; only one should be non-nil") - } - - if mc.ContentStr != nil { - return json.Marshal(*mc.ContentStr) - } - if mc.ContentBlocks != nil { - return json.Marshal(*mc.ContentBlocks) - } - // If both are nil, return null - return json.Marshal(nil) -} - -// UnmarshalJSON implements custom JSON unmarshalling for MessageContent. -// It determines whether "content" is a string or array and assigns to the appropriate field. -// It also handles direct string/array content without a wrapper object. -func (mc *AnthropicContent) UnmarshalJSON(data []byte) error { - // First, try to unmarshal as a direct string - var stringContent string - if err := json.Unmarshal(data, &stringContent); err == nil { - mc.ContentStr = &stringContent - return nil - } - - // Try to unmarshal as a direct array of ContentBlock - var arrayContent []AnthropicContentBlock - if err := json.Unmarshal(data, &arrayContent); err == nil { - mc.ContentBlocks = &arrayContent - return nil - } - - return fmt.Errorf("content field is neither a string nor an array of ContentBlock") -} - // ConvertToBifrostRequest converts an Anthropic messages request to Bifrost format -func (r *AnthropicMessageRequest) ConvertToBifrostRequest() *schemas.BifrostRequest { +func ConvertToBifrostRequest(r *api.AnthropicMessageRequest) *schemas.BifrostRequest { provider, model := integrations.ParseModelString(r.Model, schemas.Anthropic) bifrostReq := &schemas.BifrostRequest{ @@ -268,7 +96,7 @@ func (r *AnthropicMessageRequest) ConvertToBifrostRequest() *schemas.BifrostRequ case "tool_use": if content.ID != nil && content.Name != nil { tc := schemas.ToolCall{ - Type: fnTypePtr, + Type: bifrost.Ptr(string(schemas.ToolChoiceTypeFunction)), ID: content.ID, Function: schemas.FunctionCall{ Name: content.Name, @@ -364,7 +192,7 @@ func (r *AnthropicMessageRequest) ConvertToBifrostRequest() *schemas.BifrostRequ // Convert tools if r.Tools != nil { tools := []schemas.Tool{} - for _, tool := range *r.Tools { + for _, tool := range r.Tools { // Convert input_schema to FunctionParameters params := schemas.FunctionParameters{ Type: "object", @@ -387,7 +215,7 @@ func (r *AnthropicMessageRequest) ConvertToBifrostRequest() *schemas.BifrostRequ if bifrostReq.Params == nil { bifrostReq.Params = &schemas.ModelParameters{} } - bifrostReq.Params.Tools = &tools + bifrostReq.Params.Tools = tools } // Convert tool choice @@ -405,9 +233,9 @@ func (r *AnthropicMessageRequest) ConvertToBifrostRequest() *schemas.BifrostRequ }(), }, } - if r.ToolChoice.Type == "tool" && r.ToolChoice.Name != "" { - toolChoice.ToolChoiceStruct.Function = schemas.ToolChoiceFunction{ - Name: r.ToolChoice.Name, + if r.ToolChoice.Type == "tool" && r.ToolChoice.Name != nil { + toolChoice.ToolChoiceStruct.Function = &schemas.ToolChoiceFunction{ + Name: *r.ToolChoice.Name, } } bifrostReq.Params.ToolChoice = toolChoice @@ -416,25 +244,13 @@ func (r *AnthropicMessageRequest) ConvertToBifrostRequest() *schemas.BifrostRequ return bifrostReq } -// Helper function to convert interface{} to JSON string -func jsonifyInput(input interface{}) string { - if input == nil { - return "{}" - } - jsonBytes, err := json.Marshal(input) - if err != nil { - return "{}" - } - return string(jsonBytes) -} - // DeriveAnthropicFromBifrostResponse converts a Bifrost response to Anthropic format -func DeriveAnthropicFromBifrostResponse(bifrostResp *schemas.BifrostResponse) *AnthropicMessageResponse { +func DeriveAnthropicFromBifrostResponse(bifrostResp *schemas.BifrostResponse) *api.AnthropicMessageResponse { if bifrostResp == nil { return nil } - anthropicResp := &AnthropicMessageResponse{ + anthropicResp := &api.AnthropicMessageResponse{ ID: bifrostResp.ID, Type: "message", Role: string(schemas.ModelChatMessageRoleAssistant), @@ -443,14 +259,14 @@ func DeriveAnthropicFromBifrostResponse(bifrostResp *schemas.BifrostResponse) *A // Convert usage information if bifrostResp.Usage != nil { - anthropicResp.Usage = &AnthropicUsage{ + anthropicResp.Usage = &api.AnthropicUsage{ InputTokens: bifrostResp.Usage.PromptTokens, OutputTokens: bifrostResp.Usage.CompletionTokens, } } // Convert choices to content - var content []AnthropicContentBlock + var content []api.AnthropicContentBlock if len(bifrostResp.Choices) > 0 { choice := bifrostResp.Choices[0] // Anthropic typically returns one choice @@ -463,7 +279,7 @@ func DeriveAnthropicFromBifrostResponse(bifrostResp *schemas.BifrostResponse) *A // Add thinking content if present if choice.Message.AssistantMessage != nil && choice.Message.AssistantMessage.Thought != nil && *choice.Message.AssistantMessage.Thought != "" { - content = append(content, AnthropicContentBlock{ + content = append(content, api.AnthropicContentBlock{ Type: "thinking", Text: choice.Message.AssistantMessage.Thought, }) @@ -471,14 +287,14 @@ func DeriveAnthropicFromBifrostResponse(bifrostResp *schemas.BifrostResponse) *A // Add text content if choice.Message.Content.ContentStr != nil && *choice.Message.Content.ContentStr != "" { - content = append(content, AnthropicContentBlock{ + content = append(content, api.AnthropicContentBlock{ Type: "text", Text: choice.Message.Content.ContentStr, }) } else if choice.Message.Content.ContentBlocks != nil { for _, block := range *choice.Message.Content.ContentBlocks { if block.Text != nil { - content = append(content, AnthropicContentBlock{ + content = append(content, api.AnthropicContentBlock{ Type: "text", Text: block.Text, }) @@ -499,7 +315,7 @@ func DeriveAnthropicFromBifrostResponse(bifrostResp *schemas.BifrostResponse) *A input = map[string]interface{}{} } - content = append(content, AnthropicContentBlock{ + content = append(content, api.AnthropicContentBlock{ Type: "tool_use", ID: toolCall.ID, Name: toolCall.Function.Name, @@ -510,7 +326,7 @@ func DeriveAnthropicFromBifrostResponse(bifrostResp *schemas.BifrostResponse) *A } if content == nil { - content = []AnthropicContentBlock{} + content = []api.AnthropicContentBlock{} } anthropicResp.Content = content @@ -523,7 +339,7 @@ func DeriveAnthropicStreamFromBifrostResponse(bifrostResp *schemas.BifrostRespon return "" } - streamResp := &AnthropicStreamResponse{} + streamResp := &api.AnthropicStreamResponse{} // Handle different streaming event types based on the response content if len(bifrostResp.Choices) > 0 { @@ -537,7 +353,7 @@ func DeriveAnthropicStreamFromBifrostResponse(bifrostResp *schemas.BifrostRespon if delta.Content != nil { streamResp.Type = "content_block_delta" streamResp.Index = &choice.Index - streamResp.Delta = &AnthropicStreamDelta{ + streamResp.Delta = &api.AnthropicStreamDelta{ Type: "text_delta", Text: delta.Content, } @@ -545,7 +361,7 @@ func DeriveAnthropicStreamFromBifrostResponse(bifrostResp *schemas.BifrostRespon // Handle thinking content deltas streamResp.Type = "content_block_delta" streamResp.Index = &choice.Index - streamResp.Delta = &AnthropicStreamDelta{ + streamResp.Delta = &api.AnthropicStreamDelta{ Type: "thinking_delta", Thinking: delta.Thought, } @@ -557,7 +373,7 @@ func DeriveAnthropicStreamFromBifrostResponse(bifrostResp *schemas.BifrostRespon // Tool use start event streamResp.Type = "content_block_start" streamResp.Index = &choice.Index - streamResp.ContentBlock = &AnthropicContentBlock{ + streamResp.ContentBlock = &api.AnthropicContentBlock{ Type: "tool_use", ID: toolCall.ID, Name: toolCall.Function.Name, @@ -566,7 +382,7 @@ func DeriveAnthropicStreamFromBifrostResponse(bifrostResp *schemas.BifrostRespon // Tool input delta streamResp.Type = "content_block_delta" streamResp.Index = &choice.Index - streamResp.Delta = &AnthropicStreamDelta{ + streamResp.Delta = &api.AnthropicStreamDelta{ Type: "input_json_delta", PartialJSON: &toolCall.Function.Arguments, } @@ -574,7 +390,7 @@ func DeriveAnthropicStreamFromBifrostResponse(bifrostResp *schemas.BifrostRespon } else if choice.FinishReason != nil && *choice.FinishReason != "" { // Handle finish reason streamResp.Type = "message_delta" - streamResp.Delta = &AnthropicStreamDelta{ + streamResp.Delta = &api.AnthropicStreamDelta{ Type: "message_delta", StopReason: choice.FinishReason, } @@ -585,7 +401,7 @@ func DeriveAnthropicStreamFromBifrostResponse(bifrostResp *schemas.BifrostRespon streamResp.Type = "message_start" // Create message start event - streamMessage := &AnthropicStreamMessage{ + streamMessage := &api.AnthropicStreamMessage{ ID: bifrostResp.ID, Type: "message", Role: string(choice.BifrostNonStreamResponseChoice.Message.Role), @@ -593,9 +409,9 @@ func DeriveAnthropicStreamFromBifrostResponse(bifrostResp *schemas.BifrostRespon } // Convert content - var content []AnthropicContentBlock + var content []api.AnthropicContentBlock if choice.BifrostNonStreamResponseChoice.Message.Content.ContentStr != nil { - content = append(content, AnthropicContentBlock{ + content = append(content, api.AnthropicContentBlock{ Type: "text", Text: choice.BifrostNonStreamResponseChoice.Message.Content.ContentStr, }) @@ -611,7 +427,7 @@ func DeriveAnthropicStreamFromBifrostResponse(bifrostResp *schemas.BifrostRespon if streamResp.Type == "" { streamResp.Type = "message_delta" } - streamResp.Usage = &AnthropicUsage{ + streamResp.Usage = &api.AnthropicUsage{ InputTokens: bifrostResp.Usage.PromptTokens, OutputTokens: bifrostResp.Usage.CompletionTokens, } @@ -629,7 +445,7 @@ func DeriveAnthropicStreamFromBifrostResponse(bifrostResp *schemas.BifrostRespon if streamResp.Type == "" { streamResp.Type = "content_block_delta" streamResp.Index = bifrost.Ptr(0) - streamResp.Delta = &AnthropicStreamDelta{ + streamResp.Delta = &api.AnthropicStreamDelta{ Type: "text_delta", Text: bifrost.Ptr(""), } @@ -646,7 +462,7 @@ func DeriveAnthropicStreamFromBifrostResponse(bifrostResp *schemas.BifrostRespon } // DeriveAnthropicErrorFromBifrostError derives a AnthropicMessageError from a BifrostError -func DeriveAnthropicErrorFromBifrostError(bifrostErr *schemas.BifrostError) *AnthropicMessageError { +func DeriveAnthropicErrorFromBifrostError(bifrostErr *schemas.BifrostError) *api.AnthropicMessageError { if bifrostErr == nil { return nil } @@ -658,7 +474,7 @@ func DeriveAnthropicErrorFromBifrostError(bifrostErr *schemas.BifrostError) *Ant } // Handle nested error fields with nil checks - errorStruct := AnthropicMessageErrorStruct{ + errorStruct := api.AnthropicMessageErrorStruct{ Type: "", Message: bifrostErr.Error.Message, } @@ -667,7 +483,7 @@ func DeriveAnthropicErrorFromBifrostError(bifrostErr *schemas.BifrostError) *Ant errorStruct.Type = *bifrostErr.Error.Type } - return &AnthropicMessageError{ + return &api.AnthropicMessageError{ Type: errorType, Error: errorStruct, } @@ -689,3 +505,16 @@ func DeriveAnthropicStreamFromBifrostError(bifrostErr *schemas.BifrostError) str // Format as Anthropic SSE error event return fmt.Sprintf("event: error\ndata: %s\n\n", jsonData) } + +// Helper function to convert interface{} to JSON string +func jsonifyInput(input interface{}) string { + if input == nil { + return "{}" + } + jsonBytes, err := sonic.Marshal(input) + if err != nil { + log.Printf("Failed to marshal tool input: %v", err) + return "{}" + } + return string(jsonBytes) +} diff --git a/transports/bifrost-http/integrations/genai/types.go b/transports/bifrost-http/integrations/genai/types.go index d0131d4e38..0e03fae755 100644 --- a/transports/bifrost-http/integrations/genai/types.go +++ b/transports/bifrost-http/integrations/genai/types.go @@ -380,7 +380,7 @@ func (r *GeminiChatRequest) ConvertToBifrostRequest() *schemas.BifrostRequest { } if len(tools) > 0 { - bifrostReq.Params.Tools = &tools + bifrostReq.Params.Tools = tools } } @@ -420,7 +420,7 @@ func (r *GeminiChatRequest) convertGenerationConfigToParams() *schemas.ModelPara params.ExtraParams["candidate_count"] = config.CandidateCount } if len(config.StopSequences) > 0 { - params.StopSequences = &config.StopSequences + params.StopSequences = config.StopSequences } if config.PresencePenalty != nil { params.PresencePenalty = bifrost.Ptr(float64(*config.PresencePenalty)) diff --git a/transports/bifrost-http/integrations/litellm/router.go b/transports/bifrost-http/integrations/litellm/router.go index cede7a42f7..74feafa294 100644 --- a/transports/bifrost-http/integrations/litellm/router.go +++ b/transports/bifrost-http/integrations/litellm/router.go @@ -7,6 +7,7 @@ import ( bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/core/schemas/api" "github.com/maximhq/bifrost/transports/bifrost-http/integrations" "github.com/maximhq/bifrost/transports/bifrost-http/integrations/anthropic" "github.com/maximhq/bifrost/transports/bifrost-http/integrations/genai" @@ -89,9 +90,9 @@ func NewLiteLLMRouter(client *bifrost.Bifrost) *LiteLLMRouter { var actualReq interface{} switch provider { case schemas.OpenAI, schemas.Azure: - actualReq = &openai.OpenAIChatRequest{} + actualReq = &api.OpenAIChatRequest{} case schemas.Anthropic: - actualReq = &anthropic.AnthropicMessageRequest{} + actualReq = &api.AnthropicMessageRequest{} case schemas.Vertex: actualReq = &genai.GeminiChatRequest{} default: @@ -122,13 +123,13 @@ func NewLiteLLMRouter(client *bifrost.Bifrost) *LiteLLMRouter { // Handle different provider-specific request types switch actualReq := wrapper.ActualRequest.(type) { - case *openai.OpenAIChatRequest: - bifrostReq := actualReq.ConvertToBifrostRequest() + case *api.OpenAIChatRequest: + bifrostReq := openai.ConvertChatRequestToBifrostRequest(actualReq) bifrostReq.Provider = wrapper.Provider return bifrostReq, nil - case *anthropic.AnthropicMessageRequest: - bifrostReq := actualReq.ConvertToBifrostRequest() + case *api.AnthropicMessageRequest: + bifrostReq := anthropic.ConvertToBifrostRequest(actualReq) bifrostReq.Provider = wrapper.Provider return bifrostReq, nil diff --git a/transports/bifrost-http/integrations/openai/router.go b/transports/bifrost-http/integrations/openai/router.go index 226e8abbb9..889a2beec2 100644 --- a/transports/bifrost-http/integrations/openai/router.go +++ b/transports/bifrost-http/integrations/openai/router.go @@ -7,6 +7,7 @@ import ( bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/core/schemas/api" "github.com/maximhq/bifrost/transports/bifrost-http/integrations" "github.com/valyala/fasthttp" ) @@ -30,11 +31,11 @@ func NewOpenAIRouter(client *bifrost.Bifrost) *OpenAIRouter { Path: path, Method: "POST", GetRequestTypeInstance: func() interface{} { - return &OpenAIChatRequest{} + return &api.OpenAIChatRequest{} }, RequestConverter: func(req interface{}) (*schemas.BifrostRequest, error) { - if openaiReq, ok := req.(*OpenAIChatRequest); ok { - return openaiReq.ConvertToBifrostRequest(), nil + if openaiReq, ok := req.(*api.OpenAIChatRequest); ok { + return ConvertChatRequestToBifrostRequest(openaiReq), nil } return nil, errors.New("invalid request type") }, @@ -64,11 +65,11 @@ func NewOpenAIRouter(client *bifrost.Bifrost) *OpenAIRouter { Path: path, Method: "POST", GetRequestTypeInstance: func() interface{} { - return &OpenAISpeechRequest{} + return &api.OpenAISpeechRequest{} }, RequestConverter: func(req interface{}) (*schemas.BifrostRequest, error) { - if speechReq, ok := req.(*OpenAISpeechRequest); ok { - return speechReq.ConvertToBifrostRequest(), nil + if speechReq, ok := req.(*api.OpenAISpeechRequest); ok { + return ConvertSpeechRequestToBifrostRequest(speechReq), nil } return nil, errors.New("invalid speech request type") }, @@ -103,12 +104,12 @@ func NewOpenAIRouter(client *bifrost.Bifrost) *OpenAIRouter { Path: path, Method: "POST", GetRequestTypeInstance: func() interface{} { - return &OpenAITranscriptionRequest{} + return &api.OpenAITranscriptionRequest{} }, RequestParser: parseTranscriptionMultipartRequest, // Handle multipart form parsing RequestConverter: func(req interface{}) (*schemas.BifrostRequest, error) { - if transcriptionReq, ok := req.(*OpenAITranscriptionRequest); ok { - return transcriptionReq.ConvertToBifrostRequest(), nil + if transcriptionReq, ok := req.(*api.OpenAITranscriptionRequest); ok { + return ConvertTranscriptionRequestToBifrostRequest(transcriptionReq), nil } return nil, errors.New("invalid transcription request type") }, @@ -136,7 +137,7 @@ func NewOpenAIRouter(client *bifrost.Bifrost) *OpenAIRouter { // parseTranscriptionMultipartRequest is a RequestParser that handles multipart/form-data for transcription requests func parseTranscriptionMultipartRequest(ctx *fasthttp.RequestCtx, req interface{}) error { - transcriptionReq, ok := req.(*OpenAITranscriptionRequest) + transcriptionReq, ok := req.(*api.OpenAITranscriptionRequest) if !ok { return errors.New("invalid request type for transcription") } diff --git a/transports/bifrost-http/integrations/openai/types.go b/transports/bifrost-http/integrations/openai/types.go index b11ae1594f..90538222d7 100644 --- a/transports/bifrost-http/integrations/openai/types.go +++ b/transports/bifrost-http/integrations/openai/types.go @@ -2,133 +2,12 @@ package openai import ( "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/core/schemas/api" "github.com/maximhq/bifrost/transports/bifrost-http/integrations" ) -// OpenAIChatRequest represents an OpenAI chat completion request -type OpenAIChatRequest struct { - Model string `json:"model"` - Messages []schemas.BifrostMessage `json:"messages"` - MaxTokens *int `json:"max_tokens,omitempty"` - Temperature *float64 `json:"temperature,omitempty"` - TopP *float64 `json:"top_p,omitempty"` - N *int `json:"n,omitempty"` - Stop interface{} `json:"stop,omitempty"` - PresencePenalty *float64 `json:"presence_penalty,omitempty"` - FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` - LogitBias map[string]float64 `json:"logit_bias,omitempty"` - User *string `json:"user,omitempty"` - Tools *[]schemas.Tool `json:"tools,omitempty"` // Reuse schema type - ToolChoice *schemas.ToolChoice `json:"tool_choice,omitempty"` - Stream *bool `json:"stream,omitempty"` - LogProbs *bool `json:"logprobs,omitempty"` - TopLogProbs *int `json:"top_logprobs,omitempty"` - ResponseFormat interface{} `json:"response_format,omitempty"` - Seed *int `json:"seed,omitempty"` -} - -// OpenAISpeechRequest represents an OpenAI speech synthesis request -type OpenAISpeechRequest struct { - Model string `json:"model"` - Input string `json:"input"` - Voice string `json:"voice"` - ResponseFormat *string `json:"response_format,omitempty"` - Speed *float64 `json:"speed,omitempty"` - Instructions *string `json:"instructions,omitempty"` - StreamFormat *string `json:"stream_format,omitempty"` -} - -// OpenAITranscriptionRequest represents an OpenAI transcription request -// Note: This is used for JSON body parsing, actual form parsing is handled in the router -type OpenAITranscriptionRequest struct { - Model string `json:"model"` - File []byte `json:"file"` // Binary audio data - Language *string `json:"language,omitempty"` - Prompt *string `json:"prompt,omitempty"` - ResponseFormat *string `json:"response_format,omitempty"` - Temperature *float64 `json:"temperature,omitempty"` - Include []string `json:"include,omitempty"` - TimestampGranularities []string `json:"timestamp_granularities,omitempty"` - Stream *bool `json:"stream,omitempty"` -} - -// IsStreamingRequested implements the StreamingRequest interface -func (r *OpenAIChatRequest) IsStreamingRequested() bool { - return r.Stream != nil && *r.Stream -} - -// IsStreamingRequested implements the StreamingRequest interface for speech -func (r *OpenAISpeechRequest) IsStreamingRequested() bool { - return r.StreamFormat != nil && *r.StreamFormat == "sse" -} - -// IsStreamingRequested implements the StreamingRequest interface for transcription -func (r *OpenAITranscriptionRequest) IsStreamingRequested() bool { - return r.Stream != nil && *r.Stream -} - -// OpenAIChatResponse represents an OpenAI chat completion response -type OpenAIChatResponse struct { - ID string `json:"id"` - Object string `json:"object"` - Created int `json:"created"` - Model string `json:"model"` - Choices []schemas.BifrostResponseChoice `json:"choices"` - Usage *schemas.LLMUsage `json:"usage,omitempty"` // Reuse schema type - ServiceTier *string `json:"service_tier,omitempty"` - SystemFingerprint *string `json:"system_fingerprint,omitempty"` -} - -// OpenAIChatError represents an OpenAI chat completion error response -type OpenAIChatError struct { - EventID string `json:"event_id"` // Unique identifier for the error event - Type string `json:"type"` // Type of error - Error struct { - Type string `json:"type"` // Error type - Code string `json:"code"` // Error code - Message string `json:"message"` // Error message - Param interface{} `json:"param"` // Parameter that caused the error - EventID string `json:"event_id"` // Event ID for tracking - } `json:"error"` -} - -// OpenAIChatErrorStruct represents the error structure of an OpenAI chat completion error response -type OpenAIChatErrorStruct struct { - Type string `json:"type"` // Error type - Code string `json:"code"` // Error code - Message string `json:"message"` // Error message - Param interface{} `json:"param"` // Parameter that caused the error - EventID string `json:"event_id"` // Event ID for tracking -} - -// OpenAIStreamChoice represents a choice in a streaming response chunk -type OpenAIStreamChoice struct { - Index int `json:"index"` - Delta *OpenAIStreamDelta `json:"delta,omitempty"` - FinishReason *string `json:"finish_reason,omitempty"` - LogProbs *schemas.LogProbs `json:"logprobs,omitempty"` -} - -// OpenAIStreamDelta represents the incremental content in a streaming chunk -type OpenAIStreamDelta struct { - Role *string `json:"role,omitempty"` - Content *string `json:"content,omitempty"` - ToolCalls *[]schemas.ToolCall `json:"tool_calls,omitempty"` -} - -// OpenAIStreamResponse represents a single chunk in the OpenAI streaming response -type OpenAIStreamResponse struct { - ID string `json:"id"` - Object string `json:"object"` - Created int `json:"created"` - Model string `json:"model"` - SystemFingerprint *string `json:"system_fingerprint,omitempty"` - Choices []OpenAIStreamChoice `json:"choices"` - Usage *schemas.LLMUsage `json:"usage,omitempty"` -} - // ConvertToBifrostRequest converts an OpenAI chat request to Bifrost format -func (r *OpenAIChatRequest) ConvertToBifrostRequest() *schemas.BifrostRequest { +func ConvertChatRequestToBifrostRequest(r *api.OpenAIChatRequest) *schemas.BifrostRequest { provider, model := integrations.ParseModelString(r.Model, schemas.OpenAI) bifrostReq := &schemas.BifrostRequest{ @@ -140,13 +19,13 @@ func (r *OpenAIChatRequest) ConvertToBifrostRequest() *schemas.BifrostRequest { } // Map extra parameters and tool settings - bifrostReq.Params = r.convertParameters() + bifrostReq.Params = convertParameters(r) return bifrostReq } // ConvertToBifrostRequest converts an OpenAI speech request to Bifrost format -func (r *OpenAISpeechRequest) ConvertToBifrostRequest() *schemas.BifrostRequest { +func ConvertSpeechRequestToBifrostRequest(r *api.OpenAISpeechRequest) *schemas.BifrostRequest { provider, model := integrations.ParseModelString(r.Model, schemas.OpenAI) // Create speech input @@ -176,13 +55,13 @@ func (r *OpenAISpeechRequest) ConvertToBifrostRequest() *schemas.BifrostRequest } // Map parameters - bifrostReq.Params = r.convertSpeechParameters() + bifrostReq.Params = convertSpeechParameters(r) return bifrostReq } // ConvertToBifrostRequest converts an OpenAI transcription request to Bifrost format -func (r *OpenAITranscriptionRequest) ConvertToBifrostRequest() *schemas.BifrostRequest { +func ConvertTranscriptionRequestToBifrostRequest(r *api.OpenAITranscriptionRequest) *schemas.BifrostRequest { provider, model := integrations.ParseModelString(r.Model, schemas.OpenAI) // Create transcription input @@ -210,14 +89,14 @@ func (r *OpenAITranscriptionRequest) ConvertToBifrostRequest() *schemas.BifrostR } // Map parameters - bifrostReq.Params = r.convertTranscriptionParameters() + bifrostReq.Params = convertTranscriptionParameters(r) return bifrostReq } // convertParameters converts OpenAI request parameters to Bifrost ModelParameters // using direct field access for better performance and type safety. -func (r *OpenAIChatRequest) convertParameters() *schemas.ModelParameters { +func convertParameters(r *api.OpenAIChatRequest) *schemas.ModelParameters { params := &schemas.ModelParameters{ ExtraParams: make(map[string]interface{}), } @@ -270,7 +149,7 @@ func (r *OpenAIChatRequest) convertParameters() *schemas.ModelParameters { } // convertSpeechParameters converts OpenAI speech request parameters to Bifrost ModelParameters -func (r *OpenAISpeechRequest) convertSpeechParameters() *schemas.ModelParameters { +func convertSpeechParameters(r *api.OpenAISpeechRequest) *schemas.ModelParameters { params := &schemas.ModelParameters{ ExtraParams: make(map[string]interface{}), } @@ -284,7 +163,7 @@ func (r *OpenAISpeechRequest) convertSpeechParameters() *schemas.ModelParameters } // convertTranscriptionParameters converts OpenAI transcription request parameters to Bifrost ModelParameters -func (r *OpenAITranscriptionRequest) convertTranscriptionParameters() *schemas.ModelParameters { +func convertTranscriptionParameters(r *api.OpenAITranscriptionRequest) *schemas.ModelParameters { params := &schemas.ModelParameters{ ExtraParams: make(map[string]interface{}), } @@ -304,12 +183,12 @@ func (r *OpenAITranscriptionRequest) convertTranscriptionParameters() *schemas.M } // DeriveOpenAIFromBifrostResponse converts a Bifrost response to OpenAI format -func DeriveOpenAIFromBifrostResponse(bifrostResp *schemas.BifrostResponse) *OpenAIChatResponse { +func DeriveOpenAIFromBifrostResponse(bifrostResp *schemas.BifrostResponse) *api.OpenAIChatResponse { if bifrostResp == nil { return nil } - openaiResp := &OpenAIChatResponse{ + openaiResp := &api.OpenAIChatResponse{ ID: bifrostResp.ID, Object: bifrostResp.Object, Created: bifrostResp.Created, @@ -341,7 +220,7 @@ func DeriveOpenAITranscriptionFromBifrostResponse(bifrostResp *schemas.BifrostRe } // DeriveOpenAIErrorFromBifrostError derives a OpenAIChatError from a BifrostError -func DeriveOpenAIErrorFromBifrostError(bifrostErr *schemas.BifrostError) *OpenAIChatError { +func DeriveOpenAIErrorFromBifrostError(bifrostErr *schemas.BifrostError) *api.OpenAIChatError { if bifrostErr == nil { return nil } @@ -358,7 +237,7 @@ func DeriveOpenAIErrorFromBifrostError(bifrostErr *schemas.BifrostError) *OpenAI } // Handle nested error fields with nil checks - errorStruct := OpenAIChatErrorStruct{ + errorStruct := api.OpenAIChatErrorStruct{ Type: "", Code: "", Message: bifrostErr.Error.Message, @@ -378,7 +257,7 @@ func DeriveOpenAIErrorFromBifrostError(bifrostErr *schemas.BifrostError) *OpenAI errorStruct.EventID = *bifrostErr.Error.EventID } - return &OpenAIChatError{ + return &api.OpenAIChatError{ EventID: eventID, Type: errorType, Error: errorStruct, @@ -386,18 +265,18 @@ func DeriveOpenAIErrorFromBifrostError(bifrostErr *schemas.BifrostError) *OpenAI } // DeriveOpenAIStreamFromBifrostError derives an OpenAI streaming error from a BifrostError -func DeriveOpenAIStreamFromBifrostError(bifrostErr *schemas.BifrostError) *OpenAIChatError { +func DeriveOpenAIStreamFromBifrostError(bifrostErr *schemas.BifrostError) *api.OpenAIChatError { // For streaming, we use the same error format as regular OpenAI errors return DeriveOpenAIErrorFromBifrostError(bifrostErr) } // DeriveOpenAIStreamFromBifrostResponse converts a Bifrost response to OpenAI streaming format -func DeriveOpenAIStreamFromBifrostResponse(bifrostResp *schemas.BifrostResponse) *OpenAIStreamResponse { +func DeriveOpenAIStreamFromBifrostResponse(bifrostResp *schemas.BifrostResponse) *api.OpenAIStreamResponse { if bifrostResp == nil { return nil } - streamResp := &OpenAIStreamResponse{ + streamResp := &api.OpenAIStreamResponse{ ID: bifrostResp.ID, Object: "chat.completion.chunk", Created: bifrostResp.Created, @@ -408,17 +287,17 @@ func DeriveOpenAIStreamFromBifrostResponse(bifrostResp *schemas.BifrostResponse) // Convert choices to streaming format for _, choice := range bifrostResp.Choices { - streamChoice := OpenAIStreamChoice{ + streamChoice := api.OpenAIStreamChoice{ Index: choice.Index, FinishReason: choice.FinishReason, } - var delta *OpenAIStreamDelta + var delta *api.OpenAIStreamDelta // Handle streaming vs non-streaming choices if choice.BifrostStreamResponseChoice != nil { // This is a streaming response - use the delta directly - delta = &OpenAIStreamDelta{} + delta = &api.OpenAIStreamDelta{} // Only set fields that are not nil if choice.BifrostStreamResponseChoice.Delta.Role != nil { @@ -432,7 +311,7 @@ func DeriveOpenAIStreamFromBifrostResponse(bifrostResp *schemas.BifrostResponse) } } else if choice.BifrostNonStreamResponseChoice != nil { // This is a non-streaming response - convert message to delta format - delta = &OpenAIStreamDelta{} + delta = &api.OpenAIStreamDelta{} // Convert role role := string(choice.BifrostNonStreamResponseChoice.Message.Role) diff --git a/transports/bifrost-http/integrations/utils.go b/transports/bifrost-http/integrations/utils.go index d81ba8260e..a0d2cacb70 100644 --- a/transports/bifrost-http/integrations/utils.go +++ b/transports/bifrost-http/integrations/utils.go @@ -60,6 +60,7 @@ import ( "github.com/fasthttp/router" bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/core/schemas/api" "github.com/maximhq/bifrost/transports/bifrost-http/lib" "github.com/valyala/fasthttp" ) @@ -618,42 +619,6 @@ func (g *GenericRouter) sendSuccess(ctx *fasthttp.RequestCtx, errorConverter Err ctx.SetBody(responseBody) } -// validProviders is a pre-computed map for efficient O(1) provider validation. -var validProviders = map[schemas.ModelProvider]bool{ - schemas.OpenAI: true, - schemas.Azure: true, - schemas.Anthropic: true, - schemas.Bedrock: true, - schemas.Cohere: true, - schemas.Vertex: true, - schemas.Mistral: true, - schemas.Ollama: true, -} - -// ParseModelString extracts provider and model from a model string. -// For model strings like "anthropic/claude", it returns ("anthropic", "claude"). -// For model strings like "claude", it returns ("", "claude"). -// If the extracted provider is not valid, it treats the whole string as a model name. -func ParseModelString(model string, defaultProvider schemas.ModelProvider) (schemas.ModelProvider, string) { - // Check if model contains a provider prefix (only split on first "/" to preserve model names with "/") - if strings.Contains(model, "/") { - parts := strings.SplitN(model, "/", 2) - if len(parts) == 2 { - extractedProvider := parts[0] - extractedModel := parts[1] - - // Validate that the extracted provider is actually a valid provider - if validProviders[schemas.ModelProvider(extractedProvider)] { - return schemas.ModelProvider(extractedProvider), extractedModel - } - // If extracted provider is not valid, treat the whole string as model name - // This prevents corrupting model names that happen to contain "/" - } - } - // No provider prefix found or invalid provider, return empty provider and the original model - return defaultProvider, model -} - // GetProviderFromModel determines the appropriate provider based on model name patterns // This function uses comprehensive pattern matching to identify the correct provider // for various model naming conventions used across different AI providers. @@ -662,32 +627,32 @@ func GetProviderFromModel(model string) schemas.ModelProvider { modelLower := strings.ToLower(strings.TrimSpace(model)) // Azure OpenAI Models - check first to prevent false positives from OpenAI "gpt" patterns - if isAzureModel(modelLower) { + if api.IsAzureModel(modelLower) { return schemas.Azure } // OpenAI Models - comprehensive pattern matching - if isOpenAIModel(modelLower) { + if api.IsOpenAIModel(modelLower) { return schemas.OpenAI } // Anthropic Models - Claude family - if isAnthropicModel(modelLower) { + if api.IsAnthropicModel(modelLower) { return schemas.Anthropic } // Google Vertex AI Models - Gemini and Palm family - if isVertexModel(modelLower) { + if api.IsVertexModel(modelLower) { return schemas.Vertex } // AWS Bedrock Models - various model providers through Bedrock - if isBedrockModel(modelLower) { + if api.IsBedrockModel(modelLower) { return schemas.Bedrock } // Cohere Models - Command and Embed family - if isCohereModel(modelLower) { + if api.IsCohereModel(modelLower) { return schemas.Cohere } @@ -695,87 +660,6 @@ func GetProviderFromModel(model string) schemas.ModelProvider { return schemas.OpenAI } -// isOpenAIModel checks for OpenAI model patterns -func isOpenAIModel(model string) bool { - // Exclude Azure models to prevent overlap - if strings.Contains(model, "azure/") { - return false - } - - openaiPatterns := []string{ - "gpt", "davinci", "curie", "babbage", "ada", "o1", "o3", "o4", - "text-embedding", "dall-e", "whisper", "tts", "chatgpt", - } - - return matchesAnyPattern(model, openaiPatterns) -} - -// isAzureModel checks for Azure OpenAI specific patterns -func isAzureModel(model string) bool { - azurePatterns := []string{ - "azure", "model-router", "computer-use-preview", - } - - return matchesAnyPattern(model, azurePatterns) -} - -// isAnthropicModel checks for Anthropic Claude model patterns -func isAnthropicModel(model string) bool { - anthropicPatterns := []string{ - "claude", "anthropic/", - } - - return matchesAnyPattern(model, anthropicPatterns) -} - -// isVertexModel checks for Google Vertex AI model patterns -func isVertexModel(model string) bool { - vertexPatterns := []string{ - "gemini", "palm", "bison", "gecko", "vertex/", "google/", - } - - return matchesAnyPattern(model, vertexPatterns) -} - -// isBedrockModel checks for AWS Bedrock model patterns -func isBedrockModel(model string) bool { - bedrockPatterns := []string{ - "bedrock", "bedrock.amazonaws.com/", "bedrock/", - "amazon.titan", "amazon.nova", "aws/amazon.", - "ai21.jamba", "ai21.j2", "aws/ai21.", - "meta.llama", "aws/meta.", - "stability.stable-diffusion", "stability.sd3", "aws/stability.", - "anthropic.claude", "aws/anthropic.", - "cohere.command", "cohere.embed", "aws/cohere.", - "mistral.mistral", "mistral.mixtral", "aws/mistral.", - "titan-text", "titan-embed", "nova-micro", "nova-lite", "nova-pro", - "jamba-instruct", "j2-ultra", "j2-mid", - "llama-2", "llama-3", "llama-3.1", "llama-3.2", - "stable-diffusion-xl", "sd3-large", - } - - return matchesAnyPattern(model, bedrockPatterns) -} - -// isCohereModel checks for Cohere model patterns -func isCohereModel(model string) bool { - coherePatterns := []string{ - "command-", "embed-", "cohere", - } - - return matchesAnyPattern(model, coherePatterns) -} - -// matchesAnyPattern checks if the model matches any of the given patterns -func matchesAnyPattern(model string, patterns []string) bool { - for _, pattern := range patterns { - if strings.Contains(model, pattern) { - return true - } - } - return false -} - // newBifrostError wraps a standard error into a BifrostError with IsBifrostError set to false. // This helper function reduces code duplication when handling non-Bifrost errors. func newBifrostError(err error, message string) *schemas.BifrostError { @@ -796,3 +680,37 @@ func newBifrostError(err error, message string) *schemas.BifrostError { }, } } + +var validProviders = map[schemas.ModelProvider]bool{ + schemas.OpenAI: true, + schemas.Anthropic: true, + schemas.Bedrock: true, + schemas.Cohere: true, + schemas.Vertex: true, + schemas.Mistral: true, + schemas.Ollama: true, +} + +// ParseModelString extracts provider and model from a model string. +// For model strings like "anthropic/claude", it returns ("anthropic", "claude"). +// For model strings like "claude", it returns ("", "claude"). +// If the extracted provider is not valid, it treats the whole string as a model name. +func ParseModelString(model string, defaultProvider schemas.ModelProvider) (schemas.ModelProvider, string) { + // Check if model contains a provider prefix (only split on first "/" to preserve model names with "/") + if strings.Contains(model, "/") { + parts := strings.SplitN(model, "/", 2) + if len(parts) == 2 { + extractedProvider := parts[0] + extractedModel := parts[1] + + // Validate that the extracted provider is actually a valid provider + if validProviders[schemas.ModelProvider(extractedProvider)] { + return schemas.ModelProvider(extractedProvider), extractedModel + } + // If extracted provider is not valid, treat the whole string as model name + // This prevents corrupting model names that happen to contain "/" + } + } + // No provider prefix found or invalid provider, return empty provider and the original model + return defaultProvider, model +} diff --git a/transports/bifrost-http/plugins/logging/main.go b/transports/bifrost-http/plugins/logging/main.go index 0ce7381039..3b5c647137 100644 --- a/transports/bifrost-http/plugins/logging/main.go +++ b/transports/bifrost-http/plugins/logging/main.go @@ -82,7 +82,7 @@ type InitialLogData struct { Params *schemas.ModelParameters SpeechInput *schemas.SpeechInput TranscriptionInput *schemas.TranscriptionInput - Tools *[]schemas.Tool + Tools []schemas.Tool } // LogEntry represents a complete log entry for a request/response cycle @@ -99,7 +99,7 @@ type LogEntry struct { TranscriptionInput *schemas.TranscriptionInput `json:"transcription_input,omitempty"` SpeechOutput *schemas.BifrostSpeech `json:"speech_output,omitempty"` TranscriptionOutput *schemas.BifrostTranscribe `json:"transcription_output,omitempty"` - Tools *[]schemas.Tool `json:"tools,omitempty"` + Tools []schemas.Tool `json:"tools,omitempty"` ToolCalls *[]schemas.ToolCall `json:"tool_calls,omitempty"` Latency *float64 `json:"latency,omitempty"` TokenUsage *schemas.LLMUsage `json:"token_usage,omitempty"` diff --git a/transports/go.mod b/transports/go.mod index c458370df9..3903208941 100644 --- a/transports/go.mod +++ b/transports/go.mod @@ -14,6 +14,8 @@ require ( google.golang.org/genai v1.4.0 ) +replace github.com/maximhq/bifrost/core => ../core + require ( cloud.google.com/go v0.121.0 // indirect cloud.google.com/go/auth v0.16.0 // indirect @@ -33,17 +35,20 @@ require ( github.com/aws/aws-sdk-go-v2/service/sts v1.33.19 // indirect github.com/aws/smithy-go v1.22.3 // indirect github.com/beorn7/perks v1.0.1 // indirect + github.com/bytedance/sonic v1.14.0 // indirect + github.com/bytedance/sonic/loader v0.3.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/cloudwego/base64x v0.1.5 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/go-logr/logr v1.4.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect - github.com/goccy/go-json v0.10.5 // indirect github.com/google/go-cmp v0.7.0 // indirect github.com/google/s2a-go v0.1.9 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.6 // indirect github.com/googleapis/gax-go/v2 v2.14.1 // indirect github.com/gorilla/websocket v1.5.3 // indirect github.com/klauspost/compress v1.18.0 // indirect + github.com/klauspost/cpuid/v2 v2.0.9 // indirect github.com/mark3labs/mcp-go v0.32.0 // indirect github.com/maximhq/maxim-go v0.1.3 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect @@ -52,6 +57,7 @@ require ( github.com/prometheus/procfs v0.15.1 // indirect github.com/savsgio/gotils v0.0.0-20240704082632-aef3928b8a38 // indirect github.com/spf13/cast v1.7.1 // indirect + github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect go.opentelemetry.io/auto/sdk v1.1.0 // indirect @@ -59,6 +65,7 @@ require ( go.opentelemetry.io/otel v1.35.0 // indirect go.opentelemetry.io/otel/metric v1.35.0 // indirect go.opentelemetry.io/otel/trace v1.35.0 // indirect + golang.org/x/arch v0.0.0-20210923205945-b76863e36670 // indirect golang.org/x/crypto v0.38.0 // indirect golang.org/x/net v0.40.0 // indirect golang.org/x/oauth2 v0.30.0 // indirect diff --git a/transports/go.sum b/transports/go.sum index fa306d13e4..66c71ac657 100644 --- a/transports/go.sum +++ b/transports/go.sum @@ -34,8 +34,17 @@ github.com/aws/smithy-go v1.22.3 h1:Z//5NuZCSW6R4PhQ93hShNbyBbn8BWCmCVCt+Q8Io5k= github.com/aws/smithy-go v1.22.3/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/bytedance/sonic v1.14.0 h1:/OfKt8HFw0kh2rj8N0F6C/qPGRESq0BbaNZgcNXXzQQ= +github.com/bytedance/sonic v1.14.0/go.mod h1:WoEbx8WTcFJfzCe0hbmyTGrfjt8PzNEBdxlNUO24NhA= +github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= +github.com/bytedance/sonic/loader v0.3.0 h1:dskwH8edlzNMctoruo8FPTJDF3vLtDT0sXZwvZJyqeA= +github.com/bytedance/sonic/loader v0.3.0/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFosQU6FxH2JmUe6VI= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cloudwego/base64x v0.1.5 h1:XPciSp1xaq2VCSt6lF0phncD4koWyULpl5bUxbfCyP4= +github.com/cloudwego/base64x v0.1.5/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w= +github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/fasthttp/router v1.5.4 h1:oxdThbBwQgsDIYZ3wR1IavsNl6ZS9WdjKukeMikOnC8= @@ -51,8 +60,6 @@ github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= -github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= -github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= @@ -69,6 +76,9 @@ github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aN github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= +github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4= +github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= +github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= @@ -79,8 +89,6 @@ github.com/mark3labs/mcp-go v0.32.0 h1:fgwmbfL2gbd67obg57OfV2Dnrhs1HtSdlY/i5fn7M github.com/mark3labs/mcp-go v0.32.0/go.mod h1:rXqOudj/djTORU/ThxYx8fqEVj/5pvTuuebQ2RC7uk4= github.com/mattn/go-sqlite3 v1.14.28 h1:ThEiQrnbtumT+QMknw63Befp/ce/nUPgBPMlRFEum7A= github.com/mattn/go-sqlite3 v1.14.28/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= -github.com/maximhq/bifrost/core v1.1.11 h1:ir/LLG7xFbavocsa60VWMzDM7uK9E1GKVvyQA27WWF0= -github.com/maximhq/bifrost/core v1.1.11/go.mod h1:yMRCncTgKYBIrECSRVxMbY3BL8CjLbipJlc644jryxc= github.com/maximhq/bifrost/plugins/maxim v1.0.6 h1:m1tWjbmxW9Lz4mDhXclQhZdFt/TrRPbZwFcoWY9ZAEk= github.com/maximhq/bifrost/plugins/maxim v1.0.6/go.mod h1:+D/E498VB4JNTEzG4fYyFJf9WQaq/9FgYrmzl49mLNc= github.com/maximhq/maxim-go v0.1.3 h1:nVzdz3hEjZVxmWHARWIM+Yrn1Jp50qrsK4BA/sz2jj8= @@ -103,8 +111,17 @@ github.com/savsgio/gotils v0.0.0-20240704082632-aef3928b8a38 h1:D0vL7YNisV2yqE55 github.com/savsgio/gotils v0.0.0-20240704082632-aef3928b8a38/go.mod h1:sM7Mt7uEoCeFSCBM+qBrqvEo+/9vdmj19wzp3yzUhmg= github.com/spf13/cast v1.7.1 h1:cuNEagBQEHWN1FnbGEjCXL2szYEXqfJPbP2HNUaca9Y= github.com/spf13/cast v1.7.1/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/fasthttp v1.62.0 h1:8dKRBX/y2rCzyc6903Zu1+3qN0H/d2MsxPPmVNamiH0= @@ -127,6 +144,8 @@ go.opentelemetry.io/otel/sdk/metric v1.35.0 h1:1RriWBmCKgkeHEhM7a2uMjMUfP7MsOF5J go.opentelemetry.io/otel/sdk/metric v1.35.0/go.mod h1:is6XYCUMpcKi+ZsOvfluY5YstFnhW0BidkR+gL+qN+w= go.opentelemetry.io/otel/trace v1.35.0 h1:dPpEfJu1sDIqruz7BHFG3c7528f6ddfSWfFDVt/xgMs= go.opentelemetry.io/otel/trace v1.35.0/go.mod h1:WUk7DtFp1Aw2MkvqGdwiXYDZZNvA/1J8o6xRXLrIkyc= +golang.org/x/arch v0.0.0-20210923205945-b76863e36670 h1:18EFjUmQOcUvxNYSkA6jO9VAiXCnxFY6NyDX0bHDmkU= +golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/crypto v0.38.0 h1:jt+WWG8IZlBnVbomuhg2Mdq0+BBQaHbtqHEFEigjUV8= golang.org/x/crypto v0.38.0/go.mod h1:MvrbAqul58NNYPKnOra203SB9vpuZW0e+RRZV+Ggqjw= golang.org/x/net v0.40.0 h1:79Xs7wF06Gbdcg4kdCCIQArK11Z1hr5POQ6+fIYHNuY= @@ -147,5 +166,8 @@ google.golang.org/grpc v1.72.0 h1:S7UkcVa60b5AAQTaO6ZKamFp1zMZSU0fGDK2WZLbBnM= google.golang.org/grpc v1.72.0/go.mod h1:wH5Aktxcg25y1I3w7H69nHfXdOG3UiadoBtjh3izSDM= google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY= google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50=