Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion core/schemas/bifrost.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,9 @@ const (
BifrostContextKeyGovernanceBusinessUnitName BifrostContextKey = "bifrost-governance-business-unit-name" // string (to store the business unit name (set by enterprise governance plugin - DO NOT SET THIS MANUALLY))
BifrostContextKeyGovernanceRoutingRuleID BifrostContextKey = "bifrost-governance-routing-rule-id" // string (to store the routing rule ID (set by bifrost governance plugin - DO NOT SET THIS MANUALLY))
BifrostContextKeyGovernanceRoutingRuleName BifrostContextKey = "bifrost-governance-routing-rule-name" // string (to store the routing rule name (set by bifrost governance plugin - DO NOT SET THIS MANUALLY))
BifrostContextKeySelectedPromptName BifrostContextKey = "bifrost-selected-prompt-name" // string (display name of the selected prompt (set by prompts plugin - DO NOT SET THIS MANUALLY))
Comment thread
roroghost17 marked this conversation as resolved.
BifrostContextKeySelectedPromptVersion BifrostContextKey = "bifrost-selected-prompt-version" // string (numeric version as string, e.g. "3" (set by prompts plugin - DO NOT SET THIS MANUALLY))
BifrostContextKeySelectedPromptID BifrostContextKey = "bifrost-selected-prompt-id" // string (id of the selected prompt (set by prompts plugin - DO NOT SET THIS MANUALLY))
BifrostContextKeyGovernanceIncludeOnlyKeys BifrostContextKey = "bf-governance-include-only-keys" // []string (to store the include-only key IDs for provider config routing (set by bifrost governance plugin - DO NOT SET THIS MANUALLY))
BifrostContextKeyNumberOfRetries BifrostContextKey = "bifrost-number-of-retries" // int (to store the number of retries (set by bifrost - DO NOT SET THIS MANUALLY))
BifrostContextKeyFallbackIndex BifrostContextKey = "bifrost-fallback-index" // int (to store the fallback index (set by bifrost - DO NOT SET THIS MANUALLY)) 0 for primary, 1 for first fallback, etc.
Expand Down Expand Up @@ -222,6 +225,7 @@ const (
BifrostContextKeyHasEmittedMessageDelta BifrostContextKey = "bifrost-has-emitted-message-delta" // bool (tracks whether message_delta was already emitted during streaming - avoids duplicates)
BifrostContextKeySkipDBUpdate BifrostContextKey = "bifrost-skip-db-update" // bool (set by bifrost - DO NOT SET THIS MANUALLY))
BifrostContextKeyGovernancePluginName BifrostContextKey = "governance-plugin-name" // string (name of the governance plugin that processed the request - set by bifrost)
BifrostContextKeyPromptsPluginName BifrostContextKey = "prompts-plugin-name" // string (name of the prompts plugin to use - set by bifrost - DO NOT SET THIS MANUALLY))
BifrostContextKeyIsEnterprise BifrostContextKey = "is-enterprise" // bool (set by bifrost - DO NOT SET THIS MANUALLY))
BifrostContextKeyAvailableProviders BifrostContextKey = "available-providers" // []ModelProvider (set by bifrost - DO NOT SET THIS MANUALLY))
BifrostContextKeyStoreRawRequestResponse BifrostContextKey = "bifrost-store-raw-request-response" // bool (per-request override — read by bifrost.go, never overwritten)
Expand All @@ -235,7 +239,6 @@ const (
BifrostContextKeyHTTPRequestType BifrostContextKey = "bifrost-http-request-type" // RequestType (set by bifrost - DO NOT SET THIS MANUALLY))
BifrostContextKeyPassthroughExtraParams BifrostContextKey = "bifrost-passthrough-extra-params" // bool
BifrostContextKeyRoutingEnginesUsed BifrostContextKey = "bifrost-routing-engines-used" // []string (set by bifrost - DO NOT SET THIS MANUALLY) - list of routing engines used ("routing-rule", "governance", "loadbalancing", etc.)
BifrostContextKeyPromptStreamRequest BifrostContextKey = "bifrost-prompt-stream-request" // bool (set by prompts HTTP plugin when prompt version model_params.stream is true and body omitted stream)
BifrostContextKeyRoutingEngineLogs BifrostContextKey = "bifrost-routing-engine-logs" // []RoutingEngineLogEntry (set by bifrost - DO NOT SET THIS MANUALLY) - list of routing engine log entries
BifrostContextKeyTransportPluginLogs BifrostContextKey = "bifrost-transport-plugin-logs" // []PluginLogEntry (transport-layer plugin logs accumulated during HTTP transport hooks)
BifrostContextKeyTransportPostHookCompleter BifrostContextKey = "bifrost-transport-posthook-completer" // func() (callback to run HTTPTransportPostHook after streaming - set by transport interceptor middleware)
Expand Down
47 changes: 47 additions & 0 deletions framework/configstore/migrations.go
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,9 @@ func triggerMigrations(ctx context.Context, db *gorm.DB) error {
if err := migrationAddRoutingChainMaxDepthColumn(ctx, db); err != nil {
return err
}
if err := migrationAddPromptVariablesColumns(ctx, db); err != nil {
return err
}
if err := migrationAddModelCapabilityColumns(ctx, db); err != nil {
return err
}
Expand Down Expand Up @@ -5466,6 +5469,50 @@ func migrationAddOpenAIConfigJSONColumn(ctx context.Context, db *gorm.DB) error
return nil
}

// migrationAddPromptVariablesColumns adds variables_json column to prompt_sessions and prompt_versions
func migrationAddPromptVariablesColumns(ctx context.Context, db *gorm.DB) error {
m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{
ID: "add_prompt_variables_columns",
Migrate: func(tx *gorm.DB) error {
tx = tx.WithContext(ctx)
migrator := tx.Migrator()

if !migrator.HasColumn(&tables.TablePromptSession{}, "variables_json") {
if err := migrator.AddColumn(&tables.TablePromptSession{}, "VariablesJSON"); err != nil {
return fmt.Errorf("failed to add variables_json column to prompt_sessions: %w", err)
}
}

if !migrator.HasColumn(&tables.TablePromptVersion{}, "variables_json") {
if err := migrator.AddColumn(&tables.TablePromptVersion{}, "VariablesJSON"); err != nil {
return fmt.Errorf("failed to add variables_json column to prompt_versions: %w", err)
}
}

return nil
},
Rollback: func(tx *gorm.DB) error {
tx = tx.WithContext(ctx)
migrator := tx.Migrator()
if migrator.HasColumn(&tables.TablePromptSession{}, "variables_json") {
if err := migrator.DropColumn(&tables.TablePromptSession{}, "variables_json"); err != nil {
return err
}
}
if migrator.HasColumn(&tables.TablePromptVersion{}, "variables_json") {
if err := migrator.DropColumn(&tables.TablePromptVersion{}, "variables_json"); err != nil {
return err
}
}
return nil
},
}})
if err := m.Migrate(); err != nil {
return fmt.Errorf("error while running add_prompt_variables_columns migration: %s", err.Error())
}
return nil
}

// migrationAddKeyBlacklistedModelsJSONColumn adds blacklisted_models_json to config_keys
// for per-key model deny lists (JSON array of model ids, default []).
func migrationAddKeyBlacklistedModelsJSONColumn(ctx context.Context, db *gorm.DB) error {
Expand Down
22 changes: 22 additions & 0 deletions framework/configstore/tables/promptSessions.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ type TablePromptSession struct {
ModelParams ModelParams `gorm:"-" json:"model_params"`
Provider string `gorm:"type:varchar(100)" json:"provider"`
Model string `gorm:"type:varchar(100)" json:"model"`
VariablesJSON *string `gorm:"type:text;column:variables_json" json:"-"`
Variables PromptVariables `gorm:"-" json:"variables,omitempty"` // {key: value} map for Jinja2 variables
CreatedAt time.Time `gorm:"not null" json:"created_at"`
UpdatedAt time.Time `gorm:"not null" json:"updated_at"`

Expand All @@ -40,6 +42,17 @@ func (s *TablePromptSession) BeforeSave(tx *gorm.DB) error {
}
paramsStr := string(data)
s.ModelParamsJSON = &paramsStr

if s.Variables != nil {
varsData, err := json.Marshal(s.Variables)
if err != nil {
return err
}
varsStr := string(varsData)
s.VariablesJSON = &varsStr
} else {
s.VariablesJSON = nil
}
return nil
}

Expand All @@ -52,6 +65,15 @@ func (s *TablePromptSession) AfterFind(tx *gorm.DB) error {
return err
}
}
if s.VariablesJSON != nil && *s.VariablesJSON != "" {
var vars PromptVariables
if err := json.Unmarshal([]byte(*s.VariablesJSON), &vars); err != nil {
return err
}
s.Variables = vars
} else {
s.Variables = nil
}
return nil
}

Expand Down
41 changes: 30 additions & 11 deletions framework/configstore/tables/promptVersions.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,19 @@ import (
// TablePromptVersion represents an immutable version of a prompt
// Once created, a version cannot be modified - to make changes, create a new version
type TablePromptVersion struct {
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
PromptID string `gorm:"type:varchar(36);not null;index;uniqueIndex:idx_prompt_version" json:"prompt_id"`
Prompt *TablePrompt `gorm:"foreignKey:PromptID" json:"prompt,omitempty"`
VersionNumber int `gorm:"not null;uniqueIndex:idx_prompt_version" json:"version_number"`
CommitMessage string `gorm:"type:text" json:"commit_message"`
ModelParamsJSON *string `gorm:"type:text;column:model_params_json" json:"-"`
ModelParams ModelParams `gorm:"-" json:"model_params"`
Provider string `gorm:"type:varchar(100)" json:"provider"`
Model string `gorm:"type:varchar(100)" json:"model"`
IsLatest bool `gorm:"not null;default:false" json:"is_latest"`
CreatedAt time.Time `gorm:"not null" json:"created_at"`
ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
PromptID string `gorm:"type:varchar(36);not null;index;uniqueIndex:idx_prompt_version" json:"prompt_id"`
Prompt *TablePrompt `gorm:"foreignKey:PromptID" json:"prompt,omitempty"`
VersionNumber int `gorm:"not null;uniqueIndex:idx_prompt_version" json:"version_number"`
CommitMessage string `gorm:"type:text" json:"commit_message"`
ModelParamsJSON *string `gorm:"type:text;column:model_params_json" json:"-"`
ModelParams ModelParams `gorm:"-" json:"model_params"`
Provider string `gorm:"type:varchar(100)" json:"provider"`
Model string `gorm:"type:varchar(100)" json:"model"`
VariablesJSON *string `gorm:"type:text;column:variables_json" json:"-"`
Variables PromptVariables `gorm:"-" json:"variables,omitempty"` // {key: value} map for Jinja2 variables
IsLatest bool `gorm:"not null;default:false" json:"is_latest"`
CreatedAt time.Time `gorm:"not null" json:"created_at"`
// No UpdatedAt - versions are immutable

// Relationships
Expand All @@ -36,6 +38,10 @@ func (TablePromptVersion) TableName() string { return "prompt_versions" }
// so that any provider-specific params (response_format, seed, logprobs, etc.) are preserved.
type ModelParams map[string]interface{}

// PromptVariables represents a map of Jinja2 variable names to their values.
// Sessions store full {key: value} pairs; versions store {key: ""} (keys only).
type PromptVariables map[string]string

// BeforeSave GORM hook to serialize JSON fields
func (v *TablePromptVersion) BeforeSave(tx *gorm.DB) error {
if v.ModelParams != nil {
Expand All @@ -46,6 +52,14 @@ func (v *TablePromptVersion) BeforeSave(tx *gorm.DB) error {
paramsStr := string(data)
v.ModelParamsJSON = &paramsStr
}
if v.Variables != nil {
varsData, err := json.Marshal(v.Variables)
if err != nil {
return err
}
varsStr := string(varsData)
v.VariablesJSON = &varsStr
}
return nil
}

Expand All @@ -58,6 +72,11 @@ func (v *TablePromptVersion) AfterFind(tx *gorm.DB) error {
return err
}
}
if v.VariablesJSON != nil && *v.VariablesJSON != "" {
if err := json.Unmarshal([]byte(*v.VariablesJSON), &v.Variables); err != nil {
return err
}
}
return nil
}

Expand Down
66 changes: 66 additions & 0 deletions framework/routing/routing.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package routing

import (
"fmt"
"regexp"
"strings"
)

// headerKeyPattern matches header map access patterns like headers["X-Api-Key"] or headers['X-Api-Key']
var headerKeyPattern = regexp.MustCompile(`headers\[["']([^"']+)["']\]`)

// headerInPattern matches "in headers" membership test patterns like "X-Api-Key" in headers or 'X-Api-Key' in headers
var headerInPattern = regexp.MustCompile(`["']([^"']+)["']\s+in\s+headers`)

// paramKeyPattern matches param map access patterns like params["Region"] or params['Region']
var paramKeyPattern = regexp.MustCompile(`params\[["']([^"']+)["']\]`)

// paramInPattern matches "in params" membership test patterns like "Region" in params or 'Region' in params
var paramInPattern = regexp.MustCompile(`["']([^"']+)["']\s+in\s+params`)

// normalizeMapKeysInCEL lowercases header and param keys in CEL expressions
// so that headers["X-Api-Key"] becomes headers["x-api-key"], "X-Api-Key" in headers becomes "x-api-key" in headers,
// params["Region"] becomes params["region"], and "Region" in params becomes "region" in params.
// This ensures CEL expressions match against the normalized (lowercase) map keys at runtime.
func NormalizeMapKeysInCEL(expr string) string {
toLower := func(match string) string {
return strings.ToLower(match)
}
// Normalize bracket access
expr = headerKeyPattern.ReplaceAllStringFunc(expr, toLower)
expr = paramKeyPattern.ReplaceAllStringFunc(expr, toLower)
// Normalize "in" membership test
expr = headerInPattern.ReplaceAllStringFunc(expr, toLower)
expr = paramInPattern.ReplaceAllStringFunc(expr, toLower)
return expr
}

// validateCELExpression performs basic validation on CEL expression format
func ValidateCELExpression(expr string) error {
normalized := strings.TrimSpace(expr)
if normalized == "" || normalized == "true" || normalized == "false" {
return nil // Empty, true, or false are valid
}

// List of allowed operators and keywords
validPatterns := []string{
"==", "!=", "&&", "||", ">", "<", ">=", "<=",
"in ", "matches ", ".startsWith(", ".contains(", ".endsWith(",
"[", "]", "(", ")", "!",
}

// Check if expression contains at least one valid operator
hasPattern := false
for _, pattern := range validPatterns {
if strings.Contains(normalized, pattern) {
hasPattern = true
break
}
}

if !hasPattern {
return fmt.Errorf("expression must contain at least one operator: %s", expr)
}

return nil
}
59 changes: 0 additions & 59 deletions plugins/governance/routing.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package governance
import (
"fmt"
"math/rand/v2"
"regexp"
"strings"

"github.com/google/cel-go/cel"
Expand All @@ -14,18 +13,6 @@ import (
// DefaultRoutingChainMaxDepth is the default maximum depth for routing rule chain evaluation.
const DefaultRoutingChainMaxDepth = 10

// headerKeyPattern matches header map access patterns like headers["X-Api-Key"] or headers['X-Api-Key']
var headerKeyPattern = regexp.MustCompile(`headers\[["']([^"']+)["']\]`)

// headerInPattern matches "in headers" membership test patterns like "X-Api-Key" in headers or 'X-Api-Key' in headers
var headerInPattern = regexp.MustCompile(`["']([^"']+)["']\s+in\s+headers`)

// paramKeyPattern matches param map access patterns like params["Region"] or params['Region']
var paramKeyPattern = regexp.MustCompile(`params\[["']([^"']+)["']\]`)

// paramInPattern matches "in params" membership test patterns like "Region" in params or 'Region' in params
var paramInPattern = regexp.MustCompile(`["']([^"']+)["']\s+in\s+params`)

// ScopeLevel represents a level in the scope precedence hierarchy
type ScopeLevel struct {
ScopeName string // "virtual_key", "team", "customer", or "global"
Expand Down Expand Up @@ -482,52 +469,6 @@ func scopeChainToStrings(chain []ScopeLevel) []string {
return scopes
}

// validateCELExpression performs basic validation on CEL expression format
func validateCELExpression(expr string) error {
if expr == "" || expr == "true" || expr == "false" {
return nil // Empty, true, or false are valid
}

// List of allowed operators and keywords
validPatterns := []string{
"==", "!=", "&&", "||", ">", "<", ">=", "<=",
"in ", "matches ", ".startsWith(", ".contains(", ".endsWith(",
"[", "]", "(", ")", "!",
}

// Check if expression contains at least one valid operator
hasPattern := false
for _, pattern := range validPatterns {
if strings.Contains(expr, pattern) {
hasPattern = true
break
}
}

if !hasPattern {
return fmt.Errorf("expression must contain at least one operator: %s", expr)
}

return nil
}

// normalizeMapKeysInCEL lowercases header and param keys in CEL expressions
// so that headers["X-Api-Key"] becomes headers["x-api-key"], "X-Api-Key" in headers becomes "x-api-key" in headers,
// params["Region"] becomes params["region"], and "Region" in params becomes "region" in params.
// This ensures CEL expressions match against the normalized (lowercase) map keys at runtime.
func normalizeMapKeysInCEL(expr string) string {
toLower := func(match string) string {
return strings.ToLower(match)
}
// Normalize bracket access
expr = headerKeyPattern.ReplaceAllStringFunc(expr, toLower)
expr = paramKeyPattern.ReplaceAllStringFunc(expr, toLower)
// Normalize "in" membership test
expr = headerInPattern.ReplaceAllStringFunc(expr, toLower)
expr = paramInPattern.ReplaceAllStringFunc(expr, toLower)
return expr
}

// createCELEnvironment creates a new CEL environment for routing rules
func createCELEnvironment() (*cel.Env, error) {
return cel.NewEnv(
Expand Down
7 changes: 4 additions & 3 deletions plugins/governance/routing_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/framework/configstore"
configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables"
"github.com/maximhq/bifrost/framework/routing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -1159,7 +1160,7 @@ func TestValidateCELExpression_Valid(t *testing.T) {
}

for _, expr := range tests {
err := validateCELExpression(expr)
err := routing.ValidateCELExpression(expr)
assert.NoError(t, err, "expression should be valid: %s", expr)
}
}
Expand All @@ -1173,7 +1174,7 @@ func TestValidateCELExpression_Invalid(t *testing.T) {
}

for _, expr := range tests {
err := validateCELExpression(expr)
err := routing.ValidateCELExpression(expr)
assert.Error(t, err, "expression should be invalid: %s", expr)
}
}
Expand Down Expand Up @@ -1733,7 +1734,7 @@ func TestNormalizeMapKeysInCEL(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := normalizeMapKeysInCEL(tt.input)
result := routing.NormalizeMapKeysInCEL(tt.input)
assert.Equal(t, tt.expected, result)
})
}
Expand Down
Loading
Loading