From 7897425c800988ed3b6070caa8770db8140b16b2 Mon Sep 17 00:00:00 2001 From: Dan Piths <85949566+danpiths@users.noreply.github.com> Date: Wed, 18 Mar 2026 16:39:04 +0530 Subject: [PATCH] feat: add provider keys HTTP handlers and refactor optional keys --- .../bifrost-http/handlers/provider_keys.go | 416 ++++++++++++++++++ transports/bifrost-http/handlers/providers.go | 327 ++------------ .../bifrost-http/handlers/providers_test.go | 1 + transports/bifrost-http/server/server.go | 25 +- 4 files changed, 479 insertions(+), 290 deletions(-) create mode 100644 transports/bifrost-http/handlers/provider_keys.go create mode 100644 transports/bifrost-http/handlers/providers_test.go diff --git a/transports/bifrost-http/handlers/provider_keys.go b/transports/bifrost-http/handlers/provider_keys.go new file mode 100644 index 0000000000..70a86ebfcc --- /dev/null +++ b/transports/bifrost-http/handlers/provider_keys.go @@ -0,0 +1,416 @@ +package handlers + +import ( + "errors" + "fmt" + "net/url" + + "github.com/bytedance/sonic" + "github.com/google/uuid" + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/transports/bifrost-http/lib" + "github.com/valyala/fasthttp" +) + +// ListProviderKeysResponse represents the response for listing keys for a provider. +type ListProviderKeysResponse struct { + Keys []schemas.Key `json:"keys"` + Total int `json:"total"` +} + +func (h *ProviderHandler) listProviderKeys(ctx *fasthttp.RequestCtx) { + provider, err := getProviderFromCtx(ctx) + if err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid provider: %v", err)) + return + } + + keys, err := h.inMemoryStore.GetProviderKeysRedacted(provider) + if err != nil { + if errors.Is(err, lib.ErrNotFound) { + SendError(ctx, fasthttp.StatusNotFound, fmt.Sprintf("Provider not found: %v", err)) + return + } + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get provider keys: %v", err)) + return + } + + SendJSON(ctx, ListProviderKeysResponse{Keys: keys, Total: len(keys)}) +} + +func (h *ProviderHandler) getProviderKey(ctx *fasthttp.RequestCtx) { + provider, err := getProviderFromCtx(ctx) + if err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid provider: %v", err)) + return + } + + keyID, err := getKeyIDFromCtx(ctx) + if err != nil { + SendError(ctx, fasthttp.StatusBadRequest, err.Error()) + return + } + + key, err := h.inMemoryStore.GetProviderKeyRedacted(provider, keyID) + if err != nil { + if errors.Is(err, lib.ErrNotFound) { + SendError(ctx, fasthttp.StatusNotFound, fmt.Sprintf("Provider key not found: %v", err)) + return + } + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get provider key: %v", err)) + return + } + + SendJSON(ctx, key) +} + +func (h *ProviderHandler) createProviderKey(ctx *fasthttp.RequestCtx) { + provider, err := getProviderFromCtx(ctx) + if err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid provider: %v", err)) + return + } + + var key schemas.Key + if err := sonic.Unmarshal(ctx.PostBody(), &key); err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid JSON: %v", err)) + return + } + + providerConfig, err := h.inMemoryStore.GetProviderConfigRaw(provider) + if err != nil { + if errors.Is(err, lib.ErrNotFound) { + SendError(ctx, fasthttp.StatusNotFound, fmt.Sprintf("Provider not found: %v", err)) + return + } + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get provider config: %v", err)) + return + } + + if bifrost.IsKeylessProvider(provider) || (providerConfig.CustomProviderConfig != nil && providerConfig.CustomProviderConfig.IsKeyLess) { + SendError(ctx, fasthttp.StatusBadRequest, "Cannot add keys to a keyless provider") + return + } + + if key.Value.GetValue() == "" { + SendError(ctx, fasthttp.StatusBadRequest, "Key value must not be empty") + return + } + + if key.ID == "" { + key.ID = uuid.NewString() + } + if key.Enabled == nil { + key.Enabled = bifrost.Ptr(true) + } + + if err := h.inMemoryStore.AddProviderKey(ctx, provider, key); err != nil { + logger.Warn("Failed to create key for provider %s: %v", provider, err) + if errors.Is(err, lib.ErrNotFound) { + SendError(ctx, fasthttp.StatusNotFound, fmt.Sprintf("Provider not found: %v", err)) + return + } + if errors.Is(err, lib.ErrAlreadyExists) { + SendError(ctx, fasthttp.StatusConflict, err.Error()) + return + } + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to create provider key: %v", err)) + return + } + + if err := h.attemptModelDiscovery(ctx, provider, providerConfig.CustomProviderConfig); err != nil { + logger.Warn("Model discovery failed for provider %s after key create: %v", provider, err) + } + + redactedKey, err := h.inMemoryStore.GetProviderKeyRedacted(provider, key.ID) + if err != nil { + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get created provider key: %v", err)) + return + } + + SendJSON(ctx, redactedKey) +} + +func (h *ProviderHandler) updateProviderKey(ctx *fasthttp.RequestCtx) { + provider, err := getProviderFromCtx(ctx) + if err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid provider: %v", err)) + return + } + + keyID, err := getKeyIDFromCtx(ctx) + if err != nil { + SendError(ctx, fasthttp.StatusBadRequest, err.Error()) + return + } + + var updateKey schemas.Key + if err := sonic.Unmarshal(ctx.PostBody(), &updateKey); err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid JSON: %v", err)) + return + } + + providerConfig, err := h.inMemoryStore.GetProviderConfigRaw(provider) + if err != nil { + if errors.Is(err, lib.ErrNotFound) { + SendError(ctx, fasthttp.StatusNotFound, fmt.Sprintf("Provider not found: %v", err)) + return + } + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get provider config: %v", err)) + return + } + + if bifrost.IsKeylessProvider(provider) || (providerConfig.CustomProviderConfig != nil && providerConfig.CustomProviderConfig.IsKeyLess) { + SendError(ctx, fasthttp.StatusBadRequest, "Cannot update keys on a keyless provider") + return + } + + oldRawKey, err := h.inMemoryStore.GetProviderKeyRaw(provider, keyID) + if err != nil { + if errors.Is(err, lib.ErrNotFound) { + SendError(ctx, fasthttp.StatusNotFound, fmt.Sprintf("Provider key not found: %v", err)) + return + } + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get provider key: %v", err)) + return + } + + oldRedactedKey, err := h.inMemoryStore.GetProviderKeyRedacted(provider, keyID) + if err != nil { + if errors.Is(err, lib.ErrNotFound) { + SendError(ctx, fasthttp.StatusNotFound, fmt.Sprintf("Provider key not found: %v", err)) + return + } + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get provider key: %v", err)) + return + } + + updateKey.ID = keyID + mergedKey := h.mergeUpdatedKey(*oldRawKey, *oldRedactedKey, updateKey) + + if err := h.inMemoryStore.UpdateProviderKey(ctx, provider, keyID, mergedKey); err != nil { + logger.Warn("Failed to update key %s for provider %s: %v", keyID, provider, err) + if errors.Is(err, lib.ErrNotFound) { + SendError(ctx, fasthttp.StatusNotFound, fmt.Sprintf("Provider key not found: %v", err)) + return + } + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to update provider key: %v", err)) + return + } + + if err := h.attemptModelDiscovery(ctx, provider, providerConfig.CustomProviderConfig); err != nil { + logger.Warn("Model discovery failed for provider %s after key update: %v", provider, err) + } + + redactedKey, err := h.inMemoryStore.GetProviderKeyRedacted(provider, keyID) + if err != nil { + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get updated provider key: %v", err)) + return + } + + SendJSON(ctx, redactedKey) +} + +func (h *ProviderHandler) deleteProviderKey(ctx *fasthttp.RequestCtx) { + provider, err := getProviderFromCtx(ctx) + if err != nil { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid provider: %v", err)) + return + } + + keyID, err := getKeyIDFromCtx(ctx) + if err != nil { + SendError(ctx, fasthttp.StatusBadRequest, err.Error()) + return + } + + providerConfig, err := h.inMemoryStore.GetProviderConfigRaw(provider) + if err != nil { + if errors.Is(err, lib.ErrNotFound) { + SendError(ctx, fasthttp.StatusNotFound, fmt.Sprintf("Provider not found: %v", err)) + return + } + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get provider config: %v", err)) + return + } + + if bifrost.IsKeylessProvider(provider) || (providerConfig.CustomProviderConfig != nil && providerConfig.CustomProviderConfig.IsKeyLess) { + SendError(ctx, fasthttp.StatusBadRequest, "Cannot delete keys on a keyless provider") + return + } + + redactedKey, err := h.inMemoryStore.GetProviderKeyRedacted(provider, keyID) + if err != nil { + if errors.Is(err, lib.ErrNotFound) { + SendError(ctx, fasthttp.StatusNotFound, fmt.Sprintf("Provider key not found: %v", err)) + return + } + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to get provider key: %v", err)) + return + } + + if err := h.inMemoryStore.RemoveProviderKey(ctx, provider, keyID); err != nil { + logger.Warn("Failed to delete key %s for provider %s: %v", keyID, provider, err) + if errors.Is(err, lib.ErrNotFound) { + SendError(ctx, fasthttp.StatusNotFound, fmt.Sprintf("Provider key not found: %v", err)) + return + } + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to delete provider key: %v", err)) + return + } + + if err := h.attemptModelDiscovery(ctx, provider, providerConfig.CustomProviderConfig); err != nil { + logger.Warn("Model discovery failed for provider %s after key delete: %v", provider, err) + } + + SendJSON(ctx, redactedKey) +} + +// mergeUpdatedKey merges an updated key with the old raw/redacted versions, +// preserving real values for fields that were sent back in redacted form. +func (h *ProviderHandler) mergeUpdatedKey(oldRawKey, oldRedactedKey, updateKey schemas.Key) schemas.Key { + mergedKey := updateKey + + if updateKey.Value.IsRedacted() && updateKey.Value.Equals(&oldRedactedKey.Value) { + mergedKey.Value = oldRawKey.Value + } + + if updateKey.AzureKeyConfig != nil && oldRedactedKey.AzureKeyConfig != nil && oldRawKey.AzureKeyConfig != nil { + if updateKey.AzureKeyConfig.Endpoint.IsRedacted() && + updateKey.AzureKeyConfig.Endpoint.Equals(&oldRedactedKey.AzureKeyConfig.Endpoint) { + mergedKey.AzureKeyConfig.Endpoint = oldRawKey.AzureKeyConfig.Endpoint + } + if updateKey.AzureKeyConfig.APIVersion != nil && + oldRedactedKey.AzureKeyConfig.APIVersion != nil && + oldRawKey.AzureKeyConfig != nil && + updateKey.AzureKeyConfig.APIVersion.IsRedacted() && + updateKey.AzureKeyConfig.APIVersion.Equals(oldRedactedKey.AzureKeyConfig.APIVersion) { + mergedKey.AzureKeyConfig.APIVersion = oldRawKey.AzureKeyConfig.APIVersion + } + if updateKey.AzureKeyConfig.ClientID != nil && + oldRedactedKey.AzureKeyConfig.ClientID != nil && + oldRawKey.AzureKeyConfig != nil && + updateKey.AzureKeyConfig.ClientID.IsRedacted() && + updateKey.AzureKeyConfig.ClientID.Equals(oldRedactedKey.AzureKeyConfig.ClientID) { + mergedKey.AzureKeyConfig.ClientID = oldRawKey.AzureKeyConfig.ClientID + } + if updateKey.AzureKeyConfig.ClientSecret != nil && + oldRedactedKey.AzureKeyConfig.ClientSecret != nil && + oldRawKey.AzureKeyConfig != nil && + updateKey.AzureKeyConfig.ClientSecret.IsRedacted() && + updateKey.AzureKeyConfig.ClientSecret.Equals(oldRedactedKey.AzureKeyConfig.ClientSecret) { + mergedKey.AzureKeyConfig.ClientSecret = oldRawKey.AzureKeyConfig.ClientSecret + } + if updateKey.AzureKeyConfig.TenantID != nil && + oldRedactedKey.AzureKeyConfig.TenantID != nil && + oldRawKey.AzureKeyConfig != nil && + updateKey.AzureKeyConfig.TenantID.IsRedacted() && + updateKey.AzureKeyConfig.TenantID.Equals(oldRedactedKey.AzureKeyConfig.TenantID) { + mergedKey.AzureKeyConfig.TenantID = oldRawKey.AzureKeyConfig.TenantID + } + } + + if updateKey.VertexKeyConfig != nil && oldRedactedKey.VertexKeyConfig != nil && oldRawKey.VertexKeyConfig != nil { + if updateKey.VertexKeyConfig.ProjectID.IsRedacted() && + updateKey.VertexKeyConfig.ProjectID.Equals(&oldRedactedKey.VertexKeyConfig.ProjectID) { + mergedKey.VertexKeyConfig.ProjectID = oldRawKey.VertexKeyConfig.ProjectID + } + if updateKey.VertexKeyConfig.ProjectNumber.IsRedacted() && + updateKey.VertexKeyConfig.ProjectNumber.Equals(&oldRedactedKey.VertexKeyConfig.ProjectNumber) { + mergedKey.VertexKeyConfig.ProjectNumber = oldRawKey.VertexKeyConfig.ProjectNumber + } + if updateKey.VertexKeyConfig.Region.IsRedacted() && + updateKey.VertexKeyConfig.Region.Equals(&oldRedactedKey.VertexKeyConfig.Region) { + mergedKey.VertexKeyConfig.Region = oldRawKey.VertexKeyConfig.Region + } + if updateKey.VertexKeyConfig.AuthCredentials.IsRedacted() && + updateKey.VertexKeyConfig.AuthCredentials.Equals(&oldRedactedKey.VertexKeyConfig.AuthCredentials) { + mergedKey.VertexKeyConfig.AuthCredentials = oldRawKey.VertexKeyConfig.AuthCredentials + } + } + + if updateKey.BedrockKeyConfig != nil && oldRedactedKey.BedrockKeyConfig != nil && oldRawKey.BedrockKeyConfig != nil { + if updateKey.BedrockKeyConfig.AccessKey.IsRedacted() && + updateKey.BedrockKeyConfig.AccessKey.Equals(&oldRedactedKey.BedrockKeyConfig.AccessKey) { + mergedKey.BedrockKeyConfig.AccessKey = oldRawKey.BedrockKeyConfig.AccessKey + } + if updateKey.BedrockKeyConfig.SecretKey.IsRedacted() && + updateKey.BedrockKeyConfig.SecretKey.Equals(&oldRedactedKey.BedrockKeyConfig.SecretKey) { + mergedKey.BedrockKeyConfig.SecretKey = oldRawKey.BedrockKeyConfig.SecretKey + } + if updateKey.BedrockKeyConfig.SessionToken != nil && + oldRedactedKey.BedrockKeyConfig.SessionToken != nil && + oldRawKey.BedrockKeyConfig != nil && + updateKey.BedrockKeyConfig.SessionToken.IsRedacted() && + updateKey.BedrockKeyConfig.SessionToken.Equals(oldRedactedKey.BedrockKeyConfig.SessionToken) { + mergedKey.BedrockKeyConfig.SessionToken = oldRawKey.BedrockKeyConfig.SessionToken + } + if updateKey.BedrockKeyConfig.Region != nil && + oldRedactedKey.BedrockKeyConfig.Region != nil && + oldRawKey.BedrockKeyConfig != nil && + updateKey.BedrockKeyConfig.Region.IsRedacted() && + updateKey.BedrockKeyConfig.Region.Equals(oldRedactedKey.BedrockKeyConfig.Region) { + mergedKey.BedrockKeyConfig.Region = oldRawKey.BedrockKeyConfig.Region + } + if updateKey.BedrockKeyConfig.ARN != nil && + oldRedactedKey.BedrockKeyConfig.ARN != nil && + oldRawKey.BedrockKeyConfig != nil && + updateKey.BedrockKeyConfig.ARN.IsRedacted() && + updateKey.BedrockKeyConfig.ARN.Equals(oldRedactedKey.BedrockKeyConfig.ARN) { + mergedKey.BedrockKeyConfig.ARN = oldRawKey.BedrockKeyConfig.ARN + } + if updateKey.BedrockKeyConfig.RoleARN != nil && + oldRedactedKey.BedrockKeyConfig.RoleARN != nil && + oldRawKey.BedrockKeyConfig != nil && + updateKey.BedrockKeyConfig.RoleARN.IsRedacted() && + updateKey.BedrockKeyConfig.RoleARN.Equals(oldRedactedKey.BedrockKeyConfig.RoleARN) { + mergedKey.BedrockKeyConfig.RoleARN = oldRawKey.BedrockKeyConfig.RoleARN + } + if updateKey.BedrockKeyConfig.ExternalID != nil && + oldRedactedKey.BedrockKeyConfig.ExternalID != nil && + oldRawKey.BedrockKeyConfig != nil && + updateKey.BedrockKeyConfig.ExternalID.IsRedacted() && + updateKey.BedrockKeyConfig.ExternalID.Equals(oldRedactedKey.BedrockKeyConfig.ExternalID) { + mergedKey.BedrockKeyConfig.ExternalID = oldRawKey.BedrockKeyConfig.ExternalID + } + if updateKey.BedrockKeyConfig.RoleSessionName != nil && + oldRedactedKey.BedrockKeyConfig.RoleSessionName != nil && + oldRawKey.BedrockKeyConfig != nil && + updateKey.BedrockKeyConfig.RoleSessionName.IsRedacted() && + updateKey.BedrockKeyConfig.RoleSessionName.Equals(oldRedactedKey.BedrockKeyConfig.RoleSessionName) { + mergedKey.BedrockKeyConfig.RoleSessionName = oldRawKey.BedrockKeyConfig.RoleSessionName + } + } + + if updateKey.VLLMKeyConfig != nil && oldRedactedKey.VLLMKeyConfig != nil && oldRawKey.VLLMKeyConfig != nil { + if updateKey.VLLMKeyConfig.URL.IsRedacted() && + updateKey.VLLMKeyConfig.URL.Equals(&oldRedactedKey.VLLMKeyConfig.URL) { + mergedKey.VLLMKeyConfig.URL = oldRawKey.VLLMKeyConfig.URL + } + } + + mergedKey.ConfigHash = oldRawKey.ConfigHash + mergedKey.Status = oldRawKey.Status + + return mergedKey +} + +func getKeyIDFromCtx(ctx *fasthttp.RequestCtx) (string, error) { + keyValue := ctx.UserValue("key_id") + if keyValue == nil { + return "", fmt.Errorf("missing key_id parameter") + } + + keyID, ok := keyValue.(string) + if !ok || keyID == "" { + return "", fmt.Errorf("invalid key_id parameter") + } + + decoded, err := url.PathUnescape(keyID) + if err != nil { + return "", fmt.Errorf("invalid key_id parameter encoding: %v", err) + } + + return decoded, nil +} diff --git a/transports/bifrost-http/handlers/providers.go b/transports/bifrost-http/handlers/providers.go index 8b0649fa1d..e608813c7e 100644 --- a/transports/bifrost-http/handlers/providers.go +++ b/transports/bifrost-http/handlers/providers.go @@ -4,7 +4,6 @@ package handlers import ( "context" - "encoding/json" "errors" "fmt" "net/url" @@ -59,19 +58,18 @@ const ( // ProviderResponse represents the response for provider operations type ProviderResponse struct { - Name schemas.ModelProvider `json:"name"` - Keys []schemas.Key `json:"keys"` // API keys for the provider - NetworkConfig schemas.NetworkConfig `json:"network_config"` // Network-related settings - ConcurrencyAndBufferSize schemas.ConcurrencyAndBufferSize `json:"concurrency_and_buffer_size"` // Concurrency settings - ProxyConfig *schemas.ProxyConfig `json:"proxy_config"` // Proxy configuration - SendBackRawRequest bool `json:"send_back_raw_request"` // Include raw request in BifrostResponse - SendBackRawResponse bool `json:"send_back_raw_response"` // Include raw response in BifrostResponse - StoreRawRequestResponse bool `json:"store_raw_request_response"` // Capture raw request/response for internal logging only - CustomProviderConfig *schemas.CustomProviderConfig `json:"custom_provider_config,omitempty"` // Custom provider configuration - ProviderStatus ProviderStatus `json:"provider_status"` // Health/initialization status of the provider - Status string `json:"status,omitempty"` // Operational status (e.g., list_models_failed) - Description string `json:"description,omitempty"` // Error/status description - ConfigHash string `json:"config_hash,omitempty"` // Hash of config.json version, used for change detection + Name schemas.ModelProvider `json:"name"` + NetworkConfig schemas.NetworkConfig `json:"network_config"` // Network-related settings + ConcurrencyAndBufferSize schemas.ConcurrencyAndBufferSize `json:"concurrency_and_buffer_size"` // Concurrency settings + ProxyConfig *schemas.ProxyConfig `json:"proxy_config"` // Proxy configuration + SendBackRawRequest bool `json:"send_back_raw_request"` // Include raw request in BifrostResponse + SendBackRawResponse bool `json:"send_back_raw_response"` // Include raw response in BifrostResponse + StoreRawRequestResponse bool `json:"store_raw_request_response"` // Capture raw request/response for internal logging only + CustomProviderConfig *schemas.CustomProviderConfig `json:"custom_provider_config,omitempty"` // Custom provider configuration + ProviderStatus ProviderStatus `json:"provider_status"` // Health/initialization status of the provider + Status string `json:"status,omitempty"` // Operational status (e.g., list_models_failed) + Description string `json:"description,omitempty"` // Error/status description + ConfigHash string `json:"config_hash,omitempty"` // Hash of config.json version, used for change detection } // ListProvidersResponse represents the response for listing all providers @@ -86,14 +84,40 @@ type ErrorResponse struct { Message string `json:"message,omitempty"` } +type providerCreatePayload struct { + Provider schemas.ModelProvider `json:"provider"` + NetworkConfig *schemas.NetworkConfig `json:"network_config,omitempty"` + ConcurrencyAndBufferSize *schemas.ConcurrencyAndBufferSize `json:"concurrency_and_buffer_size,omitempty"` + ProxyConfig *schemas.ProxyConfig `json:"proxy_config,omitempty"` + SendBackRawRequest *bool `json:"send_back_raw_request,omitempty"` + SendBackRawResponse *bool `json:"send_back_raw_response,omitempty"` + StoreRawRequestResponse *bool `json:"store_raw_request_response,omitempty"` + CustomProviderConfig *schemas.CustomProviderConfig `json:"custom_provider_config,omitempty"` +} + +type providerUpdatePayload struct { + NetworkConfig schemas.NetworkConfig `json:"network_config"` + ConcurrencyAndBufferSize schemas.ConcurrencyAndBufferSize `json:"concurrency_and_buffer_size"` + ProxyConfig *schemas.ProxyConfig `json:"proxy_config,omitempty"` + SendBackRawRequest *bool `json:"send_back_raw_request,omitempty"` + SendBackRawResponse *bool `json:"send_back_raw_response,omitempty"` + StoreRawRequestResponse *bool `json:"store_raw_request_response,omitempty"` + CustomProviderConfig *schemas.CustomProviderConfig `json:"custom_provider_config,omitempty"` +} + // RegisterRoutes registers all provider management routes func (h *ProviderHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) { // Provider CRUD operations r.GET("/api/providers", lib.ChainMiddlewares(h.listProviders, middlewares...)) r.GET("/api/providers/{provider}", lib.ChainMiddlewares(h.getProvider, middlewares...)) + r.GET("/api/providers/{provider}/keys", lib.ChainMiddlewares(h.listProviderKeys, middlewares...)) + r.GET("/api/providers/{provider}/keys/{key_id}", lib.ChainMiddlewares(h.getProviderKey, middlewares...)) r.POST("/api/providers", lib.ChainMiddlewares(h.addProvider, middlewares...)) + r.POST("/api/providers/{provider}/keys", lib.ChainMiddlewares(h.createProviderKey, middlewares...)) r.PUT("/api/providers/{provider}", lib.ChainMiddlewares(h.updateProvider, middlewares...)) + r.PUT("/api/providers/{provider}/keys/{key_id}", lib.ChainMiddlewares(h.updateProviderKey, middlewares...)) r.DELETE("/api/providers/{provider}", lib.ChainMiddlewares(h.deleteProvider, middlewares...)) + r.DELETE("/api/providers/{provider}/keys/{key_id}", lib.ChainMiddlewares(h.deleteProviderKey, middlewares...)) r.GET("/api/keys", lib.ChainMiddlewares(h.listKeys, middlewares...)) r.GET("/api/models", lib.ChainMiddlewares(h.listModels, middlewares...)) r.GET("/api/models/parameters", lib.ChainMiddlewares(h.getModelParameters, middlewares...)) @@ -195,19 +219,8 @@ func (h *ProviderHandler) getProvider(ctx *fasthttp.RequestCtx) { // addProvider handles POST /api/providers - Add a new provider // NOTE: This only gets called when a new custom provider is added func (h *ProviderHandler) addProvider(ctx *fasthttp.RequestCtx) { - // Payload structure - var payload = struct { - Provider schemas.ModelProvider `json:"provider"` - Keys []schemas.Key `json:"keys"` // API keys for the provider - NetworkConfig *schemas.NetworkConfig `json:"network_config,omitempty"` // Network-related settings - ConcurrencyAndBufferSize *schemas.ConcurrencyAndBufferSize `json:"concurrency_and_buffer_size,omitempty"` // Concurrency settings - ProxyConfig *schemas.ProxyConfig `json:"proxy_config,omitempty"` // Proxy configuration - SendBackRawRequest *bool `json:"send_back_raw_request,omitempty"` // Include raw request in BifrostResponse - SendBackRawResponse *bool `json:"send_back_raw_response,omitempty"` // Include raw response in BifrostResponse - StoreRawRequestResponse *bool `json:"store_raw_request_response,omitempty"` // Capture raw request/response for internal logging only - CustomProviderConfig *schemas.CustomProviderConfig `json:"custom_provider_config,omitempty"` // Custom provider configuration - }{} - if err := json.Unmarshal(ctx.PostBody(), &payload); err != nil { + var payload providerCreatePayload + if err := sonic.Unmarshal(ctx.PostBody(), &payload); err != nil { SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid JSON: %v", err)) return } @@ -266,7 +279,6 @@ func (h *ProviderHandler) addProvider(ctx *fasthttp.RequestCtx) { // Construct ProviderConfig from individual fields config := configstore.ProviderConfig{ - Keys: payload.Keys, NetworkConfig: payload.NetworkConfig, ProxyConfig: payload.ProxyConfig, ConcurrencyAndBufferSize: payload.ConcurrencyAndBufferSize, @@ -293,9 +305,7 @@ func (h *ProviderHandler) addProvider(ctx *fasthttp.RequestCtx) { logger.Info("Provider %s added successfully", payload.Provider) // Attempt model discovery - err := h.attemptModelDiscovery(ctx, payload.Provider, payload.CustomProviderConfig) - - if err != nil { + if err := h.attemptModelDiscovery(ctx, payload.Provider, payload.CustomProviderConfig); err != nil { logger.Warn("Model discovery failed for provider %s: %v", payload.Provider, err) } @@ -337,16 +347,7 @@ func (h *ProviderHandler) updateProvider(ctx *fasthttp.RequestCtx) { return } - var payload = struct { - Keys []schemas.Key `json:"keys"` // API keys for the provider - NetworkConfig schemas.NetworkConfig `json:"network_config"` // Network-related settings - ConcurrencyAndBufferSize schemas.ConcurrencyAndBufferSize `json:"concurrency_and_buffer_size"` // Concurrency settings - ProxyConfig *schemas.ProxyConfig `json:"proxy_config,omitempty"` // Proxy configuration - SendBackRawRequest *bool `json:"send_back_raw_request,omitempty"` // Include raw request in BifrostResponse - SendBackRawResponse *bool `json:"send_back_raw_response,omitempty"` // Include raw response in BifrostResponse - StoreRawRequestResponse *bool `json:"store_raw_request_response,omitempty"` // Capture raw request/response for internal logging only - CustomProviderConfig *schemas.CustomProviderConfig `json:"custom_provider_config,omitempty"` // Custom provider configuration - }{} + var payload providerUpdatePayload if err := sonic.Unmarshal(ctx.PostBody(), &payload); err != nil { SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid JSON: %v", err)) @@ -367,20 +368,7 @@ func (h *ProviderHandler) updateProvider(ctx *fasthttp.RequestCtx) { oldConfigRaw = &configstore.ProviderConfig{} } - oldConfigRedacted, err := h.inMemoryStore.GetProviderConfigRedacted(provider) - if err != nil { - if !errors.Is(err, lib.ErrNotFound) { - logger.Warn("Failed to get old redacted config for provider %s: %v", provider, err) - SendError(ctx, fasthttp.StatusInternalServerError, err.Error()) - return - } - } - - if oldConfigRedacted == nil { - oldConfigRedacted = &configstore.ProviderConfig{} - } - - // Construct ProviderConfig from individual fields + // Construct ProviderConfig from individual fields (keys are managed separately via /keys endpoints) config := configstore.ProviderConfig{ Keys: oldConfigRaw.Keys, NetworkConfig: oldConfigRaw.NetworkConfig, @@ -392,39 +380,6 @@ func (h *ProviderHandler) updateProvider(ctx *fasthttp.RequestCtx) { Description: oldConfigRaw.Description, } - // Environment variable cleanup is now handled automatically by mergeKeys function - - var keysToAdd []schemas.Key - var keysToUpdate []schemas.Key - - for _, key := range payload.Keys { - if !slices.ContainsFunc(oldConfigRaw.Keys, func(k schemas.Key) bool { - return k.ID == key.ID - }) { - // By default new keys are enabled - key.Enabled = bifrost.Ptr(true) - keysToAdd = append(keysToAdd, key) - } else { - keysToUpdate = append(keysToUpdate, key) - } - } - - var keysToDelete []schemas.Key - for _, key := range oldConfigRaw.Keys { - if !slices.ContainsFunc(payload.Keys, func(k schemas.Key) bool { - return k.ID == key.ID - }) { - keysToDelete = append(keysToDelete, key) - } - } - - keys, err := h.mergeKeys(oldConfigRaw.Keys, oldConfigRedacted.Keys, keysToAdd, keysToDelete, keysToUpdate) - if err != nil { - SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("Invalid keys: %v", err)) - return - } - config.Keys = keys - if payload.ConcurrencyAndBufferSize.Concurrency == 0 { SendError(ctx, fasthttp.StatusBadRequest, "Concurrency must be greater than 0") return @@ -851,201 +806,6 @@ func (h *ProviderHandler) listBaseModels(ctx *fasthttp.RequestCtx) { SendJSON(ctx, ListBaseModelsResponse{Models: baseModels, Total: total}) } -// mergeKeys merges new keys with old, preserving values that are redacted in the new config -func (h *ProviderHandler) mergeKeys(oldRawKeys []schemas.Key, oldRedactedKeys []schemas.Key, keysToAdd []schemas.Key, keysToDelete []schemas.Key, keysToUpdate []schemas.Key) ([]schemas.Key, error) { - // Create a map of indices to delete - toDelete := make(map[int]bool) - for _, key := range keysToDelete { - for i, oldKey := range oldRawKeys { - if oldKey.ID == key.ID { - toDelete[i] = true - break - } - } - } - - // Create a map of updates by ID for quick lookup - updates := make(map[string]schemas.Key) - for _, key := range keysToUpdate { - updates[key.ID] = key - } - - // Map old redacted keys by ID for reliable lookup - redactedByID := make(map[string]schemas.Key) - for _, rk := range oldRedactedKeys { - redactedByID[rk.ID] = rk - } - - // Process existing keys (handle updates and deletions) - var resultKeys []schemas.Key - for i, oldRawKey := range oldRawKeys { - // Skip if this key should be deleted - if toDelete[i] { - continue - } - // Check if this key should be updated - if updateKey, exists := updates[oldRawKey.ID]; exists { - oldRedactedKey, ok := redactedByID[oldRawKey.ID] - if !ok { - oldRedactedKey = schemas.Key{} - } - mergedKey := updateKey - - // Handle redacted values - preserve old value if new value is redacted/env var AND it's the same as old redacted value - if updateKey.Value.IsRedacted() && - updateKey.Value.Equals(&oldRedactedKey.Value) { - mergedKey.Value = oldRawKey.Value - } - - // Handle Azure config redacted values - if updateKey.AzureKeyConfig != nil && oldRedactedKey.AzureKeyConfig != nil && oldRawKey.AzureKeyConfig != nil { - if updateKey.AzureKeyConfig.Endpoint.IsRedacted() && - updateKey.AzureKeyConfig.Endpoint.Equals(&oldRedactedKey.AzureKeyConfig.Endpoint) { - mergedKey.AzureKeyConfig.Endpoint = oldRawKey.AzureKeyConfig.Endpoint - } - if updateKey.AzureKeyConfig.APIVersion != nil && - oldRedactedKey.AzureKeyConfig.APIVersion != nil && - oldRawKey.AzureKeyConfig != nil { - if updateKey.AzureKeyConfig.APIVersion.IsRedacted() && - updateKey.AzureKeyConfig.APIVersion.Equals(oldRedactedKey.AzureKeyConfig.APIVersion) { - mergedKey.AzureKeyConfig.APIVersion = oldRawKey.AzureKeyConfig.APIVersion - } - } - // handle client id and secret and tenant id - if updateKey.AzureKeyConfig.ClientID != nil && - oldRedactedKey.AzureKeyConfig.ClientID != nil && - oldRawKey.AzureKeyConfig != nil { - if updateKey.AzureKeyConfig.ClientID.IsRedacted() && - updateKey.AzureKeyConfig.ClientID.Equals(oldRedactedKey.AzureKeyConfig.ClientID) { - mergedKey.AzureKeyConfig.ClientID = oldRawKey.AzureKeyConfig.ClientID - } - } - if updateKey.AzureKeyConfig.ClientSecret != nil && - oldRedactedKey.AzureKeyConfig.ClientSecret != nil && - oldRawKey.AzureKeyConfig != nil { - if updateKey.AzureKeyConfig.ClientSecret.IsRedacted() && - updateKey.AzureKeyConfig.ClientSecret.Equals(oldRedactedKey.AzureKeyConfig.ClientSecret) { - mergedKey.AzureKeyConfig.ClientSecret = oldRawKey.AzureKeyConfig.ClientSecret - } - } - if updateKey.AzureKeyConfig.TenantID != nil && - oldRedactedKey.AzureKeyConfig.TenantID != nil && - oldRawKey.AzureKeyConfig != nil { - if updateKey.AzureKeyConfig.TenantID.IsRedacted() && - updateKey.AzureKeyConfig.TenantID.Equals(oldRedactedKey.AzureKeyConfig.TenantID) { - mergedKey.AzureKeyConfig.TenantID = oldRawKey.AzureKeyConfig.TenantID - } - } - } - - // Handle Vertex config redacted values - if updateKey.VertexKeyConfig != nil && oldRedactedKey.VertexKeyConfig != nil && oldRawKey.VertexKeyConfig != nil { - if updateKey.VertexKeyConfig.ProjectID.IsRedacted() && - updateKey.VertexKeyConfig.ProjectID.Equals(&oldRedactedKey.VertexKeyConfig.ProjectID) { - mergedKey.VertexKeyConfig.ProjectID = oldRawKey.VertexKeyConfig.ProjectID - } - if updateKey.VertexKeyConfig.ProjectNumber.IsRedacted() && - updateKey.VertexKeyConfig.ProjectNumber.Equals(&oldRedactedKey.VertexKeyConfig.ProjectNumber) { - mergedKey.VertexKeyConfig.ProjectNumber = oldRawKey.VertexKeyConfig.ProjectNumber - } - if updateKey.VertexKeyConfig.Region.IsRedacted() && - updateKey.VertexKeyConfig.Region.Equals(&oldRedactedKey.VertexKeyConfig.Region) { - mergedKey.VertexKeyConfig.Region = oldRawKey.VertexKeyConfig.Region - } - if updateKey.VertexKeyConfig.AuthCredentials.IsRedacted() && - updateKey.VertexKeyConfig.AuthCredentials.Equals(&oldRedactedKey.VertexKeyConfig.AuthCredentials) { - mergedKey.VertexKeyConfig.AuthCredentials = oldRawKey.VertexKeyConfig.AuthCredentials - } - } - - // Handle Bedrock config redacted values - if updateKey.BedrockKeyConfig != nil && oldRedactedKey.BedrockKeyConfig != nil && oldRawKey.BedrockKeyConfig != nil { - if updateKey.BedrockKeyConfig.AccessKey.IsRedacted() && - updateKey.BedrockKeyConfig.AccessKey.Equals(&oldRedactedKey.BedrockKeyConfig.AccessKey) { - mergedKey.BedrockKeyConfig.AccessKey = oldRawKey.BedrockKeyConfig.AccessKey - } - if updateKey.BedrockKeyConfig.SecretKey.IsRedacted() && - updateKey.BedrockKeyConfig.SecretKey.Equals(&oldRedactedKey.BedrockKeyConfig.SecretKey) { - mergedKey.BedrockKeyConfig.SecretKey = oldRawKey.BedrockKeyConfig.SecretKey - } - if updateKey.BedrockKeyConfig.SessionToken != nil && - oldRedactedKey.BedrockKeyConfig.SessionToken != nil && - oldRawKey.BedrockKeyConfig != nil { - if updateKey.BedrockKeyConfig.SessionToken.IsRedacted() && - updateKey.BedrockKeyConfig.SessionToken.Equals(oldRedactedKey.BedrockKeyConfig.SessionToken) { - mergedKey.BedrockKeyConfig.SessionToken = oldRawKey.BedrockKeyConfig.SessionToken - } - } - if updateKey.BedrockKeyConfig.Region != nil && - oldRedactedKey.BedrockKeyConfig.Region != nil && - oldRawKey.BedrockKeyConfig != nil { - if updateKey.BedrockKeyConfig.Region.IsRedacted() && - updateKey.BedrockKeyConfig.Region.Equals(oldRedactedKey.BedrockKeyConfig.Region) { - mergedKey.BedrockKeyConfig.Region = oldRawKey.BedrockKeyConfig.Region - } - } - if updateKey.BedrockKeyConfig.ARN != nil && - oldRedactedKey.BedrockKeyConfig.ARN != nil && - oldRawKey.BedrockKeyConfig != nil { - if updateKey.BedrockKeyConfig.ARN.IsRedacted() && - updateKey.BedrockKeyConfig.ARN.Equals(oldRedactedKey.BedrockKeyConfig.ARN) { - mergedKey.BedrockKeyConfig.ARN = oldRawKey.BedrockKeyConfig.ARN - } - } - if updateKey.BedrockKeyConfig.RoleARN != nil && - oldRedactedKey.BedrockKeyConfig.RoleARN != nil && - oldRawKey.BedrockKeyConfig != nil { - if updateKey.BedrockKeyConfig.RoleARN.IsRedacted() && - updateKey.BedrockKeyConfig.RoleARN.Equals(oldRedactedKey.BedrockKeyConfig.RoleARN) { - mergedKey.BedrockKeyConfig.RoleARN = oldRawKey.BedrockKeyConfig.RoleARN - } - } - if updateKey.BedrockKeyConfig.ExternalID != nil && - oldRedactedKey.BedrockKeyConfig.ExternalID != nil && - oldRawKey.BedrockKeyConfig != nil { - if updateKey.BedrockKeyConfig.ExternalID.IsRedacted() && - updateKey.BedrockKeyConfig.ExternalID.Equals(oldRedactedKey.BedrockKeyConfig.ExternalID) { - mergedKey.BedrockKeyConfig.ExternalID = oldRawKey.BedrockKeyConfig.ExternalID - } - } - if updateKey.BedrockKeyConfig.RoleSessionName != nil && - oldRedactedKey.BedrockKeyConfig.RoleSessionName != nil && - oldRawKey.BedrockKeyConfig != nil { - if updateKey.BedrockKeyConfig.RoleSessionName.IsRedacted() && - updateKey.BedrockKeyConfig.RoleSessionName.Equals(oldRedactedKey.BedrockKeyConfig.RoleSessionName) { - mergedKey.BedrockKeyConfig.RoleSessionName = oldRawKey.BedrockKeyConfig.RoleSessionName - } - } - } - - // Handle VLLM config redacted values - if updateKey.VLLMKeyConfig != nil && oldRedactedKey.VLLMKeyConfig != nil && oldRawKey.VLLMKeyConfig != nil { - if updateKey.VLLMKeyConfig.URL.IsRedacted() && - updateKey.VLLMKeyConfig.URL.Equals(&oldRedactedKey.VLLMKeyConfig.URL) { - mergedKey.VLLMKeyConfig.URL = oldRawKey.VLLMKeyConfig.URL - } - } - - // Preserve ConfigHash from old key (UI doesn't send it back) - mergedKey.ConfigHash = oldRawKey.ConfigHash - - // Preserve Status and Description from old key (UI doesn't send them back, they're updated by model discovery) - mergedKey.Status = oldRawKey.Status - mergedKey.Description = oldRawKey.Description - - resultKeys = append(resultKeys, mergedKey) - } else { - // Keep unchanged key - resultKeys = append(resultKeys, oldRawKey) - } - } - - // Add new keys - resultKeys = append(resultKeys, keysToAdd...) - - return resultKeys, nil -} - // attemptModelDiscovery performs model discovery with timeout func (h *ProviderHandler) attemptModelDiscovery(ctx *fasthttp.RequestCtx, provider schemas.ModelProvider, customProviderConfig *schemas.CustomProviderConfig) error { // Determine if we should attempt model discovery @@ -1079,7 +839,6 @@ func (h *ProviderHandler) getProviderResponseFromConfig(provider schemas.ModelPr return ProviderResponse{ Name: provider, - Keys: config.Keys, NetworkConfig: *config.NetworkConfig, ConcurrencyAndBufferSize: *config.ConcurrencyAndBufferSize, ProxyConfig: config.ProxyConfig, diff --git a/transports/bifrost-http/handlers/providers_test.go b/transports/bifrost-http/handlers/providers_test.go new file mode 100644 index 0000000000..5ac8282f4b --- /dev/null +++ b/transports/bifrost-http/handlers/providers_test.go @@ -0,0 +1 @@ +package handlers diff --git a/transports/bifrost-http/server/server.go b/transports/bifrost-http/server/server.go index 1e710cf9ac..c17deef8c4 100644 --- a/transports/bifrost-http/server/server.go +++ b/transports/bifrost-http/server/server.go @@ -505,10 +505,10 @@ func (s *BifrostHTTPServer) ReloadProvider(ctx context.Context, provider schemas } } - bfCtx := schemas.NewBifrostContext(ctx, time.Now().Add(15*time.Second)) - bfCtx.SetValue(schemas.BifrostContextKeySkipPluginPipeline, true) - bfCtx.SetValue(schemas.BifrostContextKeyValidateKeys, true) // Validate keys during provider add/update - defer bfCtx.Cancel() + // Read current key count from in-memory store (providerInfo.Keys is not preloaded from DB) + inMemoryKeys, _ := s.Config.GetProviderKeysRaw(provider) + isKeylessProvider := bifrost.IsKeylessProvider(provider) || (providerInfo.CustomProviderConfig != nil && providerInfo.CustomProviderConfig.IsKeyLess) + hasNoKeys := len(inMemoryKeys) == 0 && !isKeylessProvider // Getting allowed models from all provider keys (needed before model listing) providerKeys, err := s.Config.ConfigStore.GetKeysByProvider(ctx, string(provider)) @@ -516,6 +516,11 @@ func (s *BifrostHTTPServer) ReloadProvider(ctx context.Context, provider schemas return nil, fmt.Errorf("failed to update provider model catalog: failed to get keys by provider: %s", err) } + bfCtx := schemas.NewBifrostContext(ctx, time.Now().Add(15*time.Second)) + bfCtx.SetValue(schemas.BifrostContextKeySkipPluginPipeline, true) + bfCtx.SetValue(schemas.BifrostContextKeyValidateKeys, true) // Validate keys during provider add/update + defer bfCtx.Cancel() + // Run filtered and unfiltered model listing concurrently var ( allModels *schemas.BifrostListModelsResponse @@ -548,7 +553,11 @@ func (s *BifrostHTTPServer) ReloadProvider(ctx context.Context, provider schemas s.updateKeyStatus(ctx, bifrostErr.ExtraFields.KeyStatuses) } - logger.Warn("failed to update provider model catalog: failed to list all models: %s. We are falling back onto the static datasheet", bifrost.GetErrorMessage(bifrostErr)) + if hasNoKeys { + logger.Warn("model discovery skipped for provider %s: no keys configured", provider) + } else { + logger.Warn("failed to update provider model catalog: failed to list all models: %s. We are falling back onto the static datasheet", bifrost.GetErrorMessage(bifrostErr)) + } // In case of error, we return an empty list of models, and fallback onto the static datasheet allModels = &schemas.BifrostListModelsResponse{ Data: make([]schemas.Model, 0), @@ -568,7 +577,11 @@ func (s *BifrostHTTPServer) ReloadProvider(ctx context.Context, provider schemas } s.Config.ModelCatalog.UpsertModelDataForProvider(provider, allModels, modelsInKeys) if listModelsErr != nil { - logger.Error("failed to list unfiltered models for provider %s: %v: falling back onto the static datasheet", provider, bifrost.GetErrorMessage(listModelsErr)) + if hasNoKeys { + logger.Warn("unfiltered model discovery skipped for provider %s: no keys configured", provider) + } else { + logger.Error("failed to list unfiltered models for provider %s: %v: falling back onto the static datasheet", provider, bifrost.GetErrorMessage(listModelsErr)) + } } else { s.Config.ModelCatalog.UpsertUnfilteredModelDataForProvider(provider, unfilteredModels) }