diff --git a/bifrost.go b/bifrost.go
index ae7454a31d..78c03c4920 100644
--- a/bifrost.go
+++ b/bifrost.go
@@ -22,7 +22,7 @@ const (
type ChannelMessage struct {
interfaces.BifrostRequest
- Response chan *interfaces.CompletionResult
+ Response chan *interfaces.BifrostResponse
Err chan error
Type RequestType
}
@@ -179,7 +179,7 @@ func (bifrost *Bifrost) processRequests(provider interfaces.Provider, queue chan
defer bifrost.wg[provider.GetProviderKey()].Done()
for req := range queue {
- var result *interfaces.CompletionResult
+ var result *interfaces.BifrostResponse
var err error
key, err := bifrost.SelectKeyFromProviderForModel(provider, req.Model)
@@ -234,13 +234,13 @@ func (bifrost *Bifrost) GetProviderQueue(providerKey interfaces.SupportedModelPr
return queue, nil
}
-func (bifrost *Bifrost) TextCompletionRequest(providerKey interfaces.SupportedModelProvider, req *interfaces.BifrostRequest, ctx context.Context) (*interfaces.CompletionResult, error) {
+func (bifrost *Bifrost) TextCompletionRequest(providerKey interfaces.SupportedModelProvider, req *interfaces.BifrostRequest, ctx context.Context) (*interfaces.BifrostResponse, error) {
queue, err := bifrost.GetProviderQueue(providerKey)
if err != nil {
return nil, err
}
- responseChan := make(chan *interfaces.CompletionResult)
+ responseChan := make(chan *interfaces.BifrostResponse)
errorChan := make(chan error)
for _, plugin := range bifrost.plugins {
@@ -273,13 +273,13 @@ func (bifrost *Bifrost) TextCompletionRequest(providerKey interfaces.SupportedMo
}
}
-func (bifrost *Bifrost) ChatCompletionRequest(providerKey interfaces.SupportedModelProvider, req *interfaces.BifrostRequest, ctx context.Context) (*interfaces.CompletionResult, error) {
+func (bifrost *Bifrost) ChatCompletionRequest(providerKey interfaces.SupportedModelProvider, req *interfaces.BifrostRequest, ctx context.Context) (*interfaces.BifrostResponse, error) {
queue, err := bifrost.GetProviderQueue(providerKey)
if err != nil {
return nil, err
}
- responseChan := make(chan *interfaces.CompletionResult)
+ responseChan := make(chan *interfaces.BifrostResponse)
errorChan := make(chan error)
for _, plugin := range bifrost.plugins {
diff --git a/interfaces/account.go b/interfaces/account.go
index 47ef6b0c14..252c0358e9 100644
--- a/interfaces/account.go
+++ b/interfaces/account.go
@@ -11,6 +11,7 @@ type Key struct {
Weight float64 `json:"weight"`
}
+// TODO one get config method
type Account interface {
GetInitiallyConfiguredProviderKeys() ([]SupportedModelProvider, error)
GetKeysForProvider(provider Provider) ([]Key, error)
diff --git a/interfaces/plugin.go b/interfaces/plugin.go
index 0de9d9144a..1d842408e8 100644
--- a/interfaces/plugin.go
+++ b/interfaces/plugin.go
@@ -15,5 +15,5 @@ type BifrostRequest struct {
type Plugin interface {
PreHook(ctx context.Context, req *BifrostRequest) (context.Context, *BifrostRequest, error)
- PostHook(ctx context.Context, result *CompletionResult) (*CompletionResult, error)
+ PostHook(ctx context.Context, result *BifrostResponse) (*BifrostResponse, error)
}
diff --git a/interfaces/provider.go b/interfaces/provider.go
index 4d08b32505..e5512e7d02 100644
--- a/interfaces/provider.go
+++ b/interfaces/provider.go
@@ -2,83 +2,6 @@ package interfaces
import "encoding/json"
-// LLMUsage represents token usage information
-type LLMUsage struct {
- PromptTokens int `json:"prompt_tokens"`
- CompletionTokens int `json:"completion_tokens"`
- TotalTokens int `json:"total_tokens"`
- Latency *float64 `json:"latency"`
-}
-
-type BilledLLMUsage struct {
- PromptTokens *float64 `json:"prompt_tokens"`
- CompletionTokens *float64 `json:"completion_tokens"`
- SearchUnits *float64 `json:"search_units"`
- Classifications *float64 `json:"classifications"`
-}
-
-// LLMInteractionCost represents cost information for LLM interactions
-type LLMInteractionCost struct {
- Input float64 `json:"input"`
- Output float64 `json:"output"`
- Total float64 `json:"total"`
-}
-
-// Function represents a function definition for tool calls
-type Function struct {
- Name string `json:"name"`
- Description string `json:"description"`
- Parameters interface{} `json:"parameters"`
-}
-
-// Tool represents a tool that can be used with the model
-type Tool struct {
- Type string `json:"type"`
- Function Function `json:"function"`
-}
-
-// ModelParameters represents the parameters for model requests
-type ModelParameters struct {
- TestRunEntryID *string `json:"testRunEntryId"`
- PromptTools *[]string `json:"promptTools"`
- ToolChoice *string `json:"toolChoice"`
- Tools *[]Tool `json:"tools"`
- FunctionCall *string `json:"functionCall"`
- Functions *[]Function `json:"functions"`
- // Dynamic parameters
- ExtraParams map[string]interface{} `json:"-"`
-}
-
-// RequestOptions represents options for model requests
-type RequestOptions struct {
- UseCache *bool `json:"useCache"`
- WaitForModel *bool `json:"waitForModel"`
- CompletionType *string `json:"CompletionType"`
-}
-
-// FunctionCall represents a function call in a tool call
-type FunctionCall struct {
- Name string `json:"name"`
- Arguments string `json:"arguments"`
-}
-
-// ToolCall represents a tool call in a message
-type ToolCall struct {
- Type *string `json:"type"`
- ID string `json:"id"`
- Name *string `json:"name"`
- Input json.RawMessage `json:"input"`
- Function *FunctionCall `json:"function"`
-}
-
-type Citation struct {
- Start *int `json:"start"`
- End *int `json:"end"`
- Text *string `json:"text"`
- Sources *interface{} `json:"sources"`
- Type *string `json:"type"`
-}
-
// ModelChatMessageRole represents the role of a chat message
type ModelChatMessageRole string
@@ -86,74 +9,10 @@ const (
RoleAssistant ModelChatMessageRole = "assistant"
RoleUser ModelChatMessageRole = "user"
RoleSystem ModelChatMessageRole = "system"
- RoleModel ModelChatMessageRole = "model"
+ RoleChatbot ModelChatMessageRole = "chatbot"
RoleTool ModelChatMessageRole = "tool"
)
-// CompletionResponseChoice represents a choice in the completion response
-type CompletionResponseChoice struct {
- Role ModelChatMessageRole `json:"role"`
- Content string `json:"content"`
- Image json.RawMessage `json:"image"`
- ToolCalls *[]ToolCall `json:"tool_calls"`
- Citations *[]Citation `json:"citation"`
-}
-
-// CompletionResultChoice represents a choice in the completion result
-type CompletionResultChoice struct {
- Index int `json:"index"`
- Message CompletionResponseChoice `json:"message"`
- StopReason *string `json:"stop_reason"`
- Stop *string `json:"stop"`
- LogProbs *interface{} `json:"logprobs"`
-}
-
-// ToolResult represents the result of a tool call
-type ToolResult struct {
- Role ModelChatMessageRole `json:"role"`
- Content string `json:"content"`
- ToolCallID string `json:"tool_call_id"`
-}
-
-// ToolCallResult represents a single tool call result
-type ToolCallResult struct {
- Name string `json:"name"`
- Result interface{} `json:"result"`
- Type string `json:"type"`
- ID string `json:"id"`
-}
-
-// ToolCallResults represents a collection of tool call results
-type ToolCallResults struct {
- Version int `json:"version"`
- Results []ToolCallResult `json:"results"`
-}
-
-// CompletionResult represents the complete result from a model completion
-type CompletionResult struct {
- Error *struct {
- Code string `json:"code"`
- Message string `json:"message"`
- Type string `json:"type"`
- } `json:"error"`
- ID string `json:"id"`
- Choices []CompletionResultChoice `json:"choices"`
- ChatHistory *[]CompletionResponseChoice `json:"chat_history"`
- ToolCallResult *interface{} `json:"tool_call_result"`
- ToolCallResults *ToolCallResults `json:"toolCallResults"`
- Provider SupportedModelProvider `json:"provider"`
- Usage LLMUsage `json:"usage"`
- BilledUsage *BilledLLMUsage `json:"billed_usage"`
- Cost *LLMInteractionCost `json:"cost"`
- Model string `json:"model"`
- Created string `json:"created"`
- Params *interface{} `json:"modelParams"`
- Trace *struct {
- Input interface{} `json:"input"`
- Output interface{} `json:"output"`
- } `json:"trace"`
-}
-
type SupportedModelProvider string
const (
@@ -170,13 +29,55 @@ const (
Lmstudio SupportedModelProvider = "lmstudio"
)
+//* Request Structs
+
+// ModelParameters represents the parameters for model requests
+type ModelParameters struct {
+ TestRunEntryID *string `json:"test_run_entry_id"`
+ PromptTools *[]string `json:"prompt_tools"`
+ ToolChoice *string `json:"tool_choice"`
+ Tools *[]Tool `json:"tools"`
+
+ // Common model parameters
+ Temperature *float64 `json:"temperature"`
+ TopP *float64 `json:"top_p"`
+ TopK *int `json:"top_k"`
+ MaxTokens *int `json:"max_tokens"`
+ StopSequences *[]string `json:"stop_sequences"`
+ PresencePenalty *float64 `json:"presence_penalty"`
+ FrequencyPenalty *float64 `json:"frequency_penalty"`
+
+ // Dynamic parameters
+ ExtraParams map[string]interface{} `json:"-"`
+}
+
+type FunctionParameters struct {
+ Type string `json:"type"`
+ Required []string `json:"required"`
+ Properties map[string]interface{} `json:"properties"`
+}
+
+// Function represents a function definition for tool calls
+type Function struct {
+ Name string `json:"name"`
+ Description string `json:"description"`
+ Parameters FunctionParameters `json:"parameters"`
+}
+
+// Tool represents a tool that can be used with the model
+type Tool struct {
+ ID *string `json:"id"`
+ Type string `json:"type"`
+ Function Function `json:"function"`
+}
+
type Message struct {
//* strict check for roles
Role ModelChatMessageRole `json:"role"`
//* need to make sure either content or imagecontent is provided
Content *string `json:"content"`
- ImageContent *ImageContent `json:"imageContent"`
- ToolCalls *[]ToolCall `json:"toolCall"`
+ ImageContent *ImageContent `json:"image_content"`
+ ToolCalls *[]Tool `json:"tool_calls"`
}
type ImageContent struct {
@@ -185,62 +86,100 @@ type ImageContent struct {
MediaType string `json:"media_type"`
}
-// type Content struct {
-// Content *string `json:"content"`
-// ImageContent *ImageContent `json:"imageContent"`
-// }
-
-// func (content *Content) MarshalJSON() ([]byte, error) {
-// if content.Content != nil {
-// return []byte(*content.Content), nil
-// } else if content.ImageContent != nil {
-// return json.Marshal(content.ImageContent)
-// }
-
-// return nil, fmt.Errorf("invalid content")
-// }
+type NetworkConfig struct {
+ DefaultRequestTimeoutInSeconds int `json:"default_request_timeout_in_seconds"`
+}
-// func (content *Content) UnmarshalJSON(val []byte) error {
-// var s any
-// json.Unmarshal(val, &s)
+type MetaConfig struct {
+ SecretAccessKey string `json:"secret_access_key"`
+ Region *string `json:"region"`
+ SessionToken *string `json:"session_token"`
+ ARN *string `json:"arn"`
+ InferenceProfiles map[string]string `json:"inference_profiles"`
+}
-// switch s := s.(type) {
-// case string:
-// content.Content = &s
-// case ImageContent:
-// content.ImageContent = &s
+type ProviderConfig struct {
+ NetworkConfig NetworkConfig `json:"network_config"`
+ MetaConfig *MetaConfig `json:"meta_config"`
+}
-// default:
-// return fmt.Errorf("invalid stop")
-// }
+//* Response Structs
-// return nil
-// }
+// LLMUsage represents token usage information
+type LLMUsage struct {
+ PromptTokens int `json:"prompt_tokens"`
+ CompletionTokens int `json:"completion_tokens"`
+ TotalTokens int `json:"total_tokens"`
+ Latency *float64 `json:"latency"`
+}
-type NetworkConfig struct {
- DefaultRequestTimeoutInSeconds int `json:"defaultRequestTimeoutInSeconds"`
+type BilledLLMUsage struct {
+ PromptTokens *float64 `json:"prompt_tokens"`
+ CompletionTokens *float64 `json:"completion_tokens"`
+ SearchUnits *float64 `json:"search_units"`
+ Classifications *float64 `json:"classifications"`
}
-type MetaConfig struct {
- BedrockMetaConfig *BedrockMetaConfig `json:"bedrockMetaConfig"`
+// LLMInteractionCost represents cost information for LLM interactions
+type LLMInteractionCost struct {
+ Input float64 `json:"input"`
+ Output float64 `json:"output"`
+ Total float64 `json:"total"`
}
-type ProviderConfig struct {
- NetworkConfig NetworkConfig `json:"networkConfig"`
- MetaConfig *MetaConfig `json:"metaConfig"`
+// ToolCall represents a tool call in a message
+type ToolCall struct {
+ Type *string `json:"type"`
+ ID *string `json:"id"`
+ Name *string `json:"name"`
+ Arguments json.RawMessage `json:"arguments"`
}
-type BedrockMetaConfig struct {
- SecretAccessKey string `json:"secretAccessKey"`
- Region *string `json:"region"`
- SessionToken *string `json:"sessionToken"`
- ARN *string `json:"arn"`
- InferenceProfiles map[string]string `json:"inferenceProfiles"`
+type Citation struct {
+ Start *int `json:"start"`
+ End *int `json:"end"`
+ Text *string `json:"text"`
+ Sources *interface{} `json:"sources"`
+ Type *string `json:"type"`
}
+// BifrostResponseChoiceMessage represents a choice in the completion response
+type BifrostResponseChoiceMessage struct {
+ Role ModelChatMessageRole `json:"role"`
+ Content string `json:"content"`
+ Image json.RawMessage `json:"image"`
+ ToolCalls *[]ToolCall `json:"tool_calls"`
+ Citations *[]Citation `json:"citations"`
+}
+
+// BifrostResponseChoice represents a choice in the completion result
+type BifrostResponseChoice struct {
+ Index int `json:"index"`
+ Message BifrostResponseChoiceMessage `json:"message"`
+ StopReason *string `json:"stop_reason"`
+ Stop *string `json:"stop"`
+ LogProbs *interface{} `json:"log_probs"`
+}
+
+// BifrostResponse represents the complete result from a model completion
+type BifrostResponse struct {
+ ID string `json:"id"`
+ Choices []BifrostResponseChoice `json:"choices"`
+ ChatHistory *[]BifrostResponseChoiceMessage `json:"chat_history"`
+ Provider SupportedModelProvider `json:"provider"`
+ Usage LLMUsage `json:"usage"`
+ BilledUsage *BilledLLMUsage `json:"billed_usage"`
+ Cost *LLMInteractionCost `json:"cost"`
+ Model string `json:"model"`
+ Created string `json:"created"`
+ Params *interface{} `json:"model_params"`
+ RawResponse interface{} `json:"raw_response"`
+}
+
+// TODO third party providers
// Provider defines the interface for AI model providers
type Provider interface {
GetProviderKey() SupportedModelProvider
- TextCompletion(model, key, text string, params *ModelParameters) (*CompletionResult, error)
- ChatCompletion(model, key string, messages []Message, params *ModelParameters) (*CompletionResult, error)
+ TextCompletion(model, key, text string, params *ModelParameters) (*BifrostResponse, error)
+ ChatCompletion(model, key string, messages []Message, params *ModelParameters) (*BifrostResponse, error)
}
diff --git a/providers/anthropic.go b/providers/anthropic.go
index 0fe2394e11..d3fdf66c6a 100644
--- a/providers/anthropic.go
+++ b/providers/anthropic.go
@@ -28,14 +28,12 @@ type AnthropicChatResponse struct {
Type string `json:"type"`
Role string `json:"role"`
Content []struct {
- Type string `json:"type"`
- Text string `json:"text,omitempty"`
- Thinking string `json:"thinking,omitempty"`
- ToolUse *struct {
- ID string `json:"id"`
- Name string `json:"name"`
- Input map[string]interface{} `json:"input"`
- } `json:"tool_use,omitempty"`
+ Type string `json:"type"`
+ Text string `json:"text,omitempty"`
+ Thinking string `json:"thinking,omitempty"`
+ ID string `json:"id"`
+ Name string `json:"name"`
+ Input map[string]interface{} `json:"input"`
} `json:"content"`
Model string `json:"model"`
StopReason string `json:"stop_reason,omitempty"`
@@ -62,11 +60,25 @@ func (provider *AnthropicProvider) GetProviderKey() interfaces.SupportedModelPro
return interfaces.Anthropic
}
+func (provider *AnthropicProvider) PrepareTextCompletionParams(params map[string]interface{}) map[string]interface{} {
+ // Check if there is a key entry for max_tokens
+ if maxTokens, exists := params["max_tokens"]; exists {
+ // Check if max_tokens_to_sample is already present
+ if _, exists := params["max_tokens_to_sample"]; !exists {
+ // If max_tokens_to_sample is not present, rename max_tokens to max_tokens_to_sample
+ params["max_tokens_to_sample"] = maxTokens
+ }
+ delete(params, "max_tokens")
+ }
+ return params
+
+}
+
// TextCompletion implements text completion using Anthropic's API
-func (provider *AnthropicProvider) TextCompletion(model, key, text string, params *interfaces.ModelParameters) (*interfaces.CompletionResult, error) {
+func (provider *AnthropicProvider) TextCompletion(model, key, text string, params *interfaces.ModelParameters) (*interfaces.BifrostResponse, error) {
startTime := time.Now()
- preparedParams := PrepareParams(params)
+ preparedParams := provider.PrepareTextCompletionParams(PrepareParams(params))
// Merge additional parameters
requestBody := MergeConfig(map[string]interface{}{
@@ -98,6 +110,9 @@ func (provider *AnthropicProvider) TextCompletion(model, key, text string, param
}
defer resp.Body.Close()
+ // Calculate latency
+ latency := time.Since(startTime).Seconds()
+
// Read the response body
body, err := io.ReadAll(resp.Body)
if err != nil {
@@ -106,36 +121,28 @@ func (provider *AnthropicProvider) TextCompletion(model, key, text string, param
// Check for error response
if resp.StatusCode != http.StatusOK {
- var errorResp struct {
- Type string `json:"type"`
- Error struct {
- Type string `json:"type"`
- Message string `json:"message"`
- } `json:"error"`
- }
- if err := json.Unmarshal(body, &errorResp); err != nil {
- return nil, fmt.Errorf("error response: %s", string(body))
- }
- return nil, fmt.Errorf("anthropic error: %s", errorResp.Error.Message)
+ return nil, fmt.Errorf("anthropic error: %s", string(body))
}
// Parse the response
var response AnthropicTextResponse
-
if err := json.Unmarshal(body, &response); err != nil {
return nil, fmt.Errorf("error parsing response: %v", err)
}
- // Calculate latency
- latency := time.Since(startTime).Seconds()
+ // Parse raw response
+ var rawResponse interface{}
+ if err := json.Unmarshal(body, &rawResponse); err != nil {
+ return nil, fmt.Errorf("error parsing raw response: %v", err)
+ }
// Create the completion result
- completionResult := &interfaces.CompletionResult{
+ completionResult := &interfaces.BifrostResponse{
ID: response.ID,
- Choices: []interfaces.CompletionResultChoice{
+ Choices: []interfaces.BifrostResponseChoice{
{
Index: 0,
- Message: interfaces.CompletionResponseChoice{
+ Message: interfaces.BifrostResponseChoiceMessage{
Role: interfaces.RoleAssistant,
Content: response.Completion,
},
@@ -147,17 +154,16 @@ func (provider *AnthropicProvider) TextCompletion(model, key, text string, param
TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
Latency: &latency,
},
- Model: response.Model,
- Provider: interfaces.Anthropic,
+ Model: response.Model,
+ Provider: interfaces.Anthropic,
+ RawResponse: rawResponse,
}
return completionResult, nil
}
// ChatCompletion implements chat completion using Anthropic's API
-func (provider *AnthropicProvider) ChatCompletion(model, key string, messages []interfaces.Message, params *interfaces.ModelParameters) (*interfaces.CompletionResult, error) {
- startTime := time.Now()
-
+func (provider *AnthropicProvider) ChatCompletion(model, key string, messages []interfaces.Message, params *interfaces.ModelParameters) (*interfaces.BifrostResponse, error) {
// Format messages for Anthropic API
var formattedMessages []map[string]interface{}
for _, msg := range messages {
@@ -169,6 +175,20 @@ func (provider *AnthropicProvider) ChatCompletion(model, key string, messages []
preparedParams := PrepareParams(params)
+ // Transform tools if present
+ if params != nil && params.Tools != nil && len(*params.Tools) > 0 {
+ var tools []map[string]interface{}
+ for _, tool := range *params.Tools {
+ tools = append(tools, map[string]interface{}{
+ "name": tool.Function.Name,
+ "description": tool.Function.Description,
+ "input_schema": tool.Function.Parameters,
+ })
+ }
+
+ preparedParams["tools"] = tools
+ }
+
// Merge additional parameters
requestBody := MergeConfig(map[string]interface{}{
"model": model,
@@ -209,67 +229,64 @@ func (provider *AnthropicProvider) ChatCompletion(model, key string, messages []
return nil, fmt.Errorf("API error: %s", string(body))
}
- // Calculate latency
- latency := time.Since(startTime).Seconds()
-
// Decode response
var anthropicResponse AnthropicChatResponse
-
if err := json.Unmarshal(body, &anthropicResponse); err != nil {
return nil, fmt.Errorf("error decoding response: %v", err)
}
- // Process the response into our CompletionResult format
- var content string
- var toolCalls []interfaces.ToolCall
- var stopReason string
+ // Parse raw response
+ var rawResponse interface{}
+ if err := json.Unmarshal(body, &rawResponse); err != nil {
+ return nil, fmt.Errorf("error parsing raw response: %v", err)
+ }
+
+ // Process the response into our BifrostResponse format
+ var choices []interfaces.BifrostResponseChoice
// Process content and tool calls
- for _, c := range anthropicResponse.Content {
+ for i, c := range anthropicResponse.Content {
+ var content string
+ var toolCalls []interfaces.ToolCall
+
switch c.Type {
case "thinking":
- if content == "" {
- content = fmt.Sprintf("\n%s\n\n\n", c.Thinking)
- }
+ content = c.Thinking
case "text":
- content += c.Text
+ content = c.Text
case "tool_use":
- if c.ToolUse != nil {
- toolCalls = append(toolCalls, interfaces.ToolCall{
- Type: maxim.StrPtr("function"),
- ID: c.ToolUse.ID,
- Function: &interfaces.FunctionCall{
- Name: c.ToolUse.Name,
- Arguments: string(must(json.Marshal(c.ToolUse.Input))),
- },
- })
- stopReason = "tool_calls"
- }
+ toolCalls = append(toolCalls, interfaces.ToolCall{
+ Type: maxim.StrPtr("function"),
+ ID: &c.ID,
+ Name: &c.Name,
+ Arguments: json.RawMessage(must(json.Marshal(c.Input))),
+ })
}
+
+ choices = append(choices, interfaces.BifrostResponseChoice{
+ Index: i,
+ Message: interfaces.BifrostResponseChoiceMessage{
+ Role: interfaces.RoleAssistant,
+ Content: content,
+ ToolCalls: &toolCalls,
+ },
+ StopReason: &anthropicResponse.StopReason,
+ Stop: anthropicResponse.StopSequence,
+ })
}
// Create the completion result
- result := &interfaces.CompletionResult{
- ID: anthropicResponse.ID,
- Choices: []interfaces.CompletionResultChoice{
- {
- Index: 0,
- Message: interfaces.CompletionResponseChoice{
- Role: interfaces.RoleAssistant,
- Content: content,
- ToolCalls: &toolCalls,
- },
- StopReason: &stopReason,
- },
- },
+ result := &interfaces.BifrostResponse{
+ ID: anthropicResponse.ID,
+ Choices: choices,
Usage: interfaces.LLMUsage{
PromptTokens: anthropicResponse.Usage.InputTokens,
CompletionTokens: anthropicResponse.Usage.OutputTokens,
TotalTokens: anthropicResponse.Usage.InputTokens + anthropicResponse.Usage.OutputTokens,
- Latency: &latency,
},
- Model: anthropicResponse.Model,
- Provider: interfaces.Anthropic,
+ Model: anthropicResponse.Model,
+ Provider: interfaces.Anthropic,
+ RawResponse: rawResponse,
}
return result, nil
diff --git a/providers/bedrock.go b/providers/bedrock.go
index 1f5309df8d..2a839215f7 100644
--- a/providers/bedrock.go
+++ b/providers/bedrock.go
@@ -78,34 +78,46 @@ type BedrockAnthropicImageSource struct {
}
type BedrockMistralToolCall struct {
- ID string `json:"id"`
- Function interfaces.FunctionCall `json:"function"`
+ ID string `json:"id"`
+ Function interfaces.Function `json:"function"`
+}
+
+type BedrockAnthropicToolCall struct {
+ ToolSpec BedrockAnthropicToolSpec `json:"toolSpec"`
+}
+
+type BedrockAnthropicToolSpec struct {
+ Name string `json:"name"`
+ Description string `json:"description"`
+ InputSchema struct {
+ Json interface{} `json:"json"`
+ } `json:"inputSchema"`
}
type BedrockProvider struct {
client *http.Client
- meta *interfaces.BedrockMetaConfig
+ meta *interfaces.MetaConfig
}
func NewBedrockProvider(config *interfaces.ProviderConfig) *BedrockProvider {
return &BedrockProvider{
client: &http.Client{Timeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds)},
- meta: config.MetaConfig.BedrockMetaConfig,
+ meta: config.MetaConfig,
}
}
-func (p *BedrockProvider) GetProviderKey() interfaces.SupportedModelProvider {
+func (provider *BedrockProvider) GetProviderKey() interfaces.SupportedModelProvider {
return interfaces.Bedrock
}
-func (p *BedrockProvider) PrepareReq(path string, jsonData []byte, accessKey string) (*http.Request, error) {
- if p.meta == nil {
+func (provider *BedrockProvider) PrepareReq(path string, jsonData []byte, accessKey string) (*http.Request, error) {
+ if provider.meta == nil {
return nil, errors.New("meta config for bedrock is not provided")
}
region := "us-east-1"
- if p.meta.Region != nil {
- region = *p.meta.Region
+ if provider.meta.Region != nil {
+ region = *provider.meta.Region
}
// Create the request with the JSON body
@@ -114,14 +126,14 @@ func (p *BedrockProvider) PrepareReq(path string, jsonData []byte, accessKey str
return nil, fmt.Errorf("error creating request: %v", err)
}
- if err := SignAWSRequest(req, accessKey, p.meta.SecretAccessKey, p.meta.SessionToken, region, "bedrock"); err != nil {
+ if err := SignAWSRequest(req, accessKey, provider.meta.SecretAccessKey, provider.meta.SessionToken, region, "bedrock"); err != nil {
return nil, err
}
return req, nil
}
-func (p *BedrockProvider) GetTextCompletionResult(result []byte, model string) (*interfaces.CompletionResult, error) {
+func (provider *BedrockProvider) GetTextCompletionResult(result []byte, model string) (*interfaces.BifrostResponse, error) {
switch model {
case "anthropic.claude-instant-v1:2":
fallthrough
@@ -133,11 +145,11 @@ func (p *BedrockProvider) GetTextCompletionResult(result []byte, model string) (
return nil, fmt.Errorf("failed to parse Bedrock response: %v", err)
}
- return &interfaces.CompletionResult{
- Choices: []interfaces.CompletionResultChoice{
+ return &interfaces.BifrostResponse{
+ Choices: []interfaces.BifrostResponseChoice{
{
Index: 0,
- Message: interfaces.CompletionResponseChoice{
+ Message: interfaces.BifrostResponseChoiceMessage{
Role: interfaces.RoleAssistant,
Content: response.Completion,
},
@@ -161,11 +173,11 @@ func (p *BedrockProvider) GetTextCompletionResult(result []byte, model string) (
return nil, fmt.Errorf("failed to parse Bedrock response: %v", err)
}
- var choices []interfaces.CompletionResultChoice
+ var choices []interfaces.BifrostResponseChoice
for i, output := range response.Outputs {
- choices = append(choices, interfaces.CompletionResultChoice{
+ choices = append(choices, interfaces.BifrostResponseChoice{
Index: i,
- Message: interfaces.CompletionResponseChoice{
+ Message: interfaces.BifrostResponseChoiceMessage{
Role: interfaces.RoleAssistant,
Content: output.Text,
},
@@ -173,7 +185,7 @@ func (p *BedrockProvider) GetTextCompletionResult(result []byte, model string) (
})
}
- return &interfaces.CompletionResult{
+ return &interfaces.BifrostResponse{
Choices: choices,
}, nil
}
@@ -181,7 +193,7 @@ func (p *BedrockProvider) GetTextCompletionResult(result []byte, model string) (
return nil, fmt.Errorf("invalid model choice: %s", model)
}
-func (p *BedrockProvider) PrepareChatCompletionMessages(messages []interfaces.Message, model string) (map[string]interface{}, error) {
+func (provider *BedrockProvider) PrepareChatCompletionMessages(messages []interfaces.Message, model string) (map[string]interface{}, error) {
switch model {
case "anthropic.claude-instant-v1:2":
fallthrough
@@ -263,8 +275,8 @@ func (p *BedrockProvider) PrepareChatCompletionMessages(messages []interfaces.Me
if msg.ToolCalls != nil {
for _, toolCall := range *msg.ToolCalls {
filteredToolCalls = append(filteredToolCalls, BedrockMistralToolCall{
- ID: toolCall.ID,
- Function: *toolCall.Function,
+ ID: *toolCall.ID,
+ Function: toolCall.Function,
})
}
}
@@ -293,10 +305,67 @@ func (p *BedrockProvider) PrepareChatCompletionMessages(messages []interfaces.Me
return nil, fmt.Errorf("invalid model choice: %s", model)
}
-func (p *BedrockProvider) TextCompletion(model, key, text string, params *interfaces.ModelParameters) (*interfaces.CompletionResult, error) {
- startTime := time.Now()
+func (provider *BedrockProvider) GetChatCompletionTools(params *interfaces.ModelParameters, model string) []BedrockAnthropicToolCall {
+ var tools []BedrockAnthropicToolCall
- preparedParams := PrepareParams(params)
+ switch model {
+ case "anthropic.claude-instant-v1:2":
+ fallthrough
+ case "anthropic.claude-v2":
+ fallthrough
+ case "anthropic.claude-v2:1":
+ fallthrough
+ case "anthropic.claude-3-sonnet-20240229-v1:0":
+ fallthrough
+ case "anthropic.claude-3-5-sonnet-20240620-v1:0":
+ fallthrough
+ case "anthropic.claude-3-5-sonnet-20241022-v2:0":
+ fallthrough
+ case "anthropic.claude-3-5-haiku-20241022-v1:0":
+ fallthrough
+ case "anthropic.claude-3-opus-20240229-v1:0":
+ fallthrough
+ case "anthropic.claude-3-7-sonnet-20250219-v1:0":
+ for _, tool := range *params.Tools {
+ tools = append(tools, BedrockAnthropicToolCall{
+ ToolSpec: BedrockAnthropicToolSpec{
+ Name: tool.Function.Name,
+ Description: tool.Function.Description,
+ InputSchema: struct {
+ Json interface{} `json:"json"`
+ }{
+ Json: tool.Function.Parameters,
+ },
+ },
+ })
+ }
+ }
+
+ return tools
+}
+
+func (provider *BedrockProvider) PrepareTextCompletionParams(params map[string]interface{}, model string) map[string]interface{} {
+ switch model {
+ case "anthropic.claude-instant-v1:2":
+ fallthrough
+ case "anthropic.claude-v2":
+ fallthrough
+ case "anthropic.claude-v2:1":
+ // Check if there is a key entry for max_tokens
+ if maxTokens, exists := params["max_tokens"]; exists {
+ // Check if max_tokens_to_sample is already present
+ if _, exists := params["max_tokens_to_sample"]; !exists {
+ // If max_tokens_to_sample is not present, rename max_tokens to max_tokens_to_sample
+ params["max_tokens_to_sample"] = maxTokens
+ }
+ delete(params, "max_tokens")
+ }
+ }
+ return params
+}
+
+func (provider *BedrockProvider) TextCompletion(model, key, text string, params *interfaces.ModelParameters) (*interfaces.BifrostResponse, error) {
+ preparedParams := provider.PrepareTextCompletionParams(PrepareParams(params), model)
requestBody := MergeConfig(map[string]interface{}{
"prompt": text,
@@ -309,13 +378,13 @@ func (p *BedrockProvider) TextCompletion(model, key, text string, params *interf
}
// Create the signed request with correct operation name
- req, err := p.PrepareReq(fmt.Sprintf("%s/invoke", model), jsonData, key)
+ req, err := provider.PrepareReq(fmt.Sprintf("%s/invoke", model), jsonData, key)
if err != nil {
return nil, fmt.Errorf("error creating request: %v", err)
}
// Execute the request
- resp, err := p.client.Do(req)
+ resp, err := provider.client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to execute request: %v", err)
}
@@ -331,24 +400,35 @@ func (p *BedrockProvider) TextCompletion(model, key, text string, params *interf
return nil, fmt.Errorf("bedrock API error: %s", string(body))
}
- result, err := p.GetTextCompletionResult(body, model)
+ result, err := provider.GetTextCompletionResult(body, model)
if err != nil {
return nil, fmt.Errorf("failed to parse response body: %v", err)
}
- // Calculate latency
- latency := time.Since(startTime).Seconds()
- result.Usage.Latency = &latency
+
+ // Parse raw response
+ var rawResponse interface{}
+ if err := json.Unmarshal(body, &rawResponse); err != nil {
+ return nil, fmt.Errorf("failed to parse raw response: %v", err)
+ }
+
+ result.RawResponse = rawResponse
return result, nil
}
-func (p *BedrockProvider) ChatCompletion(model, key string, messages []interfaces.Message, params *interfaces.ModelParameters) (*interfaces.CompletionResult, error) {
- messageBody, err := p.PrepareChatCompletionMessages(messages, model)
+func (provider *BedrockProvider) ChatCompletion(model, key string, messages []interfaces.Message, params *interfaces.ModelParameters) (*interfaces.BifrostResponse, error) {
+ messageBody, err := provider.PrepareChatCompletionMessages(messages, model)
if err != nil {
return nil, fmt.Errorf("error preparing messages: %v", err)
}
preparedParams := PrepareParams(params)
+
+ // Transform tools if present
+ if params != nil && params.Tools != nil && len(*params.Tools) > 0 {
+ preparedParams["tools"] = provider.GetChatCompletionTools(params, model)
+ }
+
requestBody := MergeConfig(messageBody, preparedParams)
// Marshal the request body
@@ -360,23 +440,23 @@ func (p *BedrockProvider) ChatCompletion(model, key string, messages []interface
// Format the path with proper model identifier
path := fmt.Sprintf("%s/converse", model)
- if p.meta != nil && p.meta.InferenceProfiles != nil {
- if inferenceProfileId, ok := p.meta.InferenceProfiles[model]; ok {
- if p.meta.ARN != nil {
- encodedModelIdentifier := url.PathEscape(fmt.Sprintf("%s/%s", *p.meta.ARN, inferenceProfileId))
+ if provider.meta != nil && provider.meta.InferenceProfiles != nil {
+ if inferenceProfileId, ok := provider.meta.InferenceProfiles[model]; ok {
+ if provider.meta.ARN != nil {
+ encodedModelIdentifier := url.PathEscape(fmt.Sprintf("%s/%s", *provider.meta.ARN, inferenceProfileId))
path = fmt.Sprintf("%s/converse", encodedModelIdentifier)
}
}
}
// Create the signed request
- req, err := p.PrepareReq(path, jsonData, key)
+ req, err := provider.PrepareReq(path, jsonData, key)
if err != nil {
return nil, fmt.Errorf("error creating request: %v", err)
}
// Execute the request
- resp, err := p.client.Do(req)
+ resp, err := provider.client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to execute request: %v", err)
}
@@ -397,11 +477,17 @@ func (p *BedrockProvider) ChatCompletion(model, key string, messages []interface
return nil, fmt.Errorf("failed to parse Bedrock response: %v", err)
}
- var choices []interfaces.CompletionResultChoice
+ // Parse raw response
+ var rawResponse interface{}
+ if err := json.Unmarshal(body, &rawResponse); err != nil {
+ return nil, fmt.Errorf("failed to parse raw response: %v", err)
+ }
+
+ var choices []interfaces.BifrostResponseChoice
for i, choice := range response.Output.Message.Content {
- choices = append(choices, interfaces.CompletionResultChoice{
+ choices = append(choices, interfaces.BifrostResponseChoice{
Index: i,
- Message: interfaces.CompletionResponseChoice{
+ Message: interfaces.BifrostResponseChoiceMessage{
Role: interfaces.RoleAssistant,
Content: choice.Text,
},
@@ -411,7 +497,7 @@ func (p *BedrockProvider) ChatCompletion(model, key string, messages []interface
latency := float64(response.Metrics.Latency)
- result := &interfaces.CompletionResult{
+ result := &interfaces.BifrostResponse{
Choices: choices,
Usage: interfaces.LLMUsage{
PromptTokens: response.Usage.InputTokens,
@@ -419,8 +505,9 @@ func (p *BedrockProvider) ChatCompletion(model, key string, messages []interface
TotalTokens: response.Usage.TotalTokens,
Latency: &latency,
},
- Model: model,
- Provider: interfaces.Bedrock,
+ Model: model,
+ Provider: interfaces.Bedrock,
+ RawResponse: rawResponse,
}
return result, nil
diff --git a/providers/cohere.go b/providers/cohere.go
index 04062671f0..4030c7091a 100644
--- a/providers/cohere.go
+++ b/providers/cohere.go
@@ -5,6 +5,7 @@ import (
"bytes"
"encoding/json"
"fmt"
+ "io"
"net/http"
"slices"
"time"
@@ -24,14 +25,20 @@ type CohereTool struct {
ParameterDefinitions map[string]CohereParameterDefinition `json:"parameter_definitions"`
}
+type CohereToolCall struct {
+ Name string `json:"name"`
+ Parameters interface{} `json:"parameters"`
+}
+
// CohereChatResponse represents the response from Cohere's chat API
type CohereChatResponse struct {
ResponseID string `json:"response_id"`
Text string `json:"text"`
GenerationID string `json:"generation_id"`
ChatHistory []struct {
- Role interfaces.ModelChatMessageRole `json:"role"`
- Message string `json:"message"`
+ Role interfaces.ModelChatMessageRole `json:"role"`
+ Message string `json:"message"`
+ ToolCalls []CohereToolCall `json:"tool_calls"`
} `json:"chat_history"`
FinishReason string `json:"finish_reason"`
Meta struct {
@@ -47,6 +54,7 @@ type CohereChatResponse struct {
OutputTokens float64 `json:"output_tokens"`
} `json:"tokens"`
} `json:"meta"`
+ ToolCalls []CohereToolCall `json:"tool_calls"`
}
// OpenAIProvider implements the Provider interface for OpenAI
@@ -65,13 +73,11 @@ func (provider *CohereProvider) GetProviderKey() interfaces.SupportedModelProvid
return interfaces.Cohere
}
-func (provider *CohereProvider) TextCompletion(model, key, text string, params *interfaces.ModelParameters) (*interfaces.CompletionResult, error) {
+func (provider *CohereProvider) TextCompletion(model, key, text string, params *interfaces.ModelParameters) (*interfaces.BifrostResponse, error) {
return nil, fmt.Errorf("text completion is not supported by Cohere")
}
-func (provider *CohereProvider) ChatCompletion(model, key string, messages []interfaces.Message, params *interfaces.ModelParameters) (*interfaces.CompletionResult, error) {
- startTime := time.Now()
-
+func (provider *CohereProvider) ChatCompletion(model, key string, messages []interfaces.Message, params *interfaces.ModelParameters) (*interfaces.BifrostResponse, error) {
// Get the last message and chat history
lastMessage := messages[len(messages)-1]
chatHistory := messages[:len(messages)-1]
@@ -99,29 +105,23 @@ func (provider *CohereProvider) ChatCompletion(model, key string, messages []int
var tools []CohereTool
for _, tool := range *params.Tools {
parameterDefinitions := make(map[string]CohereParameterDefinition)
- if tool.Function.Parameters != nil {
- paramsMap, ok := tool.Function.Parameters.(map[string]interface{})
+ params := tool.Function.Parameters
+ for name, prop := range tool.Function.Parameters.Properties {
+ propMap, ok := prop.(map[string]interface{})
if ok {
- if properties, ok := paramsMap["properties"].(map[string]interface{}); ok {
- for name, prop := range properties {
- propMap, ok := prop.(map[string]interface{})
- if ok {
- paramDef := CohereParameterDefinition{
- Required: slices.Contains(paramsMap["required"].([]string), name),
- }
-
- if typeStr, ok := propMap["type"].(string); ok {
- paramDef.Type = typeStr
- }
-
- if desc, ok := propMap["description"].(string); ok {
- paramDef.Description = &desc
- }
-
- parameterDefinitions[name] = paramDef
- }
- }
+ paramDef := CohereParameterDefinition{
+ Required: slices.Contains(params.Required, name),
}
+
+ if typeStr, ok := propMap["type"].(string); ok {
+ paramDef.Type = typeStr
+ }
+
+ if desc, ok := propMap["description"].(string); ok {
+ paramDef.Description = &desc
+ }
+
+ parameterDefinitions[name] = paramDef
}
}
@@ -157,28 +157,40 @@ func (provider *CohereProvider) ChatCompletion(model, key string, messages []int
}
defer resp.Body.Close()
+ // Read response body
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read response body: %v", err)
+ }
+
// Handle error response
if resp.StatusCode != http.StatusOK {
- var errorResp struct {
- Message string `json:"message"`
- }
- if err := json.NewDecoder(resp.Body).Decode(&errorResp); err != nil {
- return nil, fmt.Errorf("error decoding error response: %v", err)
- }
- return nil, fmt.Errorf("cohere error: %s", errorResp.Message)
+ return nil, fmt.Errorf("cohere error: %s", string(body))
}
// Decode response
var response CohereChatResponse
- if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
- return nil, fmt.Errorf("error decoding response: %v", err)
+ if err := json.Unmarshal(body, &response); err != nil {
+ return nil, fmt.Errorf("failed to parse Bedrock response: %v", err)
}
- // Transform tool calls if present
- var toolCalls *[]interfaces.ToolCall
+ // Parse raw response
+ var rawResponse interface{}
+ if err := json.Unmarshal(body, &rawResponse); err != nil {
+ return nil, fmt.Errorf("failed to parse raw response: %v", err)
+ }
- // Calculate latency
- latency := time.Since(startTime).Seconds()
+ // Transform tool calls if present
+ var toolCalls []interfaces.ToolCall
+ if response.ToolCalls != nil {
+ for _, tool := range response.ToolCalls {
+ args := json.RawMessage(must(json.Marshal(tool.Parameters)))
+ toolCalls = append(toolCalls, interfaces.ToolCall{
+ Name: &tool.Name,
+ Arguments: args,
+ })
+ }
+ }
// Get role and content from the last message in chat history
var role interfaces.ModelChatMessageRole
@@ -188,20 +200,20 @@ func (provider *CohereProvider) ChatCompletion(model, key string, messages []int
role = lastMsg.Role
content = lastMsg.Message
} else {
- role = interfaces.ModelChatMessageRole("assistant")
+ role = interfaces.RoleChatbot
content = response.Text
}
// Create completion result
- result := &interfaces.CompletionResult{
+ result := &interfaces.BifrostResponse{
ID: response.ResponseID,
- Choices: []interfaces.CompletionResultChoice{
+ Choices: []interfaces.BifrostResponseChoice{
{
Index: 0,
- Message: interfaces.CompletionResponseChoice{
+ Message: interfaces.BifrostResponseChoiceMessage{
Role: role,
Content: content,
- ToolCalls: toolCalls,
+ ToolCalls: &toolCalls,
},
StopReason: &response.FinishReason,
},
@@ -211,14 +223,14 @@ func (provider *CohereProvider) ChatCompletion(model, key string, messages []int
PromptTokens: int(response.Meta.Tokens.InputTokens),
CompletionTokens: int(response.Meta.Tokens.OutputTokens),
TotalTokens: int(response.Meta.Tokens.InputTokens + response.Meta.Tokens.OutputTokens),
- Latency: &latency,
},
BilledUsage: &interfaces.BilledLLMUsage{
PromptTokens: float64Ptr(response.Meta.BilledUnits.InputTokens),
CompletionTokens: float64Ptr(response.Meta.BilledUnits.OutputTokens),
},
- Model: model,
- Provider: interfaces.Cohere,
+ Model: model,
+ Provider: interfaces.Cohere,
+ RawResponse: rawResponse,
}
return result, nil
@@ -226,14 +238,26 @@ func (provider *CohereProvider) ChatCompletion(model, key string, messages []int
// Helper function to convert chat history to the correct type
func convertChatHistory(history []struct {
- Role interfaces.ModelChatMessageRole `json:"role"`
- Message string `json:"message"`
-}) *[]interfaces.CompletionResponseChoice {
- converted := make([]interfaces.CompletionResponseChoice, len(history))
+ Role interfaces.ModelChatMessageRole `json:"role"`
+ Message string `json:"message"`
+ ToolCalls []CohereToolCall `json:"tool_calls"`
+}) *[]interfaces.BifrostResponseChoiceMessage {
+ converted := make([]interfaces.BifrostResponseChoiceMessage, len(history))
for i, msg := range history {
- converted[i] = interfaces.CompletionResponseChoice{
- Role: msg.Role,
- Content: msg.Message,
+ var toolCalls []interfaces.ToolCall
+ if msg.ToolCalls != nil {
+ for _, tool := range msg.ToolCalls {
+ args := json.RawMessage(must(json.Marshal(tool.Parameters)))
+ toolCalls = append(toolCalls, interfaces.ToolCall{
+ Name: &tool.Name,
+ Arguments: args,
+ })
+ }
+ }
+ converted[i] = interfaces.BifrostResponseChoiceMessage{
+ Role: msg.Role,
+ Content: msg.Message,
+ ToolCalls: &toolCalls,
}
}
return &converted
diff --git a/providers/openai.go b/providers/openai.go
index b97a86dcdd..16d0f818b7 100644
--- a/providers/openai.go
+++ b/providers/openai.go
@@ -5,16 +5,57 @@ import (
"bytes"
"encoding/json"
"fmt"
+ "io"
"net/http"
"time"
)
+type OpenAIToolCall struct {
+ Type *string `json:"type"`
+ ID *string `json:"id"`
+ Function struct {
+ Name string `json:"name"`
+ Arguments string `json:"arguments"`
+ } `json:"function"`
+}
+
+type OpenAIMessage struct {
+ Role interfaces.ModelChatMessageRole `json:"role"`
+ Content string `json:"content"`
+ ToolCalls *[]OpenAIToolCall `json:"tool_calls,omitempty"`
+}
+
+type OpenAIChoice struct {
+ Index int `json:"index"`
+ Message OpenAIMessage `json:"message"`
+ FinishReason *string `json:"finish_reason"`
+ LogProbs *interface{} `json:"logprobs"`
+}
+
type OpenAIResponse struct {
- ID string `json:"id"`
- Choices []interfaces.CompletionResultChoice `json:"choices"`
- Usage interfaces.LLMUsage `json:"usage"`
- Model string `json:"model"`
- Created interface{} `json:"created"`
+ ID string `json:"id"`
+ Object string `json:"object"`
+ Choices []OpenAIChoice `json:"choices"`
+ Usage struct {
+ PromptTokens int `json:"prompt_tokens"`
+ CompletionTokens int `json:"completion_tokens"`
+ TotalTokens int `json:"total_tokens"`
+ PromptTokenDetails struct {
+ CachedToken int `json:"cached_tokens"`
+ AudioToken int `json:"audio_tokens"`
+ } `json:"prompt_tokens_details"`
+ CompletionTokenDetails struct {
+ ReasoningTokens int `json:"reasoning_tokens"`
+ AudioTokens int `json:"audio_tokens"`
+ AcceptedPredictionTokens int `json:"accepted_prediction_tokens"`
+ RejectedPredictionTokens int `json:"rejected_prediction_tokens"`
+ } `json:"completion_tokens_details"`
+ Latency float64 `json:"latency"`
+ } `json:"usage"`
+ Model string `json:"model"`
+ Created interface{} `json:"created"`
+ ServiceTier string `json:"service_tier"`
+ SystemFingerprint string `json:"system_fingerprint"`
}
// OpenAIProvider implements the Provider interface for OpenAI
@@ -34,7 +75,7 @@ func (provider *OpenAIProvider) GetProviderKey() interfaces.SupportedModelProvid
}
// TextCompletion performs text completion
-func (provider *OpenAIProvider) TextCompletion(model, key, text string, params *interfaces.ModelParameters) (*interfaces.CompletionResult, error) {
+func (provider *OpenAIProvider) TextCompletion(model, key, text string, params *interfaces.ModelParameters) (*interfaces.BifrostResponse, error) {
return nil, fmt.Errorf("text completion is not supported by OpenAI")
}
@@ -55,8 +96,7 @@ func (provider *OpenAIProvider) sanitizeParameters(params *interfaces.ModelParam
return sanitized
}
-// ChatCompletion implements chat completion using OpenAI's API
-func (provider *OpenAIProvider) ChatCompletion(model, key string, messages []interfaces.Message, params *interfaces.ModelParameters) (*interfaces.CompletionResult, error) {
+func (provider *OpenAIProvider) ChatCompletion(model, key string, messages []interfaces.Message, params *interfaces.ModelParameters) (*interfaces.BifrostResponse, error) {
startTime := time.Now()
// Format messages for OpenAI API
@@ -108,35 +148,69 @@ func (provider *OpenAIProvider) ChatCompletion(model, key string, messages []int
latency := time.Since(startTime).Seconds()
+ // Read response body
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return nil, fmt.Errorf("error reading response: %v", err)
+ }
+
// Handle error response
if resp.StatusCode != http.StatusOK {
- var errorResp struct {
- Error struct {
- Message string `json:"message"`
- Type string `json:"type"`
- Param any `json:"param"`
- Code string `json:"code"`
- } `json:"error"`
- }
- if err := json.NewDecoder(resp.Body).Decode(&errorResp); err != nil {
- return nil, fmt.Errorf("error decoding error response: %v", err)
- }
- return nil, fmt.Errorf("OpenAI error: %s", errorResp.Error.Message)
+ return nil, fmt.Errorf("OpenAI error: %s", string(body))
}
- // Decode response
+ // Decode structured response
var response OpenAIResponse
+ if err := json.Unmarshal(body, &response); err != nil {
+ return nil, fmt.Errorf("error decoding structured response: %v", err)
+ }
- if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
- return nil, fmt.Errorf("error decoding response: %v", err)
+ // Decode raw response
+ var rawResponse interface{}
+ if err := json.Unmarshal(body, &rawResponse); err != nil {
+ return nil, fmt.Errorf("error decoding raw response: %v", err)
}
- // Convert the raw result to CompletionResult
- result := &interfaces.CompletionResult{
+ // Transform choices to include tool calls
+ var choices []interfaces.BifrostResponseChoice
+ for i, c := range response.Choices {
+ // Transform tool calls if present
+ var toolCalls []interfaces.ToolCall
+ if c.Message.ToolCalls != nil {
+ for _, tool := range *c.Message.ToolCalls {
+ toolCalls = append(toolCalls, interfaces.ToolCall{
+ ID: tool.ID,
+ Type: tool.Type,
+ Name: &tool.Function.Name,
+ Arguments: json.RawMessage(tool.Function.Arguments),
+ })
+ }
+ }
+
+ choices = append(choices, interfaces.BifrostResponseChoice{
+ Index: i,
+ Message: interfaces.BifrostResponseChoiceMessage{
+ Role: c.Message.Role,
+ Content: c.Message.Content,
+ ToolCalls: &toolCalls,
+ },
+ StopReason: c.FinishReason,
+ LogProbs: c.LogProbs,
+ })
+ }
+
+ result := &interfaces.BifrostResponse{
ID: response.ID,
- Choices: response.Choices,
- Usage: response.Usage,
- Model: response.Model,
+ Choices: choices,
+ Usage: interfaces.LLMUsage{
+ PromptTokens: response.Usage.PromptTokens,
+ CompletionTokens: response.Usage.CompletionTokens,
+ TotalTokens: response.Usage.TotalTokens,
+ Latency: &latency,
+ },
+ Model: response.Model,
+ Provider: interfaces.OpenAI,
+ RawResponse: rawResponse,
}
// Handle the created field conversion
@@ -150,9 +224,5 @@ func (provider *OpenAIProvider) ChatCompletion(model, key string, messages []int
}
}
- // Add provider-specific information
- result.Provider = interfaces.OpenAI
- result.Usage.Latency = &latency
-
return result, nil
}
diff --git a/providers/utils.go b/providers/utils.go
index 56c5908089..a3a964d9a5 100644
--- a/providers/utils.go
+++ b/providers/utils.go
@@ -12,6 +12,8 @@ import (
"reflect"
"time"
+ "maps"
+
"github.com/aws/aws-sdk-go-v2/aws"
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
"github.com/aws/aws-sdk-go-v2/config"
@@ -47,7 +49,7 @@ func PrepareParams(params *interfaces.ModelParameters) map[string]interface{} {
typ := val.Type()
// Iterate through all fields
- for i := 0; i < val.NumField(); i++ {
+ for i := range val.NumField() {
field := val.Field(i)
fieldType := typ.Field(i)
@@ -69,9 +71,7 @@ func PrepareParams(params *interfaces.ModelParameters) map[string]interface{} {
}
// Handle ExtraParams
- for k, v := range params.ExtraParams {
- flatParams[k] = v
- }
+ maps.Copy(flatParams, params.ExtraParams)
return flatParams
}
diff --git a/tests/anthropic_test.go b/tests/anthropic_test.go
index f4790f8b71..4791d8ffc8 100644
--- a/tests/anthropic_test.go
+++ b/tests/anthropic_test.go
@@ -11,21 +11,16 @@ import (
// setupAnthropicRequests sends multiple test requests to Anthropic
func setupAnthropicRequests(bifrost *bifrost.Bifrost) {
- anthropicMessages := []string{
- "What's your favorite programming language?",
- "Can you help me write a Go function?",
- "What's the best way to learn programming?",
- "Tell me about artificial intelligence.",
- }
-
ctx := context.Background()
+ maxTokens := 4096
+
+ params := interfaces.ModelParameters{
+ MaxTokens: &maxTokens,
+ }
+
+ // Text completion request
go func() {
- params := interfaces.ModelParameters{
- ExtraParams: map[string]interface{}{
- "max_tokens_to_sample": 4096,
- },
- }
text := "Hello world!"
result, err := bifrost.TextCompletionRequest(interfaces.Anthropic, &interfaces.BifrostRequest{
@@ -42,10 +37,11 @@ func setupAnthropicRequests(bifrost *bifrost.Bifrost) {
}
}()
- params := interfaces.ModelParameters{
- ExtraParams: map[string]interface{}{
- "max_tokens": 4096,
- },
+ // Regular chat completion requests
+ anthropicMessages := []string{
+ "Hello! How are you today?",
+ "Tell me a joke!",
+ "What's your favorite programming language?",
}
for i, message := range anthropicMessages {
@@ -73,6 +69,70 @@ func setupAnthropicRequests(bifrost *bifrost.Bifrost) {
}
}(message, delay, i)
}
+
+ // Tool calls test
+ setupAnthropicToolCalls(bifrost, ctx)
+}
+
+// setupAnthropicToolCalls tests Anthropic's function calling capability
+func setupAnthropicToolCalls(bifrost *bifrost.Bifrost, ctx context.Context) {
+ anthropicMessages := []string{
+ "What's the weather like in Mumbai?",
+ }
+
+ maxTokens := 4096
+
+ params := interfaces.ModelParameters{
+ Tools: &[]interfaces.Tool{{
+ Type: "function",
+ Function: interfaces.Function{
+ Name: "get_weather",
+ Description: "Get the current weather in a given location",
+ Parameters: interfaces.FunctionParameters{
+ Type: "object",
+ Properties: map[string]interface{}{
+ "location": map[string]interface{}{
+ "type": "string",
+ "description": "The city and state, e.g. San Francisco, CA",
+ },
+ "unit": map[string]interface{}{
+ "type": "string",
+ "enum": []string{"celsius", "fahrenheit"},
+ },
+ },
+ Required: []string{"location"},
+ },
+ },
+ }},
+ MaxTokens: &maxTokens,
+ }
+
+ for i, message := range anthropicMessages {
+ delay := time.Duration(500+100*i) * time.Millisecond
+ go func(msg string, delay time.Duration, index int) {
+ time.Sleep(delay)
+ messages := []interfaces.Message{
+ {
+ Role: interfaces.RoleUser,
+ Content: &msg,
+ },
+ }
+ result, err := bifrost.ChatCompletionRequest(interfaces.Anthropic, &interfaces.BifrostRequest{
+ Model: "claude-3-7-sonnet-20250219",
+ Input: interfaces.RequestInput{
+ ChatInput: &messages,
+ },
+ Params: ¶ms,
+ }, ctx)
+
+ if err != nil {
+ fmt.Printf("Error in Anthropic tool call request %d: %v\n", index+1, err)
+ } else {
+ toolCall := *result.Choices[1].Message.ToolCalls
+ fmt.Printf("🤖 Tool Call Result %d: %s\n", index+1, toolCall[0].Arguments)
+ }
+ }(message, delay, i)
+ }
}
func TestAnthropic(t *testing.T) {
diff --git a/tests/bedrock_test.go b/tests/bedrock_test.go
index 4fb3ba53c9..168c44693f 100644
--- a/tests/bedrock_test.go
+++ b/tests/bedrock_test.go
@@ -11,21 +11,16 @@ import (
// setupBedrockRequests sends multiple test requests to Bedrock
func setupBedrockRequests(bifrost *bifrost.Bifrost) {
- bedrockMessages := []string{
- "What's your favorite programming language?",
- "Can you help me write a Go function?",
- "What's the best way to learn programming?",
- "Tell me about artificial intelligence.",
- }
-
ctx := context.Background()
+ maxTokens := 4096
+
+ params := interfaces.ModelParameters{
+ MaxTokens: &maxTokens,
+ }
+
+ // Text completion request
go func() {
- params := interfaces.ModelParameters{
- ExtraParams: map[string]interface{}{
- "max_tokens_to_sample": 4096,
- },
- }
text := "\n\nHuman:\n\nAssistant:"
result, err := bifrost.TextCompletionRequest(interfaces.Bedrock, &interfaces.BifrostRequest{
@@ -42,10 +37,11 @@ func setupBedrockRequests(bifrost *bifrost.Bifrost) {
}
}()
- params := interfaces.ModelParameters{
- ExtraParams: map[string]interface{}{
- "max_tokens": 4096,
- },
+ // Regular chat completion requests
+ bedrockMessages := []string{
+ "Hello! How are you today?",
+ "Tell me a joke!",
+ "What's your favorite programming language?",
}
for i, message := range bedrockMessages {
@@ -73,6 +69,75 @@ func setupBedrockRequests(bifrost *bifrost.Bifrost) {
}
}(message, delay, i)
}
+
+ // Tool calls test
+ setupBedrockToolCalls(bifrost, ctx)
+}
+
+// setupBedrockToolCalls tests Bedrock's function calling capability
+func setupBedrockToolCalls(bifrost *bifrost.Bifrost, ctx context.Context) {
+ bedrockMessages := []string{
+ "What's the weather like in Mumbai?",
+ }
+
+ maxTokens := 4096
+
+ params := interfaces.ModelParameters{
+ Tools: &[]interfaces.Tool{{
+ Type: "function",
+ Function: interfaces.Function{
+ Name: "get_weather",
+ Description: "Get the current weather in a given location",
+ Parameters: interfaces.FunctionParameters{
+ Type: "object",
+ Properties: map[string]interface{}{
+ "location": map[string]interface{}{
+ "type": "string",
+ "description": "The city and state, e.g. San Francisco, CA",
+ },
+ "unit": map[string]interface{}{
+ "type": "string",
+ "enum": []string{"celsius", "fahrenheit"},
+ },
+ },
+ Required: []string{"location"},
+ },
+ },
+ }},
+ MaxTokens: &maxTokens,
+ }
+
+ for i, message := range bedrockMessages {
+ delay := time.Duration(500+100*i) * time.Millisecond
+ go func(msg string, delay time.Duration, index int) {
+ time.Sleep(delay)
+ messages := []interfaces.Message{
+ {
+ Role: interfaces.RoleUser,
+ Content: &msg,
+ },
+ }
+ result, err := bifrost.ChatCompletionRequest(interfaces.Bedrock, &interfaces.BifrostRequest{
+ Model: "anthropic.claude-3-sonnet-20240229-v1:0",
+ Input: interfaces.RequestInput{
+ ChatInput: &messages,
+ },
+ Params: ¶ms,
+ }, ctx)
+
+ if err != nil {
+ fmt.Printf("Error in Bedrock tool call request %d: %v\n", index+1, err)
+ } else {
+ if result.Choices[0].Message.ToolCalls != nil && len(*result.Choices[0].Message.ToolCalls) > 0 {
+ toolCall := *result.Choices[0].Message.ToolCalls
+ fmt.Printf("🤖 Tool Call Result %d: %s\n", index+1, toolCall[0].Arguments)
+ } else {
+ fmt.Printf("🤖 No tool calls in response %d\n", index+1)
+ fmt.Println("Raw JSON Response", result.RawResponse)
+ }
+ }
+ }(message, delay, i)
+ }
}
func TestBedrock(t *testing.T) {
diff --git a/tests/cohere_test.go b/tests/cohere_test.go
index 0fcb6ec7b3..e9346bc60d 100644
--- a/tests/cohere_test.go
+++ b/tests/cohere_test.go
@@ -12,7 +12,6 @@ import (
// setupCohereRequests sends multiple test requests to Cohere
func setupCohereRequests(bifrost *bifrost.Bifrost) {
text := "Hello world!"
-
ctx := context.Background()
// Text completion request
@@ -31,15 +30,14 @@ func setupCohereRequests(bifrost *bifrost.Bifrost) {
}
}()
- // Chat completion requests with different messages and delays
- CohereMessages := []string{
+ // Regular chat completion requests
+ cohereMessages := []string{
"Hello! How are you today?",
- "What's the weather like?",
"Tell me a joke!",
"What's your favorite programming language?",
}
- for i, message := range CohereMessages {
+ for i, message := range cohereMessages {
delay := time.Duration(100*(i+1)) * time.Millisecond
go func(msg string, delay time.Duration, index int) {
time.Sleep(delay)
@@ -63,6 +61,66 @@ func setupCohereRequests(bifrost *bifrost.Bifrost) {
}
}(message, delay, i)
}
+
+ // Tool calls test
+ setupCohereToolCalls(bifrost, ctx)
+}
+
+// setupCohereToolCalls tests Cohere's function calling capability
+func setupCohereToolCalls(bifrost *bifrost.Bifrost, ctx context.Context) {
+ cohereMessages := []string{
+ "What's the weather like in Mumbai?",
+ }
+
+ params := interfaces.ModelParameters{
+ Tools: &[]interfaces.Tool{{
+ Type: "function",
+ Function: interfaces.Function{
+ Name: "get_weather",
+ Description: "Get the current weather in a given location",
+ Parameters: interfaces.FunctionParameters{
+ Type: "object",
+ Properties: map[string]interface{}{
+ "location": map[string]interface{}{
+ "type": "string",
+ "description": "The city and state, e.g. San Francisco, CA",
+ },
+ "unit": map[string]interface{}{
+ "type": "string",
+ "enum": []string{"celsius", "fahrenheit"},
+ },
+ },
+ Required: []string{"location"},
+ },
+ },
+ }},
+ }
+
+ for i, message := range cohereMessages {
+ delay := time.Duration(100*(i+1)) * time.Millisecond
+ go func(msg string, delay time.Duration, index int) {
+ time.Sleep(delay)
+ messages := []interfaces.Message{
+ {
+ Role: interfaces.RoleUser,
+ Content: &msg,
+ },
+ }
+ result, err := bifrost.ChatCompletionRequest(interfaces.Cohere, &interfaces.BifrostRequest{
+ Model: "command-a-03-2025",
+ Input: interfaces.RequestInput{
+ ChatInput: &messages,
+ },
+ Params: ¶ms,
+ }, ctx)
+ if err != nil {
+ fmt.Printf("Error in Cohere tool call request %d: %v\n", index+1, err)
+ } else {
+ toolCall := *result.Choices[0].Message.ToolCalls
+ fmt.Printf("🐒 Tool Call Result %d: %s\n", index+1, toolCall[0].Arguments)
+ }
+ }(message, delay, i)
+ }
}
func TestCohere(t *testing.T) {
diff --git a/tests/openai_test.go b/tests/openai_test.go
index 246f1a4114..0c7a214170 100644
--- a/tests/openai_test.go
+++ b/tests/openai_test.go
@@ -12,7 +12,6 @@ import (
// setupOpenAIRequests sends multiple test requests to OpenAI
func setupOpenAIRequests(bifrost *bifrost.Bifrost) {
text := "Hello world!"
-
ctx := context.Background()
// Text completion request
@@ -31,10 +30,9 @@ func setupOpenAIRequests(bifrost *bifrost.Bifrost) {
}
}()
- // Chat completion requests with different messages and delays
+ // Regular chat completion requests
openAIMessages := []string{
"Hello! How are you today?",
- "What's the weather like?",
"Tell me a joke!",
"What's your favorite programming language?",
}
@@ -63,6 +61,66 @@ func setupOpenAIRequests(bifrost *bifrost.Bifrost) {
}
}(message, delay, i)
}
+
+ // Tool calls test
+ setupOpenAIToolCalls(bifrost, ctx)
+}
+
+// setupOpenAIToolCalls tests OpenAI's function calling capability
+func setupOpenAIToolCalls(bifrost *bifrost.Bifrost, ctx context.Context) {
+ openAIMessages := []string{
+ "What's the weather like in Mumbai?",
+ }
+
+ params := interfaces.ModelParameters{
+ Tools: &[]interfaces.Tool{{
+ Type: "function",
+ Function: interfaces.Function{
+ Name: "get_weather",
+ Description: "Get the current weather in a given location",
+ Parameters: interfaces.FunctionParameters{
+ Type: "object",
+ Properties: map[string]interface{}{
+ "location": map[string]interface{}{
+ "type": "string",
+ "description": "The city and state, e.g. San Francisco, CA",
+ },
+ "unit": map[string]interface{}{
+ "type": "string",
+ "enum": []string{"celsius", "fahrenheit"},
+ },
+ },
+ Required: []string{"location"},
+ },
+ },
+ }},
+ }
+
+ for i, message := range openAIMessages {
+ delay := time.Duration(100*(i+1)) * time.Millisecond
+ go func(msg string, delay time.Duration, index int) {
+ time.Sleep(delay)
+ messages := []interfaces.Message{
+ {
+ Role: interfaces.RoleUser,
+ Content: &msg,
+ },
+ }
+ result, err := bifrost.ChatCompletionRequest(interfaces.OpenAI, &interfaces.BifrostRequest{
+ Model: "gpt-4o-mini",
+ Input: interfaces.RequestInput{
+ ChatInput: &messages,
+ },
+ Params: ¶ms,
+ }, ctx)
+ if err != nil {
+ fmt.Printf("Error in OpenAI tool call request %d: %v\n", index+1, err)
+ } else {
+ toolCall := *result.Choices[0].Message.ToolCalls
+ fmt.Printf("🐒 Tool Call Result %d: %s\n", index+1, toolCall[0].Arguments)
+ }
+ }(message, delay, i)
+ }
}
func TestOpenAI(t *testing.T) {
diff --git a/tests/plugin.go b/tests/plugin.go
index dace4ff9e9..2cc435e3d7 100644
--- a/tests/plugin.go
+++ b/tests/plugin.go
@@ -37,7 +37,7 @@ func (plugin *Plugin) PreHook(ctx context.Context, req *interfaces.BifrostReques
return ctx, req, nil
}
-func (plugin *Plugin) PostHook(ctx context.Context, res *interfaces.CompletionResult) (*interfaces.CompletionResult, error) {
+func (plugin *Plugin) PostHook(ctx context.Context, res *interfaces.BifrostResponse) (*interfaces.BifrostResponse, error) {
// Get traceID from context
traceID, ok := ctx.Value(traceIDKey).(string)
if !ok {
diff --git a/tests/setup.go b/tests/setup.go
index fc1923e15c..f8f5bddf5b 100644
--- a/tests/setup.go
+++ b/tests/setup.go
@@ -60,10 +60,8 @@ func getBifrost() (*bifrost.Bifrost, error) {
DefaultRequestTimeoutInSeconds: 30,
},
MetaConfig: &interfaces.MetaConfig{
- BedrockMetaConfig: &interfaces.BedrockMetaConfig{
- SecretAccessKey: "AMpq95pNadM2fD1GlcNvjbMiGhizwYaGKJxv+nti",
- Region: maxim.StrPtr("us-east-1"),
- },
+ SecretAccessKey: "AMpq95pNadM2fD1GlcNvjbMiGhizwYaGKJxv+nti",
+ Region: maxim.StrPtr("us-east-1"),
},
},
interfaces.Cohere: {