diff --git a/cli/go.mod b/cli/go.mod index 2c1c930bc2..c6a77eaf26 100644 --- a/cli/go.mod +++ b/cli/go.mod @@ -46,7 +46,7 @@ require ( github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect golang.org/x/arch v0.23.0 // indirect - golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 // indirect + golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 // indirect golang.org/x/sys v0.41.0 // indirect golang.org/x/text v0.33.0 // indirect ) diff --git a/cli/go.sum b/cli/go.sum index e6e613043a..9746bb475a 100644 --- a/cli/go.sum +++ b/cli/go.sum @@ -89,8 +89,7 @@ github.com/zalando/go-keyring v0.2.6 h1:r7Yc3+H+Ux0+M72zacZoItR3UDxeWfKTcabvkI8u github.com/zalando/go-keyring v0.2.6/go.mod h1:2TCrxYrbUNYfNS/Kgy/LSrkSQzZ5UPVH85RwfczwvcI= golang.org/x/arch v0.23.0 h1:lKF64A2jF6Zd8L0knGltUnegD62JMFBiCPBmQpToHhg= golang.org/x/arch v0.23.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A= -golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 h1:R84qjqJb5nVJMxqWYb3np9L5ZsaDtB+a39EqjV0JSUM= -golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0/go.mod h1:S9Xr4PYopiDyqSyp5NjCrhFrqg6A5zA2E/iPHPhqnS8= +golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 h1:zfMcR1Cs4KNuomFFgGefv5N0czO2XZpUbxGUy8i8ug0= golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= diff --git a/framework/configstore/clientconfig.go b/framework/configstore/clientconfig.go index 8656166e18..8bb66cfa65 100644 --- a/framework/configstore/clientconfig.go +++ b/framework/configstore/clientconfig.go @@ -967,6 +967,23 @@ func GenerateRoutingRuleHash(r tables.TableRoutingRule) (string, error) { return hex.EncodeToString(hash.Sum(nil)), nil } +// GeneratePricingOverrideHash generates a SHA256 hash for a pricing override. +// Skips: CreatedAt, UpdatedAt, ConfigHash (dynamic/meta fields). +func GeneratePricingOverrideHash(p tables.TablePricingOverride) (string, error) { + hash := sha256.New() + hash.Write([]byte(p.ID)) + hash.Write([]byte(p.Name)) + hash.Write([]byte(p.ScopeKind)) + hash.Write([]byte(derefStr(p.VirtualKeyID))) + hash.Write([]byte(derefStr(p.ProviderID))) + hash.Write([]byte(derefStr(p.ProviderKeyID))) + hash.Write([]byte(p.MatchType)) + hash.Write([]byte(p.Pattern)) + hash.Write([]byte(p.RequestTypesJSON)) + hash.Write([]byte(p.PricingPatchJSON)) + return hex.EncodeToString(hash.Sum(nil)), nil +} + // GenerateMCPClientHash generates a SHA256 hash for an MCP client. // This is used to detect changes to MCP clients between config.json and database. // Skips: ID (autoIncrement), CreatedAt, UpdatedAt (dynamic fields) @@ -1093,13 +1110,14 @@ type ConfigMap map[schemas.ModelProvider]ProviderConfig // GovernanceConfig contains governance entities loaded from the config store or // reconciled from config.json. type GovernanceConfig struct { - VirtualKeys []tables.TableVirtualKey `json:"virtual_keys"` - Teams []tables.TableTeam `json:"teams"` - Customers []tables.TableCustomer `json:"customers"` - Budgets []tables.TableBudget `json:"budgets"` - RateLimits []tables.TableRateLimit `json:"rate_limits"` - ModelConfigs []tables.TableModelConfig `json:"model_configs"` - Providers []tables.TableProvider `json:"providers"` - RoutingRules []tables.TableRoutingRule `json:"routing_rules"` - AuthConfig *AuthConfig `json:"auth_config,omitempty"` + VirtualKeys []tables.TableVirtualKey `json:"virtual_keys"` + Teams []tables.TableTeam `json:"teams"` + Customers []tables.TableCustomer `json:"customers"` + Budgets []tables.TableBudget `json:"budgets"` + RateLimits []tables.TableRateLimit `json:"rate_limits"` + ModelConfigs []tables.TableModelConfig `json:"model_configs"` + Providers []tables.TableProvider `json:"providers"` + RoutingRules []tables.TableRoutingRule `json:"routing_rules"` + PricingOverrides []tables.TablePricingOverride `json:"pricing_overrides,omitempty"` + AuthConfig *AuthConfig `json:"auth_config,omitempty"` } diff --git a/framework/configstore/migrations.go b/framework/configstore/migrations.go index 3e3f091642..2c7e87cb83 100644 --- a/framework/configstore/migrations.go +++ b/framework/configstore/migrations.go @@ -355,7 +355,6 @@ func migrationAddStoreRawRequestResponseColumn(ctx context.Context, db *gorm.DB) "concurrency_buffer_json", "proxy_config_json", "custom_provider_config_json", - "pricing_overrides_json", "send_back_raw_request", "send_back_raw_response", "store_raw_request_response", diff --git a/framework/configstore/migrations_test.go b/framework/configstore/migrations_test.go index 7d81111d27..a57f7262cb 100644 --- a/framework/configstore/migrations_test.go +++ b/framework/configstore/migrations_test.go @@ -11,9 +11,7 @@ import ( "testing" "time" - "github.com/maximhq/bifrost/core/schemas" "github.com/maximhq/bifrost/framework/configstore/tables" - "github.com/maximhq/bifrost/framework/pricingoverrides" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "gorm.io/driver/postgres" @@ -247,47 +245,6 @@ func TestFindUniqueName_NormalizationAndCollision(t *testing.T) { assert.Contains(t, logOutput, "MCP Client Name Normalized: 'my-tool' -> 'my_tool2'", "Should log the full transformation") } -func TestMigrationReconcilePricingOverridesTable_PreservesExistingRows(t *testing.T) { - db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) - require.NoError(t, err) - - err = db.AutoMigrate(&tables.TablePricingOverride{}) - require.NoError(t, err) - - inputCost := 1.25 - override := tables.TablePricingOverride{ - ID: "override-1", - Name: "Config Override", - ScopeKind: pricingoverrides.ScopeKindGlobal, - MatchType: pricingoverrides.MatchTypeExact, - Pattern: "gpt-4.1", - RequestTypes: []schemas.RequestType{schemas.ChatCompletionRequest}, - Patch: pricingoverrides.Patch{ - InputCostPerToken: &inputCost, - }, - ConfigHash: "config-hash-1", - CreatedAt: time.Now().UTC().Round(time.Second), - UpdatedAt: time.Now().UTC().Round(time.Second), - } - require.NoError(t, db.Create(&override).Error) - - require.NoError(t, db.Migrator().DropIndex(&tables.TablePricingOverride{}, "idx_pricing_override_match")) - require.False(t, db.Migrator().HasIndex(&tables.TablePricingOverride{}, "idx_pricing_override_match")) - - require.NoError(t, migrationReconcilePricingOverridesTable(context.Background(), db)) - - var stored []tables.TablePricingOverride - require.NoError(t, db.Order("id").Find(&stored).Error) - require.Len(t, stored, 1) - assert.Equal(t, override.ID, stored[0].ID) - assert.Equal(t, override.Name, stored[0].Name) - require.NotNil(t, stored[0].Patch.InputCostPerToken) - assert.Equal(t, inputCost, *stored[0].Patch.InputCostPerToken) - assert.Equal(t, override.ConfigHash, stored[0].ConfigHash) - assert.True(t, db.Migrator().HasIndex(&tables.TablePricingOverride{}, "idx_pricing_override_scope")) - assert.True(t, db.Migrator().HasIndex(&tables.TablePricingOverride{}, "idx_pricing_override_match")) -} - func TestFindUniqueName_MultipleNormalizationsToSameBase(t *testing.T) { db := setupTestDB(t) ctx := context.Background() diff --git a/framework/configstore/rdb.go b/framework/configstore/rdb.go index 722ff0bf90..95b641ff15 100644 --- a/framework/configstore/rdb.go +++ b/framework/configstore/rdb.go @@ -1302,20 +1302,20 @@ func (s *RDBConfigStore) DeleteModelPrices(ctx context.Context, tx ...*gorm.DB) return txDB.WithContext(ctx).Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(&tables.TableModelPricing{}).Error } -func (s *RDBConfigStore) GetPricingOverrides(ctx context.Context, filter PricingOverrideFilter) ([]tables.TablePricingOverride, error) { +func (s *RDBConfigStore) GetPricingOverrides(ctx context.Context, filters PricingOverrideFilters) ([]tables.TablePricingOverride, error) { var overrides []tables.TablePricingOverride q := s.db.WithContext(ctx).Model(&tables.TablePricingOverride{}) - if filter.ScopeKind != nil { - q = q.Where("scope_kind = ?", *filter.ScopeKind) + if filters.ScopeKind != nil { + q = q.Where("scope_kind = ?", *filters.ScopeKind) } - if filter.VirtualKeyID != nil { - q = q.Where("virtual_key_id = ?", *filter.VirtualKeyID) + if filters.VirtualKeyID != nil { + q = q.Where("virtual_key_id = ?", *filters.VirtualKeyID) } - if filter.ProviderID != nil { - q = q.Where("provider_id = ?", *filter.ProviderID) + if filters.ProviderID != nil { + q = q.Where("provider_id = ?", *filters.ProviderID) } - if filter.ProviderKeyID != nil { - q = q.Where("provider_key_id = ?", *filter.ProviderKeyID) + if filters.ProviderKeyID != nil { + q = q.Where("provider_key_id = ?", *filters.ProviderKeyID) } if err := q.Order("created_at ASC").Find(&overrides).Error; err != nil { return nil, s.parseGormError(err) diff --git a/framework/configstore/store.go b/framework/configstore/store.go index 5aed312a80..f6803a6ad3 100644 --- a/framework/configstore/store.go +++ b/framework/configstore/store.go @@ -10,7 +10,6 @@ import ( "github.com/maximhq/bifrost/framework/configstore/tables" "github.com/maximhq/bifrost/framework/logstore" "github.com/maximhq/bifrost/framework/migrator" - "github.com/maximhq/bifrost/framework/pricingoverrides" "github.com/maximhq/bifrost/framework/vectorstore" "gorm.io/gorm" ) @@ -60,6 +59,14 @@ type CustomersQueryParams struct { Search string } +// PricingOverrideFilters holds the filters for pricing overrides. +type PricingOverrideFilters struct { + ScopeKind *string + VirtualKeyID *string + ProviderID *string + ProviderKeyID *string +} + // ConfigStore is the interface for the config store. type ConfigStore interface { // Health check @@ -220,7 +227,7 @@ type ConfigStore interface { DeleteModelPrices(ctx context.Context, tx ...*gorm.DB) error // Governance pricing overrides CRUD - GetPricingOverrides(ctx context.Context, filter PricingOverrideFilter) ([]tables.TablePricingOverride, error) + GetPricingOverrides(ctx context.Context, filters PricingOverrideFilters) ([]tables.TablePricingOverride, error) GetPricingOverrideByID(ctx context.Context, id string) (*tables.TablePricingOverride, error) CreatePricingOverride(ctx context.Context, override *tables.TablePricingOverride, tx ...*gorm.DB) error UpdatePricingOverride(ctx context.Context, override *tables.TablePricingOverride, tx ...*gorm.DB) error @@ -317,13 +324,6 @@ type ConfigStore interface { Close(ctx context.Context) error } -type PricingOverrideFilter struct { - ScopeKind *pricingoverrides.ScopeKind - VirtualKeyID *string - ProviderID *string - ProviderKeyID *string -} - // NewConfigStore creates a new config store based on the configuration func NewConfigStore(ctx context.Context, config *Config, logger schemas.Logger) (ConfigStore, error) { if config == nil { diff --git a/framework/configstore/tables/pricingoverride.go b/framework/configstore/tables/pricingoverride.go index df0d6dd5d0..ab6ee9b5c5 100644 --- a/framework/configstore/tables/pricingoverride.go +++ b/framework/configstore/tables/pricingoverride.go @@ -2,65 +2,36 @@ package tables import ( "encoding/json" - "fmt" - "strings" "time" "github.com/maximhq/bifrost/core/schemas" - "github.com/maximhq/bifrost/framework/pricingoverrides" "gorm.io/gorm" ) -// TablePricingOverride is the persistence model for governance pricing -// overrides. +// TablePricingOverride is the persistence model for governance pricing overrides. type TablePricingOverride struct { - ID string `gorm:"primaryKey;type:varchar(255)" json:"id"` - Name string `gorm:"type:varchar(255);not null" json:"name"` - ScopeKind pricingoverrides.ScopeKind `gorm:"type:varchar(50);index:idx_pricing_override_scope;not null" json:"scope_kind"` - VirtualKeyID *string `gorm:"type:varchar(255);index:idx_pricing_override_scope" json:"virtual_key_id,omitempty"` - ProviderID *string `gorm:"type:varchar(255);index:idx_pricing_override_scope" json:"provider_id,omitempty"` - ProviderKeyID *string `gorm:"type:varchar(255);index:idx_pricing_override_scope" json:"provider_key_id,omitempty"` - MatchType pricingoverrides.MatchType `gorm:"type:varchar(20);index:idx_pricing_override_match;not null" json:"match_type"` - Pattern string `gorm:"type:varchar(255);not null" json:"pattern"` - RequestTypesJSON string `gorm:"type:text" json:"-"` - PricingPatchJSON string `gorm:"type:text" json:"-"` - ConfigHash string `gorm:"type:varchar(255);null" json:"config_hash,omitempty"` - CreatedAt time.Time `gorm:"index;not null" json:"created_at"` - UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"` - - RequestTypes []schemas.RequestType `gorm:"-" json:"request_types,omitempty"` - Patch pricingoverrides.Patch `gorm:"-" json:"patch,omitempty"` + ID string `gorm:"primaryKey;type:varchar(255)" json:"id"` + Name string `gorm:"type:varchar(255);not null" json:"name"` + ScopeKind string `gorm:"type:varchar(50);index:idx_pricing_override_scope;not null" json:"scope_kind"` + VirtualKeyID *string `gorm:"type:varchar(255);index:idx_pricing_override_scope" json:"virtual_key_id,omitempty"` + ProviderID *string `gorm:"type:varchar(255);index:idx_pricing_override_scope" json:"provider_id,omitempty"` + ProviderKeyID *string `gorm:"type:varchar(255);index:idx_pricing_override_scope" json:"provider_key_id,omitempty"` + MatchType string `gorm:"type:varchar(20);index:idx_pricing_override_match;not null" json:"match_type"` + Pattern string `gorm:"type:varchar(255);not null" json:"pattern"` + RequestTypesJSON string `gorm:"type:text" json:"-"` + PricingPatchJSON string `gorm:"type:text" json:"pricing_patch,omitempty"` + ConfigHash string `gorm:"type:varchar(255);null" json:"config_hash,omitempty"` + CreatedAt time.Time `gorm:"index;not null" json:"created_at"` + UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"` + + RequestTypes []schemas.RequestType `gorm:"-" json:"request_types,omitempty"` } // TableName returns the backing table name for governance pricing overrides. func (TablePricingOverride) TableName() string { return "governance_pricing_overrides" } -// BeforeSave validates and serializes the sparse pricing override fields before -// the row is persisted. +// BeforeSave serializes virtual fields into their JSON columns before persistence. func (p *TablePricingOverride) BeforeSave(tx *gorm.DB) error { - p.Name = strings.TrimSpace(p.Name) - if p.Name == "" { - return fmt.Errorf("name is required") - } - - if err := pricingoverrides.ValidateScopeKind(p.ScopeKind, p.VirtualKeyID, p.ProviderID, p.ProviderKeyID); err != nil { - return err - } - - normalizedPattern, err := pricingoverrides.ValidatePattern(p.MatchType, p.Pattern) - if err != nil { - return err - } - p.Pattern = normalizedPattern - - if err := pricingoverrides.ValidateRequestTypes(p.RequestTypes); err != nil { - return err - } - - if err := pricingoverrides.ValidatePatchNonNegative(p.Patch); err != nil { - return err - } - if len(p.RequestTypes) > 0 { b, err := json.Marshal(p.RequestTypes) if err != nil { @@ -70,66 +41,15 @@ func (p *TablePricingOverride) BeforeSave(tx *gorm.DB) error { } else { p.RequestTypesJSON = "" } - - b, err := json.Marshal(p.Patch) - if err != nil { - return err - } - p.PricingPatchJSON = string(b) - return nil } -// AfterFind restores the request type and patch fields from their persisted -// JSON columns. +// AfterFind restores virtual fields from their persisted JSON columns. func (p *TablePricingOverride) AfterFind(tx *gorm.DB) error { if p.RequestTypesJSON != "" { if err := json.Unmarshal([]byte(p.RequestTypesJSON), &p.RequestTypes); err != nil { return err } } - if p.PricingPatchJSON != "" { - if err := json.Unmarshal([]byte(p.PricingPatchJSON), &p.Patch); err != nil { - return err - } - } return nil } - -// ToPricingOverride converts the persisted row into the shared pricing override -// contract used by runtime components. -func (p TablePricingOverride) ToPricingOverride() pricingoverrides.Override { - return pricingoverrides.Override{ - ID: p.ID, - Name: p.Name, - ScopeKind: p.ScopeKind, - VirtualKeyID: p.VirtualKeyID, - ProviderID: p.ProviderID, - ProviderKeyID: p.ProviderKeyID, - MatchType: p.MatchType, - Pattern: p.Pattern, - RequestTypes: p.RequestTypes, - Patch: p.Patch, - CreatedAt: p.CreatedAt, - UpdatedAt: p.UpdatedAt, - } -} - -// TablePricingOverrideFromPricingOverride converts the shared runtime override -// contract into its persistence representation. -func TablePricingOverrideFromPricingOverride(override pricingoverrides.Override) TablePricingOverride { - return TablePricingOverride{ - ID: override.ID, - Name: override.Name, - ScopeKind: override.ScopeKind, - VirtualKeyID: override.VirtualKeyID, - ProviderID: override.ProviderID, - ProviderKeyID: override.ProviderKeyID, - MatchType: override.MatchType, - Pattern: override.Pattern, - RequestTypes: override.RequestTypes, - Patch: override.Patch, - CreatedAt: override.CreatedAt, - UpdatedAt: override.UpdatedAt, - } -} diff --git a/framework/logstore/tables.go b/framework/logstore/tables.go index dae5502078..c408a8f408 100644 --- a/framework/logstore/tables.go +++ b/framework/logstore/tables.go @@ -29,22 +29,22 @@ const ( // SearchFilters represents the available filters for log searches type SearchFilters struct { - Providers []string `json:"providers,omitempty"` - Models []string `json:"models,omitempty"` - Status []string `json:"status,omitempty"` - Objects []string `json:"objects,omitempty"` // For filtering by request type (chat.completion, text.completion, embedding) - SelectedKeyIDs []string `json:"selected_key_ids,omitempty"` - VirtualKeyIDs []string `json:"virtual_key_ids,omitempty"` - RoutingRuleIDs []string `json:"routing_rule_ids,omitempty"` - RoutingEngineUsed []string `json:"routing_engine_used,omitempty"` // For filtering by routing engine (routing-rule, governance, loadbalancing) - StartTime *time.Time `json:"start_time,omitempty"` - EndTime *time.Time `json:"end_time,omitempty"` - MinLatency *float64 `json:"min_latency,omitempty"` - MaxLatency *float64 `json:"max_latency,omitempty"` - MinTokens *int `json:"min_tokens,omitempty"` - MaxTokens *int `json:"max_tokens,omitempty"` - MinCost *float64 `json:"min_cost,omitempty"` - MaxCost *float64 `json:"max_cost,omitempty"` + Providers []string `json:"providers,omitempty"` + Models []string `json:"models,omitempty"` + Status []string `json:"status,omitempty"` + Objects []string `json:"objects,omitempty"` // For filtering by request type (chat.completion, text.completion, embedding) + SelectedKeyIDs []string `json:"selected_key_ids,omitempty"` + VirtualKeyIDs []string `json:"virtual_key_ids,omitempty"` + RoutingRuleIDs []string `json:"routing_rule_ids,omitempty"` + RoutingEngineUsed []string `json:"routing_engine_used,omitempty"` // For filtering by routing engine (routing-rule, governance, loadbalancing) + StartTime *time.Time `json:"start_time,omitempty"` + EndTime *time.Time `json:"end_time,omitempty"` + MinLatency *float64 `json:"min_latency,omitempty"` + MaxLatency *float64 `json:"max_latency,omitempty"` + MinTokens *int `json:"min_tokens,omitempty"` + MaxTokens *int `json:"max_tokens,omitempty"` + MinCost *float64 `json:"min_cost,omitempty"` + MaxCost *float64 `json:"max_cost,omitempty"` MissingCostOnly bool `json:"missing_cost_only,omitempty"` ContentSearch string `json:"content_search,omitempty"` MetadataFilters map[string]string `json:"metadata_filters,omitempty"` // key=metadataKey, value=metadataValue for filtering by metadata @@ -78,59 +78,59 @@ type SearchStats struct { // Log represents a complete log entry for a request/response cycle // This is the GORM model with appropriate tags type Log struct { - ID string `gorm:"primaryKey;type:varchar(255)" json:"id"` - ParentRequestID *string `gorm:"type:varchar(255)" json:"parent_request_id"` - Timestamp time.Time `gorm:"index;index:idx_logs_ts_provider_status,priority:1;not null" json:"timestamp"` - Object string `gorm:"type:varchar(255);index;not null;column:object_type" json:"object"` // text.completion, chat.completion, or embedding - Provider string `gorm:"type:varchar(255);index;index:idx_logs_ts_provider_status,priority:2;not null" json:"provider"` - Model string `gorm:"type:varchar(255);index;not null" json:"model"` - NumberOfRetries int `gorm:"default:0" json:"number_of_retries"` - FallbackIndex int `gorm:"default:0" json:"fallback_index"` - SelectedKeyID string `gorm:"type:varchar(255);index:idx_logs_selected_key_id" json:"selected_key_id"` - SelectedKeyName string `gorm:"type:varchar(255)" json:"selected_key_name"` - VirtualKeyID *string `gorm:"type:varchar(255);index:idx_logs_virtual_key_id" json:"virtual_key_id"` - VirtualKeyName *string `gorm:"type:varchar(255)" json:"virtual_key_name"` - RoutingEnginesUsedStr *string `gorm:"type:varchar(255);column:routing_engines_used" json:"-"` // Comma-separated routing engines - RoutingRuleID *string `gorm:"type:varchar(255);index:idx_logs_routing_rule_id" json:"routing_rule_id"` - RoutingRuleName *string `gorm:"type:varchar(255)" json:"routing_rule_name"` - InputHistory string `gorm:"type:text" json:"-"` // JSON serialized []schemas.ChatMessage - ResponsesInputHistory string `gorm:"type:text" json:"-"` // JSON serialized []schemas.ResponsesMessage - OutputMessage string `gorm:"type:text" json:"-"` // JSON serialized *schemas.ChatMessage - ResponsesOutput string `gorm:"type:text" json:"-"` // JSON serialized *schemas.ResponsesMessage - EmbeddingOutput string `gorm:"type:text" json:"-"` // JSON serialized [][]float32 - RerankOutput string `gorm:"type:text" json:"-"` // JSON serialized []schemas.RerankResult - Params string `gorm:"type:text" json:"-"` // JSON serialized *schemas.ModelParameters - Tools string `gorm:"type:text" json:"-"` // JSON serialized []schemas.Tool - ToolCalls string `gorm:"type:text" json:"-"` // JSON serialized []schemas.ToolCall (For backward compatibility, tool calls are now in the content) - SpeechInput string `gorm:"type:text" json:"-"` // JSON serialized *schemas.SpeechInput - TranscriptionInput string `gorm:"type:text" json:"-"` // JSON serialized *schemas.TranscriptionInput - ImageGenerationInput string `gorm:"type:text" json:"-"` // JSON serialized *schemas.ImageGenerationInput - VideoGenerationInput string `gorm:"type:text" json:"-"` // JSON serialized *schemas.VideoGenerationInput - SpeechOutput string `gorm:"type:text" json:"-"` // JSON serialized *schemas.BifrostSpeech - TranscriptionOutput string `gorm:"type:text" json:"-"` // JSON serialized *schemas.BifrostTranscribe - ImageGenerationOutput string `gorm:"type:text" json:"-"` // JSON serialized *schemas.BifrostImageGenerationResponse - ListModelsOutput string `gorm:"type:text" json:"-"` // JSON serialized []schemas.Model - VideoGenerationOutput string `gorm:"type:text" json:"-"` // JSON serialized *schemas.BifrostVideoGenerationResponse - VideoRetrieveOutput string `gorm:"type:text" json:"-"` // JSON serialized *schemas.BifrostVideoRetrieveResponse - VideoDownloadOutput string `gorm:"type:text" json:"-"` // JSON serialized *schemas.BifrostVideoDownloadResponse - VideoListOutput string `gorm:"type:text" json:"-"` // JSON serialized *schemas.BifrostVideoListResponse - VideoDeleteOutput string `gorm:"type:text" json:"-"` // JSON serialized *schemas.BifrostVideoDeleteResponse - CacheDebug string `gorm:"type:text" json:"-"` // JSON serialized *schemas.BifrostCacheDebug - Latency *float64 `gorm:"index:idx_logs_latency" json:"latency,omitempty"` - TokenUsage string `gorm:"type:text" json:"-"` // JSON serialized *schemas.LLMUsage - Cost *float64 `gorm:"index" json:"cost,omitempty"` // Cost in dollars (total cost of the request - includes cache lookup cost) - Status string `gorm:"type:varchar(50);index;index:idx_logs_ts_provider_status,priority:3;not null" json:"status"` // "processing", "success", or "error" - ErrorDetails string `gorm:"type:text" json:"-"` // JSON serialized *schemas.BifrostError - Stream bool `gorm:"default:false" json:"stream"` // true if this was a streaming response - ContentSummary string `gorm:"type:text" json:"-"` - RawRequest string `gorm:"type:text" json:"raw_request"` // Populated when `send-back-raw-request` is on - RawResponse string `gorm:"type:text" json:"raw_response"` // Populated when `send-back-raw-response` is on + ID string `gorm:"primaryKey;type:varchar(255)" json:"id"` + ParentRequestID *string `gorm:"type:varchar(255)" json:"parent_request_id"` + Timestamp time.Time `gorm:"index;index:idx_logs_ts_provider_status,priority:1;not null" json:"timestamp"` + Object string `gorm:"type:varchar(255);index;not null;column:object_type" json:"object"` // text.completion, chat.completion, or embedding + Provider string `gorm:"type:varchar(255);index;index:idx_logs_ts_provider_status,priority:2;not null" json:"provider"` + Model string `gorm:"type:varchar(255);index;not null" json:"model"` + NumberOfRetries int `gorm:"default:0" json:"number_of_retries"` + FallbackIndex int `gorm:"default:0" json:"fallback_index"` + SelectedKeyID string `gorm:"type:varchar(255);index:idx_logs_selected_key_id" json:"selected_key_id"` + SelectedKeyName string `gorm:"type:varchar(255)" json:"selected_key_name"` + VirtualKeyID *string `gorm:"type:varchar(255);index:idx_logs_virtual_key_id" json:"virtual_key_id"` + VirtualKeyName *string `gorm:"type:varchar(255)" json:"virtual_key_name"` + RoutingEnginesUsedStr *string `gorm:"type:varchar(255);column:routing_engines_used" json:"-"` // Comma-separated routing engines + RoutingRuleID *string `gorm:"type:varchar(255);index:idx_logs_routing_rule_id" json:"routing_rule_id"` + RoutingRuleName *string `gorm:"type:varchar(255)" json:"routing_rule_name"` + InputHistory string `gorm:"type:text" json:"-"` // JSON serialized []schemas.ChatMessage + ResponsesInputHistory string `gorm:"type:text" json:"-"` // JSON serialized []schemas.ResponsesMessage + OutputMessage string `gorm:"type:text" json:"-"` // JSON serialized *schemas.ChatMessage + ResponsesOutput string `gorm:"type:text" json:"-"` // JSON serialized *schemas.ResponsesMessage + EmbeddingOutput string `gorm:"type:text" json:"-"` // JSON serialized [][]float32 + RerankOutput string `gorm:"type:text" json:"-"` // JSON serialized []schemas.RerankResult + Params string `gorm:"type:text" json:"-"` // JSON serialized *schemas.ModelParameters + Tools string `gorm:"type:text" json:"-"` // JSON serialized []schemas.Tool + ToolCalls string `gorm:"type:text" json:"-"` // JSON serialized []schemas.ToolCall (For backward compatibility, tool calls are now in the content) + SpeechInput string `gorm:"type:text" json:"-"` // JSON serialized *schemas.SpeechInput + TranscriptionInput string `gorm:"type:text" json:"-"` // JSON serialized *schemas.TranscriptionInput + ImageGenerationInput string `gorm:"type:text" json:"-"` // JSON serialized *schemas.ImageGenerationInput + VideoGenerationInput string `gorm:"type:text" json:"-"` // JSON serialized *schemas.VideoGenerationInput + SpeechOutput string `gorm:"type:text" json:"-"` // JSON serialized *schemas.BifrostSpeech + TranscriptionOutput string `gorm:"type:text" json:"-"` // JSON serialized *schemas.BifrostTranscribe + ImageGenerationOutput string `gorm:"type:text" json:"-"` // JSON serialized *schemas.BifrostImageGenerationResponse + ListModelsOutput string `gorm:"type:text" json:"-"` // JSON serialized []schemas.Model + VideoGenerationOutput string `gorm:"type:text" json:"-"` // JSON serialized *schemas.BifrostVideoGenerationResponse + VideoRetrieveOutput string `gorm:"type:text" json:"-"` // JSON serialized *schemas.BifrostVideoRetrieveResponse + VideoDownloadOutput string `gorm:"type:text" json:"-"` // JSON serialized *schemas.BifrostVideoDownloadResponse + VideoListOutput string `gorm:"type:text" json:"-"` // JSON serialized *schemas.BifrostVideoListResponse + VideoDeleteOutput string `gorm:"type:text" json:"-"` // JSON serialized *schemas.BifrostVideoDeleteResponse + CacheDebug string `gorm:"type:text" json:"-"` // JSON serialized *schemas.BifrostCacheDebug + Latency *float64 `gorm:"index:idx_logs_latency" json:"latency,omitempty"` + TokenUsage string `gorm:"type:text" json:"-"` // JSON serialized *schemas.LLMUsage + Cost *float64 `gorm:"index" json:"cost,omitempty"` // Cost in dollars (total cost of the request - includes cache lookup cost) + Status string `gorm:"type:varchar(50);index;index:idx_logs_ts_provider_status,priority:3;not null" json:"status"` // "processing", "success", or "error" + ErrorDetails string `gorm:"type:text" json:"-"` // JSON serialized *schemas.BifrostError + Stream bool `gorm:"default:false" json:"stream"` // true if this was a streaming response + ContentSummary string `gorm:"type:text" json:"-"` + RawRequest string `gorm:"type:text" json:"raw_request"` // Populated when `send-back-raw-request` is on + RawResponse string `gorm:"type:text" json:"raw_response"` // Populated when `send-back-raw-response` is on PassthroughRequestBody string `gorm:"type:text" json:"passthrough_request_body,omitempty"` // Raw body for passthrough requests (UTF-8) PassthroughResponseBody string `gorm:"type:text" json:"passthrough_response_body,omitempty"` // Raw body for passthrough responses (UTF-8) - RoutingEngineLogs string `gorm:"type:text" json:"routing_engine_logs,omitempty"` // Formatted routing engine decision logs - Metadata *string `gorm:"type:text" json:"-"` // JSON serialized map[string]interface{} - IsLargePayloadRequest bool `gorm:"default:false" json:"is_large_payload_request"` - IsLargePayloadResponse bool `gorm:"default:false" json:"is_large_payload_response"` + RoutingEngineLogs string `gorm:"type:text" json:"routing_engine_logs,omitempty"` // Formatted routing engine decision logs + Metadata *string `gorm:"type:text" json:"-"` // JSON serialized map[string]interface{} + IsLargePayloadRequest bool `gorm:"default:false" json:"is_large_payload_request"` + IsLargePayloadResponse bool `gorm:"default:false" json:"is_large_payload_response"` // Denormalized token fields for easier querying PromptTokens int `gorm:"default:0" json:"-"` diff --git a/framework/modelcatalog/main.go b/framework/modelcatalog/main.go index 578ebd23a7..fef252a9a7 100644 --- a/framework/modelcatalog/main.go +++ b/framework/modelcatalog/main.go @@ -13,7 +13,6 @@ import ( "github.com/maximhq/bifrost/core/schemas" "github.com/maximhq/bifrost/framework/configstore" configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables" - "github.com/maximhq/bifrost/framework/pricingoverrides" ) // Default sync interval and config key @@ -39,10 +38,13 @@ type ModelCatalog struct { pricingData map[string]configstoreTables.TableModelPricing mu sync.RWMutex - // Scoped pricing overrides are maintained separately to avoid contention - // with pricing cache rebuilds. - scopedOverrides *compiledScopedOverrides - overridesMu sync.RWMutex + // rawOverrides is the canonical list of all active overrides. It exists solely + // to support incremental mutations: UpsertPricingOverrides and DeletePricingOverride + // iterate over it to rebuild the list, then derive customPricing from it. + // customPricing is the actual lookup structure used at query time. + rawOverrides []PricingOverride + customPricing *customPricingData + overridesMu sync.RWMutex modelPool map[schemas.ModelProvider][]string unfilteredModelPool map[schemas.ModelProvider][]string // model pool without allowed models filtering @@ -62,7 +64,10 @@ type PricingEntry struct { BaseModel string `json:"base_model,omitempty"` Provider string `json:"provider"` Mode string `json:"mode"` + PricingOptions +} +type PricingOptions struct { // Costs - Text InputCostPerToken float64 `json:"input_cost_per_token"` OutputCostPerToken float64 `json:"output_cost_per_token"` @@ -195,7 +200,6 @@ func Init(ctx context.Context, config *Config, configStore configstore.ConfigSto configStore: configStore, logger: logger, pricingData: make(map[string]configstoreTables.TableModelPricing), - scopedOverrides: &compiledScopedOverrides{buckets: make(map[string]*pricingOverrideScopeBucket), byID: make(map[string]pricingoverrides.Override)}, modelPool: make(map[schemas.ModelProvider][]string), unfilteredModelPool: make(map[schemas.ModelProvider][]string), baseModelIndex: make(map[string]string), @@ -251,6 +255,7 @@ func Init(ctx context.Context, config *Config, configStore configstore.ConfigSto // Populate model pool with normalized providers from pricing data mc.populateModelPoolFromPricingData() + if err := mc.loadPricingOverridesFromStore(ctx); err != nil { logger.Warn("failed to load pricing overrides: %v", err) } @@ -324,6 +329,7 @@ func (mc *ModelCatalog) ForceReloadPricing(ctx context.Context) error { // Rebuild model pool from updated pricing data mc.populateModelPoolFromPricingData() + if err := mc.loadPricingOverridesFromStore(ctx); err != nil { return fmt.Errorf("failed to load pricing overrides: %w", err) } @@ -791,6 +797,71 @@ func (mc *ModelCatalog) RefineModelForProvider(provider schemas.ModelProvider, m return model, nil } +// SetPricingOverrides replaces the full in-memory pricing override set. +func (mc *ModelCatalog) SetPricingOverrides(rows []configstoreTables.TablePricingOverride) error { + overrides := make([]PricingOverride, 0, len(rows)) + for i := range rows { + o, err := convertTablePricingOverrideToPricingOverride(&rows[i]) + if err != nil { + return err + } + overrides = append(overrides, o) + } + mc.overridesMu.Lock() + mc.rawOverrides = overrides + mc.customPricing = buildCustomPricingData(overrides) + mc.overridesMu.Unlock() + return nil +} + +// UpsertPricingOverrides inserts or replaces one or more pricing overrides in a single +// operation, rebuilding the lookup map only once at the end. +func (mc *ModelCatalog) UpsertPricingOverrides(rows ...*configstoreTables.TablePricingOverride) error { + overrides := make([]PricingOverride, 0, len(rows)) + for _, row := range rows { + o, err := convertTablePricingOverrideToPricingOverride(row) + if err != nil { + return err + } + overrides = append(overrides, o) + } + + // Build a set of IDs being upserted for O(1) lookup during dedup. + incoming := make(map[string]struct{}, len(overrides)) + for _, o := range overrides { + incoming[o.ID] = struct{}{} + } + + mc.overridesMu.Lock() + defer mc.overridesMu.Unlock() + + updated := make([]PricingOverride, 0, len(mc.rawOverrides)+len(overrides)) + for _, o := range mc.rawOverrides { + if _, replacing := incoming[o.ID]; !replacing { + updated = append(updated, o) + } + } + updated = append(updated, overrides...) + mc.rawOverrides = updated + mc.customPricing = buildCustomPricingData(updated) + return nil +} + +// DeletePricingOverride removes a pricing override by ID. +func (mc *ModelCatalog) DeletePricingOverride(id string) { + mc.overridesMu.Lock() + defer mc.overridesMu.Unlock() + + updated := make([]PricingOverride, 0, len(mc.rawOverrides)) + for _, o := range mc.rawOverrides { + if o.ID != id { + updated = append(updated, o) + } + } + mc.rawOverrides = updated + mc.customPricing = buildCustomPricingData(updated) +} + // IsTextCompletionSupported checks if a model supports text completion for the given provider. // Returns true if the model has pricing data for text completion ("text_completion"), // false otherwise. This is used by the litellmcompat plugin to determine whether to @@ -885,7 +956,6 @@ func NewTestCatalog(baseModelIndex map[string]string) *ModelCatalog { unfilteredModelPool: make(map[schemas.ModelProvider][]string), baseModelIndex: baseModelIndex, pricingData: make(map[string]configstoreTables.TableModelPricing), - scopedOverrides: &compiledScopedOverrides{buckets: make(map[string]*pricingOverrideScopeBucket), byID: make(map[string]pricingoverrides.Override)}, done: make(chan struct{}), } } diff --git a/framework/modelcatalog/main_test.go b/framework/modelcatalog/main_test.go index fced3ef0f7..c406a951f1 100644 --- a/framework/modelcatalog/main_test.go +++ b/framework/modelcatalog/main_test.go @@ -5,7 +5,6 @@ import ( "github.com/maximhq/bifrost/core/schemas" configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables" - "github.com/maximhq/bifrost/framework/pricingoverrides" "github.com/stretchr/testify/assert" ) @@ -21,7 +20,6 @@ func newTestCatalog(modelPool map[schemas.ModelProvider][]string, baseModelIndex modelPool: modelPool, baseModelIndex: baseModelIndex, pricingData: make(map[string]configstoreTables.TableModelPricing), - scopedOverrides: &compiledScopedOverrides{buckets: make(map[string]*pricingOverrideScopeBucket), byID: make(map[string]pricingoverrides.Override)}, } } diff --git a/framework/modelcatalog/overrides.go b/framework/modelcatalog/overrides.go index 3c51baaca8..1458d5a2d6 100644 --- a/framework/modelcatalog/overrides.go +++ b/framework/modelcatalog/overrides.go @@ -3,13 +3,11 @@ package modelcatalog import ( "context" "fmt" - "slices" "strings" "github.com/maximhq/bifrost/core/schemas" "github.com/maximhq/bifrost/framework/configstore" configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables" - "github.com/maximhq/bifrost/framework/pricingoverrides" ) // PricingLookupScopes carries the runtime identifiers used to resolve scoped @@ -20,351 +18,345 @@ type PricingLookupScopes struct { Provider string } -func normalizeScopeIDPointer(id *string) *string { - if id == nil { - return nil - } - trimmed := strings.TrimSpace(*id) - if trimmed == "" { - return nil - } - return &trimmed -} +// ScopeKind identifies which governance scope an override applies to. +type ScopeKind string -type compiledPricingOverride struct { - override pricingoverrides.Override - pricingPatch pricingoverrides.Patch - wildcardPrefix string - requestModes map[string]struct{} - hasRequestFilter bool - order int -} +const ( + ScopeKindGlobal ScopeKind = "global" + ScopeKindProvider ScopeKind = "provider" + ScopeKindProviderKey ScopeKind = "provider_key" + ScopeKindVirtualKey ScopeKind = "virtual_key" + ScopeKindVirtualKeyProvider ScopeKind = "virtual_key_provider" + ScopeKindVirtualKeyProviderKey ScopeKind = "virtual_key_provider_key" +) + +// MatchType controls how an override pattern is matched against model names. +type MatchType string -type pricingOverrideScopeBucket struct { - exact map[string][]compiledPricingOverride - wildcard map[string][]compiledPricingOverride - wildcardPrefixLengths []int +const ( + MatchTypeExact MatchType = "exact" + MatchTypeWildcard MatchType = "wildcard" +) + +// PricingOverride describes a scoped pricing override shared across config storage, +// model catalog compilation, and governance APIs. +type PricingOverride struct { + ID string `json:"id"` + Name string `json:"name"` + ScopeKind ScopeKind `json:"scope_kind"` + VirtualKeyID *string `json:"virtual_key_id,omitempty"` + ProviderID *string `json:"provider_id,omitempty"` + ProviderKeyID *string `json:"provider_key_id,omitempty"` + MatchType MatchType `json:"match_type"` + Pattern string `json:"pattern"` + RequestTypes []schemas.RequestType `json:"request_types,omitempty"` + Options PricingOptions `json:"options"` } -type compiledScopedOverrides struct { - buckets map[string]*pricingOverrideScopeBucket - byID map[string]pricingoverrides.Override +// customPricingEntry is a single flattened override ready for lookup. +type customPricingEntry struct { + id string + scopeKind ScopeKind + virtualKeyID string + providerID string + providerKeyID string + pattern string // exact model name, or wildcard prefix (trailing * stripped) + wildcard bool + requestModes map[string]struct{} // nil = matches all request types + options PricingOptions } -func normalizedScopeKey(scopeKind pricingoverrides.ScopeKind, virtualKeyID, providerID, providerKeyID string) string { - return string(scopeKind) + "|" + virtualKeyID + "|" + providerID + "|" + providerKeyID +// customPricingData is the in-memory lookup structure for pricing overrides. +// Exact matches are indexed by model name; wildcards are a flat slice. +type customPricingData struct { + exact map[string][]customPricingEntry + wildcard []customPricingEntry } -// SetPricingOverrides replaces the in-memory compiled pricing override set with -// overrides. -func (mc *ModelCatalog) SetPricingOverrides(overrides []pricingoverrides.Override) error { - compiled, err := compileScopedOverrides(overrides) - if err != nil { +// IsValid validates the shared pricing override contract before persistence or runtime use. +// +// Input: override — the PricingOverride to validate (receiver). +// Output: error — non-nil if any scope, pattern, or request-type constraint is violated. +func (override *PricingOverride) IsValid() error { + if err := override.validateScopeKind(); err != nil { return err } - - mc.overridesMu.Lock() - mc.scopedOverrides = compiled - mc.overridesMu.Unlock() - return nil + if err := override.validatePattern(); err != nil { + return err + } + return override.validateRequestTypes() } -// UpsertPricingOverride inserts or replaces a single pricing override in the -// compiled in-memory override set. -func (mc *ModelCatalog) UpsertPricingOverride(override pricingoverrides.Override) error { - mc.overridesMu.Lock() - defer mc.overridesMu.Unlock() - current := mc.scopedOverrides - - overrides := make([]pricingoverrides.Override, 0) - if current != nil { - for _, ov := range current.byID { - if ov.ID == override.ID { - continue - } - overrides = append(overrides, ov) +// validateScopeKind validates the scope identifiers required by override.ScopeKind. +// +// Input: override — receiver; ScopeKind and the three optional ID fields are inspected. +// Output: error — non-nil when required identifiers are absent or forbidden ones are present. +func (override *PricingOverride) validateScopeKind() error { + switch override.ScopeKind { + case ScopeKindGlobal: + if override.VirtualKeyID != nil || override.ProviderID != nil || override.ProviderKeyID != nil { + return fmt.Errorf("global scope_kind must not include scope identifiers") } - } - overrides = append(overrides, override) - slices.SortFunc(overrides, func(a, b pricingoverrides.Override) int { - if a.ID < b.ID { - return -1 + case ScopeKindProvider: + if override.ProviderID == nil { + return fmt.Errorf("provider_id is required for provider scope_kind") } - if a.ID > b.ID { - return 1 + if override.VirtualKeyID != nil || override.ProviderKeyID != nil { + return fmt.Errorf("provider scope_kind only supports provider_id") } - return 0 - }) - compiled, err := compileScopedOverrides(overrides) - if err != nil { - return err + case ScopeKindProviderKey: + if override.ProviderKeyID == nil { + return fmt.Errorf("provider_key_id is required for provider_key scope_kind") + } + if override.VirtualKeyID != nil || override.ProviderID != nil { + return fmt.Errorf("provider_key scope_kind only supports provider_key_id") + } + case ScopeKindVirtualKey: + if override.VirtualKeyID == nil { + return fmt.Errorf("virtual_key_id is required for virtual_key scope_kind") + } + if override.ProviderID != nil || override.ProviderKeyID != nil { + return fmt.Errorf("virtual_key scope_kind only supports virtual_key_id") + } + case ScopeKindVirtualKeyProvider: + if override.VirtualKeyID == nil || override.ProviderID == nil { + return fmt.Errorf("virtual_key_id and provider_id are required for virtual_key_provider scope_kind") + } + if override.ProviderKeyID != nil { + return fmt.Errorf("virtual_key_provider scope_kind does not support provider_key_id") + } + case ScopeKindVirtualKeyProviderKey: + if override.VirtualKeyID == nil || override.ProviderID == nil || override.ProviderKeyID == nil { + return fmt.Errorf("virtual_key_id, provider_id, and provider_key_id are required for virtual_key_provider_key scope_kind") + } + default: + return fmt.Errorf("unsupported scope_kind %q", override.ScopeKind) } - mc.scopedOverrides = compiled return nil } -// DeletePricingOverride removes a pricing override from the compiled in-memory -// override set. -func (mc *ModelCatalog) DeletePricingOverride(id string) { - mc.overridesMu.Lock() - defer mc.overridesMu.Unlock() - current := mc.scopedOverrides - if current == nil { - return +// validatePattern checks that Pattern is non-empty and consistent with MatchType. +// +// Input: override — receiver; Pattern and MatchType are inspected. +// Output: error — non-nil when the pattern is empty, contains a wildcard for exact mode, +// +// or does not end with a single trailing "*" for wildcard mode. +func (override *PricingOverride) validatePattern() error { + pattern := strings.TrimSpace(override.Pattern) + if pattern == "" { + return fmt.Errorf("pattern is required") } - overrides := make([]pricingoverrides.Override, 0, len(current.byID)) - for _, ov := range current.byID { - if ov.ID == id { - continue + switch override.MatchType { + case MatchTypeExact: + if strings.Contains(pattern, "*") { + return fmt.Errorf("exact match pattern must not contain wildcards") } - overrides = append(overrides, ov) - } - slices.SortFunc(overrides, func(a, b pricingoverrides.Override) int { - if a.ID < b.ID { - return -1 + case MatchTypeWildcard: + if !strings.HasSuffix(pattern, "*") { + return fmt.Errorf("wildcard pattern must end with *") } - if a.ID > b.ID { - return 1 + if strings.Count(pattern, "*") != 1 { + return fmt.Errorf("wildcard pattern must contain exactly one trailing *") } - return 0 - }) - compiled, err := compileScopedOverrides(overrides) - if err != nil { - mc.logger.Warn("failed to recompile overrides after delete: %v", err) - return + default: + return fmt.Errorf("unsupported match_type %q", override.MatchType) } - mc.scopedOverrides = compiled + return nil } -func compileScopedOverrides(overrides []pricingoverrides.Override) (*compiledScopedOverrides, error) { - compiled := &compiledScopedOverrides{ - buckets: make(map[string]*pricingOverrideScopeBucket), - byID: make(map[string]pricingoverrides.Override, len(overrides)), - } - - for i := range overrides { - item, err := compilePricingOverride(i, overrides[i]) - if err != nil { - return nil, err - } - virtualKeyID := "" - if item.override.VirtualKeyID != nil { - virtualKeyID = *item.override.VirtualKeyID - } - providerID := "" - if item.override.ProviderID != nil { - providerID = *item.override.ProviderID - } - providerKeyID := "" - if item.override.ProviderKeyID != nil { - providerKeyID = *item.override.ProviderKeyID +// validateRequestTypes checks that every entry in RequestTypes maps to a known pricing mode. +// +// Input: override — receiver; RequestTypes slice is inspected. +// Output: error — non-nil if any request type normalizes to "unknown". +func (override *PricingOverride) validateRequestTypes() error { + for _, rt := range override.RequestTypes { + if normalizeRequestType(rt) == "unknown" { + return fmt.Errorf("unsupported request_type %q", rt) } - key := normalizedScopeKey(item.override.ScopeKind, virtualKeyID, providerID, providerKeyID) - bucket := compiled.buckets[key] - if bucket == nil { - bucket = &pricingOverrideScopeBucket{ - exact: make(map[string][]compiledPricingOverride), - wildcard: make(map[string][]compiledPricingOverride), - } - compiled.buckets[key] = bucket - } - switch item.override.MatchType { - case pricingoverrides.MatchTypeExact: - bucket.exact[item.override.Pattern] = append(bucket.exact[item.override.Pattern], item) - case pricingoverrides.MatchTypeWildcard: - if _, exists := bucket.wildcard[item.wildcardPrefix]; !exists { - bucket.wildcardPrefixLengths = append(bucket.wildcardPrefixLengths, len(item.wildcardPrefix)) - } - bucket.wildcard[item.wildcardPrefix] = append(bucket.wildcard[item.wildcardPrefix], item) - } - compiled.byID[item.override.ID] = item.override - } - for key := range compiled.buckets { - bucket := compiled.buckets[key] - slices.SortFunc(bucket.wildcardPrefixLengths, func(a, b int) int { - if a > b { - return -1 - } - if a < b { - return 1 - } - return 0 - }) } - - return compiled, nil + return nil } -func (mc *ModelCatalog) applyScopedPricingOverrides(model string, requestType schemas.RequestType, pricing configstoreTables.TableModelPricing, scopes PricingLookupScopes) configstoreTables.TableModelPricing { - mc.overridesMu.RLock() - compiled := mc.scopedOverrides - mc.overridesMu.RUnlock() - if compiled == nil { - return pricing - } +// matchesScope reports whether the entry's governance scope matches the runtime identifiers. +// +// Input: scopes — runtime VirtualKeyID, SelectedKeyID, and Provider to match against. +// Output: bool — true when the entry's scope kind and stored IDs align with scopes. +func (e *customPricingEntry) matchesScope(scopes PricingLookupScopes) bool { + switch e.scopeKind { + case ScopeKindGlobal: + return true + case ScopeKindProvider: + return e.providerID == scopes.Provider + case ScopeKindProviderKey: + return e.providerKeyID == scopes.SelectedKeyID + case ScopeKindVirtualKey: + return e.virtualKeyID == scopes.VirtualKeyID + case ScopeKindVirtualKeyProvider: + return e.virtualKeyID == scopes.VirtualKeyID && e.providerID == scopes.Provider + case ScopeKindVirtualKeyProviderKey: + return e.virtualKeyID == scopes.VirtualKeyID && e.providerID == scopes.Provider && e.providerKeyID == scopes.SelectedKeyID + } + return false +} - mode := normalizeRequestType(requestType) - if mode == "unknown" { - return pricing +// matchesMode reports whether the entry applies to the given normalized request mode. +// +// Input: mode — normalized request type string (e.g. "chat", "embedding"). +// Output: bool — true when requestModes is empty (matches all) or contains mode. +func (e *customPricingEntry) matchesMode(mode string) bool { + if len(e.requestModes) == 0 { + return true } + _, ok := e.requestModes[mode] + return ok +} - if override := resolveScopedOverride(compiled, model, mode, scopes); override != nil { - return patchPricing(pricing, override.pricingPatch) +// resolve walks the 6-scope priority hierarchy and returns the first matching +// pricing patch for the given model, request mode, and runtime scopes. +// +// Input: model — exact model name being priced. +// +// mode — normalized request type string (e.g. "chat", "embedding"). +// scopes — runtime governance identifiers used to narrow the scope search. +// +// Output: *PricingOptions — pointer to the first matching override's options, or nil if none match. +func (c *customPricingData) resolve(model, mode string, scopes PricingLookupScopes) *PricingOptions { + for _, scopeKind := range scopePriorityOrder(scopes) { + for i := range c.exact[model] { + e := &c.exact[model][i] + if e.scopeKind == scopeKind && e.matchesScope(scopes) && e.matchesMode(mode) { + return &e.options + } + } + for i := range c.wildcard { + e := &c.wildcard[i] + if e.scopeKind == scopeKind && e.matchesScope(scopes) && strings.HasPrefix(model, e.pattern) && e.matchesMode(mode) { + return &e.options + } + } } - return pricing + return nil } -func resolveScopedOverride(compiled *compiledScopedOverrides, model, mode string, scopes PricingLookupScopes) *compiledPricingOverride { - scopeOrder := make([]string, 0, 6) +// scopePriorityOrder returns scope kinds in most-specific-first order, +// skipping scopes that can't match given the available runtime identifiers. +// +// Input: scopes — runtime governance identifiers; empty fields cause the corresponding scope kinds to be omitted. +// Output: []ScopeKind — ordered list from most-specific (VirtualKeyProviderKey) to least-specific (Global). +func scopePriorityOrder(scopes PricingLookupScopes) []ScopeKind { + order := make([]ScopeKind, 0, 6) if scopes.VirtualKeyID != "" && scopes.Provider != "" && scopes.SelectedKeyID != "" { - scopeOrder = append(scopeOrder, normalizedScopeKey(pricingoverrides.ScopeKindVirtualKeyProviderKey, scopes.VirtualKeyID, scopes.Provider, scopes.SelectedKeyID)) + order = append(order, ScopeKindVirtualKeyProviderKey) } if scopes.VirtualKeyID != "" && scopes.Provider != "" { - scopeOrder = append(scopeOrder, normalizedScopeKey(pricingoverrides.ScopeKindVirtualKeyProvider, scopes.VirtualKeyID, scopes.Provider, "")) + order = append(order, ScopeKindVirtualKeyProvider) } if scopes.VirtualKeyID != "" { - scopeOrder = append(scopeOrder, normalizedScopeKey(pricingoverrides.ScopeKindVirtualKey, scopes.VirtualKeyID, "", "")) + order = append(order, ScopeKindVirtualKey) } if scopes.SelectedKeyID != "" { - scopeOrder = append(scopeOrder, normalizedScopeKey(pricingoverrides.ScopeKindProviderKey, "", "", scopes.SelectedKeyID)) + order = append(order, ScopeKindProviderKey) } if scopes.Provider != "" { - scopeOrder = append(scopeOrder, normalizedScopeKey(pricingoverrides.ScopeKindProvider, "", scopes.Provider, "")) + order = append(order, ScopeKindProvider) } - scopeOrder = append(scopeOrder, normalizedScopeKey(pricingoverrides.ScopeKindGlobal, "", "", "")) + order = append(order, ScopeKindGlobal) + return order +} - for _, key := range scopeOrder { - bucket := compiled.buckets[key] - if bucket == nil { - continue +// buildCustomPricingData constructs a customPricingData lookup structure from a raw override slice. +// +// Input: overrides — slice of validated PricingOverride records loaded from the config store. +// Output: *customPricingData — ready-to-query structure with exact and wildcard indexes populated. +func buildCustomPricingData(overrides []PricingOverride) *customPricingData { + data := &customPricingData{ + exact: make(map[string][]customPricingEntry, len(overrides)), + } + for _, o := range overrides { + entry := customPricingEntry{ + id: o.ID, + scopeKind: o.ScopeKind, + options: o.Options, } - if best := selectBestOverride(bucket.exact[model], mode); best != nil { - return best + if o.VirtualKeyID != nil { + entry.virtualKeyID = *o.VirtualKeyID } - for _, prefixLength := range bucket.wildcardPrefixLengths { - if prefixLength > len(model) { - continue - } - prefix := model[:prefixLength] - if best := selectBestOverride(bucket.wildcard[prefix], mode); best != nil { - return best - } + if o.ProviderID != nil { + entry.providerID = *o.ProviderID } - } - return nil -} - -func selectBestOverride(candidates []compiledPricingOverride, mode string) *compiledPricingOverride { - if len(candidates) == 0 { - return nil - } - var bestSpecific *compiledPricingOverride - var bestGeneric *compiledPricingOverride - for i := range candidates { - candidate := &candidates[i] - if candidate.hasRequestFilter { - if _, ok := candidate.requestModes[mode]; !ok { - continue - } - if bestSpecific == nil || isBetterOverrideCandidate(candidate, bestSpecific) { - bestSpecific = candidate + if o.ProviderKeyID != nil { + entry.providerKeyID = *o.ProviderKeyID + } + if len(o.RequestTypes) > 0 { + entry.requestModes = make(map[string]struct{}, len(o.RequestTypes)) + for _, rt := range o.RequestTypes { + entry.requestModes[normalizeRequestType(rt)] = struct{}{} } - continue } - if bestGeneric == nil || isBetterOverrideCandidate(candidate, bestGeneric) { - bestGeneric = candidate + switch o.MatchType { + case MatchTypeExact: + entry.pattern = o.Pattern + data.exact[o.Pattern] = append(data.exact[o.Pattern], entry) + case MatchTypeWildcard: + entry.pattern = strings.TrimSuffix(o.Pattern, "*") + entry.wildcard = true + data.wildcard = append(data.wildcard, entry) } } - if bestSpecific != nil { - return bestSpecific - } - return bestGeneric + return data } -func isBetterOverrideCandidate(candidate, current *compiledPricingOverride) bool { - if candidate.override.UpdatedAt.After(current.override.UpdatedAt) { - return true - } - if candidate.override.UpdatedAt.Before(current.override.UpdatedAt) { - return false - } - - if candidate.override.ID < current.override.ID { - return true - } - if candidate.override.ID > current.override.ID { - return false - } - - return candidate.order < current.order -} - -func compilePricingOverride(order int, override pricingoverrides.Override) (compiledPricingOverride, error) { - override.VirtualKeyID = normalizeScopeIDPointer(override.VirtualKeyID) - override.ProviderID = normalizeScopeIDPointer(override.ProviderID) - override.ProviderKeyID = normalizeScopeIDPointer(override.ProviderKeyID) - - if err := pricingoverrides.ValidateScopeKind(override.ScopeKind, override.VirtualKeyID, override.ProviderID, override.ProviderKeyID); err != nil { - return compiledPricingOverride{}, err - } - - pattern, err := pricingoverrides.ValidatePattern(override.MatchType, override.Pattern) - if err != nil { - return compiledPricingOverride{}, err - } - override.Pattern = pattern +// applyPricingOverrides resolves any active scoped pricing override for the given model +// and request type, then patches the catalog base pricing with the override values. +// It returns the original pricing unchanged when no custom pricing tree is loaded or +// when the request type cannot be mapped to a known pricing mode. +// +// Input: model — exact model name being priced. +// +// requestType — the request type used to derive the pricing mode. +// pricing — base pricing row from the catalog to patch. +// scopes — runtime governance identifiers used to narrow the override scope. +// +// Output: TableModelPricing — patched pricing row, or pricing unchanged if no override matches. +func (mc *ModelCatalog) applyPricingOverrides(model string, requestType schemas.RequestType, pricing configstoreTables.TableModelPricing, scopes PricingLookupScopes) configstoreTables.TableModelPricing { + mc.overridesMu.RLock() + custom := mc.customPricing + mc.overridesMu.RUnlock() - compiled := compiledPricingOverride{ - override: override, - pricingPatch: override.Patch, - requestModes: make(map[string]struct{}), - order: order, + if custom == nil { + return pricing } - switch override.MatchType { - case pricingoverrides.MatchTypeExact: - case pricingoverrides.MatchTypeWildcard: - compiled.wildcardPrefix = strings.TrimSuffix(override.Pattern, "*") - default: - return compiledPricingOverride{}, fmt.Errorf("unsupported match_type: %s", override.MatchType) + mode := normalizeRequestType(requestType) + if mode == "unknown" { + return pricing } - if len(override.RequestTypes) > 0 { - if err := pricingoverrides.ValidateRequestTypes(override.RequestTypes); err != nil { - return compiledPricingOverride{}, err - } - compiled.hasRequestFilter = true - for _, requestType := range override.RequestTypes { - compiled.requestModes[normalizeRequestType(requestType)] = struct{}{} - } + if patch := custom.resolve(model, mode, scopes); patch != nil { + return patchPricing(pricing, *patch) } - - return compiled, nil + return pricing } -func patchPricing(pricing configstoreTables.TableModelPricing, override pricingoverrides.Patch) configstoreTables.TableModelPricing { +// patchPricing applies non-zero override values onto a copy of the base pricing row. +// For plain float64 fields (InputCostPerToken, OutputCostPerToken), a non-zero override +// replaces the base value. For pointer fields, a non-nil override pointer replaces the +// corresponding destination pointer; a nil override leaves the base value intact. +// The original pricing row is never modified; a patched copy is always returned. +// +// Input: pricing — base pricing row from the catalog. +// +// override — pricing options sourced from the matched override entry. +// +// Output: TableModelPricing — shallow copy of pricing with override fields applied. +func patchPricing(pricing configstoreTables.TableModelPricing, override PricingOptions) configstoreTables.TableModelPricing { patched := pricing - for _, field := range []struct { - dst *float64 - src *float64 - }{ - {dst: &patched.InputCostPerToken, src: override.InputCostPerToken}, - {dst: &patched.OutputCostPerToken, src: override.OutputCostPerToken}, - } { - setFloatValue(field.dst, field.src) - } - if override.OutputCostPerImageLowQuality != nil { - patched.OutputCostPerImageLowQuality = override.OutputCostPerImageLowQuality + if override.InputCostPerToken != 0 { + patched.InputCostPerToken = override.InputCostPerToken } - if override.OutputCostPerImageMediumQuality != nil { - patched.OutputCostPerImageMediumQuality = override.OutputCostPerImageMediumQuality - } - if override.OutputCostPerImageHighQuality != nil { - patched.OutputCostPerImageHighQuality = override.OutputCostPerImageHighQuality - } - if override.OutputCostPerImageAutoQuality != nil { - patched.OutputCostPerImageAutoQuality = override.OutputCostPerImageAutoQuality + if override.OutputCostPerToken != 0 { + patched.OutputCostPerToken = override.OutputCostPerToken } for _, field := range []struct { @@ -412,35 +404,25 @@ func patchPricing(pricing configstoreTables.TableModelPricing, override pricingo {dst: &patched.CacheReadInputImageTokenCost, src: override.CacheReadInputImageTokenCost}, {dst: &patched.SearchContextCostPerQuery, src: override.SearchContextCostPerQuery}, {dst: &patched.CodeInterpreterCostPerSession, src: override.CodeInterpreterCostPerSession}, + {dst: &patched.OutputCostPerImageLowQuality, src: override.OutputCostPerImageLowQuality}, + {dst: &patched.OutputCostPerImageMediumQuality, src: override.OutputCostPerImageMediumQuality}, + {dst: &patched.OutputCostPerImageHighQuality, src: override.OutputCostPerImageHighQuality}, + {dst: &patched.OutputCostPerImageAutoQuality, src: override.OutputCostPerImageAutoQuality}, } { - setOptionalFloatValue(field.dst, field.src) + if field.src != nil { + *field.dst = field.src + } } return patched } -func setFloatValue(dst *float64, src *float64) { - if src != nil { - *dst = *src - } -} - -func setOptionalFloatValue(dst **float64, src *float64) { - if src != nil { - *dst = src - } -} - func (mc *ModelCatalog) loadPricingOverridesFromStore(ctx context.Context) error { if mc.configStore == nil { return nil } - rows, err := mc.configStore.GetPricingOverrides(ctx, configstore.PricingOverrideFilter{}) + rows, err := mc.configStore.GetPricingOverrides(ctx, configstore.PricingOverrideFilters{}) if err != nil { return err } - overrides := make([]pricingoverrides.Override, 0, len(rows)) - for i := range rows { - overrides = append(overrides, rows[i].ToPricingOverride()) - } - return mc.SetPricingOverrides(overrides) + return mc.SetPricingOverrides(rows) } diff --git a/framework/modelcatalog/overrides_test.go b/framework/modelcatalog/overrides_test.go index cf3ee5a8af..93e51fd330 100644 --- a/framework/modelcatalog/overrides_test.go +++ b/framework/modelcatalog/overrides_test.go @@ -1,13 +1,10 @@ package modelcatalog import ( - "fmt" "testing" - "time" "github.com/maximhq/bifrost/core/schemas" configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables" - "github.com/maximhq/bifrost/framework/pricingoverrides" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -25,40 +22,6 @@ func (noOpLogger) LogHTTPRequest(schemas.LogLevel, string) schemas.LogEventBuild return schemas.NoopLogEvent } -type providerOverrideCompat struct { - ModelPattern string - MatchType pricingoverrides.MatchType - RequestTypes []schemas.RequestType - InputCostPerToken *float64 - ID string - UpdatedAt time.Time -} - -func setProviderScopedOverrides(t *testing.T, mc *ModelCatalog, provider schemas.ModelProvider, overrides []providerOverrideCompat) error { - t.Helper() - scopeID := string(provider) - compiled := make([]pricingoverrides.Override, 0, len(overrides)) - for i, override := range overrides { - id := override.ID - if id == "" { - id = fmt.Sprintf("%s-override-%d", scopeID, i) - } - compiled = append(compiled, pricingoverrides.Override{ - ID: id, - ScopeKind: pricingoverrides.ScopeKindProvider, - ProviderID: &scopeID, - MatchType: override.MatchType, - Pattern: override.ModelPattern, - RequestTypes: override.RequestTypes, - UpdatedAt: override.UpdatedAt, - Patch: pricingoverrides.Patch{ - InputCostPerToken: override.InputCostPerToken, - }, - }) - } - return mc.SetPricingOverrides(compiled) -} - func TestGetPricing_OverridePrecedenceExactWildcard(t *testing.T) { mc := newTestCatalog(nil, nil) mc.logger = noOpLogger{} @@ -70,26 +33,29 @@ func TestGetPricing_OverridePrecedenceExactWildcard(t *testing.T) { OutputCostPerToken: 2, } - exact := 20.0 - wildcard := 10.0 - require.NoError(t, setProviderScopedOverrides(t, mc, schemas.OpenAI, []providerOverrideCompat{ + providerID := "openai" + require.NoError(t, mc.SetPricingOverrides([]configstoreTables.TablePricingOverride{ { - ModelPattern: "gpt-*", - MatchType: pricingoverrides.MatchTypeWildcard, - InputCostPerToken: &wildcard, + ID: "openai-override-0", + ScopeKind: string(ScopeKindProvider), + ProviderID: &providerID, + MatchType: string(MatchTypeWildcard), + Pattern: "gpt-*", + PricingPatchJSON: `{"input_cost_per_token":10}`, }, { - ModelPattern: "gpt-4o", - MatchType: pricingoverrides.MatchTypeExact, - InputCostPerToken: &exact, + ID: "openai-override-1", + ScopeKind: string(ScopeKindProvider), + ProviderID: &providerID, + MatchType: string(MatchTypeExact), + Pattern: "gpt-4o", + PricingPatchJSON: `{"input_cost_per_token":20}`, }, })) - pricing, ok := mc.getPricing("gpt-4o", "openai", schemas.ChatCompletionRequest) - require.True(t, ok) + pricing := mc.resolvePricing("openai", "gpt-4o", "", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "openai"}) require.NotNil(t, pricing) assert.Equal(t, 20.0, pricing.InputCostPerToken) - assert.Equal(t, 2.0, pricing.OutputCostPerToken) } func TestGetPricing_RequestTypeSpecificOverrideBeatsGeneric(t *testing.T) { @@ -104,24 +70,28 @@ func TestGetPricing_RequestTypeSpecificOverrideBeatsGeneric(t *testing.T) { OutputCostPerToken: 2, } - specific := 15.0 - generic := 9.0 - require.NoError(t, setProviderScopedOverrides(t, mc, schemas.OpenAI, []providerOverrideCompat{ + providerID := "openai" + require.NoError(t, mc.SetPricingOverrides([]configstoreTables.TablePricingOverride{ { - ModelPattern: "gpt-4o", - MatchType: pricingoverrides.MatchTypeExact, - InputCostPerToken: &generic, + ID: "openai-generic", + ScopeKind: string(ScopeKindProvider), + ProviderID: &providerID, + MatchType: string(MatchTypeExact), + Pattern: "gpt-4o", + PricingPatchJSON: `{"input_cost_per_token":9}`, }, { - ModelPattern: "gpt-4o", - MatchType: pricingoverrides.MatchTypeExact, - RequestTypes: []schemas.RequestType{schemas.ResponsesRequest}, - InputCostPerToken: &specific, + ID: "openai-specific", + ScopeKind: string(ScopeKindProvider), + ProviderID: &providerID, + MatchType: string(MatchTypeExact), + Pattern: "gpt-4o", + RequestTypes: []schemas.RequestType{schemas.ResponsesRequest}, + PricingPatchJSON: `{"input_cost_per_token":15}`, }, })) - pricing, ok := mc.getPricing("gpt-4o", "openai", schemas.ResponsesRequest) - require.True(t, ok) + pricing := mc.resolvePricing("openai", "gpt-4o", "", schemas.ResponsesRequest, PricingLookupScopes{Provider: "openai"}) require.NotNil(t, pricing) assert.Equal(t, 15.0, pricing.InputCostPerToken) } @@ -138,17 +108,19 @@ func TestGetPricing_AppliesOverrideAfterFallbackResolution(t *testing.T) { OutputCostPerToken: 2, } - override := 7.0 - require.NoError(t, setProviderScopedOverrides(t, mc, schemas.Gemini, []providerOverrideCompat{ + geminiProviderID := "gemini" + require.NoError(t, mc.SetPricingOverrides([]configstoreTables.TablePricingOverride{ { - ModelPattern: "gpt-4o", - MatchType: pricingoverrides.MatchTypeExact, - InputCostPerToken: &override, + ID: "gemini-override", + ScopeKind: string(ScopeKindProvider), + ProviderID: &geminiProviderID, + MatchType: string(MatchTypeExact), + Pattern: "gpt-4o", + PricingPatchJSON: `{"input_cost_per_token":7}`, }, })) - pricing, ok := mc.getPricing("gpt-4o", "gemini", schemas.ChatCompletionRequest) - require.True(t, ok) + pricing := mc.resolvePricing("gemini", "gpt-4o", "", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "gemini"}) require.NotNil(t, pricing) assert.Equal(t, 7.0, pricing.InputCostPerToken) } @@ -164,29 +136,19 @@ func TestGetPricing_DeploymentLookupUsesRequestedModelForOverrideMatching(t *tes OutputCostPerToken: 2, } - override := 7.0 - providerID := string(schemas.OpenAI) - require.NoError(t, mc.SetPricingOverrides([]pricingoverrides.Override{ + providerID := "openai" + require.NoError(t, mc.SetPricingOverrides([]configstoreTables.TablePricingOverride{ { - ID: "requested-model-override", - ScopeKind: pricingoverrides.ScopeKindProvider, - ProviderID: &providerID, - MatchType: pricingoverrides.MatchTypeExact, - Pattern: "gpt-4o", - Patch: pricingoverrides.Patch{ - InputCostPerToken: &override, - }, + ID: "requested-model-override", + ScopeKind: string(ScopeKindProvider), + ProviderID: &providerID, + MatchType: string(MatchTypeExact), + Pattern: "gpt-4o", + PricingPatchJSON: `{"input_cost_per_token":7}`, }, })) - pricing, ok := mc.getPricingLocked( - "dep-gpt4o", - "gpt-4o", - "openai", - schemas.ChatCompletionRequest, - PricingLookupScopes{Provider: "openai"}, - ) - require.True(t, ok) + pricing := mc.resolvePricing("openai", "gpt-4o", "dep-gpt4o", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "openai"}) require.NotNil(t, pricing) assert.Equal(t, 7.0, pricing.InputCostPerToken) } @@ -202,35 +164,28 @@ func TestGetPricing_FallbackUsesRequestedProviderForScopeMatching(t *testing.T) OutputCostPerToken: 2, } - geminiProviderID := string(schemas.Gemini) - vertexProviderID := string(schemas.Vertex) - geminiOverrideCost := 5.0 - vertexOverrideCost := 9.0 - require.NoError(t, mc.SetPricingOverrides([]pricingoverrides.Override{ + geminiProviderID := "gemini" + vertexProviderID := "vertex" + require.NoError(t, mc.SetPricingOverrides([]configstoreTables.TablePricingOverride{ { - ID: "gemini-provider-override", - ScopeKind: pricingoverrides.ScopeKindProvider, - ProviderID: &geminiProviderID, - MatchType: pricingoverrides.MatchTypeExact, - Pattern: "gpt-4o", - Patch: pricingoverrides.Patch{ - InputCostPerToken: &geminiOverrideCost, - }, + ID: "gemini-provider-override", + ScopeKind: string(ScopeKindProvider), + ProviderID: &geminiProviderID, + MatchType: string(MatchTypeExact), + Pattern: "gpt-4o", + PricingPatchJSON: `{"input_cost_per_token":5}`, }, { - ID: "vertex-provider-override", - ScopeKind: pricingoverrides.ScopeKindProvider, - ProviderID: &vertexProviderID, - MatchType: pricingoverrides.MatchTypeExact, - Pattern: "gpt-4o", - Patch: pricingoverrides.Patch{ - InputCostPerToken: &vertexOverrideCost, - }, + ID: "vertex-provider-override", + ScopeKind: string(ScopeKindProvider), + ProviderID: &vertexProviderID, + MatchType: string(MatchTypeExact), + Pattern: "gpt-4o", + PricingPatchJSON: `{"input_cost_per_token":9}`, }, })) - pricing, ok := mc.getPricing("gpt-4o", "gemini", schemas.ChatCompletionRequest) - require.True(t, ok) + pricing := mc.resolvePricing("gemini", "gpt-4o", "", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "gemini"}) require.NotNil(t, pricing) assert.Equal(t, 5.0, pricing.InputCostPerToken) } @@ -247,17 +202,19 @@ func TestGetPricing_ExactOverrideDoesNotMatchProviderPrefixedModel(t *testing.T) OutputCostPerToken: 2, } - override := 19.0 - require.NoError(t, setProviderScopedOverrides(t, mc, schemas.OpenAI, []providerOverrideCompat{ + providerID := "openai" + require.NoError(t, mc.SetPricingOverrides([]configstoreTables.TablePricingOverride{ { - ModelPattern: "gpt-4o", - MatchType: pricingoverrides.MatchTypeExact, - InputCostPerToken: &override, + ID: "openai-override-0", + ScopeKind: string(ScopeKindProvider), + ProviderID: &providerID, + MatchType: string(MatchTypeExact), + Pattern: "gpt-4o", + PricingPatchJSON: `{"input_cost_per_token":19}`, }, })) - pricing, ok := mc.getPricing("openai/gpt-4o", "openai", schemas.ChatCompletionRequest) - require.True(t, ok) + pricing := mc.resolvePricing("openai", "openai/gpt-4o", "", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "openai"}) require.NotNil(t, pricing) assert.Equal(t, 1.0, pricing.InputCostPerToken) } @@ -276,17 +233,19 @@ func TestGetPricing_NoMatchingOverrideLeavesPricingUnchanged(t *testing.T) { CacheReadInputTokenCost: &baseCacheRead, } - override := 9.0 - require.NoError(t, setProviderScopedOverrides(t, mc, schemas.OpenAI, []providerOverrideCompat{ + providerID := "openai" + require.NoError(t, mc.SetPricingOverrides([]configstoreTables.TablePricingOverride{ { - ModelPattern: "claude-*", - MatchType: pricingoverrides.MatchTypeWildcard, - InputCostPerToken: &override, + ID: "openai-override-0", + ScopeKind: string(ScopeKindProvider), + ProviderID: &providerID, + MatchType: string(MatchTypeWildcard), + Pattern: "claude-*", + PricingPatchJSON: `{"input_cost_per_token":9}`, }, })) - pricing, ok := mc.getPricing("gpt-4o", "openai", schemas.ChatCompletionRequest) - require.True(t, ok) + pricing := mc.resolvePricing("openai", "gpt-4o", "", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "openai"}) require.NotNil(t, pricing) assert.Equal(t, 1.0, pricing.InputCostPerToken) assert.Equal(t, 2.0, pricing.OutputCostPerToken) @@ -306,24 +265,25 @@ func TestDeleteProviderPricingOverrides_StopsApplying(t *testing.T) { OutputCostPerToken: 2, } - override := 11.0 - require.NoError(t, setProviderScopedOverrides(t, mc, schemas.OpenAI, []providerOverrideCompat{ + providerID := "openai" + require.NoError(t, mc.SetPricingOverrides([]configstoreTables.TablePricingOverride{ { - ModelPattern: "gpt-4o", - MatchType: pricingoverrides.MatchTypeExact, - InputCostPerToken: &override, + ID: "openai-override-0", + ScopeKind: string(ScopeKindProvider), + ProviderID: &providerID, + MatchType: string(MatchTypeExact), + Pattern: "gpt-4o", + PricingPatchJSON: `{"input_cost_per_token":11}`, }, })) - pricing, ok := mc.getPricing("gpt-4o", "openai", schemas.ChatCompletionRequest) - require.True(t, ok) + pricing := mc.resolvePricing("openai", "gpt-4o", "", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "openai"}) require.NotNil(t, pricing) assert.Equal(t, 11.0, pricing.InputCostPerToken) require.NoError(t, mc.SetPricingOverrides(nil)) - pricing, ok = mc.getPricing("gpt-4o", "openai", schemas.ChatCompletionRequest) - require.True(t, ok) + pricing = mc.resolvePricing("openai", "gpt-4o", "", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "openai"}) require.NotNil(t, pricing) assert.Equal(t, 1.0, pricing.InputCostPerToken) } @@ -340,28 +300,34 @@ func TestGetPricing_WildcardSpecificityLongerLiteralWins(t *testing.T) { OutputCostPerToken: 2, } - generic := 5.0 - specific := 6.0 - require.NoError(t, setProviderScopedOverrides(t, mc, schemas.OpenAI, []providerOverrideCompat{ + providerID := "openai" + require.NoError(t, mc.SetPricingOverrides([]configstoreTables.TablePricingOverride{ { - ModelPattern: "gpt-*", - MatchType: pricingoverrides.MatchTypeWildcard, - InputCostPerToken: &generic, + ID: "openai-override-0", + ScopeKind: string(ScopeKindProvider), + ProviderID: &providerID, + MatchType: string(MatchTypeWildcard), + Pattern: "gpt-*", + PricingPatchJSON: `{"input_cost_per_token":5}`, }, { - ModelPattern: "gpt-4o*", - MatchType: pricingoverrides.MatchTypeWildcard, - InputCostPerToken: &specific, + ID: "openai-override-1", + ScopeKind: string(ScopeKindProvider), + ProviderID: &providerID, + MatchType: string(MatchTypeWildcard), + Pattern: "gpt-4o*", + PricingPatchJSON: `{"input_cost_per_token":6}`, }, })) - pricing, ok := mc.getPricing("gpt-4o-mini", "openai", schemas.ChatCompletionRequest) - require.True(t, ok) + pricing := mc.resolvePricing("openai", "gpt-4o-mini", "", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "openai"}) require.NotNil(t, pricing) assert.Equal(t, 6.0, pricing.InputCostPerToken) } -func TestGetPricing_TieBreakLatestUpdatedAtWins(t *testing.T) { +// TestGetPricing_FirstInsertionWinsOnTie verifies that when multiple wildcard overrides +// match the same model and scope, the first one inserted takes precedence. +func TestGetPricing_FirstInsertionWinsOnTie(t *testing.T) { mc := newTestCatalog(nil, nil) mc.logger = noOpLogger{} mc.pricingData[makeKey("gpt-4o-mini", "openai", "chat")] = configstoreTables.TableModelPricing{ @@ -372,65 +338,27 @@ func TestGetPricing_TieBreakLatestUpdatedAtWins(t *testing.T) { OutputCostPerToken: 2, } - first := 8.0 - second := 9.0 - now := time.Now().UTC() - require.NoError(t, setProviderScopedOverrides(t, mc, schemas.OpenAI, []providerOverrideCompat{ + providerID := "openai" + require.NoError(t, mc.SetPricingOverrides([]configstoreTables.TablePricingOverride{ { - ModelPattern: "gpt-4o*", - MatchType: pricingoverrides.MatchTypeWildcard, - InputCostPerToken: &first, - ID: "older", - UpdatedAt: now.Add(-1 * time.Minute), + ID: "a-override", + ScopeKind: string(ScopeKindProvider), + ProviderID: &providerID, + MatchType: string(MatchTypeWildcard), + Pattern: "gpt-4o*", + PricingPatchJSON: `{"input_cost_per_token":8}`, }, { - ModelPattern: "gpt-4o*", - MatchType: pricingoverrides.MatchTypeWildcard, - InputCostPerToken: &second, - ID: "newer", - UpdatedAt: now, + ID: "b-override", + ScopeKind: string(ScopeKindProvider), + ProviderID: &providerID, + MatchType: string(MatchTypeWildcard), + Pattern: "gpt-4o*", + PricingPatchJSON: `{"input_cost_per_token":9}`, }, })) - pricing, ok := mc.getPricing("gpt-4o-mini", "openai", schemas.ChatCompletionRequest) - require.True(t, ok) - require.NotNil(t, pricing) - assert.Equal(t, 9.0, pricing.InputCostPerToken) -} - -func TestGetPricing_TieBreakIDWinsWhenUpdatedAtEqual(t *testing.T) { - mc := newTestCatalog(nil, nil) - mc.logger = noOpLogger{} - mc.pricingData[makeKey("gpt-4o-mini", "openai", "chat")] = configstoreTables.TableModelPricing{ - Model: "gpt-4o-mini", - Provider: "openai", - Mode: "chat", - InputCostPerToken: 1, - OutputCostPerToken: 2, - } - - first := 8.0 - second := 9.0 - now := time.Now().UTC() - require.NoError(t, setProviderScopedOverrides(t, mc, schemas.OpenAI, []providerOverrideCompat{ - { - ModelPattern: "gpt-4o*", - MatchType: pricingoverrides.MatchTypeWildcard, - InputCostPerToken: &first, - ID: "a-override", - UpdatedAt: now, - }, - { - ModelPattern: "gpt-4o*", - MatchType: pricingoverrides.MatchTypeWildcard, - InputCostPerToken: &second, - ID: "b-override", - UpdatedAt: now, - }, - })) - - pricing, ok := mc.getPricing("gpt-4o-mini", "openai", schemas.ChatCompletionRequest) - require.True(t, ok) + pricing := mc.resolvePricing("openai", "gpt-4o-mini", "", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "openai"}) require.NotNil(t, pricing) assert.Equal(t, 8.0, pricing.InputCostPerToken) } @@ -449,17 +377,16 @@ func TestPatchPricing_PartialPatchOnlyChangesSpecifiedFields(t *testing.T) { InputCostPerImage: &baseInputImage, } - patched := patchPricing(base, pricingoverrides.Patch{ - InputCostPerToken: schemas.Ptr(3.0), - CacheReadInputTokenCost: schemas.Ptr(0.9), + cacheRead := 0.9 + patched := patchPricing(base, PricingOptions{ + InputCostPerToken: 3.0, + CacheReadInputTokenCost: &cacheRead, }) - // Changed fields assert.Equal(t, 3.0, patched.InputCostPerToken) require.NotNil(t, patched.CacheReadInputTokenCost) assert.Equal(t, 0.9, *patched.CacheReadInputTokenCost) - // Unchanged fields assert.Equal(t, 2.0, patched.OutputCostPerToken) require.NotNil(t, patched.InputCostPerImage) assert.Equal(t, 0.7, *patched.InputCostPerImage) @@ -473,50 +400,37 @@ func TestApplyScopedPricingOverrides_ScopePrecedence(t *testing.T) { providerKeyScopeID := "provider-key-1" virtualKeyScopeID := "virtual-key-1" - globalCost := 2.0 - providerCost := 3.0 - providerKeyCost := 4.0 - virtualKeyCost := 5.0 - - require.NoError(t, mc.SetPricingOverrides([]pricingoverrides.Override{ + require.NoError(t, mc.SetPricingOverrides([]configstoreTables.TablePricingOverride{ { - ID: "global", - ScopeKind: pricingoverrides.ScopeKindGlobal, - MatchType: pricingoverrides.MatchTypeExact, - Pattern: "gpt-5-nano", - Patch: pricingoverrides.Patch{ - InputCostPerToken: &globalCost, - }, + ID: "global", + ScopeKind: string(ScopeKindGlobal), + MatchType: string(MatchTypeExact), + Pattern: "gpt-5-nano", + PricingPatchJSON: `{"input_cost_per_token":2}`, }, { - ID: "provider", - ScopeKind: pricingoverrides.ScopeKindProvider, - ProviderID: &providerScopeID, - MatchType: pricingoverrides.MatchTypeExact, - Pattern: "gpt-5-nano", - Patch: pricingoverrides.Patch{ - InputCostPerToken: &providerCost, - }, + ID: "provider", + ScopeKind: string(ScopeKindProvider), + ProviderID: &providerScopeID, + MatchType: string(MatchTypeExact), + Pattern: "gpt-5-nano", + PricingPatchJSON: `{"input_cost_per_token":3}`, }, { - ID: "provider-key", - ScopeKind: pricingoverrides.ScopeKindProviderKey, - ProviderKeyID: &providerKeyScopeID, - MatchType: pricingoverrides.MatchTypeExact, - Pattern: "gpt-5-nano", - Patch: pricingoverrides.Patch{ - InputCostPerToken: &providerKeyCost, - }, + ID: "provider-key", + ScopeKind: string(ScopeKindProviderKey), + ProviderKeyID: &providerKeyScopeID, + MatchType: string(MatchTypeExact), + Pattern: "gpt-5-nano", + PricingPatchJSON: `{"input_cost_per_token":4}`, }, { - ID: "virtual-key", - ScopeKind: pricingoverrides.ScopeKindVirtualKey, - VirtualKeyID: &virtualKeyScopeID, - MatchType: pricingoverrides.MatchTypeExact, - Pattern: "gpt-5-nano", - Patch: pricingoverrides.Patch{ - InputCostPerToken: &virtualKeyCost, - }, + ID: "virtual-key", + ScopeKind: string(ScopeKindVirtualKey), + VirtualKeyID: &virtualKeyScopeID, + MatchType: string(MatchTypeExact), + Pattern: "gpt-5-nano", + PricingPatchJSON: `{"input_cost_per_token":5}`, }, })) @@ -540,7 +454,7 @@ func TestApplyScopedPricingOverrides_ScopePrecedence(t *testing.T) { SelectedKeyID: providerKeyScopeID, Provider: providerScopeID, }, - expected: virtualKeyCost, + expected: 5.0, }, { name: "provider key wins over provider and global", @@ -548,25 +462,25 @@ func TestApplyScopedPricingOverrides_ScopePrecedence(t *testing.T) { SelectedKeyID: providerKeyScopeID, Provider: providerScopeID, }, - expected: providerKeyCost, + expected: 4.0, }, { name: "provider wins over global", scopes: PricingLookupScopes{ Provider: providerScopeID, }, - expected: providerCost, + expected: 3.0, }, { name: "global applies when no narrower scope is provided", scopes: PricingLookupScopes{}, - expected: globalCost, + expected: 2.0, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - patched := mc.applyScopedPricingOverrides("gpt-5-nano", schemas.ChatCompletionRequest, base, tc.scopes) + patched := mc.applyPricingOverrides("gpt-5-nano", schemas.ChatCompletionRequest, base, tc.scopes) assert.Equal(t, tc.expected, patched.InputCostPerToken) }) } diff --git a/framework/modelcatalog/pricing.go b/framework/modelcatalog/pricing.go index 2f1d6ffec7..f6158bff6b 100644 --- a/framework/modelcatalog/pricing.go +++ b/framework/modelcatalog/pricing.go @@ -71,8 +71,8 @@ func (mc *ModelCatalog) computeCacheEmbeddingCost(cacheDebug *schemas.BifrostCac if scopes.Provider == "" { scopes.Provider = *cacheDebug.ProviderUsed } - pricing, exists := mc.getPricingLocked(*cacheDebug.ModelUsed, *cacheDebug.ModelUsed, *cacheDebug.ProviderUsed, schemas.EmbeddingRequest, scopes) - if !exists { + pricing := mc.resolvePricing(*cacheDebug.ProviderUsed, *cacheDebug.ModelUsed, "", schemas.EmbeddingRequest, scopes) + if pricing == nil { return 0 } return float64(*cacheDebug.InputTokens) * tieredInputRate(pricing, *cacheDebug.InputTokens) @@ -759,16 +759,19 @@ func (mc *ModelCatalog) resolvePricing(provider, model, deployment string, reque scopes.Provider = provider } - pricing, exists := mc.getPricingLocked(model, model, provider, requestType, scopes) + base, exists := mc.getBasePricing(model, provider, requestType) if exists { - return pricing + result := mc.applyPricingOverrides(model, requestType, base, scopes) + return &result } if deployment != "" { mc.logger.Debug("pricing not found for model %s, trying deployment %s", model, deployment) - pricing, exists = mc.getPricingLocked(deployment, model, provider, requestType, scopes) + base, exists = mc.getBasePricing(deployment, provider, requestType) if exists { - return pricing + // Apply overrides using the requested model name, not the deployment name + result := mc.applyPricingOverrides(model, requestType, base, scopes) + return &result } } @@ -776,35 +779,29 @@ func (mc *ModelCatalog) resolvePricing(provider, model, deployment string, reque return nil } -// getPricing returns pricing information for a model (thread-safe) -func (mc *ModelCatalog) getPricing(model, provider string, requestType schemas.RequestType) (*configstoreTables.TableModelPricing, bool) { - return mc.getPricingLocked(model, model, provider, requestType, PricingLookupScopes{Provider: provider}) -} - -// getPricingLocked acquires a read lock and resolves pricing for a model with scoped overrides. -func (mc *ModelCatalog) getPricingLocked(lookupModel, matchModel, provider string, requestType schemas.RequestType, scopes PricingLookupScopes) (*configstoreTables.TableModelPricing, bool) { +// getBasePricing looks up catalog pricing for the given model, provider, and request type. +// It applies a provider-specific fallback chain when an exact match is not found: +// +// - Gemini: retries under the "vertex" provider, then falls back to chat mode for Responses requests. +// - Vertex: strips the "provider/model" prefix and retries, then falls back to chat mode for Responses requests. +// - Bedrock: prepends the "anthropic." namespace for Claude models, then falls back to chat mode for Responses requests. +// - All providers: for Responses/ResponsesStream requests, retries the lookup in chat mode. +// - All providers: for ImageEdit/ImageVariation requests, retries the lookup in image-generation mode. +// +// The method acquires a read lock for the duration of the lookup. +// +// Input: model — exact model name to look up. +// +// provider — provider identifier (e.g. "openai", "anthropic"). +// requestType — the request type used to derive the pricing mode. +// +// Output: TableModelPricing — the matched pricing row (zero value when not found). +// +// bool — true when a pricing entry was found, false otherwise. +func (mc *ModelCatalog) getBasePricing(model, provider string, requestType schemas.RequestType) (configstoreTables.TableModelPricing, bool) { mc.mu.RLock() - pricing, ok := mc.resolvePricingEntryLocked(lookupModel, matchModel, provider, requestType, scopes) - mc.mu.RUnlock() - if !ok { - return nil, false - } - return &pricing, true -} - -// resolvePricingEntryLocked resolves pricing data including scoped overrides. -// Caller must hold mc.mu read lock. -func (mc *ModelCatalog) resolvePricingEntryLocked(lookupModel, matchModel, provider string, requestType schemas.RequestType, scopes PricingLookupScopes) (configstoreTables.TableModelPricing, bool) { - pricing, ok := mc.resolveBasePricingEntryLocked(lookupModel, provider, requestType) - if !ok { - return configstoreTables.TableModelPricing{}, false - } - return mc.applyScopedPricingOverrides(matchModel, requestType, pricing, scopes), true -} + defer mc.mu.RUnlock() -// resolveBasePricingEntryLocked resolves pricing data from the base catalog including all fallback logic. -// Caller must hold mc.mu read lock. -func (mc *ModelCatalog) resolveBasePricingEntryLocked(model, provider string, requestType schemas.RequestType) (configstoreTables.TableModelPricing, bool) { mode := normalizeRequestType(requestType) pricing, ok := mc.pricingData[makeKey(model, provider, mode)] diff --git a/framework/modelcatalog/pricing_test.go b/framework/modelcatalog/pricing_test.go index 61c23993ca..b8f8434a71 100644 --- a/framework/modelcatalog/pricing_test.go +++ b/framework/modelcatalog/pricing_test.go @@ -1140,8 +1140,7 @@ func TestGetPricing_DirectLookup(t *testing.T) { mc := testCatalogWithPricing(map[string]configstoreTables.TableModelPricing{ makeKey("gpt-4o", "openai", "chat"): chatPricing(0.000005, 0.000015), }) - p, ok := mc.getPricing("gpt-4o", "openai", schemas.ChatCompletionRequest) - require.True(t, ok) + p := mc.resolvePricing("openai", "gpt-4o", "", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "openai"}) assert.Equal(t, 0.000005, p.InputCostPerToken) } @@ -1152,8 +1151,7 @@ func TestGetPricing_GeminiFallsBackToVertex(t *testing.T) { InputCostPerToken: 0.0000001, OutputCostPerToken: 0.0000004, }, }) - p, ok := mc.getPricing("gemini-2.0-flash", "gemini", schemas.ChatCompletionRequest) - require.True(t, ok) + p := mc.resolvePricing("gemini", "gemini-2.0-flash", "", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "gemini"}) assert.Equal(t, 0.0000001, p.InputCostPerToken) } @@ -1161,8 +1159,7 @@ func TestGetPricing_VertexStripsProviderPrefix(t *testing.T) { mc := testCatalogWithPricing(map[string]configstoreTables.TableModelPricing{ makeKey("gemini-2.0-flash", "vertex", "chat"): chatPricing(0.0000001, 0.0000004), }) - p, ok := mc.getPricing("google/gemini-2.0-flash", "vertex", schemas.ChatCompletionRequest) - require.True(t, ok) + p := mc.resolvePricing("vertex", "google/gemini-2.0-flash", "", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "vertex"}) assert.Equal(t, 0.0000001, p.InputCostPerToken) } @@ -1170,8 +1167,7 @@ func TestGetPricing_BedrockAddsAnthropicPrefix(t *testing.T) { mc := testCatalogWithPricing(map[string]configstoreTables.TableModelPricing{ makeKey("anthropic.claude-3-5-sonnet-20241022-v2:0", "bedrock", "chat"): chatPricing(0.000003, 0.000015), }) - p, ok := mc.getPricing("claude-3-5-sonnet-20241022-v2:0", "bedrock", schemas.ChatCompletionRequest) - require.True(t, ok) + p := mc.resolvePricing("bedrock", "claude-3-5-sonnet-20241022-v2:0", "", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "bedrock"}) assert.Equal(t, 0.000003, p.InputCostPerToken) } @@ -1179,8 +1175,7 @@ func TestGetPricing_ResponsesFallsBackToChat(t *testing.T) { mc := testCatalogWithPricing(map[string]configstoreTables.TableModelPricing{ makeKey("gpt-4o", "openai", "chat"): chatPricing(0.000005, 0.000015), }) - p, ok := mc.getPricing("gpt-4o", "openai", schemas.ResponsesRequest) - require.True(t, ok) + p := mc.resolvePricing("openai", "gpt-4o", "", schemas.ResponsesRequest, PricingLookupScopes{Provider: "openai"}) assert.Equal(t, 0.000005, p.InputCostPerToken) } @@ -1188,8 +1183,7 @@ func TestGetPricing_ResponsesStreamFallsBackToChat(t *testing.T) { mc := testCatalogWithPricing(map[string]configstoreTables.TableModelPricing{ makeKey("gpt-4o", "openai", "chat"): chatPricing(0.000005, 0.000015), }) - p, ok := mc.getPricing("gpt-4o", "openai", schemas.ResponsesStreamRequest) - require.True(t, ok) + p := mc.resolvePricing("openai", "gpt-4o", "", schemas.ResponsesStreamRequest, PricingLookupScopes{Provider: "openai"}) assert.Equal(t, 0.000005, p.InputCostPerToken) } @@ -1198,15 +1192,14 @@ func TestGetPricing_GeminiResponsesFallsBackToVertexChat(t *testing.T) { makeKey("gemini-2.0-flash", "vertex", "chat"): chatPricing(0.0000001, 0.0000004), }) // gemini provider + responses request → try vertex + responses → try vertex + chat - p, ok := mc.getPricing("gemini-2.0-flash", "gemini", schemas.ResponsesRequest) - require.True(t, ok) + p := mc.resolvePricing("gemini", "gemini-2.0-flash", "", schemas.ResponsesRequest, PricingLookupScopes{Provider: "gemini"}) assert.Equal(t, 0.0000001, p.InputCostPerToken) } func TestGetPricing_NotFound(t *testing.T) { mc := testCatalogWithPricing(nil) - _, ok := mc.getPricing("nonexistent", "openai", schemas.ChatCompletionRequest) - assert.False(t, ok) + p := mc.resolvePricing("openai", "nonexistent", "", schemas.ChatCompletionRequest, PricingLookupScopes{Provider: "openai"}) + assert.Nil(t, p) } // ========================================================================= diff --git a/framework/modelcatalog/utils.go b/framework/modelcatalog/utils.go index c477696c6a..4808ee844d 100644 --- a/framework/modelcatalog/utils.go +++ b/framework/modelcatalog/utils.go @@ -3,6 +3,7 @@ package modelcatalog import ( "strings" + "github.com/bytedance/sonic" "github.com/maximhq/bifrost/core/schemas" configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables" ) @@ -163,11 +164,7 @@ func convertPricingDataToTableModelPricing(modelKey string, entry PricingEntry) // convertTableModelPricingToPricingData converts the TableModelPricing struct to a PricingEntry struct func convertTableModelPricingToPricingData(pricing *configstoreTables.TableModelPricing) *PricingEntry { - return &PricingEntry{ - BaseModel: pricing.BaseModel, - Provider: pricing.Provider, - Mode: pricing.Mode, - + options := PricingOptions{ // Costs - Text InputCostPerToken: pricing.InputCostPerToken, OutputCostPerToken: pricing.OutputCostPerToken, @@ -230,4 +227,30 @@ func convertTableModelPricingToPricingData(pricing *configstoreTables.TableModel SearchContextCostPerQuery: pricing.SearchContextCostPerQuery, CodeInterpreterCostPerSession: pricing.CodeInterpreterCostPerSession, } + return &PricingEntry{ + BaseModel: pricing.BaseModel, + Provider: pricing.Provider, + Mode: pricing.Mode, + PricingOptions: options, + } +} + +// convertTablePricingOverrideToPricingOverride converts a TablePricingOverride to a PricingOverride. +func convertTablePricingOverrideToPricingOverride(override *configstoreTables.TablePricingOverride) (PricingOverride, error) { + var options PricingOptions + if err := sonic.Unmarshal([]byte(override.PricingPatchJSON), &options); err != nil { + return PricingOverride{}, err + } + return PricingOverride{ + ID: override.ID, + Name: override.Name, + ScopeKind: ScopeKind(override.ScopeKind), + VirtualKeyID: override.VirtualKeyID, + ProviderID: override.ProviderID, + ProviderKeyID: override.ProviderKeyID, + MatchType: MatchType(override.MatchType), + Pattern: override.Pattern, + RequestTypes: override.RequestTypes, + Options: options, + }, nil } diff --git a/framework/pricingoverrides/pricing_overrides.go b/framework/pricingoverrides/pricing_overrides.go deleted file mode 100644 index 588a638d47..0000000000 --- a/framework/pricingoverrides/pricing_overrides.go +++ /dev/null @@ -1,476 +0,0 @@ -// Package pricingoverrides defines the shared pricing override contract used by -// config storage, model catalog compilation, and HTTP governance handlers. -package pricingoverrides - -import ( - "fmt" - "strings" - "time" - - "github.com/maximhq/bifrost/core/schemas" -) - -// ScopeKind identifies which governance scope an override applies to. -type ScopeKind string - -const ( - ScopeKindGlobal ScopeKind = "global" - ScopeKindProvider ScopeKind = "provider" - ScopeKindProviderKey ScopeKind = "provider_key" - ScopeKindVirtualKey ScopeKind = "virtual_key" - ScopeKindVirtualKeyProvider ScopeKind = "virtual_key_provider" - ScopeKindVirtualKeyProviderKey ScopeKind = "virtual_key_provider_key" -) - -// MatchType controls how an override pattern is matched against model names. -type MatchType string - -const ( - MatchTypeExact MatchType = "exact" - MatchTypeWildcard MatchType = "wildcard" -) - -// Patch is a sparse pricing override payload. -// -// Nil fields mean "leave the base pricing unchanged". -type Patch struct { - InputCostPerToken *float64 `json:"input_cost_per_token,omitempty"` - OutputCostPerToken *float64 `json:"output_cost_per_token,omitempty"` - InputCostPerTokenPriority *float64 `json:"input_cost_per_token_priority,omitempty"` - OutputCostPerTokenPriority *float64 `json:"output_cost_per_token_priority,omitempty"` - - InputCostPerVideoPerSecond *float64 `json:"input_cost_per_video_per_second,omitempty"` - OutputCostPerVideoPerSecond *float64 `json:"output_cost_per_video_per_second,omitempty"` - OutputCostPerSecond *float64 `json:"output_cost_per_second,omitempty"` - InputCostPerAudioPerSecond *float64 `json:"input_cost_per_audio_per_second,omitempty"` - InputCostPerSecond *float64 `json:"input_cost_per_second,omitempty"` - InputCostPerAudioToken *float64 `json:"input_cost_per_audio_token,omitempty"` - OutputCostPerAudioToken *float64 `json:"output_cost_per_audio_token,omitempty"` - - InputCostPerCharacter *float64 `json:"input_cost_per_character,omitempty"` - OutputCostPerCharacter *float64 `json:"output_cost_per_character,omitempty"` - - InputCostPerTokenAbove128kTokens *float64 `json:"input_cost_per_token_above_128k_tokens,omitempty"` - InputCostPerCharacterAbove128kTokens *float64 `json:"input_cost_per_character_above_128k_tokens,omitempty"` - InputCostPerImageAbove128kTokens *float64 `json:"input_cost_per_image_above_128k_tokens,omitempty"` - InputCostPerVideoPerSecondAbove128kTokens *float64 `json:"input_cost_per_video_per_second_above_128k_tokens,omitempty"` - InputCostPerAudioPerSecondAbove128kTokens *float64 `json:"input_cost_per_audio_per_second_above_128k_tokens,omitempty"` - OutputCostPerTokenAbove128kTokens *float64 `json:"output_cost_per_token_above_128k_tokens,omitempty"` - OutputCostPerCharacterAbove128kTokens *float64 `json:"output_cost_per_character_above_128k_tokens,omitempty"` - - InputCostPerTokenAbove200kTokens *float64 `json:"input_cost_per_token_above_200k_tokens,omitempty"` - OutputCostPerTokenAbove200kTokens *float64 `json:"output_cost_per_token_above_200k_tokens,omitempty"` - CacheCreationInputTokenCostAbove200kTokens *float64 `json:"cache_creation_input_token_cost_above_200k_tokens,omitempty"` - CacheReadInputTokenCostAbove200kTokens *float64 `json:"cache_read_input_token_cost_above_200k_tokens,omitempty"` - - CacheReadInputTokenCost *float64 `json:"cache_read_input_token_cost,omitempty"` - CacheCreationInputTokenCost *float64 `json:"cache_creation_input_token_cost,omitempty"` - CacheCreationInputTokenCostAbove1hr *float64 `json:"cache_creation_input_token_cost_above_1hr,omitempty"` - CacheCreationInputTokenCostAbove1hrAbove200kTokens *float64 `json:"cache_creation_input_token_cost_above_1hr_above_200k_tokens,omitempty"` - CacheCreationInputAudioTokenCost *float64 `json:"cache_creation_input_audio_token_cost,omitempty"` - CacheReadInputTokenCostPriority *float64 `json:"cache_read_input_token_cost_priority,omitempty"` - InputCostPerTokenBatches *float64 `json:"input_cost_per_token_batches,omitempty"` - OutputCostPerTokenBatches *float64 `json:"output_cost_per_token_batches,omitempty"` - - InputCostPerImageToken *float64 `json:"input_cost_per_image_token,omitempty"` - OutputCostPerImageToken *float64 `json:"output_cost_per_image_token,omitempty"` - InputCostPerImage *float64 `json:"input_cost_per_image,omitempty"` - OutputCostPerImage *float64 `json:"output_cost_per_image,omitempty"` - InputCostPerPixel *float64 `json:"input_cost_per_pixel,omitempty"` - OutputCostPerPixel *float64 `json:"output_cost_per_pixel,omitempty"` - OutputCostPerImagePremiumImage *float64 `json:"output_cost_per_image_premium_image,omitempty"` - OutputCostPerImageAbove512x512Pixels *float64 `json:"output_cost_per_image_above_512_and_512_pixels,omitempty"` - OutputCostPerImageAbove512x512PixelsPremium *float64 `json:"output_cost_per_image_above_512_and_512_pixels_and_premium_image,omitempty"` - OutputCostPerImageAbove1024x1024Pixels *float64 `json:"output_cost_per_image_above_1024_and_1024_pixels,omitempty"` - OutputCostPerImageAbove1024x1024PixelsPremium *float64 `json:"output_cost_per_image_above_1024_and_1024_pixels_and_premium_image,omitempty"` - OutputCostPerImageAbove2048x2048Pixels *float64 `json:"output_cost_per_image_above_2048_and_2048_pixels,omitempty"` - OutputCostPerImageAbove4096x4096Pixels *float64 `json:"output_cost_per_image_above_4096_and_4096_pixels,omitempty"` - OutputCostPerImageLowQuality *float64 `json:"output_cost_per_image_low_quality,omitempty"` - OutputCostPerImageMediumQuality *float64 `json:"output_cost_per_image_medium_quality,omitempty"` - OutputCostPerImageHighQuality *float64 `json:"output_cost_per_image_high_quality,omitempty"` - OutputCostPerImageAutoQuality *float64 `json:"output_cost_per_image_auto_quality,omitempty"` - CacheReadInputImageTokenCost *float64 `json:"cache_read_input_image_token_cost,omitempty"` - - SearchContextCostPerQuery *float64 `json:"search_context_cost_per_query,omitempty"` - CodeInterpreterCostPerSession *float64 `json:"code_interpreter_cost_per_session,omitempty"` -} - -// Override describes a scoped pricing override shared across config storage, -// model catalog compilation, and governance APIs. -type Override struct { - ID string `json:"id"` - Name string `json:"name"` - ScopeKind ScopeKind `json:"scope_kind"` - VirtualKeyID *string `json:"virtual_key_id,omitempty"` - ProviderID *string `json:"provider_id,omitempty"` - ProviderKeyID *string `json:"provider_key_id,omitempty"` - MatchType MatchType `json:"match_type"` - Pattern string `json:"pattern"` - RequestTypes []schemas.RequestType `json:"request_types,omitempty"` - Patch Patch `json:"patch,omitempty"` - CreatedAt time.Time `json:"created_at,omitempty"` - UpdatedAt time.Time `json:"updated_at,omitempty"` -} - -// NormalizeOverride trims identifiers and validates the shared pricing override -// contract before persistence or runtime compilation. -func NormalizeOverride(override Override) (Override, error) { - override.Name = strings.TrimSpace(override.Name) - if override.Name == "" { - return Override{}, fmt.Errorf("name is required") - } - - override.VirtualKeyID = normalizeOptionalID(override.VirtualKeyID) - override.ProviderID = normalizeOptionalID(override.ProviderID) - override.ProviderKeyID = normalizeOptionalID(override.ProviderKeyID) - - if err := ValidateScopeKind(override.ScopeKind, override.VirtualKeyID, override.ProviderID, override.ProviderKeyID); err != nil { - return Override{}, err - } - - normalizedPattern, err := ValidatePattern(override.MatchType, override.Pattern) - if err != nil { - return Override{}, err - } - override.Pattern = normalizedPattern - - if err := ValidateRequestTypes(override.RequestTypes); err != nil { - return Override{}, err - } - if err := ValidatePatchNonNegative(override.Patch); err != nil { - return Override{}, err - } - - return override, nil -} - -// MergePatch overlays non-nil fields from updates onto base. -func MergePatch(base, updates Patch) Patch { - merged := base - - if updates.InputCostPerToken != nil { - merged.InputCostPerToken = updates.InputCostPerToken - } - if updates.OutputCostPerToken != nil { - merged.OutputCostPerToken = updates.OutputCostPerToken - } - if updates.InputCostPerTokenPriority != nil { - merged.InputCostPerTokenPriority = updates.InputCostPerTokenPriority - } - if updates.OutputCostPerTokenPriority != nil { - merged.OutputCostPerTokenPriority = updates.OutputCostPerTokenPriority - } - if updates.InputCostPerVideoPerSecond != nil { - merged.InputCostPerVideoPerSecond = updates.InputCostPerVideoPerSecond - } - if updates.OutputCostPerVideoPerSecond != nil { - merged.OutputCostPerVideoPerSecond = updates.OutputCostPerVideoPerSecond - } - if updates.OutputCostPerSecond != nil { - merged.OutputCostPerSecond = updates.OutputCostPerSecond - } - if updates.InputCostPerAudioPerSecond != nil { - merged.InputCostPerAudioPerSecond = updates.InputCostPerAudioPerSecond - } - if updates.InputCostPerSecond != nil { - merged.InputCostPerSecond = updates.InputCostPerSecond - } - if updates.InputCostPerAudioToken != nil { - merged.InputCostPerAudioToken = updates.InputCostPerAudioToken - } - if updates.OutputCostPerAudioToken != nil { - merged.OutputCostPerAudioToken = updates.OutputCostPerAudioToken - } - if updates.InputCostPerCharacter != nil { - merged.InputCostPerCharacter = updates.InputCostPerCharacter - } - if updates.OutputCostPerCharacter != nil { - merged.OutputCostPerCharacter = updates.OutputCostPerCharacter - } - if updates.InputCostPerTokenAbove128kTokens != nil { - merged.InputCostPerTokenAbove128kTokens = updates.InputCostPerTokenAbove128kTokens - } - if updates.InputCostPerCharacterAbove128kTokens != nil { - merged.InputCostPerCharacterAbove128kTokens = updates.InputCostPerCharacterAbove128kTokens - } - if updates.InputCostPerImageAbove128kTokens != nil { - merged.InputCostPerImageAbove128kTokens = updates.InputCostPerImageAbove128kTokens - } - if updates.InputCostPerVideoPerSecondAbove128kTokens != nil { - merged.InputCostPerVideoPerSecondAbove128kTokens = updates.InputCostPerVideoPerSecondAbove128kTokens - } - if updates.InputCostPerAudioPerSecondAbove128kTokens != nil { - merged.InputCostPerAudioPerSecondAbove128kTokens = updates.InputCostPerAudioPerSecondAbove128kTokens - } - if updates.OutputCostPerTokenAbove128kTokens != nil { - merged.OutputCostPerTokenAbove128kTokens = updates.OutputCostPerTokenAbove128kTokens - } - if updates.OutputCostPerCharacterAbove128kTokens != nil { - merged.OutputCostPerCharacterAbove128kTokens = updates.OutputCostPerCharacterAbove128kTokens - } - if updates.InputCostPerTokenAbove200kTokens != nil { - merged.InputCostPerTokenAbove200kTokens = updates.InputCostPerTokenAbove200kTokens - } - if updates.OutputCostPerTokenAbove200kTokens != nil { - merged.OutputCostPerTokenAbove200kTokens = updates.OutputCostPerTokenAbove200kTokens - } - if updates.CacheCreationInputTokenCostAbove200kTokens != nil { - merged.CacheCreationInputTokenCostAbove200kTokens = updates.CacheCreationInputTokenCostAbove200kTokens - } - if updates.CacheReadInputTokenCostAbove200kTokens != nil { - merged.CacheReadInputTokenCostAbove200kTokens = updates.CacheReadInputTokenCostAbove200kTokens - } - if updates.CacheReadInputTokenCost != nil { - merged.CacheReadInputTokenCost = updates.CacheReadInputTokenCost - } - if updates.CacheCreationInputTokenCost != nil { - merged.CacheCreationInputTokenCost = updates.CacheCreationInputTokenCost - } - if updates.CacheCreationInputTokenCostAbove1hr != nil { - merged.CacheCreationInputTokenCostAbove1hr = updates.CacheCreationInputTokenCostAbove1hr - } - if updates.CacheCreationInputTokenCostAbove1hrAbove200kTokens != nil { - merged.CacheCreationInputTokenCostAbove1hrAbove200kTokens = updates.CacheCreationInputTokenCostAbove1hrAbove200kTokens - } - if updates.CacheCreationInputAudioTokenCost != nil { - merged.CacheCreationInputAudioTokenCost = updates.CacheCreationInputAudioTokenCost - } - if updates.CacheReadInputTokenCostPriority != nil { - merged.CacheReadInputTokenCostPriority = updates.CacheReadInputTokenCostPriority - } - if updates.InputCostPerTokenBatches != nil { - merged.InputCostPerTokenBatches = updates.InputCostPerTokenBatches - } - if updates.OutputCostPerTokenBatches != nil { - merged.OutputCostPerTokenBatches = updates.OutputCostPerTokenBatches - } - if updates.InputCostPerImageToken != nil { - merged.InputCostPerImageToken = updates.InputCostPerImageToken - } - if updates.OutputCostPerImageToken != nil { - merged.OutputCostPerImageToken = updates.OutputCostPerImageToken - } - if updates.InputCostPerImage != nil { - merged.InputCostPerImage = updates.InputCostPerImage - } - if updates.OutputCostPerImage != nil { - merged.OutputCostPerImage = updates.OutputCostPerImage - } - if updates.InputCostPerPixel != nil { - merged.InputCostPerPixel = updates.InputCostPerPixel - } - if updates.OutputCostPerPixel != nil { - merged.OutputCostPerPixel = updates.OutputCostPerPixel - } - if updates.OutputCostPerImagePremiumImage != nil { - merged.OutputCostPerImagePremiumImage = updates.OutputCostPerImagePremiumImage - } - if updates.OutputCostPerImageAbove512x512Pixels != nil { - merged.OutputCostPerImageAbove512x512Pixels = updates.OutputCostPerImageAbove512x512Pixels - } - if updates.OutputCostPerImageAbove512x512PixelsPremium != nil { - merged.OutputCostPerImageAbove512x512PixelsPremium = updates.OutputCostPerImageAbove512x512PixelsPremium - } - if updates.OutputCostPerImageAbove1024x1024Pixels != nil { - merged.OutputCostPerImageAbove1024x1024Pixels = updates.OutputCostPerImageAbove1024x1024Pixels - } - if updates.OutputCostPerImageAbove1024x1024PixelsPremium != nil { - merged.OutputCostPerImageAbove1024x1024PixelsPremium = updates.OutputCostPerImageAbove1024x1024PixelsPremium - } - if updates.CacheReadInputImageTokenCost != nil { - merged.CacheReadInputImageTokenCost = updates.CacheReadInputImageTokenCost - } - if updates.SearchContextCostPerQuery != nil { - merged.SearchContextCostPerQuery = updates.SearchContextCostPerQuery - } - if updates.CodeInterpreterCostPerSession != nil { - merged.CodeInterpreterCostPerSession = updates.CodeInterpreterCostPerSession - } - - return merged -} - -// IsSupportedRequestType reports whether requestType can be used in pricing -// override request filters. -func IsSupportedRequestType(requestType schemas.RequestType) bool { - switch requestType { - case schemas.TextCompletionRequest, - schemas.TextCompletionStreamRequest, - schemas.ChatCompletionRequest, - schemas.ChatCompletionStreamRequest, - schemas.ResponsesRequest, - schemas.ResponsesStreamRequest, - schemas.EmbeddingRequest, - schemas.RerankRequest, - schemas.SpeechRequest, - schemas.SpeechStreamRequest, - schemas.TranscriptionRequest, - schemas.TranscriptionStreamRequest, - schemas.ImageGenerationRequest, - schemas.ImageGenerationStreamRequest, - schemas.ImageEditRequest, - schemas.ImageEditStreamRequest, - schemas.ImageVariationRequest, - schemas.VideoGenerationRequest: - return true - default: - return false - } -} - -// ValidatePattern trims and validates a model pattern for the given match type. -func ValidatePattern(matchType MatchType, pattern string) (string, error) { - pattern = strings.TrimSpace(pattern) - if pattern == "" { - return "", fmt.Errorf("pattern is required") - } - switch matchType { - case MatchTypeExact: - if strings.Contains(pattern, "*") { - return "", fmt.Errorf("exact pattern cannot include '*'") - } - case MatchTypeWildcard: - if !strings.HasSuffix(pattern, "*") || strings.Count(pattern, "*") != 1 { - return "", fmt.Errorf("wildcard pattern supports a single trailing '*' only") - } - if strings.TrimSuffix(pattern, "*") == "" { - return "", fmt.Errorf("wildcard prefix cannot be empty") - } - default: - return "", fmt.Errorf("unsupported match_type %q", matchType) - } - return pattern, nil -} - -// ValidateRequestTypes validates that every request type in requestTypes is -// supported by pricing overrides. -func ValidateRequestTypes(requestTypes []schemas.RequestType) error { - for _, requestType := range requestTypes { - if !IsSupportedRequestType(requestType) { - return fmt.Errorf("unsupported request_type %q", requestType) - } - } - return nil -} - -// ValidatePatchNonNegative validates that all populated pricing values are -// non-negative. -func ValidatePatchNonNegative(patch Patch) error { - values := []struct { - name string - value *float64 - }{ - {name: "input_cost_per_token", value: patch.InputCostPerToken}, - {name: "output_cost_per_token", value: patch.OutputCostPerToken}, - {name: "input_cost_per_token_priority", value: patch.InputCostPerTokenPriority}, - {name: "output_cost_per_token_priority", value: patch.OutputCostPerTokenPriority}, - {name: "input_cost_per_video_per_second", value: patch.InputCostPerVideoPerSecond}, - {name: "output_cost_per_video_per_second", value: patch.OutputCostPerVideoPerSecond}, - {name: "output_cost_per_second", value: patch.OutputCostPerSecond}, - {name: "input_cost_per_audio_per_second", value: patch.InputCostPerAudioPerSecond}, - {name: "input_cost_per_second", value: patch.InputCostPerSecond}, - {name: "input_cost_per_audio_token", value: patch.InputCostPerAudioToken}, - {name: "output_cost_per_audio_token", value: patch.OutputCostPerAudioToken}, - {name: "input_cost_per_character", value: patch.InputCostPerCharacter}, - {name: "output_cost_per_character", value: patch.OutputCostPerCharacter}, - {name: "input_cost_per_token_above_128k_tokens", value: patch.InputCostPerTokenAbove128kTokens}, - {name: "input_cost_per_character_above_128k_tokens", value: patch.InputCostPerCharacterAbove128kTokens}, - {name: "input_cost_per_image_above_128k_tokens", value: patch.InputCostPerImageAbove128kTokens}, - {name: "input_cost_per_video_per_second_above_128k_tokens", value: patch.InputCostPerVideoPerSecondAbove128kTokens}, - {name: "input_cost_per_audio_per_second_above_128k_tokens", value: patch.InputCostPerAudioPerSecondAbove128kTokens}, - {name: "output_cost_per_token_above_128k_tokens", value: patch.OutputCostPerTokenAbove128kTokens}, - {name: "output_cost_per_character_above_128k_tokens", value: patch.OutputCostPerCharacterAbove128kTokens}, - {name: "input_cost_per_token_above_200k_tokens", value: patch.InputCostPerTokenAbove200kTokens}, - {name: "output_cost_per_token_above_200k_tokens", value: patch.OutputCostPerTokenAbove200kTokens}, - {name: "cache_creation_input_token_cost_above_200k_tokens", value: patch.CacheCreationInputTokenCostAbove200kTokens}, - {name: "cache_read_input_token_cost_above_200k_tokens", value: patch.CacheReadInputTokenCostAbove200kTokens}, - {name: "cache_read_input_token_cost", value: patch.CacheReadInputTokenCost}, - {name: "cache_creation_input_token_cost", value: patch.CacheCreationInputTokenCost}, - {name: "cache_creation_input_token_cost_above_1hr", value: patch.CacheCreationInputTokenCostAbove1hr}, - {name: "cache_creation_input_token_cost_above_1hr_above_200k_tokens", value: patch.CacheCreationInputTokenCostAbove1hrAbove200kTokens}, - {name: "cache_creation_input_audio_token_cost", value: patch.CacheCreationInputAudioTokenCost}, - {name: "cache_read_input_token_cost_priority", value: patch.CacheReadInputTokenCostPriority}, - {name: "input_cost_per_token_batches", value: patch.InputCostPerTokenBatches}, - {name: "output_cost_per_token_batches", value: patch.OutputCostPerTokenBatches}, - {name: "input_cost_per_image_token", value: patch.InputCostPerImageToken}, - {name: "output_cost_per_image_token", value: patch.OutputCostPerImageToken}, - {name: "input_cost_per_image", value: patch.InputCostPerImage}, - {name: "output_cost_per_image", value: patch.OutputCostPerImage}, - {name: "input_cost_per_pixel", value: patch.InputCostPerPixel}, - {name: "output_cost_per_pixel", value: patch.OutputCostPerPixel}, - {name: "output_cost_per_image_premium_image", value: patch.OutputCostPerImagePremiumImage}, - {name: "output_cost_per_image_above_512_and_512_pixels", value: patch.OutputCostPerImageAbove512x512Pixels}, - {name: "output_cost_per_image_above_512_and_512_pixels_and_premium_image", value: patch.OutputCostPerImageAbove512x512PixelsPremium}, - {name: "output_cost_per_image_above_1024_and_1024_pixels", value: patch.OutputCostPerImageAbove1024x1024Pixels}, - {name: "output_cost_per_image_above_1024_and_1024_pixels_and_premium_image", value: patch.OutputCostPerImageAbove1024x1024PixelsPremium}, - {name: "cache_read_input_image_token_cost", value: patch.CacheReadInputImageTokenCost}, - {name: "search_context_cost_per_query", value: patch.SearchContextCostPerQuery}, - {name: "code_interpreter_cost_per_session", value: patch.CodeInterpreterCostPerSession}, - } - for _, item := range values { - if item.value != nil && *item.value < 0 { - return fmt.Errorf("%s must be non-negative", item.name) - } - } - return nil -} - -// ValidateScopeKind validates the scope identifiers required by scopeKind. -func ValidateScopeKind(scopeKind ScopeKind, virtualKeyID, providerID, providerKeyID *string) error { - normalizedVK := normalizeOptionalID(virtualKeyID) - normalizedProvider := normalizeOptionalID(providerID) - normalizedProviderKey := normalizeOptionalID(providerKeyID) - - switch scopeKind { - case ScopeKindGlobal: - if normalizedVK != nil || normalizedProvider != nil || normalizedProviderKey != nil { - return fmt.Errorf("global scope_kind must not include scope identifiers") - } - case ScopeKindProvider: - if normalizedProvider == nil { - return fmt.Errorf("provider_id is required for provider scope_kind") - } - if normalizedVK != nil || normalizedProviderKey != nil { - return fmt.Errorf("provider scope_kind only supports provider_id") - } - case ScopeKindProviderKey: - if normalizedProviderKey == nil { - return fmt.Errorf("provider_key_id is required for provider_key scope_kind") - } - if normalizedVK != nil || normalizedProvider != nil { - return fmt.Errorf("provider_key scope_kind only supports provider_key_id") - } - case ScopeKindVirtualKey: - if normalizedVK == nil { - return fmt.Errorf("virtual_key_id is required for virtual_key scope_kind") - } - if normalizedProvider != nil || normalizedProviderKey != nil { - return fmt.Errorf("virtual_key scope_kind only supports virtual_key_id") - } - case ScopeKindVirtualKeyProvider: - if normalizedVK == nil || normalizedProvider == nil { - return fmt.Errorf("virtual_key_id and provider_id are required for virtual_key_provider scope_kind") - } - if normalizedProviderKey != nil { - return fmt.Errorf("virtual_key_provider scope_kind does not support provider_key_id") - } - case ScopeKindVirtualKeyProviderKey: - if normalizedVK == nil || normalizedProvider == nil || normalizedProviderKey == nil { - return fmt.Errorf("virtual_key_id, provider_id, and provider_key_id are required for virtual_key_provider_key scope_kind") - } - default: - return fmt.Errorf("unsupported scope_kind %q", scopeKind) - } - return nil -} - -func normalizeOptionalID(id *string) *string { - if id == nil { - return nil - } - trimmed := strings.TrimSpace(*id) - if trimmed == "" { - return nil - } - return &trimmed -} diff --git a/transports/bifrost-http/handlers/governance.go b/transports/bifrost-http/handlers/governance.go index ae677a9a29..e530eefca8 100644 --- a/transports/bifrost-http/handlers/governance.go +++ b/transports/bifrost-http/handlers/governance.go @@ -20,7 +20,6 @@ import ( "github.com/maximhq/bifrost/framework/configstore" configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables" "github.com/maximhq/bifrost/framework/modelcatalog" - "github.com/maximhq/bifrost/framework/pricingoverrides" "github.com/maximhq/bifrost/plugins/governance" "github.com/maximhq/bifrost/transports/bifrost-http/lib" "github.com/valyala/fasthttp" @@ -304,7 +303,7 @@ func (h *GovernanceHandler) RegisterRoutes(r *router.Router, middlewares ...sche // Pricing override operations r.GET("/api/governance/pricing-overrides", lib.ChainMiddlewares(h.getPricingOverrides, middlewares...)) r.POST("/api/governance/pricing-overrides", lib.ChainMiddlewares(h.createPricingOverride, middlewares...)) - r.PATCH("/api/governance/pricing-overrides/{id}", lib.ChainMiddlewares(h.patchPricingOverride, middlewares...)) + r.PUT("/api/governance/pricing-overrides/{id}", lib.ChainMiddlewares(h.updatePricingOverride, middlewares...)) r.DELETE("/api/governance/pricing-overrides/{id}", lib.ChainMiddlewares(h.deletePricingOverride, middlewares...)) } @@ -3261,29 +3260,15 @@ func (h *GovernanceHandler) deleteRoutingRule(ctx *fasthttp.RequestCtx) { // CreatePricingOverrideRequest is the request payload for creating a governance // pricing override. type CreatePricingOverrideRequest struct { - Name string `json:"name"` - ScopeKind pricingoverrides.ScopeKind `json:"scope_kind"` - VirtualKeyID *string `json:"virtual_key_id,omitempty"` - ProviderID *string `json:"provider_id,omitempty"` - ProviderKeyID *string `json:"provider_key_id,omitempty"` - MatchType pricingoverrides.MatchType `json:"match_type"` - Pattern string `json:"pattern"` - RequestTypes []schemas.RequestType `json:"request_types,omitempty"` - Patch pricingoverrides.Patch `json:"patch,omitempty"` -} - -// PatchPricingOverrideRequest is the sparse request payload for updating a -// governance pricing override. -type PatchPricingOverrideRequest struct { - Name *string `json:"name,omitempty"` - ScopeKind *pricingoverrides.ScopeKind `json:"scope_kind,omitempty"` + Name string `json:"name"` + ScopeKind modelcatalog.ScopeKind `json:"scope_kind"` VirtualKeyID *string `json:"virtual_key_id,omitempty"` ProviderID *string `json:"provider_id,omitempty"` ProviderKeyID *string `json:"provider_key_id,omitempty"` - MatchType *pricingoverrides.MatchType `json:"match_type,omitempty"` - Pattern *string `json:"pattern,omitempty"` - RequestTypes *[]schemas.RequestType `json:"request_types,omitempty"` - Patch *pricingoverrides.Patch `json:"patch,omitempty"` + MatchType modelcatalog.MatchType `json:"match_type"` + Pattern string `json:"pattern"` + RequestTypes []schemas.RequestType `json:"request_types,omitempty"` + Patch modelcatalog.PricingOptions `json:"patch,omitempty"` } func (h *GovernanceHandler) getPricingOverrides(ctx *fasthttp.RequestCtx) { @@ -3312,26 +3297,41 @@ func (h *GovernanceHandler) createPricingOverride(ctx *fasthttp.RequestCtx) { return } - if err := validatePricingOverrideRequest(req.ScopeKind, req.VirtualKeyID, req.ProviderID, req.ProviderKeyID, req.MatchType, req.Pattern, req.RequestTypes, req.Patch); err != nil { + shape := modelcatalog.PricingOverride{ + ScopeKind: req.ScopeKind, + VirtualKeyID: req.VirtualKeyID, + ProviderID: req.ProviderID, + ProviderKeyID: req.ProviderKeyID, + MatchType: req.MatchType, + Pattern: req.Pattern, + RequestTypes: req.RequestTypes, + } + if err := shape.IsValid(); err != nil { SendError(ctx, fasthttp.StatusBadRequest, err.Error()) return } + patchJSON, err := sonic.Marshal(req.Patch) + if err != nil { + SendError(ctx, fasthttp.StatusBadRequest, "Invalid patch") + return + } + now := time.Now() override := configstoreTables.TablePricingOverride{ - ID: uuid.NewString(), - Name: name, - ScopeKind: req.ScopeKind, - VirtualKeyID: normalizeOptionalString(req.VirtualKeyID), - ProviderID: normalizeOptionalString(req.ProviderID), - ProviderKeyID: normalizeOptionalString(req.ProviderKeyID), - MatchType: req.MatchType, - Pattern: strings.TrimSpace(req.Pattern), - RequestTypes: req.RequestTypes, - Patch: req.Patch, - ConfigHash: "", - CreatedAt: now, - UpdatedAt: now, + ID: uuid.NewString(), + Name: name, + ScopeKind: string(req.ScopeKind), + VirtualKeyID: normalizeOptionalString(req.VirtualKeyID), + ProviderID: normalizeOptionalString(req.ProviderID), + ProviderKeyID: normalizeOptionalString(req.ProviderKeyID), + MatchType: string(req.MatchType), + Pattern: strings.TrimSpace(req.Pattern), + RequestTypes: req.RequestTypes, + PricingPatchJSON: string(patchJSON), + ConfigHash: "", + CreatedAt: now, + UpdatedAt: now, } if err := h.configStore.CreatePricingOverride(ctx, &override); err != nil { @@ -3340,73 +3340,86 @@ func (h *GovernanceHandler) createPricingOverride(ctx *fasthttp.RequestCtx) { return } - h.refreshPricingOverrides(ctx) + if h.modelCatalog != nil { + if err := h.modelCatalog.UpsertPricingOverrides(&override); err != nil { + logger.Error("failed to upsert pricing override: %v", err) + SendError(ctx, fasthttp.StatusInternalServerError, "Failed to upsert pricing override") + } + } SendJSONWithStatus(ctx, map[string]interface{}{ "message": "Pricing override created successfully", "pricing_override": override, }, fasthttp.StatusCreated) } -func (h *GovernanceHandler) patchPricingOverride(ctx *fasthttp.RequestCtx) { +func (h *GovernanceHandler) updatePricingOverride(ctx *fasthttp.RequestCtx) { id := ctx.UserValue("id").(string) - var req PatchPricingOverrideRequest + var req CreatePricingOverrideRequest if !decodePricingOverrideJSON(ctx, &req) { return } - override, ok := h.getPricingOverrideOrSendError(ctx, id) + existing, ok := h.getPricingOverrideOrSendError(ctx, id) if !ok { return } - if req.ScopeKind != nil { - override.ScopeKind = *req.ScopeKind - } - if req.Name != nil { - name, err := normalizeAndValidatePricingOverrideName(*req.Name) - if err != nil { - SendError(ctx, fasthttp.StatusBadRequest, err.Error()) - return - } - override.Name = name - } - if req.VirtualKeyID != nil { - override.VirtualKeyID = normalizeOptionalString(req.VirtualKeyID) - } - if req.ProviderID != nil { - override.ProviderID = normalizeOptionalString(req.ProviderID) - } - if req.ProviderKeyID != nil { - override.ProviderKeyID = normalizeOptionalString(req.ProviderKeyID) - } - if req.MatchType != nil { - override.MatchType = *req.MatchType - } - if req.Pattern != nil { - override.Pattern = strings.TrimSpace(*req.Pattern) + name, err := normalizeAndValidatePricingOverrideName(req.Name) + if err != nil { + SendError(ctx, fasthttp.StatusBadRequest, err.Error()) + return } - if req.RequestTypes != nil { - override.RequestTypes = *req.RequestTypes + + shape := modelcatalog.PricingOverride{ + ScopeKind: req.ScopeKind, + VirtualKeyID: req.VirtualKeyID, + ProviderID: req.ProviderID, + ProviderKeyID: req.ProviderKeyID, + MatchType: req.MatchType, + Pattern: req.Pattern, + RequestTypes: req.RequestTypes, } - if req.Patch != nil { - override.Patch = pricingoverrides.MergePatch(override.Patch, *req.Patch) + if err := shape.IsValid(); err != nil { + SendError(ctx, fasthttp.StatusBadRequest, err.Error()) + return } - override.ConfigHash = "" - override.UpdatedAt = time.Now() - if err := validatePricingOverrideRequest(override.ScopeKind, override.VirtualKeyID, override.ProviderID, override.ProviderKeyID, override.MatchType, override.Pattern, override.RequestTypes, override.Patch); err != nil { - SendError(ctx, fasthttp.StatusBadRequest, err.Error()) + patchJSON, err := sonic.Marshal(req.Patch) + if err != nil { + SendError(ctx, fasthttp.StatusBadRequest, "Invalid patch") return } - if err := h.configStore.UpdatePricingOverride(ctx, override); err != nil { + override := configstoreTables.TablePricingOverride{ + ID: id, + Name: name, + ScopeKind: string(req.ScopeKind), + VirtualKeyID: normalizeOptionalString(req.VirtualKeyID), + ProviderID: normalizeOptionalString(req.ProviderID), + ProviderKeyID: normalizeOptionalString(req.ProviderKeyID), + MatchType: string(req.MatchType), + Pattern: strings.TrimSpace(req.Pattern), + RequestTypes: req.RequestTypes, + PricingPatchJSON: string(patchJSON), + ConfigHash: "", + CreatedAt: existing.CreatedAt, + UpdatedAt: time.Now(), + } + + if err := h.configStore.UpdatePricingOverride(ctx, &override); err != nil { logger.Error("failed to update pricing override: %v", err) SendError(ctx, fasthttp.StatusInternalServerError, "Failed to update pricing override") return } - h.refreshPricingOverrides(ctx) + if h.modelCatalog != nil { + if err := h.modelCatalog.UpsertPricingOverrides(&override); err != nil { + logger.Error("failed to upsert pricing override: %v", err) + SendError(ctx, fasthttp.StatusInternalServerError, "Failed to upsert pricing override") + return + } + } SendJSON(ctx, map[string]interface{}{ "message": "Pricing override updated successfully", "pricing_override": override, @@ -3425,17 +3438,18 @@ func (h *GovernanceHandler) deletePricingOverride(ctx *fasthttp.RequestCtx) { return } - h.refreshPricingOverrides(ctx) + if h.modelCatalog != nil { + h.modelCatalog.DeletePricingOverride(id) + } SendJSON(ctx, map[string]interface{}{ "message": "Pricing override deleted successfully", }) } -func pricingOverrideFilterFromQuery(ctx *fasthttp.RequestCtx) configstore.PricingOverrideFilter { - var filter configstore.PricingOverrideFilter +func pricingOverrideFilterFromQuery(ctx *fasthttp.RequestCtx) configstore.PricingOverrideFilters { + var filter configstore.PricingOverrideFilters if scopeKindRaw := strings.TrimSpace(string(ctx.QueryArgs().Peek("scope_kind"))); scopeKindRaw != "" { - scopeKind := pricingoverrides.ScopeKind(scopeKindRaw) - filter.ScopeKind = &scopeKind + filter.ScopeKind = &scopeKindRaw } if virtualKeyID := strings.TrimSpace(string(ctx.QueryArgs().Peek("virtual_key_id"))); virtualKeyID != "" { filter.VirtualKeyID = &virtualKeyID @@ -3471,48 +3485,6 @@ func (h *GovernanceHandler) getPricingOverrideOrSendError(ctx *fasthttp.RequestC return nil, false } -func (h *GovernanceHandler) refreshPricingOverrides(ctx context.Context) { - if h.modelCatalog == nil { - return - } - rows, err := h.configStore.GetPricingOverrides(ctx, configstore.PricingOverrideFilter{}) - if err != nil { - logger.Warn("failed to load pricing overrides for model catalog refresh: %v", err) - return - } - if err := h.modelCatalog.SetPricingOverrides(toPricingOverrides(rows)); err != nil { - logger.Warn("failed to apply pricing override refresh: %v", err) - } -} - -func toPricingOverrides(rows []configstoreTables.TablePricingOverride) []pricingoverrides.Override { - overrides := make([]pricingoverrides.Override, 0, len(rows)) - for i := range rows { - overrides = append(overrides, rows[i].ToPricingOverride()) - } - return overrides -} - -func validatePricingOverrideRequest( - scopeKind pricingoverrides.ScopeKind, - virtualKeyID, providerID, providerKeyID *string, - matchType pricingoverrides.MatchType, - pattern string, - requestTypes []schemas.RequestType, - patch pricingoverrides.Patch, -) error { - if err := pricingoverrides.ValidateScopeKind(scopeKind, virtualKeyID, providerID, providerKeyID); err != nil { - return err - } - if _, err := pricingoverrides.ValidatePattern(matchType, pattern); err != nil { - return err - } - if err := pricingoverrides.ValidateRequestTypes(requestTypes); err != nil { - return err - } - return pricingoverrides.ValidatePatchNonNegative(patch) -} - func normalizeAndValidatePricingOverrideName(name string) (string, error) { trimmed := strings.TrimSpace(name) if trimmed == "" { diff --git a/transports/bifrost-http/handlers/pricing_override_test.go b/transports/bifrost-http/handlers/pricing_override_test.go index 42aec91c68..9e7948455b 100644 --- a/transports/bifrost-http/handlers/pricing_override_test.go +++ b/transports/bifrost-http/handlers/pricing_override_test.go @@ -2,6 +2,7 @@ package handlers import ( "context" + "encoding/json" "net" "os" "strings" @@ -12,7 +13,6 @@ import ( "github.com/maximhq/bifrost/framework/configstore" configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables" "github.com/maximhq/bifrost/framework/modelcatalog" - "github.com/maximhq/bifrost/framework/pricingoverrides" "github.com/maximhq/bifrost/plugins/governance" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -89,7 +89,7 @@ func newTestRequestCtx(body string) *fasthttp.RequestCtx { return ctx } -func TestPatchPricingOverride_MergesPatch(t *testing.T) { +func TestUpdatePricingOverride_ReplacesFullBody(t *testing.T) { SetLogger(&mockLogger{}) store := setupPricingOverrideHandlerStore(t) handler := &GovernanceHandler{ @@ -98,39 +98,44 @@ func TestPatchPricingOverride_MergesPatch(t *testing.T) { modelCatalog: &modelcatalog.ModelCatalog{}, } - inputCost := 1.0 - outputCost := 2.0 + now := time.Now().UTC() override := configstoreTables.TablePricingOverride{ - ID: "override-1", - Name: "Config Managed", - ScopeKind: pricingoverrides.ScopeKindGlobal, - MatchType: pricingoverrides.MatchTypeExact, - Pattern: "gpt-4.1", - CreatedAt: time.Now().UTC(), - UpdatedAt: time.Now().UTC(), - RequestTypes: []schemas.RequestType{ - schemas.ChatCompletionRequest, - }, - Patch: pricingoverrides.Patch{ - InputCostPerToken: &inputCost, - OutputCostPerToken: &outputCost, - }, + ID: "override-1", + Name: "Original", + ScopeKind: string(modelcatalog.ScopeKindGlobal), + MatchType: string(modelcatalog.MatchTypeExact), + Pattern: "gpt-4.1", + CreatedAt: now, + UpdatedAt: now, + PricingPatchJSON: `{"input_cost_per_token":1,"output_cost_per_token":2}`, + RequestTypes: []schemas.RequestType{schemas.ChatCompletionRequest}, } require.NoError(t, store.CreatePricingOverride(context.Background(), &override)) - ctx := newTestRequestCtx(`{"patch":{"output_cost_per_token":3.5}}`) + // Send complete replacement body — output cost changed, input cost kept + body := `{ + "name":"Updated", + "scope_kind":"global", + "match_type":"exact", + "pattern":"gpt-4.1", + "request_types":["chat_completion"], + "patch":{"input_cost_per_token":1,"output_cost_per_token":3.5} + }` + ctx := newTestRequestCtx(body) ctx.SetUserValue("id", override.ID) - handler.patchPricingOverride(ctx) + handler.updatePricingOverride(ctx) require.Equal(t, fasthttp.StatusOK, ctx.Response.StatusCode(), string(ctx.Response.Body())) stored, err := store.GetPricingOverrideByID(context.Background(), override.ID) require.NoError(t, err) - require.NotNil(t, stored.Patch.InputCostPerToken) - assert.Equal(t, inputCost, *stored.Patch.InputCostPerToken) - require.NotNil(t, stored.Patch.OutputCostPerToken) - assert.Equal(t, 3.5, *stored.Patch.OutputCostPerToken) + assert.Equal(t, "Updated", stored.Name) + + var patch modelcatalog.PricingOptions + require.NoError(t, json.Unmarshal([]byte(stored.PricingPatchJSON), &patch)) + assert.Equal(t, 1.0, patch.InputCostPerToken) + assert.Equal(t, 3.5, patch.OutputCostPerToken) assert.Empty(t, stored.ConfigHash) } diff --git a/transports/bifrost-http/lib/config.go b/transports/bifrost-http/lib/config.go index 3945ab0daa..0e85f22e77 100644 --- a/transports/bifrost-http/lib/config.go +++ b/transports/bifrost-http/lib/config.go @@ -937,6 +937,12 @@ func loadGovernanceConfigFromFile(ctx context.Context, config *Config, configDat logger.Debug("no governance config found in store, processing from config file") config.GovernanceConfig = configData.Governance createGovernanceConfigInStore(ctx, config) + // No config store: load pricing overrides directly into the model catalog + if config.ConfigStore == nil && config.ModelCatalog != nil && len(configData.Governance.PricingOverrides) > 0 { + if err := config.ModelCatalog.SetPricingOverrides(configData.Governance.PricingOverrides); err != nil { + logger.Warn("failed to set pricing overrides from config file: %v", err) + } + } } else { logger.Debug("no governance config in store or config file") } @@ -1175,6 +1181,35 @@ func mergeGovernanceConfig(ctx context.Context, config *Config, configData *Conf routingRulesToAdd = append(routingRulesToAdd, configData.Governance.RoutingRules[i]) } } + // Merge PricingOverrides by ID with hash comparison + pricingOverridesToAdd := make([]configstoreTables.TablePricingOverride, 0) + pricingOverridesToUpdate := make([]configstoreTables.TablePricingOverride, 0) + for i, newOverride := range configData.Governance.PricingOverrides { + fileHash, err := configstore.GeneratePricingOverrideHash(newOverride) + if err != nil { + logger.Warn("failed to generate pricing override hash for %s: %v", newOverride.ID, err) + continue + } + configData.Governance.PricingOverrides[i].ConfigHash = fileHash + + found := false + for j, existing := range governanceConfig.PricingOverrides { + if existing.ID == newOverride.ID { + found = true + if existing.ConfigHash != fileHash { + logger.Debug("config hash mismatch for pricing override %s, syncing from config file", newOverride.ID) + pricingOverridesToUpdate = append(pricingOverridesToUpdate, configData.Governance.PricingOverrides[i]) + governanceConfig.PricingOverrides[j] = configData.Governance.PricingOverrides[i] + } else { + logger.Debug("config hash matches for pricing override %s, keeping DB config", newOverride.ID) + } + break + } + } + if !found { + pricingOverridesToAdd = append(pricingOverridesToAdd, configData.Governance.PricingOverrides[i]) + } + } // Add merged items to config config.GovernanceConfig.Budgets = append(governanceConfig.Budgets, budgetsToAdd...) config.GovernanceConfig.RateLimits = append(governanceConfig.RateLimits, rateLimitsToAdd...) @@ -1182,13 +1217,15 @@ func mergeGovernanceConfig(ctx context.Context, config *Config, configData *Conf config.GovernanceConfig.Teams = append(governanceConfig.Teams, teamsToAdd...) config.GovernanceConfig.VirtualKeys = append(governanceConfig.VirtualKeys, virtualKeysToAdd...) config.GovernanceConfig.RoutingRules = append(governanceConfig.RoutingRules, routingRulesToAdd...) + config.GovernanceConfig.PricingOverrides = append(governanceConfig.PricingOverrides, pricingOverridesToAdd...) // Update store with merged config items hasChanges := len(budgetsToAdd) > 0 || len(budgetsToUpdate) > 0 || len(rateLimitsToAdd) > 0 || len(rateLimitsToUpdate) > 0 || len(customersToAdd) > 0 || len(customersToUpdate) > 0 || len(teamsToAdd) > 0 || len(teamsToUpdate) > 0 || len(virtualKeysToAdd) > 0 || len(virtualKeysToUpdate) > 0 || - len(routingRulesToAdd) > 0 || len(routingRulesToUpdate) > 0 + len(routingRulesToAdd) > 0 || len(routingRulesToUpdate) > 0 || + len(pricingOverridesToAdd) > 0 || len(pricingOverridesToUpdate) > 0 if config.ConfigStore != nil && hasChanges { err := updateGovernanceConfigInStore(ctx, config, budgetsToAdd, budgetsToUpdate, @@ -1196,11 +1233,28 @@ func mergeGovernanceConfig(ctx context.Context, config *Config, configData *Conf customersToAdd, customersToUpdate, teamsToAdd, teamsToUpdate, virtualKeysToAdd, virtualKeysToUpdate, - routingRulesToAdd, routingRulesToUpdate) + routingRulesToAdd, routingRulesToUpdate, + pricingOverridesToAdd, pricingOverridesToUpdate) if err != nil { logger.Fatal("failed to sync governance config: %v", err) } } + // Sync pricing overrides into the model catalog in one batch to avoid + // rebuilding the lookup map on every iteration. + if config.ModelCatalog != nil { + rows := make([]*configstoreTables.TablePricingOverride, 0, len(pricingOverridesToAdd)+len(pricingOverridesToUpdate)) + for i := range pricingOverridesToAdd { + rows = append(rows, &pricingOverridesToAdd[i]) + } + for i := range pricingOverridesToUpdate { + rows = append(rows, &pricingOverridesToUpdate[i]) + } + if len(rows) > 0 { + if err := config.ModelCatalog.UpsertPricingOverrides(rows...); err != nil { + logger.Error("failed to upsert pricing overrides into model catalog: %v", err) + } + } + } } // updateGovernanceConfigInStore updates governance config items in the store @@ -1219,6 +1273,8 @@ func updateGovernanceConfigInStore( virtualKeysToUpdate []configstoreTables.TableVirtualKey, routingRulesToAdd []configstoreTables.TableRoutingRule, routingRulesToUpdate []configstoreTables.TableRoutingRule, + pricingOverridesToAdd []configstoreTables.TablePricingOverride, + pricingOverridesToUpdate []configstoreTables.TablePricingOverride, ) error { logger.Debug("updating governance config in store with merged items") return config.ConfigStore.ExecuteTransaction(ctx, func(tx *gorm.DB) error { @@ -1330,6 +1386,20 @@ func updateGovernanceConfigInStore( } } + // Create pricing overrides (new from config.json) + for _, override := range pricingOverridesToAdd { + if err := config.ConfigStore.CreatePricingOverride(ctx, &override, tx); err != nil { + return fmt.Errorf("failed to create pricing override %s: %w", override.ID, err) + } + } + + // Update pricing overrides (config.json changed) + for _, override := range pricingOverridesToUpdate { + if err := config.ConfigStore.UpdatePricingOverride(ctx, &override, tx); err != nil { + return fmt.Errorf("failed to update pricing override %s: %w", override.ID, err) + } + } + return nil }) } @@ -1411,6 +1481,19 @@ func createGovernanceConfigInStore(ctx context.Context, config *Config) { } } + for i := range config.GovernanceConfig.PricingOverrides { + override := &config.GovernanceConfig.PricingOverrides[i] + overrideHash, err := configstore.GeneratePricingOverrideHash(*override) + if err != nil { + logger.Warn("failed to generate pricing override hash for %s: %v", override.ID, err) + } else { + override.ConfigHash = overrideHash + } + if err := config.ConfigStore.CreatePricingOverride(ctx, override, tx); err != nil { + return fmt.Errorf("failed to create pricing override %s: %w", override.ID, err) + } + } + for i := range config.GovernanceConfig.VirtualKeys { virtualKey := &config.GovernanceConfig.VirtualKeys[i] logger.Debug("creating virtual key: id=%s, name=%s, value=%s", virtualKey.ID, virtualKey.Name, virtualKey.Value) diff --git a/transports/bifrost-http/lib/config_test.go b/transports/bifrost-http/lib/config_test.go index 691ab73a02..a363c940d8 100644 --- a/transports/bifrost-http/lib/config_test.go +++ b/transports/bifrost-http/lib/config_test.go @@ -855,7 +855,7 @@ func (m *MockConfigStore) DeleteModelPrices(ctx context.Context, tx ...*gorm.DB) return nil } -func (m *MockConfigStore) GetPricingOverrides(ctx context.Context, filter configstore.PricingOverrideFilter) ([]tables.TablePricingOverride, error) { +func (m *MockConfigStore) GetPricingOverrides(ctx context.Context, filter configstore.PricingOverrideFilters) ([]tables.TablePricingOverride, error) { return []tables.TablePricingOverride{}, nil } diff --git a/transports/config.schema.json b/transports/config.schema.json index 985c83a667..95580bd67a 100644 --- a/transports/config.schema.json +++ b/transports/config.schema.json @@ -497,6 +497,13 @@ "$ref": "#/$defs/routing_rule" } }, + "pricing_overrides": { + "type": "array", + "description": "Scoped pricing overrides applied at runtime by the model catalog", + "items": { + "$ref": "#/$defs/pricing_override" + } + }, "auth_config": { "$ref": "#/$defs/auth_config" }, @@ -3056,6 +3063,206 @@ } }, "additionalProperties": false + }, + "pricing_override": { + "type": "object", + "description": "Scoped pricing override applied at runtime by the model catalog", + "properties": { + "id": { + "type": "string", + "description": "Unique pricing override ID" + }, + "name": { + "type": "string", + "description": "Human-readable name for this override" + }, + "scope_kind": { + "type": "string", + "description": "Scope level for this override", + "enum": ["global", "provider", "provider_key", "virtual_key", "virtual_key_provider", "virtual_key_provider_key"] + }, + "virtual_key_id": { + "type": "string", + "description": "Virtual key ID (required for virtual_key* scopes)" + }, + "provider_id": { + "type": "string", + "description": "Provider ID (required for provider* scopes)" + }, + "provider_key_id": { + "type": "string", + "description": "Provider key ID (required for provider_key and virtual_key_provider_key scopes)" + }, + "match_type": { + "type": "string", + "description": "How the pattern is matched against model names", + "enum": ["exact", "wildcard"] + }, + "pattern": { + "type": "string", + "description": "Model name pattern to match (exact name or wildcard prefix ending with *)" + }, + "request_types": { + "type": "array", + "description": "Request types this override applies to (empty = all types)", + "items": { + "type": "string" + } + }, + "pricing_patch": { + "type": "string", + "description": "JSON-encoded pricing fields to override (e.g. '{\"input_cost_per_token\":0.000001}')" + }, + "config_hash": { + "type": "string", + "description": "Internal hash for change detection (auto-managed)" + } + }, + "required": ["id", "name", "scope_kind", "match_type", "pattern"], + "additionalProperties": false + }, + "pricing_override_match_type": { + "type": "string", + "enum": ["exact", "wildcard"] + }, + "pricing_override_request_type": { + "type": "string", + "enum": [ + "list_models", "text_completion", "text_completion_stream", + "chat_completion", "chat_completion_stream", "responses", "responses_stream", + "count_tokens", "embedding", "rerank", "speech", "speech_stream", + "transcription", "transcription_stream", "image_generation", "image_generation_stream", + "image_edit", "image_edit_stream", "image_variation", "video_generation", + "video_retrieve", "video_download", "video_delete", "video_list", "video_remix", + "batch_create", "batch_list", "batch_retrieve", "batch_cancel", "batch_delete", "batch_results", + "file_upload", "file_list", "file_retrieve", "file_delete", "file_content", + "container_create", "container_list", "container_retrieve", "container_delete", + "container_file_create", "container_file_list", "container_file_retrieve", + "container_file_content", "container_file_delete", "passthrough", "passthrough_stream" + ] + }, + "provider_pricing_override": { + "type": "object", + "properties": { + "model_pattern": { + "type": "string", + "minLength": 1 + }, + "match_type": { + "$ref": "#/$defs/pricing_override_match_type" + }, + "request_types": { + "type": "array", + "items": { + "$ref": "#/$defs/pricing_override_request_type" + } + }, + "input_cost_per_token": { "type": "number", "minimum": 0 }, + "output_cost_per_token": { "type": "number", "minimum": 0 }, + "input_cost_per_video_per_second": { "type": "number", "minimum": 0 }, + "input_cost_per_audio_per_second": { "type": "number", "minimum": 0 }, + "input_cost_per_character": { "type": "number", "minimum": 0 }, + "output_cost_per_character": { "type": "number", "minimum": 0 }, + "input_cost_per_token_above_128k_tokens": { "type": "number", "minimum": 0 }, + "input_cost_per_character_above_128k_tokens": { "type": "number", "minimum": 0 }, + "input_cost_per_image_above_128k_tokens": { "type": "number", "minimum": 0 }, + "input_cost_per_video_per_second_above_128k_tokens": { "type": "number", "minimum": 0 }, + "input_cost_per_audio_per_second_above_128k_tokens": { "type": "number", "minimum": 0 }, + "output_cost_per_token_above_128k_tokens": { "type": "number", "minimum": 0 }, + "output_cost_per_character_above_128k_tokens": { "type": "number", "minimum": 0 }, + "input_cost_per_token_above_200k_tokens": { "type": "number", "minimum": 0 }, + "output_cost_per_token_above_200k_tokens": { "type": "number", "minimum": 0 }, + "cache_creation_input_token_cost_above_200k_tokens": { "type": "number", "minimum": 0 }, + "cache_read_input_token_cost_above_200k_tokens": { "type": "number", "minimum": 0 }, + "cache_read_input_token_cost": { "type": "number", "minimum": 0 }, + "cache_creation_input_token_cost": { "type": "number", "minimum": 0 }, + "input_cost_per_token_batches": { "type": "number", "minimum": 0 }, + "output_cost_per_token_batches": { "type": "number", "minimum": 0 }, + "input_cost_per_image_token": { "type": "number", "minimum": 0 }, + "output_cost_per_image_token": { "type": "number", "minimum": 0 }, + "input_cost_per_image": { "type": "number", "minimum": 0 }, + "output_cost_per_image": { "type": "number", "minimum": 0 }, + "cache_read_input_image_token_cost": { "type": "number", "minimum": 0 } + }, + "required": ["model_pattern", "match_type"], + "additionalProperties": false + }, + "custom_provider_config": { + "type": "object", + "description": "Custom provider configuration for extending or customizing provider behavior", + "properties": { + "is_key_less": { + "type": "boolean", + "description": "Whether the custom provider requires a key" + }, + "base_provider_type": { + "type": "string", + "description": "Base provider type to extend" + }, + "allowed_requests": { + "type": "object", + "description": "Allowed request types for the custom provider", + "properties": { + "list_models": { "type": "boolean" }, + "text_completion": { "type": "boolean" }, + "text_completion_stream": { "type": "boolean" }, + "chat_completion": { "type": "boolean" }, + "chat_completion_stream": { "type": "boolean" }, + "responses": { "type": "boolean" }, + "responses_stream": { "type": "boolean" }, + "count_tokens": { "type": "boolean" }, + "embedding": { "type": "boolean" }, + "rerank": { "type": "boolean" }, + "speech": { "type": "boolean" }, + "speech_stream": { "type": "boolean" }, + "transcription": { "type": "boolean" }, + "transcription_stream": { "type": "boolean" }, + "image_generation": { "type": "boolean" }, + "image_generation_stream": { "type": "boolean" }, + "image_edit": { "type": "boolean" }, + "image_edit_stream": { "type": "boolean" }, + "image_variation": { "type": "boolean" }, + "video_generation": { "type": "boolean" }, + "video_retrieve": { "type": "boolean" }, + "video_download": { "type": "boolean" }, + "video_delete": { "type": "boolean" }, + "video_list": { "type": "boolean" }, + "video_remix": { "type": "boolean" }, + "batch_create": { "type": "boolean" }, + "batch_list": { "type": "boolean" }, + "batch_retrieve": { "type": "boolean" }, + "batch_cancel": { "type": "boolean" }, + "batch_delete": { "type": "boolean" }, + "batch_results": { "type": "boolean" }, + "file_upload": { "type": "boolean" }, + "file_list": { "type": "boolean" }, + "file_retrieve": { "type": "boolean" }, + "file_delete": { "type": "boolean" }, + "file_content": { "type": "boolean" }, + "container_create": { "type": "boolean" }, + "container_list": { "type": "boolean" }, + "container_retrieve": { "type": "boolean" }, + "container_delete": { "type": "boolean" }, + "container_file_create": { "type": "boolean" }, + "container_file_list": { "type": "boolean" }, + "container_file_retrieve": { "type": "boolean" }, + "container_file_content": { "type": "boolean" }, + "container_file_delete": { "type": "boolean" }, + "passthrough": { "type": "boolean" }, + "passthrough_stream": { "type": "boolean" }, + "websocket_responses": { "type": "boolean" }, + "realtime": { "type": "boolean" } + }, + "additionalProperties": false + }, + "request_path_overrides": { + "type": "object", + "description": "Mapping of request type to custom path overriding the default provider path", + "additionalProperties": { "type": "string" } + } + }, + "required": ["base_provider_type"], + "additionalProperties": false } } } diff --git a/ui/app/workspace/custom-pricing/overrides/pricingFieldSelector.tsx b/ui/app/workspace/custom-pricing/overrides/pricingFieldSelector.tsx new file mode 100644 index 0000000000..b6184cc2fc --- /dev/null +++ b/ui/app/workspace/custom-pricing/overrides/pricingFieldSelector.tsx @@ -0,0 +1,204 @@ +"use client"; + +import { Badge } from "@/components/ui/badge"; +import { Input } from "@/components/ui/input"; +import { cn } from "@/lib/utils"; +import { ChevronDown, Plus, X } from "lucide-react"; +import { useEffect, useMemo, useState } from "react"; +import type { FieldErrors, PricingFieldKey } from "./pricingOverrideDrawer"; +import { PRICING_FIELDS } from "./pricingOverrideDrawer"; + +const PRICING_GROUPS = [ + { key: "token" as const, label: "Token" }, + { key: "cache" as const, label: "Cache" }, + { key: "image" as const, label: "Image" }, + { key: "av" as const, label: "Audio & Video" }, + { key: "other" as const, label: "Other" }, +]; + +type GroupKey = (typeof PRICING_GROUPS)[number]["key"]; + +interface PricingFieldSelectorProps { + values: Partial>; + errors: FieldErrors; + onChange: (key: PricingFieldKey, value: string) => void; + onFieldInteraction?: () => void; +} + +export function PricingFieldSelector({ values, errors, onChange, onFieldInteraction }: PricingFieldSelectorProps) { + const [search, setSearch] = useState(""); + const [openGroups, setOpenGroups] = useState>(new Set(["token"])); + + const [activeFields, setActiveFields] = useState>( + () => new Set(PRICING_FIELDS.filter((f) => values[f.key] != null && values[f.key]!.trim() !== "").map((f) => f.key)), + ); + + // Auto-activate fields that gain values (e.g., from JSON editing or loading an existing override) + useEffect(() => { + setActiveFields((prev) => { + const next = new Set(prev); + for (const f of PRICING_FIELDS) { + if (values[f.key] != null && values[f.key]!.trim() !== "") { + next.add(f.key); + } + } + return next; + }); + }, [values]); + + const trimmedSearch = search.trim().toLowerCase(); + const isSearching = trimmedSearch.length > 0; + + const filteredFields = useMemo(() => { + if (!isSearching) return null; + return PRICING_FIELDS.filter((f) => f.label.toLowerCase().includes(trimmedSearch) || f.key.toLowerCase().includes(trimmedSearch)); + }, [isSearching, trimmedSearch]); + + const groupedFields = useMemo( + () => + PRICING_GROUPS.map((group) => ({ + ...group, + fields: PRICING_FIELDS.filter((f) => f.group === group.key), + })), + [], + ); + + const toggleGroup = (key: GroupKey) => { + setOpenGroups((prev) => { + const next = new Set(prev); + if (next.has(key)) next.delete(key); + else next.add(key); + return next; + }); + }; + + const activateField = (key: PricingFieldKey) => { + setActiveFields((prev) => new Set([...prev, key])); + }; + + const deactivateField = (key: PricingFieldKey) => { + setActiveFields((prev) => { + const next = new Set(prev); + next.delete(key); + return next; + }); + onFieldInteraction?.(); + onChange(key, ""); + }; + + const handleInputChange = (key: PricingFieldKey, value: string) => { + onFieldInteraction?.(); + onChange(key, value); + }; + + const renderFieldRow = (field: { key: PricingFieldKey; label: string }) => { + const isActive = activeFields.has(field.key); + const hasValue = values[field.key]?.trim(); + const error = errors[field.key]; + + if (!isActive) { + return ( + + ); + } + + return ( +
+
+ {field.label} + +
+ handleInputChange(field.key, e.target.value)} + placeholder="0.0" + /> + {error &&

{error}

} +
+ ); + }; + + return ( +
+ setSearch(e.target.value)} + className="h-9" + data-testid="pricing-field-search" + /> + +
+ {isSearching ? ( +
+ {filteredFields!.length === 0 ? ( +
No fields match “{search}”
+ ) : ( + filteredFields!.map((field) => renderFieldRow(field)) + )} +
+ ) : ( +
+ {groupedFields.map((group) => { + const isOpen = openGroups.has(group.key); + const valueCount = group.fields.filter((f) => values[f.key]?.trim()).length; + + return ( +
+ + + {isOpen && ( +
+ {group.fields.map((field) => renderFieldRow(field))} +
+ )} +
+ ); + })} +
+ )} +
+
+ ); +} diff --git a/ui/app/workspace/custom-pricing/overrides/pricingOverrideDrawer.tsx b/ui/app/workspace/custom-pricing/overrides/pricingOverrideDrawer.tsx index a71c0b889b..fb4b503c86 100644 --- a/ui/app/workspace/custom-pricing/overrides/pricingOverrideDrawer.tsx +++ b/ui/app/workspace/custom-pricing/overrides/pricingOverrideDrawer.tsx @@ -1,8 +1,6 @@ "use client"; -import { CodeEditor } from "@/app/workspace/logs/views/codeEditor"; -import { Accordion, AccordionContent, AccordionItem, AccordionTrigger } from "@/components/ui/accordion"; -import { Badge } from "@/components/ui/badge"; +import { CodeEditor } from "@/components/ui/codeEditor"; import { Button } from "@/components/ui/button"; import { Checkbox } from "@/components/ui/checkbox"; import { Input } from "@/components/ui/input"; @@ -10,26 +8,26 @@ import { Label } from "@/components/ui/label"; import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover"; import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from "@/components/ui/select"; import { DottedSeparator } from "@/components/ui/separator"; -import { Sheet, SheetContent, SheetFooter, SheetHeader, SheetTitle } from "@/components/ui/sheet"; +import { Sheet, SheetContent, SheetHeader, SheetTitle } from "@/components/ui/sheet"; +import { PricingFieldSelector } from "./pricingFieldSelector"; import { getErrorMessage, useCreatePricingOverrideMutation, useGetProvidersQuery, useGetVirtualKeysQuery, - usePatchPricingOverrideMutation, + useUpdatePricingOverrideMutation, } from "@/lib/store"; import { RequestTypeLabels } from "@/lib/constants/logs"; import { ModelProvider } from "@/lib/types/config"; import { CreatePricingOverrideRequest, - PatchPricingOverrideRequest, PricingOverride, PricingOverrideMatchType, PricingOverridePatch, PricingOverrideScopeKind, } from "@/lib/types/governance"; import { cn } from "@/lib/utils"; -import { ChevronDown } from "lucide-react"; +import { ChevronDown, Save, X } from "lucide-react"; import { Dispatch, SetStateAction, useCallback, useEffect, useMemo, useRef, useState } from "react"; import { toast } from "sonner"; @@ -71,15 +69,12 @@ export const PRICING_FIELDS = [ { key: "input_cost_per_token", label: "Input / token", group: "token" }, { key: "output_cost_per_token", label: "Output / token", group: "token" }, { key: "input_cost_per_character", label: "Input / character", group: "token" }, - { key: "output_cost_per_character", label: "Output / character", group: "token" }, { key: "input_cost_per_token_batches", label: "Input / token (batch)", group: "token" }, { key: "output_cost_per_token_batches", label: "Output / token (batch)", group: "token" }, { key: "input_cost_per_token_priority", label: "Input / token (priority)", group: "token" }, { key: "output_cost_per_token_priority", label: "Output / token (priority)", group: "token" }, { key: "input_cost_per_token_above_128k_tokens", label: "Input / token (>128k)", group: "token" }, { key: "output_cost_per_token_above_128k_tokens", label: "Output / token (>128k)", group: "token" }, - { key: "input_cost_per_character_above_128k_tokens", label: "Input / character (>128k)", group: "token" }, - { key: "output_cost_per_character_above_128k_tokens", label: "Output / character (>128k)", group: "token" }, { key: "input_cost_per_token_above_200k_tokens", label: "Input / token (>200k)", group: "token" }, { key: "output_cost_per_token_above_200k_tokens", label: "Output / token (>200k)", group: "token" }, { key: "cache_creation_input_token_cost", label: "Cache creation / token", group: "cache" }, @@ -90,6 +85,7 @@ export const PRICING_FIELDS = [ { key: "cache_creation_input_token_cost_above_1hr_above_200k_tokens", label: "Cache creation / token (>1hr, >200k)", group: "cache" }, { key: "cache_creation_input_audio_token_cost", label: "Cache creation / audio token", group: "cache" }, { key: "cache_read_input_token_cost_priority", label: "Cache read / token (priority)", group: "cache" }, + { key: "cache_read_input_image_token_cost", label: "Cache read / image token", group: "cache" }, { key: "input_cost_per_image_token", label: "Input / image token", group: "image" }, { key: "output_cost_per_image_token", label: "Output / image token", group: "image" }, { key: "input_cost_per_image", label: "Input / image", group: "image" }, @@ -102,7 +98,10 @@ export const PRICING_FIELDS = [ { key: "output_cost_per_image_above_512_and_512_pixels_and_premium_image", label: "Output / image (>512px, premium)", group: "image" }, { key: "output_cost_per_image_above_1024_and_1024_pixels", label: "Output / image (>1024px)", group: "image" }, { key: "output_cost_per_image_above_1024_and_1024_pixels_and_premium_image", label: "Output / image (>1024px, premium)", group: "image" }, - { key: "cache_read_input_image_token_cost", label: "Cache read / image token", group: "image" }, + { key: "output_cost_per_image_low_quality", label: "Output / image (low quality)", group: "image" }, + { key: "output_cost_per_image_medium_quality", label: "Output / image (medium quality)", group: "image" }, + { key: "output_cost_per_image_high_quality", label: "Output / image (high quality)", group: "image" }, + { key: "output_cost_per_image_auto_quality", label: "Output / image (auto quality)", group: "image" }, { key: "input_cost_per_audio_token", label: "Input / audio token", group: "av" }, { key: "input_cost_per_audio_per_second", label: "Input / audio second", group: "av" }, { key: "input_cost_per_audio_per_second_above_128k_tokens", label: "Input / audio second (>128k)", group: "av" }, @@ -266,11 +265,11 @@ export function renderFields( {fields.map((field) => (
- { onFieldChange?.(); @@ -287,9 +286,6 @@ export function renderFields( ); } -function countFieldsWithValues(fields: ReadonlyArray<{ key: PricingFieldKey }>, form: FormState): number { - return fields.filter((f) => form.pricingValues[f.key]?.trim()).length; -} interface PricingOverrideDrawerProps { open: boolean; @@ -329,7 +325,7 @@ export default function PricingOverrideDrawer({ open, onOpenChange, editingOverr const { data: providersData } = useGetProvidersQuery(); const { data: virtualKeysData } = useGetVirtualKeysQuery(); const [createOverride, { isLoading: isCreating }] = useCreatePricingOverrideMutation(); - const [patchOverride, { isLoading: isPatching }] = usePatchPricingOverrideMutation(); + const [updateOverride, { isLoading: isPatching }] = useUpdatePricingOverrideMutation(); const [form, setForm] = useState(defaultFormState); const [jsonPatch, setJSONPatch] = useState(""); @@ -374,8 +370,8 @@ export default function PricingOverrideDrawer({ open, onOpenChange, editingOverr providerKeyID: scopeLock.providerKeyID ?? "", scopeRoot: scopeLock.scopeKind === "virtual_key" || - scopeLock.scopeKind === "virtual_key_provider" || - scopeLock.scopeKind === "virtual_key_provider_key" + scopeLock.scopeKind === "virtual_key_provider" || + scopeLock.scopeKind === "virtual_key_provider_key" ? "virtual_key" : "global", }; @@ -405,46 +401,26 @@ export default function PricingOverrideDrawer({ open, onOpenChange, editingOverr return form.providerKeyID || undefined; }, [scopeLock, shouldLockScope, form.providerKeyID]); - const validation = useMemo(() => { - const errors: FieldErrors = {}; - if (!form.name.trim()) { - errors.name = "Name is required"; - } - if ( - (resolvedScopeKind === "virtual_key" || - resolvedScopeKind === "virtual_key_provider" || - resolvedScopeKind === "virtual_key_provider_key") && - !resolvedVirtualKeyID - ) { - errors.scope = "Virtual key is required"; + const pricingFieldErrors = useMemo(() => { + const errors: FieldErrors = {}; + for (const key of patchKeys) { + const raw = form.pricingValues[key]; + if (!raw || raw.trim() === "") continue; + const parsed = Number(raw); + if (!Number.isFinite(parsed)) errors[key] = "Must be a number"; + else if (parsed < 0) errors[key] = "Must be >= 0"; } - if ((resolvedScopeKind === "provider" || resolvedScopeKind === "virtual_key_provider") && !resolvedProviderID) { - errors.scope = "Provider is required"; - } - if (resolvedScopeKind === "provider_key" && !resolvedProviderKeyID) { - errors.scope = "Provider key is required"; - } - if (resolvedScopeKind === "virtual_key_provider_key" && (!resolvedProviderID || !resolvedProviderKeyID)) { - errors.scope = "Provider and provider key are required"; - } - - const pError = patternError(form.matchType, form.pattern); - if (pError) errors.pattern = pError; - - const built = buildPatchFromForm(form); - Object.assign(errors, built.errors); - if (Object.keys(built.patch).length === 0) errors.patch = "At least one pricing field must be overridden"; - - return { errors, patch: built.patch }; - }, [form, resolvedScopeKind, resolvedVirtualKeyID, resolvedProviderID, resolvedProviderKeyID]); + return errors; + }, [form.pricingValues]); useEffect(() => { if (!jsonEditingRef.current) { - const json = Object.keys(validation.patch).length > 0 ? JSON.stringify(validation.patch, null, 2) : ""; + const { patch } = buildPatchFromForm(form); + const json = Object.keys(patch).length > 0 ? JSON.stringify(patch, null, 2) : ""; setJSONPatch(json); setJSONError(undefined); } - }, [validation.patch]); + }, [form]); const handleJSONChange = useCallback((value: string) => { jsonEditingRef.current = true; @@ -484,7 +460,6 @@ export default function PricingOverrideDrawer({ open, onOpenChange, editingOverr jsonEditingRef.current = false; }, []); - const isFormValid = Object.keys(validation.errors).length === 0 && !jsonError; const selectedRequestTypeGroup = form.requestTypes.length > 0 ? getRequestTypeGroup(form.requestTypes[0]) || "Other request types" : undefined; @@ -502,61 +477,98 @@ export default function PricingOverrideDrawer({ open, onOpenChange, editingOverr })); }; - const handleSave = async () => { - if (!isFormValid) return; - let scopedVirtualKeyID: string | undefined; - let scopedProviderID: string | undefined; - let scopedProviderKeyID: string | undefined; - - switch (resolvedScopeKind) { - case "global": - break; - case "provider": - scopedProviderID = resolvedProviderID; - break; - case "provider_key": - scopedProviderKeyID = resolvedProviderKeyID; - break; - case "virtual_key": - scopedVirtualKeyID = resolvedVirtualKeyID; - break; - case "virtual_key_provider": - scopedVirtualKeyID = resolvedVirtualKeyID; - scopedProviderID = resolvedProviderID; - break; - case "virtual_key_provider_key": - scopedVirtualKeyID = resolvedVirtualKeyID; - scopedProviderID = resolvedProviderID; - scopedProviderKeyID = resolvedProviderKeyID; - break; - } + const handleSave = async () => { + if (!form.name.trim()) { + toast.error("Name is required"); + return; + } + + if ( + (resolvedScopeKind === "virtual_key" || + resolvedScopeKind === "virtual_key_provider" || + resolvedScopeKind === "virtual_key_provider_key") && + !resolvedVirtualKeyID + ) { + toast.error("Virtual key is required"); + return; + } + if ((resolvedScopeKind === "provider" || resolvedScopeKind === "virtual_key_provider") && !resolvedProviderID) { + toast.error("Provider is required"); + return; + } + if (resolvedScopeKind === "provider_key" && !resolvedProviderKeyID) { + toast.error("Provider key is required"); + return; + } + if (resolvedScopeKind === "virtual_key_provider_key" && (!resolvedProviderID || !resolvedProviderKeyID)) { + toast.error("Provider and provider key are required"); + return; + } + + const pError = patternError(form.matchType, form.pattern); + if (pError) { + toast.error(pError); + return; + } + + if (jsonError) { + toast.error("Fix the JSON error before saving"); + return; + } + + const { patch, errors: pricingErrors } = buildPatchFromForm(form); + const firstPricingError = Object.values(pricingErrors)[0]; + if (firstPricingError) { + toast.error(firstPricingError); + return; + } + if (Object.keys(patch).length === 0) { + toast.error("At least one pricing field must be overridden"); + return; + } + + let scopedVirtualKeyID: string | undefined; + let scopedProviderID: string | undefined; + let scopedProviderKeyID: string | undefined; + + switch (resolvedScopeKind) { + case "global": + break; + case "provider": + scopedProviderID = resolvedProviderID; + break; + case "provider_key": + scopedProviderKeyID = resolvedProviderKeyID; + break; + case "virtual_key": + scopedVirtualKeyID = resolvedVirtualKeyID; + break; + case "virtual_key_provider": + scopedVirtualKeyID = resolvedVirtualKeyID; + scopedProviderID = resolvedProviderID; + break; + case "virtual_key_provider_key": + scopedVirtualKeyID = resolvedVirtualKeyID; + scopedProviderID = resolvedProviderID; + scopedProviderKeyID = resolvedProviderKeyID; + break; + } - const requestPayload: CreatePricingOverrideRequest = { - name: form.name.trim(), - scope_kind: resolvedScopeKind, - virtual_key_id: scopedVirtualKeyID, - provider_id: scopedProviderID, - provider_key_id: scopedProviderKeyID, - match_type: form.matchType, - pattern: form.pattern.trim(), - request_types: form.requestTypes.length > 0 ? form.requestTypes : [], - patch: validation.patch, + const requestPayload: CreatePricingOverrideRequest = { + name: form.name.trim(), + scope_kind: resolvedScopeKind, + virtual_key_id: scopedVirtualKeyID, + provider_id: scopedProviderID, + provider_key_id: scopedProviderKeyID, + match_type: form.matchType, + pattern: form.pattern.trim(), + request_types: form.requestTypes.length > 0 ? form.requestTypes : [], + patch, }; try { if (editingOverride) { - const payload: PatchPricingOverrideRequest = { - name: requestPayload.name, - scope_kind: requestPayload.scope_kind, - virtual_key_id: requestPayload.virtual_key_id ?? "", - provider_id: requestPayload.provider_id ?? "", - provider_key_id: requestPayload.provider_key_id ?? "", - match_type: requestPayload.match_type, - pattern: requestPayload.pattern, - request_types: requestPayload.request_types, - patch: requestPayload.patch, - }; - await patchOverride({ id: editingOverride.id, data: payload }).unwrap(); + await updateOverride({ id: editingOverride.id, data: requestPayload }).unwrap(); toast.success("Pricing override updated"); } else { await createOverride(requestPayload).unwrap(); @@ -569,13 +581,6 @@ export default function PricingOverrideDrawer({ open, onOpenChange, editingOverr } }; - const advancedSections = { - cache: PRICING_FIELDS.filter((field) => field.group === "cache"), - image: PRICING_FIELDS.filter((field) => field.group === "image"), - av: PRICING_FIELDS.filter((field) => field.group === "av"), - other: PRICING_FIELDS.filter((field) => field.group === "other"), - }; - const tokenFields = PRICING_FIELDS.filter((field) => field.group === "token"); return ( (o ? onOpenChange(true) : handleCloseDrawer())}> @@ -587,17 +592,16 @@ export default function PricingOverrideDrawer({ open, onOpenChange, editingOverr
- - setForm((prev) => ({ ...prev, name: e.target.value }))} /> - {validation.errors.name &&

{validation.errors.name}

} + + setForm((prev) => ({ ...prev, name: e.target.value }))} />
- {shouldLockScope && scopeLock ? ( -
- - -
- ) : ( + {shouldLockScope && scopeLock ? ( +
+ + +
+ ) : ( <>
@@ -619,7 +623,7 @@ export default function PricingOverrideDrawer({ open, onOpenChange, editingOverr {form.scopeRoot === "virtual_key" && (
- + + setForm((prev) => ({ ...prev, providerID: value === "__none__" ? "" : value, providerKeyID: "" })) + } + > + + + + + All providers + {providers.map((provider) => ( + + {provider.name} + + ))} + + +
+ + {form.providerID ? (
- +
- - {form.providerID ? ( -
- - -
- ) : ( -
- )} -
+ ) : ( +
+ )} +
)} - {validation.errors.scope &&

{validation.errors.scope}

} -
- - +
-
+
- + setForm((prev) => ({ ...prev, pattern: e.target.value }))} - placeholder={form.matchType === "exact" ? "gpt-5-mini" : "gpt-5*"} + placeholder={form.matchType === "exact" ? "e.g., gpt-4o" : "e.g., gpt-4*"} />
- {validation.errors.pattern &&

{validation.errors.pattern}

} -
- - +
- + - @@ -762,11 +760,11 @@ export default function PricingOverrideDrawer({ open, onOpenChange, editingOverr isGroupDisabled ? "cursor-not-allowed opacity-50" : "hover:bg-muted cursor-pointer", )} > - toggleRequestType(requestType)} + toggleRequestType(requestType)} /> {RequestTypeLabels[requestType as keyof typeof RequestTypeLabels] ?? requestType} @@ -778,93 +776,33 @@ export default function PricingOverrideDrawer({ open, onOpenChange, editingOverr })()}
- -
+ +
- - -
- - - - - - Token - {countFieldsWithValues(tokenFields, form) > 0 && ( - - {countFieldsWithValues(tokenFields, form)} - - )} - - - {renderFields(tokenFields, form, setForm, validation.errors, handleFieldChange)} - - - - - Cache - {countFieldsWithValues(advancedSections.cache, form) > 0 && ( - - {countFieldsWithValues(advancedSections.cache, form)} - - )} - - - {renderFields(advancedSections.cache, form, setForm, validation.errors, handleFieldChange)} - - - - - Image - {countFieldsWithValues(advancedSections.image, form) > 0 && ( - - {countFieldsWithValues(advancedSections.image, form)} - - )} - - - {renderFields(advancedSections.image, form, setForm, validation.errors, handleFieldChange)} - - - - - Audio and Video - {countFieldsWithValues(advancedSections.av, form) > 0 && ( - - {countFieldsWithValues(advancedSections.av, form)} - - )} - - - {renderFields(advancedSections.av, form, setForm, validation.errors, handleFieldChange)} - - - - - Other - {countFieldsWithValues(advancedSections.other, form) > 0 && ( - - {countFieldsWithValues(advancedSections.other, form)} - - )} - - - {renderFields(advancedSections.other, form, setForm, validation.errors, handleFieldChange)} - - - {validation.errors.patch &&

{validation.errors.patch}

} -
+
+ + { + handleFieldChange(); + setForm((prev) => ({ ...prev, pricingValues: { ...prev.pricingValues, [key]: value } })); + }} + onFieldInteraction={handleFieldChange} + /> +
@@ -884,14 +822,16 @@ export default function PricingOverrideDrawer({ open, onOpenChange, editingOverr
- +
- - +
); diff --git a/ui/app/workspace/custom-pricing/overrides/pricingOverridesEmptyState.tsx b/ui/app/workspace/custom-pricing/overrides/pricingOverridesEmptyState.tsx new file mode 100644 index 0000000000..52c6dae93b --- /dev/null +++ b/ui/app/workspace/custom-pricing/overrides/pricingOverridesEmptyState.tsx @@ -0,0 +1,45 @@ +"use client"; + +import { Button } from "@/components/ui/button"; +import { ArrowUpRight, SlidersHorizontal } from "lucide-react"; + +const PRICING_OVERRIDES_DOCS_URL = "https://docs.getbifrost.ai/features/governance/custom-pricing"; + +interface PricingOverridesEmptyStateProps { + onCreateClick: () => void; +} + +export function PricingOverridesEmptyState({ onCreateClick }: PricingOverridesEmptyStateProps) { + return ( +
+
+ +
+
+

Pricing overrides customize cost tracking per scope

+
+ Define custom per-token prices for specific providers, keys, or virtual keys to accurately reflect your negotiated rates. +
+
+ + +
+
+
+ ); +} diff --git a/ui/app/workspace/custom-pricing/overrides/scopedPricingOverridesView.tsx b/ui/app/workspace/custom-pricing/overrides/scopedPricingOverridesView.tsx index 6acfcd1fa3..93c223af55 100644 --- a/ui/app/workspace/custom-pricing/overrides/scopedPricingOverridesView.tsx +++ b/ui/app/workspace/custom-pricing/overrides/scopedPricingOverridesView.tsx @@ -25,6 +25,7 @@ import { useSearchParams } from "next/navigation"; import { useEffect, useMemo, useState } from "react"; import { toast } from "sonner"; import PricingOverrideDrawer from "./pricingOverrideDrawer"; +import { PricingOverridesEmptyState } from "./pricingOverridesEmptyState"; type ScopeFilter = "all" | PricingOverrideScopeKind; @@ -190,6 +191,20 @@ export default function ScopedPricingOverridesView() { } }; + if (!isLoading && !error && rows.length === 0) { + return ( + <> + + + + ); + } + return (
@@ -204,8 +219,6 @@ export default function ScopedPricingOverridesView() {
Loading overrides...
) : error ? (
Failed to load pricing overrides. Please try refreshing the page.
- ) : rows.length === 0 ? ( -
No pricing overrides configured.
) : ( diff --git a/ui/components/sidebar.tsx b/ui/components/sidebar.tsx index b0e35f4f33..e7dab835e2 100644 --- a/ui/components/sidebar.tsx +++ b/ui/components/sidebar.tsx @@ -250,12 +250,12 @@ const SidebarItemView = ({ data-testid={`nav-button-${item.title.toLowerCase().replace(/\s+/g, "-")}`} data-nav-url={!hasSubItems ? item.url : undefined} className={`relative h-7.5 cursor-pointer rounded-sm border px-3 transition-all duration-200 ${isHighlighted - ? "bg-sidebar-accent text-accent-foreground border-primary/20" - : isActive || isAnySubItemActive - ? "bg-sidebar-accent text-primary border-primary/20" - : item.hasAccess - ? "hover:bg-sidebar-accent hover:text-accent-foreground border-transparent text-slate-500 dark:text-zinc-400" - : "hover:bg-destructive/5 hover:text-muted-foreground text-muted-foreground cursor-not-allowed border-transparent" + ? "bg-sidebar-accent text-accent-foreground border-primary/20" + : isActive || isAnySubItemActive + ? "bg-sidebar-accent text-primary border-primary/20" + : item.hasAccess + ? "hover:bg-sidebar-accent hover:text-accent-foreground border-transparent text-slate-500 dark:text-zinc-400" + : "hover:bg-destructive/5 hover:text-muted-foreground text-muted-foreground cursor-not-allowed border-transparent" } `} onClick={hasSubItems ? handleClick : item.hasAccess ? (e) => handleNavigation(item.url, e) : undefined} > @@ -297,12 +297,12 @@ const SidebarItemView = ({ data-testid={`nav-submenu-toggle-${subItem.title.toLowerCase().replace(/\s+/g, "-")}`} data-nav-url={subItem.url} className={`h-7 cursor-pointer rounded-sm px-2 transition-all duration-200 ${isSubItemHighlighted - ? "bg-sidebar-accent text-accent-foreground" - : isSubItemActive - ? "bg-sidebar-accent text-primary font-medium" - : subItem.hasAccess === false - ? "hover:bg-destructive/5 hover:text-muted-foreground text-muted-foreground cursor-not-allowed border-transparent" - : "hover:bg-sidebar-accent hover:text-accent-foreground text-slate-500 dark:text-zinc-400" + ? "bg-sidebar-accent text-accent-foreground" + : isSubItemActive + ? "bg-sidebar-accent text-primary font-medium" + : subItem.hasAccess === false + ? "hover:bg-destructive/5 hover:text-muted-foreground text-muted-foreground cursor-not-allowed border-transparent" + : "hover:bg-sidebar-accent hover:text-accent-foreground text-slate-500 dark:text-zinc-400" }`} onClick={(e) => (subItem.hasAccess === false ? undefined : handleSubItemClick(subItem, e))} > @@ -486,7 +486,7 @@ export default function AppSidebar() { hasAccess: hasSettingsAccess, }, { - title: "Pricing overrides", + title: "Pricing Overrides", url: "/workspace/custom-pricing/overrides", icon: SlidersHorizontal, description: "Scoped pricing overrides", diff --git a/ui/lib/store/apis/governanceApi.ts b/ui/lib/store/apis/governanceApi.ts index d2b2027463..5a0141d2e2 100644 --- a/ui/lib/store/apis/governanceApi.ts +++ b/ui/lib/store/apis/governanceApi.ts @@ -23,7 +23,6 @@ import { HealthCheckResponse, ModelConfig, ProviderGovernance, - PatchPricingOverrideRequest, PricingOverride, RateLimit, ResetUsageRequest, @@ -590,13 +589,13 @@ export const governanceApi = baseApi.injectEndpoints({ invalidatesTags: ["PricingOverrides"], }), - patchPricingOverride: builder.mutation< + updatePricingOverride: builder.mutation< { message: string; pricing_override: PricingOverride }, - { id: string; data: PatchPricingOverrideRequest } + { id: string; data: CreatePricingOverrideRequest } >({ query: ({ id, data }) => ({ url: `/governance/pricing-overrides/${id}`, - method: "PATCH", + method: "PUT", body: data, }), invalidatesTags: ["PricingOverrides"], @@ -729,7 +728,7 @@ export const { useDeleteModelConfigMutation, useGetPricingOverridesQuery, useCreatePricingOverrideMutation, - usePatchPricingOverrideMutation, + useUpdatePricingOverrideMutation, useDeletePricingOverrideMutation, // Provider Governance diff --git a/ui/lib/types/governance.ts b/ui/lib/types/governance.ts index 8a960d1a9c..4109095f9c 100644 --- a/ui/lib/types/governance.ts +++ b/ui/lib/types/governance.ts @@ -372,38 +372,34 @@ export type PricingOverrideScopeKind = export type PricingOverrideMatchType = "exact" | "wildcard"; export interface PricingOverridePatch { + // Token input_cost_per_token?: number; output_cost_per_token?: number; + input_cost_per_token_batches?: number; + output_cost_per_token_batches?: number; input_cost_per_token_priority?: number; output_cost_per_token_priority?: number; input_cost_per_character?: number; - output_cost_per_character?: number; - input_cost_per_audio_token?: number; - input_cost_per_video_per_second?: number; - input_cost_per_second?: number; - output_cost_per_audio_token?: number; - output_cost_per_video_per_second?: number; - output_cost_per_second?: number; - input_cost_per_audio_per_second?: number; + // 128k tier input_cost_per_token_above_128k_tokens?: number; - input_cost_per_character_above_128k_tokens?: number; + output_cost_per_token_above_128k_tokens?: number; input_cost_per_image_above_128k_tokens?: number; input_cost_per_video_per_second_above_128k_tokens?: number; input_cost_per_audio_per_second_above_128k_tokens?: number; - output_cost_per_token_above_128k_tokens?: number; - output_cost_per_character_above_128k_tokens?: number; + // 200k tier input_cost_per_token_above_200k_tokens?: number; output_cost_per_token_above_200k_tokens?: number; + // Cache + cache_creation_input_token_cost?: number; + cache_read_input_token_cost?: number; cache_creation_input_token_cost_above_200k_tokens?: number; cache_read_input_token_cost_above_200k_tokens?: number; cache_creation_input_token_cost_above_1hr?: number; cache_creation_input_token_cost_above_1hr_above_200k_tokens?: number; cache_creation_input_audio_token_cost?: number; cache_read_input_token_cost_priority?: number; - cache_read_input_token_cost?: number; - cache_creation_input_token_cost?: number; - input_cost_per_token_batches?: number; - output_cost_per_token_batches?: number; + cache_read_input_image_token_cost?: number; + // Image input_cost_per_image_token?: number; output_cost_per_image_token?: number; input_cost_per_image?: number; @@ -415,7 +411,19 @@ export interface PricingOverridePatch { output_cost_per_image_above_512_and_512_pixels_and_premium_image?: number; output_cost_per_image_above_1024_and_1024_pixels?: number; output_cost_per_image_above_1024_and_1024_pixels_and_premium_image?: number; - cache_read_input_image_token_cost?: number; + output_cost_per_image_low_quality?: number; + output_cost_per_image_medium_quality?: number; + output_cost_per_image_high_quality?: number; + output_cost_per_image_auto_quality?: number; + // Audio/Video + input_cost_per_audio_token?: number; + input_cost_per_audio_per_second?: number; + input_cost_per_second?: number; + input_cost_per_video_per_second?: number; + output_cost_per_audio_token?: number; + output_cost_per_video_per_second?: number; + output_cost_per_second?: number; + // Other search_context_cost_per_query?: number; code_interpreter_cost_per_session?: number; } @@ -448,17 +456,6 @@ export interface CreatePricingOverrideRequest { patch?: PricingOverridePatch; } -export interface PatchPricingOverrideRequest { - name?: string; - scope_kind?: PricingOverrideScopeKind; - virtual_key_id?: string; - provider_id?: string; - provider_key_id?: string; - match_type?: PricingOverrideMatchType; - pattern?: string; - request_types?: string[]; - patch?: PricingOverridePatch; -} export interface GetPricingOverridesResponse { pricing_overrides: PricingOverride[];