diff --git a/plugins/governance/main.go b/plugins/governance/main.go index 02c58f3fef..45e2ca874a 100644 --- a/plugins/governance/main.go +++ b/plugins/governance/main.go @@ -649,19 +649,16 @@ func (p *GovernancePlugin) evaluateGovernanceRequest(ctx *schemas.BifrostContext func (p *GovernancePlugin) PreLLMHook(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.LLMPluginShortCircuit, error) { // Extract governance headers and virtual key using utility functions virtualKeyValue := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeyVirtualKey) - + // Getting provider and mode from the request provider, model, _ := req.GetRequestFields() - // Create request context for evaluation evaluationRequest := &EvaluationRequest{ VirtualKey: virtualKeyValue, Provider: provider, Model: model, } - // Evaluate governance using common function _, bifrostError := p.evaluateGovernanceRequest(ctx, evaluationRequest, req.RequestType) - // Convert BifrostError to LLMPluginShortCircuit if needed if bifrostError != nil { return req, &schemas.LLMPluginShortCircuit{ diff --git a/plugins/governance/resolver.go b/plugins/governance/resolver.go index 6e4f267ef4..ce41450ce4 100644 --- a/plugins/governance/resolver.go +++ b/plugins/governance/resolver.go @@ -83,7 +83,6 @@ func (r *BudgetResolver) EvaluateModelAndProviderRequest(ctx *schemas.BifrostCon Provider: provider, Model: model, } - // 1. Check provider-level rate limits FIRST (before model-level checks) if provider != "" { if err, decision := r.store.CheckProviderRateLimit(ctx, request, nil, nil); err != nil { @@ -92,7 +91,6 @@ func (r *BudgetResolver) EvaluateModelAndProviderRequest(ctx *schemas.BifrostCon Reason: fmt.Sprintf("Provider-level rate limit check failed: %s", err.Error()), } } - // 2. Check provider-level budgets FIRST (before model-level checks) if err := r.store.CheckProviderBudget(ctx, request, nil); err != nil { return &EvaluationResult{ @@ -101,7 +99,6 @@ func (r *BudgetResolver) EvaluateModelAndProviderRequest(ctx *schemas.BifrostCon } } } - // 3. Check model-level rate limits (after provider-level checks) if model != "" { if err, decision := r.store.CheckModelRateLimit(ctx, request, nil, nil); err != nil { @@ -119,7 +116,6 @@ func (r *BudgetResolver) EvaluateModelAndProviderRequest(ctx *schemas.BifrostCon } } } - // All provider-level and model-level checks passed return &EvaluationResult{ Decision: DecisionAllow, @@ -127,6 +123,17 @@ func (r *BudgetResolver) EvaluateModelAndProviderRequest(ctx *schemas.BifrostCon } } +// isModelRequired checks if the requested model is required for this request +func (r *BudgetResolver) isModelRequired(requestType schemas.RequestType) bool { + // Here we will have to check for some requests which do not need model + // For example, batches, container, files requests + // For these requests, we will only check for provider filtering + if requestType == schemas.MCPToolExecutionRequest || requestType == schemas.BatchCreateRequest || requestType == schemas.BatchListRequest || requestType == schemas.BatchRetrieveRequest || requestType == schemas.BatchCancelRequest || requestType == schemas.BatchResultsRequest || requestType == schemas.FileUploadRequest || requestType == schemas.FileListRequest || requestType == schemas.FileRetrieveRequest || requestType == schemas.FileDeleteRequest || requestType == schemas.FileContentRequest || requestType == schemas.ContainerCreateRequest || requestType == schemas.ContainerListRequest || requestType == schemas.ContainerRetrieveRequest || requestType == schemas.ContainerDeleteRequest || requestType == schemas.ContainerFileCreateRequest || requestType == schemas.ContainerFileListRequest || requestType == schemas.ContainerFileRetrieveRequest || requestType == schemas.ContainerFileContentRequest || requestType == schemas.ContainerFileDeleteRequest { + return false + } + return true +} + // EvaluateVirtualKeyRequest evaluates virtual key-specific checks including validation, filtering, rate limits, and budgets func (r *BudgetResolver) EvaluateVirtualKeyRequest(ctx *schemas.BifrostContext, virtualKeyValue string, provider schemas.ModelProvider, model string, requestType schemas.RequestType) *EvaluationResult { // 1. Validate virtual key exists and is active @@ -137,7 +144,6 @@ func (r *BudgetResolver) EvaluateVirtualKeyRequest(ctx *schemas.BifrostContext, Reason: "Virtual key not found", } } - // Set virtual key id and name in context ctx.SetValue(schemas.BifrostContextKey("bf-governance-virtual-key-id"), vk.ID) ctx.SetValue(schemas.BifrostContextKey("bf-governance-virtual-key-name"), vk.Name) @@ -153,14 +159,12 @@ func (r *BudgetResolver) EvaluateVirtualKeyRequest(ctx *schemas.BifrostContext, ctx.SetValue(schemas.BifrostContextKey("bf-governance-customer-id"), vk.Customer.ID) ctx.SetValue(schemas.BifrostContextKey("bf-governance-customer-name"), vk.Customer.Name) } - if !vk.IsActive { return &EvaluationResult{ Decision: DecisionVirtualKeyBlocked, Reason: "Virtual key is inactive", } } - // 2. Check provider filtering if requestType != schemas.MCPToolExecutionRequest && !r.isProviderAllowed(vk, provider) { return &EvaluationResult{ @@ -169,9 +173,8 @@ func (r *BudgetResolver) EvaluateVirtualKeyRequest(ctx *schemas.BifrostContext, VirtualKey: vk, } } - // 3. Check model filtering - if requestType != schemas.MCPToolExecutionRequest && !r.isModelAllowed(vk, provider, model) { + if r.isModelRequired(requestType) && !r.isModelAllowed(vk, provider, model) { return &EvaluationResult{ Decision: DecisionModelBlocked, Reason: fmt.Sprintf("Model '%s' is not allowed for this virtual key", model), diff --git a/transports/bifrost-http/lib/config.go b/transports/bifrost-http/lib/config.go index 1dfe268e43..88b2141349 100644 --- a/transports/bifrost-http/lib/config.go +++ b/transports/bifrost-http/lib/config.go @@ -1350,7 +1350,6 @@ func createGovernanceConfigInStore(ctx context.Context, config *Config) { } // Resolve MCP client names to IDs for config file mcp_configs - // This is done outside the transaction setup so we can access the store mcpConfigs = resolveMCPConfigClientIDs(ctx, config.ConfigStore, mcpConfigs, virtualKey.ID) for _, mc := range mcpConfigs { diff --git a/transports/bifrost-http/lib/config_test.go b/transports/bifrost-http/lib/config_test.go index 0060f74fbc..8fafc9e9ba 100644 --- a/transports/bifrost-http/lib/config_test.go +++ b/transports/bifrost-http/lib/config_test.go @@ -9924,13 +9924,15 @@ func TestSQLite_VKMCPConfig_MCPClientNameResolution(t *testing.T) { // Now create config.json with virtual key using mcp_client_name (not mcp_client_id) // This simulates the real-world scenario where config.json uses human-readable names + dbPath := filepath.Join(tempDir, "config.db") + cfgPath := filepath.Join(tempDir, "config.json") configJSON := fmt.Sprintf(`{ "$schema": "https://www.getbifrost.ai/schema", "config_store": { "enabled": true, "type": "sqlite", "config": { - "path": "%s/config.db" + "path": %s } }, "providers": { @@ -9988,10 +9990,10 @@ func TestSQLite_VKMCPConfig_MCPClientNameResolution(t *testing.T) { } ] } - }`, tempDir, keyID) + }`, fmt.Sprintf("%q", dbPath), keyID) // Write the config file directly - err = os.WriteFile(tempDir+"/config.json", []byte(configJSON), 0644) + err = os.WriteFile(cfgPath, []byte(configJSON), 0644) if err != nil { t.Fatalf("Failed to write config.json: %v", err) }