Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions flake.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 16 additions & 3 deletions framework/modelcatalog/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -670,10 +670,23 @@ func (mc *ModelCatalog) GetProvidersForModel(model string) []schemas.ModelProvid
// // Explicit allowedModels without prefix
// mc.IsModelAllowedForProvider("openai", "gpt-4o", []string{"gpt-4o"})
// // Returns: true (direct match)
func (mc *ModelCatalog) IsModelAllowedForProvider(provider schemas.ModelProvider, model string, allowedModels []string) bool {
// Case 1: Empty allowedModels = use catalog to determine support
func (mc *ModelCatalog) IsModelAllowedForProvider(provider schemas.ModelProvider, model string, providerConfig *configstore.ProviderConfig, allowedModels []string) bool {
isCustomProvider := false
hasListModelsEndpointDisabled := false
if providerConfig != nil {
isCustomProvider = providerConfig.CustomProviderConfig != nil
hasListModelsEndpointDisabled = !providerConfig.CustomProviderConfig.IsOperationAllowed(schemas.ListModelsRequest)
Comment thread
danpiths marked this conversation as resolved.
}

// Case 1: Unrestricted allowedModels (empty or ["*"]) = use catalog to determine support
// This leverages GetProvidersForModel which already handles all cross-provider logic
if len(allowedModels) == 0 {
isUnrestricted := len(allowedModels) == 0 || (len(allowedModels) == 1 && allowedModels[0] == "*")
if isUnrestricted {
// Custom providers without a list-models endpoint can't be in the catalog,
// so allow any model through rather than blocking on missing catalog data
if isCustomProvider && hasListModelsEndpointDisabled {
return true
}
supportedProviders := mc.GetProvidersForModel(model)
return slices.Contains(supportedProviders, provider)
}
Expand Down
54 changes: 53 additions & 1 deletion framework/modelcatalog/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"testing"

"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/framework/configstore"
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
"github.com/stretchr/testify/assert"
)
Expand Down Expand Up @@ -154,5 +155,56 @@ func TestIsModelAllowedForProvider_PrefixedAllowedModelInCatalog(t *testing.T) {
nil,
)

assert.True(t, mc.IsModelAllowedForProvider(schemas.OpenRouter, "gpt-4o", []string{"openai/gpt-4o"}))
providerConfig := configstore.ProviderConfig{}

assert.True(t, mc.IsModelAllowedForProvider(schemas.OpenRouter, "gpt-4o", &providerConfig, []string{"openai/gpt-4o"}))
}

func TestIsModelAllowedForProvider_CustomProviderListModelsDisabled(t *testing.T) {
mc := newTestCatalog(nil, nil)

// Custom provider with list-models disabled + ["*"] → should return true
providerConfig := configstore.ProviderConfig{
CustomProviderConfig: &schemas.CustomProviderConfig{
AllowedRequests: &schemas.AllowedRequests{
ListModels: false,
},
},
}
assert.True(t, mc.IsModelAllowedForProvider("custom-provider", "any-model", &providerConfig, []string{"*"}))
}

func TestIsModelAllowedForProvider_CustomProviderListModelsEnabled(t *testing.T) {
mc := newTestCatalog(
map[schemas.ModelProvider][]string{
"custom-provider": {"model-a"},
},
nil,
)

// Custom provider with list-models enabled + ["*"] → should go through catalog
providerConfig := configstore.ProviderConfig{
CustomProviderConfig: &schemas.CustomProviderConfig{
AllowedRequests: &schemas.AllowedRequests{
ListModels: true,
},
},
}
// model-a is in catalog → allowed
assert.True(t, mc.IsModelAllowedForProvider("custom-provider", "model-a", &providerConfig, []string{"*"}))
// model-b is NOT in catalog → denied
assert.False(t, mc.IsModelAllowedForProvider("custom-provider", "model-b", &providerConfig, []string{"*"}))
}

func TestIsModelAllowedForProvider_NilProviderConfig(t *testing.T) {
mc := newTestCatalog(
map[schemas.ModelProvider][]string{
"some-provider": {"model-x"},
},
nil,
)

// nil providerConfig + ["*"] → should go through catalog (not bypass)
assert.True(t, mc.IsModelAllowedForProvider("some-provider", "model-x", nil, []string{"*"}))
assert.False(t, mc.IsModelAllowedForProvider("some-provider", "model-y", nil, []string{"*"}))
}
10 changes: 6 additions & 4 deletions plugins/governance/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ func Init(
}
// Initialize components in dependency order with fixed, optimal settings
// Resolver (pure decision engine for hierarchical governance, depends only on store)
resolver := NewBudgetResolver(governanceStore, modelCatalog, logger)
resolver := NewBudgetResolver(governanceStore, modelCatalog, logger, inMemoryStore)

// 3. Tracker (business logic owner, depends on store and resolver)
tracker := NewUsageTracker(ctx, governanceStore, resolver, configStore, logger)
Expand Down Expand Up @@ -263,7 +263,7 @@ func InitFromStore(
isVkMandatory = config.IsVkMandatory
requiredHeaders = config.RequiredHeaders
}
resolver := NewBudgetResolver(governanceStore, modelCatalog, logger)
resolver := NewBudgetResolver(governanceStore, modelCatalog, logger, inMemoryStore)
tracker := NewUsageTracker(ctx, governanceStore, resolver, configStore, logger)
engine, err := NewRoutingEngine(governanceStore, logger)
if err != nil {
Expand Down Expand Up @@ -576,8 +576,10 @@ func (p *GovernancePlugin) loadBalanceProvider(ctx *schemas.BifrostContext, req
// This handles all cross-provider logic (OpenRouter, Vertex, Groq, Bedrock)
// and provider-prefixed allowed_models entries
isProviderAllowed := false
if p.modelCatalog != nil {
isProviderAllowed = p.modelCatalog.IsModelAllowedForProvider(schemas.ModelProvider(config.Provider), modelStr, config.AllowedModels)
if p.modelCatalog != nil && p.inMemoryStore != nil {
provider := schemas.ModelProvider(config.Provider)
providerConfig := p.inMemoryStore.GetConfiguredProviders()[provider]
isProviderAllowed = p.modelCatalog.IsModelAllowedForProvider(provider, modelStr, &providerConfig, config.AllowedModels)
} else {
// Fallback when model catalog is not available: simple string matching
if len(config.AllowedModels) == 0 {
Expand Down
30 changes: 15 additions & 15 deletions plugins/governance/model_provider_governance_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -793,7 +793,7 @@ func TestResolver_EvaluateModelAndProviderRequest_NoConfigs(t *testing.T) {
store, err := NewLocalGovernanceStore(context.Background(), logger, nil, &configstore.GovernanceConfig{}, nil)
require.NoError(t, err)

resolver := NewBudgetResolver(store, nil, logger)
resolver := NewBudgetResolver(store, nil, logger, nil)
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)

result := resolver.EvaluateModelAndProviderRequest(ctx, schemas.OpenAI, "gpt-4")
Expand All @@ -810,7 +810,7 @@ func TestResolver_EvaluateModelAndProviderRequest_ProviderBudgetExceeded(t *test
}, nil)
require.NoError(t, err)

resolver := NewBudgetResolver(store, nil, logger)
resolver := NewBudgetResolver(store, nil, logger, nil)
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)

result := resolver.EvaluateModelAndProviderRequest(ctx, schemas.OpenAI, "gpt-4")
Expand All @@ -828,7 +828,7 @@ func TestResolver_EvaluateModelAndProviderRequest_ProviderRateLimitExceeded(t *t
}, nil)
require.NoError(t, err)

resolver := NewBudgetResolver(store, nil, logger)
resolver := NewBudgetResolver(store, nil, logger, nil)
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)

result := resolver.EvaluateModelAndProviderRequest(ctx, schemas.OpenAI, "gpt-4")
Expand All @@ -846,7 +846,7 @@ func TestResolver_EvaluateModelAndProviderRequest_ModelBudgetExceeded(t *testing
}, nil)
require.NoError(t, err)

resolver := NewBudgetResolver(store, nil, logger)
resolver := NewBudgetResolver(store, nil, logger, nil)
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)

result := resolver.EvaluateModelAndProviderRequest(ctx, schemas.OpenAI, "gpt-4")
Expand All @@ -864,7 +864,7 @@ func TestResolver_EvaluateModelAndProviderRequest_ModelRateLimitExceeded(t *test
}, nil)
require.NoError(t, err)

resolver := NewBudgetResolver(store, nil, logger)
resolver := NewBudgetResolver(store, nil, logger, nil)
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)

result := resolver.EvaluateModelAndProviderRequest(ctx, schemas.OpenAI, "gpt-4")
Expand All @@ -882,7 +882,7 @@ func TestResolver_EvaluateModelAndProviderRequest_ModelRateLimitExceeded_Request
}, nil)
require.NoError(t, err)

resolver := NewBudgetResolver(store, nil, logger)
resolver := NewBudgetResolver(store, nil, logger, nil)
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)

result := resolver.EvaluateModelAndProviderRequest(ctx, schemas.OpenAI, "gpt-4")
Expand All @@ -905,7 +905,7 @@ func TestResolver_EvaluateModelAndProviderRequest_ProviderBudgetThenModelBudget(
}, nil)
require.NoError(t, err)

resolver := NewBudgetResolver(store, nil, logger)
resolver := NewBudgetResolver(store, nil, logger, nil)
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)

result := resolver.EvaluateModelAndProviderRequest(ctx, schemas.OpenAI, "gpt-4")
Expand All @@ -929,7 +929,7 @@ func TestResolver_EvaluateModelAndProviderRequest_ProviderRateLimitThenModelRate
}, nil)
require.NoError(t, err)

resolver := NewBudgetResolver(store, nil, logger)
resolver := NewBudgetResolver(store, nil, logger, nil)
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)

result := resolver.EvaluateModelAndProviderRequest(ctx, schemas.OpenAI, "gpt-4")
Expand All @@ -953,7 +953,7 @@ func TestResolver_EvaluateModelAndProviderRequest_ProviderRateLimitThenModelRate
}, nil)
require.NoError(t, err)

resolver := NewBudgetResolver(store, nil, logger)
resolver := NewBudgetResolver(store, nil, logger, nil)
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)

result := resolver.EvaluateModelAndProviderRequest(ctx, schemas.OpenAI, "gpt-4")
Expand All @@ -980,7 +980,7 @@ func TestResolver_EvaluateModelAndProviderRequest_AllChecksPass(t *testing.T) {
}, nil)
require.NoError(t, err)

resolver := NewBudgetResolver(store, nil, logger)
resolver := NewBudgetResolver(store, nil, logger, nil)
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)

result := resolver.EvaluateModelAndProviderRequest(ctx, schemas.OpenAI, "gpt-4")
Expand All @@ -998,7 +998,7 @@ func TestResolver_EvaluateModelAndProviderRequest_ProviderOnly_NoModel(t *testin
}, nil)
require.NoError(t, err)

resolver := NewBudgetResolver(store, nil, logger)
resolver := NewBudgetResolver(store, nil, logger, nil)
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)

// No model provided
Expand All @@ -1016,7 +1016,7 @@ func TestResolver_EvaluateModelAndProviderRequest_ModelOnly_NoProvider(t *testin
}, nil)
require.NoError(t, err)

resolver := NewBudgetResolver(store, nil, logger)
resolver := NewBudgetResolver(store, nil, logger, nil)
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)

// No provider provided
Expand All @@ -1036,7 +1036,7 @@ func TestResolver_EvaluateModelAndProviderRequest_ProviderSpecificBudget_Differe
}, nil)
require.NoError(t, err)

resolver := NewBudgetResolver(store, nil, logger)
resolver := NewBudgetResolver(store, nil, logger, nil)
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)

// Request with Azure (different provider) for same model should pass
Expand All @@ -1056,7 +1056,7 @@ func TestResolver_EvaluateModelAndProviderRequest_ProviderSpecificRateLimit_Diff
}, nil)
require.NoError(t, err)

resolver := NewBudgetResolver(store, nil, logger)
resolver := NewBudgetResolver(store, nil, logger, nil)
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)

// Request with Azure (different provider) for same model should pass
Expand All @@ -1076,7 +1076,7 @@ func TestResolver_EvaluateModelAndProviderRequest_ProviderSpecificRateLimit_Diff
}, nil)
require.NoError(t, err)

resolver := NewBudgetResolver(store, nil, logger)
resolver := NewBudgetResolver(store, nil, logger, nil)
ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline)

// Request with Azure (different provider) for same model should pass
Expand Down
21 changes: 12 additions & 9 deletions plugins/governance/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,17 +62,19 @@ type UsageInfo struct {

// BudgetResolver provides decision logic for the new hierarchical governance system
type BudgetResolver struct {
store GovernanceStore
logger schemas.Logger
modelCatalog *modelcatalog.ModelCatalog
store GovernanceStore
logger schemas.Logger
modelCatalog *modelcatalog.ModelCatalog
governanceInMemoryStore InMemoryStore
}

// NewBudgetResolver creates a new budget-based governance resolver
func NewBudgetResolver(store GovernanceStore, modelCatalog *modelcatalog.ModelCatalog, logger schemas.Logger) *BudgetResolver {
func NewBudgetResolver(store GovernanceStore, modelCatalog *modelcatalog.ModelCatalog, logger schemas.Logger, governanceInMemoryStore InMemoryStore) *BudgetResolver {
return &BudgetResolver{
store: store,
logger: logger,
modelCatalog: modelCatalog,
store: store,
logger: logger,
modelCatalog: modelCatalog,
governanceInMemoryStore: governanceInMemoryStore,
}
}

Expand Down Expand Up @@ -334,8 +336,9 @@ func (r *BudgetResolver) isModelAllowed(vk *configstoreTables.TableVirtualKey, p
// Delegate model allowance check to model catalog
// This handles all cross-provider logic (OpenRouter, Vertex, Groq, Bedrock)
// and provider-prefixed allowed_models entries
if r.modelCatalog != nil {
return r.modelCatalog.IsModelAllowedForProvider(provider, model, pc.AllowedModels)
if r.modelCatalog != nil && r.governanceInMemoryStore != nil {
providerConfig := r.governanceInMemoryStore.GetConfiguredProviders()[provider]
return r.modelCatalog.IsModelAllowedForProvider(provider, model, &providerConfig, pc.AllowedModels)
}
// Fallback when model catalog is not available: simple string matching
if len(pc.AllowedModels) == 0 {
Expand Down
Loading
Loading