Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions plugins/governance/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -649,19 +649,16 @@ func (p *GovernancePlugin) evaluateGovernanceRequest(ctx *schemas.BifrostContext
func (p *GovernancePlugin) PreLLMHook(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.LLMPluginShortCircuit, error) {
// Extract governance headers and virtual key using utility functions
virtualKeyValue := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeyVirtualKey)

// Getting provider and mode from the request
provider, model, _ := req.GetRequestFields()

// Create request context for evaluation
evaluationRequest := &EvaluationRequest{
VirtualKey: virtualKeyValue,
Provider: provider,
Model: model,
}

// Evaluate governance using common function
_, bifrostError := p.evaluateGovernanceRequest(ctx, evaluationRequest, req.RequestType)

// Convert BifrostError to LLMPluginShortCircuit if needed
if bifrostError != nil {
return req, &schemas.LLMPluginShortCircuit{
Expand Down
21 changes: 12 additions & 9 deletions plugins/governance/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ func (r *BudgetResolver) EvaluateModelAndProviderRequest(ctx *schemas.BifrostCon
Provider: provider,
Model: model,
}

// 1. Check provider-level rate limits FIRST (before model-level checks)
if provider != "" {
if err, decision := r.store.CheckProviderRateLimit(ctx, request, nil, nil); err != nil {
Expand All @@ -92,7 +91,6 @@ func (r *BudgetResolver) EvaluateModelAndProviderRequest(ctx *schemas.BifrostCon
Reason: fmt.Sprintf("Provider-level rate limit check failed: %s", err.Error()),
}
}

// 2. Check provider-level budgets FIRST (before model-level checks)
if err := r.store.CheckProviderBudget(ctx, request, nil); err != nil {
return &EvaluationResult{
Expand All @@ -101,7 +99,6 @@ func (r *BudgetResolver) EvaluateModelAndProviderRequest(ctx *schemas.BifrostCon
}
}
}

// 3. Check model-level rate limits (after provider-level checks)
if model != "" {
if err, decision := r.store.CheckModelRateLimit(ctx, request, nil, nil); err != nil {
Expand All @@ -119,14 +116,24 @@ func (r *BudgetResolver) EvaluateModelAndProviderRequest(ctx *schemas.BifrostCon
}
}
}

// All provider-level and model-level checks passed
return &EvaluationResult{
Decision: DecisionAllow,
Reason: "Request allowed by governance policy (provider-level and model-level checks passed)",
}
}

// isModelRequired checks if the requested model is required for this request
func (r *BudgetResolver) isModelRequired(requestType schemas.RequestType) bool {
// Here we will have to check for some requests which do not need model
// For example, batches, container, files requests
// For these requests, we will only check for provider filtering
if requestType == schemas.MCPToolExecutionRequest || requestType == schemas.BatchCreateRequest || requestType == schemas.BatchListRequest || requestType == schemas.BatchRetrieveRequest || requestType == schemas.BatchCancelRequest || requestType == schemas.BatchResultsRequest || requestType == schemas.FileUploadRequest || requestType == schemas.FileListRequest || requestType == schemas.FileRetrieveRequest || requestType == schemas.FileDeleteRequest || requestType == schemas.FileContentRequest || requestType == schemas.ContainerCreateRequest || requestType == schemas.ContainerListRequest || requestType == schemas.ContainerRetrieveRequest || requestType == schemas.ContainerDeleteRequest || requestType == schemas.ContainerFileCreateRequest || requestType == schemas.ContainerFileListRequest || requestType == schemas.ContainerFileRetrieveRequest || requestType == schemas.ContainerFileContentRequest || requestType == schemas.ContainerFileDeleteRequest {
return false
}
return true
}

// EvaluateVirtualKeyRequest evaluates virtual key-specific checks including validation, filtering, rate limits, and budgets
func (r *BudgetResolver) EvaluateVirtualKeyRequest(ctx *schemas.BifrostContext, virtualKeyValue string, provider schemas.ModelProvider, model string, requestType schemas.RequestType) *EvaluationResult {
// 1. Validate virtual key exists and is active
Expand All @@ -137,7 +144,6 @@ func (r *BudgetResolver) EvaluateVirtualKeyRequest(ctx *schemas.BifrostContext,
Reason: "Virtual key not found",
}
}

// Set virtual key id and name in context
ctx.SetValue(schemas.BifrostContextKey("bf-governance-virtual-key-id"), vk.ID)
ctx.SetValue(schemas.BifrostContextKey("bf-governance-virtual-key-name"), vk.Name)
Expand All @@ -153,14 +159,12 @@ func (r *BudgetResolver) EvaluateVirtualKeyRequest(ctx *schemas.BifrostContext,
ctx.SetValue(schemas.BifrostContextKey("bf-governance-customer-id"), vk.Customer.ID)
ctx.SetValue(schemas.BifrostContextKey("bf-governance-customer-name"), vk.Customer.Name)
}

if !vk.IsActive {
return &EvaluationResult{
Decision: DecisionVirtualKeyBlocked,
Reason: "Virtual key is inactive",
}
}

// 2. Check provider filtering
if requestType != schemas.MCPToolExecutionRequest && !r.isProviderAllowed(vk, provider) {
return &EvaluationResult{
Expand All @@ -169,9 +173,8 @@ func (r *BudgetResolver) EvaluateVirtualKeyRequest(ctx *schemas.BifrostContext,
VirtualKey: vk,
}
}

// 3. Check model filtering
if requestType != schemas.MCPToolExecutionRequest && !r.isModelAllowed(vk, provider, model) {
if r.isModelRequired(requestType) && !r.isModelAllowed(vk, provider, model) {
return &EvaluationResult{
Decision: DecisionModelBlocked,
Reason: fmt.Sprintf("Model '%s' is not allowed for this virtual key", model),
Expand Down
1 change: 0 additions & 1 deletion transports/bifrost-http/lib/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -1350,7 +1350,6 @@ func createGovernanceConfigInStore(ctx context.Context, config *Config) {
}

// Resolve MCP client names to IDs for config file mcp_configs
// This is done outside the transaction setup so we can access the store
mcpConfigs = resolveMCPConfigClientIDs(ctx, config.ConfigStore, mcpConfigs, virtualKey.ID)

for _, mc := range mcpConfigs {
Expand Down
8 changes: 5 additions & 3 deletions transports/bifrost-http/lib/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9924,13 +9924,15 @@ func TestSQLite_VKMCPConfig_MCPClientNameResolution(t *testing.T) {

// Now create config.json with virtual key using mcp_client_name (not mcp_client_id)
// This simulates the real-world scenario where config.json uses human-readable names
dbPath := filepath.Join(tempDir, "config.db")
cfgPath := filepath.Join(tempDir, "config.json")
configJSON := fmt.Sprintf(`{
"$schema": "https://www.getbifrost.ai/schema",
"config_store": {
"enabled": true,
"type": "sqlite",
"config": {
"path": "%s/config.db"
"path": %s
}
},
"providers": {
Expand Down Expand Up @@ -9988,10 +9990,10 @@ func TestSQLite_VKMCPConfig_MCPClientNameResolution(t *testing.T) {
}
]
}
}`, tempDir, keyID)
}`, fmt.Sprintf("%q", dbPath), keyID)

// Write the config file directly
err = os.WriteFile(tempDir+"/config.json", []byte(configJSON), 0644)
err = os.WriteFile(cfgPath, []byte(configJSON), 0644)
if err != nil {
t.Fatalf("Failed to write config.json: %v", err)
}
Expand Down