diff --git a/bifrost.go b/bifrost.go index ae7454a31d..78c03c4920 100644 --- a/bifrost.go +++ b/bifrost.go @@ -22,7 +22,7 @@ const ( type ChannelMessage struct { interfaces.BifrostRequest - Response chan *interfaces.CompletionResult + Response chan *interfaces.BifrostResponse Err chan error Type RequestType } @@ -179,7 +179,7 @@ func (bifrost *Bifrost) processRequests(provider interfaces.Provider, queue chan defer bifrost.wg[provider.GetProviderKey()].Done() for req := range queue { - var result *interfaces.CompletionResult + var result *interfaces.BifrostResponse var err error key, err := bifrost.SelectKeyFromProviderForModel(provider, req.Model) @@ -234,13 +234,13 @@ func (bifrost *Bifrost) GetProviderQueue(providerKey interfaces.SupportedModelPr return queue, nil } -func (bifrost *Bifrost) TextCompletionRequest(providerKey interfaces.SupportedModelProvider, req *interfaces.BifrostRequest, ctx context.Context) (*interfaces.CompletionResult, error) { +func (bifrost *Bifrost) TextCompletionRequest(providerKey interfaces.SupportedModelProvider, req *interfaces.BifrostRequest, ctx context.Context) (*interfaces.BifrostResponse, error) { queue, err := bifrost.GetProviderQueue(providerKey) if err != nil { return nil, err } - responseChan := make(chan *interfaces.CompletionResult) + responseChan := make(chan *interfaces.BifrostResponse) errorChan := make(chan error) for _, plugin := range bifrost.plugins { @@ -273,13 +273,13 @@ func (bifrost *Bifrost) TextCompletionRequest(providerKey interfaces.SupportedMo } } -func (bifrost *Bifrost) ChatCompletionRequest(providerKey interfaces.SupportedModelProvider, req *interfaces.BifrostRequest, ctx context.Context) (*interfaces.CompletionResult, error) { +func (bifrost *Bifrost) ChatCompletionRequest(providerKey interfaces.SupportedModelProvider, req *interfaces.BifrostRequest, ctx context.Context) (*interfaces.BifrostResponse, error) { queue, err := bifrost.GetProviderQueue(providerKey) if err != nil { return nil, err } - responseChan := make(chan *interfaces.CompletionResult) + responseChan := make(chan *interfaces.BifrostResponse) errorChan := make(chan error) for _, plugin := range bifrost.plugins { diff --git a/interfaces/account.go b/interfaces/account.go index 47ef6b0c14..252c0358e9 100644 --- a/interfaces/account.go +++ b/interfaces/account.go @@ -11,6 +11,7 @@ type Key struct { Weight float64 `json:"weight"` } +// TODO one get config method type Account interface { GetInitiallyConfiguredProviderKeys() ([]SupportedModelProvider, error) GetKeysForProvider(provider Provider) ([]Key, error) diff --git a/interfaces/plugin.go b/interfaces/plugin.go index 0de9d9144a..1d842408e8 100644 --- a/interfaces/plugin.go +++ b/interfaces/plugin.go @@ -15,5 +15,5 @@ type BifrostRequest struct { type Plugin interface { PreHook(ctx context.Context, req *BifrostRequest) (context.Context, *BifrostRequest, error) - PostHook(ctx context.Context, result *CompletionResult) (*CompletionResult, error) + PostHook(ctx context.Context, result *BifrostResponse) (*BifrostResponse, error) } diff --git a/interfaces/provider.go b/interfaces/provider.go index 4d08b32505..e5512e7d02 100644 --- a/interfaces/provider.go +++ b/interfaces/provider.go @@ -2,83 +2,6 @@ package interfaces import "encoding/json" -// LLMUsage represents token usage information -type LLMUsage struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` - Latency *float64 `json:"latency"` -} - -type BilledLLMUsage struct { - PromptTokens *float64 `json:"prompt_tokens"` - CompletionTokens *float64 `json:"completion_tokens"` - SearchUnits *float64 `json:"search_units"` - Classifications *float64 `json:"classifications"` -} - -// LLMInteractionCost represents cost information for LLM interactions -type LLMInteractionCost struct { - Input float64 `json:"input"` - Output float64 `json:"output"` - Total float64 `json:"total"` -} - -// Function represents a function definition for tool calls -type Function struct { - Name string `json:"name"` - Description string `json:"description"` - Parameters interface{} `json:"parameters"` -} - -// Tool represents a tool that can be used with the model -type Tool struct { - Type string `json:"type"` - Function Function `json:"function"` -} - -// ModelParameters represents the parameters for model requests -type ModelParameters struct { - TestRunEntryID *string `json:"testRunEntryId"` - PromptTools *[]string `json:"promptTools"` - ToolChoice *string `json:"toolChoice"` - Tools *[]Tool `json:"tools"` - FunctionCall *string `json:"functionCall"` - Functions *[]Function `json:"functions"` - // Dynamic parameters - ExtraParams map[string]interface{} `json:"-"` -} - -// RequestOptions represents options for model requests -type RequestOptions struct { - UseCache *bool `json:"useCache"` - WaitForModel *bool `json:"waitForModel"` - CompletionType *string `json:"CompletionType"` -} - -// FunctionCall represents a function call in a tool call -type FunctionCall struct { - Name string `json:"name"` - Arguments string `json:"arguments"` -} - -// ToolCall represents a tool call in a message -type ToolCall struct { - Type *string `json:"type"` - ID string `json:"id"` - Name *string `json:"name"` - Input json.RawMessage `json:"input"` - Function *FunctionCall `json:"function"` -} - -type Citation struct { - Start *int `json:"start"` - End *int `json:"end"` - Text *string `json:"text"` - Sources *interface{} `json:"sources"` - Type *string `json:"type"` -} - // ModelChatMessageRole represents the role of a chat message type ModelChatMessageRole string @@ -86,74 +9,10 @@ const ( RoleAssistant ModelChatMessageRole = "assistant" RoleUser ModelChatMessageRole = "user" RoleSystem ModelChatMessageRole = "system" - RoleModel ModelChatMessageRole = "model" + RoleChatbot ModelChatMessageRole = "chatbot" RoleTool ModelChatMessageRole = "tool" ) -// CompletionResponseChoice represents a choice in the completion response -type CompletionResponseChoice struct { - Role ModelChatMessageRole `json:"role"` - Content string `json:"content"` - Image json.RawMessage `json:"image"` - ToolCalls *[]ToolCall `json:"tool_calls"` - Citations *[]Citation `json:"citation"` -} - -// CompletionResultChoice represents a choice in the completion result -type CompletionResultChoice struct { - Index int `json:"index"` - Message CompletionResponseChoice `json:"message"` - StopReason *string `json:"stop_reason"` - Stop *string `json:"stop"` - LogProbs *interface{} `json:"logprobs"` -} - -// ToolResult represents the result of a tool call -type ToolResult struct { - Role ModelChatMessageRole `json:"role"` - Content string `json:"content"` - ToolCallID string `json:"tool_call_id"` -} - -// ToolCallResult represents a single tool call result -type ToolCallResult struct { - Name string `json:"name"` - Result interface{} `json:"result"` - Type string `json:"type"` - ID string `json:"id"` -} - -// ToolCallResults represents a collection of tool call results -type ToolCallResults struct { - Version int `json:"version"` - Results []ToolCallResult `json:"results"` -} - -// CompletionResult represents the complete result from a model completion -type CompletionResult struct { - Error *struct { - Code string `json:"code"` - Message string `json:"message"` - Type string `json:"type"` - } `json:"error"` - ID string `json:"id"` - Choices []CompletionResultChoice `json:"choices"` - ChatHistory *[]CompletionResponseChoice `json:"chat_history"` - ToolCallResult *interface{} `json:"tool_call_result"` - ToolCallResults *ToolCallResults `json:"toolCallResults"` - Provider SupportedModelProvider `json:"provider"` - Usage LLMUsage `json:"usage"` - BilledUsage *BilledLLMUsage `json:"billed_usage"` - Cost *LLMInteractionCost `json:"cost"` - Model string `json:"model"` - Created string `json:"created"` - Params *interface{} `json:"modelParams"` - Trace *struct { - Input interface{} `json:"input"` - Output interface{} `json:"output"` - } `json:"trace"` -} - type SupportedModelProvider string const ( @@ -170,13 +29,55 @@ const ( Lmstudio SupportedModelProvider = "lmstudio" ) +//* Request Structs + +// ModelParameters represents the parameters for model requests +type ModelParameters struct { + TestRunEntryID *string `json:"test_run_entry_id"` + PromptTools *[]string `json:"prompt_tools"` + ToolChoice *string `json:"tool_choice"` + Tools *[]Tool `json:"tools"` + + // Common model parameters + Temperature *float64 `json:"temperature"` + TopP *float64 `json:"top_p"` + TopK *int `json:"top_k"` + MaxTokens *int `json:"max_tokens"` + StopSequences *[]string `json:"stop_sequences"` + PresencePenalty *float64 `json:"presence_penalty"` + FrequencyPenalty *float64 `json:"frequency_penalty"` + + // Dynamic parameters + ExtraParams map[string]interface{} `json:"-"` +} + +type FunctionParameters struct { + Type string `json:"type"` + Required []string `json:"required"` + Properties map[string]interface{} `json:"properties"` +} + +// Function represents a function definition for tool calls +type Function struct { + Name string `json:"name"` + Description string `json:"description"` + Parameters FunctionParameters `json:"parameters"` +} + +// Tool represents a tool that can be used with the model +type Tool struct { + ID *string `json:"id"` + Type string `json:"type"` + Function Function `json:"function"` +} + type Message struct { //* strict check for roles Role ModelChatMessageRole `json:"role"` //* need to make sure either content or imagecontent is provided Content *string `json:"content"` - ImageContent *ImageContent `json:"imageContent"` - ToolCalls *[]ToolCall `json:"toolCall"` + ImageContent *ImageContent `json:"image_content"` + ToolCalls *[]Tool `json:"tool_calls"` } type ImageContent struct { @@ -185,62 +86,100 @@ type ImageContent struct { MediaType string `json:"media_type"` } -// type Content struct { -// Content *string `json:"content"` -// ImageContent *ImageContent `json:"imageContent"` -// } - -// func (content *Content) MarshalJSON() ([]byte, error) { -// if content.Content != nil { -// return []byte(*content.Content), nil -// } else if content.ImageContent != nil { -// return json.Marshal(content.ImageContent) -// } - -// return nil, fmt.Errorf("invalid content") -// } +type NetworkConfig struct { + DefaultRequestTimeoutInSeconds int `json:"default_request_timeout_in_seconds"` +} -// func (content *Content) UnmarshalJSON(val []byte) error { -// var s any -// json.Unmarshal(val, &s) +type MetaConfig struct { + SecretAccessKey string `json:"secret_access_key"` + Region *string `json:"region"` + SessionToken *string `json:"session_token"` + ARN *string `json:"arn"` + InferenceProfiles map[string]string `json:"inference_profiles"` +} -// switch s := s.(type) { -// case string: -// content.Content = &s -// case ImageContent: -// content.ImageContent = &s +type ProviderConfig struct { + NetworkConfig NetworkConfig `json:"network_config"` + MetaConfig *MetaConfig `json:"meta_config"` +} -// default: -// return fmt.Errorf("invalid stop") -// } +//* Response Structs -// return nil -// } +// LLMUsage represents token usage information +type LLMUsage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + Latency *float64 `json:"latency"` +} -type NetworkConfig struct { - DefaultRequestTimeoutInSeconds int `json:"defaultRequestTimeoutInSeconds"` +type BilledLLMUsage struct { + PromptTokens *float64 `json:"prompt_tokens"` + CompletionTokens *float64 `json:"completion_tokens"` + SearchUnits *float64 `json:"search_units"` + Classifications *float64 `json:"classifications"` } -type MetaConfig struct { - BedrockMetaConfig *BedrockMetaConfig `json:"bedrockMetaConfig"` +// LLMInteractionCost represents cost information for LLM interactions +type LLMInteractionCost struct { + Input float64 `json:"input"` + Output float64 `json:"output"` + Total float64 `json:"total"` } -type ProviderConfig struct { - NetworkConfig NetworkConfig `json:"networkConfig"` - MetaConfig *MetaConfig `json:"metaConfig"` +// ToolCall represents a tool call in a message +type ToolCall struct { + Type *string `json:"type"` + ID *string `json:"id"` + Name *string `json:"name"` + Arguments json.RawMessage `json:"arguments"` } -type BedrockMetaConfig struct { - SecretAccessKey string `json:"secretAccessKey"` - Region *string `json:"region"` - SessionToken *string `json:"sessionToken"` - ARN *string `json:"arn"` - InferenceProfiles map[string]string `json:"inferenceProfiles"` +type Citation struct { + Start *int `json:"start"` + End *int `json:"end"` + Text *string `json:"text"` + Sources *interface{} `json:"sources"` + Type *string `json:"type"` } +// BifrostResponseChoiceMessage represents a choice in the completion response +type BifrostResponseChoiceMessage struct { + Role ModelChatMessageRole `json:"role"` + Content string `json:"content"` + Image json.RawMessage `json:"image"` + ToolCalls *[]ToolCall `json:"tool_calls"` + Citations *[]Citation `json:"citations"` +} + +// BifrostResponseChoice represents a choice in the completion result +type BifrostResponseChoice struct { + Index int `json:"index"` + Message BifrostResponseChoiceMessage `json:"message"` + StopReason *string `json:"stop_reason"` + Stop *string `json:"stop"` + LogProbs *interface{} `json:"log_probs"` +} + +// BifrostResponse represents the complete result from a model completion +type BifrostResponse struct { + ID string `json:"id"` + Choices []BifrostResponseChoice `json:"choices"` + ChatHistory *[]BifrostResponseChoiceMessage `json:"chat_history"` + Provider SupportedModelProvider `json:"provider"` + Usage LLMUsage `json:"usage"` + BilledUsage *BilledLLMUsage `json:"billed_usage"` + Cost *LLMInteractionCost `json:"cost"` + Model string `json:"model"` + Created string `json:"created"` + Params *interface{} `json:"model_params"` + RawResponse interface{} `json:"raw_response"` +} + +// TODO third party providers // Provider defines the interface for AI model providers type Provider interface { GetProviderKey() SupportedModelProvider - TextCompletion(model, key, text string, params *ModelParameters) (*CompletionResult, error) - ChatCompletion(model, key string, messages []Message, params *ModelParameters) (*CompletionResult, error) + TextCompletion(model, key, text string, params *ModelParameters) (*BifrostResponse, error) + ChatCompletion(model, key string, messages []Message, params *ModelParameters) (*BifrostResponse, error) } diff --git a/providers/anthropic.go b/providers/anthropic.go index 0fe2394e11..d3fdf66c6a 100644 --- a/providers/anthropic.go +++ b/providers/anthropic.go @@ -28,14 +28,12 @@ type AnthropicChatResponse struct { Type string `json:"type"` Role string `json:"role"` Content []struct { - Type string `json:"type"` - Text string `json:"text,omitempty"` - Thinking string `json:"thinking,omitempty"` - ToolUse *struct { - ID string `json:"id"` - Name string `json:"name"` - Input map[string]interface{} `json:"input"` - } `json:"tool_use,omitempty"` + Type string `json:"type"` + Text string `json:"text,omitempty"` + Thinking string `json:"thinking,omitempty"` + ID string `json:"id"` + Name string `json:"name"` + Input map[string]interface{} `json:"input"` } `json:"content"` Model string `json:"model"` StopReason string `json:"stop_reason,omitempty"` @@ -62,11 +60,25 @@ func (provider *AnthropicProvider) GetProviderKey() interfaces.SupportedModelPro return interfaces.Anthropic } +func (provider *AnthropicProvider) PrepareTextCompletionParams(params map[string]interface{}) map[string]interface{} { + // Check if there is a key entry for max_tokens + if maxTokens, exists := params["max_tokens"]; exists { + // Check if max_tokens_to_sample is already present + if _, exists := params["max_tokens_to_sample"]; !exists { + // If max_tokens_to_sample is not present, rename max_tokens to max_tokens_to_sample + params["max_tokens_to_sample"] = maxTokens + } + delete(params, "max_tokens") + } + return params + +} + // TextCompletion implements text completion using Anthropic's API -func (provider *AnthropicProvider) TextCompletion(model, key, text string, params *interfaces.ModelParameters) (*interfaces.CompletionResult, error) { +func (provider *AnthropicProvider) TextCompletion(model, key, text string, params *interfaces.ModelParameters) (*interfaces.BifrostResponse, error) { startTime := time.Now() - preparedParams := PrepareParams(params) + preparedParams := provider.PrepareTextCompletionParams(PrepareParams(params)) // Merge additional parameters requestBody := MergeConfig(map[string]interface{}{ @@ -98,6 +110,9 @@ func (provider *AnthropicProvider) TextCompletion(model, key, text string, param } defer resp.Body.Close() + // Calculate latency + latency := time.Since(startTime).Seconds() + // Read the response body body, err := io.ReadAll(resp.Body) if err != nil { @@ -106,36 +121,28 @@ func (provider *AnthropicProvider) TextCompletion(model, key, text string, param // Check for error response if resp.StatusCode != http.StatusOK { - var errorResp struct { - Type string `json:"type"` - Error struct { - Type string `json:"type"` - Message string `json:"message"` - } `json:"error"` - } - if err := json.Unmarshal(body, &errorResp); err != nil { - return nil, fmt.Errorf("error response: %s", string(body)) - } - return nil, fmt.Errorf("anthropic error: %s", errorResp.Error.Message) + return nil, fmt.Errorf("anthropic error: %s", string(body)) } // Parse the response var response AnthropicTextResponse - if err := json.Unmarshal(body, &response); err != nil { return nil, fmt.Errorf("error parsing response: %v", err) } - // Calculate latency - latency := time.Since(startTime).Seconds() + // Parse raw response + var rawResponse interface{} + if err := json.Unmarshal(body, &rawResponse); err != nil { + return nil, fmt.Errorf("error parsing raw response: %v", err) + } // Create the completion result - completionResult := &interfaces.CompletionResult{ + completionResult := &interfaces.BifrostResponse{ ID: response.ID, - Choices: []interfaces.CompletionResultChoice{ + Choices: []interfaces.BifrostResponseChoice{ { Index: 0, - Message: interfaces.CompletionResponseChoice{ + Message: interfaces.BifrostResponseChoiceMessage{ Role: interfaces.RoleAssistant, Content: response.Completion, }, @@ -147,17 +154,16 @@ func (provider *AnthropicProvider) TextCompletion(model, key, text string, param TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens, Latency: &latency, }, - Model: response.Model, - Provider: interfaces.Anthropic, + Model: response.Model, + Provider: interfaces.Anthropic, + RawResponse: rawResponse, } return completionResult, nil } // ChatCompletion implements chat completion using Anthropic's API -func (provider *AnthropicProvider) ChatCompletion(model, key string, messages []interfaces.Message, params *interfaces.ModelParameters) (*interfaces.CompletionResult, error) { - startTime := time.Now() - +func (provider *AnthropicProvider) ChatCompletion(model, key string, messages []interfaces.Message, params *interfaces.ModelParameters) (*interfaces.BifrostResponse, error) { // Format messages for Anthropic API var formattedMessages []map[string]interface{} for _, msg := range messages { @@ -169,6 +175,20 @@ func (provider *AnthropicProvider) ChatCompletion(model, key string, messages [] preparedParams := PrepareParams(params) + // Transform tools if present + if params != nil && params.Tools != nil && len(*params.Tools) > 0 { + var tools []map[string]interface{} + for _, tool := range *params.Tools { + tools = append(tools, map[string]interface{}{ + "name": tool.Function.Name, + "description": tool.Function.Description, + "input_schema": tool.Function.Parameters, + }) + } + + preparedParams["tools"] = tools + } + // Merge additional parameters requestBody := MergeConfig(map[string]interface{}{ "model": model, @@ -209,67 +229,64 @@ func (provider *AnthropicProvider) ChatCompletion(model, key string, messages [] return nil, fmt.Errorf("API error: %s", string(body)) } - // Calculate latency - latency := time.Since(startTime).Seconds() - // Decode response var anthropicResponse AnthropicChatResponse - if err := json.Unmarshal(body, &anthropicResponse); err != nil { return nil, fmt.Errorf("error decoding response: %v", err) } - // Process the response into our CompletionResult format - var content string - var toolCalls []interfaces.ToolCall - var stopReason string + // Parse raw response + var rawResponse interface{} + if err := json.Unmarshal(body, &rawResponse); err != nil { + return nil, fmt.Errorf("error parsing raw response: %v", err) + } + + // Process the response into our BifrostResponse format + var choices []interfaces.BifrostResponseChoice // Process content and tool calls - for _, c := range anthropicResponse.Content { + for i, c := range anthropicResponse.Content { + var content string + var toolCalls []interfaces.ToolCall + switch c.Type { case "thinking": - if content == "" { - content = fmt.Sprintf("\n%s\n\n\n", c.Thinking) - } + content = c.Thinking case "text": - content += c.Text + content = c.Text case "tool_use": - if c.ToolUse != nil { - toolCalls = append(toolCalls, interfaces.ToolCall{ - Type: maxim.StrPtr("function"), - ID: c.ToolUse.ID, - Function: &interfaces.FunctionCall{ - Name: c.ToolUse.Name, - Arguments: string(must(json.Marshal(c.ToolUse.Input))), - }, - }) - stopReason = "tool_calls" - } + toolCalls = append(toolCalls, interfaces.ToolCall{ + Type: maxim.StrPtr("function"), + ID: &c.ID, + Name: &c.Name, + Arguments: json.RawMessage(must(json.Marshal(c.Input))), + }) } + + choices = append(choices, interfaces.BifrostResponseChoice{ + Index: i, + Message: interfaces.BifrostResponseChoiceMessage{ + Role: interfaces.RoleAssistant, + Content: content, + ToolCalls: &toolCalls, + }, + StopReason: &anthropicResponse.StopReason, + Stop: anthropicResponse.StopSequence, + }) } // Create the completion result - result := &interfaces.CompletionResult{ - ID: anthropicResponse.ID, - Choices: []interfaces.CompletionResultChoice{ - { - Index: 0, - Message: interfaces.CompletionResponseChoice{ - Role: interfaces.RoleAssistant, - Content: content, - ToolCalls: &toolCalls, - }, - StopReason: &stopReason, - }, - }, + result := &interfaces.BifrostResponse{ + ID: anthropicResponse.ID, + Choices: choices, Usage: interfaces.LLMUsage{ PromptTokens: anthropicResponse.Usage.InputTokens, CompletionTokens: anthropicResponse.Usage.OutputTokens, TotalTokens: anthropicResponse.Usage.InputTokens + anthropicResponse.Usage.OutputTokens, - Latency: &latency, }, - Model: anthropicResponse.Model, - Provider: interfaces.Anthropic, + Model: anthropicResponse.Model, + Provider: interfaces.Anthropic, + RawResponse: rawResponse, } return result, nil diff --git a/providers/bedrock.go b/providers/bedrock.go index 1f5309df8d..2a839215f7 100644 --- a/providers/bedrock.go +++ b/providers/bedrock.go @@ -78,34 +78,46 @@ type BedrockAnthropicImageSource struct { } type BedrockMistralToolCall struct { - ID string `json:"id"` - Function interfaces.FunctionCall `json:"function"` + ID string `json:"id"` + Function interfaces.Function `json:"function"` +} + +type BedrockAnthropicToolCall struct { + ToolSpec BedrockAnthropicToolSpec `json:"toolSpec"` +} + +type BedrockAnthropicToolSpec struct { + Name string `json:"name"` + Description string `json:"description"` + InputSchema struct { + Json interface{} `json:"json"` + } `json:"inputSchema"` } type BedrockProvider struct { client *http.Client - meta *interfaces.BedrockMetaConfig + meta *interfaces.MetaConfig } func NewBedrockProvider(config *interfaces.ProviderConfig) *BedrockProvider { return &BedrockProvider{ client: &http.Client{Timeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds)}, - meta: config.MetaConfig.BedrockMetaConfig, + meta: config.MetaConfig, } } -func (p *BedrockProvider) GetProviderKey() interfaces.SupportedModelProvider { +func (provider *BedrockProvider) GetProviderKey() interfaces.SupportedModelProvider { return interfaces.Bedrock } -func (p *BedrockProvider) PrepareReq(path string, jsonData []byte, accessKey string) (*http.Request, error) { - if p.meta == nil { +func (provider *BedrockProvider) PrepareReq(path string, jsonData []byte, accessKey string) (*http.Request, error) { + if provider.meta == nil { return nil, errors.New("meta config for bedrock is not provided") } region := "us-east-1" - if p.meta.Region != nil { - region = *p.meta.Region + if provider.meta.Region != nil { + region = *provider.meta.Region } // Create the request with the JSON body @@ -114,14 +126,14 @@ func (p *BedrockProvider) PrepareReq(path string, jsonData []byte, accessKey str return nil, fmt.Errorf("error creating request: %v", err) } - if err := SignAWSRequest(req, accessKey, p.meta.SecretAccessKey, p.meta.SessionToken, region, "bedrock"); err != nil { + if err := SignAWSRequest(req, accessKey, provider.meta.SecretAccessKey, provider.meta.SessionToken, region, "bedrock"); err != nil { return nil, err } return req, nil } -func (p *BedrockProvider) GetTextCompletionResult(result []byte, model string) (*interfaces.CompletionResult, error) { +func (provider *BedrockProvider) GetTextCompletionResult(result []byte, model string) (*interfaces.BifrostResponse, error) { switch model { case "anthropic.claude-instant-v1:2": fallthrough @@ -133,11 +145,11 @@ func (p *BedrockProvider) GetTextCompletionResult(result []byte, model string) ( return nil, fmt.Errorf("failed to parse Bedrock response: %v", err) } - return &interfaces.CompletionResult{ - Choices: []interfaces.CompletionResultChoice{ + return &interfaces.BifrostResponse{ + Choices: []interfaces.BifrostResponseChoice{ { Index: 0, - Message: interfaces.CompletionResponseChoice{ + Message: interfaces.BifrostResponseChoiceMessage{ Role: interfaces.RoleAssistant, Content: response.Completion, }, @@ -161,11 +173,11 @@ func (p *BedrockProvider) GetTextCompletionResult(result []byte, model string) ( return nil, fmt.Errorf("failed to parse Bedrock response: %v", err) } - var choices []interfaces.CompletionResultChoice + var choices []interfaces.BifrostResponseChoice for i, output := range response.Outputs { - choices = append(choices, interfaces.CompletionResultChoice{ + choices = append(choices, interfaces.BifrostResponseChoice{ Index: i, - Message: interfaces.CompletionResponseChoice{ + Message: interfaces.BifrostResponseChoiceMessage{ Role: interfaces.RoleAssistant, Content: output.Text, }, @@ -173,7 +185,7 @@ func (p *BedrockProvider) GetTextCompletionResult(result []byte, model string) ( }) } - return &interfaces.CompletionResult{ + return &interfaces.BifrostResponse{ Choices: choices, }, nil } @@ -181,7 +193,7 @@ func (p *BedrockProvider) GetTextCompletionResult(result []byte, model string) ( return nil, fmt.Errorf("invalid model choice: %s", model) } -func (p *BedrockProvider) PrepareChatCompletionMessages(messages []interfaces.Message, model string) (map[string]interface{}, error) { +func (provider *BedrockProvider) PrepareChatCompletionMessages(messages []interfaces.Message, model string) (map[string]interface{}, error) { switch model { case "anthropic.claude-instant-v1:2": fallthrough @@ -263,8 +275,8 @@ func (p *BedrockProvider) PrepareChatCompletionMessages(messages []interfaces.Me if msg.ToolCalls != nil { for _, toolCall := range *msg.ToolCalls { filteredToolCalls = append(filteredToolCalls, BedrockMistralToolCall{ - ID: toolCall.ID, - Function: *toolCall.Function, + ID: *toolCall.ID, + Function: toolCall.Function, }) } } @@ -293,10 +305,67 @@ func (p *BedrockProvider) PrepareChatCompletionMessages(messages []interfaces.Me return nil, fmt.Errorf("invalid model choice: %s", model) } -func (p *BedrockProvider) TextCompletion(model, key, text string, params *interfaces.ModelParameters) (*interfaces.CompletionResult, error) { - startTime := time.Now() +func (provider *BedrockProvider) GetChatCompletionTools(params *interfaces.ModelParameters, model string) []BedrockAnthropicToolCall { + var tools []BedrockAnthropicToolCall - preparedParams := PrepareParams(params) + switch model { + case "anthropic.claude-instant-v1:2": + fallthrough + case "anthropic.claude-v2": + fallthrough + case "anthropic.claude-v2:1": + fallthrough + case "anthropic.claude-3-sonnet-20240229-v1:0": + fallthrough + case "anthropic.claude-3-5-sonnet-20240620-v1:0": + fallthrough + case "anthropic.claude-3-5-sonnet-20241022-v2:0": + fallthrough + case "anthropic.claude-3-5-haiku-20241022-v1:0": + fallthrough + case "anthropic.claude-3-opus-20240229-v1:0": + fallthrough + case "anthropic.claude-3-7-sonnet-20250219-v1:0": + for _, tool := range *params.Tools { + tools = append(tools, BedrockAnthropicToolCall{ + ToolSpec: BedrockAnthropicToolSpec{ + Name: tool.Function.Name, + Description: tool.Function.Description, + InputSchema: struct { + Json interface{} `json:"json"` + }{ + Json: tool.Function.Parameters, + }, + }, + }) + } + } + + return tools +} + +func (provider *BedrockProvider) PrepareTextCompletionParams(params map[string]interface{}, model string) map[string]interface{} { + switch model { + case "anthropic.claude-instant-v1:2": + fallthrough + case "anthropic.claude-v2": + fallthrough + case "anthropic.claude-v2:1": + // Check if there is a key entry for max_tokens + if maxTokens, exists := params["max_tokens"]; exists { + // Check if max_tokens_to_sample is already present + if _, exists := params["max_tokens_to_sample"]; !exists { + // If max_tokens_to_sample is not present, rename max_tokens to max_tokens_to_sample + params["max_tokens_to_sample"] = maxTokens + } + delete(params, "max_tokens") + } + } + return params +} + +func (provider *BedrockProvider) TextCompletion(model, key, text string, params *interfaces.ModelParameters) (*interfaces.BifrostResponse, error) { + preparedParams := provider.PrepareTextCompletionParams(PrepareParams(params), model) requestBody := MergeConfig(map[string]interface{}{ "prompt": text, @@ -309,13 +378,13 @@ func (p *BedrockProvider) TextCompletion(model, key, text string, params *interf } // Create the signed request with correct operation name - req, err := p.PrepareReq(fmt.Sprintf("%s/invoke", model), jsonData, key) + req, err := provider.PrepareReq(fmt.Sprintf("%s/invoke", model), jsonData, key) if err != nil { return nil, fmt.Errorf("error creating request: %v", err) } // Execute the request - resp, err := p.client.Do(req) + resp, err := provider.client.Do(req) if err != nil { return nil, fmt.Errorf("failed to execute request: %v", err) } @@ -331,24 +400,35 @@ func (p *BedrockProvider) TextCompletion(model, key, text string, params *interf return nil, fmt.Errorf("bedrock API error: %s", string(body)) } - result, err := p.GetTextCompletionResult(body, model) + result, err := provider.GetTextCompletionResult(body, model) if err != nil { return nil, fmt.Errorf("failed to parse response body: %v", err) } - // Calculate latency - latency := time.Since(startTime).Seconds() - result.Usage.Latency = &latency + + // Parse raw response + var rawResponse interface{} + if err := json.Unmarshal(body, &rawResponse); err != nil { + return nil, fmt.Errorf("failed to parse raw response: %v", err) + } + + result.RawResponse = rawResponse return result, nil } -func (p *BedrockProvider) ChatCompletion(model, key string, messages []interfaces.Message, params *interfaces.ModelParameters) (*interfaces.CompletionResult, error) { - messageBody, err := p.PrepareChatCompletionMessages(messages, model) +func (provider *BedrockProvider) ChatCompletion(model, key string, messages []interfaces.Message, params *interfaces.ModelParameters) (*interfaces.BifrostResponse, error) { + messageBody, err := provider.PrepareChatCompletionMessages(messages, model) if err != nil { return nil, fmt.Errorf("error preparing messages: %v", err) } preparedParams := PrepareParams(params) + + // Transform tools if present + if params != nil && params.Tools != nil && len(*params.Tools) > 0 { + preparedParams["tools"] = provider.GetChatCompletionTools(params, model) + } + requestBody := MergeConfig(messageBody, preparedParams) // Marshal the request body @@ -360,23 +440,23 @@ func (p *BedrockProvider) ChatCompletion(model, key string, messages []interface // Format the path with proper model identifier path := fmt.Sprintf("%s/converse", model) - if p.meta != nil && p.meta.InferenceProfiles != nil { - if inferenceProfileId, ok := p.meta.InferenceProfiles[model]; ok { - if p.meta.ARN != nil { - encodedModelIdentifier := url.PathEscape(fmt.Sprintf("%s/%s", *p.meta.ARN, inferenceProfileId)) + if provider.meta != nil && provider.meta.InferenceProfiles != nil { + if inferenceProfileId, ok := provider.meta.InferenceProfiles[model]; ok { + if provider.meta.ARN != nil { + encodedModelIdentifier := url.PathEscape(fmt.Sprintf("%s/%s", *provider.meta.ARN, inferenceProfileId)) path = fmt.Sprintf("%s/converse", encodedModelIdentifier) } } } // Create the signed request - req, err := p.PrepareReq(path, jsonData, key) + req, err := provider.PrepareReq(path, jsonData, key) if err != nil { return nil, fmt.Errorf("error creating request: %v", err) } // Execute the request - resp, err := p.client.Do(req) + resp, err := provider.client.Do(req) if err != nil { return nil, fmt.Errorf("failed to execute request: %v", err) } @@ -397,11 +477,17 @@ func (p *BedrockProvider) ChatCompletion(model, key string, messages []interface return nil, fmt.Errorf("failed to parse Bedrock response: %v", err) } - var choices []interfaces.CompletionResultChoice + // Parse raw response + var rawResponse interface{} + if err := json.Unmarshal(body, &rawResponse); err != nil { + return nil, fmt.Errorf("failed to parse raw response: %v", err) + } + + var choices []interfaces.BifrostResponseChoice for i, choice := range response.Output.Message.Content { - choices = append(choices, interfaces.CompletionResultChoice{ + choices = append(choices, interfaces.BifrostResponseChoice{ Index: i, - Message: interfaces.CompletionResponseChoice{ + Message: interfaces.BifrostResponseChoiceMessage{ Role: interfaces.RoleAssistant, Content: choice.Text, }, @@ -411,7 +497,7 @@ func (p *BedrockProvider) ChatCompletion(model, key string, messages []interface latency := float64(response.Metrics.Latency) - result := &interfaces.CompletionResult{ + result := &interfaces.BifrostResponse{ Choices: choices, Usage: interfaces.LLMUsage{ PromptTokens: response.Usage.InputTokens, @@ -419,8 +505,9 @@ func (p *BedrockProvider) ChatCompletion(model, key string, messages []interface TotalTokens: response.Usage.TotalTokens, Latency: &latency, }, - Model: model, - Provider: interfaces.Bedrock, + Model: model, + Provider: interfaces.Bedrock, + RawResponse: rawResponse, } return result, nil diff --git a/providers/cohere.go b/providers/cohere.go index 04062671f0..4030c7091a 100644 --- a/providers/cohere.go +++ b/providers/cohere.go @@ -5,6 +5,7 @@ import ( "bytes" "encoding/json" "fmt" + "io" "net/http" "slices" "time" @@ -24,14 +25,20 @@ type CohereTool struct { ParameterDefinitions map[string]CohereParameterDefinition `json:"parameter_definitions"` } +type CohereToolCall struct { + Name string `json:"name"` + Parameters interface{} `json:"parameters"` +} + // CohereChatResponse represents the response from Cohere's chat API type CohereChatResponse struct { ResponseID string `json:"response_id"` Text string `json:"text"` GenerationID string `json:"generation_id"` ChatHistory []struct { - Role interfaces.ModelChatMessageRole `json:"role"` - Message string `json:"message"` + Role interfaces.ModelChatMessageRole `json:"role"` + Message string `json:"message"` + ToolCalls []CohereToolCall `json:"tool_calls"` } `json:"chat_history"` FinishReason string `json:"finish_reason"` Meta struct { @@ -47,6 +54,7 @@ type CohereChatResponse struct { OutputTokens float64 `json:"output_tokens"` } `json:"tokens"` } `json:"meta"` + ToolCalls []CohereToolCall `json:"tool_calls"` } // OpenAIProvider implements the Provider interface for OpenAI @@ -65,13 +73,11 @@ func (provider *CohereProvider) GetProviderKey() interfaces.SupportedModelProvid return interfaces.Cohere } -func (provider *CohereProvider) TextCompletion(model, key, text string, params *interfaces.ModelParameters) (*interfaces.CompletionResult, error) { +func (provider *CohereProvider) TextCompletion(model, key, text string, params *interfaces.ModelParameters) (*interfaces.BifrostResponse, error) { return nil, fmt.Errorf("text completion is not supported by Cohere") } -func (provider *CohereProvider) ChatCompletion(model, key string, messages []interfaces.Message, params *interfaces.ModelParameters) (*interfaces.CompletionResult, error) { - startTime := time.Now() - +func (provider *CohereProvider) ChatCompletion(model, key string, messages []interfaces.Message, params *interfaces.ModelParameters) (*interfaces.BifrostResponse, error) { // Get the last message and chat history lastMessage := messages[len(messages)-1] chatHistory := messages[:len(messages)-1] @@ -99,29 +105,23 @@ func (provider *CohereProvider) ChatCompletion(model, key string, messages []int var tools []CohereTool for _, tool := range *params.Tools { parameterDefinitions := make(map[string]CohereParameterDefinition) - if tool.Function.Parameters != nil { - paramsMap, ok := tool.Function.Parameters.(map[string]interface{}) + params := tool.Function.Parameters + for name, prop := range tool.Function.Parameters.Properties { + propMap, ok := prop.(map[string]interface{}) if ok { - if properties, ok := paramsMap["properties"].(map[string]interface{}); ok { - for name, prop := range properties { - propMap, ok := prop.(map[string]interface{}) - if ok { - paramDef := CohereParameterDefinition{ - Required: slices.Contains(paramsMap["required"].([]string), name), - } - - if typeStr, ok := propMap["type"].(string); ok { - paramDef.Type = typeStr - } - - if desc, ok := propMap["description"].(string); ok { - paramDef.Description = &desc - } - - parameterDefinitions[name] = paramDef - } - } + paramDef := CohereParameterDefinition{ + Required: slices.Contains(params.Required, name), } + + if typeStr, ok := propMap["type"].(string); ok { + paramDef.Type = typeStr + } + + if desc, ok := propMap["description"].(string); ok { + paramDef.Description = &desc + } + + parameterDefinitions[name] = paramDef } } @@ -157,28 +157,40 @@ func (provider *CohereProvider) ChatCompletion(model, key string, messages []int } defer resp.Body.Close() + // Read response body + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %v", err) + } + // Handle error response if resp.StatusCode != http.StatusOK { - var errorResp struct { - Message string `json:"message"` - } - if err := json.NewDecoder(resp.Body).Decode(&errorResp); err != nil { - return nil, fmt.Errorf("error decoding error response: %v", err) - } - return nil, fmt.Errorf("cohere error: %s", errorResp.Message) + return nil, fmt.Errorf("cohere error: %s", string(body)) } // Decode response var response CohereChatResponse - if err := json.NewDecoder(resp.Body).Decode(&response); err != nil { - return nil, fmt.Errorf("error decoding response: %v", err) + if err := json.Unmarshal(body, &response); err != nil { + return nil, fmt.Errorf("failed to parse Bedrock response: %v", err) } - // Transform tool calls if present - var toolCalls *[]interfaces.ToolCall + // Parse raw response + var rawResponse interface{} + if err := json.Unmarshal(body, &rawResponse); err != nil { + return nil, fmt.Errorf("failed to parse raw response: %v", err) + } - // Calculate latency - latency := time.Since(startTime).Seconds() + // Transform tool calls if present + var toolCalls []interfaces.ToolCall + if response.ToolCalls != nil { + for _, tool := range response.ToolCalls { + args := json.RawMessage(must(json.Marshal(tool.Parameters))) + toolCalls = append(toolCalls, interfaces.ToolCall{ + Name: &tool.Name, + Arguments: args, + }) + } + } // Get role and content from the last message in chat history var role interfaces.ModelChatMessageRole @@ -188,20 +200,20 @@ func (provider *CohereProvider) ChatCompletion(model, key string, messages []int role = lastMsg.Role content = lastMsg.Message } else { - role = interfaces.ModelChatMessageRole("assistant") + role = interfaces.RoleChatbot content = response.Text } // Create completion result - result := &interfaces.CompletionResult{ + result := &interfaces.BifrostResponse{ ID: response.ResponseID, - Choices: []interfaces.CompletionResultChoice{ + Choices: []interfaces.BifrostResponseChoice{ { Index: 0, - Message: interfaces.CompletionResponseChoice{ + Message: interfaces.BifrostResponseChoiceMessage{ Role: role, Content: content, - ToolCalls: toolCalls, + ToolCalls: &toolCalls, }, StopReason: &response.FinishReason, }, @@ -211,14 +223,14 @@ func (provider *CohereProvider) ChatCompletion(model, key string, messages []int PromptTokens: int(response.Meta.Tokens.InputTokens), CompletionTokens: int(response.Meta.Tokens.OutputTokens), TotalTokens: int(response.Meta.Tokens.InputTokens + response.Meta.Tokens.OutputTokens), - Latency: &latency, }, BilledUsage: &interfaces.BilledLLMUsage{ PromptTokens: float64Ptr(response.Meta.BilledUnits.InputTokens), CompletionTokens: float64Ptr(response.Meta.BilledUnits.OutputTokens), }, - Model: model, - Provider: interfaces.Cohere, + Model: model, + Provider: interfaces.Cohere, + RawResponse: rawResponse, } return result, nil @@ -226,14 +238,26 @@ func (provider *CohereProvider) ChatCompletion(model, key string, messages []int // Helper function to convert chat history to the correct type func convertChatHistory(history []struct { - Role interfaces.ModelChatMessageRole `json:"role"` - Message string `json:"message"` -}) *[]interfaces.CompletionResponseChoice { - converted := make([]interfaces.CompletionResponseChoice, len(history)) + Role interfaces.ModelChatMessageRole `json:"role"` + Message string `json:"message"` + ToolCalls []CohereToolCall `json:"tool_calls"` +}) *[]interfaces.BifrostResponseChoiceMessage { + converted := make([]interfaces.BifrostResponseChoiceMessage, len(history)) for i, msg := range history { - converted[i] = interfaces.CompletionResponseChoice{ - Role: msg.Role, - Content: msg.Message, + var toolCalls []interfaces.ToolCall + if msg.ToolCalls != nil { + for _, tool := range msg.ToolCalls { + args := json.RawMessage(must(json.Marshal(tool.Parameters))) + toolCalls = append(toolCalls, interfaces.ToolCall{ + Name: &tool.Name, + Arguments: args, + }) + } + } + converted[i] = interfaces.BifrostResponseChoiceMessage{ + Role: msg.Role, + Content: msg.Message, + ToolCalls: &toolCalls, } } return &converted diff --git a/providers/openai.go b/providers/openai.go index b97a86dcdd..16d0f818b7 100644 --- a/providers/openai.go +++ b/providers/openai.go @@ -5,16 +5,57 @@ import ( "bytes" "encoding/json" "fmt" + "io" "net/http" "time" ) +type OpenAIToolCall struct { + Type *string `json:"type"` + ID *string `json:"id"` + Function struct { + Name string `json:"name"` + Arguments string `json:"arguments"` + } `json:"function"` +} + +type OpenAIMessage struct { + Role interfaces.ModelChatMessageRole `json:"role"` + Content string `json:"content"` + ToolCalls *[]OpenAIToolCall `json:"tool_calls,omitempty"` +} + +type OpenAIChoice struct { + Index int `json:"index"` + Message OpenAIMessage `json:"message"` + FinishReason *string `json:"finish_reason"` + LogProbs *interface{} `json:"logprobs"` +} + type OpenAIResponse struct { - ID string `json:"id"` - Choices []interfaces.CompletionResultChoice `json:"choices"` - Usage interfaces.LLMUsage `json:"usage"` - Model string `json:"model"` - Created interface{} `json:"created"` + ID string `json:"id"` + Object string `json:"object"` + Choices []OpenAIChoice `json:"choices"` + Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + PromptTokenDetails struct { + CachedToken int `json:"cached_tokens"` + AudioToken int `json:"audio_tokens"` + } `json:"prompt_tokens_details"` + CompletionTokenDetails struct { + ReasoningTokens int `json:"reasoning_tokens"` + AudioTokens int `json:"audio_tokens"` + AcceptedPredictionTokens int `json:"accepted_prediction_tokens"` + RejectedPredictionTokens int `json:"rejected_prediction_tokens"` + } `json:"completion_tokens_details"` + Latency float64 `json:"latency"` + } `json:"usage"` + Model string `json:"model"` + Created interface{} `json:"created"` + ServiceTier string `json:"service_tier"` + SystemFingerprint string `json:"system_fingerprint"` } // OpenAIProvider implements the Provider interface for OpenAI @@ -34,7 +75,7 @@ func (provider *OpenAIProvider) GetProviderKey() interfaces.SupportedModelProvid } // TextCompletion performs text completion -func (provider *OpenAIProvider) TextCompletion(model, key, text string, params *interfaces.ModelParameters) (*interfaces.CompletionResult, error) { +func (provider *OpenAIProvider) TextCompletion(model, key, text string, params *interfaces.ModelParameters) (*interfaces.BifrostResponse, error) { return nil, fmt.Errorf("text completion is not supported by OpenAI") } @@ -55,8 +96,7 @@ func (provider *OpenAIProvider) sanitizeParameters(params *interfaces.ModelParam return sanitized } -// ChatCompletion implements chat completion using OpenAI's API -func (provider *OpenAIProvider) ChatCompletion(model, key string, messages []interfaces.Message, params *interfaces.ModelParameters) (*interfaces.CompletionResult, error) { +func (provider *OpenAIProvider) ChatCompletion(model, key string, messages []interfaces.Message, params *interfaces.ModelParameters) (*interfaces.BifrostResponse, error) { startTime := time.Now() // Format messages for OpenAI API @@ -108,35 +148,69 @@ func (provider *OpenAIProvider) ChatCompletion(model, key string, messages []int latency := time.Since(startTime).Seconds() + // Read response body + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("error reading response: %v", err) + } + // Handle error response if resp.StatusCode != http.StatusOK { - var errorResp struct { - Error struct { - Message string `json:"message"` - Type string `json:"type"` - Param any `json:"param"` - Code string `json:"code"` - } `json:"error"` - } - if err := json.NewDecoder(resp.Body).Decode(&errorResp); err != nil { - return nil, fmt.Errorf("error decoding error response: %v", err) - } - return nil, fmt.Errorf("OpenAI error: %s", errorResp.Error.Message) + return nil, fmt.Errorf("OpenAI error: %s", string(body)) } - // Decode response + // Decode structured response var response OpenAIResponse + if err := json.Unmarshal(body, &response); err != nil { + return nil, fmt.Errorf("error decoding structured response: %v", err) + } - if err := json.NewDecoder(resp.Body).Decode(&response); err != nil { - return nil, fmt.Errorf("error decoding response: %v", err) + // Decode raw response + var rawResponse interface{} + if err := json.Unmarshal(body, &rawResponse); err != nil { + return nil, fmt.Errorf("error decoding raw response: %v", err) } - // Convert the raw result to CompletionResult - result := &interfaces.CompletionResult{ + // Transform choices to include tool calls + var choices []interfaces.BifrostResponseChoice + for i, c := range response.Choices { + // Transform tool calls if present + var toolCalls []interfaces.ToolCall + if c.Message.ToolCalls != nil { + for _, tool := range *c.Message.ToolCalls { + toolCalls = append(toolCalls, interfaces.ToolCall{ + ID: tool.ID, + Type: tool.Type, + Name: &tool.Function.Name, + Arguments: json.RawMessage(tool.Function.Arguments), + }) + } + } + + choices = append(choices, interfaces.BifrostResponseChoice{ + Index: i, + Message: interfaces.BifrostResponseChoiceMessage{ + Role: c.Message.Role, + Content: c.Message.Content, + ToolCalls: &toolCalls, + }, + StopReason: c.FinishReason, + LogProbs: c.LogProbs, + }) + } + + result := &interfaces.BifrostResponse{ ID: response.ID, - Choices: response.Choices, - Usage: response.Usage, - Model: response.Model, + Choices: choices, + Usage: interfaces.LLMUsage{ + PromptTokens: response.Usage.PromptTokens, + CompletionTokens: response.Usage.CompletionTokens, + TotalTokens: response.Usage.TotalTokens, + Latency: &latency, + }, + Model: response.Model, + Provider: interfaces.OpenAI, + RawResponse: rawResponse, } // Handle the created field conversion @@ -150,9 +224,5 @@ func (provider *OpenAIProvider) ChatCompletion(model, key string, messages []int } } - // Add provider-specific information - result.Provider = interfaces.OpenAI - result.Usage.Latency = &latency - return result, nil } diff --git a/providers/utils.go b/providers/utils.go index 56c5908089..a3a964d9a5 100644 --- a/providers/utils.go +++ b/providers/utils.go @@ -12,6 +12,8 @@ import ( "reflect" "time" + "maps" + "github.com/aws/aws-sdk-go-v2/aws" v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" "github.com/aws/aws-sdk-go-v2/config" @@ -47,7 +49,7 @@ func PrepareParams(params *interfaces.ModelParameters) map[string]interface{} { typ := val.Type() // Iterate through all fields - for i := 0; i < val.NumField(); i++ { + for i := range val.NumField() { field := val.Field(i) fieldType := typ.Field(i) @@ -69,9 +71,7 @@ func PrepareParams(params *interfaces.ModelParameters) map[string]interface{} { } // Handle ExtraParams - for k, v := range params.ExtraParams { - flatParams[k] = v - } + maps.Copy(flatParams, params.ExtraParams) return flatParams } diff --git a/tests/anthropic_test.go b/tests/anthropic_test.go index f4790f8b71..4791d8ffc8 100644 --- a/tests/anthropic_test.go +++ b/tests/anthropic_test.go @@ -11,21 +11,16 @@ import ( // setupAnthropicRequests sends multiple test requests to Anthropic func setupAnthropicRequests(bifrost *bifrost.Bifrost) { - anthropicMessages := []string{ - "What's your favorite programming language?", - "Can you help me write a Go function?", - "What's the best way to learn programming?", - "Tell me about artificial intelligence.", - } - ctx := context.Background() + maxTokens := 4096 + + params := interfaces.ModelParameters{ + MaxTokens: &maxTokens, + } + + // Text completion request go func() { - params := interfaces.ModelParameters{ - ExtraParams: map[string]interface{}{ - "max_tokens_to_sample": 4096, - }, - } text := "Hello world!" result, err := bifrost.TextCompletionRequest(interfaces.Anthropic, &interfaces.BifrostRequest{ @@ -42,10 +37,11 @@ func setupAnthropicRequests(bifrost *bifrost.Bifrost) { } }() - params := interfaces.ModelParameters{ - ExtraParams: map[string]interface{}{ - "max_tokens": 4096, - }, + // Regular chat completion requests + anthropicMessages := []string{ + "Hello! How are you today?", + "Tell me a joke!", + "What's your favorite programming language?", } for i, message := range anthropicMessages { @@ -73,6 +69,70 @@ func setupAnthropicRequests(bifrost *bifrost.Bifrost) { } }(message, delay, i) } + + // Tool calls test + setupAnthropicToolCalls(bifrost, ctx) +} + +// setupAnthropicToolCalls tests Anthropic's function calling capability +func setupAnthropicToolCalls(bifrost *bifrost.Bifrost, ctx context.Context) { + anthropicMessages := []string{ + "What's the weather like in Mumbai?", + } + + maxTokens := 4096 + + params := interfaces.ModelParameters{ + Tools: &[]interfaces.Tool{{ + Type: "function", + Function: interfaces.Function{ + Name: "get_weather", + Description: "Get the current weather in a given location", + Parameters: interfaces.FunctionParameters{ + Type: "object", + Properties: map[string]interface{}{ + "location": map[string]interface{}{ + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": map[string]interface{}{ + "type": "string", + "enum": []string{"celsius", "fahrenheit"}, + }, + }, + Required: []string{"location"}, + }, + }, + }}, + MaxTokens: &maxTokens, + } + + for i, message := range anthropicMessages { + delay := time.Duration(500+100*i) * time.Millisecond + go func(msg string, delay time.Duration, index int) { + time.Sleep(delay) + messages := []interfaces.Message{ + { + Role: interfaces.RoleUser, + Content: &msg, + }, + } + result, err := bifrost.ChatCompletionRequest(interfaces.Anthropic, &interfaces.BifrostRequest{ + Model: "claude-3-7-sonnet-20250219", + Input: interfaces.RequestInput{ + ChatInput: &messages, + }, + Params: ¶ms, + }, ctx) + + if err != nil { + fmt.Printf("Error in Anthropic tool call request %d: %v\n", index+1, err) + } else { + toolCall := *result.Choices[1].Message.ToolCalls + fmt.Printf("🤖 Tool Call Result %d: %s\n", index+1, toolCall[0].Arguments) + } + }(message, delay, i) + } } func TestAnthropic(t *testing.T) { diff --git a/tests/bedrock_test.go b/tests/bedrock_test.go index 4fb3ba53c9..168c44693f 100644 --- a/tests/bedrock_test.go +++ b/tests/bedrock_test.go @@ -11,21 +11,16 @@ import ( // setupBedrockRequests sends multiple test requests to Bedrock func setupBedrockRequests(bifrost *bifrost.Bifrost) { - bedrockMessages := []string{ - "What's your favorite programming language?", - "Can you help me write a Go function?", - "What's the best way to learn programming?", - "Tell me about artificial intelligence.", - } - ctx := context.Background() + maxTokens := 4096 + + params := interfaces.ModelParameters{ + MaxTokens: &maxTokens, + } + + // Text completion request go func() { - params := interfaces.ModelParameters{ - ExtraParams: map[string]interface{}{ - "max_tokens_to_sample": 4096, - }, - } text := "\n\nHuman:\n\nAssistant:" result, err := bifrost.TextCompletionRequest(interfaces.Bedrock, &interfaces.BifrostRequest{ @@ -42,10 +37,11 @@ func setupBedrockRequests(bifrost *bifrost.Bifrost) { } }() - params := interfaces.ModelParameters{ - ExtraParams: map[string]interface{}{ - "max_tokens": 4096, - }, + // Regular chat completion requests + bedrockMessages := []string{ + "Hello! How are you today?", + "Tell me a joke!", + "What's your favorite programming language?", } for i, message := range bedrockMessages { @@ -73,6 +69,75 @@ func setupBedrockRequests(bifrost *bifrost.Bifrost) { } }(message, delay, i) } + + // Tool calls test + setupBedrockToolCalls(bifrost, ctx) +} + +// setupBedrockToolCalls tests Bedrock's function calling capability +func setupBedrockToolCalls(bifrost *bifrost.Bifrost, ctx context.Context) { + bedrockMessages := []string{ + "What's the weather like in Mumbai?", + } + + maxTokens := 4096 + + params := interfaces.ModelParameters{ + Tools: &[]interfaces.Tool{{ + Type: "function", + Function: interfaces.Function{ + Name: "get_weather", + Description: "Get the current weather in a given location", + Parameters: interfaces.FunctionParameters{ + Type: "object", + Properties: map[string]interface{}{ + "location": map[string]interface{}{ + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": map[string]interface{}{ + "type": "string", + "enum": []string{"celsius", "fahrenheit"}, + }, + }, + Required: []string{"location"}, + }, + }, + }}, + MaxTokens: &maxTokens, + } + + for i, message := range bedrockMessages { + delay := time.Duration(500+100*i) * time.Millisecond + go func(msg string, delay time.Duration, index int) { + time.Sleep(delay) + messages := []interfaces.Message{ + { + Role: interfaces.RoleUser, + Content: &msg, + }, + } + result, err := bifrost.ChatCompletionRequest(interfaces.Bedrock, &interfaces.BifrostRequest{ + Model: "anthropic.claude-3-sonnet-20240229-v1:0", + Input: interfaces.RequestInput{ + ChatInput: &messages, + }, + Params: ¶ms, + }, ctx) + + if err != nil { + fmt.Printf("Error in Bedrock tool call request %d: %v\n", index+1, err) + } else { + if result.Choices[0].Message.ToolCalls != nil && len(*result.Choices[0].Message.ToolCalls) > 0 { + toolCall := *result.Choices[0].Message.ToolCalls + fmt.Printf("🤖 Tool Call Result %d: %s\n", index+1, toolCall[0].Arguments) + } else { + fmt.Printf("🤖 No tool calls in response %d\n", index+1) + fmt.Println("Raw JSON Response", result.RawResponse) + } + } + }(message, delay, i) + } } func TestBedrock(t *testing.T) { diff --git a/tests/cohere_test.go b/tests/cohere_test.go index 0fcb6ec7b3..e9346bc60d 100644 --- a/tests/cohere_test.go +++ b/tests/cohere_test.go @@ -12,7 +12,6 @@ import ( // setupCohereRequests sends multiple test requests to Cohere func setupCohereRequests(bifrost *bifrost.Bifrost) { text := "Hello world!" - ctx := context.Background() // Text completion request @@ -31,15 +30,14 @@ func setupCohereRequests(bifrost *bifrost.Bifrost) { } }() - // Chat completion requests with different messages and delays - CohereMessages := []string{ + // Regular chat completion requests + cohereMessages := []string{ "Hello! How are you today?", - "What's the weather like?", "Tell me a joke!", "What's your favorite programming language?", } - for i, message := range CohereMessages { + for i, message := range cohereMessages { delay := time.Duration(100*(i+1)) * time.Millisecond go func(msg string, delay time.Duration, index int) { time.Sleep(delay) @@ -63,6 +61,66 @@ func setupCohereRequests(bifrost *bifrost.Bifrost) { } }(message, delay, i) } + + // Tool calls test + setupCohereToolCalls(bifrost, ctx) +} + +// setupCohereToolCalls tests Cohere's function calling capability +func setupCohereToolCalls(bifrost *bifrost.Bifrost, ctx context.Context) { + cohereMessages := []string{ + "What's the weather like in Mumbai?", + } + + params := interfaces.ModelParameters{ + Tools: &[]interfaces.Tool{{ + Type: "function", + Function: interfaces.Function{ + Name: "get_weather", + Description: "Get the current weather in a given location", + Parameters: interfaces.FunctionParameters{ + Type: "object", + Properties: map[string]interface{}{ + "location": map[string]interface{}{ + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": map[string]interface{}{ + "type": "string", + "enum": []string{"celsius", "fahrenheit"}, + }, + }, + Required: []string{"location"}, + }, + }, + }}, + } + + for i, message := range cohereMessages { + delay := time.Duration(100*(i+1)) * time.Millisecond + go func(msg string, delay time.Duration, index int) { + time.Sleep(delay) + messages := []interfaces.Message{ + { + Role: interfaces.RoleUser, + Content: &msg, + }, + } + result, err := bifrost.ChatCompletionRequest(interfaces.Cohere, &interfaces.BifrostRequest{ + Model: "command-a-03-2025", + Input: interfaces.RequestInput{ + ChatInput: &messages, + }, + Params: ¶ms, + }, ctx) + if err != nil { + fmt.Printf("Error in Cohere tool call request %d: %v\n", index+1, err) + } else { + toolCall := *result.Choices[0].Message.ToolCalls + fmt.Printf("🐒 Tool Call Result %d: %s\n", index+1, toolCall[0].Arguments) + } + }(message, delay, i) + } } func TestCohere(t *testing.T) { diff --git a/tests/openai_test.go b/tests/openai_test.go index 246f1a4114..0c7a214170 100644 --- a/tests/openai_test.go +++ b/tests/openai_test.go @@ -12,7 +12,6 @@ import ( // setupOpenAIRequests sends multiple test requests to OpenAI func setupOpenAIRequests(bifrost *bifrost.Bifrost) { text := "Hello world!" - ctx := context.Background() // Text completion request @@ -31,10 +30,9 @@ func setupOpenAIRequests(bifrost *bifrost.Bifrost) { } }() - // Chat completion requests with different messages and delays + // Regular chat completion requests openAIMessages := []string{ "Hello! How are you today?", - "What's the weather like?", "Tell me a joke!", "What's your favorite programming language?", } @@ -63,6 +61,66 @@ func setupOpenAIRequests(bifrost *bifrost.Bifrost) { } }(message, delay, i) } + + // Tool calls test + setupOpenAIToolCalls(bifrost, ctx) +} + +// setupOpenAIToolCalls tests OpenAI's function calling capability +func setupOpenAIToolCalls(bifrost *bifrost.Bifrost, ctx context.Context) { + openAIMessages := []string{ + "What's the weather like in Mumbai?", + } + + params := interfaces.ModelParameters{ + Tools: &[]interfaces.Tool{{ + Type: "function", + Function: interfaces.Function{ + Name: "get_weather", + Description: "Get the current weather in a given location", + Parameters: interfaces.FunctionParameters{ + Type: "object", + Properties: map[string]interface{}{ + "location": map[string]interface{}{ + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": map[string]interface{}{ + "type": "string", + "enum": []string{"celsius", "fahrenheit"}, + }, + }, + Required: []string{"location"}, + }, + }, + }}, + } + + for i, message := range openAIMessages { + delay := time.Duration(100*(i+1)) * time.Millisecond + go func(msg string, delay time.Duration, index int) { + time.Sleep(delay) + messages := []interfaces.Message{ + { + Role: interfaces.RoleUser, + Content: &msg, + }, + } + result, err := bifrost.ChatCompletionRequest(interfaces.OpenAI, &interfaces.BifrostRequest{ + Model: "gpt-4o-mini", + Input: interfaces.RequestInput{ + ChatInput: &messages, + }, + Params: ¶ms, + }, ctx) + if err != nil { + fmt.Printf("Error in OpenAI tool call request %d: %v\n", index+1, err) + } else { + toolCall := *result.Choices[0].Message.ToolCalls + fmt.Printf("🐒 Tool Call Result %d: %s\n", index+1, toolCall[0].Arguments) + } + }(message, delay, i) + } } func TestOpenAI(t *testing.T) { diff --git a/tests/plugin.go b/tests/plugin.go index dace4ff9e9..2cc435e3d7 100644 --- a/tests/plugin.go +++ b/tests/plugin.go @@ -37,7 +37,7 @@ func (plugin *Plugin) PreHook(ctx context.Context, req *interfaces.BifrostReques return ctx, req, nil } -func (plugin *Plugin) PostHook(ctx context.Context, res *interfaces.CompletionResult) (*interfaces.CompletionResult, error) { +func (plugin *Plugin) PostHook(ctx context.Context, res *interfaces.BifrostResponse) (*interfaces.BifrostResponse, error) { // Get traceID from context traceID, ok := ctx.Value(traceIDKey).(string) if !ok { diff --git a/tests/setup.go b/tests/setup.go index fc1923e15c..f8f5bddf5b 100644 --- a/tests/setup.go +++ b/tests/setup.go @@ -60,10 +60,8 @@ func getBifrost() (*bifrost.Bifrost, error) { DefaultRequestTimeoutInSeconds: 30, }, MetaConfig: &interfaces.MetaConfig{ - BedrockMetaConfig: &interfaces.BedrockMetaConfig{ - SecretAccessKey: "AMpq95pNadM2fD1GlcNvjbMiGhizwYaGKJxv+nti", - Region: maxim.StrPtr("us-east-1"), - }, + SecretAccessKey: "AMpq95pNadM2fD1GlcNvjbMiGhizwYaGKJxv+nti", + Region: maxim.StrPtr("us-east-1"), }, }, interfaces.Cohere: {