diff --git a/bifrost.go b/bifrost.go index 2d3cd61d5a..14c681bde2 100644 --- a/bifrost.go +++ b/bifrost.go @@ -52,6 +52,8 @@ func (bifrost *Bifrost) createProviderFromProviderKey(providerKey interfaces.Sup return providers.NewBedrockProvider(config), nil case interfaces.Cohere: return providers.NewCohereProvider(config), nil + case interfaces.Azure: + return providers.NewAzureProvider(config, bifrost.logger), nil default: return nil, fmt.Errorf("unsupported provider: %s", providerKey) } diff --git a/interfaces/bifrost.go b/interfaces/bifrost.go index c2e1733e1e..2ba967762c 100644 --- a/interfaces/bifrost.go +++ b/interfaces/bifrost.go @@ -67,9 +67,10 @@ type ModelParameters struct { } type FunctionParameters struct { - Type string `json:"type,"` - Required []string `json:"required"` - Properties map[string]interface{} `json:"properties"` + Type string `json:"type,"` + Description *string `json:"description,omitempty"` + Required []string `json:"required"` + Properties map[string]interface{} `json:"properties"` } // Function represents a function definition for tool calls @@ -165,9 +166,17 @@ type ContentLogProb struct { TopLogProbs []LogProb `json:"top_logprobs"` } +type TextCompletionLogProb struct { + TextOffset []int `json:"text_offset"` + TokenLogProbs []float64 `json:"token_logprobs"` + Tokens []string `json:"tokens"` + TopLogProbs []map[string]float64 `json:"top_logprobs"` +} + type LogProbs struct { - Content []ContentLogProb `json:"content"` - Refusal []LogProb `json:"refusal"` + Content []ContentLogProb `json:"content,omitempty"` + Refusal []LogProb `json:"refusal,omitempty"` + Text TextCompletionLogProb `json:"text,omitempty"` } type FunctionCall struct { diff --git a/interfaces/meta/azure.go b/interfaces/meta/azure.go new file mode 100644 index 0000000000..700cd897ae --- /dev/null +++ b/interfaces/meta/azure.go @@ -0,0 +1,39 @@ +package meta + +type AzureMetaConfig struct { + Endpoint string `json:"endpoint"` + Deployments map[string]string `json:"deployments,omitempty"` + APIVersion *string `json:"api_version,omitempty"` +} + +func (c *AzureMetaConfig) GetSecretAccessKey() *string { + return nil +} + +func (c *AzureMetaConfig) GetRegion() *string { + return nil +} + +func (c *AzureMetaConfig) GetSessionToken() *string { + return nil +} + +func (c *AzureMetaConfig) GetARN() *string { + return nil +} + +func (c *AzureMetaConfig) GetInferenceProfiles() map[string]string { + return nil +} + +func (c *AzureMetaConfig) GetEndpoint() *string { + return &c.Endpoint +} + +func (c *AzureMetaConfig) GetDeployments() map[string]string { + return c.Deployments +} + +func (c *AzureMetaConfig) GetAPIVersion() *string { + return c.APIVersion +} diff --git a/interfaces/meta/bedrock.go b/interfaces/meta/bedrock.go index c4caf56a89..b9047f73bb 100644 --- a/interfaces/meta/bedrock.go +++ b/interfaces/meta/bedrock.go @@ -27,3 +27,15 @@ func (c *BedrockMetaConfig) GetARN() *string { func (c *BedrockMetaConfig) GetInferenceProfiles() map[string]string { return c.InferenceProfiles } + +func (c *BedrockMetaConfig) GetEndpoint() *string { + return nil +} + +func (c *BedrockMetaConfig) GetDeployments() map[string]string { + return nil +} + +func (c *BedrockMetaConfig) GetAPIVersion() *string { + return nil +} diff --git a/interfaces/provider.go b/interfaces/provider.go index 20d54c3539..484fee4795 100644 --- a/interfaces/provider.go +++ b/interfaces/provider.go @@ -17,6 +17,9 @@ type MetaConfig interface { GetSessionToken() *string GetARN() *string GetInferenceProfiles() map[string]string + GetEndpoint() *string + GetDeployments() map[string]string + GetAPIVersion() *string } type ConcurrencyAndBufferSize struct { diff --git a/providers/azure.go b/providers/azure.go new file mode 100644 index 0000000000..359bf8d3d7 --- /dev/null +++ b/providers/azure.go @@ -0,0 +1,369 @@ +package providers + +import ( + "encoding/json" + "fmt" + "time" + + "github.com/maximhq/bifrost/interfaces" + "github.com/valyala/fasthttp" + + "github.com/maximhq/maxim-go" +) + +type AzureTextResponse struct { + ID string `json:"id"` + Object string `json:"object"` // text.completion or chat.completion + Choices []struct { + FinishReason *string `json:"finish_reason,omitempty"` + Index int `json:"index"` + Text string `json:"text"` + LogProbs interfaces.TextCompletionLogProb `json:"logprobs"` + } `json:"choices"` + Model string `json:"model"` + Created int `json:"created"` // The Unix timestamp (in seconds). + SystemFingerprint *string `json:"system_fingerprint"` + Usage interfaces.LLMUsage `json:"usage"` +} + +type AzureChatResponse struct { + ID string `json:"id"` + Object string `json:"object"` // text.completion or chat.completion + Choices []interfaces.BifrostResponseChoice `json:"choices"` + Model string `json:"model"` + Created int `json:"created"` // The Unix timestamp (in seconds). + SystemFingerprint *string `json:"system_fingerprint"` + Usage interfaces.LLMUsage `json:"usage"` +} + +type AzureError struct { + Error struct { + Code string `json:"code"` + Message string `json:"message"` + } `json:"error"` +} + +// AzureProvider implements the Provider interface for Azure API +type AzureProvider struct { + logger interfaces.Logger + client *fasthttp.Client + meta interfaces.MetaConfig +} + +// NewAzureProvider creates a new AzureProvider instance +func NewAzureProvider(config *interfaces.ProviderConfig, logger interfaces.Logger) *AzureProvider { + client := &fasthttp.Client{ + ReadTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), + WriteTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), + MaxConnsPerHost: config.ConcurrencyAndBufferSize.BufferSize, + } + + // Configure proxy if provided + client = configureProxy(client, config.ProxyConfig, logger) + + return &AzureProvider{ + logger: logger, + client: client, + meta: config.MetaConfig, + } +} + +func (provider *AzureProvider) GetProviderKey() interfaces.SupportedModelProvider { + return interfaces.Azure +} + +func (provider *AzureProvider) PrepareToolChoices(params map[string]interface{}) map[string]interface{} { + toolChoice, exists := params["tool_choice"] + if !exists { + return params + } + + switch tc := toolChoice.(type) { + case interfaces.ToolChoice: + anthropicToolChoice := AnthropicToolChoice{ + Type: tc.Type, + Name: &tc.Function.Name, + } + + parallelToolCalls, exists := params["parallel_tool_calls"] + if !exists { + return params + } + + switch parallelTC := parallelToolCalls.(type) { + case bool: + disableParallel := !parallelTC + anthropicToolChoice.DisableParallelToolUse = &disableParallel + + delete(params, "parallel_tool_calls") + } + + params["tool_choice"] = anthropicToolChoice + } + + return params +} + +func (provider *AzureProvider) CompleteRequest(requestBody map[string]interface{}, path string, key string, model string) ([]byte, *interfaces.BifrostError) { + // Marshal the request body + jsonData, err := json.Marshal(requestBody) + if err != nil { + return nil, &interfaces.BifrostError{ + IsBifrostError: true, + Error: interfaces.ErrorField{ + Message: "error marshaling request", + Error: err, + }, + } + } + + if provider.meta.GetEndpoint() == nil { + return nil, &interfaces.BifrostError{ + IsBifrostError: false, + Error: interfaces.ErrorField{ + Message: "endpoint not set", + }, + } + } + + url := *provider.meta.GetEndpoint() + + if provider.meta.GetDeployments() != nil { + deployment := provider.meta.GetDeployments()[model] + if deployment == "" { + return nil, &interfaces.BifrostError{ + IsBifrostError: false, + Error: interfaces.ErrorField{ + Message: fmt.Sprintf("deployment if not found for model %s", model), + }, + } + } + + apiVersion := provider.meta.GetAPIVersion() + if apiVersion == nil { + apiVersion = maxim.StrPtr("2024-02-01") + } + + url = fmt.Sprintf("%s/openai/deployments/%s/%s?api-version=%s", url, deployment, path, *apiVersion) + } else { + return nil, &interfaces.BifrostError{ + IsBifrostError: false, + Error: interfaces.ErrorField{ + Message: "deployments not set", + }, + } + } + + // Create the request with the JSON body + req := fasthttp.AcquireRequest() + resp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseRequest(req) + defer fasthttp.ReleaseResponse(resp) + + req.SetRequestURI(url) + req.Header.SetMethod("POST") + req.Header.SetContentType("application/json") + req.Header.Set("api-key", key) + req.SetBody(jsonData) + + // Send the request + if err := provider.client.Do(req, resp); err != nil { + return nil, &interfaces.BifrostError{ + IsBifrostError: true, + Error: interfaces.ErrorField{ + Message: "error sending request", + Error: err, + }, + } + } + + // Handle error response + if resp.StatusCode() != fasthttp.StatusOK { + var errorResp AzureError + if err := json.Unmarshal(resp.Body(), &errorResp); err != nil { + return nil, &interfaces.BifrostError{ + IsBifrostError: true, + Error: interfaces.ErrorField{ + Message: "error parsing error response", + Error: err, + }, + } + } + + statusCode := resp.StatusCode() + + return nil, &interfaces.BifrostError{ + IsBifrostError: false, + StatusCode: &statusCode, + Error: interfaces.ErrorField{ + Type: &errorResp.Error.Code, + Message: errorResp.Error.Message, + }, + } + } + + // Read the response body + body := resp.Body() + + return body, nil +} + +// TextCompletion implements text completion using Anthropic's API +func (provider *AzureProvider) TextCompletion(model, key, text string, params *interfaces.ModelParameters) (*interfaces.BifrostResponse, *interfaces.BifrostError) { + preparedParams := PrepareParams(params) + + // Merge additional parameters + requestBody := MergeConfig(map[string]interface{}{ + "model": model, + "prompt": text, + }, preparedParams) + + body, err := provider.CompleteRequest(requestBody, "completions", key, model) + if err != nil { + return nil, err + } + + // Parse the response + var response AzureTextResponse + if err := json.Unmarshal(body, &response); err != nil { + return nil, &interfaces.BifrostError{ + IsBifrostError: true, + Error: interfaces.ErrorField{ + Message: "error parsing response", + Error: err, + }, + } + } + + // Parse raw response + var rawResponse interface{} + if err := json.Unmarshal(body, &rawResponse); err != nil { + return nil, &interfaces.BifrostError{ + IsBifrostError: true, + Error: interfaces.ErrorField{ + Message: "error parsing raw response", + Error: err, + }, + } + } + + choices := []interfaces.BifrostResponseChoice{} + + // Create the completion result + if len(response.Choices) > 0 { + choices = append(choices, interfaces.BifrostResponseChoice{ + Index: 0, + Message: interfaces.BifrostResponseChoiceMessage{ + Role: interfaces.RoleAssistant, + Content: &response.Choices[0].Text, + }, + FinishReason: response.Choices[0].FinishReason, + LogProbs: &interfaces.LogProbs{ + Text: response.Choices[0].LogProbs, + }, + }) + } + + completionResult := &interfaces.BifrostResponse{ + ID: response.ID, + Choices: choices, + Usage: response.Usage, + Model: response.Model, + ExtraFields: interfaces.BifrostResponseExtraFields{ + Provider: interfaces.Azure, + RawResponse: rawResponse, + }, + } + + return completionResult, nil +} + +// ChatCompletion implements chat completion using Azure's API +func (provider *AzureProvider) ChatCompletion(model, key string, messages []interfaces.Message, params *interfaces.ModelParameters) (*interfaces.BifrostResponse, *interfaces.BifrostError) { + // Format messages for Azure API + var formattedMessages []map[string]interface{} + for _, msg := range messages { + if msg.ImageContent != nil { + var content []map[string]interface{} + + imageContent := map[string]interface{}{ + "type": "image_url", + "image_url": map[string]interface{}{"url": msg.ImageContent.URL}, + } + + content = append(content, imageContent) + + // Add text content if present + if msg.Content != nil { + content = append(content, map[string]interface{}{ + "type": "text", + "text": msg.Content, + }) + } + + formattedMessages = append(formattedMessages, map[string]interface{}{ + "role": msg.Role, + "content": content, + }) + } else { + formattedMessages = append(formattedMessages, map[string]interface{}{ + "role": msg.Role, + "content": msg.Content, + }) + } + } + + preparedParams := PrepareParams(params) + + // Merge additional parameters + requestBody := MergeConfig(map[string]interface{}{ + "model": model, + "messages": formattedMessages, + }, preparedParams) + + body, err := provider.CompleteRequest(requestBody, "chat/completions", key, model) + if err != nil { + return nil, err + } + + // Decode response + var response AzureChatResponse + if err := json.Unmarshal(body, &response); err != nil { + return nil, &interfaces.BifrostError{ + IsBifrostError: true, + Error: interfaces.ErrorField{ + Message: "error decoding response", + Error: err, + }, + } + } + + // Decode raw response + var rawResponse interface{} + if err := json.Unmarshal(body, &rawResponse); err != nil { + return nil, &interfaces.BifrostError{ + IsBifrostError: true, + Error: interfaces.ErrorField{ + Message: "error parsing raw response", + Error: err, + }, + } + } + + // Create the completion result + result := &interfaces.BifrostResponse{ + ID: response.ID, + Object: response.Object, + Choices: response.Choices, + Model: response.Model, + Created: response.Created, + SystemFingerprint: response.SystemFingerprint, + Usage: response.Usage, + ExtraFields: interfaces.BifrostResponseExtraFields{ + Provider: interfaces.Azure, + RawResponse: rawResponse, + }, + } + + return result, nil +} diff --git a/tests/account.go b/tests/account.go index ce1ecc2576..afd6e32bf0 100644 --- a/tests/account.go +++ b/tests/account.go @@ -54,6 +54,15 @@ func (baseAccount *BaseAccount) GetKeysForProvider(providerKey interfaces.Suppor Weight: 1.0, }, }, nil + case interfaces.Azure: + return []interfaces.Key{ + { + Value: os.Getenv("AZURE_API_KEY"), + Models: []string{"gpt-4o"}, + Weight: 1.0, + }, + }, nil + default: return nil, fmt.Errorf("unsupported provider: %s", providerKey) } @@ -118,6 +127,26 @@ func (baseAccount *BaseAccount) GetConfigForProvider(providerKey interfaces.Supp BufferSize: 10, }, }, nil + case interfaces.Azure: + return &interfaces.ProviderConfig{ + NetworkConfig: interfaces.NetworkConfig{ + DefaultRequestTimeoutInSeconds: 30, + MaxRetries: 1, + RetryBackoffInitial: 100 * time.Millisecond, + RetryBackoffMax: 2 * time.Second, + }, + MetaConfig: &meta.AzureMetaConfig{ + Endpoint: os.Getenv("AZURE_ENDPOINT"), + Deployments: map[string]string{ + "gpt-4o": "gpt-4o-aug", + }, + APIVersion: maxim.StrPtr("2024-08-01-preview"), + }, + ConcurrencyAndBufferSize: interfaces.ConcurrencyAndBufferSize{ + Concurrency: 3, + BufferSize: 10, + }, + }, nil default: return nil, fmt.Errorf("unsupported provider: %s", providerKey) } diff --git a/tests/azure_test.go b/tests/azure_test.go new file mode 100644 index 0000000000..e21726ee5b --- /dev/null +++ b/tests/azure_test.go @@ -0,0 +1,171 @@ +package tests + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/maximhq/bifrost" + "github.com/maximhq/bifrost/interfaces" + + "github.com/maximhq/maxim-go" +) + +// setupAzureRequests sends multiple test requests to Azure +func setupAzureRequests(bifrost *bifrost.Bifrost) { + text := "Hello world!" + ctx := context.Background() + + // Text completion request + go func() { + result, err := bifrost.TextCompletionRequest(interfaces.Azure, &interfaces.BifrostRequest{ + Model: "gpt-4o", + Input: interfaces.RequestInput{ + TextCompletionInput: &text, + }, + Params: nil, + }, ctx) + if err != nil { + fmt.Println("Error:", err.Error.Message) + } else { + fmt.Println("🐒 Azure Text Completion Result:", result.Choices[0].Message.Content) + } + }() + + // Regular chat completion requests + azureMessages := []string{ + "Hello! How are you today?", + "Tell me a joke!", + "What's your favorite programming language?", + } + + for i, message := range azureMessages { + delay := time.Duration(100*(i+1)) * time.Millisecond + go func(msg string, delay time.Duration, index int) { + time.Sleep(delay) + messages := []interfaces.Message{ + { + Role: interfaces.RoleUser, + Content: &msg, + }, + } + result, err := bifrost.ChatCompletionRequest(interfaces.Azure, &interfaces.BifrostRequest{ + Model: "gpt-4o", + Input: interfaces.RequestInput{ + ChatCompletionInput: &messages, + }, + Params: nil, + }, ctx) + if err != nil { + fmt.Printf("Error in Azure request %d: %v\n", index+1, err.Error.Message) + } else { + fmt.Printf("🐒 Azure Chat Completion Result %d: %s\n", index+1, *result.Choices[0].Message.Content) + } + }(message, delay, i) + } + + // Image input tests + setupAzureImageTests(bifrost, ctx) + + // Tool calls test + setupAzureToolCalls(bifrost, ctx) +} + +// setupAzureImageTests tests Azure's image input capabilities +func setupAzureImageTests(bifrost *bifrost.Bifrost, ctx context.Context) { + // Test with URL image + urlImageMessages := []interfaces.Message{ + { + Role: interfaces.RoleUser, + Content: maxim.StrPtr("What is Happening in this picture?"), + ImageContent: &interfaces.ImageContent{ + URL: "https://upload.wikimedia.org/wikipedia/commons/a/a7/Camponotus_flavomarginatus_ant.jpg", + }, + }, + } + + go func() { + result, err := bifrost.ChatCompletionRequest(interfaces.Azure, &interfaces.BifrostRequest{ + Model: "gpt-4o", + Input: interfaces.RequestInput{ + ChatCompletionInput: &urlImageMessages, + }, + Params: nil, + }, ctx) + if err != nil { + fmt.Printf("Error in Azure URL image request: %v\n", err.Error.Message) + } else { + fmt.Printf("🐒 Azure URL Image Result: %s\n", *result.Choices[0].Message.Content) + } + }() +} + +// setupAzureToolCalls tests Azure's function calling capability +func setupAzureToolCalls(bifrost *bifrost.Bifrost, ctx context.Context) { + azureMessages := []string{ + "What's the weather like in Mumbai?", + } + + params := interfaces.ModelParameters{ + Tools: &[]interfaces.Tool{{ + Type: "function", + Function: interfaces.Function{ + Name: "get_weather", + Description: "Get the current weather in a given location", + Parameters: interfaces.FunctionParameters{ + Type: "object", + Properties: map[string]interface{}{ + "location": map[string]interface{}{ + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": map[string]interface{}{ + "type": "string", + "enum": []string{"celsius", "fahrenheit"}, + }, + }, + Required: []string{"location"}, + }, + }, + }}, + } + + for i, message := range azureMessages { + delay := time.Duration(100*(i+1)) * time.Millisecond + go func(msg string, delay time.Duration, index int) { + time.Sleep(delay) + messages := []interfaces.Message{ + { + Role: interfaces.RoleUser, + Content: &msg, + }, + } + result, err := bifrost.ChatCompletionRequest(interfaces.Azure, &interfaces.BifrostRequest{ + Model: "gpt-4o", + Input: interfaces.RequestInput{ + ChatCompletionInput: &messages, + }, + Params: ¶ms, + }, ctx) + if err != nil { + fmt.Printf("Error in Azure tool call request %d: %v\n", index+1, err.Error.Message) + } else { + toolCall := *result.Choices[0].Message.ToolCalls + fmt.Printf("🐒 Azure Tool Call Result %d: %s\n", index+1, toolCall[0].Function.Arguments) + } + }(message, delay, i) + } +} + +func TestAzure(t *testing.T) { + bifrost, err := getBifrost() + if err != nil { + t.Fatalf("Error initializing bifrost: %v", err) + return + } + + setupAzureRequests(bifrost) + + bifrost.Cleanup() +}