diff --git a/flake.lock b/flake.lock index 089c1fdaa1..6eef0dc1c4 100644 --- a/flake.lock +++ b/flake.lock @@ -2,11 +2,11 @@ "nodes": { "nixpkgs": { "locked": { - "lastModified": 1773144721, - "narHash": "sha256-1fa382ppXYOqqFIECQ3A1qogn/QLwNFvpjx/WivuNBc=", + "lastModified": 1776062742, + "narHash": "sha256-CYncVXVsUzYK+JZldSuK08ibXrAIJh+T22V13Z4ySS0=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "fb30d84f085815771af9decacb4b41b841798601", + "rev": "1c742e001e98f5191a5586751e16311fe1481f61", "type": "github" }, "original": { diff --git a/framework/modelcatalog/main.go b/framework/modelcatalog/main.go index 8972689637..6603d04db7 100644 --- a/framework/modelcatalog/main.go +++ b/framework/modelcatalog/main.go @@ -670,10 +670,23 @@ func (mc *ModelCatalog) GetProvidersForModel(model string) []schemas.ModelProvid // // Explicit allowedModels without prefix // mc.IsModelAllowedForProvider("openai", "gpt-4o", []string{"gpt-4o"}) // // Returns: true (direct match) -func (mc *ModelCatalog) IsModelAllowedForProvider(provider schemas.ModelProvider, model string, allowedModels []string) bool { - // Case 1: Empty allowedModels = use catalog to determine support +func (mc *ModelCatalog) IsModelAllowedForProvider(provider schemas.ModelProvider, model string, providerConfig *configstore.ProviderConfig, allowedModels []string) bool { + isCustomProvider := false + hasListModelsEndpointDisabled := false + if providerConfig != nil { + isCustomProvider = providerConfig.CustomProviderConfig != nil + hasListModelsEndpointDisabled = !providerConfig.CustomProviderConfig.IsOperationAllowed(schemas.ListModelsRequest) + } + + // Case 1: Unrestricted allowedModels (empty or ["*"]) = use catalog to determine support // This leverages GetProvidersForModel which already handles all cross-provider logic - if len(allowedModels) == 0 { + isUnrestricted := len(allowedModels) == 0 || (len(allowedModels) == 1 && allowedModels[0] == "*") + if isUnrestricted { + // Custom providers without a list-models endpoint can't be in the catalog, + // so allow any model through rather than blocking on missing catalog data + if isCustomProvider && hasListModelsEndpointDisabled { + return true + } supportedProviders := mc.GetProvidersForModel(model) return slices.Contains(supportedProviders, provider) } diff --git a/framework/modelcatalog/main_test.go b/framework/modelcatalog/main_test.go index 3b7e67e702..324b28c791 100644 --- a/framework/modelcatalog/main_test.go +++ b/framework/modelcatalog/main_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/framework/configstore" configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables" "github.com/stretchr/testify/assert" ) @@ -154,5 +155,56 @@ func TestIsModelAllowedForProvider_PrefixedAllowedModelInCatalog(t *testing.T) { nil, ) - assert.True(t, mc.IsModelAllowedForProvider(schemas.OpenRouter, "gpt-4o", []string{"openai/gpt-4o"})) + providerConfig := configstore.ProviderConfig{} + + assert.True(t, mc.IsModelAllowedForProvider(schemas.OpenRouter, "gpt-4o", &providerConfig, []string{"openai/gpt-4o"})) +} + +func TestIsModelAllowedForProvider_CustomProviderListModelsDisabled(t *testing.T) { + mc := newTestCatalog(nil, nil) + + // Custom provider with list-models disabled + ["*"] → should return true + providerConfig := configstore.ProviderConfig{ + CustomProviderConfig: &schemas.CustomProviderConfig{ + AllowedRequests: &schemas.AllowedRequests{ + ListModels: false, + }, + }, + } + assert.True(t, mc.IsModelAllowedForProvider("custom-provider", "any-model", &providerConfig, []string{"*"})) +} + +func TestIsModelAllowedForProvider_CustomProviderListModelsEnabled(t *testing.T) { + mc := newTestCatalog( + map[schemas.ModelProvider][]string{ + "custom-provider": {"model-a"}, + }, + nil, + ) + + // Custom provider with list-models enabled + ["*"] → should go through catalog + providerConfig := configstore.ProviderConfig{ + CustomProviderConfig: &schemas.CustomProviderConfig{ + AllowedRequests: &schemas.AllowedRequests{ + ListModels: true, + }, + }, + } + // model-a is in catalog → allowed + assert.True(t, mc.IsModelAllowedForProvider("custom-provider", "model-a", &providerConfig, []string{"*"})) + // model-b is NOT in catalog → denied + assert.False(t, mc.IsModelAllowedForProvider("custom-provider", "model-b", &providerConfig, []string{"*"})) +} + +func TestIsModelAllowedForProvider_NilProviderConfig(t *testing.T) { + mc := newTestCatalog( + map[schemas.ModelProvider][]string{ + "some-provider": {"model-x"}, + }, + nil, + ) + + // nil providerConfig + ["*"] → should go through catalog (not bypass) + assert.True(t, mc.IsModelAllowedForProvider("some-provider", "model-x", nil, []string{"*"})) + assert.False(t, mc.IsModelAllowedForProvider("some-provider", "model-y", nil, []string{"*"})) } diff --git a/plugins/governance/main.go b/plugins/governance/main.go index 0098de2fe5..89a645b905 100644 --- a/plugins/governance/main.go +++ b/plugins/governance/main.go @@ -161,7 +161,7 @@ func Init( } // Initialize components in dependency order with fixed, optimal settings // Resolver (pure decision engine for hierarchical governance, depends only on store) - resolver := NewBudgetResolver(governanceStore, modelCatalog, logger) + resolver := NewBudgetResolver(governanceStore, modelCatalog, logger, inMemoryStore) // 3. Tracker (business logic owner, depends on store and resolver) tracker := NewUsageTracker(ctx, governanceStore, resolver, configStore, logger) @@ -263,7 +263,7 @@ func InitFromStore( isVkMandatory = config.IsVkMandatory requiredHeaders = config.RequiredHeaders } - resolver := NewBudgetResolver(governanceStore, modelCatalog, logger) + resolver := NewBudgetResolver(governanceStore, modelCatalog, logger, inMemoryStore) tracker := NewUsageTracker(ctx, governanceStore, resolver, configStore, logger) engine, err := NewRoutingEngine(governanceStore, logger) if err != nil { @@ -576,8 +576,10 @@ func (p *GovernancePlugin) loadBalanceProvider(ctx *schemas.BifrostContext, req // This handles all cross-provider logic (OpenRouter, Vertex, Groq, Bedrock) // and provider-prefixed allowed_models entries isProviderAllowed := false - if p.modelCatalog != nil { - isProviderAllowed = p.modelCatalog.IsModelAllowedForProvider(schemas.ModelProvider(config.Provider), modelStr, config.AllowedModels) + if p.modelCatalog != nil && p.inMemoryStore != nil { + provider := schemas.ModelProvider(config.Provider) + providerConfig := p.inMemoryStore.GetConfiguredProviders()[provider] + isProviderAllowed = p.modelCatalog.IsModelAllowedForProvider(provider, modelStr, &providerConfig, config.AllowedModels) } else { // Fallback when model catalog is not available: simple string matching if len(config.AllowedModels) == 0 { diff --git a/plugins/governance/model_provider_governance_test.go b/plugins/governance/model_provider_governance_test.go index a17855552a..fb8d188f66 100644 --- a/plugins/governance/model_provider_governance_test.go +++ b/plugins/governance/model_provider_governance_test.go @@ -793,7 +793,7 @@ func TestResolver_EvaluateModelAndProviderRequest_NoConfigs(t *testing.T) { store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{}, nil) require.NoError(t, err) - resolver := NewBudgetResolver(store, nil, logger) + resolver := NewBudgetResolver(store, nil, logger, nil) ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) result := resolver.EvaluateModelAndProviderRequest(ctx, schemas.OpenAI, "gpt-4") @@ -810,7 +810,7 @@ func TestResolver_EvaluateModelAndProviderRequest_ProviderBudgetExceeded(t *test }, nil) require.NoError(t, err) - resolver := NewBudgetResolver(store, nil, logger) + resolver := NewBudgetResolver(store, nil, logger, nil) ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) result := resolver.EvaluateModelAndProviderRequest(ctx, schemas.OpenAI, "gpt-4") @@ -828,7 +828,7 @@ func TestResolver_EvaluateModelAndProviderRequest_ProviderRateLimitExceeded(t *t }, nil) require.NoError(t, err) - resolver := NewBudgetResolver(store, nil, logger) + resolver := NewBudgetResolver(store, nil, logger, nil) ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) result := resolver.EvaluateModelAndProviderRequest(ctx, schemas.OpenAI, "gpt-4") @@ -846,7 +846,7 @@ func TestResolver_EvaluateModelAndProviderRequest_ModelBudgetExceeded(t *testing }, nil) require.NoError(t, err) - resolver := NewBudgetResolver(store, nil, logger) + resolver := NewBudgetResolver(store, nil, logger, nil) ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) result := resolver.EvaluateModelAndProviderRequest(ctx, schemas.OpenAI, "gpt-4") @@ -864,7 +864,7 @@ func TestResolver_EvaluateModelAndProviderRequest_ModelRateLimitExceeded(t *test }, nil) require.NoError(t, err) - resolver := NewBudgetResolver(store, nil, logger) + resolver := NewBudgetResolver(store, nil, logger, nil) ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) result := resolver.EvaluateModelAndProviderRequest(ctx, schemas.OpenAI, "gpt-4") @@ -882,7 +882,7 @@ func TestResolver_EvaluateModelAndProviderRequest_ModelRateLimitExceeded_Request }, nil) require.NoError(t, err) - resolver := NewBudgetResolver(store, nil, logger) + resolver := NewBudgetResolver(store, nil, logger, nil) ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) result := resolver.EvaluateModelAndProviderRequest(ctx, schemas.OpenAI, "gpt-4") @@ -905,7 +905,7 @@ func TestResolver_EvaluateModelAndProviderRequest_ProviderBudgetThenModelBudget( }, nil) require.NoError(t, err) - resolver := NewBudgetResolver(store, nil, logger) + resolver := NewBudgetResolver(store, nil, logger, nil) ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) result := resolver.EvaluateModelAndProviderRequest(ctx, schemas.OpenAI, "gpt-4") @@ -929,7 +929,7 @@ func TestResolver_EvaluateModelAndProviderRequest_ProviderRateLimitThenModelRate }, nil) require.NoError(t, err) - resolver := NewBudgetResolver(store, nil, logger) + resolver := NewBudgetResolver(store, nil, logger, nil) ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) result := resolver.EvaluateModelAndProviderRequest(ctx, schemas.OpenAI, "gpt-4") @@ -953,7 +953,7 @@ func TestResolver_EvaluateModelAndProviderRequest_ProviderRateLimitThenModelRate }, nil) require.NoError(t, err) - resolver := NewBudgetResolver(store, nil, logger) + resolver := NewBudgetResolver(store, nil, logger, nil) ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) result := resolver.EvaluateModelAndProviderRequest(ctx, schemas.OpenAI, "gpt-4") @@ -980,7 +980,7 @@ func TestResolver_EvaluateModelAndProviderRequest_AllChecksPass(t *testing.T) { }, nil) require.NoError(t, err) - resolver := NewBudgetResolver(store, nil, logger) + resolver := NewBudgetResolver(store, nil, logger, nil) ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) result := resolver.EvaluateModelAndProviderRequest(ctx, schemas.OpenAI, "gpt-4") @@ -998,7 +998,7 @@ func TestResolver_EvaluateModelAndProviderRequest_ProviderOnly_NoModel(t *testin }, nil) require.NoError(t, err) - resolver := NewBudgetResolver(store, nil, logger) + resolver := NewBudgetResolver(store, nil, logger, nil) ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) // No model provided @@ -1016,7 +1016,7 @@ func TestResolver_EvaluateModelAndProviderRequest_ModelOnly_NoProvider(t *testin }, nil) require.NoError(t, err) - resolver := NewBudgetResolver(store, nil, logger) + resolver := NewBudgetResolver(store, nil, logger, nil) ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) // No provider provided @@ -1036,7 +1036,7 @@ func TestResolver_EvaluateModelAndProviderRequest_ProviderSpecificBudget_Differe }, nil) require.NoError(t, err) - resolver := NewBudgetResolver(store, nil, logger) + resolver := NewBudgetResolver(store, nil, logger, nil) ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) // Request with Azure (different provider) for same model should pass @@ -1056,7 +1056,7 @@ func TestResolver_EvaluateModelAndProviderRequest_ProviderSpecificRateLimit_Diff }, nil) require.NoError(t, err) - resolver := NewBudgetResolver(store, nil, logger) + resolver := NewBudgetResolver(store, nil, logger, nil) ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) // Request with Azure (different provider) for same model should pass @@ -1076,7 +1076,7 @@ func TestResolver_EvaluateModelAndProviderRequest_ProviderSpecificRateLimit_Diff }, nil) require.NoError(t, err) - resolver := NewBudgetResolver(store, nil, logger) + resolver := NewBudgetResolver(store, nil, logger, nil) ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) // Request with Azure (different provider) for same model should pass diff --git a/plugins/governance/resolver.go b/plugins/governance/resolver.go index 8d3da777af..bccca93a67 100644 --- a/plugins/governance/resolver.go +++ b/plugins/governance/resolver.go @@ -62,17 +62,19 @@ type UsageInfo struct { // BudgetResolver provides decision logic for the new hierarchical governance system type BudgetResolver struct { - store GovernanceStore - logger schemas.Logger - modelCatalog *modelcatalog.ModelCatalog + store GovernanceStore + logger schemas.Logger + modelCatalog *modelcatalog.ModelCatalog + governanceInMemoryStore InMemoryStore } // NewBudgetResolver creates a new budget-based governance resolver -func NewBudgetResolver(store GovernanceStore, modelCatalog *modelcatalog.ModelCatalog, logger schemas.Logger) *BudgetResolver { +func NewBudgetResolver(store GovernanceStore, modelCatalog *modelcatalog.ModelCatalog, logger schemas.Logger, governanceInMemoryStore InMemoryStore) *BudgetResolver { return &BudgetResolver{ - store: store, - logger: logger, - modelCatalog: modelCatalog, + store: store, + logger: logger, + modelCatalog: modelCatalog, + governanceInMemoryStore: governanceInMemoryStore, } } @@ -334,8 +336,9 @@ func (r *BudgetResolver) isModelAllowed(vk *configstoreTables.TableVirtualKey, p // Delegate model allowance check to model catalog // This handles all cross-provider logic (OpenRouter, Vertex, Groq, Bedrock) // and provider-prefixed allowed_models entries - if r.modelCatalog != nil { - return r.modelCatalog.IsModelAllowedForProvider(provider, model, pc.AllowedModels) + if r.modelCatalog != nil && r.governanceInMemoryStore != nil { + providerConfig := r.governanceInMemoryStore.GetConfiguredProviders()[provider] + return r.modelCatalog.IsModelAllowedForProvider(provider, model, &providerConfig, pc.AllowedModels) } // Fallback when model catalog is not available: simple string matching if len(pc.AllowedModels) == 0 { diff --git a/plugins/governance/resolver_test.go b/plugins/governance/resolver_test.go index 7e55c57328..5dc1d3e0d7 100644 --- a/plugins/governance/resolver_test.go +++ b/plugins/governance/resolver_test.go @@ -23,7 +23,7 @@ func TestBudgetResolver_EvaluateRequest_AllowedRequest(t *testing.T) { }, nil) require.NoError(t, err) - resolver := NewBudgetResolver(store, nil, logger) + resolver := NewBudgetResolver(store, nil, logger, nil) ctx := &schemas.BifrostContext{} result := resolver.EvaluateVirtualKeyRequest(ctx, "sk-bf-test", schemas.OpenAI, "gpt-4", schemas.ChatCompletionRequest) @@ -38,7 +38,7 @@ func TestBudgetResolver_EvaluateRequest_VirtualKeyNotFound(t *testing.T) { store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{}, nil) require.NoError(t, err) - resolver := NewBudgetResolver(store, nil, logger) + resolver := NewBudgetResolver(store, nil, logger, nil) ctx := &schemas.BifrostContext{} result := resolver.EvaluateVirtualKeyRequest(ctx, "sk-bf-nonexistent", schemas.OpenAI, "gpt-4", schemas.ChatCompletionRequest) @@ -56,7 +56,7 @@ func TestBudgetResolver_EvaluateRequest_VirtualKeyBlocked(t *testing.T) { }, nil) require.NoError(t, err) - resolver := NewBudgetResolver(store, nil, logger) + resolver := NewBudgetResolver(store, nil, logger, nil) ctx := &schemas.BifrostContext{} result := resolver.EvaluateVirtualKeyRequest(ctx, "sk-bf-test", schemas.OpenAI, "gpt-4", schemas.ChatCompletionRequest) @@ -79,7 +79,7 @@ func TestBudgetResolver_EvaluateRequest_ProviderBlocked(t *testing.T) { }, nil) require.NoError(t, err) - resolver := NewBudgetResolver(store, nil, logger) + resolver := NewBudgetResolver(store, nil, logger, nil) ctx := &schemas.BifrostContext{} // Try to use OpenAI (not allowed) @@ -111,7 +111,7 @@ func TestBudgetResolver_EvaluateRequest_ModelBlocked(t *testing.T) { }, nil) require.NoError(t, err) - resolver := NewBudgetResolver(store, nil, logger) + resolver := NewBudgetResolver(store, nil, logger, nil) ctx := &schemas.BifrostContext{} // Try to use gpt-4o-mini (not in allowed list) @@ -134,7 +134,7 @@ func TestBudgetResolver_EvaluateRequest_RateLimitExceeded_TokenLimit(t *testing. }, nil) require.NoError(t, err) - resolver := NewBudgetResolver(store, nil, logger) + resolver := NewBudgetResolver(store, nil, logger, nil) ctx := &schemas.BifrostContext{} result := resolver.EvaluateVirtualKeyRequest(ctx, "sk-bf-test", schemas.OpenAI, "gpt-4", schemas.ChatCompletionRequest) @@ -157,7 +157,7 @@ func TestBudgetResolver_EvaluateRequest_RateLimitExceeded_RequestLimit(t *testin }, nil) require.NoError(t, err) - resolver := NewBudgetResolver(store, nil, logger) + resolver := NewBudgetResolver(store, nil, logger, nil) ctx := &schemas.BifrostContext{} result := resolver.EvaluateVirtualKeyRequest(ctx, "sk-bf-test", schemas.OpenAI, "gpt-4", schemas.ChatCompletionRequest) @@ -195,7 +195,7 @@ func TestBudgetResolver_EvaluateRequest_RateLimitExpired(t *testing.T) { err = store.ResetExpiredRateLimits(context.Background(), expiredRateLimits) require.NoError(t, err) - resolver := NewBudgetResolver(store, nil, logger) + resolver := NewBudgetResolver(store, nil, logger, nil) ctx := &schemas.BifrostContext{} result := resolver.EvaluateVirtualKeyRequest(ctx, "sk-bf-test", schemas.OpenAI, "gpt-4", schemas.ChatCompletionRequest) @@ -217,7 +217,7 @@ func TestBudgetResolver_EvaluateRequest_BudgetExceeded(t *testing.T) { }, nil) require.NoError(t, err) - resolver := NewBudgetResolver(store, nil, logger) + resolver := NewBudgetResolver(store, nil, logger, nil) ctx := &schemas.BifrostContext{} result := resolver.EvaluateVirtualKeyRequest(ctx, "sk-bf-test", schemas.OpenAI, "gpt-4", schemas.ChatCompletionRequest) @@ -244,7 +244,7 @@ func TestBudgetResolver_EvaluateRequest_BudgetExpired(t *testing.T) { }, nil) require.NoError(t, err) - resolver := NewBudgetResolver(store, nil, logger) + resolver := NewBudgetResolver(store, nil, logger, nil) ctx := &schemas.BifrostContext{} result := resolver.EvaluateVirtualKeyRequest(ctx, "sk-bf-test", schemas.OpenAI, "gpt-4", schemas.ChatCompletionRequest) @@ -278,7 +278,7 @@ func TestBudgetResolver_EvaluateRequest_MultiLevelBudgetHierarchy(t *testing.T) }, nil) require.NoError(t, err) - resolver := NewBudgetResolver(store, nil, logger) + resolver := NewBudgetResolver(store, nil, logger, nil) ctx := &schemas.BifrostContext{} // Test: All under limit should pass @@ -312,7 +312,7 @@ func TestBudgetResolver_EvaluateRequest_ProviderLevelRateLimit(t *testing.T) { }, nil) require.NoError(t, err) - resolver := NewBudgetResolver(store, nil, logger) + resolver := NewBudgetResolver(store, nil, logger, nil) ctx := &schemas.BifrostContext{} result := resolver.EvaluateVirtualKeyRequest(ctx, "sk-bf-test", schemas.OpenAI, "gpt-4", schemas.ChatCompletionRequest) @@ -335,7 +335,7 @@ func TestBudgetResolver_CheckRateLimits_BothExceeded(t *testing.T) { }, nil) require.NoError(t, err) - resolver := NewBudgetResolver(store, nil, logger) + resolver := NewBudgetResolver(store, nil, logger, nil) ctx := &schemas.BifrostContext{} result := resolver.EvaluateVirtualKeyRequest(ctx, "sk-bf-test", schemas.OpenAI, "gpt-4", schemas.ChatCompletionRequest) @@ -350,7 +350,7 @@ func TestBudgetResolver_IsProviderAllowed(t *testing.T) { store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{}, nil) require.NoError(t, err) - resolver := NewBudgetResolver(store, nil, logger) + resolver := NewBudgetResolver(store, nil, logger, nil) tests := []struct { name string @@ -398,7 +398,7 @@ func TestBudgetResolver_IsModelAllowed(t *testing.T) { store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{}, nil) require.NoError(t, err) - resolver := NewBudgetResolver(store, nil, logger) + resolver := NewBudgetResolver(store, nil, logger, nil) tests := []struct { name string @@ -473,7 +473,7 @@ func TestBudgetResolver_ContextPopulation(t *testing.T) { }, nil) require.NoError(t, err) - resolver := NewBudgetResolver(store, nil, logger) + resolver := NewBudgetResolver(store, nil, logger, nil) ctx := &schemas.BifrostContext{} result := resolver.EvaluateVirtualKeyRequest(ctx, "sk-bf-test", schemas.OpenAI, "gpt-4", schemas.ChatCompletionRequest) diff --git a/plugins/governance/tracker_test.go b/plugins/governance/tracker_test.go index 7102a8e06d..6af947d0fa 100644 --- a/plugins/governance/tracker_test.go +++ b/plugins/governance/tracker_test.go @@ -25,7 +25,7 @@ func TestUsageTracker_UpdateUsage_FailedRequest(t *testing.T) { }, nil) require.NoError(t, err) - resolver := NewBudgetResolver(store, nil, logger) + resolver := NewBudgetResolver(store, nil, logger, nil) tracker := NewUsageTracker(context.Background(), store, resolver, nil, logger) defer tracker.Cleanup() @@ -60,7 +60,7 @@ func TestUsageTracker_UpdateUsage_VirtualKeyNotFound(t *testing.T) { store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{}, nil) require.NoError(t, err) - resolver := NewBudgetResolver(store, nil, logger) + resolver := NewBudgetResolver(store, nil, logger, nil) tracker := NewUsageTracker(context.Background(), store, resolver, nil, logger) defer tracker.Cleanup() @@ -94,7 +94,7 @@ func TestUsageTracker_UpdateUsage_StreamingOptimization(t *testing.T) { }, nil) require.NoError(t, err) - resolver := NewBudgetResolver(store, nil, logger) + resolver := NewBudgetResolver(store, nil, logger, nil) tracker := NewUsageTracker(context.Background(), store, resolver, nil, logger) defer tracker.Cleanup() @@ -157,7 +157,7 @@ func TestUsageTracker_Cleanup(t *testing.T) { store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{}, nil) require.NoError(t, err) - resolver := NewBudgetResolver(store, nil, logger) + resolver := NewBudgetResolver(store, nil, logger, nil) tracker := NewUsageTracker(context.Background(), store, resolver, nil, logger) // Should cleanup without error diff --git a/plugins/governance/utils.go b/plugins/governance/utils.go index f33abb3aa1..3ba0849f6d 100644 --- a/plugins/governance/utils.go +++ b/plugins/governance/utils.go @@ -2,6 +2,7 @@ package governance import ( + "slices" "strings" bifrost "github.com/maximhq/bifrost/core" @@ -113,9 +114,17 @@ func (p *GovernancePlugin) filterModelsForVirtualKey( isAllowed := false for _, pc := range vk.ProviderConfigs { if pc.Provider == string(provider) { - if p.modelCatalog.IsModelAllowedForProvider(provider, modelName, pc.AllowedModels) { - isAllowed = true - break + if p.modelCatalog != nil && p.inMemoryStore != nil { + providerConfig := p.inMemoryStore.GetConfiguredProviders()[provider] + if p.modelCatalog.IsModelAllowedForProvider(provider, modelName, &providerConfig, pc.AllowedModels) { + isAllowed = true + break + } + } else { + if len(pc.AllowedModels) == 0 || slices.Contains(pc.AllowedModels, modelName) { + isAllowed = true + break + } } } } diff --git a/transports/bifrost-http/handlers/providers.go b/transports/bifrost-http/handlers/providers.go index 1696c9530a..b1be3d9ba9 100644 --- a/transports/bifrost-http/handlers/providers.go +++ b/transports/bifrost-http/handlers/providers.go @@ -357,16 +357,16 @@ func (h *ProviderHandler) updateProvider(ctx *fasthttp.RequestCtx) { } var payload = struct { - Keys []schemas.Key `json:"keys"` // API keys for the provider - NetworkConfig schemas.NetworkConfig `json:"network_config"` // Network-related settings - ConcurrencyAndBufferSize schemas.ConcurrencyAndBufferSize `json:"concurrency_and_buffer_size"` // Concurrency settings - ProxyConfig *schemas.ProxyConfig `json:"proxy_config,omitempty"` // Proxy configuration - SendBackRawRequest *bool `json:"send_back_raw_request,omitempty"` // Include raw request in BifrostResponse - SendBackRawResponse *bool `json:"send_back_raw_response,omitempty"` // Include raw response in BifrostResponse + Keys []schemas.Key `json:"keys"` // API keys for the provider + NetworkConfig schemas.NetworkConfig `json:"network_config"` // Network-related settings + ConcurrencyAndBufferSize schemas.ConcurrencyAndBufferSize `json:"concurrency_and_buffer_size"` // Concurrency settings + ProxyConfig *schemas.ProxyConfig `json:"proxy_config,omitempty"` // Proxy configuration + SendBackRawRequest *bool `json:"send_back_raw_request,omitempty"` // Include raw request in BifrostResponse + SendBackRawResponse *bool `json:"send_back_raw_response,omitempty"` // Include raw response in BifrostResponse StoreRawRequestResponse *bool `json:"store_raw_request_response,omitempty"` // Capture raw request/response for internal logging only - CustomProviderConfig *schemas.CustomProviderConfig `json:"custom_provider_config,omitempty"` // Custom provider configuration - OpenAIConfig *schemas.OpenAIConfig `json:"openai_config,omitempty"` // OpenAI-specific configuration - PricingOverrides []schemas.ProviderPricingOverride `json:"pricing_overrides,omitempty"` // Provider-level pricing overrides + CustomProviderConfig *schemas.CustomProviderConfig `json:"custom_provider_config,omitempty"` // Custom provider configuration + OpenAIConfig *schemas.OpenAIConfig `json:"openai_config,omitempty"` // OpenAI-specific configuration + PricingOverrides []schemas.ProviderPricingOverride `json:"pricing_overrides,omitempty"` // Provider-level pricing overrides }{} if err := sonic.Unmarshal(ctx.PostBody(), &payload); err != nil { @@ -891,19 +891,19 @@ func (h *ProviderHandler) getModelParameters(ctx *fasthttp.RequestCtx) { } // keyAllowsModelForList reports whether a provider key permits model for catalog listing. -func keyAllowsModelForList(provider schemas.ModelProvider, model string, key schemas.Key, modelCatalog *modelcatalog.ModelCatalog) bool { - if len(key.BlacklistedModels) > 0 && keyModelListAllowsModel(provider, model, key.BlacklistedModels, modelCatalog) { +func keyAllowsModelForList(provider schemas.ModelProvider, model string, providerConfig *configstore.ProviderConfig, key schemas.Key, modelCatalog *modelcatalog.ModelCatalog) bool { + if len(key.BlacklistedModels) > 0 && keyModelListAllowsModel(provider, model, providerConfig, key.BlacklistedModels, modelCatalog) { return false } if len(key.Models) > 0 { - return keyModelListAllowsModel(provider, model, key.Models, modelCatalog) + return keyModelListAllowsModel(provider, model, providerConfig, key.Models, modelCatalog) } return true } // keyModelListAllowsModel reports whether model matches a key allow/deny list entry, // using catalog-aware alias matching when model metadata is available. -func keyModelListAllowsModel(provider schemas.ModelProvider, model string, allowedModels []string, modelCatalog *modelcatalog.ModelCatalog) bool { +func keyModelListAllowsModel(provider schemas.ModelProvider, model string, providerConfig *configstore.ProviderConfig, allowedModels []string, modelCatalog *modelcatalog.ModelCatalog) bool { if len(allowedModels) == 0 { return false } @@ -912,7 +912,7 @@ func keyModelListAllowsModel(provider schemas.ModelProvider, model string, allow return slices.Contains(allowedModels, model) } - if modelCatalog.IsModelAllowedForProvider(provider, model, allowedModels) { + if modelCatalog.IsModelAllowedForProvider(provider, model, providerConfig, allowedModels) { return true } @@ -1010,7 +1010,7 @@ func filterModelsByKeysWithAccessMap(config *configstore.ProviderConfig, provide for _, model := range models { grantedBy := make([]string, 0, len(matchedKeys)) for _, matched := range matchedKeys { - if keyAllowsModelForList(provider, model, matched.key, modelCatalog) { + if keyAllowsModelForList(provider, model, config, matched.key, modelCatalog) { grantedBy = append(grantedBy, matched.id) } } @@ -1391,8 +1391,8 @@ func validatePricingOverrideNonNegativeFields(index int, override schemas.Provid "input_cost_per_token_above_200k_tokens": override.InputCostPerTokenAbove200kTokens, "output_cost_per_token_above_200k_tokens": override.OutputCostPerTokenAbove200kTokens, "cache_creation_input_token_cost_above_200k_tokens": override.CacheCreationInputTokenCostAbove200kTokens, - "cache_read_input_token_cost_above_200k_tokens": override.CacheReadInputTokenCostAbove200kTokens, - "cache_read_input_token_cost": override.CacheReadInputTokenCost, + "cache_read_input_token_cost_above_200k_tokens": override.CacheReadInputTokenCostAbove200kTokens, + "cache_read_input_token_cost": override.CacheReadInputTokenCost, "cache_creation_input_token_cost": override.CacheCreationInputTokenCost, "input_cost_per_token_batches": override.InputCostPerTokenBatches, "output_cost_per_token_batches": override.OutputCostPerTokenBatches,