Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions core/bifrost.go
Original file line number Diff line number Diff line change
Expand Up @@ -6191,10 +6191,9 @@ func (bifrost *Bifrost) getKeysForBatchAndFileOps(ctx *schemas.BifrostContext, p
// - If key.Models is ["*"] → include key (supports all non-blacklisted models)
// - If key.Models is empty → exclude key (deny-by-default)
// - If key.Models is non-empty → only include if model is in list
// Blacklist wins over allowlist
if model != nil && *model != "" {
if k.Models.IsUnrestricted() {
// wildcard: allow all models
} else if !k.Models.IsAllowed(*model) {
if k.BlacklistedModels.IsBlocked(*model) || !k.Models.IsAllowed(*model) {
continue
}
}
Expand Down Expand Up @@ -6289,7 +6288,7 @@ func (bifrost *Bifrost) selectKeyFromProviderForModel(ctx *schemas.BifrostContex
}
hasValue := strings.TrimSpace(key.Value.GetValue()) != "" || CanProviderKeyValueBeEmpty(baseProviderType)
// ["*"] = allow all models; [] = deny all; specific list = allow only listed
modelSupported := hasValue && key.Models.IsAllowed(model)
modelSupported := hasValue && key.Models.IsAllowed(model) && !key.BlacklistedModels.IsBlocked(model)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
// Additional deployment checks for Azure, Bedrock and Vertex
deploymentSupported := true
if baseProviderType == schemas.Azure && key.AzureKeyConfig != nil {
Expand Down
3 changes: 1 addition & 2 deletions core/bifrost_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -891,7 +891,7 @@ func TestSelectKeyFromProviderForModel_BlacklistedModels(t *testing.T) {
t.Run("second key used when first blacklists", func(t *testing.T) {
account.SetKeysForProvider(schemas.OpenAI, []schemas.Key{
{ID: "k1", Name: "K1", Value: *schemas.NewEnvVar("sk-1"), Weight: 1, BlacklistedModels: []string{"gpt-4"}},
{ID: "k2", Name: "K2", Value: *schemas.NewEnvVar("sk-2"), Weight: 1},
{ID: "k2", Name: "K2", Value: *schemas.NewEnvVar("sk-2"), Weight: 1, Models: []string{"*"}},
})
key, err := bifrost.selectKeyFromProviderForModel(bfCtx, schemas.ChatCompletionRequest, schemas.OpenAI, "gpt-4", schemas.OpenAI)
if err != nil {
Expand Down Expand Up @@ -1226,4 +1226,3 @@ func TestUpdateProvider_ProviderSliceIntegrity(t *testing.T) {
}
})
}

2 changes: 1 addition & 1 deletion core/providers/anthropic/anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ func (provider *AnthropicProvider) listModelsByKey(ctx *schemas.BifrostContext,
}

// Create final response
response := anthropicResponse.ToBifrostListModelsResponse(provider.GetProviderKey(), key.Models, request.Unfiltered)
response := anthropicResponse.ToBifrostListModelsResponse(provider.GetProviderKey(), key.Models, key.BlacklistedModels, request.Unfiltered)
response.ExtraFields.Latency = latency.Milliseconds()

// Set raw request if enabled
Expand Down
10 changes: 8 additions & 2 deletions core/providers/anthropic/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"github.com/maximhq/bifrost/core/schemas"
)

func (response *AnthropicListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, unfiltered bool) *schemas.BifrostListModelsResponse {
func (response *AnthropicListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, unfiltered bool) *schemas.BifrostListModelsResponse {
if response == nil {
return nil
}
Expand All @@ -24,7 +24,7 @@ func (response *AnthropicListModelsResponse) ToBifrostListModelsResponse(provide
bifrostResponse.NextPageToken = *response.LastID
}

if !unfiltered && allowedModels.IsEmpty() {
if !unfiltered && (allowedModels.IsEmpty() || blacklistedModels.IsBlockAll()) {
return bifrostResponse
}

Expand All @@ -44,6 +44,9 @@ func (response *AnthropicListModelsResponse) ToBifrostListModelsResponse(provide
continue
}
}
if !unfiltered && blacklistedModels.IsBlocked(modelID) {
continue
}
bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{
ID: string(providerKey) + "/" + modelID,
Name: schemas.Ptr(model.DisplayName),
Expand All @@ -55,6 +58,9 @@ func (response *AnthropicListModelsResponse) ToBifrostListModelsResponse(provide
// Backfill allowed models that were not in the response
if !unfiltered && allowedModels.IsRestricted() {
for _, allowedModel := range allowedModels {
if blacklistedModels.IsBlocked(allowedModel) {
continue
}
if !includedModels[allowedModel] {
bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{
ID: string(providerKey) + "/" + allowedModel,
Expand Down
2 changes: 1 addition & 1 deletion core/providers/azure/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ func (provider *AzureProvider) listModelsByKey(ctx *schemas.BifrostContext, key
}

// Convert to Bifrost response
response := azureResponse.ToBifrostListModelsResponse(key.Models, key.AzureKeyConfig.Deployments, request.Unfiltered)
response := azureResponse.ToBifrostListModelsResponse(key.Models, key.BlacklistedModels, key.AzureKeyConfig.Deployments, request.Unfiltered)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
if response == nil {
return nil, providerUtils.NewBifrostOperationError("failed to convert Azure model list response", nil, schemas.Azure)
}
Expand Down
31 changes: 29 additions & 2 deletions core/providers/azure/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,24 @@ func findDeploymentMatch(deployments map[string]string, modelID string) (deploym
return "", ""
}

func (response *AzureListModelsResponse) ToBifrostListModelsResponse(allowedModels schemas.WhiteList, deployments map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse {
// matchesBlacklist reports whether modelID matches any entry in the blacklist,
// using the same matching logic as findMatchingAllowedModel (exact and base-model).
func matchesBlacklist(bl schemas.BlackList, modelID string) bool {
if bl.IsEmpty() {
return false
}
if bl.Contains(modelID) {
return true
}
for _, item := range bl {
if schemas.SameBaseModel(item, modelID) {
return true
}
}
return false
}

func (response *AzureListModelsResponse) ToBifrostListModelsResponse(allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, deployments map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse {
if response == nil {
return nil
}
Expand All @@ -67,7 +84,7 @@ func (response *AzureListModelsResponse) ToBifrostListModelsResponse(allowedMode
Data: make([]schemas.Model, 0, len(response.Data)),
}

if !unfiltered && allowedModels.IsEmpty() && len(deployments) == 0 {
if !unfiltered && (allowedModels.IsEmpty() && len(deployments) == 0 || blacklistedModels.IsBlockAll()) {
return bifrostResponse
}

Expand Down Expand Up @@ -113,6 +130,10 @@ func (response *AzureListModelsResponse) ToBifrostListModelsResponse(allowedMode
if shouldFilter {
continue
}
if !unfiltered && (matchesBlacklist(blacklistedModels, model.ID) ||
(deploymentAlias != "" && matchesBlacklist(blacklistedModels, deploymentAlias))) {
continue
}

// Use the matched name from allowedModels or deployments (like Anthropic)
// Priority: deployment value > matched allowedModel > original model.ID
Expand Down Expand Up @@ -148,6 +169,9 @@ func (response *AzureListModelsResponse) ToBifrostListModelsResponse(allowedMode
if restrictAllowed && !allowedModels.Contains(alias) {
continue
}
if !unfiltered && matchesBlacklist(blacklistedModels, alias) {
continue
}
bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{
ID: string(schemas.Azure) + "/" + alias,
Name: schemas.Ptr(alias),
Expand All @@ -160,6 +184,9 @@ func (response *AzureListModelsResponse) ToBifrostListModelsResponse(allowedMode
// Backfill allowed models that were not in the response
if restrictAllowed {
for _, allowedModel := range allowedModels {
if matchesBlacklist(blacklistedModels, allowedModel) {
continue
}
if !includedModels[strings.ToLower(allowedModel)] {
bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{
ID: string(schemas.Azure) + "/" + allowedModel,
Expand Down
2 changes: 1 addition & 1 deletion core/providers/bedrock/bedrock.go
Original file line number Diff line number Diff line change
Expand Up @@ -732,7 +732,7 @@ func (provider *BedrockProvider) listModelsByKey(ctx *schemas.BifrostContext, ke
}

// Convert to Bifrost response
response := bedrockResponse.ToBifrostListModelsResponse(providerName, key.Models, config.Deployments, request.Unfiltered)
response := bedrockResponse.ToBifrostListModelsResponse(providerName, key.Models, key.BlacklistedModels, config.Deployments, request.Unfiltered)
if response == nil {
return nil, providerUtils.NewBifrostOperationError("failed to convert Bedrock model list response", nil, providerName)
}
Expand Down
42 changes: 40 additions & 2 deletions core/providers/bedrock/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,35 @@ func findDeploymentMatch(deployments map[string]string, modelID string) (deploym
return "", ""
}

func (response *BedrockListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, deployments map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse {
// matchesBlacklist reports whether modelID matches any entry in the blacklist,
// using the same matching logic as findMatchingAllowedModel (exact, prefix-normalized, base-model).
func matchesBlacklist(bl schemas.BlackList, modelID string) bool {
if bl.IsEmpty() {
return false
}
if bl.Contains(modelID) {
return true
}
if extractPrefix(modelID) != "" {
if bl.Contains(removePrefix(modelID)) {
return true
}
}
for _, item := range bl {
if extractPrefix(item) != "" && removePrefix(item) == modelID {
return true
}
}
valueNormalized := removePrefix(modelID)
for _, item := range bl {
if schemas.SameBaseModel(removePrefix(item), valueNormalized) {
return true
}
}
return false
}

func (response *BedrockListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, deployments map[string]string, unfiltered bool) *schemas.BifrostListModelsResponse {
if response == nil {
return nil
}
Expand All @@ -229,7 +257,7 @@ func (response *BedrockListModelsResponse) ToBifrostListModelsResponse(providerK
Data: make([]schemas.Model, 0, len(response.ModelSummaries)),
}

if !unfiltered && allowedModels.IsEmpty() && len(deployments) == 0 {
if !unfiltered && (allowedModels.IsEmpty() && len(deployments) == 0 || blacklistedModels.IsBlockAll()) {
Comment thread
TejasGhatte marked this conversation as resolved.
return bifrostResponse
}

Expand Down Expand Up @@ -280,6 +308,10 @@ func (response *BedrockListModelsResponse) ToBifrostListModelsResponse(providerK
if shouldFilter {
continue
}
if !unfiltered && (matchesBlacklist(blacklistedModels, model.ModelID) ||
(deploymentAlias != "" && matchesBlacklist(blacklistedModels, deploymentAlias))) {
continue
}

// Use the matched name from allowedModels or deployments (like Anthropic)
// Priority: deployment value > matched allowedModel > original model.ModelID
Expand Down Expand Up @@ -320,6 +352,9 @@ func (response *BedrockListModelsResponse) ToBifrostListModelsResponse(providerK
if restrictAllowed && !allowedModels.Contains(alias) {
continue
}
if !unfiltered && matchesBlacklist(blacklistedModels, alias) {
continue
}
bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{
ID: string(providerKey) + "/" + alias,
Name: schemas.Ptr(alias),
Expand All @@ -332,6 +367,9 @@ func (response *BedrockListModelsResponse) ToBifrostListModelsResponse(providerK
// Backfill allowed models that were not in the response
if restrictAllowed {
for _, allowedModel := range allowedModels {
if matchesBlacklist(blacklistedModels, allowedModel) {
continue
}
if !includedModels[strings.ToLower(allowedModel)] {
bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{
ID: string(providerKey) + "/" + allowedModel,
Expand Down
2 changes: 1 addition & 1 deletion core/providers/cohere/cohere.go
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ func (provider *CohereProvider) listModelsByKey(ctx *schemas.BifrostContext, key
}

// Convert Cohere v2 response to Bifrost response
response := cohereResponse.ToBifrostListModelsResponse(providerName, key.Models, request.Unfiltered)
response := cohereResponse.ToBifrostListModelsResponse(providerName, key.Models, key.BlacklistedModels, request.Unfiltered)

response.ExtraFields.Latency = latency.Milliseconds()

Expand Down
10 changes: 8 additions & 2 deletions core/providers/cohere/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ type CohereRerankMeta struct {
Tokens *CohereTokenUsage `json:"tokens,omitempty"`
}

func (response *CohereListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, unfiltered bool) *schemas.BifrostListModelsResponse {
func (response *CohereListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, unfiltered bool) *schemas.BifrostListModelsResponse {
if response == nil {
return nil
}
Expand All @@ -53,7 +53,7 @@ func (response *CohereListModelsResponse) ToBifrostListModelsResponse(providerKe
Data: make([]schemas.Model, 0, len(response.Models)),
}

if !unfiltered && allowedModels.IsEmpty() {
if !unfiltered && (allowedModels.IsEmpty() || blacklistedModels.IsBlockAll()) {
return bifrostResponse
}

Expand All @@ -62,6 +62,9 @@ func (response *CohereListModelsResponse) ToBifrostListModelsResponse(providerKe
if !unfiltered && allowedModels.IsRestricted() && !allowedModels.Contains(model.Name) {
continue
}
if !unfiltered && blacklistedModels.IsBlocked(model.Name) {
continue
}
bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{
ID: string(providerKey) + "/" + model.Name,
Name: schemas.Ptr(model.Name),
Expand All @@ -74,6 +77,9 @@ func (response *CohereListModelsResponse) ToBifrostListModelsResponse(providerKe
// Backfill allowed models that were not in the response
if !unfiltered && allowedModels.IsRestricted() {
for _, allowedModel := range allowedModels {
if blacklistedModels.IsBlocked(allowedModel) {
continue
}
if !includedModels[strings.ToLower(allowedModel)] {
bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{
ID: string(providerKey) + "/" + allowedModel,
Expand Down
2 changes: 1 addition & 1 deletion core/providers/elevenlabs/elevenlabs.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ func (provider *ElevenlabsProvider) listModelsByKey(ctx *schemas.BifrostContext,
return nil, bifrostErr
}

response := elevenlabsResponse.ToBifrostListModelsResponse(providerName, key.Models, request.Unfiltered)
response := elevenlabsResponse.ToBifrostListModelsResponse(providerName, key.Models, key.BlacklistedModels, request.Unfiltered)

response.ExtraFields.Latency = latency.Milliseconds()
response.ExtraFields.ProviderResponseHeaders = providerUtils.ExtractProviderResponseHeaders(resp)
Expand Down
10 changes: 8 additions & 2 deletions core/providers/elevenlabs/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"github.com/maximhq/bifrost/core/schemas"
)

func (response *ElevenlabsListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, unfiltered bool) *schemas.BifrostListModelsResponse {
func (response *ElevenlabsListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, unfiltered bool) *schemas.BifrostListModelsResponse {
if response == nil {
return nil
}
Expand All @@ -15,7 +15,7 @@ func (response *ElevenlabsListModelsResponse) ToBifrostListModelsResponse(provid
Data: make([]schemas.Model, 0, len(*response)),
}

if !unfiltered && allowedModels.IsEmpty() {
if !unfiltered && (allowedModels.IsEmpty() || blacklistedModels.IsBlockAll()) {
return bifrostResponse
}

Expand All @@ -24,6 +24,9 @@ func (response *ElevenlabsListModelsResponse) ToBifrostListModelsResponse(provid
if !unfiltered && allowedModels.IsRestricted() && !allowedModels.Contains(model.ModelID) {
continue
}
if !unfiltered && blacklistedModels.IsBlocked(model.ModelID) {
continue
}
bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{
ID: string(providerKey) + "/" + model.ModelID,
Name: schemas.Ptr(model.Name),
Expand All @@ -34,6 +37,9 @@ func (response *ElevenlabsListModelsResponse) ToBifrostListModelsResponse(provid
// Backfill allowed models that were not in the response
if !unfiltered && allowedModels.IsRestricted() {
for _, allowedModel := range allowedModels {
if blacklistedModels.IsBlocked(allowedModel) {
continue
}
if !includedModels[strings.ToLower(allowedModel)] {
bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{
ID: string(providerKey) + "/" + allowedModel,
Expand Down
2 changes: 1 addition & 1 deletion core/providers/gemini/gemini.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ func (provider *GeminiProvider) listModelsByKey(ctx *schemas.BifrostContext, key
}
}

response := geminiResponse.ToBifrostListModelsResponse(providerName, key.Models, request.Unfiltered)
response := geminiResponse.ToBifrostListModelsResponse(providerName, key.Models, key.BlacklistedModels, request.Unfiltered)

response.ExtraFields.Latency = latency.Milliseconds()

Expand Down
10 changes: 8 additions & 2 deletions core/providers/gemini/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ func toGeminiModelResourceName(modelID string) string {
return "models/" + modelID
}

func (response *GeminiListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, unfiltered bool) *schemas.BifrostListModelsResponse {
func (response *GeminiListModelsResponse) ToBifrostListModelsResponse(providerKey schemas.ModelProvider, allowedModels schemas.WhiteList, blacklistedModels schemas.BlackList, unfiltered bool) *schemas.BifrostListModelsResponse {
if response == nil {
return nil
}
Expand All @@ -25,7 +25,7 @@ func (response *GeminiListModelsResponse) ToBifrostListModelsResponse(providerKe
Data: make([]schemas.Model, 0, len(response.Models)),
}

if !unfiltered && allowedModels.IsEmpty() {
if !unfiltered && (allowedModels.IsEmpty() || blacklistedModels.IsBlockAll()) {
return bifrostResponse
}

Expand All @@ -38,6 +38,9 @@ func (response *GeminiListModelsResponse) ToBifrostListModelsResponse(providerKe
if !unfiltered && allowedModels.IsRestricted() && !allowedModels.Contains(modelName) {
continue
}
if !unfiltered && blacklistedModels.IsBlocked(modelName) {
continue
}
bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{
ID: string(providerKey) + "/" + modelName,
Name: schemas.Ptr(model.DisplayName),
Expand All @@ -53,6 +56,9 @@ func (response *GeminiListModelsResponse) ToBifrostListModelsResponse(providerKe
// Backfill allowed models that were not in the response
if !unfiltered && allowedModels.IsRestricted() {
for _, allowedModel := range allowedModels {
if blacklistedModels.IsBlocked(allowedModel) {
continue
}
if !includedModels[strings.ToLower(allowedModel)] {
bifrostResponse.Data = append(bifrostResponse.Data, schemas.Model{
ID: string(providerKey) + "/" + allowedModel,
Expand Down
2 changes: 1 addition & 1 deletion core/providers/huggingface/huggingface.go
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ func (provider *HuggingFaceProvider) listModelsByKey(ctx *schemas.BifrostContext
}

if result.response != nil {
providerResponse := result.response.ToBifrostListModelsResponse(providerName, result.provider, key.Models, request.Unfiltered)
providerResponse := result.response.ToBifrostListModelsResponse(providerName, result.provider, key.Models, key.BlacklistedModels, request.Unfiltered)
if providerResponse != nil {
aggregatedResponse.Data = append(aggregatedResponse.Data, providerResponse.Data...)
totalLatency += result.latency
Expand Down
Loading