diff --git a/.gitignore b/.gitignore index 48303bc0a5..9c7a2150b7 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ .env .vscode .DS_Store +*_creds* \ No newline at end of file diff --git a/core/bifrost.go b/core/bifrost.go index 4e296b0cec..f3f4f67dc8 100644 --- a/core/bifrost.go +++ b/core/bifrost.go @@ -60,11 +60,13 @@ func (bifrost *Bifrost) createProviderFromProviderKey(providerKey schemas.ModelP case schemas.Anthropic: return providers.NewAnthropicProvider(config, bifrost.logger), nil case schemas.Bedrock: - return providers.NewBedrockProvider(config, bifrost.logger), nil + return providers.NewBedrockProvider(config, bifrost.logger) case schemas.Cohere: return providers.NewCohereProvider(config, bifrost.logger), nil case schemas.Azure: - return providers.NewAzureProvider(config, bifrost.logger), nil + return providers.NewAzureProvider(config, bifrost.logger) + case schemas.Vertex: + return providers.NewVertexProvider(config, bifrost.logger) default: return nil, fmt.Errorf("unsupported provider: %s", providerKey) } @@ -78,10 +80,12 @@ func (bifrost *Bifrost) prepareProvider(providerKey schemas.ModelProvider, confi return fmt.Errorf("failed to get config for provider: %v", err) } - // Check if the provider has any keys - keys, err := bifrost.account.GetKeysForProvider(providerKey) - if err != nil || len(keys) == 0 { - return fmt.Errorf("failed to get keys for provider: %v", err) + // Check if the provider has any keys (skip vertex) + if providerKey != schemas.Vertex { + keys, err := bifrost.account.GetKeysForProvider(providerKey) + if err != nil || len(keys) == 0 { + return fmt.Errorf("failed to get keys for provider: %v", err) + } } queue := make(chan ChannelMessage, providerConfig.ConcurrencyAndBufferSize.BufferSize) // Buffered channel per provider @@ -93,7 +97,7 @@ func (bifrost *Bifrost) prepareProvider(providerKey schemas.ModelProvider, confi provider, err := bifrost.createProviderFromProviderKey(providerKey, config) if err != nil { - return fmt.Errorf("failed to get provider for the given key: %v", err) + return fmt.Errorf("failed to create provider for the given key: %v", err) } for range providerConfig.ConcurrencyAndBufferSize.Concurrency { @@ -166,7 +170,7 @@ func Init(config schemas.BifrostConfig) (*Bifrost, error) { } if err := bifrost.prepareProvider(providerKey, config); err != nil { - bifrost.logger.Warn(fmt.Sprintf("failed to prepare provider: %v", err)) + bifrost.logger.Warn(fmt.Sprintf("failed to prepare provider %s: %v", providerKey, err)) } } @@ -291,18 +295,22 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, queue chan Chan for req := range queue { var result *schemas.BifrostResponse var bifrostError *schemas.BifrostError + var err error - key, err := bifrost.SelectKeyFromProviderForModel(provider.GetProviderKey(), req.Model) - if err != nil { - bifrost.logger.Warn(fmt.Sprintf("Error selecting key for model %s: %v", req.Model, err)) - req.Err <- schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: err.Error(), - Error: err, - }, + key := "" + if provider.GetProviderKey() != schemas.Vertex { + key, err = bifrost.SelectKeyFromProviderForModel(provider.GetProviderKey(), req.Model) + if err != nil { + bifrost.logger.Warn(fmt.Sprintf("Error selecting key for model %s: %v", req.Model, err)) + req.Err <- schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Message: err.Error(), + Error: err, + }, + } + continue } - continue } config, err := bifrost.account.GetConfigForProvider(provider.GetProviderKey()) diff --git a/core/go.mod b/core/go.mod index b35f8bf61b..114449f328 100644 --- a/core/go.mod +++ b/core/go.mod @@ -10,9 +10,11 @@ require ( github.com/goccy/go-json v0.10.5 github.com/maximhq/bifrost/plugins v1.0.0 github.com/valyala/fasthttp v1.60.0 + golang.org/x/oauth2 v0.30.0 ) require ( + cloud.google.com/go/compute/metadata v0.3.0 // indirect github.com/andybalholm/brotli v1.1.1 // indirect github.com/aws/aws-sdk-go-v2/credentials v1.17.67 // indirect github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30 // indirect diff --git a/core/go.sum b/core/go.sum index d0f8edd171..bef8cd14b0 100644 --- a/core/go.sum +++ b/core/go.sum @@ -1,3 +1,5 @@ +cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc= +cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA= github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA= github.com/aws/aws-sdk-go-v2 v1.36.3 h1:mJoei2CxPutQVxaATCzDUjcZEjVRdpsiiXi2o38yqWM= @@ -44,5 +46,7 @@ github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZ github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY= golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E= +golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= +golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0= golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU= diff --git a/core/providers/anthropic.go b/core/providers/anthropic.go index f9c4fb601b..987bc7be3b 100644 --- a/core/providers/anthropic.go +++ b/core/providers/anthropic.go @@ -120,7 +120,7 @@ func releaseAnthropicTextResponse(resp *AnthropicTextResponse) { // It initializes the HTTP client with the provided configuration and sets up response pools. // The client is configured with timeouts, concurrency limits, and optional proxy settings. func NewAnthropicProvider(config *schemas.ProviderConfig, logger schemas.Logger) *AnthropicProvider { - setConfigDefaults(config) + config.CheckAndSetDefaults() client := &fasthttp.Client{ ReadTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), @@ -207,6 +207,8 @@ func (provider *AnthropicProvider) completeRequest(requestBody map[string]interf // Handle error response if resp.StatusCode() != fasthttp.StatusOK { + provider.logger.Debug(fmt.Sprintf("error from anthropic provider: %s", string(resp.Body()))) + var errorResp AnthropicError bifrostErr := handleProviderAPIError(resp, &errorResp) @@ -280,6 +282,46 @@ func (provider *AnthropicProvider) TextCompletion(model, key, text string, param // It formats the request, sends it to Anthropic, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. func (provider *AnthropicProvider) ChatCompletion(model, key string, messages []schemas.Message, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + formattedMessages, preparedParams := prepareAnthropicChatRequest(model, messages, params) + + // Merge additional parameters + requestBody := mergeConfig(map[string]interface{}{ + "model": model, + "messages": formattedMessages, + }, preparedParams) + + responseBody, err := provider.completeRequest(requestBody, "https://api.anthropic.com/v1/messages", key) + if err != nil { + return nil, err + } + + // Create response object from pool + response := acquireAnthropicChatResponse() + defer releaseAnthropicChatResponse(response) + + // Create Bifrost response from pool + bifrostResponse := acquireBifrostResponse() + defer releaseBifrostResponse(bifrostResponse) + + rawResponse, bifrostErr := handleProviderResponse(responseBody, response) + if bifrostErr != nil { + return nil, bifrostErr + } + + bifrostResponse, err = parseAnthropicResponse(response, bifrostResponse) + if err != nil { + return nil, err + } + + bifrostResponse.ExtraFields = schemas.BifrostResponseExtraFields{ + Provider: schemas.Anthropic, + RawResponse: rawResponse, + } + + return bifrostResponse, nil +} + +func prepareAnthropicChatRequest(model string, messages []schemas.Message, params *schemas.ModelParameters) ([]map[string]interface{}, map[string]interface{}) { // Add system messages if present var systemMessages []BedrockAnthropicSystemMessage for _, msg := range messages { @@ -352,39 +394,19 @@ func (provider *AnthropicProvider) ChatCompletion(model, key string, messages [] preparedParams["tools"] = tools } - // Merge additional parameters - requestBody := mergeConfig(map[string]interface{}{ - "model": model, - "messages": formattedMessages, - }, preparedParams) - if len(systemMessages) > 0 { var messages []string for _, message := range systemMessages { messages = append(messages, message.Text) } - requestBody["system"] = strings.Join(messages, " ") - } - - responseBody, err := provider.completeRequest(requestBody, "https://api.anthropic.com/v1/messages", key) - if err != nil { - return nil, err + preparedParams["system"] = strings.Join(messages, " ") } - // Create response object from pool - response := acquireAnthropicChatResponse() - defer releaseAnthropicChatResponse(response) - - // Create Bifrost response from pool - bifrostResponse := acquireBifrostResponse() - defer releaseBifrostResponse(bifrostResponse) - - rawResponse, bifrostErr := handleProviderResponse(responseBody, response) - if bifrostErr != nil { - return nil, bifrostErr - } + return formattedMessages, preparedParams +} +func parseAnthropicResponse(response *AnthropicChatResponse, bifrostResponse *schemas.BifrostResponse) (*schemas.BifrostResponse, *schemas.BifrostError) { // Process the response into our BifrostResponse format var choices []schemas.BifrostResponseChoice @@ -437,10 +459,6 @@ func (provider *AnthropicProvider) ChatCompletion(model, key string, messages [] TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens, } bifrostResponse.Model = response.Model - bifrostResponse.ExtraFields = schemas.BifrostResponseExtraFields{ - Provider: schemas.Anthropic, - RawResponse: rawResponse, - } return bifrostResponse, nil } diff --git a/core/providers/azure.go b/core/providers/azure.go index 13e8e2ee1b..0788d0196f 100644 --- a/core/providers/azure.go +++ b/core/providers/azure.go @@ -103,8 +103,12 @@ type AzureProvider struct { // NewAzureProvider creates a new Azure provider instance. // It initializes the HTTP client with the provided configuration and sets up response pools. // The client is configured with timeouts, concurrency limits, and optional proxy settings. -func NewAzureProvider(config *schemas.ProviderConfig, logger schemas.Logger) *AzureProvider { - setConfigDefaults(config) +func NewAzureProvider(config *schemas.ProviderConfig, logger schemas.Logger) (*AzureProvider, error) { + config.CheckAndSetDefaults() + + if config.MetaConfig == nil { + return nil, fmt.Errorf("meta config is not set") + } client := &fasthttp.Client{ ReadTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), @@ -126,7 +130,7 @@ func NewAzureProvider(config *schemas.ProviderConfig, logger schemas.Logger) *Az logger: logger, client: client, meta: config.MetaConfig, - } + }, nil } // GetProviderKey returns the provider identifier for Azure. @@ -212,6 +216,8 @@ func (provider *AzureProvider) completeRequest(requestBody map[string]interface{ // Handle error response if resp.StatusCode() != fasthttp.StatusOK { + provider.logger.Debug(fmt.Sprintf("error from azure provider: %s", string(resp.Body()))) + var errorResp AzureError bifrostErr := handleProviderAPIError(resp, &errorResp) diff --git a/core/providers/bedrock.go b/core/providers/bedrock.go index 3a9b35c39b..3eb6bc1206 100644 --- a/core/providers/bedrock.go +++ b/core/providers/bedrock.go @@ -159,8 +159,12 @@ func releaseBedrockChatResponse(resp *BedrockChatResponse) { // NewBedrockProvider creates a new Bedrock provider instance. // It initializes the HTTP client with the provided configuration and sets up response pools. // The client is configured with timeouts and AWS-specific settings. -func NewBedrockProvider(config *schemas.ProviderConfig, logger schemas.Logger) *BedrockProvider { - setConfigDefaults(config) +func NewBedrockProvider(config *schemas.ProviderConfig, logger schemas.Logger) (*BedrockProvider, error) { + config.CheckAndSetDefaults() + + if config.MetaConfig == nil { + return nil, fmt.Errorf("meta config is not set") + } client := &http.Client{Timeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds)} @@ -174,7 +178,7 @@ func NewBedrockProvider(config *schemas.ProviderConfig, logger schemas.Logger) * logger: logger, client: client, meta: config.MetaConfig, - } + }, nil } // GetProviderKey returns the provider identifier for Bedrock. @@ -258,6 +262,7 @@ func (provider *BedrockProvider) completeRequest(requestBody map[string]interfac if err := json.Unmarshal(body, &errorResp); err != nil { return nil, &schemas.BifrostError{ IsBifrostError: true, + StatusCode: &resp.StatusCode, Error: schemas.ErrorField{ Message: schemas.ErrProviderResponseUnmarshal, Error: err, diff --git a/core/providers/cohere.go b/core/providers/cohere.go index 0240af086b..c8e6f41b19 100644 --- a/core/providers/cohere.go +++ b/core/providers/cohere.go @@ -101,7 +101,7 @@ type CohereProvider struct { // It initializes the HTTP client with the provided configuration and sets up response pools. // The client is configured with timeouts and connection limits. func NewCohereProvider(config *schemas.ProviderConfig, logger schemas.Logger) *CohereProvider { - setConfigDefaults(config) + config.CheckAndSetDefaults() client := &fasthttp.Client{ ReadTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), @@ -234,6 +234,8 @@ func (provider *CohereProvider) ChatCompletion(model, key string, messages []sch // Handle error response if resp.StatusCode() != fasthttp.StatusOK { + provider.logger.Debug(fmt.Sprintf("error from cohere provider: %s", string(resp.Body()))) + var errorResp CohereError bifrostErr := handleProviderAPIError(resp, &errorResp) diff --git a/core/providers/openai.go b/core/providers/openai.go index c4be162bcb..9dff479d98 100644 --- a/core/providers/openai.go +++ b/core/providers/openai.go @@ -3,6 +3,7 @@ package providers import ( + "fmt" "sync" "time" @@ -70,7 +71,7 @@ type OpenAIProvider struct { // It initializes the HTTP client with the provided configuration and sets up response pools. // The client is configured with timeouts, concurrency limits, and optional proxy settings. func NewOpenAIProvider(config *schemas.ProviderConfig, logger schemas.Logger) *OpenAIProvider { - setConfigDefaults(config) + config.CheckAndSetDefaults() client := &fasthttp.Client{ ReadTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), @@ -113,46 +114,7 @@ func (provider *OpenAIProvider) TextCompletion(model, key, text string, params * // It supports both text and image content in messages. // Returns a BifrostResponse containing the completion results or an error if the request fails. func (provider *OpenAIProvider) ChatCompletion(model, key string, messages []schemas.Message, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { - // Format messages for OpenAI API - var formattedMessages []map[string]interface{} - for _, msg := range messages { - if msg.ImageContent != nil { - var content []map[string]interface{} - - // Add text content if present - if msg.Content != nil { - content = append(content, map[string]interface{}{ - "type": "text", - "text": msg.Content, - }) - } - - imageContent := map[string]interface{}{ - "type": "image_url", - "image_url": map[string]interface{}{ - "url": msg.ImageContent.URL, - }, - } - - if msg.ImageContent.Detail != nil { - imageContent["image_url"].(map[string]interface{})["detail"] = msg.ImageContent.Detail - } - - content = append(content, imageContent) - - formattedMessages = append(formattedMessages, map[string]interface{}{ - "role": msg.Role, - "content": content, - }) - } else { - formattedMessages = append(formattedMessages, map[string]interface{}{ - "role": msg.Role, - "content": msg.Content, - }) - } - } - - preparedParams := prepareParams(params) + formattedMessages, preparedParams := prepareOpenAIChatRequest(model, messages, params) requestBody := mergeConfig(map[string]interface{}{ "model": model, @@ -195,6 +157,8 @@ func (provider *OpenAIProvider) ChatCompletion(model, key string, messages []sch // Handle error response if resp.StatusCode() != fasthttp.StatusOK { + provider.logger.Debug(fmt.Sprintf("error from openai provider: %s", string(resp.Body()))) + var errorResp OpenAIError bifrostErr := handleProviderAPIError(resp, &errorResp) @@ -244,3 +208,48 @@ func (provider *OpenAIProvider) ChatCompletion(model, key string, messages []sch return result, nil } + +func prepareOpenAIChatRequest(model string, messages []schemas.Message, params *schemas.ModelParameters) ([]map[string]interface{}, map[string]interface{}) { + // Format messages for OpenAI API + var formattedMessages []map[string]interface{} + for _, msg := range messages { + if msg.ImageContent != nil { + var content []map[string]interface{} + + // Add text content if present + if msg.Content != nil { + content = append(content, map[string]interface{}{ + "type": "text", + "text": msg.Content, + }) + } + + imageContent := map[string]interface{}{ + "type": "image_url", + "image_url": map[string]interface{}{ + "url": msg.ImageContent.URL, + }, + } + + if msg.ImageContent.Detail != nil { + imageContent["image_url"].(map[string]interface{})["detail"] = msg.ImageContent.Detail + } + + content = append(content, imageContent) + + formattedMessages = append(formattedMessages, map[string]interface{}{ + "role": msg.Role, + "content": content, + }) + } else { + formattedMessages = append(formattedMessages, map[string]interface{}{ + "role": msg.Role, + "content": msg.Content, + }) + } + } + + preparedParams := prepareParams(params) + + return formattedMessages, preparedParams +} diff --git a/core/providers/utils.go b/core/providers/utils.go index 2988fac358..9f1482ee64 100644 --- a/core/providers/utils.go +++ b/core/providers/utils.go @@ -161,9 +161,12 @@ func configureProxy(client *fasthttp.Client, proxyConfig *schemas.ProxyConfig, l // It attempts to unmarshal the error response and returns a BifrostError // with the appropriate status code and error information. func handleProviderAPIError(resp *fasthttp.Response, errorResp any) *schemas.BifrostError { + statusCode := resp.StatusCode() + if err := json.Unmarshal(resp.Body(), &errorResp); err != nil { return &schemas.BifrostError{ IsBifrostError: true, + StatusCode: &statusCode, Error: schemas.ErrorField{ Message: schemas.ErrProviderResponseUnmarshal, Error: err, @@ -171,8 +174,6 @@ func handleProviderAPIError(resp *fasthttp.Response, errorResp any) *schemas.Bif } } - statusCode := resp.StatusCode() - return &schemas.BifrostError{ IsBifrostError: false, StatusCode: &statusCode, @@ -229,32 +230,6 @@ func float64Ptr(f float64) *float64 { return &f } -func setConfigDefaults(config *schemas.ProviderConfig) { - if config.ConcurrencyAndBufferSize.Concurrency == 0 { - config.ConcurrencyAndBufferSize.Concurrency = schemas.DefaultConcurrency - } - - if config.ConcurrencyAndBufferSize.BufferSize == 0 { - config.ConcurrencyAndBufferSize.BufferSize = schemas.DefaultBufferSize - } - - if config.NetworkConfig.DefaultRequestTimeoutInSeconds == 0 { - config.NetworkConfig.DefaultRequestTimeoutInSeconds = schemas.DefaultRequestTimeoutInSeconds - } - - if config.NetworkConfig.MaxRetries == 0 { - config.NetworkConfig.MaxRetries = schemas.DefaultMaxRetries - } - - if config.NetworkConfig.RetryBackoffInitial == 0 { - config.NetworkConfig.RetryBackoffInitial = schemas.DefaultRetryBackoffInitial - } - - if config.NetworkConfig.RetryBackoffMax == 0 { - config.NetworkConfig.RetryBackoffMax = schemas.DefaultRetryBackoffMax - } -} - func StrPtr(s string) *string { return &s } diff --git a/core/providers/vertex.go b/core/providers/vertex.go new file mode 100644 index 0000000000..a8acbecf3e --- /dev/null +++ b/core/providers/vertex.go @@ -0,0 +1,294 @@ +// Package providers implements various LLM providers and their utility functions. +// This file contains the Vertex provider implementation. +package providers + +import ( + "bytes" + "context" + "fmt" + "io" + "net/http" + "os" + "strings" + + "github.com/goccy/go-json" + "golang.org/x/oauth2/google" + + schemas "github.com/maximhq/bifrost/core/schemas" +) + +type VertexError struct { + Error struct { + Code int `json:"code"` + Message string `json:"message"` + Status string `json:"status"` + } `json:"error"` +} + +// VertexProvider implements the Provider interface for Vertex's API. +type VertexProvider struct { + logger schemas.Logger // Logger for provider operations + client *http.Client // HTTP client for API requests + meta schemas.MetaConfig // Vertex-specific configuration +} + +// NewVertexProvider creates a new Vertex provider instance. +// It initializes the HTTP client with the provided configuration and sets up response pools. +// The client is configured with timeouts, concurrency limits, and optional proxy settings. +func NewVertexProvider(config *schemas.ProviderConfig, logger schemas.Logger) (*VertexProvider, error) { + config.CheckAndSetDefaults() + + if config.MetaConfig == nil { + return nil, fmt.Errorf("meta config is not set") + } + + authCredentialPath := config.MetaConfig.GetAuthCredentialPath() + if authCredentialPath == nil { + return nil, fmt.Errorf("auth credential path is not set") + } + + data, err := os.ReadFile(*authCredentialPath) + if err != nil { + return nil, fmt.Errorf("failed to read auth credentials: %w", err) + } + + // Get a Google JWT Config for the correct scope + conf, err := google.JWTConfigFromJSON(data, "https://www.googleapis.com/auth/cloud-platform") + if err != nil { + return nil, fmt.Errorf("failed to create JWT config: %w", err) + } + + // Get an access token + client := conf.Client(context.Background()) + + // Pre-warm response pools + for range config.ConcurrencyAndBufferSize.Concurrency { + openAIResponsePool.Put(&OpenAIResponse{}) + anthropicChatResponsePool.Put(&AnthropicChatResponse{}) + bifrostResponsePool.Put(&schemas.BifrostResponse{}) + } + + return &VertexProvider{ + logger: logger, + client: client, + meta: config.MetaConfig, + }, nil +} + +// GetProviderKey returns the provider identifier for Vertex. +func (provider *VertexProvider) GetProviderKey() schemas.ModelProvider { + return schemas.Vertex +} + +// TextCompletion is not supported by the Vertex provider. +// Returns an error indicating that text completion is not available. +func (provider *VertexProvider) TextCompletion(model, key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Message: "text completion is not supported by vertex provider", + }, + } +} + +// ChatCompletion performs a chat completion request to the Vertex API. +// It supports both text and image content in messages. +// Returns a BifrostResponse containing the completion results or an error if the request fails. +func (provider *VertexProvider) ChatCompletion(model, key string, messages []schemas.Message, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + // Format messages for Vertex API + var formattedMessages []map[string]interface{} + var preparedParams map[string]interface{} + + if strings.Contains(model, "claude") { + formattedMessages, preparedParams = prepareAnthropicChatRequest(model, messages, params) + } else { + formattedMessages, preparedParams = prepareOpenAIChatRequest(model, messages, params) + } + + requestBody := mergeConfig(map[string]interface{}{ + "model": model, + "messages": formattedMessages, + }, preparedParams) + + if strings.Contains(model, "claude") { + if _, exists := requestBody["anthropic_version"]; !exists { + requestBody["anthropic_version"] = "vertex-2023-10-16" + } + + delete(requestBody, "model") + } + + delete(requestBody, "region") + + jsonBody, err := json.Marshal(requestBody) + if err != nil { + return nil, &schemas.BifrostError{ + IsBifrostError: true, + Error: schemas.ErrorField{ + Message: schemas.ErrProviderJSONMarshaling, + Error: err, + }, + } + } + + projectID := provider.meta.GetProjectID() + if projectID == nil { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Message: "project ID is not set", + }, + } + } + + region := params.Region + if region == nil { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Message: "region is not set in model params", + }, + } + } + + url := fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions", *region, *projectID, *region) + + if strings.Contains(model, "claude") { + url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:streamRawPredict", *region, *projectID, *region, model) + } + + // Create request + req, err := http.NewRequest("POST", url, bytes.NewReader(jsonBody)) + if err != nil { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Message: schemas.ErrProviderRequest, + Error: err, + }, + } + } + req.Header.Set("Content-Type", "application/json") + + // Make request + resp, err := provider.client.Do(req) + if err != nil { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Message: "error creating request", + Error: err, + }, + } + } + defer resp.Body.Close() + + // Handle error response + // Read response body + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, &schemas.BifrostError{ + IsBifrostError: true, + Error: schemas.ErrorField{ + Message: "error reading request", + Error: err, + }, + } + } + + if resp.StatusCode != http.StatusOK { + var openAIErr OpenAIError + var vertexErr []VertexError + + provider.logger.Debug(fmt.Sprintf("error from vertex provider: %s", string(body))) + + if err := json.Unmarshal(body, &openAIErr); err != nil { + // Try Vertex error format if OpenAI format fails + if err := json.Unmarshal(body, &vertexErr); err != nil { + fmt.Printf("error unmarshalling vertex error: %s, body: %s", err, string(body)) + return nil, &schemas.BifrostError{ + IsBifrostError: true, + StatusCode: &resp.StatusCode, + Error: schemas.ErrorField{ + Message: schemas.ErrProviderResponseUnmarshal, + Error: err, + }, + } + } + + if len(vertexErr) > 0 { + return nil, &schemas.BifrostError{ + StatusCode: &resp.StatusCode, + Type: &vertexErr[0].Error.Status, + Error: schemas.ErrorField{ + Message: vertexErr[0].Error.Message, + }, + } + } + } + + return nil, &schemas.BifrostError{ + StatusCode: &resp.StatusCode, + Error: schemas.ErrorField{ + Message: openAIErr.Error.Message, + }, + } + } + + if strings.Contains(model, "claude") { + // Create response object from pool + response := acquireAnthropicChatResponse() + defer releaseAnthropicChatResponse(response) + + // Create Bifrost response from pool + bifrostResponse := acquireBifrostResponse() + defer releaseBifrostResponse(bifrostResponse) + + rawResponse, bifrostErr := handleProviderResponse(body, response) + if bifrostErr != nil { + return nil, bifrostErr + } + + var err *schemas.BifrostError + bifrostResponse, err = parseAnthropicResponse(response, bifrostResponse) + if err != nil { + return nil, err + } + + bifrostResponse.ExtraFields = schemas.BifrostResponseExtraFields{ + Provider: schemas.Vertex, + RawResponse: rawResponse, + } + + return bifrostResponse, nil + } else { + // Pre-allocate response structs from pools + response := acquireOpenAIResponse() + defer releaseOpenAIResponse(response) + + result := acquireBifrostResponse() + defer releaseBifrostResponse(result) + + // Use enhanced response handler with pre-allocated response + rawResponse, bifrostErr := handleProviderResponse(body, response) + if bifrostErr != nil { + return nil, bifrostErr + } + + // Populate result from response + result.ID = response.ID + result.Choices = response.Choices + result.Object = response.Object + result.Usage = response.Usage + result.ServiceTier = response.ServiceTier + result.SystemFingerprint = response.SystemFingerprint + result.Model = response.Model + result.Created = response.Created + result.ExtraFields = schemas.BifrostResponseExtraFields{ + Provider: schemas.Vertex, + RawResponse: rawResponse, + } + + return result, nil + } +} diff --git a/core/schemas/bifrost.go b/core/schemas/bifrost.go index 5b95b5d796..821a608271 100644 --- a/core/schemas/bifrost.go +++ b/core/schemas/bifrost.go @@ -36,6 +36,7 @@ const ( Anthropic ModelProvider = "anthropic" Bedrock ModelProvider = "bedrock" Cohere ModelProvider = "cohere" + Vertex ModelProvider = "vertex" ) //* Request Structs @@ -81,7 +82,7 @@ type ModelParameters struct { PresencePenalty *float64 `json:"presence_penalty,omitempty"` // Penalizes repeated tokens FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` // Penalizes frequent tokens ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` // Enables parallel tool calls - + Region *string `json:"region"` // Dynamic parameters that can be provider-specific, they are directly // added to the request as is. ExtraParams map[string]interface{} `json:"-"` diff --git a/core/schemas/meta/azure.go b/core/schemas/meta/azure.go index df5fd163b9..d1a98a1bae 100644 --- a/core/schemas/meta/azure.go +++ b/core/schemas/meta/azure.go @@ -54,3 +54,13 @@ func (c *AzureMetaConfig) GetDeployments() map[string]string { func (c *AzureMetaConfig) GetAPIVersion() *string { return c.APIVersion } + +// This is not used for Azure. +func (c *AzureMetaConfig) GetProjectID() *string { + return nil +} + +// This is not used for Azure. +func (c *AzureMetaConfig) GetAuthCredentialPath() *string { + return nil +} diff --git a/core/schemas/meta/bedrock.go b/core/schemas/meta/bedrock.go index 1a875d3f65..3b9fe9329e 100644 --- a/core/schemas/meta/bedrock.go +++ b/core/schemas/meta/bedrock.go @@ -57,3 +57,13 @@ func (c *BedrockMetaConfig) GetDeployments() map[string]string { func (c *BedrockMetaConfig) GetAPIVersion() *string { return nil } + +// This is not used for Bedrock. +func (c *BedrockMetaConfig) GetProjectID() *string { + return nil +} + +// This is not used for Bedrock. +func (c *BedrockMetaConfig) GetAuthCredentialPath() *string { + return nil +} diff --git a/core/schemas/meta/vertex.go b/core/schemas/meta/vertex.go new file mode 100644 index 0000000000..d43146ebc6 --- /dev/null +++ b/core/schemas/meta/vertex.go @@ -0,0 +1,63 @@ +// Package meta provides provider-specific configuration structures and schemas. +// This file contains the AWS Vertex-specific configuration implementation. + +package meta + +// VertexMetaConfig represents the Vertex-specific configuration. +// It contains Vertex-specific settings required for authentication and service access. +type VertexMetaConfig struct { + ProjectID string `json:"project_id,omitempty"` + AuthCredentialPath string `json:"auth_credential_path,omitempty"` +} + +// This is not used for Vertex. +func (c *VertexMetaConfig) GetSecretAccessKey() *string { + return nil +} + +// This is not used for Vertex. +func (c *VertexMetaConfig) GetRegion() *string { + return nil +} + +// This is not used for Vertex. +func (c *VertexMetaConfig) GetSessionToken() *string { + return nil +} + +// This is not used for Vertex. +func (c *VertexMetaConfig) GetARN() *string { + return nil +} + +// This is not used for Vertex. +func (c *VertexMetaConfig) GetInferenceProfiles() map[string]string { + return nil +} + +// This is not used for Vertex. +func (c *VertexMetaConfig) GetEndpoint() *string { + return nil +} + +// This is not used for Vertex. +func (c *VertexMetaConfig) GetDeployments() map[string]string { + return nil +} + +// This is not used for Vertex. +func (c *VertexMetaConfig) GetAPIVersion() *string { + return nil +} + +// GetProjectID returns the Vertex project ID. +// This is the project ID for the Vertex project. +func (c *VertexMetaConfig) GetProjectID() *string { + return &c.ProjectID +} + +// GetAuthCredentialPath returns the path to the authentication credentials for the provider. +// This is the path to the authentication credentials for the google cloud api. +func (c *VertexMetaConfig) GetAuthCredentialPath() *string { + return &c.AuthCredentialPath +} diff --git a/core/schemas/provider.go b/core/schemas/provider.go index 56376b730f..25631887d0 100644 --- a/core/schemas/provider.go +++ b/core/schemas/provider.go @@ -49,6 +49,10 @@ type MetaConfig interface { GetDeployments() map[string]string // GetAPIVersion returns the API version GetAPIVersion() *string + // GetProjectID returns the project ID + GetProjectID() *string + // GetAuthCredentialPath returns the path to the authentication credentials for the provider + GetAuthCredentialPath() *string } // ConcurrencyAndBufferSize represents configuration for concurrent operations and buffer sizes. @@ -91,6 +95,32 @@ type ProviderConfig struct { ProxyConfig *ProxyConfig `json:"proxy_config,omitempty"` // Proxy configuration } +func (config *ProviderConfig) CheckAndSetDefaults() { + if config.ConcurrencyAndBufferSize.Concurrency == 0 { + config.ConcurrencyAndBufferSize.Concurrency = DefaultConcurrency + } + + if config.ConcurrencyAndBufferSize.BufferSize == 0 { + config.ConcurrencyAndBufferSize.BufferSize = DefaultBufferSize + } + + if config.NetworkConfig.DefaultRequestTimeoutInSeconds == 0 { + config.NetworkConfig.DefaultRequestTimeoutInSeconds = DefaultRequestTimeoutInSeconds + } + + if config.NetworkConfig.MaxRetries == 0 { + config.NetworkConfig.MaxRetries = DefaultMaxRetries + } + + if config.NetworkConfig.RetryBackoffInitial == 0 { + config.NetworkConfig.RetryBackoffInitial = DefaultRetryBackoffInitial + } + + if config.NetworkConfig.RetryBackoffMax == 0 { + config.NetworkConfig.RetryBackoffMax = DefaultRetryBackoffMax + } +} + // Provider defines the interface for AI model providers. type Provider interface { // GetProviderKey returns the provider's identifier diff --git a/core/tests/account.go b/core/tests/account.go index b5a86a57be..6c5335206a 100644 --- a/core/tests/account.go +++ b/core/tests/account.go @@ -27,7 +27,7 @@ type BaseAccount struct{} // - []schemas.SupportedModelProvider: A slice containing the supported provider identifiers // - error: Always returns nil as this implementation doesn't produce errors func (baseAccount *BaseAccount) GetConfiguredProviders() ([]schemas.ModelProvider, error) { - return []schemas.ModelProvider{schemas.OpenAI, schemas.Anthropic, schemas.Bedrock, schemas.Cohere, schemas.Azure}, nil + return []schemas.ModelProvider{schemas.OpenAI, schemas.Anthropic, schemas.Bedrock, schemas.Cohere, schemas.Azure, schemas.Vertex}, nil } // GetKeysForProvider returns the API keys and associated models for a given provider. @@ -197,6 +197,23 @@ func (baseAccount *BaseAccount) GetConfigForProvider(providerKey schemas.ModelPr BufferSize: 10, }, }, nil + case schemas.Vertex: + return &schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + DefaultRequestTimeoutInSeconds: 30, + MaxRetries: 1, + RetryBackoffInitial: 100 * time.Millisecond, + RetryBackoffMax: 2 * time.Second, + }, + MetaConfig: &meta.VertexMetaConfig{ + ProjectID: os.Getenv("VERTEX_PROJECT_ID"), + AuthCredentialPath: os.Getenv("VERTEX_CREDENTIALS_PATH"), + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 3, + BufferSize: 10, + }, + }, nil default: return nil, fmt.Errorf("unsupported provider: %s", providerKey) } diff --git a/core/tests/vertex_test.go b/core/tests/vertex_test.go new file mode 100644 index 0000000000..bd947d81d0 --- /dev/null +++ b/core/tests/vertex_test.go @@ -0,0 +1,34 @@ +// Package tests provides test utilities and configurations for the Bifrost system. +// It includes test implementations of schemas, mock objects, and helper functions +// for testing the Bifrost functionality with various AI providers. +package tests + +import ( + "testing" + + bifrost "github.com/maximhq/bifrost/core" + schemas "github.com/maximhq/bifrost/core/schemas" +) + +func TestVertex(t *testing.T) { + bifrostClient, err := getBifrost() + if err != nil { + t.Fatalf("Error initializing bifrost: %v", err) + return + } + + config := TestConfig{ + Provider: schemas.Vertex, + ChatModel: "google/gemini-2.0-flash-001", + SetupText: false, // Vertex does not support text completion + SetupToolCalls: false, + SetupImage: false, + SetupBaseImage: false, + CustomParams: &schemas.ModelParameters{ + Region: bifrost.Ptr("us-central1"), + }, + } + + SetupAllRequests(bifrostClient, config) + bifrostClient.Cleanup() +}