diff --git a/core/schemas/bifrost.go b/core/schemas/bifrost.go index 9d9f611cec..c365c2d169 100644 --- a/core/schemas/bifrost.go +++ b/core/schemas/bifrost.go @@ -188,6 +188,9 @@ const ( BifrostContextKeyGovernanceBusinessUnitName BifrostContextKey = "bifrost-governance-business-unit-name" // string (to store the business unit name (set by enterprise governance plugin - DO NOT SET THIS MANUALLY)) BifrostContextKeyGovernanceRoutingRuleID BifrostContextKey = "bifrost-governance-routing-rule-id" // string (to store the routing rule ID (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) BifrostContextKeyGovernanceRoutingRuleName BifrostContextKey = "bifrost-governance-routing-rule-name" // string (to store the routing rule name (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) + BifrostContextKeySelectedPromptName BifrostContextKey = "bifrost-selected-prompt-name" // string (display name of the selected prompt (set by prompts plugin - DO NOT SET THIS MANUALLY)) + BifrostContextKeySelectedPromptVersion BifrostContextKey = "bifrost-selected-prompt-version" // string (numeric version as string, e.g. "3" (set by prompts plugin - DO NOT SET THIS MANUALLY)) + BifrostContextKeySelectedPromptID BifrostContextKey = "bifrost-selected-prompt-id" // string (id of the selected prompt (set by prompts plugin - DO NOT SET THIS MANUALLY)) BifrostContextKeyGovernanceIncludeOnlyKeys BifrostContextKey = "bf-governance-include-only-keys" // []string (to store the include-only key IDs for provider config routing (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) BifrostContextKeyNumberOfRetries BifrostContextKey = "bifrost-number-of-retries" // int (to store the number of retries (set by bifrost - DO NOT SET THIS MANUALLY)) BifrostContextKeyFallbackIndex BifrostContextKey = "bifrost-fallback-index" // int (to store the fallback index (set by bifrost - DO NOT SET THIS MANUALLY)) 0 for primary, 1 for first fallback, etc. @@ -222,6 +225,7 @@ const ( BifrostContextKeyHasEmittedMessageDelta BifrostContextKey = "bifrost-has-emitted-message-delta" // bool (tracks whether message_delta was already emitted during streaming - avoids duplicates) BifrostContextKeySkipDBUpdate BifrostContextKey = "bifrost-skip-db-update" // bool (set by bifrost - DO NOT SET THIS MANUALLY)) BifrostContextKeyGovernancePluginName BifrostContextKey = "governance-plugin-name" // string (name of the governance plugin that processed the request - set by bifrost) + BifrostContextKeyPromptsPluginName BifrostContextKey = "prompts-plugin-name" // string (name of the prompts plugin to use - set by bifrost - DO NOT SET THIS MANUALLY)) BifrostContextKeyIsEnterprise BifrostContextKey = "is-enterprise" // bool (set by bifrost - DO NOT SET THIS MANUALLY)) BifrostContextKeyAvailableProviders BifrostContextKey = "available-providers" // []ModelProvider (set by bifrost - DO NOT SET THIS MANUALLY)) BifrostContextKeyStoreRawRequestResponse BifrostContextKey = "bifrost-store-raw-request-response" // bool (per-request override — read by bifrost.go, never overwritten) @@ -235,7 +239,6 @@ const ( BifrostContextKeyHTTPRequestType BifrostContextKey = "bifrost-http-request-type" // RequestType (set by bifrost - DO NOT SET THIS MANUALLY)) BifrostContextKeyPassthroughExtraParams BifrostContextKey = "bifrost-passthrough-extra-params" // bool BifrostContextKeyRoutingEnginesUsed BifrostContextKey = "bifrost-routing-engines-used" // []string (set by bifrost - DO NOT SET THIS MANUALLY) - list of routing engines used ("routing-rule", "governance", "loadbalancing", etc.) - BifrostContextKeyPromptStreamRequest BifrostContextKey = "bifrost-prompt-stream-request" // bool (set by prompts HTTP plugin when prompt version model_params.stream is true and body omitted stream) BifrostContextKeyRoutingEngineLogs BifrostContextKey = "bifrost-routing-engine-logs" // []RoutingEngineLogEntry (set by bifrost - DO NOT SET THIS MANUALLY) - list of routing engine log entries BifrostContextKeyTransportPluginLogs BifrostContextKey = "bifrost-transport-plugin-logs" // []PluginLogEntry (transport-layer plugin logs accumulated during HTTP transport hooks) BifrostContextKeyTransportPostHookCompleter BifrostContextKey = "bifrost-transport-posthook-completer" // func() (callback to run HTTPTransportPostHook after streaming - set by transport interceptor middleware) diff --git a/framework/configstore/migrations.go b/framework/configstore/migrations.go index bced2ec3f5..42ca3f3c8d 100644 --- a/framework/configstore/migrations.go +++ b/framework/configstore/migrations.go @@ -359,6 +359,9 @@ func triggerMigrations(ctx context.Context, db *gorm.DB) error { if err := migrationAddRoutingChainMaxDepthColumn(ctx, db); err != nil { return err } + if err := migrationAddPromptVariablesColumns(ctx, db); err != nil { + return err + } if err := migrationAddModelCapabilityColumns(ctx, db); err != nil { return err } @@ -5466,6 +5469,50 @@ func migrationAddOpenAIConfigJSONColumn(ctx context.Context, db *gorm.DB) error return nil } +// migrationAddPromptVariablesColumns adds variables_json column to prompt_sessions and prompt_versions +func migrationAddPromptVariablesColumns(ctx context.Context, db *gorm.DB) error { + m := migrator.New(db, migrator.DefaultOptions, []*migrator.Migration{{ + ID: "add_prompt_variables_columns", + Migrate: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + migrator := tx.Migrator() + + if !migrator.HasColumn(&tables.TablePromptSession{}, "variables_json") { + if err := migrator.AddColumn(&tables.TablePromptSession{}, "VariablesJSON"); err != nil { + return fmt.Errorf("failed to add variables_json column to prompt_sessions: %w", err) + } + } + + if !migrator.HasColumn(&tables.TablePromptVersion{}, "variables_json") { + if err := migrator.AddColumn(&tables.TablePromptVersion{}, "VariablesJSON"); err != nil { + return fmt.Errorf("failed to add variables_json column to prompt_versions: %w", err) + } + } + + return nil + }, + Rollback: func(tx *gorm.DB) error { + tx = tx.WithContext(ctx) + migrator := tx.Migrator() + if migrator.HasColumn(&tables.TablePromptSession{}, "variables_json") { + if err := migrator.DropColumn(&tables.TablePromptSession{}, "variables_json"); err != nil { + return err + } + } + if migrator.HasColumn(&tables.TablePromptVersion{}, "variables_json") { + if err := migrator.DropColumn(&tables.TablePromptVersion{}, "variables_json"); err != nil { + return err + } + } + return nil + }, + }}) + if err := m.Migrate(); err != nil { + return fmt.Errorf("error while running add_prompt_variables_columns migration: %s", err.Error()) + } + return nil +} + // migrationAddKeyBlacklistedModelsJSONColumn adds blacklisted_models_json to config_keys // for per-key model deny lists (JSON array of model ids, default []). func migrationAddKeyBlacklistedModelsJSONColumn(ctx context.Context, db *gorm.DB) error { diff --git a/framework/configstore/tables/promptSessions.go b/framework/configstore/tables/promptSessions.go index df96db09aa..4618704ebf 100644 --- a/framework/configstore/tables/promptSessions.go +++ b/framework/configstore/tables/promptSessions.go @@ -22,6 +22,8 @@ type TablePromptSession struct { ModelParams ModelParams `gorm:"-" json:"model_params"` Provider string `gorm:"type:varchar(100)" json:"provider"` Model string `gorm:"type:varchar(100)" json:"model"` + VariablesJSON *string `gorm:"type:text;column:variables_json" json:"-"` + Variables PromptVariables `gorm:"-" json:"variables,omitempty"` // {key: value} map for Jinja2 variables CreatedAt time.Time `gorm:"not null" json:"created_at"` UpdatedAt time.Time `gorm:"not null" json:"updated_at"` @@ -40,6 +42,17 @@ func (s *TablePromptSession) BeforeSave(tx *gorm.DB) error { } paramsStr := string(data) s.ModelParamsJSON = ¶msStr + + if s.Variables != nil { + varsData, err := json.Marshal(s.Variables) + if err != nil { + return err + } + varsStr := string(varsData) + s.VariablesJSON = &varsStr + } else { + s.VariablesJSON = nil + } return nil } @@ -52,6 +65,15 @@ func (s *TablePromptSession) AfterFind(tx *gorm.DB) error { return err } } + if s.VariablesJSON != nil && *s.VariablesJSON != "" { + var vars PromptVariables + if err := json.Unmarshal([]byte(*s.VariablesJSON), &vars); err != nil { + return err + } + s.Variables = vars + } else { + s.Variables = nil + } return nil } diff --git a/framework/configstore/tables/promptVersions.go b/framework/configstore/tables/promptVersions.go index 1703be58f3..ca9e41e039 100644 --- a/framework/configstore/tables/promptVersions.go +++ b/framework/configstore/tables/promptVersions.go @@ -12,17 +12,19 @@ import ( // TablePromptVersion represents an immutable version of a prompt // Once created, a version cannot be modified - to make changes, create a new version type TablePromptVersion struct { - ID uint `gorm:"primaryKey;autoIncrement" json:"id"` - PromptID string `gorm:"type:varchar(36);not null;index;uniqueIndex:idx_prompt_version" json:"prompt_id"` - Prompt *TablePrompt `gorm:"foreignKey:PromptID" json:"prompt,omitempty"` - VersionNumber int `gorm:"not null;uniqueIndex:idx_prompt_version" json:"version_number"` - CommitMessage string `gorm:"type:text" json:"commit_message"` - ModelParamsJSON *string `gorm:"type:text;column:model_params_json" json:"-"` - ModelParams ModelParams `gorm:"-" json:"model_params"` - Provider string `gorm:"type:varchar(100)" json:"provider"` - Model string `gorm:"type:varchar(100)" json:"model"` - IsLatest bool `gorm:"not null;default:false" json:"is_latest"` - CreatedAt time.Time `gorm:"not null" json:"created_at"` + ID uint `gorm:"primaryKey;autoIncrement" json:"id"` + PromptID string `gorm:"type:varchar(36);not null;index;uniqueIndex:idx_prompt_version" json:"prompt_id"` + Prompt *TablePrompt `gorm:"foreignKey:PromptID" json:"prompt,omitempty"` + VersionNumber int `gorm:"not null;uniqueIndex:idx_prompt_version" json:"version_number"` + CommitMessage string `gorm:"type:text" json:"commit_message"` + ModelParamsJSON *string `gorm:"type:text;column:model_params_json" json:"-"` + ModelParams ModelParams `gorm:"-" json:"model_params"` + Provider string `gorm:"type:varchar(100)" json:"provider"` + Model string `gorm:"type:varchar(100)" json:"model"` + VariablesJSON *string `gorm:"type:text;column:variables_json" json:"-"` + Variables PromptVariables `gorm:"-" json:"variables,omitempty"` // {key: value} map for Jinja2 variables + IsLatest bool `gorm:"not null;default:false" json:"is_latest"` + CreatedAt time.Time `gorm:"not null" json:"created_at"` // No UpdatedAt - versions are immutable // Relationships @@ -36,6 +38,10 @@ func (TablePromptVersion) TableName() string { return "prompt_versions" } // so that any provider-specific params (response_format, seed, logprobs, etc.) are preserved. type ModelParams map[string]interface{} +// PromptVariables represents a map of Jinja2 variable names to their values. +// Sessions store full {key: value} pairs; versions store {key: ""} (keys only). +type PromptVariables map[string]string + // BeforeSave GORM hook to serialize JSON fields func (v *TablePromptVersion) BeforeSave(tx *gorm.DB) error { if v.ModelParams != nil { @@ -46,6 +52,14 @@ func (v *TablePromptVersion) BeforeSave(tx *gorm.DB) error { paramsStr := string(data) v.ModelParamsJSON = ¶msStr } + if v.Variables != nil { + varsData, err := json.Marshal(v.Variables) + if err != nil { + return err + } + varsStr := string(varsData) + v.VariablesJSON = &varsStr + } return nil } @@ -58,6 +72,11 @@ func (v *TablePromptVersion) AfterFind(tx *gorm.DB) error { return err } } + if v.VariablesJSON != nil && *v.VariablesJSON != "" { + if err := json.Unmarshal([]byte(*v.VariablesJSON), &v.Variables); err != nil { + return err + } + } return nil } diff --git a/framework/routing/routing.go b/framework/routing/routing.go new file mode 100644 index 0000000000..c29dd25673 --- /dev/null +++ b/framework/routing/routing.go @@ -0,0 +1,66 @@ +package routing + +import ( + "fmt" + "regexp" + "strings" +) + +// headerKeyPattern matches header map access patterns like headers["X-Api-Key"] or headers['X-Api-Key'] +var headerKeyPattern = regexp.MustCompile(`headers\[["']([^"']+)["']\]`) + +// headerInPattern matches "in headers" membership test patterns like "X-Api-Key" in headers or 'X-Api-Key' in headers +var headerInPattern = regexp.MustCompile(`["']([^"']+)["']\s+in\s+headers`) + +// paramKeyPattern matches param map access patterns like params["Region"] or params['Region'] +var paramKeyPattern = regexp.MustCompile(`params\[["']([^"']+)["']\]`) + +// paramInPattern matches "in params" membership test patterns like "Region" in params or 'Region' in params +var paramInPattern = regexp.MustCompile(`["']([^"']+)["']\s+in\s+params`) + +// normalizeMapKeysInCEL lowercases header and param keys in CEL expressions +// so that headers["X-Api-Key"] becomes headers["x-api-key"], "X-Api-Key" in headers becomes "x-api-key" in headers, +// params["Region"] becomes params["region"], and "Region" in params becomes "region" in params. +// This ensures CEL expressions match against the normalized (lowercase) map keys at runtime. +func NormalizeMapKeysInCEL(expr string) string { + toLower := func(match string) string { + return strings.ToLower(match) + } + // Normalize bracket access + expr = headerKeyPattern.ReplaceAllStringFunc(expr, toLower) + expr = paramKeyPattern.ReplaceAllStringFunc(expr, toLower) + // Normalize "in" membership test + expr = headerInPattern.ReplaceAllStringFunc(expr, toLower) + expr = paramInPattern.ReplaceAllStringFunc(expr, toLower) + return expr +} + +// validateCELExpression performs basic validation on CEL expression format +func ValidateCELExpression(expr string) error { + normalized := strings.TrimSpace(expr) + if normalized == "" || normalized == "true" || normalized == "false" { + return nil // Empty, true, or false are valid + } + + // List of allowed operators and keywords + validPatterns := []string{ + "==", "!=", "&&", "||", ">", "<", ">=", "<=", + "in ", "matches ", ".startsWith(", ".contains(", ".endsWith(", + "[", "]", "(", ")", "!", + } + + // Check if expression contains at least one valid operator + hasPattern := false + for _, pattern := range validPatterns { + if strings.Contains(normalized, pattern) { + hasPattern = true + break + } + } + + if !hasPattern { + return fmt.Errorf("expression must contain at least one operator: %s", expr) + } + + return nil +} diff --git a/plugins/governance/routing.go b/plugins/governance/routing.go index b32044be02..8e20d5ac48 100644 --- a/plugins/governance/routing.go +++ b/plugins/governance/routing.go @@ -3,7 +3,6 @@ package governance import ( "fmt" "math/rand/v2" - "regexp" "strings" "github.com/google/cel-go/cel" @@ -14,18 +13,6 @@ import ( // DefaultRoutingChainMaxDepth is the default maximum depth for routing rule chain evaluation. const DefaultRoutingChainMaxDepth = 10 -// headerKeyPattern matches header map access patterns like headers["X-Api-Key"] or headers['X-Api-Key'] -var headerKeyPattern = regexp.MustCompile(`headers\[["']([^"']+)["']\]`) - -// headerInPattern matches "in headers" membership test patterns like "X-Api-Key" in headers or 'X-Api-Key' in headers -var headerInPattern = regexp.MustCompile(`["']([^"']+)["']\s+in\s+headers`) - -// paramKeyPattern matches param map access patterns like params["Region"] or params['Region'] -var paramKeyPattern = regexp.MustCompile(`params\[["']([^"']+)["']\]`) - -// paramInPattern matches "in params" membership test patterns like "Region" in params or 'Region' in params -var paramInPattern = regexp.MustCompile(`["']([^"']+)["']\s+in\s+params`) - // ScopeLevel represents a level in the scope precedence hierarchy type ScopeLevel struct { ScopeName string // "virtual_key", "team", "customer", or "global" @@ -482,52 +469,6 @@ func scopeChainToStrings(chain []ScopeLevel) []string { return scopes } -// validateCELExpression performs basic validation on CEL expression format -func validateCELExpression(expr string) error { - if expr == "" || expr == "true" || expr == "false" { - return nil // Empty, true, or false are valid - } - - // List of allowed operators and keywords - validPatterns := []string{ - "==", "!=", "&&", "||", ">", "<", ">=", "<=", - "in ", "matches ", ".startsWith(", ".contains(", ".endsWith(", - "[", "]", "(", ")", "!", - } - - // Check if expression contains at least one valid operator - hasPattern := false - for _, pattern := range validPatterns { - if strings.Contains(expr, pattern) { - hasPattern = true - break - } - } - - if !hasPattern { - return fmt.Errorf("expression must contain at least one operator: %s", expr) - } - - return nil -} - -// normalizeMapKeysInCEL lowercases header and param keys in CEL expressions -// so that headers["X-Api-Key"] becomes headers["x-api-key"], "X-Api-Key" in headers becomes "x-api-key" in headers, -// params["Region"] becomes params["region"], and "Region" in params becomes "region" in params. -// This ensures CEL expressions match against the normalized (lowercase) map keys at runtime. -func normalizeMapKeysInCEL(expr string) string { - toLower := func(match string) string { - return strings.ToLower(match) - } - // Normalize bracket access - expr = headerKeyPattern.ReplaceAllStringFunc(expr, toLower) - expr = paramKeyPattern.ReplaceAllStringFunc(expr, toLower) - // Normalize "in" membership test - expr = headerInPattern.ReplaceAllStringFunc(expr, toLower) - expr = paramInPattern.ReplaceAllStringFunc(expr, toLower) - return expr -} - // createCELEnvironment creates a new CEL environment for routing rules func createCELEnvironment() (*cel.Env, error) { return cel.NewEnv( diff --git a/plugins/governance/routing_test.go b/plugins/governance/routing_test.go index 02aaa6a59a..21e8f7469d 100644 --- a/plugins/governance/routing_test.go +++ b/plugins/governance/routing_test.go @@ -11,6 +11,7 @@ import ( "github.com/maximhq/bifrost/core/schemas" "github.com/maximhq/bifrost/framework/configstore" configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables" + "github.com/maximhq/bifrost/framework/routing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -1159,7 +1160,7 @@ func TestValidateCELExpression_Valid(t *testing.T) { } for _, expr := range tests { - err := validateCELExpression(expr) + err := routing.ValidateCELExpression(expr) assert.NoError(t, err, "expression should be valid: %s", expr) } } @@ -1173,7 +1174,7 @@ func TestValidateCELExpression_Invalid(t *testing.T) { } for _, expr := range tests { - err := validateCELExpression(expr) + err := routing.ValidateCELExpression(expr) assert.Error(t, err, "expression should be invalid: %s", expr) } } @@ -1733,7 +1734,7 @@ func TestNormalizeMapKeysInCEL(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := normalizeMapKeysInCEL(tt.input) + result := routing.NormalizeMapKeysInCEL(tt.input) assert.Equal(t, tt.expected, result) }) } diff --git a/plugins/governance/store.go b/plugins/governance/store.go index 4775c71bd4..495f18944d 100644 --- a/plugins/governance/store.go +++ b/plugins/governance/store.go @@ -14,6 +14,7 @@ import ( "github.com/maximhq/bifrost/framework/configstore" configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables" "github.com/maximhq/bifrost/framework/modelcatalog" + "github.com/maximhq/bifrost/framework/routing" "gorm.io/gorm" ) @@ -3451,10 +3452,10 @@ func (gs *LocalGovernanceStore) GetRoutingProgram(rule *configstoreTables.TableR } // Normalize header and param keys to lowercase so CEL expressions match normalized map keys - expr = normalizeMapKeysInCEL(expr) + expr = routing.NormalizeMapKeysInCEL(expr) // Validate expression format - if err := validateCELExpression(expr); err != nil { + if err := routing.ValidateCELExpression(expr); err != nil { return nil, fmt.Errorf("invalid CEL expression: %w", err) } diff --git a/plugins/prompts/helpers_test.go b/plugins/prompts/helpers_test.go index ec8f778c7b..594ca5a7b8 100644 --- a/plugins/prompts/helpers_test.go +++ b/plugins/prompts/helpers_test.go @@ -76,7 +76,7 @@ func (l *MockLogger) Warned() bool { } // ============================================================ -// mockStore — satisfies promptStore with controllable responses. +// mockStore — satisfies InMemoryStore with controllable responses. // ============================================================ type mockStore struct { @@ -113,14 +113,13 @@ func (s *versionsErrStore) GetAllPromptVersions(_ context.Context) ([]tables.Tab // ============================================================ type staticResolver struct { - promptID string - versionNumber int - versionSpecified bool - err error + promptID string + versionNumber int + err error } -func (r *staticResolver) Resolve(_ *schemas.BifrostContext, _ *schemas.BifrostRequest) (string, int, bool, error) { - return r.promptID, r.versionNumber, r.versionSpecified, r.err +func (r *staticResolver) Resolve(_ *schemas.BifrostContext, _ *schemas.BifrostRequest) (string, int, error) { + return r.promptID, r.versionNumber, r.err } // ============================================================ @@ -129,7 +128,7 @@ func (r *staticResolver) Resolve(_ *schemas.BifrostContext, _ *schemas.BifrostRe // newPluginWithStore builds a Plugin whose store is set but maps are empty. // Use only for loadCache tests. -func newPluginWithStore(s promptStore) *Plugin { +func newPluginWithStore(s InMemoryStore) *Plugin { return &Plugin{ store: s, logger: NewMockLogger(), diff --git a/plugins/prompts/main.go b/plugins/prompts/main.go index 05c2507b9f..e710980dcb 100644 --- a/plugins/prompts/main.go +++ b/plugins/prompts/main.go @@ -1,3 +1,7 @@ +// Package prompts implements the Bifrost LLM plugin that resolves stored prompt templates +// from the config store and prepends their messages to chat and Responses API requests. +// HTTP clients select a prompt via x-bf-prompt-id / x-bf-prompt-version headers; optional +// custom PromptResolver implementations can override how ID and version are chosen. package prompts import ( @@ -9,19 +13,28 @@ import ( "strings" "sync" + bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/schemas" configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables" ) const ( - PluginName = "prompts" - PromptIDHeader = "bf-prompt-id" - PromptVersionHeader = "bf-prompt-version" - PromptIDKey schemas.BifrostContextKey = PromptIDHeader - PromptVersionKey schemas.BifrostContextKey = PromptVersionHeader + // PluginName is the canonical name registered for the prompts plugin. + PluginName = "prompts" + + // PromptIDHeader and PromptVersionHeader are request headers copied into BifrostContext + // in HTTPTransportPreHook so PreLLMHook and custom resolvers can read them. + PromptIDHeader = "x-bf-prompt-id" + PromptVersionHeader = "x-bf-prompt-version" + + // PromptIDKey and PromptVersionKey are context keys for the resolved header values. + PromptIDKey schemas.BifrostContextKey = PromptIDHeader + PromptVersionKey schemas.BifrostContextKey = PromptVersionHeader ) -type promptStore interface { +// InMemoryStore is the data source for prompts and all versions. Implementations typically +// wrap the framework config store; the plugin keeps an in-memory index built by loadCache. +type InMemoryStore interface { GetPrompts(ctx context.Context, folderID *string) ([]configstoreTables.TablePrompt, error) GetAllPromptVersions(ctx context.Context) ([]configstoreTables.TablePromptVersion, error) } @@ -29,30 +42,43 @@ type promptStore interface { // PromptResolver decides which prompt and version to inject for a given request. // Returning an empty promptID means no injection for this request. type PromptResolver interface { - Resolve(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (promptID string, versionNumber int, versionSpecified bool, err error) + Resolve(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (promptID string, versionNumber int, err error) } -// headerResolver is the default OSS resolver: reads prompt ID and version from context -// keys that were populated from HTTP headers in HTTPTransportPreHook. +// headerResolver is the default OSS resolver: it reads prompt ID and version from context +// keys populated from HTTP headers in HTTPTransportPreHook (x-bf-prompt-id, x-bf-prompt-version). type headerResolver struct { logger schemas.Logger } -func (r *headerResolver) Resolve(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (string, int, bool, error) { - promptID := promptStringFromCtx(ctx, PromptIDKey) +// Resolve returns the prompt ID and version number from context. An empty promptID means +// no prompt injection for this request. Version 0 means “use latest” when passed to resolveVersion. +func (r *headerResolver) Resolve(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (string, int, error) { + promptID := bifrost.GetStringFromContext(ctx, PromptIDKey) if promptID == "" { - return "", 0, false, nil + return "", 0, nil } - versionNumber, specified, err := parsePromptVersionNumber(ctx) + versionNumber, err := parseNumberFromContext(ctx, PromptVersionKey) if err != nil { - return "", 0, false, fmt.Errorf("invalid bifrost-prompt-version: %w", err) + return "", 0, fmt.Errorf("failed to parse version number: %w", err) } - return promptID, versionNumber, specified, nil + return promptID, versionNumber, nil } -// Plugin resolves stored prompt templates and prepends their messages to LLM requests. +// Plugin implements schemas.LLMPlugin (and HTTP transport hooks) for server-side prompt injection. +// It loads prompts and versions into memory, resolves which version to use per request, merges +// the version’s model parameters with the client request (request wins), and prepends template +// messages before chat or Responses input. +// +// Fields: +// - store: backing persistence for prompts and versions +// - logger: Bifrost logger for non-fatal merge/param warnings +// - resolver: chooses prompt ID and version; defaults to headerResolver +// - mu: protects promptsByID and versionsByPromptAndNumber +// - promptsByID: prompt ID → prompt row (includes LatestVersion when using “latest”) +// - versionsByPromptAndNumber: prompt ID → version number → version row type Plugin struct { - store promptStore + store InMemoryStore logger schemas.Logger resolver PromptResolver @@ -61,13 +87,32 @@ type Plugin struct { versionsByPromptAndNumber map[string]map[int]*configstoreTables.TablePromptVersion } -// Init wires the prompts plugin with the default header-based resolver. -func Init(ctx context.Context, store promptStore, logger schemas.Logger) (schemas.LLMPlugin, error) { +// Init constructs a Plugin using the default header-based resolver (x-bf-prompt-id / x-bf-prompt-version). +// +// Parameters: +// - ctx: used for the initial loadCache call +// - store: required config store backend for prompts +// - logger: used by the default resolver and param merge paths +// +// Returns: +// - schemas.LLMPlugin: the initialized plugin +// - error: if the store is missing or the initial cache load fails +func Init(ctx context.Context, store InMemoryStore, logger schemas.Logger) (schemas.LLMPlugin, error) { return InitWithResolver(ctx, store, &headerResolver{logger: logger}, logger) } -// InitWithResolver wires the prompts plugin with a custom resolver. -func InitWithResolver(ctx context.Context, store promptStore, resolver PromptResolver, logger schemas.Logger) (*Plugin, error) { +// InitWithResolver constructs a Plugin with an explicit PromptResolver (nil falls back to headerResolver). +// +// Parameters: +// - ctx: used for the initial loadCache call +// - store: required config store backend for prompts +// - resolver: custom resolution logic; if nil, headerResolver is used +// - logger: passed to the default resolver when it is constructed internally +// +// Returns: +// - *Plugin: the initialized plugin (concrete type for Reload and handler integration) +// - error: if the store is missing or the initial cache load fails +func InitWithResolver(ctx context.Context, store InMemoryStore, resolver PromptResolver, logger schemas.Logger) (*Plugin, error) { if store == nil { return nil, fmt.Errorf("config store is required for prompts plugin") } @@ -127,10 +172,13 @@ func (p *Plugin) Reload(ctx context.Context) error { return p.loadCache(ctx) } +// GetName returns the plugin identifier ("prompts"). func (p *Plugin) GetName() string { return PluginName } +// HTTPTransportPreHook copies x-bf-prompt-id and x-bf-prompt-version from the incoming HTTP request +// into BifrostContext so the default header resolver and PreLLMHook can read them. func (p *Plugin) HTTPTransportPreHook(ctx *schemas.BifrostContext, req *schemas.HTTPRequest) (*schemas.HTTPResponse, error) { if req == nil { return nil, nil @@ -141,66 +189,41 @@ func (p *Plugin) HTTPTransportPreHook(ctx *schemas.BifrostContext, req *schemas. if v := strings.TrimSpace(req.CaseInsensitiveHeaderLookup(PromptVersionHeader)); v != "" { ctx.SetValue(PromptVersionKey, v) } - p.setPromptStreamFromVersionForTransport(ctx) return nil, nil } -// setPromptStreamFromVersionForTransport sets BifrostContextKeyPromptStreamRequest when -// the resolved prompt version has stream:true in its ModelParams. -func (p *Plugin) setPromptStreamFromVersionForTransport(ctx *schemas.BifrostContext) { - promptID := promptStringFromCtx(ctx, PromptIDKey) - if promptID == "" { - return - } - versionNumber, versionSpecified, err := parsePromptVersionNumber(ctx) - if err != nil { - return - } - _, version, ok := p.resolveVersion(promptID, versionNumber, versionSpecified) - if !ok || version == nil || len(version.ModelParams) == 0 { - return - } - if includesStreamInModelParams(version.ModelParams) { - ctx.SetValue(schemas.BifrostContextKeyPromptStreamRequest, true) - } -} - -func includesStreamInModelParams(mp configstoreTables.ModelParams) bool { - raw, ok := mp["stream"] - if !ok { - return true // default to true if stream is not set, this is done because for the initial version, the stream key is not present but we default to true for the initial version and show it as well on the UI. If the user toggles stream off, we set `stream: false` in the model params in db. - } - switch v := raw.(type) { - case bool: - return v - case json.Number: - if i, err := strconv.ParseInt(string(v), 10, 64); err == nil { - return i != 0 - } - b, err := strconv.ParseBool(string(v)) - return err == nil && b - case string: - switch strings.ToLower(strings.TrimSpace(v)) { - case "true", "1", "yes": - return true - default: - return false - } - default: - return false - } -} - +// HTTPTransportPostHook is a no-op; this plugin does not modify HTTP response headers. func (p *Plugin) HTTPTransportPostHook(ctx *schemas.BifrostContext, req *schemas.HTTPRequest, resp *schemas.HTTPResponse) error { return nil } +// HTTPTransportStreamChunkHook passes streaming chunks through unchanged; prompt injection +// happens in PreLLMHook before the provider call. func (p *Plugin) HTTPTransportStreamChunkHook(ctx *schemas.BifrostContext, req *schemas.HTTPRequest, chunk *schemas.BifrostStreamChunk) (*schemas.BifrostStreamChunk, error) { return chunk, nil } +// PreLLMHook resolves the prompt via PromptResolver, loads the version from the in-memory +// cache, sets governance/observability context (selected prompt name and version), merges +// version ModelParams with the request (request overrides), converts stored messages to +// chat messages, and prepends them to Chat or Responses input. Non-HTTP transports rely +// on context keys set by callers instead of HTTPTransportPreHook. +// +// Parameters: +// - ctx: may set BifrostContextKeySelectedPromptName, BifrostContextKeySelectedPromptID and BifrostContextKeySelectedPromptVersion when a prompt is applied +// - req: chat or Responses request to mutate in place +// +// Returns: +// - *schemas.BifrostRequest: possibly modified request +// - *schemas.LLMPluginShortCircuit: always nil +// - error: resolution failure or missing prompt/version; invalid or empty template returns +// the request unchanged with a nil error func (p *Plugin) PreLLMHook(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.LLMPluginShortCircuit, error) { - promptID, versionNumber, versionSpecified, err := p.resolver.Resolve(ctx, req) + if req == nil { + return req, nil, nil + } + + promptID, versionNumber, err := p.resolver.Resolve(ctx, req) if err != nil { p.logger.Warn("prompts plugin: failed to resolve prompt: %v", err) return req, nil, nil @@ -209,17 +232,23 @@ func (p *Plugin) PreLLMHook(ctx *schemas.BifrostContext, req *schemas.BifrostReq return req, nil, nil } - _, version, found := p.resolveVersion(promptID, versionNumber, versionSpecified) + prompt, version, found := p.resolveVersion(promptID, versionNumber) if !found { - p.logger.Warn("prompts plugin: prompt or version not found: %s", promptID) + p.logger.Warn("prompts plugin: prompt or version not found: promptID=%s versionNumber=%d", promptID, versionNumber) return req, nil, nil } if version == nil { - p.logger.Warn("prompts plugin: prompt %s has no versions", promptID) + p.logger.Warn("prompts plugin: prompt has no resolved version: promptID=%s", promptID) return req, nil, nil } + if prompt != nil && prompt.Name != "" { + ctx.SetValue(schemas.BifrostContextKeySelectedPromptID, prompt.ID) + ctx.SetValue(schemas.BifrostContextKeySelectedPromptName, prompt.Name) + } + ctx.SetValue(schemas.BifrostContextKeySelectedPromptVersion, strconv.Itoa(version.VersionNumber)) + // Apply model params from the version (version params are defaults; request params win). switch { case req.ChatRequest != nil: @@ -230,10 +259,11 @@ func (p *Plugin) PreLLMHook(ctx *schemas.BifrostContext, req *schemas.BifrostReq template, err := chatMessagesFromVersionMessages(version.Messages) if err != nil { - p.logger.Warn("prompts plugin: failed to parse messages for prompt %s: %v", promptID, err) + p.logger.Warn("prompts plugin: failed to convert version messages to chat messages: %v", err) return req, nil, nil } if len(template) == 0 { + p.logger.Warn("prompts plugin: no template messages found for prompt %s version %d", promptID, version.VersionNumber) return req, nil, nil } @@ -247,6 +277,7 @@ func (p *Plugin) PreLLMHook(ctx *schemas.BifrostContext, req *schemas.BifrostReq return req, nil, nil } +// PostLLMHook is a no-op; the plugin does not modify responses. func (p *Plugin) PostLLMHook(ctx *schemas.BifrostContext, resp *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) { return resp, bifrostErr, nil } @@ -407,9 +438,12 @@ func applyVersionParamsToResponsesRequest(version *configstoreTables.TablePrompt } // resolveVersion centralises the map-lookup logic shared by setPromptStreamFromVersionForTransport -// and PreLLMHook. It returns the prompt and its resolved version (either the explicitly requested -// version or the prompt's latest version), plus a bool indicating whether both were found. -func (p *Plugin) resolveVersion(promptID string, versionNumber int, versionSpecified bool) ( +// and PreLLMHook. It returns the prompt and its resolved version. +// +// If versionNumber > 0, that explicit version is loaded from versionsByPromptAndNumber (from +// x-bf-prompt-version header or a custom PromptResolver such as deployment traffic routing). +// If versionNumber == 0, the prompt's latest version is used (no header / resolver chose latest). +func (p *Plugin) resolveVersion(promptID string, versionNumber int) ( *configstoreTables.TablePrompt, *configstoreTables.TablePromptVersion, bool, ) { p.mu.RLock() @@ -419,47 +453,44 @@ func (p *Plugin) resolveVersion(promptID string, versionNumber int, versionSpeci if !ok || prompt == nil { return nil, nil, false } - if !versionSpecified { - return prompt, prompt.LatestVersion, true - } - byNumber, ok := p.versionsByPromptAndNumber[promptID] - if !ok { - return nil, nil, false - } - v, found := byNumber[versionNumber] - if !found || v == nil { - return nil, nil, false + if versionNumber > 0 { + byNumber, ok := p.versionsByPromptAndNumber[promptID] + if !ok { + return nil, nil, false + } + v, found := byNumber[versionNumber] + if !found || v == nil { + return nil, nil, false + } + return prompt, v, true } - return prompt, v, true + return prompt, prompt.LatestVersion, true } +// Cleanup releases plugin resources; the prompts plugin has nothing to tear down. func (p *Plugin) Cleanup() error { return nil } -func promptStringFromCtx(ctx *schemas.BifrostContext, key schemas.BifrostContextKey) string { - if v, ok := ctx.Value(key).(string); ok { - return strings.TrimSpace(v) - } - return "" -} - -func parsePromptVersionNumber(ctx *schemas.BifrostContext) (num int, specified bool, err error) { - s, ok := ctx.Value(PromptVersionKey).(string) +// parseNumberFromContext parses a decimal integer from a string context value. Missing or +// empty values yield 0 with no error (treated as “no explicit version”). +func parseNumberFromContext(ctx *schemas.BifrostContext, key schemas.BifrostContextKey) (num int, err error) { + s, ok := ctx.Value(key).(string) if !ok { - return 0, false, nil + return 0, nil } s = strings.TrimSpace(s) if s == "" { - return 0, false, nil + return 0, nil } n, err := strconv.ParseInt(s, 10, 64) if err != nil { - return 0, true, err + return 0, err } - return int(n), true, nil + return int(n), nil } +// chatMessagePopulated reports whether a ChatMessage carries any meaningful content for injection. func chatMessagePopulated(cm schemas.ChatMessage) bool { if strings.TrimSpace(string(cm.Role)) != "" { return true @@ -526,6 +557,8 @@ func convertVersionMessagesToChatMessages(data []byte) (schemas.ChatMessage, err return chatMessage, nil } +// chatMessagesFromVersionMessages decodes each stored row into schemas.ChatMessage, preferring +// Message bytes and falling back to MessageJSON when needed. func chatMessagesFromVersionMessages(messages []configstoreTables.TablePromptVersionMessage) ([]schemas.ChatMessage, error) { out := make([]schemas.ChatMessage, 0, len(messages)) for i := range messages { @@ -543,6 +576,7 @@ func chatMessagesFromVersionMessages(messages []configstoreTables.TablePromptVer return out, nil } +// mergeChatMessages prepends prefix to the chat input slice (template first, then client messages). func mergeChatMessages(dest *[]schemas.ChatMessage, prefix []schemas.ChatMessage) { if dest == nil || len(prefix) == 0 { return @@ -554,6 +588,8 @@ func mergeChatMessages(dest *[]schemas.ChatMessage, prefix []schemas.ChatMessage *dest = merged } +// mergeResponsesMessages converts template chat messages to ResponsesMessage entries and +// prepends them before the client’s Responses input. func mergeResponsesMessages(dest *[]schemas.ResponsesMessage, template []schemas.ChatMessage) { if dest == nil || len(template) == 0 { return diff --git a/plugins/prompts/plugin_test.go b/plugins/prompts/plugin_test.go index 554174705c..df4f88e22e 100644 --- a/plugins/prompts/plugin_test.go +++ b/plugins/prompts/plugin_test.go @@ -103,7 +103,7 @@ func TestPreLLMHook_UseLatestVersion(t *testing.T) { prompt := makePrompt("p1", &v) p := newTestPlugin( - &staticResolver{promptID: "p1", versionSpecified: false}, + &staticResolver{promptID: "p1"}, map[string]*tables.TablePrompt{"p1": &prompt}, map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}}, ) @@ -129,7 +129,7 @@ func TestPreLLMHook_UseSpecificVersion(t *testing.T) { prompt := makePrompt("p1", &vLatest) p := newTestPlugin( - &staticResolver{promptID: "p1", versionNumber: 2, versionSpecified: true}, + &staticResolver{promptID: "p1", versionNumber: 2}, map[string]*tables.TablePrompt{"p1": &prompt}, map[string]map[int]*tables.TablePromptVersion{"p1": {1: &vLatest, 2: &vOld}}, ) @@ -149,7 +149,7 @@ func TestPreLLMHook_VersionNotFound(t *testing.T) { log := NewMockLogger() p := newTestPluginWithLogger( - &staticResolver{promptID: "p1", versionNumber: 99, versionSpecified: true}, + &staticResolver{promptID: "p1", versionNumber: 99}, map[string]*tables.TablePrompt{"p1": &prompt}, map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}}, log, @@ -167,7 +167,7 @@ func TestPreLLMHook_VersionBelongsToDifferentPrompt(t *testing.T) { log := NewMockLogger() p := newTestPluginWithLogger( - &staticResolver{promptID: "p1", versionNumber: 1, versionSpecified: true}, + &staticResolver{promptID: "p1", versionNumber: 1}, map[string]*tables.TablePrompt{"p1": &prompt}, map[string]map[int]*tables.TablePromptVersion{"p2": {1: &v}}, log, @@ -184,7 +184,7 @@ func TestPreLLMHook_NoLatestVersion(t *testing.T) { log := NewMockLogger() p := newTestPluginWithLogger( - &staticResolver{promptID: "p1", versionSpecified: false}, + &staticResolver{promptID: "p1"}, map[string]*tables.TablePrompt{"p1": &prompt}, nil, log, @@ -201,7 +201,7 @@ func TestPreLLMHook_EmptyTemplate(t *testing.T) { prompt := makePrompt("p1", &v) p := newTestPlugin( - &staticResolver{promptID: "p1", versionSpecified: false}, + &staticResolver{promptID: "p1"}, map[string]*tables.TablePrompt{"p1": &prompt}, map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}}, ) @@ -220,7 +220,7 @@ func TestPreLLMHook_MultipleTemplateMessages(t *testing.T) { prompt := makePrompt("p1", &v) p := newTestPlugin( - &staticResolver{promptID: "p1", versionSpecified: false}, + &staticResolver{promptID: "p1"}, map[string]*tables.TablePrompt{"p1": &prompt}, map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}}, ) @@ -266,7 +266,7 @@ func TestPreLLMHook_MessageJSON_FallbackPath(t *testing.T) { prompt := makePrompt("p1", &v) p := newTestPlugin( - &staticResolver{promptID: "p1", versionSpecified: false}, + &staticResolver{promptID: "p1"}, map[string]*tables.TablePrompt{"p1": &prompt}, map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}}, ) @@ -286,7 +286,7 @@ func TestPreLLMHook_ResponsesRequest(t *testing.T) { prompt := makePrompt("p1", &v) p := newTestPlugin( - &staticResolver{promptID: "p1", versionSpecified: false}, + &staticResolver{promptID: "p1"}, map[string]*tables.TablePrompt{"p1": &prompt}, map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}}, ) @@ -318,7 +318,7 @@ func TestPreLLMHook_PromptSystemMsg_PlusUserInputSystemMsg(t *testing.T) { prompt := makePrompt("p1", &v) p := newTestPlugin( - &staticResolver{promptID: "p1", versionSpecified: false}, + &staticResolver{promptID: "p1"}, map[string]*tables.TablePrompt{"p1": &prompt}, map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}}, ) @@ -355,7 +355,7 @@ func TestPreLLMHook_PromptWithToolCallMessages_PlusUserMessage(t *testing.T) { prompt := makePrompt("p1", &v) p := newTestPlugin( - &staticResolver{promptID: "p1", versionSpecified: false}, + &staticResolver{promptID: "p1"}, map[string]*tables.TablePrompt{"p1": &prompt}, map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}}, ) @@ -404,7 +404,7 @@ func TestPreLLMHook_MultipleSystemMessages_InPromptTemplate(t *testing.T) { prompt := makePrompt("p1", &v) p := newTestPlugin( - &staticResolver{promptID: "p1", versionSpecified: false}, + &staticResolver{promptID: "p1"}, map[string]*tables.TablePrompt{"p1": &prompt}, map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}}, ) @@ -488,9 +488,9 @@ func TestHTTPTransportPreHook_WhitespaceOnlyNotSet(t *testing.T) { func TestHTTPTransportPreHook_CaseInsensitiveHeaders(t *testing.T) { p := newTestPlugin(nil, nil, nil) ctx := bfCtx() - // "Bf-Prompt-Id" is a title-case variant of the canonical "bf-prompt-id". + // "X-Bf-Prompt-Id" is a title-case variant of the canonical "x-bf-prompt-id". req := &schemas.HTTPRequest{ - Headers: map[string]string{"Bf-Prompt-Id": "upper-case"}, + Headers: map[string]string{"X-Bf-Prompt-Id": "upper-case"}, } _, _ = p.HTTPTransportPreHook(ctx, req) @@ -635,14 +635,13 @@ func TestParsePromptVersionNumber(t *testing.T) { ctx.SetValue(PromptVersionKey, tt.value) } - num, specified, err := parsePromptVersionNumber(ctx) + num, err := parseNumberFromContext(ctx, PromptVersionKey) if tt.want.wantErr { require.Error(t, err) return } require.NoError(t, err) - assert.Equal(t, tt.want.specified, specified) assert.Equal(t, tt.want.num, num) }) } @@ -724,138 +723,6 @@ func TestChatMessagesFromVersionMessages_InvalidJSON(t *testing.T) { require.Error(t, err) } -// ============================================================ -// loadCache + PreLLMHook integration (store → cache → injection) -// ============================================================ - -// ============================================================ -// includesStreamInModelParams -// ============================================================ - -func TestIncludesStreamInModelParams(t *testing.T) { - tests := []struct { - name string - params tables.ModelParams - want bool - }{ - {"bool true", tables.ModelParams{"stream": true}, true}, - {"bool false", tables.ModelParams{"stream": false}, false}, - {"string true", tables.ModelParams{"stream": "true"}, true}, - {"string yes", tables.ModelParams{"stream": "yes"}, true}, - {"string 1", tables.ModelParams{"stream": "1"}, true}, - {"string false", tables.ModelParams{"stream": "false"}, false}, - {"string 0", tables.ModelParams{"stream": "0"}, false}, - {"absent key", tables.ModelParams{"temperature": 0.7}, true}, - {"empty params", tables.ModelParams{}, true}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - assert.Equal(t, tt.want, includesStreamInModelParams(tt.params)) - }) - } -} - -// ============================================================ -// HTTPTransportPreHook — stream routing via version ModelParams -// ============================================================ - -// TestHTTPTransportPreHook_StreamTrue_SetsStreamContext verifies that when the -// resolved version has stream:true in ModelParams, the hook marks the bifrost -// context so that the inference handler opens an SSE response. -func TestHTTPTransportPreHook_StreamTrue_SetsStreamContext(t *testing.T) { - v := makeVersionWithParams(1, "p1", true, tables.ModelParams{"stream": true}) - prompt := makePrompt("p1", &v) - - p := newTestPlugin( - nil, - map[string]*tables.TablePrompt{"p1": &prompt}, - map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}}, - ) - - ctx := bfCtx() - req := &schemas.HTTPRequest{Headers: map[string]string{PromptIDHeader: "p1"}} - - _, err := p.HTTPTransportPreHook(ctx, req) - require.NoError(t, err) - - streamVal, _ := ctx.Value(schemas.BifrostContextKeyPromptStreamRequest).(bool) - assert.True(t, streamVal, "expected BifrostContextKeyPromptStreamRequest=true when version has stream:true") -} - -// TestHTTPTransportPreHook_StreamFalse_NoStreamContext verifies that stream:false -// in ModelParams does NOT set the stream context key. -func TestHTTPTransportPreHook_StreamFalse_NoStreamContext(t *testing.T) { - v := makeVersionWithParams(1, "p1", true, tables.ModelParams{"stream": false}) - prompt := makePrompt("p1", &v) - - p := newTestPlugin( - nil, - map[string]*tables.TablePrompt{"p1": &prompt}, - map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}}, - ) - - ctx := bfCtx() - req := &schemas.HTTPRequest{Headers: map[string]string{PromptIDHeader: "p1"}} - - _, err := p.HTTPTransportPreHook(ctx, req) - require.NoError(t, err) - - assert.Nil(t, ctx.Value(schemas.BifrostContextKeyPromptStreamRequest), - "expected BifrostContextKeyPromptStreamRequest not set when version has stream:false") -} - -// TestHTTPTransportPreHook_NoStreamParam_NoStreamContext verifies that when no -// "stream" key is present in ModelParams, the stream context key is not set. -func TestHTTPTransportPreHook_NoStreamParam_NoStreamContext(t *testing.T) { - v := makeVersionWithParams(1, "p1", true, tables.ModelParams{"temperature": float64(0.7)}) - prompt := makePrompt("p1", &v) - - p := newTestPlugin( - nil, - map[string]*tables.TablePrompt{"p1": &prompt}, - map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}}, - ) - - ctx := bfCtx() - req := &schemas.HTTPRequest{Headers: map[string]string{PromptIDHeader: "p1"}} - - _, err := p.HTTPTransportPreHook(ctx, req) - require.NoError(t, err) - - assert.Equal(t, true, ctx.Value(schemas.BifrostContextKeyPromptStreamRequest), - "expected BifrostContextKeyPromptStreamRequest to default to true when no stream key in params") -} - -// TestHTTPTransportPreHook_SpecificVersion_StreamTrue_SetsStreamContext verifies -// that when a specific (non-latest) version is requested via header and that -// version has stream:true, the stream context key is set — even if the latest -// version has stream:false. -func TestHTTPTransportPreHook_SpecificVersion_StreamTrue_SetsStreamContext(t *testing.T) { - vLatest := makeVersionWithParams(1, "p1", true, tables.ModelParams{"stream": false}) - vOld := makeVersionWithParams(2, "p1", false, tables.ModelParams{"stream": true}) - prompt := makePrompt("p1", &vLatest) - - p := newTestPlugin( - nil, - map[string]*tables.TablePrompt{"p1": &prompt}, - map[string]map[int]*tables.TablePromptVersion{"p1": {1: &vLatest, 2: &vOld}}, - ) - - ctx := bfCtx() - req := &schemas.HTTPRequest{ - Headers: map[string]string{ - PromptIDHeader: "p1", - PromptVersionHeader: "2", - }, - } - - _, err := p.HTTPTransportPreHook(ctx, req) - require.NoError(t, err) - - streamVal, _ := ctx.Value(schemas.BifrostContextKeyPromptStreamRequest).(bool) - assert.True(t, streamVal, "expected stream=true from explicitly requested version with stream:true") -} - // ============================================================ // PreLLMHook — model params merge and override // ============================================================ @@ -871,7 +738,7 @@ func TestPreLLMHook_VersionParamsApplied_WhenRequestHasNoParams(t *testing.T) { prompt := makePrompt("p1", &v) p := newTestPlugin( - &staticResolver{promptID: "p1", versionSpecified: false}, + &staticResolver{promptID: "p1"}, map[string]*tables.TablePrompt{"p1": &prompt}, map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}}, ) @@ -894,7 +761,7 @@ func TestPreLLMHook_RequestParamsOverrideVersionParams(t *testing.T) { prompt := makePrompt("p1", &v) p := newTestPlugin( - &staticResolver{promptID: "p1", versionSpecified: false}, + &staticResolver{promptID: "p1"}, map[string]*tables.TablePrompt{"p1": &prompt}, map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}}, ) @@ -924,7 +791,7 @@ func TestPreLLMHook_RequestParamsPartialOverride(t *testing.T) { prompt := makePrompt("p1", &v) p := newTestPlugin( - &staticResolver{promptID: "p1", versionSpecified: false}, + &staticResolver{promptID: "p1"}, map[string]*tables.TablePrompt{"p1": &prompt}, map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}}, ) @@ -960,7 +827,7 @@ func TestPreLLMHook_ModelInVersionParams_DoesNotOverrideRequestModel(t *testing. prompt := makePrompt("p1", &v) p := newTestPlugin( - &staticResolver{promptID: "p1", versionSpecified: false}, + &staticResolver{promptID: "p1"}, map[string]*tables.TablePrompt{"p1": &prompt}, map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}}, ) @@ -995,7 +862,7 @@ func TestLoadCacheAndPreLLMHook_EndToEnd(t *testing.T) { p := newPluginWithStore(ms) require.NoError(t, p.loadCache(context.Background())) - p.resolver = &staticResolver{promptID: "p1", versionSpecified: false} + p.resolver = &staticResolver{promptID: "p1"} out, _, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("user msg"))) require.NoError(t, err) @@ -1026,7 +893,7 @@ func TestLoadCacheAndPreLLMHook_SpecificVersion(t *testing.T) { p := newPluginWithStore(ms) require.NoError(t, p.loadCache(context.Background())) - p.resolver = &staticResolver{promptID: "p1", versionNumber: 2, versionSpecified: true} + p.resolver = &staticResolver{promptID: "p1", versionNumber: 2} out, _, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("question"))) require.NoError(t, err) @@ -1045,7 +912,7 @@ func TestPreLLMHook_AssistantMessage_UIFormat(t *testing.T) { prompt := makePrompt("p1", &v) p := newTestPlugin( - &staticResolver{promptID: "p1", versionSpecified: false}, + &staticResolver{promptID: "p1"}, map[string]*tables.TablePrompt{"p1": &prompt}, map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}}, ) @@ -1062,4 +929,4 @@ func TestPreLLMHook_AssistantMessage_UIFormat(t *testing.T) { assert.Equal(t, schemas.ChatMessageRoleUser, out.ChatRequest.Input[2].Role) assert.Equal(t, "hello", msgText(out.ChatRequest.Input[2])) -} +} \ No newline at end of file diff --git a/transports/bifrost-http/handlers/inference.go b/transports/bifrost-http/handlers/inference.go index a813bb2702..eb3c4bccb2 100644 --- a/transports/bifrost-http/handlers/inference.go +++ b/transports/bifrost-http/handlers/inference.go @@ -544,13 +544,10 @@ func parseFallbacks(fallbackStrings []string) ([]schemas.Fallback, error) { return fallbacks, nil } -func effectiveStream(bodyStream *bool, bifrostCtx *schemas.BifrostContext) bool { +func effectiveStream(bodyStream *bool) bool { if bodyStream != nil { return *bodyStream } - if v, ok := bifrostCtx.Value(schemas.BifrostContextKeyPromptStreamRequest).(bool); ok && v { - return true - } return false } @@ -973,7 +970,7 @@ func (h *CompletionHandler) chatCompletion(ctx *fasthttp.RequestCtx) { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") return } - if effectiveStream(req.Stream, bifrostCtx) { + if effectiveStream(req.Stream) { h.handleStreamingChatCompletion(ctx, bifrostChatReq, bifrostCtx, cancel) return } @@ -1068,7 +1065,7 @@ func (h *CompletionHandler) responses(ctx *fasthttp.RequestCtx) { return } - if effectiveStream(req.Stream, bifrostCtx) { + if effectiveStream(req.Stream) { h.handleStreamingResponses(ctx, bifrostResponsesReq, bifrostCtx, cancel) return } diff --git a/transports/bifrost-http/handlers/prompts.go b/transports/bifrost-http/handlers/prompts.go index e5b96f0c38..c5e5737f0f 100644 --- a/transports/bifrost-http/handlers/prompts.go +++ b/transports/bifrost-http/handlers/prompts.go @@ -137,25 +137,28 @@ type CreateVersionRequest struct { ModelParams tables.ModelParams `json:"model_params"` Provider string `json:"provider"` Model string `json:"model"` + Variables tables.PromptVariables `json:"variables,omitempty"` } // CreateSessionRequest represents the request body for creating a session type CreateSessionRequest struct { - Name string `json:"name"` - VersionID *uint `json:"version_id,omitempty"` - Messages []tables.PromptMessage `json:"messages,omitempty"` - ModelParams tables.ModelParams `json:"model_params"` - Provider string `json:"provider"` - Model string `json:"model"` + Name string `json:"name"` + VersionID *uint `json:"version_id,omitempty"` + Messages []tables.PromptMessage `json:"messages,omitempty"` + ModelParams tables.ModelParams `json:"model_params"` + Provider string `json:"provider"` + Model string `json:"model"` + Variables tables.PromptVariables `json:"variables,omitempty"` } // UpdateSessionRequest represents the request body for updating a session type UpdateSessionRequest struct { - Name string `json:"name"` - Messages []tables.PromptMessage `json:"messages"` - ModelParams tables.ModelParams `json:"model_params"` - Provider string `json:"provider"` - Model string `json:"model"` + Name string `json:"name"` + Messages []tables.PromptMessage `json:"messages"` + ModelParams tables.ModelParams `json:"model_params"` + Provider string `json:"provider"` + Model string `json:"model"` + Variables tables.PromptVariables `json:"variables,omitempty"` } // RenameSessionRequest represents the request body for renaming a session @@ -631,12 +634,22 @@ func (h *PromptsHandler) createVersion(ctx *fasthttp.RequestCtx) { }) } + // Strip variable values — versions store keys only; values live in sessions + var versionVars tables.PromptVariables + if len(req.Variables) > 0 { + versionVars = make(tables.PromptVariables, len(req.Variables)) + for key := range req.Variables { + versionVars[key] = "" + } + } + version := &tables.TablePromptVersion{ PromptID: promptID, CommitMessage: req.CommitMessage, ModelParams: req.ModelParams, Provider: req.Provider, Model: req.Model, + Variables: versionVars, Messages: messages, } @@ -835,6 +848,7 @@ func (h *PromptsHandler) createSession(ctx *fasthttp.RequestCtx) { ModelParams: req.ModelParams, Provider: req.Provider, Model: req.Model, + Variables: req.Variables, Messages: messages, } @@ -890,6 +904,7 @@ func (h *PromptsHandler) updateSession(ctx *fasthttp.RequestCtx) { session.ModelParams = req.ModelParams session.Provider = req.Provider session.Model = req.Model + session.Variables = req.Variables // Update messages var messages []tables.TablePromptSessionMessage @@ -1066,12 +1081,22 @@ func (h *PromptsHandler) commitSession(ctx *fasthttp.RequestCtx) { return } + // Copy variable keys from session with empty values for the version + var versionVars tables.PromptVariables + if len(session.Variables) > 0 { + versionVars = make(tables.PromptVariables, len(session.Variables)) + for key := range session.Variables { + versionVars[key] = "" + } + } + version := &tables.TablePromptVersion{ PromptID: session.PromptID, CommitMessage: req.CommitMessage, ModelParams: session.ModelParams, Provider: session.Provider, Model: session.Model, + Variables: versionVars, Messages: messages, } diff --git a/transports/bifrost-http/server/plugins.go b/transports/bifrost-http/server/plugins.go index 031dadd73b..c0137d6e6b 100644 --- a/transports/bifrost-http/server/plugins.go +++ b/transports/bifrost-http/server/plugins.go @@ -163,8 +163,8 @@ func (s *BifrostHTTPServer) loadBuiltinPlugins(ctx context.Context) error { } s.Config.SetPluginOrderInfo(telemetry.PluginName, builtinPlacement, schemas.Ptr(1)) - // 2. Prompts (requires config store for prompt repository) - if s.Config.ConfigStore != nil { + // 2. Prompts (requires config store for prompt repository; disabled in enterprise) + if s.Config.ConfigStore != nil && ctx.Value(schemas.BifrostContextKeyIsEnterprise) == nil { s.registerPluginWithStatus(ctx, prompts.PluginName, nil, nil, false) } else { s.markPluginDisabled(prompts.PluginName) diff --git a/transports/bifrost-http/server/server.go b/transports/bifrost-http/server/server.go index 9ebca91e88..5c7a3ad679 100644 --- a/transports/bifrost-http/server/server.go +++ b/transports/bifrost-http/server/server.go @@ -285,6 +285,14 @@ func (s *BifrostHTTPServer) getGovernancePluginName() string { return governance.PluginName } +// getPromptsPluginName returns the prompts plugin name from context or default +func (s *BifrostHTTPServer) getPromptsPluginName() string { + if name, ok := s.Ctx.Value(schemas.BifrostContextKeyPromptsPluginName).(string); ok && name != "" { + return name + } + return prompts.PluginName +} + // getGovernancePlugin safely retrieves the governance plugin with proper locking. // It acquires a read lock, finds the plugin, releases the lock, performs type assertion, // and returns the BaseGovernancePlugin implementation or an error. @@ -1071,7 +1079,7 @@ func (s *BifrostHTTPServer) RegisterAPIRoutes(ctx context.Context, callbacks Ser cacheHandler = handlers.NewCacheHandler(semanticCachePlugin) } var promptsReloader handlers.PromptCacheReloader - if promptsPlugin, err := lib.FindPluginAs[*prompts.Plugin](s.Config, prompts.PluginName); err == nil && promptsPlugin != nil { + if promptsPlugin, err := lib.FindPluginAs[handlers.PromptCacheReloader](s.Config, s.getPromptsPluginName()); err == nil && promptsPlugin != nil { promptsReloader = promptsPlugin } // Websocket handler needs to go below UI handler diff --git a/ui/app/_fallbacks/enterprise/components/prompt-deployments/promptDeploymentView.tsx b/ui/app/_fallbacks/enterprise/components/prompt-deployments/promptDeploymentView.tsx index 8f392276e8..628e44439a 100644 --- a/ui/app/_fallbacks/enterprise/components/prompt-deployments/promptDeploymentView.tsx +++ b/ui/app/_fallbacks/enterprise/components/prompt-deployments/promptDeploymentView.tsx @@ -3,10 +3,11 @@ import ContactUsView from "../views/contactUsView"; export default function PromptDeploymentView() { return ( -
+
} + align="top" + className="justify-start gap-3 rounded-md border p-4" + icon={} title="Unlock prompt deployments for better prompt versioning and A/B testing." description="This feature is a part of the Bifrost enterprise license. We would love to know more about your use case and how we can help you." readmeLink="https://docs.getbifrost.ai/enterprise/prompt-deployments" diff --git a/ui/app/globals.css b/ui/app/globals.css index 2a043151fb..96650d22a5 100644 --- a/ui/app/globals.css +++ b/ui/app/globals.css @@ -3,6 +3,7 @@ @source "../app/**/*.tsx"; @source "../node_modules/streamdown/dist/*.js"; +@source "../../../bifrost-enterprise/ui/**/*.tsx"; @custom-variant dark (&:is(.dark *)); diff --git a/ui/app/workspace/mcp-logs/views/bifrost.code-workspace b/ui/app/workspace/mcp-logs/views/bifrost.code-workspace new file mode 100644 index 0000000000..8c5511061b --- /dev/null +++ b/ui/app/workspace/mcp-logs/views/bifrost.code-workspace @@ -0,0 +1,11 @@ +{ + "folders": [ + { + "path": "../../../../.." + }, + { + "path": "../../../../../../bifrost-enterprise" + } + ], + "settings": {} +} \ No newline at end of file diff --git a/ui/app/workspace/prompt-repo/deployments/page.tsx b/ui/app/workspace/prompt-repo/deployments/page.tsx deleted file mode 100644 index 26adfda683..0000000000 --- a/ui/app/workspace/prompt-repo/deployments/page.tsx +++ /dev/null @@ -1,9 +0,0 @@ -import PromptDeploymentView from "@enterprise/components/prompt-deployments/promptDeploymentView"; - -export default function PromptDeploymentsPage() { - return ( -
- -
- ); -} \ No newline at end of file diff --git a/ui/app/workspace/prompt-repo/page.tsx b/ui/app/workspace/prompt-repo/page.tsx new file mode 100644 index 0000000000..935bff8f59 --- /dev/null +++ b/ui/app/workspace/prompt-repo/page.tsx @@ -0,0 +1,12 @@ +"use client"; + +import { PromptProvider } from "@/components/prompts/context"; +import PromptsView from "@/components/prompts/promptsView"; + +export default function PromptRepoPage() { + return ( + + + + ); +} diff --git a/ui/app/workspace/prompt-repo/prompts/page.tsx b/ui/app/workspace/prompt-repo/prompts/page.tsx index fa96bb86f7..b2d6f8a7db 100644 --- a/ui/app/workspace/prompt-repo/prompts/page.tsx +++ b/ui/app/workspace/prompt-repo/prompts/page.tsx @@ -1,10 +1,9 @@ -import { PromptProvider } from "@/components/prompts/context"; -import PromptsView from "@/components/prompts/promptsView"; +"use client"; + +import { redirect, useSearchParams } from "next/navigation"; export default function PromptsPage() { - return ( - - - - ); -} \ No newline at end of file + const searchParams = useSearchParams(); + const queryString = searchParams.toString(); + redirect(`/workspace/prompt-repo${queryString ? `?${queryString}` : ""}`); +} diff --git a/ui/app/workspace/routing-rules/components/celBuilder/celRuleBuilder.tsx b/ui/app/workspace/routing-rules/components/celBuilder/celRuleBuilder.tsx index 4660699352..29f034d5d9 100644 --- a/ui/app/workspace/routing-rules/components/celBuilder/celRuleBuilder.tsx +++ b/ui/app/workspace/routing-rules/components/celBuilder/celRuleBuilder.tsx @@ -1,25 +1,16 @@ /** - * CEL Rule Builder Component for Routing Rules - * Visual query builder for creating CEL expressions + * CEL Rule Builder for Routing Rules + * Thin wrapper around the reusable CELRuleBuilder with routing-specific config */ -import { Button } from "@/components/ui/button"; -import { Label } from "@/components/ui/label"; -import { Textarea } from "@/components/ui/textarea"; +"use client"; + +import { CELRuleBuilder as BaseCELRuleBuilder } from "@/components/ui/custom/celBuilder"; import { getRoutingFields } from "@/lib/config/celFieldsRouting"; import { celOperatorsRouting } from "@/lib/config/celOperatorsRouting"; -import { convertRuleGroupToCEL } from "@/lib/utils/celConverterRouting"; -import { useCopyToClipboard } from "@/hooks/useCopyToClipboard"; -import { Check, Copy, Loader2 } from "lucide-react"; -import { useEffect, useMemo, useRef, useState } from "react"; -import { Field, QueryBuilder, RuleGroupType } from "react-querybuilder"; -import "react-querybuilder/dist/query-builder.css"; -import { ActionButton } from "./actionButton"; -import { CombinatorSelector } from "./combinatorSelector"; -import { FieldSelector } from "./fieldSelector"; -import { OperatorSelector } from "./operatorSelector"; -import { QueryBuilderWrapper } from "./queryBuilderWrapper"; -import { ValueEditor } from "./valueEditor"; +import { convertRuleGroupToCEL, validateRegexPattern } from "@/lib/utils/celConverterRouting"; +import { useMemo } from "react"; +import { RuleGroupType } from "react-querybuilder"; interface CELRuleBuilderProps { onChange?: (celExpression: string, query: RuleGroupType) => void; @@ -30,11 +21,6 @@ interface CELRuleBuilderProps { isLoading?: boolean; } -const defaultQuery: RuleGroupType = { - combinator: "and", - rules: [], -}; - export function CELRuleBuilder({ onChange, initialQuery, @@ -43,96 +29,18 @@ export function CELRuleBuilder({ isLoading = false, allowCustomModels = false, }: CELRuleBuilderProps) { - const [query, setQuery] = useState(initialQuery || defaultQuery); - const [celExpression, setCelExpression] = useState(""); - const { copy, copied } = useCopyToClipboard(); - const onChangeRef = useRef(onChange); - - // Keep ref updated so the query effect always invokes the latest callback - useEffect(() => { - onChangeRef.current = onChange; - }, [onChange]); - - // Generate fields with dynamic providers and models - const fields = useMemo(() => { - const celFields = getRoutingFields(providers, models); - return celFields.map((field) => ({ - ...field, - value: field.name, - })) as Field[]; - }, [providers, models]); - - useEffect(() => { - const expression = convertRuleGroupToCEL(query); - setCelExpression(expression); - onChangeRef.current?.(expression, query); - }, [query]); - - const handleCopy = () => copy(celExpression); - - // Show loading state - if (isLoading) { - return ( -
- - Loading CEL builder... -
- ); - } + const fields = useMemo(() => getRoutingFields(providers, models), [providers, models]); return ( -
-
-
- - ({ - name: op.name, - label: op.label, - }))} - controlElements={{ - fieldSelector: FieldSelector, - operatorSelector: OperatorSelector, - valueEditor: ValueEditor, - addRuleAction: ActionButton, - addGroupAction: ActionButton, - removeRuleAction: ActionButton, - removeGroupAction: ActionButton, - combinatorSelector: CombinatorSelector, - }} - translations={{ - addRule: { label: "Add Rule" }, - addGroup: { label: "Add Rule Group" }, - }} - /> - -
-
- -
-
- - -
-