diff --git a/core/bifrost.go b/core/bifrost.go index aa53cdd5cb..c0de1d3889 100644 --- a/core/bifrost.go +++ b/core/bifrost.go @@ -28,7 +28,7 @@ const ( ) // executor is a function type that handles specific request types. -type executor func(provider schemas.Provider, req *ChannelMessage, key string) (*schemas.BifrostResponse, *schemas.BifrostError) +type executor func(provider schemas.Provider, req *ChannelMessage, key schemas.Key) (*schemas.BifrostResponse, *schemas.BifrostError) // messageExecutors is a factory map for handling different request types. var messageExecutors = map[RequestType]executor{ @@ -1092,7 +1092,7 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, queue chan Chan var bifrostError *schemas.BifrostError var err error - key := "" + key := schemas.Key{} if providerRequiresKey(provider.GetProviderKey()) { key, err = bifrost.selectKeyFromProviderForModel(provider.GetProviderKey(), req.Model) if err != nil { @@ -1207,7 +1207,7 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, queue chan Chan } // handleTextCompletion executes a text completion request -func handleTextCompletion(provider schemas.Provider, req *ChannelMessage, key string) (*schemas.BifrostResponse, *schemas.BifrostError) { +func handleTextCompletion(provider schemas.Provider, req *ChannelMessage, key schemas.Key) (*schemas.BifrostResponse, *schemas.BifrostError) { if req.Input.TextCompletionInput == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, @@ -1220,7 +1220,7 @@ func handleTextCompletion(provider schemas.Provider, req *ChannelMessage, key st } // handleChatCompletion executes a chat completion request -func handleChatCompletion(provider schemas.Provider, req *ChannelMessage, key string) (*schemas.BifrostResponse, *schemas.BifrostError) { +func handleChatCompletion(provider schemas.Provider, req *ChannelMessage, key schemas.Key) (*schemas.BifrostResponse, *schemas.BifrostError) { if req.Input.ChatCompletionInput == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, @@ -1233,7 +1233,7 @@ func handleChatCompletion(provider schemas.Provider, req *ChannelMessage, key st } // handleEmbedding executes an embedding request -func handleEmbedding(provider schemas.Provider, req *ChannelMessage, key string) (*schemas.BifrostResponse, *schemas.BifrostError) { +func handleEmbedding(provider schemas.Provider, req *ChannelMessage, key schemas.Key) (*schemas.BifrostResponse, *schemas.BifrostError) { if req.Input.EmbeddingInput == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, @@ -1246,7 +1246,7 @@ func handleEmbedding(provider schemas.Provider, req *ChannelMessage, key string) } // handleChatCompletionStream executes a chat completion stream request -func handleChatCompletionStream(provider schemas.Provider, req *ChannelMessage, key string, postHookRunner schemas.PostHookRunner) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func handleChatCompletionStream(provider schemas.Provider, req *ChannelMessage, key schemas.Key, postHookRunner schemas.PostHookRunner) (chan *schemas.BifrostStream, *schemas.BifrostError) { if req.Input.ChatCompletionInput == nil { return nil, &schemas.BifrostError{ IsBifrostError: false, @@ -1384,30 +1384,30 @@ func (bifrost *Bifrost) releaseChannelMessage(msg *ChannelMessage) { // selectKeyFromProviderForModel selects an appropriate API key for a given provider and model. // It uses weighted random selection if multiple keys are available. -func (bifrost *Bifrost) selectKeyFromProviderForModel(providerKey schemas.ModelProvider, model string) (string, error) { +func (bifrost *Bifrost) selectKeyFromProviderForModel(providerKey schemas.ModelProvider, model string) (schemas.Key, error) { keys, err := bifrost.account.GetKeysForProvider(providerKey) if err != nil { - return "", err + return schemas.Key{}, err } if len(keys) == 0 { - return "", fmt.Errorf("no keys found for provider: %v", providerKey) + return schemas.Key{}, fmt.Errorf("no keys found for provider: %v", providerKey) } // filter out keys which dont support the model, if the key has no models, it is supported for all models var supportedKeys []schemas.Key for _, key := range keys { - if (slices.Contains(key.Models, model) && strings.TrimSpace(key.Value) != "") || len(key.Models) == 0 { + if (slices.Contains(key.Models, model) && (strings.TrimSpace(key.Value) != "" || providerKey == schemas.Vertex)) || len(key.Models) == 0 { supportedKeys = append(supportedKeys, key) } } if len(supportedKeys) == 0 { - return "", fmt.Errorf("no keys found that support model: %s", model) + return schemas.Key{}, fmt.Errorf("no keys found that support model: %s", model) } if len(supportedKeys) == 1 { - return supportedKeys[0].Value, nil + return supportedKeys[0], nil } // Use a weighted random selection based on key weights @@ -1425,12 +1425,12 @@ func (bifrost *Bifrost) selectKeyFromProviderForModel(providerKey schemas.ModelP for _, key := range supportedKeys { currentWeight += int(key.Weight * 100) if randomValue < currentWeight { - return key.Value, nil + return key, nil } } // Fallback to first key if something goes wrong - return supportedKeys[0].Value, nil + return supportedKeys[0], nil } // CLEANUP diff --git a/core/providers/anthropic.go b/core/providers/anthropic.go index 02ab13bd87..d4b9f25901 100644 --- a/core/providers/anthropic.go +++ b/core/providers/anthropic.go @@ -318,7 +318,7 @@ func (provider *AnthropicProvider) completeRequest(ctx context.Context, requestB // TextCompletion performs a text completion request to Anthropic's API. // It formats the request, sends it to Anthropic, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. -func (provider *AnthropicProvider) TextCompletion(ctx context.Context, model, key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *AnthropicProvider) TextCompletion(ctx context.Context, model string, key schemas.Key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { preparedParams := provider.prepareTextCompletionParams(prepareParams(params)) // Merge additional parameters @@ -327,7 +327,7 @@ func (provider *AnthropicProvider) TextCompletion(ctx context.Context, model, ke "prompt": fmt.Sprintf("\n\nHuman: %s\n\nAssistant:", text), }, preparedParams) - responseBody, err := provider.completeRequest(ctx, requestBody, provider.networkConfig.BaseURL+"/v1/complete", key) + responseBody, err := provider.completeRequest(ctx, requestBody, provider.networkConfig.BaseURL+"/v1/complete", key.Value) if err != nil { return nil, err } @@ -379,7 +379,7 @@ func (provider *AnthropicProvider) TextCompletion(ctx context.Context, model, ke // ChatCompletion performs a chat completion request to Anthropic's API. // It formats the request, sends it to Anthropic, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. -func (provider *AnthropicProvider) ChatCompletion(ctx context.Context, model, key string, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *AnthropicProvider) ChatCompletion(ctx context.Context, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { formattedMessages, preparedParams := prepareAnthropicChatRequest(messages, params) // Merge additional parameters @@ -388,7 +388,7 @@ func (provider *AnthropicProvider) ChatCompletion(ctx context.Context, model, ke "messages": formattedMessages, }, preparedParams) - responseBody, err := provider.completeRequest(ctx, requestBody, provider.networkConfig.BaseURL+"/v1/messages", key) + responseBody, err := provider.completeRequest(ctx, requestBody, provider.networkConfig.BaseURL+"/v1/messages", key.Value) if err != nil { return nil, err } @@ -775,14 +775,14 @@ func parseAnthropicResponse(response *AnthropicChatResponse, bifrostResponse *sc } // Embedding is not supported by the Anthropic provider. -func (provider *AnthropicProvider) Embedding(ctx context.Context, model, key string, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *AnthropicProvider) Embedding(ctx context.Context, model string, key schemas.Key, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { return nil, newUnsupportedOperationError("embedding", "anthropic") } // ChatCompletionStream performs a streaming chat completion request to the Anthropic API. // It supports real-time streaming of responses using Server-Sent Events (SSE). // Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails. -func (provider *AnthropicProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model, key string, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *AnthropicProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { formattedMessages, preparedParams := prepareAnthropicChatRequest(messages, params) // Merge additional parameters and set stream to true @@ -795,7 +795,7 @@ func (provider *AnthropicProvider) ChatCompletionStream(ctx context.Context, pos // Prepare Anthropic headers headers := map[string]string{ "Content-Type": "application/json", - "x-api-key": key, + "x-api-key": key.Value, "anthropic-version": provider.apiVersion, "Accept": "text/event-stream", "Cache-Control": "no-cache", diff --git a/core/providers/azure.go b/core/providers/azure.go index 8637ccf05e..6ef52b3593 100644 --- a/core/providers/azure.go +++ b/core/providers/azure.go @@ -117,7 +117,6 @@ type AzureProvider struct { logger schemas.Logger // Logger for provider operations client *fasthttp.Client // HTTP client for API requests streamClient *http.Client // HTTP client for streaming requests - meta schemas.MetaConfig // Azure-specific configuration networkConfig schemas.NetworkConfig // Network configuration including extra headers } @@ -127,10 +126,6 @@ type AzureProvider struct { func NewAzureProvider(config *schemas.ProviderConfig, logger schemas.Logger) (*AzureProvider, error) { config.CheckAndSetDefaults() - if config.MetaConfig == nil { - return nil, fmt.Errorf("meta config is not set") - } - client := &fasthttp.Client{ ReadTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), WriteTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds), @@ -156,7 +151,6 @@ func NewAzureProvider(config *schemas.ProviderConfig, logger schemas.Logger) (*A logger: logger, client: client, streamClient: streamClient, - meta: config.MetaConfig, networkConfig: config.NetworkConfig, }, nil } @@ -169,7 +163,16 @@ func (provider *AzureProvider) GetProviderKey() schemas.ModelProvider { // completeRequest sends a request to Azure's API and handles the response. // It constructs the API URL, sets up authentication, and processes the response. // Returns the response body or an error if the request fails. -func (provider *AzureProvider) completeRequest(ctx context.Context, requestBody map[string]interface{}, path string, key string, model string) ([]byte, *schemas.BifrostError) { +func (provider *AzureProvider) completeRequest(ctx context.Context, requestBody map[string]interface{}, path string, key schemas.Key, model string) ([]byte, *schemas.BifrostError) { + if key.AzureKeyConfig == nil { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Message: "azure key config not set", + }, + } + } + // Marshal the request body jsonData, err := json.Marshal(requestBody) if err != nil { @@ -182,7 +185,7 @@ func (provider *AzureProvider) completeRequest(ctx context.Context, requestBody } } - if provider.meta.GetEndpoint() == nil { + if key.AzureKeyConfig.Endpoint == "" { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: schemas.ErrorField{ @@ -191,10 +194,10 @@ func (provider *AzureProvider) completeRequest(ctx context.Context, requestBody } } - url := *provider.meta.GetEndpoint() + url := key.AzureKeyConfig.Endpoint - if provider.meta.GetDeployments() != nil { - deployment := provider.meta.GetDeployments()[model] + if key.AzureKeyConfig.Deployments != nil { + deployment := key.AzureKeyConfig.Deployments[model] if deployment == "" { return nil, &schemas.BifrostError{ IsBifrostError: false, @@ -204,7 +207,7 @@ func (provider *AzureProvider) completeRequest(ctx context.Context, requestBody } } - apiVersion := provider.meta.GetAPIVersion() + apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil { apiVersion = StrPtr("2024-02-01") } @@ -236,7 +239,7 @@ func (provider *AzureProvider) completeRequest(ctx context.Context, requestBody // Ensure api-key is not accidentally present (from extra headers, etc.) req.Header.Del("api-key") } else { - req.Header.Set("api-key", key) + req.Header.Set("api-key", key.Value) } req.SetBody(jsonData) @@ -269,7 +272,7 @@ func (provider *AzureProvider) completeRequest(ctx context.Context, requestBody // TextCompletion performs a text completion request to Azure's API. // It formats the request, sends it to Azure, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. -func (provider *AzureProvider) TextCompletion(ctx context.Context, model, key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *AzureProvider) TextCompletion(ctx context.Context, model string, key schemas.Key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { preparedParams := prepareParams(params) // Merge additional parameters @@ -337,7 +340,7 @@ func (provider *AzureProvider) TextCompletion(ctx context.Context, model, key, t // ChatCompletion performs a chat completion request to Azure's API. // It formats the request, sends it to Azure, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. -func (provider *AzureProvider) ChatCompletion(ctx context.Context, model, key string, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *AzureProvider) ChatCompletion(ctx context.Context, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { formattedMessages, preparedParams := prepareOpenAIChatRequest(messages, params) // Merge additional parameters @@ -384,7 +387,7 @@ func (provider *AzureProvider) ChatCompletion(ctx context.Context, model, key st // Embedding generates embeddings for the given input text(s) using Azure OpenAI. // The input can be either a single string or a slice of strings for batch embedding. // Returns a BifrostResponse containing the embedding(s) and any error that occurred. -func (provider *AzureProvider) Embedding(ctx context.Context, model string, key string, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *AzureProvider) Embedding(ctx context.Context, model string, key schemas.Key, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { if len(input.Texts) == 0 { return nil, &schemas.BifrostError{ IsBifrostError: true, @@ -493,9 +496,18 @@ func (provider *AzureProvider) Embedding(ctx context.Context, model string, key // It supports real-time streaming of responses using Server-Sent Events (SSE). // Uses Azure-specific URL construction with deployments and supports both api-key and Bearer token authentication. // Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails. -func (provider *AzureProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model, key string, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *AzureProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { formattedMessages, preparedParams := prepareOpenAIChatRequest(messages, params) + if key.AzureKeyConfig == nil { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Message: "azure key config not set", + }, + } + } + // Merge additional parameters and set stream to true requestBody := mergeConfig(map[string]interface{}{ "model": model, @@ -504,7 +516,7 @@ func (provider *AzureProvider) ChatCompletionStream(ctx context.Context, postHoo }, preparedParams) // Construct Azure-specific URL with deployment - if provider.meta.GetEndpoint() == nil { + if key.AzureKeyConfig.Endpoint == "" { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: schemas.ErrorField{ @@ -513,11 +525,11 @@ func (provider *AzureProvider) ChatCompletionStream(ctx context.Context, postHoo } } - baseURL := *provider.meta.GetEndpoint() + baseURL := key.AzureKeyConfig.Endpoint var fullURL string - if provider.meta.GetDeployments() != nil { - deployment := provider.meta.GetDeployments()[model] + if key.AzureKeyConfig.Deployments != nil { + deployment := key.AzureKeyConfig.Deployments[model] if deployment == "" { return nil, &schemas.BifrostError{ IsBifrostError: false, @@ -527,7 +539,7 @@ func (provider *AzureProvider) ChatCompletionStream(ctx context.Context, postHoo } } - apiVersion := provider.meta.GetAPIVersion() + apiVersion := key.AzureKeyConfig.APIVersion if apiVersion == nil { apiVersion = StrPtr("2024-02-01") } @@ -552,7 +564,7 @@ func (provider *AzureProvider) ChatCompletionStream(ctx context.Context, postHoo if authToken, ok := ctx.Value(AzureAuthorizationTokenKey).(string); ok { headers["Authorization"] = fmt.Sprintf("Bearer %s", authToken) } else { - headers["api-key"] = key + headers["api-key"] = key.Value } // Use shared streaming logic from OpenAI diff --git a/core/providers/bedrock.go b/core/providers/bedrock.go index 1727ef8b4d..b60a0b456f 100644 --- a/core/providers/bedrock.go +++ b/core/providers/bedrock.go @@ -829,14 +829,14 @@ func (provider *BedrockProvider) prepareTextCompletionParams(params map[string]i // TextCompletion performs a text completion request to Bedrock's API. // It formats the request, sends it to Bedrock, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. -func (provider *BedrockProvider) TextCompletion(ctx context.Context, model, key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *BedrockProvider) TextCompletion(ctx context.Context, model string, key schemas.Key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { preparedParams := provider.prepareTextCompletionParams(prepareParams(params), model) requestBody := mergeConfig(map[string]interface{}{ "prompt": text, }, preparedParams) - body, err := provider.completeRequest(ctx, requestBody, fmt.Sprintf("%s/invoke", model), key) + body, err := provider.completeRequest(ctx, requestBody, fmt.Sprintf("%s/invoke", model), key.Value) if err != nil { return nil, err } @@ -920,7 +920,7 @@ func (provider *BedrockProvider) extractToolsFromHistory(messages []schemas.Bifr // ChatCompletion performs a chat completion request to Bedrock's API. // It formats the request, sends it to Bedrock, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. -func (provider *BedrockProvider) ChatCompletion(ctx context.Context, model, key string, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *BedrockProvider) ChatCompletion(ctx context.Context, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { messageBody, err := provider.prepareChatCompletionMessages(messages, model) if err != nil { return nil, err @@ -962,7 +962,7 @@ func (provider *BedrockProvider) ChatCompletion(ctx context.Context, model, key } // Create the signed request - responseBody, err := provider.completeRequest(ctx, requestBody, path, key) + responseBody, err := provider.completeRequest(ctx, requestBody, path, key.Value) if err != nil { return nil, err } @@ -1150,12 +1150,12 @@ func signAWSRequest(req *http.Request, accessKey, secretKey string, sessionToken // Embedding generates embeddings for the given input text(s) using Amazon Bedrock. // Supports Titan and Cohere embedding models. Returns a BifrostResponse containing the embedding(s) and any error that occurred. -func (provider *BedrockProvider) Embedding(ctx context.Context, model string, key string, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *BedrockProvider) Embedding(ctx context.Context, model string, key schemas.Key, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { switch { case strings.HasPrefix(model, "amazon.titan-embed-text"): - return provider.handleTitanEmbedding(ctx, model, key, input, params) + return provider.handleTitanEmbedding(ctx, model, key.Value, input, params) case strings.HasPrefix(model, "cohere.embed"): - return provider.handleCohereEmbedding(ctx, model, key, input, params) + return provider.handleCohereEmbedding(ctx, model, key.Value, input, params) default: return nil, &schemas.BifrostError{ IsBifrostError: false, @@ -1308,7 +1308,7 @@ func (provider *BedrockProvider) handleCohereEmbedding(ctx context.Context, mode // ChatCompletionStream performs a streaming chat completion request to Bedrock's API. // It formats the request, sends it to Bedrock, and processes the streaming response. // Returns a channel for streaming BifrostResponse objects or an error if the request fails. -func (provider *BedrockProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model, key string, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *BedrockProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { messageBody, err := provider.prepareChatCompletionMessages(messages, model) if err != nil { return nil, err @@ -1392,7 +1392,7 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx context.Context, postH // Sign the request for AWS if provider.meta.GetSecretAccessKey() != nil { - if signErr := signAWSRequest(req, key, *provider.meta.GetSecretAccessKey(), provider.meta.GetSessionToken(), region, "bedrock"); signErr != nil { + if signErr := signAWSRequest(req, key.Value, *provider.meta.GetSecretAccessKey(), provider.meta.GetSessionToken(), region, "bedrock"); signErr != nil { return nil, signErr } } else { diff --git a/core/providers/cohere.go b/core/providers/cohere.go index 4be00be190..15f2491385 100644 --- a/core/providers/cohere.go +++ b/core/providers/cohere.go @@ -183,14 +183,14 @@ func (provider *CohereProvider) GetProviderKey() schemas.ModelProvider { // TextCompletion is not supported by the Cohere provider. // Returns an error indicating that text completion is not supported. -func (provider *CohereProvider) TextCompletion(ctx context.Context, model, key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *CohereProvider) TextCompletion(ctx context.Context, model string, key schemas.Key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { return nil, newUnsupportedOperationError("text completion", "cohere") } // ChatCompletion performs a chat completion request to the Cohere API. // It formats the request, sends it to Cohere, and processes the response. // Returns a BifrostResponse containing the completion results or an error if the request fails. -func (provider *CohereProvider) ChatCompletion(ctx context.Context, model, key string, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *CohereProvider) ChatCompletion(ctx context.Context, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { // Prepare request body using shared function requestBody, err := prepareCohereChatRequest(messages, params, model, false) if err != nil { @@ -227,7 +227,7 @@ func (provider *CohereProvider) ChatCompletion(ctx context.Context, model, key s req.SetRequestURI(provider.networkConfig.BaseURL + "/v1/chat") req.Header.SetMethod("POST") req.Header.SetContentType("application/json") - req.Header.Set("Authorization", "Bearer "+key) + req.Header.Set("Authorization", "Bearer "+key.Value) req.SetBody(jsonBody) @@ -590,7 +590,7 @@ func convertChatHistory(history []struct { // Embedding generates embeddings for the given input text(s) using the Cohere API. // Supports Cohere's embedding models and returns a BifrostResponse containing the embedding(s). -func (provider *CohereProvider) Embedding(ctx context.Context, model string, key string, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *CohereProvider) Embedding(ctx context.Context, model string, key schemas.Key, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { if len(input.Texts) == 0 { return nil, &schemas.BifrostError{ IsBifrostError: true, @@ -654,7 +654,7 @@ func (provider *CohereProvider) Embedding(ctx context.Context, model string, key req.SetRequestURI(provider.networkConfig.BaseURL + "/v2/embed") req.Header.SetMethod("POST") req.Header.SetContentType("application/json") - req.Header.Set("Authorization", "Bearer "+key) + req.Header.Set("Authorization", "Bearer "+key.Value) req.SetBody(jsonBody) @@ -727,7 +727,7 @@ func (provider *CohereProvider) Embedding(ctx context.Context, model string, key // ChatCompletionStream performs a streaming chat completion request to the Cohere API. // It supports real-time streaming of responses using Server-Sent Events (SSE). // Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails. -func (provider *CohereProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model, key string, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *CohereProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { // Prepare request body using shared function requestBody, err := prepareCohereChatRequest(messages, params, model, true) if err != nil { @@ -765,7 +765,7 @@ func (provider *CohereProvider) ChatCompletionStream(ctx context.Context, postHo // Set headers req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+key) + req.Header.Set("Authorization", "Bearer "+key.Value) req.Header.Set("Accept", "text/event-stream") req.Header.Set("Cache-Control", "no-cache") diff --git a/core/providers/groq.go b/core/providers/groq.go index 04a28a6e85..b898d07e86 100644 --- a/core/providers/groq.go +++ b/core/providers/groq.go @@ -100,12 +100,12 @@ func (provider *GroqProvider) GetProviderKey() schemas.ModelProvider { } // TextCompletion is not supported by the Groq provider. -func (provider *GroqProvider) TextCompletion(ctx context.Context, model, key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *GroqProvider) TextCompletion(ctx context.Context, model string, key schemas.Key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { return nil, newUnsupportedOperationError("text completion", "groq") } // ChatCompletion performs a chat completion request to the Groq API. -func (provider *GroqProvider) ChatCompletion(ctx context.Context, model, key string, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *GroqProvider) ChatCompletion(ctx context.Context, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { formattedMessages, preparedParams := prepareOpenAIChatRequest(messages, params) requestBody := mergeConfig(map[string]interface{}{ @@ -136,9 +136,7 @@ func (provider *GroqProvider) ChatCompletion(ctx context.Context, model, key str req.SetRequestURI(provider.networkConfig.BaseURL + "/v1/chat/completions") req.Header.SetMethod("POST") req.Header.SetContentType("application/json") - if key != "" { - req.Header.Set("Authorization", "Bearer "+key) - } + req.Header.Set("Authorization", "Bearer "+key.Value) req.SetBody(jsonBody) @@ -192,7 +190,7 @@ func (provider *GroqProvider) ChatCompletion(ctx context.Context, model, key str } // Embedding is not supported by the Groq provider. -func (provider *GroqProvider) Embedding(ctx context.Context, model string, key string, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *GroqProvider) Embedding(ctx context.Context, model string, key schemas.Key, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { return nil, newUnsupportedOperationError("embedding", "groq") } @@ -200,7 +198,7 @@ func (provider *GroqProvider) Embedding(ctx context.Context, model string, key s // It supports real-time streaming of responses using Server-Sent Events (SSE). // Uses Groq's OpenAI-compatible streaming format. // Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails. -func (provider *GroqProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model, key string, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *GroqProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { formattedMessages, preparedParams := prepareOpenAIChatRequest(messages, params) requestBody := mergeConfig(map[string]interface{}{ @@ -216,10 +214,7 @@ func (provider *GroqProvider) ChatCompletionStream(ctx context.Context, postHook "Cache-Control": "no-cache", } - // Only add Authorization header if key is provided (Groq can run without auth) - if key != "" { - headers["Authorization"] = "Bearer " + key - } + headers["Authorization"] = "Bearer " + key.Value // Use shared OpenAI-compatible streaming logic return handleOpenAIStreaming( diff --git a/core/providers/mistral.go b/core/providers/mistral.go index 27377ffe37..1fd5fadf4d 100644 --- a/core/providers/mistral.go +++ b/core/providers/mistral.go @@ -114,12 +114,12 @@ func (provider *MistralProvider) GetProviderKey() schemas.ModelProvider { } // TextCompletion is not supported by the Mistral provider. -func (provider *MistralProvider) TextCompletion(ctx context.Context, model, key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *MistralProvider) TextCompletion(ctx context.Context, model string, key schemas.Key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { return nil, newUnsupportedOperationError("text completion", "mistral") } // ChatCompletion performs a chat completion request to the Mistral API. -func (provider *MistralProvider) ChatCompletion(ctx context.Context, model, key string, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *MistralProvider) ChatCompletion(ctx context.Context, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { formattedMessages, preparedParams := prepareOpenAIChatRequest(messages, params) requestBody := mergeConfig(map[string]interface{}{ @@ -150,7 +150,7 @@ func (provider *MistralProvider) ChatCompletion(ctx context.Context, model, key req.SetRequestURI(provider.networkConfig.BaseURL + "/v1/chat/completions") req.Header.SetMethod("POST") req.Header.SetContentType("application/json") - req.Header.Set("Authorization", "Bearer "+key) + req.Header.Set("Authorization", "Bearer "+key.Value) req.SetBody(jsonBody) @@ -205,7 +205,7 @@ func (provider *MistralProvider) ChatCompletion(ctx context.Context, model, key // Embedding generates embeddings for the given input text(s) using the Mistral API. // Supports Mistral's embedding models and returns a BifrostResponse containing the embedding(s). -func (provider *MistralProvider) Embedding(ctx context.Context, model string, key string, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *MistralProvider) Embedding(ctx context.Context, model string, key schemas.Key, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { if len(input.Texts) == 0 { return nil, &schemas.BifrostError{ IsBifrostError: true, @@ -271,7 +271,7 @@ func (provider *MistralProvider) Embedding(ctx context.Context, model string, ke req.SetRequestURI(provider.networkConfig.BaseURL + "/v1/embeddings") req.Header.SetMethod("POST") req.Header.SetContentType("application/json") - req.Header.Set("Authorization", "Bearer "+key) + req.Header.Set("Authorization", "Bearer "+key.Value) req.SetBody(jsonBody) @@ -349,7 +349,7 @@ func (provider *MistralProvider) Embedding(ctx context.Context, model string, ke // It supports real-time streaming of responses using Server-Sent Events (SSE). // Uses Mistral's OpenAI-compatible streaming format. // Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails. -func (provider *MistralProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model, key string, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *MistralProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { formattedMessages, preparedParams := prepareOpenAIChatRequest(messages, params) requestBody := mergeConfig(map[string]interface{}{ @@ -361,7 +361,7 @@ func (provider *MistralProvider) ChatCompletionStream(ctx context.Context, postH // Prepare Mistral headers headers := map[string]string{ "Content-Type": "application/json", - "Authorization": "Bearer " + key, + "Authorization": "Bearer " + key.Value, "Accept": "text/event-stream", "Cache-Control": "no-cache", } diff --git a/core/providers/ollama.go b/core/providers/ollama.go index 4412459df6..f60c261f97 100644 --- a/core/providers/ollama.go +++ b/core/providers/ollama.go @@ -101,12 +101,12 @@ func (provider *OllamaProvider) GetProviderKey() schemas.ModelProvider { } // TextCompletion is not supported by the Ollama provider. -func (provider *OllamaProvider) TextCompletion(ctx context.Context, model, key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *OllamaProvider) TextCompletion(ctx context.Context, model string, key schemas.Key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { return nil, newUnsupportedOperationError("text completion", "ollama") } // ChatCompletion performs a chat completion request to the Ollama API. -func (provider *OllamaProvider) ChatCompletion(ctx context.Context, model, key string, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *OllamaProvider) ChatCompletion(ctx context.Context, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { formattedMessages, preparedParams := prepareOpenAIChatRequest(messages, params) requestBody := mergeConfig(map[string]interface{}{ @@ -137,8 +137,8 @@ func (provider *OllamaProvider) ChatCompletion(ctx context.Context, model, key s req.SetRequestURI(provider.networkConfig.BaseURL + "/v1/chat/completions") req.Header.SetMethod("POST") req.Header.SetContentType("application/json") - if key != "" { - req.Header.Set("Authorization", "Bearer "+key) + if key.Value != "" { + req.Header.Set("Authorization", "Bearer "+key.Value) } req.SetBody(jsonBody) @@ -193,7 +193,7 @@ func (provider *OllamaProvider) ChatCompletion(ctx context.Context, model, key s } // Embedding is not supported by the Ollama provider. -func (provider *OllamaProvider) Embedding(ctx context.Context, model string, key string, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *OllamaProvider) Embedding(ctx context.Context, model string, key schemas.Key, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { return nil, newUnsupportedOperationError("embedding", "ollama") } @@ -201,7 +201,7 @@ func (provider *OllamaProvider) Embedding(ctx context.Context, model string, key // It supports real-time streaming of responses using Server-Sent Events (SSE). // Uses Ollama's OpenAI-compatible streaming format. // Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails. -func (provider *OllamaProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model, key string, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *OllamaProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { formattedMessages, preparedParams := prepareOpenAIChatRequest(messages, params) requestBody := mergeConfig(map[string]interface{}{ @@ -218,8 +218,8 @@ func (provider *OllamaProvider) ChatCompletionStream(ctx context.Context, postHo } // Only add Authorization header if key is provided (Ollama can run without auth) - if key != "" { - headers["Authorization"] = "Bearer " + key + if key.Value != "" { + headers["Authorization"] = "Bearer " + key.Value } // Use shared OpenAI-compatible streaming logic diff --git a/core/providers/openai.go b/core/providers/openai.go index 91451aec3f..b6814a5422 100644 --- a/core/providers/openai.go +++ b/core/providers/openai.go @@ -127,14 +127,14 @@ func (provider *OpenAIProvider) GetProviderKey() schemas.ModelProvider { // TextCompletion is not supported by the OpenAI provider. // Returns an error indicating that text completion is not available. -func (provider *OpenAIProvider) TextCompletion(ctx context.Context, model, key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *OpenAIProvider) TextCompletion(ctx context.Context, model string, key schemas.Key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { return nil, newUnsupportedOperationError("text completion", "openai") } // ChatCompletion performs a chat completion request to the OpenAI API. // It supports both text and image content in messages. // Returns a BifrostResponse containing the completion results or an error if the request fails. -func (provider *OpenAIProvider) ChatCompletion(ctx context.Context, model, key string, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *OpenAIProvider) ChatCompletion(ctx context.Context, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { formattedMessages, preparedParams := prepareOpenAIChatRequest(messages, params) requestBody := mergeConfig(map[string]interface{}{ @@ -165,7 +165,7 @@ func (provider *OpenAIProvider) ChatCompletion(ctx context.Context, model, key s req.SetRequestURI(provider.networkConfig.BaseURL + "/v1/chat/completions") req.Header.SetMethod("POST") req.Header.SetContentType("application/json") - req.Header.Set("Authorization", "Bearer "+key) + req.Header.Set("Authorization", "Bearer "+key.Value) req.SetBody(jsonBody) @@ -280,7 +280,7 @@ func prepareOpenAIChatRequest(messages []schemas.BifrostMessage, params *schemas // Embedding generates embeddings for the given input text(s). // The input can be either a single string or a slice of strings for batch embedding. // Returns a BifrostResponse containing the embedding(s) and any error that occurred. -func (provider *OpenAIProvider) Embedding(ctx context.Context, model string, key string, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *OpenAIProvider) Embedding(ctx context.Context, model string, key schemas.Key, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { // Validate input texts are not empty if len(input.Texts) == 0 { return nil, &schemas.BifrostError{ @@ -337,7 +337,7 @@ func (provider *OpenAIProvider) Embedding(ctx context.Context, model string, key req.SetRequestURI(provider.networkConfig.BaseURL + "/v1/embeddings") req.Header.SetMethod("POST") req.Header.SetContentType("application/json") - req.Header.Set("Authorization", "Bearer "+key) + req.Header.Set("Authorization", "Bearer "+key.Value) req.SetBody(jsonBody) @@ -466,7 +466,7 @@ func (provider *OpenAIProvider) Embedding(ctx context.Context, model string, key return bifrostResponse, nil } -func (provider *OpenAIProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model, key string, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *OpenAIProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { formattedMessages, preparedParams := prepareOpenAIChatRequest(messages, params) requestBody := mergeConfig(map[string]interface{}{ @@ -478,7 +478,7 @@ func (provider *OpenAIProvider) ChatCompletionStream(ctx context.Context, postHo // Prepare OpenAI headers headers := map[string]string{ "Content-Type": "application/json", - "Authorization": "Bearer " + key, + "Authorization": "Bearer " + key.Value, "Accept": "text/event-stream", "Cache-Control": "no-cache", } diff --git a/core/providers/sgl.go b/core/providers/sgl.go index 74479cce8a..846d0b8584 100644 --- a/core/providers/sgl.go +++ b/core/providers/sgl.go @@ -101,12 +101,12 @@ func (provider *SGLProvider) GetProviderKey() schemas.ModelProvider { } // TextCompletion is not supported by the SGL provider. -func (provider *SGLProvider) TextCompletion(ctx context.Context, model, key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *SGLProvider) TextCompletion(ctx context.Context, model string, key schemas.Key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { return nil, newUnsupportedOperationError("text completion", "sgl") } // ChatCompletion performs a chat completion request to the SGL API. -func (provider *SGLProvider) ChatCompletion(ctx context.Context, model, key string, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *SGLProvider) ChatCompletion(ctx context.Context, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { formattedMessages, preparedParams := prepareOpenAIChatRequest(messages, params) requestBody := mergeConfig(map[string]interface{}{ @@ -137,8 +137,8 @@ func (provider *SGLProvider) ChatCompletion(ctx context.Context, model, key stri req.SetRequestURI(provider.networkConfig.BaseURL + "/v1/chat/completions") req.Header.SetMethod("POST") req.Header.SetContentType("application/json") - if key != "" { - req.Header.Set("Authorization", "Bearer "+key) + if key.Value != "" { + req.Header.Set("Authorization", "Bearer "+key.Value) } req.SetBody(jsonBody) @@ -193,7 +193,7 @@ func (provider *SGLProvider) ChatCompletion(ctx context.Context, model, key stri } // Embedding is not supported by the SGL provider. -func (provider *SGLProvider) Embedding(ctx context.Context, model string, key string, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *SGLProvider) Embedding(ctx context.Context, model string, key schemas.Key, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { return nil, newUnsupportedOperationError("embedding", "sgl") } @@ -201,7 +201,7 @@ func (provider *SGLProvider) Embedding(ctx context.Context, model string, key st // It supports real-time streaming of responses using Server-Sent Events (SSE). // Uses SGL's OpenAI-compatible streaming format. // Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails. -func (provider *SGLProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model, key string, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { +func (provider *SGLProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { formattedMessages, preparedParams := prepareOpenAIChatRequest(messages, params) requestBody := mergeConfig(map[string]interface{}{ @@ -218,8 +218,8 @@ func (provider *SGLProvider) ChatCompletionStream(ctx context.Context, postHookR } // Only add Authorization header if key is provided (SGL can run without auth) - if key != "" { - headers["Authorization"] = "Bearer " + key + if key.Value != "" { + headers["Authorization"] = "Bearer " + key.Value } // Use shared OpenAI-compatible streaming logic diff --git a/core/providers/vertex.go b/core/providers/vertex.go index 71c5c937cc..7528773c09 100644 --- a/core/providers/vertex.go +++ b/core/providers/vertex.go @@ -5,11 +5,14 @@ package providers import ( "bytes" "context" + "crypto/sha256" + "encoding/hex" "errors" "fmt" "io" "net/http" "strings" + "sync" "github.com/goccy/go-json" "golang.org/x/oauth2/google" @@ -25,11 +28,32 @@ type VertexError struct { } `json:"error"` } +// vertexClientPool provides a pool/cache for authenticated Vertex HTTP clients. +// This avoids creating and authenticating clients for every request. +// Uses sync.Map for atomic operations without explicit locking. +var vertexClientPool sync.Map + +// getClientKey generates a unique key for caching authenticated clients. +// It uses a hash of the auth credentials for security. +func getClientKey(authCredentials string) string { + hash := sha256.Sum256([]byte(authCredentials)) + return hex.EncodeToString(hash[:]) +} + +// removeVertexClient removes a specific client from the pool. +// This should be called when: +// - API returns authentication/authorization errors (401, 403) +// - Auth client creation fails +// - Network errors that might indicate credential issues +// This ensures we don't keep using potentially invalid clients. +func removeVertexClient(authCredentials string) { + clientKey := getClientKey(authCredentials) + vertexClientPool.Delete(clientKey) +} + // VertexProvider implements the Provider interface for Google's Vertex AI API. type VertexProvider struct { logger schemas.Logger // Logger for provider operations - client *http.Client // HTTP client for API requests - meta schemas.MetaConfig // Vertex-specific configuration networkConfig schemas.NetworkConfig // Network configuration including extra headers } @@ -39,24 +63,6 @@ type VertexProvider struct { func NewVertexProvider(config *schemas.ProviderConfig, logger schemas.Logger) (*VertexProvider, error) { config.CheckAndSetDefaults() - if config.MetaConfig == nil { - return nil, fmt.Errorf("meta config is not set") - } - - authCredentials := config.MetaConfig.GetAuthCredentials() - if authCredentials == nil { - return nil, fmt.Errorf("auth credentials are not set") - } - - // Get a Google JWT Config for the correct scope - conf, err := google.JWTConfigFromJSON([]byte(*authCredentials), "https://www.googleapis.com/auth/cloud-platform") - if err != nil { - return nil, fmt.Errorf("failed to create JWT config: %w", err) - } - - // Get an access token - client := conf.Client(context.Background()) - // Pre-warm response pools for range config.ConcurrencyAndBufferSize.Concurrency { openAIResponsePool.Put(&OpenAIResponse{}) @@ -66,12 +72,49 @@ func NewVertexProvider(config *schemas.ProviderConfig, logger schemas.Logger) (* return &VertexProvider{ logger: logger, - client: client, - meta: config.MetaConfig, networkConfig: config.NetworkConfig, }, nil } +// getAuthClient returns an authenticated HTTP client for Vertex AI API requests. +// This function implements client pooling to avoid creating and authenticating +// clients for every request, which significantly improves performance by: +// - Avoiding repeated JWT config creation +// - Reusing OAuth2 token refresh logic +// - Reducing authentication overhead +func getAuthClient(key schemas.Key) (*http.Client, error) { + if key.VertexKeyConfig == nil { + return nil, fmt.Errorf("vertex key config is not set") + } + + authCredentials := key.VertexKeyConfig.AuthCredentials + + if authCredentials == "" { + return nil, fmt.Errorf("auth credentials are not set") + } + + // Generate cache key from credentials + clientKey := getClientKey(authCredentials) + + // Try to get existing client from pool + if value, exists := vertexClientPool.Load(clientKey); exists { + return value.(*http.Client), nil + } + + // Create new authenticated client + conf, err := google.JWTConfigFromJSON([]byte(authCredentials), "https://www.googleapis.com/auth/cloud-platform") + if err != nil { + return nil, fmt.Errorf("failed to create JWT config: %w", err) + } + + client := conf.Client(context.Background()) + + // Store the client using LoadOrStore to handle race conditions + // If another goroutine stored a client while we were creating ours, use theirs + actual, _ := vertexClientPool.LoadOrStore(clientKey, client) + return actual.(*http.Client), nil +} + // GetProviderKey returns the provider identifier for Vertex. func (provider *VertexProvider) GetProviderKey() schemas.ModelProvider { return schemas.Vertex @@ -79,14 +122,23 @@ func (provider *VertexProvider) GetProviderKey() schemas.ModelProvider { // TextCompletion is not supported by the Vertex provider. // Returns an error indicating that text completion is not available. -func (provider *VertexProvider) TextCompletion(ctx context.Context, model, key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *VertexProvider) TextCompletion(ctx context.Context, model string, key schemas.Key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { return nil, newUnsupportedOperationError("text completion", "vertex") } // ChatCompletion performs a chat completion request to the Vertex API. // It supports both text and image content in messages. // Returns a BifrostResponse containing the completion results or an error if the request fails. -func (provider *VertexProvider) ChatCompletion(ctx context.Context, model, key string, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *VertexProvider) ChatCompletion(ctx context.Context, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { + if key.VertexKeyConfig == nil { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Message: "vertex key config is not set", + }, + } + } + // Format messages for Vertex API var formattedMessages []map[string]interface{} var preparedParams map[string]interface{} @@ -123,8 +175,8 @@ func (provider *VertexProvider) ChatCompletion(ctx context.Context, model, key s } } - projectID := provider.meta.GetProjectID() - if projectID == nil { + projectID := key.VertexKeyConfig.ProjectID + if projectID == "" { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: schemas.ErrorField{ @@ -133,8 +185,8 @@ func (provider *VertexProvider) ChatCompletion(ctx context.Context, model, key s } } - region := provider.meta.GetRegion() - if region == nil { + region := key.VertexKeyConfig.Region + if region == "" { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: schemas.ErrorField{ @@ -143,10 +195,10 @@ func (provider *VertexProvider) ChatCompletion(ctx context.Context, model, key s } } - url := fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions", *region, *projectID, *region) + url := fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions", region, projectID, region) if strings.Contains(model, "claude") { - url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:streamRawPredict", *region, *projectID, *region, model) + url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:streamRawPredict", region, projectID, region, model) } // Create request @@ -166,8 +218,21 @@ func (provider *VertexProvider) ChatCompletion(ctx context.Context, model, key s req.Header.Set("Content-Type", "application/json") + client, err := getAuthClient(key) + if err != nil { + // Remove client from pool if auth client creation fails + removeVertexClient(key.VertexKeyConfig.AuthCredentials) + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Message: "error creating auth client", + Error: err, + }, + } + } + // Make request - resp, err := provider.client.Do(req) + resp, err := client.Do(req) if err != nil { if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { return nil, &schemas.BifrostError{ @@ -179,6 +244,8 @@ func (provider *VertexProvider) ChatCompletion(ctx context.Context, model, key s }, } } + // Remove client from pool for non-context errors (could be auth/network issues) + removeVertexClient(key.VertexKeyConfig.AuthCredentials) return nil, &schemas.BifrostError{ IsBifrostError: false, Error: schemas.ErrorField{ @@ -203,6 +270,11 @@ func (provider *VertexProvider) ChatCompletion(ctx context.Context, model, key s } if resp.StatusCode != http.StatusOK { + // Remove client from pool for authentication/authorization errors + if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden { + removeVertexClient(key.VertexKeyConfig.AuthCredentials) + } + var openAIErr OpenAIError var vertexErr []VertexError @@ -304,16 +376,25 @@ func (provider *VertexProvider) ChatCompletion(ctx context.Context, model, key s } // Embedding is not supported by the Vertex provider. -func (provider *VertexProvider) Embedding(ctx context.Context, model string, key string, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { +func (provider *VertexProvider) Embedding(ctx context.Context, model string, key schemas.Key, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) { return nil, newUnsupportedOperationError("embedding", "vertex") } // ChatCompletionStream performs a streaming chat completion request to the Vertex API. // It supports both OpenAI-style streaming (for non-Claude models) and Anthropic-style streaming (for Claude models). // Returns a channel of BifrostResponse objects for streaming results or an error if the request fails. -func (provider *VertexProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model, key string, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { - projectID := provider.meta.GetProjectID() - if projectID == nil { +func (provider *VertexProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) { + if key.VertexKeyConfig == nil { + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Message: "vertex key config is not set", + }, + } + } + + projectID := key.VertexKeyConfig.ProjectID + if projectID == "" { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: schemas.ErrorField{ @@ -322,8 +403,8 @@ func (provider *VertexProvider) ChatCompletionStream(ctx context.Context, postHo } } - region := provider.meta.GetRegion() - if region == nil { + region := key.VertexKeyConfig.Region + if region == "" { return nil, &schemas.BifrostError{ IsBifrostError: false, Error: schemas.ErrorField{ @@ -332,6 +413,19 @@ func (provider *VertexProvider) ChatCompletionStream(ctx context.Context, postHo } } + client, err := getAuthClient(key) + if err != nil { + // Remove client from pool if auth client creation fails + removeVertexClient(key.VertexKeyConfig.AuthCredentials) + return nil, &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Message: "error creating auth client", + Error: err, + }, + } + } + if strings.Contains(model, "claude") { // Use Anthropic-style streaming for Claude models formattedMessages, preparedParams := prepareAnthropicChatRequest(messages, params) @@ -348,7 +442,7 @@ func (provider *VertexProvider) ChatCompletionStream(ctx context.Context, postHo delete(requestBody, "model") delete(requestBody, "region") - url := fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:streamRawPredict", *region, *projectID, *region, model) + url := fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:streamRawPredict", region, projectID, region, model) // Prepare headers for Vertex Anthropic headers := map[string]string{ @@ -360,7 +454,7 @@ func (provider *VertexProvider) ChatCompletionStream(ctx context.Context, postHo // Use shared Anthropic streaming logic return handleAnthropicStreaming( ctx, - provider.client, + client, url, requestBody, headers, @@ -382,7 +476,7 @@ func (provider *VertexProvider) ChatCompletionStream(ctx context.Context, postHo delete(requestBody, "region") - url := fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions", *region, *projectID, *region) + url := fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions", region, projectID, region) // Prepare headers for Vertex OpenAI-compatible headers := map[string]string{ @@ -394,7 +488,7 @@ func (provider *VertexProvider) ChatCompletionStream(ctx context.Context, postHo // Use shared OpenAI streaming logic return handleOpenAIStreaming( ctx, - provider.client, + client, url, requestBody, headers, diff --git a/core/schemas/account.go b/core/schemas/account.go index 7800c2dd3c..9ab2431017 100644 --- a/core/schemas/account.go +++ b/core/schemas/account.go @@ -4,9 +4,28 @@ package schemas // Key represents an API key and its associated configuration for a provider. // It contains the key value, supported models, and a weight for load balancing. type Key struct { - Value string `json:"value"` // The actual API key value - Models []string `json:"models"` // List of models this key can access - Weight float64 `json:"weight"` // Weight for load balancing between multiple keys + ID string `json:"id"` // The unique identifier for the key (not used by bifrost, but can be used by users to identify the key) + Value string `json:"value"` // The actual API key value + Models []string `json:"models"` // List of models this key can access + Weight float64 `json:"weight"` // Weight for load balancing between multiple keys + AzureKeyConfig *AzureKeyConfig `json:"azure_key_config,omitempty"` // Azure-specific key configuration + VertexKeyConfig *VertexKeyConfig `json:"vertex_key_config,omitempty"` // Vertex-specific key configuration +} + +// AzureKeyConfig represents the Azure-specific configuration. +// It contains Azure-specific settings required for service access and deployment management. +type AzureKeyConfig struct { + Endpoint string `json:"endpoint"` // Azure service endpoint URL + Deployments map[string]string `json:"deployments,omitempty"` // Mapping of model names to deployment names + APIVersion *string `json:"api_version,omitempty"` // Azure API version to use; defaults to "2024-02-01" +} + +// VertexKeyConfig represents the Vertex-specific configuration. +// It contains Vertex-specific settings required for authentication and service access. +type VertexKeyConfig struct { + ProjectID string `json:"project_id,omitempty"` + Region string `json:"region,omitempty"` + AuthCredentials string `json:"auth_credentials,omitempty"` } // Account defines the interface for managing provider accounts and their configurations. diff --git a/core/schemas/meta/azure.go b/core/schemas/meta/azure.go deleted file mode 100644 index 58abbd071c..0000000000 --- a/core/schemas/meta/azure.go +++ /dev/null @@ -1,40 +0,0 @@ -// Package meta provides provider-specific configuration structures and schemas. -// This file contains the Azure-specific configuration implementation. - -package meta - -// AzureMetaConfig represents the Azure-specific configuration. -// It contains Azure-specific settings required for service access and deployment management. -type AzureMetaConfig struct { - Endpoint string `json:"endpoint"` // Azure service endpoint URL - Deployments map[string]string `json:"deployments,omitempty"` // Mapping of model names to deployment names - APIVersion *string `json:"api_version,omitempty"` // Azure API version to use; defaults to "2024-02-01" -} - -// GetEndpoint returns the Azure service endpoint. -// This specifies the base URL for Azure API requests. -func (c *AzureMetaConfig) GetEndpoint() *string { - return &c.Endpoint -} - -// GetDeployments returns the deployment configurations. -// This maps model names to their corresponding Azure deployment names. -// E.g. "gpt-4o": "your-deployment-name-for-gpt-4o" -func (c *AzureMetaConfig) GetDeployments() map[string]string { - return c.Deployments -} - -// GetAPIVersion returns the Azure API version. -// This specifies which version of the Azure API to use. -func (c *AzureMetaConfig) GetAPIVersion() *string { - return c.APIVersion -} - -// These are not used for Azure. -func (c *AzureMetaConfig) GetARN() *string { return nil } -func (c *AzureMetaConfig) GetAuthCredentials() *string { return nil } -func (c *AzureMetaConfig) GetInferenceProfiles() map[string]string { return nil } -func (c *AzureMetaConfig) GetProjectID() *string { return nil } -func (c *AzureMetaConfig) GetRegion() *string { return nil } -func (c *AzureMetaConfig) GetSecretAccessKey() *string { return nil } -func (c *AzureMetaConfig) GetSessionToken() *string { return nil } diff --git a/core/schemas/meta/bedrock.go b/core/schemas/meta/bedrock.go index bdff19e76a..fe3561c11b 100644 --- a/core/schemas/meta/bedrock.go +++ b/core/schemas/meta/bedrock.go @@ -42,10 +42,3 @@ func (c *BedrockMetaConfig) GetARN() *string { func (c *BedrockMetaConfig) GetInferenceProfiles() map[string]string { return c.InferenceProfiles } - -// These are not used for Bedrock. -func (c *BedrockMetaConfig) GetAPIVersion() *string { return nil } -func (c *BedrockMetaConfig) GetAuthCredentials() *string { return nil } -func (c *BedrockMetaConfig) GetDeployments() map[string]string { return nil } -func (c *BedrockMetaConfig) GetEndpoint() *string { return nil } -func (c *BedrockMetaConfig) GetProjectID() *string { return nil } diff --git a/core/schemas/meta/vertex.go b/core/schemas/meta/vertex.go deleted file mode 100644 index a82e46380d..0000000000 --- a/core/schemas/meta/vertex.go +++ /dev/null @@ -1,39 +0,0 @@ -// Package meta provides provider-specific configuration structures and schemas. -// This file contains the AWS Vertex-specific configuration implementation. - -package meta - -// VertexMetaConfig represents the Vertex-specific configuration. -// It contains Vertex-specific settings required for authentication and service access. -type VertexMetaConfig struct { - ProjectID string `json:"project_id,omitempty"` - Region string `json:"region,omitempty"` - AuthCredentials string `json:"auth_credentials,omitempty"` -} - -// GetRegion returns the Vertex region. -// This is the region for the Vertex project. -func (c *VertexMetaConfig) GetRegion() *string { - return &c.Region -} - -// GetProjectID returns the Vertex project ID. -// This is the project ID for the Vertex project. -func (c *VertexMetaConfig) GetProjectID() *string { - return &c.ProjectID -} - -// GetAuthCredentials returns the authentication credentials for the provider. -// This is the authentication credentials for the google cloud api. -func (c *VertexMetaConfig) GetAuthCredentials() *string { - return &c.AuthCredentials -} - -// These are not used for Vertex. -func (c *VertexMetaConfig) GetAPIVersion() *string { return nil } -func (c *VertexMetaConfig) GetARN() *string { return nil } -func (c *VertexMetaConfig) GetDeployments() map[string]string { return nil } -func (c *VertexMetaConfig) GetEndpoint() *string { return nil } -func (c *VertexMetaConfig) GetInferenceProfiles() map[string]string { return nil } -func (c *VertexMetaConfig) GetSecretAccessKey() *string { return nil } -func (c *VertexMetaConfig) GetSessionToken() *string { return nil } diff --git a/core/schemas/provider.go b/core/schemas/provider.go index 8b13b3c259..f2df45e17c 100644 --- a/core/schemas/provider.go +++ b/core/schemas/provider.go @@ -60,16 +60,6 @@ type MetaConfig interface { GetARN() *string // GetInferenceProfiles returns the inference profiles GetInferenceProfiles() map[string]string - // GetEndpoint returns the provider endpoint - GetEndpoint() *string - // GetDeployments returns the deployment configurations - GetDeployments() map[string]string - // GetAPIVersion returns the API version - GetAPIVersion() *string - // GetProjectID returns the project ID - GetProjectID() *string - // GetAuthCredentials returns the authentication credentials for the provider - GetAuthCredentials() *string } // ConcurrencyAndBufferSize represents configuration for concurrent operations and buffer sizes. @@ -158,11 +148,11 @@ type Provider interface { // GetProviderKey returns the provider's identifier GetProviderKey() ModelProvider // TextCompletion performs a text completion request - TextCompletion(ctx context.Context, model, key, text string, params *ModelParameters) (*BifrostResponse, *BifrostError) + TextCompletion(ctx context.Context, model string, key Key, text string, params *ModelParameters) (*BifrostResponse, *BifrostError) // ChatCompletion performs a chat completion request - ChatCompletion(ctx context.Context, model, key string, messages []BifrostMessage, params *ModelParameters) (*BifrostResponse, *BifrostError) + ChatCompletion(ctx context.Context, model string, key Key, messages []BifrostMessage, params *ModelParameters) (*BifrostResponse, *BifrostError) // ChatCompletionStream performs a chat completion stream request - ChatCompletionStream(ctx context.Context, postHookRunner PostHookRunner, model, key string, messages []BifrostMessage, params *ModelParameters) (chan *BifrostStream, *BifrostError) + ChatCompletionStream(ctx context.Context, postHookRunner PostHookRunner, model string, key Key, messages []BifrostMessage, params *ModelParameters) (chan *BifrostStream, *BifrostError) // Embedding performs an embedding request - Embedding(ctx context.Context, model string, key string, input *EmbeddingInput, params *ModelParameters) (*BifrostResponse, *BifrostError) + Embedding(ctx context.Context, model string, key Key, input *EmbeddingInput, params *ModelParameters) (*BifrostResponse, *BifrostError) } diff --git a/core/utils.go b/core/utils.go index 8aa329d5c5..fb0c4256df 100644 --- a/core/utils.go +++ b/core/utils.go @@ -12,9 +12,9 @@ func Ptr[T any](v T) *T { } // providerRequiresKey returns true if the given provider requires an API key for authentication. -// Some providers like Vertex and Ollama are keyless and don't require API keys. +// Some providers like Ollama and SGL are keyless and don't require API keys. func providerRequiresKey(providerKey schemas.ModelProvider) bool { - return providerKey != schemas.Vertex && providerKey != schemas.Ollama && providerKey != schemas.SGL + return providerKey != schemas.Ollama && providerKey != schemas.SGL } // calculateBackoff implements exponential backoff with jitter for retry attempts. diff --git a/docs/usage/go-package/account.md b/docs/usage/go-package/account.md index 9ccde39273..535b334bcc 100644 --- a/docs/usage/go-package/account.md +++ b/docs/usage/go-package/account.md @@ -131,6 +131,13 @@ func (a *MultiProviderAccount) GetKeysForProvider(provider schemas.ModelProvider Value: os.Getenv("AZURE_API_KEY"), Models: []string{"gpt-4o"}, Weight: 1.0, + AzureKeyConfig: &schemas.AzureKeyConfig{ + Endpoint: os.Getenv("AZURE_ENDPOINT"), + APIVersion: bifrost.Ptr("2024-08-01-preview"), + Deployments: map[string]string{ + "gpt-4o": "gpt-4o-deployment", + }, + }, }}, nil case schemas.Bedrock: @@ -141,8 +148,15 @@ func (a *MultiProviderAccount) GetKeysForProvider(provider schemas.ModelProvider }}, nil case schemas.Vertex: - // Vertex is keyless (uses Google Cloud credentials) - return []schemas.Key{}, nil + return []schemas.Key{{ + Models: []string{"google/gemini-2.0-flash-001"}, + Weight: 1.0, + VertexKeyConfig: &schemas.VertexKeyConfig{ + ProjectID: os.Getenv("VERTEX_PROJECT_ID"), + Region: "us-central1", + AuthCredentials: os.Getenv("VERTEX_CREDENTIALS"), + }, + }}, nil } return nil, fmt.Errorf("provider %s not supported", provider) @@ -171,13 +185,6 @@ func (a *MultiProviderAccount) GetConfigForProvider(provider schemas.ModelProvid RetryBackoffMax: 10 * time.Second, }, ConcurrencyAndBufferSize: schemas.DefaultConcurrencyAndBufferSize, - MetaConfig: &meta.AzureMetaConfig{ - Endpoint: os.Getenv("AZURE_ENDPOINT"), - APIVersion: bifrost.Ptr("2024-08-01-preview"), - Deployments: map[string]string{ - "gpt-4o": "gpt-4o-deployment", - }, - }, }, nil case schemas.Bedrock: @@ -194,11 +201,6 @@ func (a *MultiProviderAccount) GetConfigForProvider(provider schemas.ModelProvid return &schemas.ProviderConfig{ NetworkConfig: schemas.DefaultNetworkConfig, ConcurrencyAndBufferSize: schemas.DefaultConcurrencyAndBufferSize, - MetaConfig: &meta.VertexMetaConfig{ - ProjectID: os.Getenv("VERTEX_PROJECT_ID"), - Region: "us-central1", - AuthCredentials: os.Getenv("VERTEX_CREDENTIALS"), - }, }, nil } diff --git a/docs/usage/http-transport/configuration/providers.md b/docs/usage/http-transport/configuration/providers.md index eab055b82d..f7ae62ae44 100644 --- a/docs/usage/http-transport/configuration/providers.md +++ b/docs/usage/http-transport/configuration/providers.md @@ -158,7 +158,14 @@ Provider configuration in `config.json` defines: { "value": "env.AZURE_API_KEY", "models": ["gpt-4o"], - "weight": 1.0 + "weight": 1.0, + "azure_key_config": { + "endpoint": "env.AZURE_ENDPOINT", + "deployments": { + "gpt-4o": "gpt-4o-aug" + }, + "api_version": "2024-08-01-preview" + } } ], "network_config": { @@ -167,13 +174,6 @@ Provider configuration in `config.json` defines: "retry_backoff_initial_ms": 100, "retry_backoff_max_ms": 2000 }, - "meta_config": { - "endpoint": "env.AZURE_ENDPOINT", - "deployments": { - "gpt-4o": "gpt-4o-aug" - }, - "api_version": "2024-08-01-preview" - }, "concurrency_and_buffer_size": { "concurrency": 3, "buffer_size": 10 @@ -189,12 +189,18 @@ Provider configuration in `config.json` defines: { "providers": { "vertex": { - "keys": [], - "meta_config": { - "project_id": "env.VERTEX_PROJECT_ID", - "region": "us-central1", - "auth_credentials": "env.VERTEX_CREDENTIALS" - }, + "keys": [ + { + "value": "env.VERTEX_API_KEY", + "models": ["gemini-2.0-flash-001"], + "weight": 1.0, + "vertex_key_config": { + "project_id": "env.VERTEX_PROJECT_ID", + "region": "us-central1", + "auth_credentials": "env.VERTEX_CREDENTIALS" + } + } + ], "concurrency_and_buffer_size": { "concurrency": 3, "buffer_size": 10 diff --git a/docs/usage/http-transport/integrations/genai-compatible.md b/docs/usage/http-transport/integrations/genai-compatible.md index 8e986c9e15..f45df1ac4b 100644 --- a/docs/usage/http-transport/integrations/genai-compatible.md +++ b/docs/usage/http-transport/integrations/genai-compatible.md @@ -544,12 +544,17 @@ response3 = model3.generate_content("Hello!") { "providers": { "vertex": { - "keys": [], - "meta_config": { - "project_id": "env.VERTEX_PROJECT_ID", - "region": "us-central1", - "auth_credentials": "env.VERTEX_CREDENTIALS" - }, + "keys": [ + { + "models": ["gemini-2.0-flash-001"], + "weight": 1.0, + "vertex_key_config": { + "project_id": "env.VERTEX_PROJECT_ID", + "region": "us-central1", + "auth_credentials": "env.VERTEX_CREDENTIALS" + } + } + ], "network_config": { "default_request_timeout_in_seconds": 30, "max_retries": 2, @@ -609,10 +614,17 @@ genai.configure( { "providers": { "vertex": { - "meta_config": { - "project_id": "env.VERTEX_PROJECT_ID", - "region": "us-central1" - } + "keys": [ + { + "models": ["gemini-2.0-flash-001"], + "weight": 1.0, + "vertex_key_config": { + "project_id": "env.VERTEX_PROJECT_ID", + "region": "us-central1", + "auth_credentials": "env.VERTEX_CREDENTIALS" + } + } + ] }, "openai": { "keys": [ diff --git a/docs/usage/http-transport/openapi.json b/docs/usage/http-transport/openapi.json index 3f935e6dd4..810eacb5b6 100644 --- a/docs/usage/http-transport/openapi.json +++ b/docs/usage/http-transport/openapi.json @@ -2334,6 +2334,50 @@ }, "description": "Models this key can access", "example": ["gpt-4o", "gpt-4o-mini"] + }, + "azure_key_config": { + "type": "object", + "properties": { + "endpoint": { + "type": "string", + "description": "Azure endpoint", + "example": "https://your-resource.openai.azure.com" + }, + "deployments": { + "type": "object", + "description": "Azure deployments", + "example": { + "gpt-4o": "gpt-4o-deployment" + } + }, + "api_version": { + "type": "string", + "description": "Azure API version", + "example": "2024-02-15-preview" + } + }, + "description": "Azure key configuration" + }, + "vertex_key_config": { + "type": "object", + "properties": { + "project_id": { + "type": "string", + "description": "Vertex project ID", + "example": "your-project-id" + }, + "region": { + "type": "string", + "description": "Vertex region", + "example": "us-central1" + }, + "auth_credentials": { + "type": "string", + "description": "Vertex auth credentials", + "example": "env.VERTEX_AUTH_CREDENTIALS" + } + }, + "description": "Vertex key configuration" } } }, diff --git a/docs/usage/providers.md b/docs/usage/providers.md index e5b68e564f..3cb5d79d5b 100644 --- a/docs/usage/providers.md +++ b/docs/usage/providers.md @@ -314,19 +314,24 @@ echo "$response" **Go Package:** ```go -func (a *MyAccount) GetConfigForProvider(provider schemas.ModelProvider) (*schemas.ProviderConfig, error) { +func (a *MyAccount) GetKeysForProvider(provider schemas.ModelProvider) ([]schemas.Key, error) { if provider == schemas.Azure { - return &schemas.ProviderConfig{ - NetworkConfig: schemas.NetworkConfig{ - BaseURL: "https://your-resource.openai.azure.com", - }, - MetaConfig: map[string]interface{}{ - "api_version": "2024-02-15-preview", - "deployment": "gpt-4o-deployment", + return []schemas.Key{ + { + Value: "your-azure-api-key", + Models: []string{"gpt-4o"}, // These models are mapped to the deployment + Weight: 1.0, + AzureKeyConfig: &schemas.AzureKeyConfig{ + Endpoint: "https://your-resource.openai.azure.com", + Deployments: map[string]string{ + "gpt-4o": "gpt-4o-deployment", + }, + APIVersion: StrPtr("2024-02-15-preview"), + }, }, }, nil } - return &schemas.ProviderConfig{}, nil + return nil, fmt.Errorf("provider not configured") } ``` @@ -340,16 +345,16 @@ func (a *MyAccount) GetConfigForProvider(provider schemas.ModelProvider) (*schem { "value": "env.AZURE_OPENAI_API_KEY", "models": ["gpt-4o"], - "weight": 1.0 + "weight": 1.0, + "azure_key_config": { + "endpoint": "https://your-resource.openai.azure.com", + "deployments": { + "gpt-4o": "gpt-4o-deployment" + }, + "api_version": "2024-02-15-preview" + } } - ], - "network_config": { - "base_url": "https://your-resource.openai.azure.com" - }, - "meta_config": { - "api_version": "2024-02-15-preview", - "deployment": "gpt-4o-deployment" - } + ] } } } @@ -363,17 +368,21 @@ func (a *MyAccount) GetConfigForProvider(provider schemas.ModelProvider) (*schem **Go Package:** ```go -func (a *MyAccount) GetConfigForProvider(provider schemas.ModelProvider) (*schemas.ProviderConfig, error) { +func (a *MyAccount) GetKeysForProvider(provider schemas.ModelProvider) ([]schemas.Key, error) { if provider == schemas.Vertex { - return &schemas.ProviderConfig{ - MetaConfig: map[string]interface{}{ - "project_id": "your-project-id", - "location": "us-central1", - "credentials_path": "/path/to/service-account.json", + return []schemas.Key{ + { + Models: []string{"gemini-pro"}, // These models are just for mapping to keys + Weight: 1.0, + VertexKeyConfig: &schemas.VertexKeyConfig{ + ProjectID: "your-project-id", + Location: "us-central1", + AuthCredentials: os.Getenv("VERTEX_AUTH_CREDENTIALS"), // Or read from file + }, }, }, nil } - return &schemas.ProviderConfig{}, nil + return nil, fmt.Errorf("provider not configured") } ``` @@ -385,15 +394,15 @@ func (a *MyAccount) GetConfigForProvider(provider schemas.ModelProvider) (*schem "vertex": { "keys": [ { - "value": "file:/path/to/service-account.json", - "models": ["gemini-pro"], - "weight": 1.0 + "models": ["google/gemini-2.0-flash-001"], + "weight": 1.0, + "vertex_key_config": { + "project_id": "your-project-id", + "region": "us-central1", + "auth_credentials": "env.VERTEX_AUTH_CREDENTIALS" + } } - ], - "meta_config": { - "project_id": "your-project-id", - "location": "us-central1" - } + ] } } } diff --git a/tests/core-providers/config/account.go b/tests/core-providers/config/account.go index f8bbf414cb..e3cf907dbf 100644 --- a/tests/core-providers/config/account.go +++ b/tests/core-providers/config/account.go @@ -109,14 +109,29 @@ func (account *ComprehensiveTestAccount) GetKeysForProvider(providerKey schemas. Value: os.Getenv("AZURE_API_KEY"), Models: []string{"gpt-4o"}, Weight: 1.0, + AzureKeyConfig: &schemas.AzureKeyConfig{ + Endpoint: os.Getenv("AZURE_ENDPOINT"), + Deployments: map[string]string{ + "gpt-4o": "gpt-4o-aug", + }, + // Use environment variable for API version with fallback to current preview version + // Note: This is a preview API version that may change over time. Update as needed. + // Set AZURE_API_VERSION environment variable to override the default. + APIVersion: bifrost.Ptr(getEnvWithDefault("AZURE_API_VERSION", "2024-08-01-preview")), + }, }, }, nil case schemas.Vertex: return []schemas.Key{ { Value: os.Getenv("VERTEX_API_KEY"), - Models: []string{"gemini-pro"}, + Models: []string{}, Weight: 1.0, + VertexKeyConfig: &schemas.VertexKeyConfig{ + ProjectID: os.Getenv("VERTEX_PROJECT_ID"), + Region: getEnvWithDefault("VERTEX_REGION", "us-central1"), + AuthCredentials: os.Getenv("VERTEX_CREDENTIALS"), + }, }, }, nil case schemas.Mistral: @@ -191,16 +206,6 @@ func (account *ComprehensiveTestAccount) GetConfigForProvider(providerKey schema RetryBackoffInitial: 100 * time.Millisecond, RetryBackoffMax: 2 * time.Second, }, - MetaConfig: &meta.AzureMetaConfig{ - Endpoint: os.Getenv("AZURE_ENDPOINT"), - Deployments: map[string]string{ - "gpt-4o": "gpt-4o-aug", - }, - // Use environment variable for API version with fallback to current preview version - // Note: This is a preview API version that may change over time. Update as needed. - // Set AZURE_API_VERSION environment variable to override the default. - APIVersion: bifrost.Ptr(getEnvWithDefault("AZURE_API_VERSION", "2024-08-01-preview")), - }, ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ Concurrency: 3, BufferSize: 10, @@ -214,11 +219,6 @@ func (account *ComprehensiveTestAccount) GetConfigForProvider(providerKey schema RetryBackoffInitial: 100 * time.Millisecond, RetryBackoffMax: 2 * time.Second, }, - MetaConfig: &meta.VertexMetaConfig{ - ProjectID: os.Getenv("VERTEX_PROJECT_ID"), - Region: getEnvWithDefault("VERTEX_REGION", "us-central1"), - AuthCredentials: os.Getenv("VERTEX_CREDENTIALS"), - }, ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ Concurrency: 3, BufferSize: 10, diff --git a/tests/core-providers/go.mod b/tests/core-providers/go.mod index bfb845fb66..2033e99b8f 100644 --- a/tests/core-providers/go.mod +++ b/tests/core-providers/go.mod @@ -38,3 +38,5 @@ require ( golang.org/x/text v0.24.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) + +replace github.com/maximhq/bifrost/core => ../../core \ No newline at end of file diff --git a/tests/core-providers/vertex_test.go b/tests/core-providers/vertex_test.go index b2698381b3..ac61cfb144 100644 --- a/tests/core-providers/vertex_test.go +++ b/tests/core-providers/vertex_test.go @@ -35,9 +35,6 @@ func TestVertex(t *testing.T) { CompleteEnd2End: true, ProviderSpecific: true, }, - Fallbacks: []schemas.Fallback{ - {Provider: schemas.Anthropic, Model: "claude-3-7-sonnet-20250219"}, - }, } runAllComprehensiveTests(t, client, ctx, testConfig) diff --git a/transports/bifrost-http/handlers/completions.go b/transports/bifrost-http/handlers/completions.go index 935b28e908..4e06bbb193 100644 --- a/transports/bifrost-http/handlers/completions.go +++ b/transports/bifrost-http/handlers/completions.go @@ -78,8 +78,8 @@ func (h *CompletionHandler) handleCompletion(ctx *fasthttp.RequestCtx, completio return } - model := strings.Split(req.Model, "/") - if len(model) != 2 { + model := strings.SplitN(req.Model, "/", 2) + if len(model) < 2 { SendError(ctx, fasthttp.StatusBadRequest, "Model must be in the format of 'provider/model'", h.logger) return } diff --git a/transports/bifrost-http/handlers/providers.go b/transports/bifrost-http/handlers/providers.go index 8e83f2879e..87c28d5ee4 100644 --- a/transports/bifrost-http/handlers/providers.go +++ b/transports/bifrost-http/handlers/providers.go @@ -5,7 +5,9 @@ package handlers import ( "encoding/json" "fmt" + "slices" "sort" + "strings" "github.com/fasthttp/router" bifrost "github.com/maximhq/bifrost/core" @@ -153,7 +155,7 @@ func (h *ProviderHandler) AddProvider(ctx *fasthttp.RequestCtx) { } // Validate required keys - if len(req.Keys) == 0 && req.Provider != schemas.Vertex && req.Provider != schemas.Ollama && req.Provider != schemas.SGL { + if len(req.Keys) == 0 && req.Provider != schemas.Ollama && req.Provider != schemas.SGL { SendError(ctx, fasthttp.StatusBadRequest, "At least one API key is required", h.logger) return } @@ -237,79 +239,51 @@ func (h *ProviderHandler) UpdateProvider(ctx *fasthttp.RequestCtx) { return } + oldConfigRedacted, err := h.store.GetProviderConfigRedacted(provider) + if err != nil { + SendError(ctx, fasthttp.StatusNotFound, fmt.Sprintf("Provider not found: %v", err), h.logger) + return + } + // Construct ProviderConfig from individual fields config := lib.ProviderConfig{ Keys: oldConfigRaw.Keys, NetworkConfig: oldConfigRaw.NetworkConfig, ConcurrencyAndBufferSize: oldConfigRaw.ConcurrencyAndBufferSize, + ProxyConfig: oldConfigRaw.ProxyConfig, } - // For now, don't replace any environment keys - preserve all existing ones - // TODO: Implement proper tracking of which env keys should be dropped - envKeysToReplace := make(map[string]struct{}) + // Environment variable cleanup is now handled automatically by mergeKeys function - // Validate and process keys - if req.Keys != nil { - if len(req.Keys) == 0 && provider != schemas.Vertex && provider != schemas.Ollama && provider != schemas.SGL { - SendError(ctx, fasthttp.StatusBadRequest, "At least one API key is required", h.logger) - return - } + var keysToAdd []schemas.Key + var keysToUpdate []schemas.Key - // Create a map of old keys by model patterns for quick lookup - oldKeysByModels := make(map[string][]schemas.Key) - for _, oldKey := range oldConfigRaw.Keys { - for _, model := range oldKey.Models { - oldKeysByModels[model] = append(oldKeysByModels[model], oldKey) - } + for _, key := range req.Keys { + if !slices.ContainsFunc(oldConfigRaw.Keys, func(k schemas.Key) bool { + return k.ID == key.ID + }) { + keysToAdd = append(keysToAdd, key) + } else { + keysToUpdate = append(keysToUpdate, key) } + } - // Process each key in the request - for i, newKey := range req.Keys { - // If the key is redacted, try to find and use the old key for the same models - if lib.IsRedacted(newKey.Value) { - // Look for matching old keys - var matchingKeys []schemas.Key - for _, model := range newKey.Models { - if oldKeys, exists := oldKeysByModels[model]; exists { - matchingKeys = append(matchingKeys, oldKeys...) - } - } - - // If we found matching keys, use the most appropriate one - if len(matchingKeys) > 0 { - // Try to find a key that matches all the same models - var bestMatch schemas.Key - bestMatchScore := 0 - - for _, oldKey := range matchingKeys { - // Calculate how many models match between the old and new key - matchCount := 0 - oldModelsMap := make(map[string]bool) - for _, m := range oldKey.Models { - oldModelsMap[m] = true - } - - for _, m := range newKey.Models { - if oldModelsMap[m] { - matchCount++ - } - } - - // Update best match if this key has more matching models - if matchCount > bestMatchScore { - bestMatch = oldKey - bestMatchScore = matchCount - } - } - - // Use the best matching key's value - req.Keys[i].Value = bestMatch.Value - } - } + var keysToDelete []schemas.Key + for _, key := range oldConfigRaw.Keys { + if !slices.ContainsFunc(req.Keys, func(k schemas.Key) bool { + return k.ID == key.ID + }) { + keysToDelete = append(keysToDelete, key) } - config.Keys = req.Keys } + keys, err := h.mergeKeys(provider, oldConfigRaw.Keys, oldConfigRedacted.Keys, keysToAdd, keysToDelete, keysToUpdate) + if err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid keys: %v", err), h.logger) + return + } + config.Keys = keys + // Handle meta config if provided if req.MetaConfig != nil && len(*req.MetaConfig) > 0 { // Merge new meta config with old, preserving redacted values @@ -340,7 +314,7 @@ func (h *ProviderHandler) UpdateProvider(ctx *fasthttp.RequestCtx) { config.ProxyConfig = req.ProxyConfig // Update provider config in store (env vars will be processed by store) - if err := h.store.UpdateProviderConfig(provider, config, envKeysToReplace); err != nil { + if err := h.store.UpdateProviderConfig(provider, config); err != nil { h.logger.Warn(fmt.Sprintf("Failed to update provider %s: %v", provider, err)) SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to update provider: %v", err), h.logger) return @@ -415,14 +389,6 @@ func (h *ProviderHandler) convertToProviderMetaConfig(provider schemas.ModelProv } switch provider { - case schemas.Azure: - var azureMetaConfig meta.AzureMetaConfig - if err := json.Unmarshal(metaConfigJSON, &azureMetaConfig); err != nil { - return nil, fmt.Errorf("failed to unmarshal Azure meta config: %w", err) - } - var metaConfig schemas.MetaConfig = &azureMetaConfig - return &metaConfig, nil - case schemas.Bedrock: var bedrockMetaConfig meta.BedrockMetaConfig if err := json.Unmarshal(metaConfigJSON, &bedrockMetaConfig); err != nil { @@ -431,50 +397,108 @@ func (h *ProviderHandler) convertToProviderMetaConfig(provider schemas.ModelProv var metaConfig schemas.MetaConfig = &bedrockMetaConfig return &metaConfig, nil - case schemas.Vertex: - var vertexMetaConfig meta.VertexMetaConfig - if err := json.Unmarshal(metaConfigJSON, &vertexMetaConfig); err != nil { - return nil, fmt.Errorf("failed to unmarshal Vertex meta config: %w", err) - } - var metaConfig schemas.MetaConfig = &vertexMetaConfig - return &metaConfig, nil - default: // For providers that don't support meta config, return nil return nil, nil } } -// mergeMetaConfig merges new meta config with old, preserving values that are redacted in the new config -func (h *ProviderHandler) mergeMetaConfig(provider schemas.ModelProvider, oldConfig *schemas.MetaConfig, newConfigMap map[string]interface{}) (*schemas.MetaConfig, error) { - if oldConfig == nil || len(newConfigMap) == 0 { - return h.convertToProviderMetaConfig(provider, newConfigMap) +// mergeKeys merges new keys with old, preserving values that are redacted in the new config +func (h *ProviderHandler) mergeKeys(provider schemas.ModelProvider, oldRawKeys []schemas.Key, oldRedactedKeys []schemas.Key, keysToAdd []schemas.Key, keysToDelete []schemas.Key, keysToUpdate []schemas.Key) ([]schemas.Key, error) { + // Clean up environment variables for deleted and updated keys + h.store.CleanupEnvKeysForKeys(string(provider), keysToDelete) + h.store.CleanupEnvKeysForUpdatedKeys(string(provider), keysToUpdate) + // Create a map of indices to delete + toDelete := make(map[int]bool) + for _, key := range keysToDelete { + for i, oldKey := range oldRawKeys { + if oldKey.ID == key.ID { + toDelete[i] = true + break + } + } } - switch provider { - case schemas.Azure: - var newAzureConfig meta.AzureMetaConfig - newConfigJSON, _ := json.Marshal(newConfigMap) - if err := json.Unmarshal(newConfigJSON, &newAzureConfig); err != nil { - return nil, fmt.Errorf("failed to unmarshal new Azure meta config: %w", err) - } + // Create a map of updates by ID for quick lookup + updates := make(map[string]schemas.Key) + for _, key := range keysToUpdate { + updates[key.ID] = key + } - oldAzureConfig, ok := (*oldConfig).(*meta.AzureMetaConfig) - if !ok { - return nil, fmt.Errorf("existing meta config type mismatch: expected AzureMetaConfig") + // Process existing keys (handle updates and deletions) + var resultKeys []schemas.Key + for i, oldRawKey := range oldRawKeys { + // Skip if this key should be deleted + if toDelete[i] { + continue } - // Preserve old values if new ones are redacted - if lib.IsRedacted(newAzureConfig.Endpoint) { - newAzureConfig.Endpoint = oldAzureConfig.Endpoint - } - if newAzureConfig.APIVersion != nil && oldAzureConfig.APIVersion != nil && lib.IsRedacted(*newAzureConfig.APIVersion) { - newAzureConfig.APIVersion = oldAzureConfig.APIVersion + // Check if this key should be updated + if updateKey, exists := updates[oldRawKey.ID]; exists { + mergedKey := updateKey + + // Handle redacted values + if lib.IsRedacted(updateKey.Value) && + (!strings.HasPrefix(updateKey.Value, "env.") || + !strings.EqualFold(updateKey.Value, oldRedactedKeys[i].Value)) { + mergedKey.Value = oldRawKey.Value + } + + // Handle Azure config redacted values + if updateKey.AzureKeyConfig != nil && oldRedactedKeys[i].AzureKeyConfig != nil { + if lib.IsRedacted(updateKey.AzureKeyConfig.Endpoint) && + (!strings.HasPrefix(updateKey.AzureKeyConfig.Endpoint, "env.") || + !strings.EqualFold(updateKey.AzureKeyConfig.Endpoint, oldRedactedKeys[i].AzureKeyConfig.Endpoint)) { + mergedKey.AzureKeyConfig.Endpoint = oldRawKey.AzureKeyConfig.Endpoint + } + if updateKey.AzureKeyConfig.APIVersion != nil { + if lib.IsRedacted(*updateKey.AzureKeyConfig.APIVersion) && + (!strings.HasPrefix(*updateKey.AzureKeyConfig.APIVersion, "env.") || + !strings.EqualFold(*updateKey.AzureKeyConfig.APIVersion, *oldRedactedKeys[i].AzureKeyConfig.APIVersion)) { + mergedKey.AzureKeyConfig.APIVersion = oldRawKey.AzureKeyConfig.APIVersion + } + } + } + + // Handle Vertex config redacted values + if updateKey.VertexKeyConfig != nil && oldRedactedKeys[i].VertexKeyConfig != nil { + if lib.IsRedacted(updateKey.VertexKeyConfig.ProjectID) && + (!strings.HasPrefix(updateKey.VertexKeyConfig.ProjectID, "env.") || + !strings.EqualFold(updateKey.VertexKeyConfig.ProjectID, oldRedactedKeys[i].VertexKeyConfig.ProjectID)) { + mergedKey.VertexKeyConfig.ProjectID = oldRawKey.VertexKeyConfig.ProjectID + } + if lib.IsRedacted(updateKey.VertexKeyConfig.Region) && + (!strings.HasPrefix(updateKey.VertexKeyConfig.Region, "env.") || + !strings.EqualFold(updateKey.VertexKeyConfig.Region, oldRedactedKeys[i].VertexKeyConfig.Region)) { + mergedKey.VertexKeyConfig.Region = oldRawKey.VertexKeyConfig.Region + } + if lib.IsRedacted(updateKey.VertexKeyConfig.AuthCredentials) && + (!strings.HasPrefix(updateKey.VertexKeyConfig.AuthCredentials, "env.") || + !strings.EqualFold(updateKey.VertexKeyConfig.AuthCredentials, oldRedactedKeys[i].VertexKeyConfig.AuthCredentials)) { + mergedKey.VertexKeyConfig.AuthCredentials = oldRawKey.VertexKeyConfig.AuthCredentials + } + } + + resultKeys = append(resultKeys, mergedKey) + } else { + // Keep unchanged key + resultKeys = append(resultKeys, oldRawKey) } + } - var metaConfig schemas.MetaConfig = &newAzureConfig - return &metaConfig, nil + // Add new keys + resultKeys = append(resultKeys, keysToAdd...) + + return resultKeys, nil +} +// mergeMetaConfig merges new meta config with old, preserving values that are redacted in the new config +func (h *ProviderHandler) mergeMetaConfig(provider schemas.ModelProvider, oldConfig *schemas.MetaConfig, newConfigMap map[string]interface{}) (*schemas.MetaConfig, error) { + if oldConfig == nil || len(newConfigMap) == 0 { + return h.convertToProviderMetaConfig(provider, newConfigMap) + } + + switch provider { case schemas.Bedrock: var newBedrockConfig meta.BedrockMetaConfig newConfigJSON, _ := json.Marshal(newConfigMap) @@ -503,33 +527,6 @@ func (h *ProviderHandler) mergeMetaConfig(provider schemas.ModelProvider, oldCon var metaConfig schemas.MetaConfig = &newBedrockConfig return &metaConfig, nil - - case schemas.Vertex: - var newVertexConfig meta.VertexMetaConfig - newConfigJSON, _ := json.Marshal(newConfigMap) - if err := json.Unmarshal(newConfigJSON, &newVertexConfig); err != nil { - return nil, fmt.Errorf("failed to unmarshal new Vertex meta config: %w", err) - } - - oldVertexConfig, ok := (*oldConfig).(*meta.VertexMetaConfig) - if !ok { - return nil, fmt.Errorf("existing meta config type mismatch: expected VertexMetaConfig") - } - - // Preserve old values if new ones are redacted - if lib.IsRedacted(newVertexConfig.ProjectID) { - newVertexConfig.ProjectID = oldVertexConfig.ProjectID - } - if lib.IsRedacted(newVertexConfig.Region) { - newVertexConfig.Region = oldVertexConfig.Region - } - if lib.IsRedacted(newVertexConfig.AuthCredentials) { - newVertexConfig.AuthCredentials = oldVertexConfig.AuthCredentials - } - - var metaConfig schemas.MetaConfig = &newVertexConfig - return &metaConfig, nil - default: return nil, nil } diff --git a/transports/bifrost-http/lib/account.go b/transports/bifrost-http/lib/account.go index e611616345..85bbe0c477 100644 --- a/transports/bifrost-http/lib/account.go +++ b/transports/bifrost-http/lib/account.go @@ -9,10 +9,10 @@ import ( ) // BaseAccount implements the Account interface for Bifrost. -// It manages provider configurations using a bbolt store for persistent storage. +// It manages provider configurations using a in-memory store for persistent storage. // All data processing (environment variables, meta configs) is done upfront in the store. type BaseAccount struct { - store *ConfigStore // bbolt store for persistent configuration + store *ConfigStore // store for in-memory configuration } // NewBaseAccount creates a new BaseAccount with the given store diff --git a/transports/bifrost-http/lib/config.go b/transports/bifrost-http/lib/config.go index 81cd7553b4..d224713719 100644 --- a/transports/bifrost-http/lib/config.go +++ b/transports/bifrost-http/lib/config.go @@ -18,7 +18,7 @@ type ClientConfig struct { // ProviderConfig represents the configuration for a specific AI model provider. // It includes API keys, network settings, provider-specific metadata, and concurrency settings. type ProviderConfig struct { - Keys []schemas.Key `json:"keys"` // API keys for the provider + Keys []schemas.Key `json:"keys"` // API keys for the provider with UUIDs NetworkConfig *schemas.NetworkConfig `json:"network_config,omitempty"` // Network-related settings MetaConfig *schemas.MetaConfig `json:"-"` // Provider-specific metadata ConcurrencyAndBufferSize *schemas.ConcurrencyAndBufferSize `json:"concurrency_and_buffer_size,omitempty"` // Concurrency settings diff --git a/transports/bifrost-http/lib/store.go b/transports/bifrost-http/lib/store.go index aaae7d77d7..6ca88f88bc 100644 --- a/transports/bifrost-http/lib/store.go +++ b/transports/bifrost-http/lib/store.go @@ -9,6 +9,7 @@ import ( "strings" "sync" + "github.com/google/uuid" bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/schemas" "github.com/maximhq/bifrost/core/schemas/meta" @@ -20,11 +21,11 @@ import ( // // Features: // - Pure in-memory storage for ultra-fast access -// - Environment variable processing for API keys and meta configurations +// - Environment variable processing for API keys and key-level configurations // - Thread-safe operations with read-write mutexes // - Real-time configuration updates via HTTP API // - Explicit persistence control via WriteConfigToFile() -// - Support for all provider-specific meta configurations (Azure, Bedrock, Vertex) +// - Support for provider-specific key configurations (Azure, Vertex) and meta configurations (Bedrock) type ConfigStore struct { mu sync.RWMutex muMCP sync.RWMutex @@ -45,8 +46,9 @@ type ConfigStore struct { type EnvKeyInfo struct { EnvVar string // The environment variable name (without env. prefix) Provider string // The provider this key belongs to (empty for core/mcp configs) - KeyType string // Type of key (e.g., "api_key", "meta_config", "connection_string") + KeyType string // Type of key (e.g., "api_key", "azure_config", "vertex_config", "meta_config", "connection_string") ConfigPath string // Path in config where this env var is used + KeyID string // The key ID this env var belongs to (empty for non-key configs like meta_config, connection_string) } var DefaultClientConfig = ClientConfig{ @@ -66,7 +68,7 @@ func NewConfigStore(logger schemas.Logger) (*ConfigStore, error) { } // LoadFromConfig loads initial configuration from a JSON config file into memory -// with full preprocessing including environment variable resolution and meta config parsing. +// with full preprocessing including environment variable resolution and key config parsing. // All processing is done upfront to ensure zero latency when retrieving data. // // If the config file doesn't exist, the system starts with default configuration @@ -75,7 +77,8 @@ func NewConfigStore(logger schemas.Logger) (*ConfigStore, error) { // This method handles: // - JSON config file parsing // - Environment variable substitution for API keys (env.VARIABLE_NAME) -// - Provider-specific meta config processing (Azure, Bedrock, Vertex) +// - Key-level config processing for Azure and Vertex (Endpoint, APIVersion, ProjectID, Region, AuthCredentials) +// - Provider-specific meta config processing (Bedrock only) // - Case conversion for provider names (e.g., "OpenAI" -> "openai") // - In-memory storage for ultra-fast access during request processing // - Graceful handling of missing config files @@ -178,8 +181,13 @@ func (s *ConfigStore) LoadFromConfig(configPath string) error { } } - // Process environment variables in keys + // Process environment variables in keys (including key-level configs) for i, key := range cfg.Keys { + if key.ID == "" { + cfg.Keys[i].ID = uuid.NewString() + } + + // Process API key value processedValue, envVar, err := s.processEnvValue(key.Value) if err != nil { s.cleanupEnvKeys(string(provider), "", newEnvKeys) @@ -195,9 +203,28 @@ func (s *ConfigStore) LoadFromConfig(configPath string) error { EnvVar: envVar, Provider: string(provider), KeyType: "api_key", - ConfigPath: fmt.Sprintf("providers.%s.keys[%d]", provider, i), + ConfigPath: fmt.Sprintf("providers.%s.keys[%s]", provider, key.ID), + KeyID: key.ID, }) } + + // Process Azure key config if present + if key.AzureKeyConfig != nil { + if err := s.processAzureKeyConfigEnvVars(&cfg.Keys[i], provider, i, newEnvKeys); err != nil { + s.cleanupEnvKeys(string(provider), "", newEnvKeys) + s.logger.Warn(fmt.Sprintf("failed to process Azure key config env vars for %s: %v", provider, err)) + continue + } + } + + // Process Vertex key config if present + if key.VertexKeyConfig != nil { + if err := s.processVertexKeyConfigEnvVars(&cfg.Keys[i], provider, i, newEnvKeys); err != nil { + s.cleanupEnvKeys(string(provider), "", newEnvKeys) + s.logger.Warn(fmt.Sprintf("failed to process Vertex key config env vars for %s: %v", provider, err)) + continue + } + } } processedProviders[provider] = cfg @@ -280,16 +307,77 @@ func (s *ConfigStore) writeConfigToFile(configPath string) error { redactedKeys := make([]schemas.Key, len(config.Keys)) for i, key := range config.Keys { redactedKeys[i] = schemas.Key{ + ID: key.ID, Models: key.Models, Weight: key.Weight, } - path := fmt.Sprintf("providers.%s.keys[%d]", provider, i) + // Restore API key value + path := fmt.Sprintf("providers.%s.keys[%s]", provider, key.ID) if envVar, ok := envVarsByPath[path]; ok { redactedKeys[i].Value = "env." + envVar } else { redactedKeys[i].Value = key.Value // Keep actual value, no asterisk redaction } + + // Restore Azure key config if present + if key.AzureKeyConfig != nil { + azureConfig := &schemas.AzureKeyConfig{ + Deployments: key.AzureKeyConfig.Deployments, + } + + // Restore Endpoint + path = fmt.Sprintf("providers.%s.keys[%s].azure_key_config.endpoint", provider, key.ID) + if envVar, ok := envVarsByPath[path]; ok { + azureConfig.Endpoint = "env." + envVar + } else { + azureConfig.Endpoint = key.AzureKeyConfig.Endpoint + } + + // Restore APIVersion if present + if key.AzureKeyConfig.APIVersion != nil { + path = fmt.Sprintf("providers.%s.keys[%s].azure_key_config.api_version", provider, key.ID) + if envVar, ok := envVarsByPath[path]; ok { + apiVersion := "env." + envVar + azureConfig.APIVersion = &apiVersion + } else { + azureConfig.APIVersion = key.AzureKeyConfig.APIVersion + } + } + + redactedKeys[i].AzureKeyConfig = azureConfig + } + + // Restore Vertex key config if present + if key.VertexKeyConfig != nil { + vertexConfig := &schemas.VertexKeyConfig{} + + // Restore ProjectID + path = fmt.Sprintf("providers.%s.keys[%s].vertex_key_config.project_id", provider, key.ID) + if envVar, ok := envVarsByPath[path]; ok { + vertexConfig.ProjectID = "env." + envVar + } else { + vertexConfig.ProjectID = key.VertexKeyConfig.ProjectID + } + + // Restore Region + path = fmt.Sprintf("providers.%s.keys[%s].vertex_key_config.region", provider, key.ID) + if envVar, ok := envVarsByPath[path]; ok { + vertexConfig.Region = "env." + envVar + } else { + vertexConfig.Region = key.VertexKeyConfig.Region + } + + // Restore AuthCredentials + path = fmt.Sprintf("providers.%s.keys[%s].vertex_key_config.auth_credentials", provider, key.ID) + if envVar, ok := envVarsByPath[path]; ok { + vertexConfig.AuthCredentials = "env." + envVar + } else { + vertexConfig.AuthCredentials = key.VertexKeyConfig.AuthCredentials + } + + redactedKeys[i].VertexKeyConfig = vertexConfig + } } // Create provider config with restored env references @@ -370,28 +458,6 @@ func (s *ConfigStore) getRestoredMCPConfig(envVarsByPath map[string]string) *sch // restoreMetaConfigEnvVars creates a copy of meta config with env variable references restored func (s *ConfigStore) restoreMetaConfigEnvVars(provider schemas.ModelProvider, metaConfig schemas.MetaConfig, envVarsByPath map[string]string) interface{} { switch m := metaConfig.(type) { - case *meta.AzureMetaConfig: - azureConfig := *m // Copy the struct - - // Restore endpoint if it came from env var - path := fmt.Sprintf("providers.%s.meta_config.endpoint", provider) - if envVar, ok := envVarsByPath[path]; ok { - azureConfig.Endpoint = "env." + envVar - } - // Otherwise keep actual value (no asterisk redaction) - - // Restore API version if it came from env var - if azureConfig.APIVersion != nil { - path = fmt.Sprintf("providers.%s.meta_config.api_version", provider) - if envVar, ok := envVarsByPath[path]; ok { - apiVersion := "env." + envVar - azureConfig.APIVersion = &apiVersion - } - // Otherwise keep actual value (no asterisk redaction) - } - - return azureConfig - case *meta.BedrockMetaConfig: bedrockConfig := *m // Copy the struct @@ -434,32 +500,6 @@ func (s *ConfigStore) restoreMetaConfigEnvVars(provider schemas.ModelProvider, m return bedrockConfig - case *meta.VertexMetaConfig: - vertexConfig := *m // Copy the struct - - // Restore project ID if it came from env var - path := fmt.Sprintf("providers.%s.meta_config.project_id", provider) - if envVar, ok := envVarsByPath[path]; ok { - vertexConfig.ProjectID = "env." + envVar - } - // Otherwise keep actual value (no asterisk redaction) - - // Restore region if it came from env var - path = fmt.Sprintf("providers.%s.meta_config.region", provider) - if envVar, ok := envVarsByPath[path]; ok { - vertexConfig.Region = "env." + envVar - } - // Otherwise keep actual value (no asterisk redaction) - - // Restore auth credentials if it came from env var - path = fmt.Sprintf("providers.%s.meta_config.auth_credentials", provider) - if envVar, ok := envVarsByPath[path]; ok { - vertexConfig.AuthCredentials = "env." + envVar - } - // Otherwise keep actual value (no asterisk redaction) - - return vertexConfig - default: return metaConfig } @@ -476,14 +516,6 @@ func (s *ConfigStore) SaveConfig() error { // parseMetaConfig converts raw JSON to the appropriate provider-specific meta config interface func (s *ConfigStore) parseMetaConfig(rawMetaConfig json.RawMessage, provider schemas.ModelProvider) (*schemas.MetaConfig, error) { switch provider { - case schemas.Azure: - var azureMetaConfig meta.AzureMetaConfig - if err := json.Unmarshal(rawMetaConfig, &azureMetaConfig); err != nil { - return nil, fmt.Errorf("failed to unmarshal Azure meta config: %w", err) - } - var metaConfig schemas.MetaConfig = &azureMetaConfig - return &metaConfig, nil - case schemas.Bedrock: var bedrockMetaConfig meta.BedrockMetaConfig if err := json.Unmarshal(rawMetaConfig, &bedrockMetaConfig); err != nil { @@ -491,14 +523,6 @@ func (s *ConfigStore) parseMetaConfig(rawMetaConfig json.RawMessage, provider sc } var metaConfig schemas.MetaConfig = &bedrockMetaConfig return &metaConfig, nil - - case schemas.Vertex: - var vertexMetaConfig meta.VertexMetaConfig - if err := json.Unmarshal(rawMetaConfig, &vertexMetaConfig); err != nil { - return nil, fmt.Errorf("failed to unmarshal Vertex meta config: %w", err) - } - var metaConfig schemas.MetaConfig = &vertexMetaConfig - return &metaConfig, nil } return nil, fmt.Errorf("unsupported provider for meta config: %s", provider) @@ -509,9 +533,7 @@ func (s *ConfigStore) parseMetaConfig(rawMetaConfig json.RawMessage, provider sc // variables in their fields, ensuring type safety and proper field handling. // // Supported providers and their processed fields: -// - Azure: Endpoint, APIVersion // - Bedrock: SecretAccessKey, Region, SessionToken, ARN -// - Vertex: ProjectID, Region, AuthCredentials // // For unsupported providers, the meta config is returned unchanged. // This approach ensures type safety while supporting environment variable substitution. @@ -520,50 +542,6 @@ func (s *ConfigStore) processMetaConfigEnvVars(rawMetaConfig json.RawMessage, pr newEnvKeys := make(map[string]struct{}) switch provider { - case schemas.Azure: - var azureMetaConfig meta.AzureMetaConfig - if err := json.Unmarshal(rawMetaConfig, &azureMetaConfig); err != nil { - return nil, newEnvKeys, fmt.Errorf("failed to unmarshal Azure meta config: %w", err) - } - - endpoint, envVar, err := s.processEnvValue(azureMetaConfig.Endpoint) - if err != nil { - return nil, newEnvKeys, err - } - if envVar != "" { - newEnvKeys[envVar] = struct{}{} - s.EnvKeys[envVar] = append(s.EnvKeys[envVar], EnvKeyInfo{ - EnvVar: envVar, - Provider: string(provider), - KeyType: "meta_config", - ConfigPath: fmt.Sprintf("providers.%s.meta_config.endpoint", provider), - }) - } - azureMetaConfig.Endpoint = endpoint - - if azureMetaConfig.APIVersion != nil { - apiVersion, envVar, err := s.processEnvValue(*azureMetaConfig.APIVersion) - if err != nil { - return nil, newEnvKeys, err - } - if envVar != "" { - newEnvKeys[envVar] = struct{}{} - s.EnvKeys[envVar] = append(s.EnvKeys[envVar], EnvKeyInfo{ - EnvVar: envVar, - Provider: string(provider), - KeyType: "meta_config", - ConfigPath: fmt.Sprintf("providers.%s.meta_config.api_version", provider), - }) - } - azureMetaConfig.APIVersion = &apiVersion - } - - processedJSON, err := json.Marshal(azureMetaConfig) - if err != nil { - return nil, newEnvKeys, fmt.Errorf("failed to marshal processed Azure meta config: %w", err) - } - return processedJSON, newEnvKeys, nil - case schemas.Bedrock: var bedrockMetaConfig meta.BedrockMetaConfig if err := json.Unmarshal(rawMetaConfig, &bedrockMetaConfig); err != nil { @@ -581,6 +559,7 @@ func (s *ConfigStore) processMetaConfigEnvVars(rawMetaConfig json.RawMessage, pr Provider: string(provider), KeyType: "meta_config", ConfigPath: fmt.Sprintf("providers.%s.meta_config.secret_access_key", provider), + KeyID: "", // Empty for meta config entries }) } bedrockMetaConfig.SecretAccessKey = secretAccessKey @@ -597,6 +576,7 @@ func (s *ConfigStore) processMetaConfigEnvVars(rawMetaConfig json.RawMessage, pr Provider: string(provider), KeyType: "meta_config", ConfigPath: fmt.Sprintf("providers.%s.meta_config.region", provider), + KeyID: "", // Empty for meta config entries }) } bedrockMetaConfig.Region = ®ion @@ -614,6 +594,7 @@ func (s *ConfigStore) processMetaConfigEnvVars(rawMetaConfig json.RawMessage, pr Provider: string(provider), KeyType: "meta_config", ConfigPath: fmt.Sprintf("providers.%s.meta_config.session_token", provider), + KeyID: "", // Empty for meta config entries }) } bedrockMetaConfig.SessionToken = &sessionToken @@ -631,6 +612,7 @@ func (s *ConfigStore) processMetaConfigEnvVars(rawMetaConfig json.RawMessage, pr Provider: string(provider), KeyType: "meta_config", ConfigPath: fmt.Sprintf("providers.%s.meta_config.arn", provider), + KeyID: "", // Empty for meta config entries }) } bedrockMetaConfig.ARN = &arn @@ -641,63 +623,6 @@ func (s *ConfigStore) processMetaConfigEnvVars(rawMetaConfig json.RawMessage, pr return nil, newEnvKeys, fmt.Errorf("failed to marshal processed Bedrock meta config: %w", err) } return processedJSON, newEnvKeys, nil - - case schemas.Vertex: - var vertexMetaConfig meta.VertexMetaConfig - if err := json.Unmarshal(rawMetaConfig, &vertexMetaConfig); err != nil { - return nil, newEnvKeys, fmt.Errorf("failed to unmarshal Vertex meta config: %w", err) - } - - projectID, envVar, err := s.processEnvValue(vertexMetaConfig.ProjectID) - if err != nil { - return nil, newEnvKeys, err - } - if envVar != "" { - newEnvKeys[envVar] = struct{}{} - s.EnvKeys[envVar] = append(s.EnvKeys[envVar], EnvKeyInfo{ - EnvVar: envVar, - Provider: string(provider), - KeyType: "meta_config", - ConfigPath: fmt.Sprintf("providers.%s.meta_config.project_id", provider), - }) - } - vertexMetaConfig.ProjectID = projectID - - region, envVar, err := s.processEnvValue(vertexMetaConfig.Region) - if err != nil { - return nil, newEnvKeys, err - } - if envVar != "" { - newEnvKeys[envVar] = struct{}{} - s.EnvKeys[envVar] = append(s.EnvKeys[envVar], EnvKeyInfo{ - EnvVar: envVar, - Provider: string(provider), - KeyType: "meta_config", - ConfigPath: fmt.Sprintf("providers.%s.meta_config.region", provider), - }) - } - vertexMetaConfig.Region = region - - authCredentials, envVar, err := s.processEnvValue(vertexMetaConfig.AuthCredentials) - if err != nil { - return nil, newEnvKeys, err - } - if envVar != "" { - newEnvKeys[envVar] = struct{}{} - s.EnvKeys[envVar] = append(s.EnvKeys[envVar], EnvKeyInfo{ - EnvVar: envVar, - Provider: string(provider), - KeyType: "meta_config", - ConfigPath: fmt.Sprintf("providers.%s.meta_config.auth_credentials", provider), - }) - } - vertexMetaConfig.AuthCredentials = authCredentials - - processedJSON, err := json.Marshal(vertexMetaConfig) - if err != nil { - return nil, newEnvKeys, fmt.Errorf("failed to marshal processed Vertex meta config: %w", err) - } - return processedJSON, newEnvKeys, nil } return rawMetaConfig, newEnvKeys, nil @@ -763,19 +688,81 @@ func (s *ConfigStore) GetProviderConfigRedacted(provider schemas.ModelProvider) redactedConfig.Keys = make([]schemas.Key, len(config.Keys)) for i, key := range config.Keys { redactedConfig.Keys[i] = schemas.Key{ + ID: key.ID, Models: key.Models, // Copy slice reference - read-only so safe Weight: key.Weight, } - path := fmt.Sprintf("providers.%s.keys[%d]", provider, i) + // Redact API key value + path := fmt.Sprintf("providers.%s.keys[%s]", provider, key.ID) if envVar, ok := envVarsByPath[path]; ok { redactedConfig.Keys[i].Value = "env." + envVar } else { redactedConfig.Keys[i].Value = RedactKey(key.Value) } + + // Redact Azure key config if present + if key.AzureKeyConfig != nil { + azureConfig := &schemas.AzureKeyConfig{ + Deployments: key.AzureKeyConfig.Deployments, + } + + // Redact Endpoint + path = fmt.Sprintf("providers.%s.keys[%s].azure_key_config.endpoint", provider, key.ID) + if envVar, ok := envVarsByPath[path]; ok { + azureConfig.Endpoint = "env." + envVar + } else { + azureConfig.Endpoint = RedactKey(key.AzureKeyConfig.Endpoint) + } + + // Redact APIVersion if present + if key.AzureKeyConfig.APIVersion != nil { + path = fmt.Sprintf("providers.%s.keys[%s].azure_key_config.api_version", provider, key.ID) + if envVar, ok := envVarsByPath[path]; ok { + apiVersion := "env." + envVar + azureConfig.APIVersion = &apiVersion + } else { + // APIVersion is not sensitive, keep as-is + azureConfig.APIVersion = key.AzureKeyConfig.APIVersion + } + } + + redactedConfig.Keys[i].AzureKeyConfig = azureConfig + } + + // Redact Vertex key config if present + if key.VertexKeyConfig != nil { + vertexConfig := &schemas.VertexKeyConfig{} + + // Redact ProjectID + path = fmt.Sprintf("providers.%s.keys[%s].vertex_key_config.project_id", provider, key.ID) + if envVar, ok := envVarsByPath[path]; ok { + vertexConfig.ProjectID = "env." + envVar + } else { + vertexConfig.ProjectID = RedactKey(key.VertexKeyConfig.ProjectID) + } + + // Region is not sensitive, handle env vars only + path = fmt.Sprintf("providers.%s.keys[%s].vertex_key_config.region", provider, key.ID) + if envVar, ok := envVarsByPath[path]; ok { + vertexConfig.Region = "env." + envVar + } else { + vertexConfig.Region = key.VertexKeyConfig.Region + } + + // Redact AuthCredentials + path = fmt.Sprintf("providers.%s.keys[%s].vertex_key_config.auth_credentials", provider, key.ID) + if envVar, ok := envVarsByPath[path]; ok { + vertexConfig.AuthCredentials = "env." + envVar + } else { + vertexConfig.AuthCredentials = RedactKey(key.VertexKeyConfig.AuthCredentials) + } + + redactedConfig.Keys[i].VertexKeyConfig = vertexConfig + } } - // Handle meta config redaction if present + // Handle meta config redaction if present (Bedrock only) if config.MetaConfig != nil { redactedMetaConfig := s.redactMetaConfig(provider, *config.MetaConfig, envVarsByPath) redactedConfig.MetaConfig = &redactedMetaConfig @@ -785,25 +772,9 @@ func (s *ConfigStore) GetProviderConfigRedacted(provider schemas.ModelProvider) } // redactMetaConfig creates a redacted copy of meta config based on provider type +// Note: Only Bedrock is supported for meta config now, Azure and Vertex moved to key level func (s *ConfigStore) redactMetaConfig(provider schemas.ModelProvider, metaConfig schemas.MetaConfig, envVarsByPath map[string]string) schemas.MetaConfig { switch m := metaConfig.(type) { - case *meta.AzureMetaConfig: - azureConfig := *m // Copy the struct - path := fmt.Sprintf("providers.%s.meta_config.endpoint", provider) - if envVar, ok := envVarsByPath[path]; ok { - azureConfig.Endpoint = "env." + envVar - } else { - azureConfig.Endpoint = RedactKey(azureConfig.Endpoint) - } - if azureConfig.APIVersion != nil { - path = fmt.Sprintf("providers.%s.meta_config.api_version", provider) - if envVar, ok := envVarsByPath[path]; ok { - apiVersion := "env." + envVar - azureConfig.APIVersion = &apiVersion - } - } - return &azureConfig - case *meta.BedrockMetaConfig: bedrockConfig := *m // Copy the struct path := fmt.Sprintf("providers.%s.meta_config.secret_access_key", provider) @@ -838,24 +809,6 @@ func (s *ConfigStore) redactMetaConfig(provider schemas.ModelProvider, metaConfi } return &bedrockConfig - case *meta.VertexMetaConfig: - vertexConfig := *m // Copy the struct - path := fmt.Sprintf("providers.%s.meta_config.project_id", provider) - if envVar, ok := envVarsByPath[path]; ok { - vertexConfig.ProjectID = "env." + envVar - } - path = fmt.Sprintf("providers.%s.meta_config.region", provider) - if envVar, ok := envVarsByPath[path]; ok { - vertexConfig.Region = "env." + envVar - } - path = fmt.Sprintf("providers.%s.meta_config.auth_credentials", provider) - if envVar, ok := envVarsByPath[path]; ok { - vertexConfig.AuthCredentials = "env." + envVar - } else { - vertexConfig.AuthCredentials = RedactKey(vertexConfig.AuthCredentials) - } - return &vertexConfig - default: return metaConfig } @@ -879,7 +832,7 @@ func (s *ConfigStore) GetAllProviders() ([]schemas.ModelProvider, error) { // // The method: // - Validates that the provider doesn't already exist -// - Processes environment variables in API keys and meta configurations +// - Processes environment variables in API keys, key-level configs, and meta configurations // - Stores the processed configuration in memory // - Updates metadata and timestamps func (s *ConfigStore) AddProvider(provider schemas.ModelProvider, config ProviderConfig) error { @@ -916,8 +869,13 @@ func (s *ConfigStore) AddProvider(provider schemas.ModelProvider, config Provide config.MetaConfig = metaConfig } - // Process environment variables in keys + // Process environment variables in keys (including key-level configs) for i, key := range config.Keys { + if key.ID == "" { + config.Keys[i].ID = uuid.NewString() + } + + // Process API key value processedValue, envVar, err := s.processEnvValue(key.Value) if err != nil { s.cleanupEnvKeys(string(provider), "", newEnvKeys) @@ -932,9 +890,26 @@ func (s *ConfigStore) AddProvider(provider schemas.ModelProvider, config Provide EnvVar: envVar, Provider: string(provider), KeyType: "api_key", - ConfigPath: fmt.Sprintf("providers.%s.keys[%d]", provider, i), + ConfigPath: fmt.Sprintf("providers.%s.keys[%s]", provider, key.ID), + KeyID: key.ID, }) } + + // Process Azure key config if present + if key.AzureKeyConfig != nil { + if err := s.processAzureKeyConfigEnvVars(&config.Keys[i], provider, i, newEnvKeys); err != nil { + s.cleanupEnvKeys(string(provider), "", newEnvKeys) + return fmt.Errorf("failed to process Azure key config env vars: %w", err) + } + } + + // Process Vertex key config if present + if key.VertexKeyConfig != nil { + if err := s.processVertexKeyConfigEnvVars(&config.Keys[i], provider, i, newEnvKeys); err != nil { + s.cleanupEnvKeys(string(provider), "", newEnvKeys) + return fmt.Errorf("failed to process Vertex key config env vars: %w", err) + } + } } s.Providers[provider] = config @@ -948,25 +923,24 @@ func (s *ConfigStore) AddProvider(provider schemas.ModelProvider, config Provide // via the HTTP API and ensures all data processing is done upfront. // // The method: -// - Processes environment variables in API keys and meta configurations +// - Processes environment variables in API keys, key-level configs, and meta configurations // - Stores the processed configuration in memory // - Updates metadata and timestamps // - Thread-safe operation with write locks // +// Note: Environment variable cleanup for deleted/updated keys is now handled automatically +// by the mergeKeys function before this method is called. +// // Parameters: // - provider: The provider to update // - config: The new configuration -// - envKeysToReplace: Map of environment keys that should be replaced (only these will be cleaned up) -func (s *ConfigStore) UpdateProviderConfig(provider schemas.ModelProvider, config ProviderConfig, envKeysToReplace map[string]struct{}) error { +func (s *ConfigStore) UpdateProviderConfig(provider schemas.ModelProvider, config ProviderConfig) error { s.mu.Lock() defer s.mu.Unlock() // Track new environment variables being added newEnvKeys := make(map[string]struct{}) - // Track which old env vars will be replaced (only those specified in envKeysToReplace) - oldEnvKeys := make(map[string]struct{}) - // Process environment variables in meta config if present if config.MetaConfig != nil { rawMetaData, err := json.Marshal(*config.MetaConfig) @@ -974,17 +948,6 @@ func (s *ConfigStore) UpdateProviderConfig(provider schemas.ModelProvider, confi return fmt.Errorf("failed to marshal meta config: %w", err) } - // Find old meta config env vars that should be replaced - for envVar, infos := range s.EnvKeys { - for _, info := range infos { - if info.Provider == string(provider) && info.KeyType == "meta_config" { - if _, shouldReplace := envKeysToReplace[envVar]; shouldReplace { - oldEnvKeys[envVar] = struct{}{} - } - } - } - } - processedMetaData, envKeys, err := s.processMetaConfigEnvVars(rawMetaData, provider) if err != nil { s.cleanupEnvKeys(string(provider), "", envKeys) // Clean up only new vars on failure @@ -1004,19 +967,13 @@ func (s *ConfigStore) UpdateProviderConfig(provider schemas.ModelProvider, confi } } - // Find old API key env vars that should be replaced - for envVar, infos := range s.EnvKeys { - for _, info := range infos { - if info.Provider == string(provider) && info.KeyType == "api_key" { - if _, shouldReplace := envKeysToReplace[envVar]; shouldReplace { - oldEnvKeys[envVar] = struct{}{} - } - } + // Process environment variables in keys (including key-level configs) + for i, key := range config.Keys { + if key.ID == "" { + config.Keys[i].ID = uuid.NewString() } - } - // Process environment variables in keys - for i, key := range config.Keys { + // Process API key value processedValue, envVar, err := s.processEnvValue(key.Value) if err != nil { s.cleanupEnvKeys(string(provider), "", newEnvKeys) // Clean up only new vars on failure @@ -1031,16 +988,30 @@ func (s *ConfigStore) UpdateProviderConfig(provider schemas.ModelProvider, confi EnvVar: envVar, Provider: string(provider), KeyType: "api_key", - ConfigPath: fmt.Sprintf("providers.%s.keys[%d]", provider, i), + ConfigPath: fmt.Sprintf("providers.%s.keys[%s]", provider, key.ID), + KeyID: key.ID, }) } + + // Process Azure key config if present + if key.AzureKeyConfig != nil { + if err := s.processAzureKeyConfigEnvVars(&config.Keys[i], provider, i, newEnvKeys); err != nil { + s.cleanupEnvKeys(string(provider), "", newEnvKeys) + return fmt.Errorf("failed to process Azure key config env vars: %w", err) + } + } + + // Process Vertex key config if present + if key.VertexKeyConfig != nil { + if err := s.processVertexKeyConfigEnvVars(&config.Keys[i], provider, i, newEnvKeys); err != nil { + s.cleanupEnvKeys(string(provider), "", newEnvKeys) + return fmt.Errorf("failed to process Vertex key config env vars: %w", err) + } + } } s.Providers[provider] = config - // Clean up old env vars that were replaced - s.cleanupEnvKeys(string(provider), "", oldEnvKeys) - s.logger.Info(fmt.Sprintf("Updated configuration for provider: %s", provider)) return nil } @@ -1089,6 +1060,7 @@ func (s *ConfigStore) processMCPEnvVars() error { Provider: "", KeyType: "connection_string", ConfigPath: fmt.Sprintf("mcp.client_configs[%d].connection_string", i), + KeyID: "", // Empty for MCP connection strings }) } s.MCPConfig.ClientConfigs[i].ConnectionString = &newValue @@ -1150,6 +1122,7 @@ func (s *ConfigStore) AddMCPClient(clientConfig schemas.MCPClientConfig) error { Provider: "", KeyType: "connection_string", ConfigPath: fmt.Sprintf("mcp.client_configs.%s.connection_string", clientConfig.Name), + KeyID: "", // Empty for MCP connection strings }) } s.MCPConfig.ClientConfigs[len(s.MCPConfig.ClientConfigs)-1].ConnectionString = &processedValue @@ -1365,6 +1338,87 @@ func (s *ConfigStore) cleanupEnvVar(envVar, provider, mcpClientName string) { } } +// CleanupEnvKeysForKeys removes environment variable entries for specific keys that are being deleted. +// This function targets key-specific environment variables based on key IDs. +// +// Parameters: +// - provider: Provider name the keys belong to +// - keysToDelete: List of keys being deleted (uses their IDs to identify env vars to clean up) +func (s *ConfigStore) CleanupEnvKeysForKeys(provider string, keysToDelete []schemas.Key) { + // Create a set of key IDs to delete for efficient lookup + keyIDsToDelete := make(map[string]bool) + for _, key := range keysToDelete { + keyIDsToDelete[key.ID] = true + } + + // Iterate through all environment variables and remove entries for deleted keys + for envVar, infos := range s.EnvKeys { + filteredInfos := make([]EnvKeyInfo, 0, len(infos)) + + for _, info := range infos { + // Keep entries that either: + // 1. Don't belong to this provider, OR + // 2. Don't have a KeyID (meta config, MCP), OR + // 3. Have a KeyID that's not being deleted + shouldKeep := info.Provider != provider || + info.KeyID == "" || + !keyIDsToDelete[info.KeyID] + + if shouldKeep { + filteredInfos = append(filteredInfos, info) + } + } + + // Update or delete the environment variable entry + if len(filteredInfos) == 0 { + delete(s.EnvKeys, envVar) + } else { + s.EnvKeys[envVar] = filteredInfos + } + } +} + +// CleanupEnvKeysForUpdatedKeys removes environment variable entries for keys that are being updated +// but whose environment variables are changing. This prevents stale env var references. +// +// Parameters: +// - provider: Provider name the keys belong to +// - keysToUpdate: List of keys being updated (uses their IDs to identify env vars to clean up) +func (s *ConfigStore) CleanupEnvKeysForUpdatedKeys(provider string, keysToUpdate []schemas.Key) { + // Create a set of key IDs to update for efficient lookup + keyIDsToUpdate := make(map[string]bool) + for _, key := range keysToUpdate { + keyIDsToUpdate[key.ID] = true + } + + // Iterate through all environment variables and remove entries for updated keys + // The updated keys will re-add their env vars during processing + for envVar, infos := range s.EnvKeys { + filteredInfos := make([]EnvKeyInfo, 0, len(infos)) + + for _, info := range infos { + // Keep entries that either: + // 1. Don't belong to this provider, OR + // 2. Don't have a KeyID (meta config, MCP), OR + // 3. Have a KeyID that's not being updated + shouldKeep := info.Provider != provider || + info.KeyID == "" || + !keyIDsToUpdate[info.KeyID] + + if shouldKeep { + filteredInfos = append(filteredInfos, info) + } + } + + // Update or delete the environment variable entry + if len(filteredInfos) == 0 { + delete(s.EnvKeys, envVar) + } else { + s.EnvKeys[envVar] = filteredInfos + } + } +} + // autoDetectProviders automatically detects common environment variables and sets up providers // when no configuration file exists. This enables zero-config startup when users have set // standard environment variables like OPENAI_API_KEY, ANTHROPIC_API_KEY, etc. @@ -1391,10 +1445,14 @@ func (s *ConfigStore) autoDetectProviders() { for provider, envVars := range providerEnvVars { for _, envVar := range envVars { if apiKey := os.Getenv(envVar); apiKey != "" { + // Generate a unique ID for the auto-detected key + keyID := uuid.NewString() + // Create default provider configuration providerConfig := ProviderConfig{ Keys: []schemas.Key{ { + ID: keyID, Value: apiKey, Models: []string{}, // Empty means all supported models Weight: 1.0, @@ -1411,7 +1469,8 @@ func (s *ConfigStore) autoDetectProviders() { EnvVar: envVar, Provider: string(provider), KeyType: "api_key", - ConfigPath: fmt.Sprintf("providers.%s.keys[0]", provider), + ConfigPath: fmt.Sprintf("providers.%s.keys[%s]", provider, keyID), + KeyID: keyID, }) s.logger.Info(fmt.Sprintf("Auto-detected %s provider from environment variable %s", provider, envVar)) @@ -1425,3 +1484,104 @@ func (s *ConfigStore) autoDetectProviders() { s.logger.Info(fmt.Sprintf("Auto-configured %d provider(s) from environment variables", detectedCount)) } } + +// processAzureKeyConfigEnvVars processes environment variables in Azure key configuration +func (s *ConfigStore) processAzureKeyConfigEnvVars(key *schemas.Key, provider schemas.ModelProvider, keyIndex int, newEnvKeys map[string]struct{}) error { + azureConfig := key.AzureKeyConfig + + // Process Endpoint + processedEndpoint, envVar, err := s.processEnvValue(azureConfig.Endpoint) + if err != nil { + return err + } + if envVar != "" { + newEnvKeys[envVar] = struct{}{} + s.EnvKeys[envVar] = append(s.EnvKeys[envVar], EnvKeyInfo{ + EnvVar: envVar, + Provider: string(provider), + KeyType: "azure_config", + ConfigPath: fmt.Sprintf("providers.%s.keys[%s].azure_key_config.endpoint", provider, key.ID), + KeyID: key.ID, + }) + } + azureConfig.Endpoint = processedEndpoint + + // Process APIVersion if present + if azureConfig.APIVersion != nil { + processedAPIVersion, envVar, err := s.processEnvValue(*azureConfig.APIVersion) + if err != nil { + return err + } + if envVar != "" { + newEnvKeys[envVar] = struct{}{} + s.EnvKeys[envVar] = append(s.EnvKeys[envVar], EnvKeyInfo{ + EnvVar: envVar, + Provider: string(provider), + KeyType: "azure_config", + ConfigPath: fmt.Sprintf("providers.%s.keys[%s].azure_key_config.api_version", provider, key.ID), + KeyID: key.ID, + }) + } + azureConfig.APIVersion = &processedAPIVersion + } + + return nil +} + +// processVertexKeyConfigEnvVars processes environment variables in Vertex key configuration +func (s *ConfigStore) processVertexKeyConfigEnvVars(key *schemas.Key, provider schemas.ModelProvider, keyIndex int, newEnvKeys map[string]struct{}) error { + vertexConfig := key.VertexKeyConfig + + // Process ProjectID + processedProjectID, envVar, err := s.processEnvValue(vertexConfig.ProjectID) + if err != nil { + return err + } + if envVar != "" { + newEnvKeys[envVar] = struct{}{} + s.EnvKeys[envVar] = append(s.EnvKeys[envVar], EnvKeyInfo{ + EnvVar: envVar, + Provider: string(provider), + KeyType: "vertex_config", + ConfigPath: fmt.Sprintf("providers.%s.keys[%s].vertex_key_config.project_id", provider, key.ID), + KeyID: key.ID, + }) + } + vertexConfig.ProjectID = processedProjectID + + // Process Region + processedRegion, envVar, err := s.processEnvValue(vertexConfig.Region) + if err != nil { + return err + } + if envVar != "" { + newEnvKeys[envVar] = struct{}{} + s.EnvKeys[envVar] = append(s.EnvKeys[envVar], EnvKeyInfo{ + EnvVar: envVar, + Provider: string(provider), + KeyType: "vertex_config", + ConfigPath: fmt.Sprintf("providers.%s.keys[%s].vertex_key_config.region", provider, key.ID), + KeyID: key.ID, + }) + } + vertexConfig.Region = processedRegion + + // Process AuthCredentials + processedAuthCredentials, envVar, err := s.processEnvValue(vertexConfig.AuthCredentials) + if err != nil { + return err + } + if envVar != "" { + newEnvKeys[envVar] = struct{}{} + s.EnvKeys[envVar] = append(s.EnvKeys[envVar], EnvKeyInfo{ + EnvVar: envVar, + Provider: string(provider), + KeyType: "vertex_config", + ConfigPath: fmt.Sprintf("providers.%s.keys[%s].vertex_key_config.auth_credentials", provider, key.ID), + KeyID: key.ID, + }) + } + vertexConfig.AuthCredentials = processedAuthCredentials + + return nil +} diff --git a/transports/bifrost-http/ui/404.html b/transports/bifrost-http/ui/404.html index a0fd813772..1a654e3737 100644 --- a/transports/bifrost-http/ui/404.html +++ b/transports/bifrost-http/ui/404.html @@ -1,4 +1,4 @@ -