diff --git a/core/schemas/bifrost.go b/core/schemas/bifrost.go index 4f8c0f0fe9..34f120432c 100644 --- a/core/schemas/bifrost.go +++ b/core/schemas/bifrost.go @@ -223,6 +223,7 @@ const ( BifrostContextKeyHTTPRequestType BifrostContextKey = "bifrost-http-request-type" // RequestType (set by bifrost - DO NOT SET THIS MANUALLY)) BifrostContextKeyPassthroughExtraParams BifrostContextKey = "bifrost-passthrough-extra-params" // bool BifrostContextKeyRoutingEnginesUsed BifrostContextKey = "bifrost-routing-engines-used" // []string (set by bifrost - DO NOT SET THIS MANUALLY) - list of routing engines used ("routing-rule", "governance", "loadbalancing", etc.) + BifrostContextKeyPromptStreamRequest BifrostContextKey = "bifrost-prompt-stream-request" // bool (set by prompts HTTP plugin when prompt version model_params.stream is true and body omitted stream) BifrostContextKeyRoutingEngineLogs BifrostContextKey = "bifrost-routing-engine-logs" // []RoutingEngineLogEntry (set by bifrost - DO NOT SET THIS MANUALLY) - list of routing engine log entries BifrostContextKeyTransportPluginLogs BifrostContextKey = "bifrost-transport-plugin-logs" // []PluginLogEntry (transport-layer plugin logs accumulated during HTTP transport hooks) BifrostContextKeyTransportPostHookCompleter BifrostContextKey = "bifrost-transport-posthook-completer" // func() (callback to run HTTPTransportPostHook after streaming - set by transport interceptor middleware) diff --git a/framework/configstore/prompts.go b/framework/configstore/prompts.go index 18b3638bb8..e760351b95 100644 --- a/framework/configstore/prompts.go +++ b/framework/configstore/prompts.go @@ -30,9 +30,6 @@ func (s *RDBConfigStore) GetFolders(ctx context.Context) ([]tables.TableFolder, if err := s.db.WithContext(ctx). Order("created_at DESC"). Find(&folders).Error; err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return []tables.TableFolder{}, nil - } return nil, err } @@ -147,9 +144,6 @@ func (s *RDBConfigStore) GetPrompts(ctx context.Context, folderID *string) ([]ta } if err := query.Find(&prompts).Error; err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return []tables.TablePrompt{}, nil - } return nil, err } @@ -261,6 +255,18 @@ func (s *RDBConfigStore) DeletePrompt(ctx context.Context, id string) error { // Prompt Repository - Versions // ============================================================================ +// GetAllPromptVersions returns every version across all prompts in a single query. +func (s *RDBConfigStore) GetAllPromptVersions(ctx context.Context) ([]tables.TablePromptVersion, error) { + var versions []tables.TablePromptVersion + if err := s.db.WithContext(ctx). + Preload("Messages", func(db *gorm.DB) *gorm.DB { return db.Order("order_index ASC") }). + Order("prompt_id ASC, version_number DESC"). + Find(&versions).Error; err != nil { + return nil, err + } + return versions, nil +} + // GetPromptVersions gets all versions for a prompt func (s *RDBConfigStore) GetPromptVersions(ctx context.Context, promptID string) ([]tables.TablePromptVersion, error) { var versions []tables.TablePromptVersion @@ -269,9 +275,6 @@ func (s *RDBConfigStore) GetPromptVersions(ctx context.Context, promptID string) Where("prompt_id = ?", promptID). Order("version_number DESC"). Find(&versions).Error; err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return []tables.TablePromptVersion{}, nil - } return nil, err } return versions, nil @@ -416,9 +419,6 @@ func (s *RDBConfigStore) GetPromptSessions(ctx context.Context, promptID string) Where("prompt_id = ?", promptID). Order("created_at DESC"). Find(&sessions).Error; err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return []tables.TablePromptSession{}, nil - } return nil, err } return sessions, nil diff --git a/framework/configstore/store.go b/framework/configstore/store.go index f20db6191a..b0fe5c74c7 100644 --- a/framework/configstore/store.go +++ b/framework/configstore/store.go @@ -349,6 +349,7 @@ type ConfigStore interface { DeletePrompt(ctx context.Context, id string) error // Prompt Repository - Versions + GetAllPromptVersions(ctx context.Context) ([]tables.TablePromptVersion, error) GetPromptVersions(ctx context.Context, promptID string) ([]tables.TablePromptVersion, error) GetPromptVersionByID(ctx context.Context, id uint) (*tables.TablePromptVersion, error) GetLatestPromptVersion(ctx context.Context, promptID string) (*tables.TablePromptVersion, error) diff --git a/plugins/prompts/go.mod b/plugins/prompts/go.mod new file mode 100644 index 0000000000..5c5acb4864 --- /dev/null +++ b/plugins/prompts/go.mod @@ -0,0 +1,79 @@ +module github.com/maximhq/bifrost/plugins/prompts + +go 1.26.1 + +require ( + github.com/maximhq/bifrost/core v1.4.13 + github.com/maximhq/bifrost/framework v1.2.32 + github.com/stretchr/testify v1.11.1 +) + +require ( + cloud.google.com/go v0.123.0 // indirect + cloud.google.com/go/compute/metadata v0.9.0 // indirect + github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0 // indirect + github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1 // indirect + github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 // indirect + github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0 // indirect + github.com/andybalholm/brotli v1.2.0 // indirect + github.com/aws/aws-sdk-go-v2 v1.41.3 // indirect + github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.6 // indirect + github.com/aws/aws-sdk-go-v2/config v1.32.11 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.19.11 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.19 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.19 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.19 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.5 // indirect + github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.16 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.6 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.7 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.19 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.16 // indirect + github.com/aws/aws-sdk-go-v2/service/s3 v1.94.0 // indirect + github.com/aws/aws-sdk-go-v2/service/signin v1.0.7 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.30.12 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.16 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.41.8 // indirect + github.com/aws/smithy-go v1.24.2 // indirect + github.com/bahlo/generic-list-go v0.2.0 // indirect + github.com/buger/jsonparser v1.1.2 // indirect + github.com/bytedance/gopkg v0.1.3 // indirect + github.com/bytedance/sonic v1.15.0 // indirect + github.com/bytedance/sonic/loader v0.5.0 // indirect + github.com/cloudwego/base64x v0.1.6 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect + github.com/golang-jwt/jwt/v5 v5.3.0 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/invopop/jsonschema v0.13.0 // indirect + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/jinzhu/now v1.1.5 // indirect + github.com/klauspost/compress v1.18.2 // indirect + github.com/klauspost/cpuid/v2 v2.3.0 // indirect + github.com/kylelemons/godebug v1.1.0 // indirect + github.com/mailru/easyjson v0.9.1 // indirect + github.com/mark3labs/mcp-go v0.43.2 // indirect + github.com/mattn/go-colorable v0.1.14 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect + github.com/rs/zerolog v1.34.0 // indirect + github.com/spf13/cast v1.10.0 // indirect + github.com/tidwall/gjson v1.18.0 // indirect + github.com/tidwall/match v1.1.1 // indirect + github.com/tidwall/pretty v1.2.0 // indirect + github.com/tidwall/sjson v1.2.5 // indirect + github.com/twitchyliquid64/golang-asm v0.15.1 // indirect + github.com/valyala/bytebufferpool v1.0.0 // indirect + github.com/valyala/fasthttp v1.68.0 // indirect + github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect + go.starlark.net v0.0.0-20260102030733-3fee463870c9 // indirect + golang.org/x/arch v0.23.0 // indirect + golang.org/x/crypto v0.49.0 // indirect + golang.org/x/net v0.52.0 // indirect + golang.org/x/oauth2 v0.35.0 // indirect + golang.org/x/sys v0.42.0 // indirect + golang.org/x/text v0.35.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect + gorm.io/gorm v1.31.1 // indirect +) diff --git a/plugins/prompts/go.sum b/plugins/prompts/go.sum new file mode 100644 index 0000000000..b51fcaafaa --- /dev/null +++ b/plugins/prompts/go.sum @@ -0,0 +1,194 @@ +cloud.google.com/go v0.123.0 h1:2NAUJwPR47q+E35uaJeYoNhuNEM9kM8SjgRgdeOJUSE= +cloud.google.com/go v0.123.0/go.mod h1:xBoMV08QcqUGuPW65Qfm1o9Y4zKZBpGS+7bImXLTAZU= +cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdBtwLoEkH9Zs= +cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0 h1:JXg2dwJUmPB9JmtVmdEB16APJ7jurfbY5jnfXpJoRMc= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0/go.mod h1:YD5h/ldMsG0XiIw7PdyNhLxaM317eFh5yNLccNfGdyw= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1 h1:Hk5QBxZQC1jb2Fwj6mpzme37xbCDdNTxU7O9eb5+LB4= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1/go.mod h1:IYus9qsFobWIc2YVwe/WPjcnyCkPKtnHAqUYeebc8z0= +github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache v0.3.2 h1:yz1bePFlP5Vws5+8ez6T3HWXPmwOK7Yvq8QxDBD3SKY= +github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache v0.3.2/go.mod h1:Pa9ZNPuoNu/GztvBSKk9J1cDJW6vk/n0zLtV4mgd8N8= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 h1:9iefClla7iYpfYWdzPCRDozdmndjTm8DXdpCzPajMgA= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2/go.mod h1:XtLgD3ZD34DAaVIIAyG3objl5DynM3CQ/vMcbBNJZGI= +github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1 h1:WJTmL004Abzc5wDB5VtZG2PJk5ndYDgVacGqfirKxjM= +github.com/AzureAD/microsoft-authentication-extensions-for-go/cache v0.1.1/go.mod h1:tCcJZ0uHAmvjsVYzEFivsRTN00oz5BEsRgQHu5JZ9WE= +github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0 h1:XRzhVemXdgvJqCH0sFfrBUTnUJSBrBf7++ypk+twtRs= +github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0/go.mod h1:HKpQxkWaGLJ+D/5H8QRpyQXA1eKjxkFlOMwck5+33Jk= +github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= +github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= +github.com/aws/aws-sdk-go-v2 v1.41.3 h1:4kQ/fa22KjDt13QCy1+bYADvdgcxpfH18f0zP542kZA= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.6 h1:N4lRUXZpZ1KVEUn6hxtco/1d2lgYhNn1fHkkl8WhlyQ= +github.com/aws/aws-sdk-go-v2/config v1.32.11 h1:ftxI5sgz8jZkckuUHXfC/wMUc8u3fG1vQS0plr2F2Zs= +github.com/aws/aws-sdk-go-v2/credentials v1.19.11 h1:NdV8cwCcAXrCWyxArt58BrvZJ9pZ9Fhf9w6Uh5W3Uyc= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.19 h1:INUvJxmhdEbVulJYHI061k4TVuS3jzzthNvjqvVvTKM= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.19 h1:/sECfyq2JTifMI2JPyZ4bdRN77zJmr6SrS1eL3augIA= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.19 h1:AWeJMk33GTBf6J20XJe6qZoRSJo0WfUhsMdUKhoODXE= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.5 h1:clHU5fm//kWS1C2HgtgWxfQbFbx4b6rx+5jzhgX9HrI= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.16 h1:CjMzUs78RDDv4ROu3JnJn/Ig1r6ZD7/T2DXLLRpejic= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.16/go.mod h1:uVW4OLBqbJXSHJYA9svT9BluSvvwbzLQ2Crf6UPzR3c= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.6 h1:XAq62tBTJP/85lFD5oqOOe7YYgWxY9LvWq8plyDvDVg= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.7 h1:DIBqIrJ7hv+e4CmIk2z3pyKT+3B6qVMgRsawHiR3qso= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.7/go.mod h1:vLm00xmBke75UmpNvOcZQ/Q30ZFjbczeLFqGx5urmGo= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.19 h1:X1Tow7suZk9UCJHE1Iw9GMZJJl0dAnKXXP1NaSDHwmw= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.16 h1:NSbvS17MlI2lurYgXnCOLvCFX38sBW4eiVER7+kkgsU= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.16/go.mod h1:SwT8Tmqd4sA6G1qaGdzWCJN99bUmPGHfRwwq3G5Qb+A= +github.com/aws/aws-sdk-go-v2/service/s3 v1.94.0 h1:SWTxh/EcUCDVqi/0s26V6pVUq0BBG7kx0tDTmF/hCgA= +github.com/aws/aws-sdk-go-v2/service/s3 v1.94.0/go.mod h1:79S2BdqCJpScXZA2y+cpZuocWsjGjJINyXnOsf5DTz8= +github.com/aws/aws-sdk-go-v2/service/signin v1.0.7 h1:Y2cAXlClHsXkkOvWZFXATr34b0hxxloeQu/pAZz2row= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.12 h1:iSsvB9EtQ09YrsmIc44Heqlx5ByGErqhPK1ZQLppias= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.16 h1:EnUdUqRP1CNzt2DkV67tJx6XDN4xlfBFm+bzeNOQVb0= +github.com/aws/aws-sdk-go-v2/service/sts v1.41.8 h1:XQTQTF75vnug2TXS8m7CVJfC2nniYPZnO1D4Np761Oo= +github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng= +github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= +github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg= +github.com/buger/jsonparser v1.1.2 h1:frqHqw7otoVbk5M8LlE/L7HTnIq2v9RX6EJ48i9AxJk= +github.com/buger/jsonparser v1.1.2/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= +github.com/bytedance/gopkg v0.1.3 h1:TPBSwH8RsouGCBcMBktLt1AymVo2TVsBVCY4b6TnZ/M= +github.com/bytedance/gopkg v0.1.3/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM= +github.com/bytedance/sonic v1.15.0 h1:/PXeWFaR5ElNcVE84U0dOHjiMHQOwNIx3K4ymzh/uSE= +github.com/bytedance/sonic v1.15.0/go.mod h1:tFkWrPz0/CUCLEF4ri4UkHekCIcdnkqXw9VduqpJh0k= +github.com/bytedance/sonic/loader v0.5.0 h1:gXH3KVnatgY7loH5/TkeVyXPfESoqSBSBEiDd5VjlgE= +github.com/bytedance/sonic/loader v0.5.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo= +github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= +github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= +github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/fasthttp/websocket v1.5.12 h1:e4RGPpWW2HTbL3zV0Y/t7g0ub294LkiuXXUuTOUInlE= +github.com/fasthttp/websocket v1.5.12/go.mod h1:I+liyL7/4moHojiOgUOIKEWm9EIxHqxZChS+aMFltyg= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= +github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/hajimehoshi/go-mp3 v0.3.4 h1:NUP7pBYH8OguP4diaTZ9wJbUbk3tC0KlfzsEpWmYj68= +github.com/hajimehoshi/go-mp3 v0.3.4/go.mod h1:fRtZraRFcWb0pu7ok0LqyFhCUrPeMsGRSVop0eemFmo= +github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E= +github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.7.6 h1:rWQc5FwZSPX58r1OQmkuaNicxdmExaEz5A2DO2hUuTk= +github.com/jackc/pgx/v5 v5.7.6/go.mod h1:aruU7o91Tc2q2cFp5h4uP3f6ztExVpyVv88Xl/8Vl8M= +github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= +github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/keybase/go-keychain v0.0.1 h1:way+bWYa6lDppZoZcgMbYsvC7GxljxrskdNInRtuthU= +github.com/keybase/go-keychain v0.0.1/go.mod h1:PdEILRW3i9D8JcdM+FmY6RwkHGnhHxXwkPPMeUgOK1k= +github.com/klauspost/compress v1.18.2 h1:iiPHWW0YrcFgpBYhsA6D1+fqHssJscY/Tm/y2Uqnapk= +github.com/klauspost/compress v1.18.2/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4= +github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= +github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= +github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= +github.com/mailru/easyjson v0.9.1 h1:LbtsOm5WAswyWbvTEOqhypdPeZzHavpZx96/n553mR8= +github.com/mailru/easyjson v0.9.1/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= +github.com/mark3labs/mcp-go v0.43.2 h1:21PUSlWWiSbUPQwXIJ5WKlETixpFpq+WBpbMGDSVy/I= +github.com/mark3labs/mcp-go v0.43.2/go.mod h1:YnJfOL382MIWDx1kMY+2zsRHU/q78dBg9aFb8W6Thdw= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= +github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs= +github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/maximhq/bifrost/core v1.4.13 h1:ECCIbdgLUy+jYRXOVn3E9uYCu3mYCOh7GV4ElVjHKLU= +github.com/maximhq/bifrost/core v1.4.13/go.mod h1:Kc11vnzU8UgwBTJS+TgG8S9vuSnas+T8uYx3xwzFuIA= +github.com/maximhq/bifrost/framework v1.2.32 h1:J8xhYXM/5bOmNmpWP9avQYoPV63bQ6IoKLAl3ZvxHok= +github.com/maximhq/bifrost/framework v1.2.32/go.mod h1:8IegKP+/HGpbl1Kh7TP/CFuENPjQVUpJiuKh/u3IvXk= +github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= +github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= +github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= +github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= +github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= +github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= +github.com/savsgio/gotils v0.0.0-20250408102913-196191ec6287 h1:qIQ0tWF9vxGtkJa24bR+2i53WBCz1nW/Pc47oVYauC4= +github.com/savsgio/gotils v0.0.0-20250408102913-196191ec6287/go.mod h1:sM7Mt7uEoCeFSCBM+qBrqvEo+/9vdmj19wzp3yzUhmg= +github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY= +github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= +github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= +github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= +github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= +github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= +github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= +github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/valyala/fasthttp v1.68.0 h1:v12Nx16iepr8r9ySOwqI+5RBJ/DqTxhOy1HrHoDFnok= +github.com/valyala/fasthttp v1.68.0/go.mod h1:5EXiRfYQAoiO/khu4oU9VISC/eVY6JqmSpPJoHCKsz4= +github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= +github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= +github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= +github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= +go.starlark.net v0.0.0-20260102030733-3fee463870c9 h1:nV1OyvU+0CYrp5eKfQ3rD03TpFYYhH08z31NK1HmtTk= +go.starlark.net v0.0.0-20260102030733-3fee463870c9/go.mod h1:YKMCv9b1WrfWmeqdV5MAuEHWsu5iC+fe6kYl2sQjdI8= +golang.org/x/arch v0.23.0 h1:lKF64A2jF6Zd8L0knGltUnegD62JMFBiCPBmQpToHhg= +golang.org/x/arch v0.23.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A= +golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4= +golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA= +golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0= +golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw= +golang.org/x/oauth2 v0.35.0 h1:Mv2mzuHuZuY2+bkyWXIHMfhNdJAdwW3FuWeCPYN5GVQ= +golang.org/x/oauth2 v0.35.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= +golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8= +golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA= +google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= +google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/driver/postgres v1.6.0 h1:2dxzU8xJ+ivvqTRph34QX+WrRaJlmfyPqXmoGVjMBa4= +gorm.io/driver/postgres v1.6.0/go.mod h1:vUw0mrGgrTK+uPHEhAdV4sfFELrByKVGnaVRkXDhtWo= +gorm.io/driver/sqlite v1.6.0 h1:WHRRrIiulaPiPFmDcod6prc4l2VGVWHz80KspNsxSfQ= +gorm.io/driver/sqlite v1.6.0/go.mod h1:AO9V1qIQddBESngQUKWL9yoH93HIeA1X6V633rBwyT8= +gorm.io/gorm v1.31.1 h1:7CA8FTFz/gRfgqgpeKIBcervUn3xSyPUmr6B2WXJ7kg= +gorm.io/gorm v1.31.1/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs= diff --git a/plugins/prompts/helpers_test.go b/plugins/prompts/helpers_test.go new file mode 100644 index 0000000000..ec8f778c7b --- /dev/null +++ b/plugins/prompts/helpers_test.go @@ -0,0 +1,385 @@ +package prompts + +import ( + "context" + "encoding/json" + "fmt" + "sync" + + "github.com/maximhq/bifrost/core/schemas" + tables "github.com/maximhq/bifrost/framework/configstore/tables" +) + +// ============================================================ +// MockLogger — captures log output per level for assertions. +// Follows the same pattern as plugins/governance/test_utils.go. +// ============================================================ + +type MockLogger struct { + mu sync.Mutex + debugs []string + infos []string + warnings []string + errors []string +} + +func NewMockLogger() *MockLogger { + return &MockLogger{ + debugs: make([]string, 0), + infos: make([]string, 0), + warnings: make([]string, 0), + errors: make([]string, 0), + } +} + +func (l *MockLogger) Debug(format string, args ...any) { + l.mu.Lock() + defer l.mu.Unlock() + l.debugs = append(l.debugs, format) +} + +func (l *MockLogger) Info(format string, args ...any) { + l.mu.Lock() + defer l.mu.Unlock() + l.infos = append(l.infos, format) +} + +func (l *MockLogger) Warn(format string, args ...any) { + l.mu.Lock() + defer l.mu.Unlock() + l.warnings = append(l.warnings, format) +} + +func (l *MockLogger) Error(format string, args ...any) { + l.mu.Lock() + defer l.mu.Unlock() + l.errors = append(l.errors, format) +} + +func (l *MockLogger) Fatal(format string, args ...any) { + l.mu.Lock() + defer l.mu.Unlock() + l.errors = append(l.errors, format) +} + +func (l *MockLogger) SetLevel(_ schemas.LogLevel) {} +func (l *MockLogger) SetOutputType(_ schemas.LoggerOutputType) {} +func (l *MockLogger) LogHTTPRequest(_ schemas.LogLevel, _ string) schemas.LogEventBuilder { + return schemas.NoopLogEvent +} + +// Warned returns true if at least one warning was logged. +func (l *MockLogger) Warned() bool { + l.mu.Lock() + defer l.mu.Unlock() + return len(l.warnings) > 0 +} + +// ============================================================ +// mockStore — satisfies promptStore with controllable responses. +// ============================================================ + +type mockStore struct { + prompts []tables.TablePrompt + versions []tables.TablePromptVersion + err error +} + +func (m *mockStore) GetPrompts(_ context.Context, _ *string) ([]tables.TablePrompt, error) { + return m.prompts, m.err +} + +func (m *mockStore) GetAllPromptVersions(_ context.Context) ([]tables.TablePromptVersion, error) { + return m.versions, m.err +} + +// versionsErrStore succeeds on GetPrompts but fails on GetAllPromptVersions. +type versionsErrStore struct { + prompts []tables.TablePrompt + err error +} + +func (s *versionsErrStore) GetPrompts(_ context.Context, _ *string) ([]tables.TablePrompt, error) { + return s.prompts, nil +} + +func (s *versionsErrStore) GetAllPromptVersions(_ context.Context) ([]tables.TablePromptVersion, error) { + return nil, s.err +} + +// ============================================================ +// staticResolver — returns fixed IDs; decouples PreLLMHook +// tests from HTTP header / context mechanics. +// ============================================================ + +type staticResolver struct { + promptID string + versionNumber int + versionSpecified bool + err error +} + +func (r *staticResolver) Resolve(_ *schemas.BifrostContext, _ *schemas.BifrostRequest) (string, int, bool, error) { + return r.promptID, r.versionNumber, r.versionSpecified, r.err +} + +// ============================================================ +// Plugin builders +// ============================================================ + +// newPluginWithStore builds a Plugin whose store is set but maps are empty. +// Use only for loadCache tests. +func newPluginWithStore(s promptStore) *Plugin { + return &Plugin{ + store: s, + logger: NewMockLogger(), + resolver: &staticResolver{}, + promptsByID: make(map[string]*tables.TablePrompt), + versionsByPromptAndNumber: make(map[string]map[int]*tables.TablePromptVersion), + } +} + +// newTestPlugin builds a Plugin with pre-seeded in-memory maps, bypassing Init +// and loadCache entirely. The store is nil — safe as long as no test path calls +// into the store. +func newTestPlugin(resolver PromptResolver, promptMap map[string]*tables.TablePrompt, versionMap map[string]map[int]*tables.TablePromptVersion) *Plugin { + return newTestPluginWithLogger(resolver, promptMap, versionMap, NewMockLogger()) +} + +// newTestPluginWithLogger is like newTestPlugin but accepts a caller-provided logger +// so tests can inspect logged warnings. +func newTestPluginWithLogger(resolver PromptResolver, promptMap map[string]*tables.TablePrompt, versionMap map[string]map[int]*tables.TablePromptVersion, log schemas.Logger) *Plugin { + if resolver == nil { + resolver = &staticResolver{} + } + if promptMap == nil { + promptMap = make(map[string]*tables.TablePrompt) + } + if versionMap == nil { + versionMap = make(map[string]map[int]*tables.TablePromptVersion) + } + return &Plugin{ + store: nil, + logger: log, + resolver: resolver, + promptsByID: promptMap, + versionsByPromptAndNumber: versionMap, + } +} + +// ============================================================ +// Message builders +// ============================================================ + +// versionMsg creates a TablePromptVersionMessage in the production envelope +// format {"payload": }, matching what the frontend writes +// to the DB and what AfterFind populates into the Message field. +func versionMsg(role schemas.ChatMessageRole, text string) tables.TablePromptVersionMessage { + content := text + inner := schemas.ChatMessage{ + Role: role, + Content: &schemas.ChatMessageContent{ContentStr: &content}, + } + innerJSON, err := json.Marshal(inner) + if err != nil { + panic(fmt.Sprintf("versionMsg: marshal inner failed: %v", err)) + } + envelope := fmt.Sprintf(`{"payload":%s}`, string(innerJSON)) + return tables.TablePromptVersionMessage{ + Message: tables.PromptMessage(envelope), + } +} + +// versionMsgViaJSON creates a TablePromptVersionMessage that has an empty Message +// field but a populated MessageJSON field, exercising the fallback branch in +// chatMessagesFromVersionMessages. +func versionMsgViaJSON(role schemas.ChatMessageRole, text string) tables.TablePromptVersionMessage { + content := text + inner := schemas.ChatMessage{ + Role: role, + Content: &schemas.ChatMessageContent{ContentStr: &content}, + } + innerJSON, err := json.Marshal(inner) + if err != nil { + panic(fmt.Sprintf("versionMsgViaJSON: marshal failed: %v", err)) + } + envelope := fmt.Sprintf(`{"payload":%s}`, string(innerJSON)) + return tables.TablePromptVersionMessage{ + Message: nil, // empty — triggers MessageJSON fallback + MessageJSON: envelope, + } +} + +// makeVersion returns a TablePromptVersion with the supplied messages. +// VersionNumber is set to int(id) so tests can reference versions by their number. +func makeVersion(id uint, promptID string, isLatest bool, msgs ...tables.TablePromptVersionMessage) tables.TablePromptVersion { + return tables.TablePromptVersion{ + ID: id, + PromptID: promptID, + IsLatest: isLatest, + VersionNumber: int(id), + Messages: msgs, + } +} + +// makePrompt returns a TablePrompt, optionally linked to a latest version. +func makePrompt(id string, latest *tables.TablePromptVersion) tables.TablePrompt { + return tables.TablePrompt{ID: id, Name: id, LatestVersion: latest} +} + +// ============================================================ +// Request / context builders +// ============================================================ + +// chatRequest returns a BifrostRequest wrapping a ChatRequest with the given messages. +func chatRequest(msgs ...schemas.ChatMessage) *schemas.BifrostRequest { + return &schemas.BifrostRequest{ + ChatRequest: &schemas.BifrostChatRequest{ + Input: append([]schemas.ChatMessage{}, msgs...), + }, + } +} + +// userMsg returns a user-role ChatMessage with plain text content. +func userMsg(text string) schemas.ChatMessage { + t := text + return schemas.ChatMessage{ + Role: schemas.ChatMessageRoleUser, + Content: &schemas.ChatMessageContent{ContentStr: &t}, + } +} + +// systemMsg returns a system-role ChatMessage with plain text content. +func systemMsg(text string) schemas.ChatMessage { + t := text + return schemas.ChatMessage{ + Role: schemas.ChatMessageRoleSystem, + Content: &schemas.ChatMessageContent{ContentStr: &t}, + } +} + +// bfCtx returns a fresh BifrostContext with no deadline. +func bfCtx() *schemas.BifrostContext { + return schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) +} + +// versionMsgWithToolCall creates a TablePromptVersionMessage for an assistant +// message that contains a single tool call (role=assistant, tool_calls=[...]). +func versionMsgWithToolCall(callID, funcName, funcArgs string) tables.TablePromptVersionMessage { + name := funcName + id := callID + inner := schemas.ChatMessage{ + Role: schemas.ChatMessageRoleAssistant, + ChatAssistantMessage: &schemas.ChatAssistantMessage{ + ToolCalls: []schemas.ChatAssistantMessageToolCall{ + { + ID: &id, + Function: schemas.ChatAssistantMessageToolCallFunction{ + Name: &name, + Arguments: funcArgs, + }, + }, + }, + }, + } + innerJSON, err := json.Marshal(inner) + if err != nil { + panic(fmt.Sprintf("versionMsgWithToolCall: marshal failed: %v", err)) + } + envelope := fmt.Sprintf(`{"payload":%s}`, string(innerJSON)) + return tables.TablePromptVersionMessage{ + Message: tables.PromptMessage(envelope), + } +} + +// versionMsgToolResult creates a TablePromptVersionMessage for a tool-result +// message (role=tool) with the given tool_call_id and result text. +func versionMsgToolResult(callID, result string) tables.TablePromptVersionMessage { + id := callID + inner := schemas.ChatMessage{ + Role: schemas.ChatMessageRoleTool, + Content: &schemas.ChatMessageContent{ContentStr: &result}, + ChatToolMessage: &schemas.ChatToolMessage{ToolCallID: &id}, + } + innerJSON, err := json.Marshal(inner) + if err != nil { + panic(fmt.Sprintf("versionMsgToolResult: marshal failed: %v", err)) + } + envelope := fmt.Sprintf(`{"payload":%s}`, string(innerJSON)) + return tables.TablePromptVersionMessage{ + Message: tables.PromptMessage(envelope), + } +} + +// makeVersionWithParams returns a TablePromptVersion with explicit ModelParams and messages. +// VersionNumber is set to int(id) so tests can reference versions by their number. +func makeVersionWithParams(id uint, promptID string, isLatest bool, params tables.ModelParams, msgs ...tables.TablePromptVersionMessage) tables.TablePromptVersion { + return tables.TablePromptVersion{ + ID: id, + PromptID: promptID, + IsLatest: isLatest, + VersionNumber: int(id), + ModelParams: params, + Messages: msgs, + } +} + +// chatRequestWithParams returns a BifrostRequest with Params pre-set. +func chatRequestWithParams(params *schemas.ChatParameters, msgs ...schemas.ChatMessage) *schemas.BifrostRequest { + return &schemas.BifrostRequest{ + ChatRequest: &schemas.BifrostChatRequest{ + Input: append([]schemas.ChatMessage{}, msgs...), + Params: params, + }, + } +} + +// chatRequestWithModel returns a BifrostRequest with the Model field pre-set. +func chatRequestWithModel(model string, msgs ...schemas.ChatMessage) *schemas.BifrostRequest { + return &schemas.BifrostRequest{ + ChatRequest: &schemas.BifrostChatRequest{ + Model: model, + Input: append([]schemas.ChatMessage{}, msgs...), + }, + } +} + +// versionMsgAssistantUIFormat creates a TablePromptVersionMessage in the format +// the Bifrost UI writes for assistant (completion_result) messages. +// The message is nested at payload.choices[0].message, matching SerializedMessage. +func versionMsgAssistantUIFormat(text string) tables.TablePromptVersionMessage { + content := text + inner := schemas.ChatMessage{ + Role: schemas.ChatMessageRoleAssistant, + Content: &schemas.ChatMessageContent{ContentStr: &content}, + } + innerJSON, err := json.Marshal(inner) + if err != nil { + panic(fmt.Sprintf("versionMsgAssistantUIFormat: marshal failed: %v", err)) + } + payload := fmt.Sprintf(`{"id":"resp-1","choices":[{"index":0,"message":%s,"finish_reason":"stop"}]}`, string(innerJSON)) + envelope := fmt.Sprintf(`{"originalType":"completion_result","payload":%s}`, payload) + return tables.TablePromptVersionMessage{ + Message: tables.PromptMessage(envelope), + } +} + +// ============================================================ +// errTest — minimal error type for test use +// ============================================================ + +type errTest string + +func (e errTest) Error() string { return string(e) } + +// ============================================================ +// Assertion helpers +// ============================================================ + +// msgText extracts the ContentStr from a ChatMessage, returning "" if absent. +func msgText(msg schemas.ChatMessage) string { + if msg.Content == nil || msg.Content.ContentStr == nil { + return "" + } + return *msg.Content.ContentStr +} diff --git a/plugins/prompts/main.go b/plugins/prompts/main.go new file mode 100644 index 0000000000..05c2507b9f --- /dev/null +++ b/plugins/prompts/main.go @@ -0,0 +1,570 @@ +package prompts + +import ( + "context" + "encoding/json" + "fmt" + "maps" + "strconv" + "strings" + "sync" + + "github.com/maximhq/bifrost/core/schemas" + configstoreTables "github.com/maximhq/bifrost/framework/configstore/tables" +) + +const ( + PluginName = "prompts" + PromptIDHeader = "bf-prompt-id" + PromptVersionHeader = "bf-prompt-version" + PromptIDKey schemas.BifrostContextKey = PromptIDHeader + PromptVersionKey schemas.BifrostContextKey = PromptVersionHeader +) + +type promptStore interface { + GetPrompts(ctx context.Context, folderID *string) ([]configstoreTables.TablePrompt, error) + GetAllPromptVersions(ctx context.Context) ([]configstoreTables.TablePromptVersion, error) +} + +// PromptResolver decides which prompt and version to inject for a given request. +// Returning an empty promptID means no injection for this request. +type PromptResolver interface { + Resolve(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (promptID string, versionNumber int, versionSpecified bool, err error) +} + +// headerResolver is the default OSS resolver: reads prompt ID and version from context +// keys that were populated from HTTP headers in HTTPTransportPreHook. +type headerResolver struct { + logger schemas.Logger +} + +func (r *headerResolver) Resolve(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (string, int, bool, error) { + promptID := promptStringFromCtx(ctx, PromptIDKey) + if promptID == "" { + return "", 0, false, nil + } + versionNumber, specified, err := parsePromptVersionNumber(ctx) + if err != nil { + return "", 0, false, fmt.Errorf("invalid bifrost-prompt-version: %w", err) + } + return promptID, versionNumber, specified, nil +} + +// Plugin resolves stored prompt templates and prepends their messages to LLM requests. +type Plugin struct { + store promptStore + logger schemas.Logger + resolver PromptResolver + + mu sync.RWMutex + promptsByID map[string]*configstoreTables.TablePrompt + versionsByPromptAndNumber map[string]map[int]*configstoreTables.TablePromptVersion +} + +// Init wires the prompts plugin with the default header-based resolver. +func Init(ctx context.Context, store promptStore, logger schemas.Logger) (schemas.LLMPlugin, error) { + return InitWithResolver(ctx, store, &headerResolver{logger: logger}, logger) +} + +// InitWithResolver wires the prompts plugin with a custom resolver. +func InitWithResolver(ctx context.Context, store promptStore, resolver PromptResolver, logger schemas.Logger) (*Plugin, error) { + if store == nil { + return nil, fmt.Errorf("config store is required for prompts plugin") + } + if resolver == nil { + resolver = &headerResolver{logger: logger} + } + p := &Plugin{ + store: store, + logger: logger, + resolver: resolver, + promptsByID: make(map[string]*configstoreTables.TablePrompt), + versionsByPromptAndNumber: make(map[string]map[int]*configstoreTables.TablePromptVersion), + } + if err := p.loadCache(ctx); err != nil { + return nil, fmt.Errorf("failed to load prompts into memory: %w", err) + } + return p, nil +} + +// loadCache rebuilds the in-memory maps with exactly two DB queries: +// one for all prompts (with their latest version), one for all versions. +func (p *Plugin) loadCache(ctx context.Context) error { + prompts, err := p.store.GetPrompts(ctx, nil) + if err != nil { + return err + } + + versions, err := p.store.GetAllPromptVersions(ctx) + if err != nil { + return fmt.Errorf("loading all prompt versions: %w", err) + } + + newPrompts := make(map[string]*configstoreTables.TablePrompt, len(prompts)) + for i := range prompts { + newPrompts[prompts[i].ID] = &prompts[i] + } + + newVersionsByPromptAndNumber := make(map[string]map[int]*configstoreTables.TablePromptVersion) + for i := range versions { + v := &versions[i] + if _, ok := newVersionsByPromptAndNumber[v.PromptID]; !ok { + newVersionsByPromptAndNumber[v.PromptID] = make(map[int]*configstoreTables.TablePromptVersion) + } + newVersionsByPromptAndNumber[v.PromptID][v.VersionNumber] = v + } + + p.mu.Lock() + p.promptsByID = newPrompts + p.versionsByPromptAndNumber = newVersionsByPromptAndNumber + p.mu.Unlock() + return nil +} + +// Reload refreshes the in-memory cache from the store. Called by the HTTP handler +// after any create/update/delete operation on prompts or versions. +func (p *Plugin) Reload(ctx context.Context) error { + return p.loadCache(ctx) +} + +func (p *Plugin) GetName() string { + return PluginName +} + +func (p *Plugin) HTTPTransportPreHook(ctx *schemas.BifrostContext, req *schemas.HTTPRequest) (*schemas.HTTPResponse, error) { + if req == nil { + return nil, nil + } + if id := strings.TrimSpace(req.CaseInsensitiveHeaderLookup(PromptIDHeader)); id != "" { + ctx.SetValue(PromptIDKey, id) + } + if v := strings.TrimSpace(req.CaseInsensitiveHeaderLookup(PromptVersionHeader)); v != "" { + ctx.SetValue(PromptVersionKey, v) + } + p.setPromptStreamFromVersionForTransport(ctx) + return nil, nil +} + +// setPromptStreamFromVersionForTransport sets BifrostContextKeyPromptStreamRequest when +// the resolved prompt version has stream:true in its ModelParams. +func (p *Plugin) setPromptStreamFromVersionForTransport(ctx *schemas.BifrostContext) { + promptID := promptStringFromCtx(ctx, PromptIDKey) + if promptID == "" { + return + } + versionNumber, versionSpecified, err := parsePromptVersionNumber(ctx) + if err != nil { + return + } + _, version, ok := p.resolveVersion(promptID, versionNumber, versionSpecified) + if !ok || version == nil || len(version.ModelParams) == 0 { + return + } + if includesStreamInModelParams(version.ModelParams) { + ctx.SetValue(schemas.BifrostContextKeyPromptStreamRequest, true) + } +} + +func includesStreamInModelParams(mp configstoreTables.ModelParams) bool { + raw, ok := mp["stream"] + if !ok { + return true // default to true if stream is not set, this is done because for the initial version, the stream key is not present but we default to true for the initial version and show it as well on the UI. If the user toggles stream off, we set `stream: false` in the model params in db. + } + switch v := raw.(type) { + case bool: + return v + case json.Number: + if i, err := strconv.ParseInt(string(v), 10, 64); err == nil { + return i != 0 + } + b, err := strconv.ParseBool(string(v)) + return err == nil && b + case string: + switch strings.ToLower(strings.TrimSpace(v)) { + case "true", "1", "yes": + return true + default: + return false + } + default: + return false + } +} + +func (p *Plugin) HTTPTransportPostHook(ctx *schemas.BifrostContext, req *schemas.HTTPRequest, resp *schemas.HTTPResponse) error { + return nil +} + +func (p *Plugin) HTTPTransportStreamChunkHook(ctx *schemas.BifrostContext, req *schemas.HTTPRequest, chunk *schemas.BifrostStreamChunk) (*schemas.BifrostStreamChunk, error) { + return chunk, nil +} + +func (p *Plugin) PreLLMHook(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.LLMPluginShortCircuit, error) { + promptID, versionNumber, versionSpecified, err := p.resolver.Resolve(ctx, req) + if err != nil { + p.logger.Warn("prompts plugin: failed to resolve prompt: %v", err) + return req, nil, nil + } + if promptID == "" { + return req, nil, nil + } + + _, version, found := p.resolveVersion(promptID, versionNumber, versionSpecified) + if !found { + p.logger.Warn("prompts plugin: prompt or version not found: %s", promptID) + return req, nil, nil + } + + if version == nil { + p.logger.Warn("prompts plugin: prompt %s has no versions", promptID) + return req, nil, nil + } + + // Apply model params from the version (version params are defaults; request params win). + switch { + case req.ChatRequest != nil: + applyVersionParamsToChatRequest(version, req.ChatRequest, p.logger) + case req.ResponsesRequest != nil: + applyVersionParamsToResponsesRequest(version, req.ResponsesRequest, p.logger) + } + + template, err := chatMessagesFromVersionMessages(version.Messages) + if err != nil { + p.logger.Warn("prompts plugin: failed to parse messages for prompt %s: %v", promptID, err) + return req, nil, nil + } + if len(template) == 0 { + return req, nil, nil + } + + switch { + case req.ChatRequest != nil: + mergeChatMessages(&req.ChatRequest.Input, template) + case req.ResponsesRequest != nil: + mergeResponsesMessages(&req.ResponsesRequest.Input, template) + } + + return req, nil, nil +} + +func (p *Plugin) PostLLMHook(ctx *schemas.BifrostContext, resp *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) { + return resp, bifrostErr, nil +} + +// knownSyntheticChatParamKeys are flat JSON keys that ChatParameters.UnmarshalJSON +// promotes into nested structs. They should not be treated as ExtraParams even though +// they won't appear as top-level keys in a re-marshaled ChatParameters. +var knownSyntheticChatParamKeys = map[string]struct{}{ + "reasoning_effort": {}, + "reasoning_max_tokens": {}, +} + +// buildMergedParamsMap builds a merged map[string]interface{} where version params +// serve as defaults and request params take priority. reqParamsBytes is the JSON of +// the request's standard params (ExtraParams excluded); reqExtraParams is its ExtraParams map. +func buildMergedParamsMap(versionParams configstoreTables.ModelParams, reqParamsBytes []byte, reqExtraParams map[string]interface{}) (map[string]interface{}, error) { + merged := make(map[string]interface{}, len(versionParams)) + maps.Copy(merged, versionParams) + if len(reqParamsBytes) > 0 && string(reqParamsBytes) != "null" { + var reqMap map[string]interface{} + if err := schemas.Unmarshal(reqParamsBytes, &reqMap); err != nil { + return nil, fmt.Errorf("unmarshal request params: %w", err) + } + maps.Copy(merged, reqMap) + } + maps.Copy(merged, reqExtraParams) + return merged, nil +} + +// applyVersionParamsToChatRequest applies the prompt version's ModelParams to the +// chat request. Version params are defaults; params already set in the request win. +func applyVersionParamsToChatRequest(version *configstoreTables.TablePromptVersion, req *schemas.BifrostChatRequest, logger schemas.Logger) { + if len(version.ModelParams) == 0 { + return + } + + var reqParamsBytes []byte + var reqExtraParams map[string]interface{} + if req.Params != nil { + b, err := schemas.Marshal(req.Params) + if err != nil { + logger.Warn("prompts plugin: failed to marshal chat request params: %v", err) + return + } + reqParamsBytes = b + reqExtraParams = req.Params.ExtraParams + } + + merged, err := buildMergedParamsMap(version.ModelParams, reqParamsBytes, reqExtraParams) + if err != nil { + logger.Warn("prompts plugin: failed to build merged chat params: %v", err) + return + } + + mergedJSON, err := schemas.Marshal(merged) + if err != nil { + logger.Warn("prompts plugin: failed to marshal merged chat params: %v", err) + return + } + + var result schemas.ChatParameters + if err := schemas.Unmarshal(mergedJSON, &result); err != nil { + logger.Warn("prompts plugin: failed to unmarshal merged chat params: %v", err) + return + } + + // Detect keys from merged that were not recognized as standard ChatParameters fields + // (i.e. they won't appear in the re-marshaled output) and put them in ExtraParams. + var recognizedMap map[string]interface{} + recognizedBytes, err := schemas.Marshal(&result) + if err != nil { + logger.Warn("prompts plugin: failed to marshal result chat params: %v", err) + return + } + if err := schemas.Unmarshal(recognizedBytes, &recognizedMap); err != nil { + logger.Warn("prompts plugin: failed to unmarshal recognized chat params: %v", err) + return + } + for k, v := range merged { + if _, ok := recognizedMap[k]; ok { + continue + } + if _, synthetic := knownSyntheticChatParamKeys[k]; synthetic { + continue + } + if result.ExtraParams == nil { + result.ExtraParams = make(map[string]interface{}) + } + if _, alreadySet := result.ExtraParams[k]; !alreadySet { + result.ExtraParams[k] = v + } + } + + req.Params = &result +} + +// applyVersionParamsToResponsesRequest applies the prompt version's ModelParams to the +// responses request. Version params are defaults; params already set in the request win. +func applyVersionParamsToResponsesRequest(version *configstoreTables.TablePromptVersion, req *schemas.BifrostResponsesRequest, logger schemas.Logger) { + if len(version.ModelParams) == 0 { + return + } + + var reqParamsBytes []byte + var reqExtraParams map[string]interface{} + if req.Params != nil { + b, err := schemas.Marshal(req.Params) + if err != nil { + logger.Warn("prompts plugin: failed to marshal responses request params: %v", err) + return + } + reqParamsBytes = b + reqExtraParams = req.Params.ExtraParams + } + + merged, err := buildMergedParamsMap(version.ModelParams, reqParamsBytes, reqExtraParams) + if err != nil { + logger.Warn("prompts plugin: failed to build merged responses params: %v", err) + return + } + + mergedJSON, err := schemas.Marshal(merged) + if err != nil { + logger.Warn("prompts plugin: failed to marshal merged responses params: %v", err) + return + } + + var result schemas.ResponsesParameters + if err := schemas.Unmarshal(mergedJSON, &result); err != nil { + logger.Warn("prompts plugin: failed to unmarshal merged responses params: %v", err) + return + } + + // Detect unrecognized keys and add them to ExtraParams. + var recognizedMap map[string]interface{} + recognizedBytes, err := schemas.Marshal(&result) + if err != nil { + logger.Warn("prompts plugin: failed to marshal result responses params: %v", err) + return + } + if err := schemas.Unmarshal(recognizedBytes, &recognizedMap); err != nil { + logger.Warn("prompts plugin: failed to unmarshal recognized responses params: %v", err) + return + } + for k, v := range merged { + if _, ok := recognizedMap[k]; ok { + continue + } + if result.ExtraParams == nil { + result.ExtraParams = make(map[string]interface{}) + } + if _, alreadySet := result.ExtraParams[k]; !alreadySet { + result.ExtraParams[k] = v + } + } + + req.Params = &result +} + +// resolveVersion centralises the map-lookup logic shared by setPromptStreamFromVersionForTransport +// and PreLLMHook. It returns the prompt and its resolved version (either the explicitly requested +// version or the prompt's latest version), plus a bool indicating whether both were found. +func (p *Plugin) resolveVersion(promptID string, versionNumber int, versionSpecified bool) ( + *configstoreTables.TablePrompt, *configstoreTables.TablePromptVersion, bool, +) { + p.mu.RLock() + defer p.mu.RUnlock() + + prompt, ok := p.promptsByID[promptID] + if !ok || prompt == nil { + return nil, nil, false + } + if !versionSpecified { + return prompt, prompt.LatestVersion, true + } + byNumber, ok := p.versionsByPromptAndNumber[promptID] + if !ok { + return nil, nil, false + } + v, found := byNumber[versionNumber] + if !found || v == nil { + return nil, nil, false + } + return prompt, v, true +} + +func (p *Plugin) Cleanup() error { + return nil +} + +func promptStringFromCtx(ctx *schemas.BifrostContext, key schemas.BifrostContextKey) string { + if v, ok := ctx.Value(key).(string); ok { + return strings.TrimSpace(v) + } + return "" +} + +func parsePromptVersionNumber(ctx *schemas.BifrostContext) (num int, specified bool, err error) { + s, ok := ctx.Value(PromptVersionKey).(string) + if !ok { + return 0, false, nil + } + s = strings.TrimSpace(s) + if s == "" { + return 0, false, nil + } + n, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return 0, true, err + } + return int(n), true, nil +} + +func chatMessagePopulated(cm schemas.ChatMessage) bool { + if strings.TrimSpace(string(cm.Role)) != "" { + return true + } + if cm.Content != nil { + return true + } + if cm.Name != nil && strings.TrimSpace(*cm.Name) != "" { + return true + } + if cm.ChatToolMessage != nil { + return true + } + if cm.ChatAssistantMessage != nil { + return true + } + return false +} + +// convertVersionMessagesToChatMessages unmarshals prompt-repo JSON into ChatMessage. +func convertVersionMessagesToChatMessages(data []byte) (schemas.ChatMessage, error) { + s := strings.TrimSpace(string(data)) + if s == "" || s == "null" { + return schemas.ChatMessage{}, fmt.Errorf("empty message") + } + data = []byte(s) + + var msg struct { + OriginalType string `json:"originalType"` + Payload json.RawMessage `json:"payload"` + } + if err := schemas.Unmarshal(data, &msg); err == nil { + ps := strings.TrimSpace(string(msg.Payload)) + if ps != "" && ps != "null" { + if msg.OriginalType == "completion_result" { + var result struct { + Choices []struct { + Message *schemas.ChatMessage `json:"message"` + } `json:"choices"` + } + if err := schemas.Unmarshal([]byte(ps), &result); err == nil && + len(result.Choices) > 0 && result.Choices[0].Message != nil { + if chatMessagePopulated(*result.Choices[0].Message) { + return *result.Choices[0].Message, nil + } + } + } + + // completion_request / tool_result / legacy envelope: payload is a direct ChatMessage. + var message schemas.ChatMessage + if err := schemas.Unmarshal([]byte(ps), &message); err != nil { + return schemas.ChatMessage{}, fmt.Errorf("decoding prompt message envelope payload: %w", err) + } + if chatMessagePopulated(message) { + return message, nil + } + } + } + + var chatMessage schemas.ChatMessage + if err := schemas.Unmarshal(data, &chatMessage); err != nil { + return schemas.ChatMessage{}, err + } + return chatMessage, nil +} + +func chatMessagesFromVersionMessages(messages []configstoreTables.TablePromptVersionMessage) ([]schemas.ChatMessage, error) { + out := make([]schemas.ChatMessage, 0, len(messages)) + for i := range messages { + row := &messages[i] + data := row.Message + if len(data) == 0 && row.MessageJSON != "" { + data = []byte(row.MessageJSON) + } + cm, err := convertVersionMessagesToChatMessages(data) + if err != nil { + return nil, fmt.Errorf("stored prompt message is not valid chat JSON: %w", err) + } + out = append(out, cm) + } + return out, nil +} + +func mergeChatMessages(dest *[]schemas.ChatMessage, prefix []schemas.ChatMessage) { + if dest == nil || len(prefix) == 0 { + return + } + cur := *dest + merged := make([]schemas.ChatMessage, 0, len(prefix)+len(cur)) + merged = append(merged, prefix...) + merged = append(merged, cur...) + *dest = merged +} + +func mergeResponsesMessages(dest *[]schemas.ResponsesMessage, template []schemas.ChatMessage) { + if dest == nil || len(template) == 0 { + return + } + var prefix []schemas.ResponsesMessage + for i := range template { + prefix = append(prefix, template[i].ToResponsesMessages()...) + } + cur := *dest + merged := make([]schemas.ResponsesMessage, 0, len(prefix)+len(cur)) + merged = append(merged, prefix...) + merged = append(merged, cur...) + *dest = merged +} diff --git a/plugins/prompts/plugin_test.go b/plugins/prompts/plugin_test.go new file mode 100644 index 0000000000..6202e8bef7 --- /dev/null +++ b/plugins/prompts/plugin_test.go @@ -0,0 +1,1065 @@ +package prompts + +import ( + "context" + "encoding/json" + "testing" + + "github.com/maximhq/bifrost/core/schemas" + tables "github.com/maximhq/bifrost/framework/configstore/tables" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// ============================================================ +// InitWithResolver +// ============================================================ + +func TestInitWithResolver_NilStore(t *testing.T) { + _, err := InitWithResolver(context.Background(), nil, &staticResolver{}, NewMockLogger()) + require.Error(t, err, "expected error for nil store") +} + +func TestInitWithResolver_NilResolverFallsBackToHeader(t *testing.T) { + ms := &mockStore{} + p, err := InitWithResolver(context.Background(), ms, nil, NewMockLogger()) + require.NoError(t, err) + require.NotNil(t, p) + _, ok := p.resolver.(*headerResolver) + assert.True(t, ok, "expected headerResolver, got %T", p.resolver) +} + +// ============================================================ +// loadCache +// ============================================================ + +func TestLoadCache_EmptyStore(t *testing.T) { + p := newPluginWithStore(&mockStore{}) + require.NoError(t, p.loadCache(context.Background())) + assert.Empty(t, p.promptsByID) + assert.Empty(t, p.versionsByPromptAndNumber) +} + +func TestLoadCache_PopulatesMaps(t *testing.T) { + v1 := makeVersion(1, "p1", true, versionMsg(schemas.ChatMessageRoleSystem, "Hello")) + v2 := makeVersion(2, "p2", true) + p1 := makePrompt("p1", &v1) + p2 := makePrompt("p2", &v2) + + p := newPluginWithStore(&mockStore{ + prompts: []tables.TablePrompt{p1, p2}, + versions: []tables.TablePromptVersion{v1, v2}, + }) + + require.NoError(t, p.loadCache(context.Background())) + assert.Len(t, p.promptsByID, 2) + assert.Len(t, p.versionsByPromptAndNumber, 2) + assert.NotNil(t, p.promptsByID["p1"]) + assert.NotNil(t, p.versionsByPromptAndNumber["p1"][1]) +} + +func TestLoadCache_GetPromptsError(t *testing.T) { + p := newPluginWithStore(&mockStore{err: errTest("boom")}) + err := p.loadCache(context.Background()) + require.Error(t, err) +} + +func TestLoadCache_GetVersionsError(t *testing.T) { + p := newPluginWithStore(&versionsErrStore{ + prompts: []tables.TablePrompt{makePrompt("p1", nil)}, + err: errTest("versions boom"), + }) + err := p.loadCache(context.Background()) + require.Error(t, err) + assert.Contains(t, err.Error(), "versions boom") +} + +// ============================================================ +// PreLLMHook +// ============================================================ + +func TestPreLLMHook_NoPromptID(t *testing.T) { + p := newTestPlugin(&staticResolver{promptID: ""}, nil, nil) + out, sc, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("hello"))) + require.NoError(t, err) + assert.Nil(t, sc) + assert.Len(t, out.ChatRequest.Input, 1) +} + +func TestPreLLMHook_PromptNotFound(t *testing.T) { + log := NewMockLogger() + p := newTestPluginWithLogger(&staticResolver{promptID: "missing"}, nil, nil, log) + + out, _, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("hello"))) + require.NoError(t, err) + assert.Len(t, out.ChatRequest.Input, 1, "input should be unchanged") + assert.True(t, log.Warned(), "expected a warning for unknown prompt") +} + +func TestPreLLMHook_UseLatestVersion(t *testing.T) { + v := makeVersion(1, "p1", true, + versionMsg(schemas.ChatMessageRoleSystem, "Be helpful"), + ) + prompt := makePrompt("p1", &v) + + p := newTestPlugin( + &staticResolver{promptID: "p1", versionSpecified: false}, + map[string]*tables.TablePrompt{"p1": &prompt}, + map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}}, + ) + + out, _, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("hello"))) + require.NoError(t, err) + require.Len(t, out.ChatRequest.Input, 2, "expected system prompt + user message") + + assert.Equal(t, schemas.ChatMessageRoleSystem, out.ChatRequest.Input[0].Role) + assert.Equal(t, "Be helpful", msgText(out.ChatRequest.Input[0])) + + assert.Equal(t, schemas.ChatMessageRoleUser, out.ChatRequest.Input[1].Role) + assert.Equal(t, "hello", msgText(out.ChatRequest.Input[1])) +} + +func TestPreLLMHook_UseSpecificVersion(t *testing.T) { + vLatest := makeVersion(1, "p1", true, + versionMsg(schemas.ChatMessageRoleSystem, "latest system prompt"), + ) + vOld := makeVersion(2, "p1", false, + versionMsg(schemas.ChatMessageRoleSystem, "old system prompt"), + ) + prompt := makePrompt("p1", &vLatest) + + p := newTestPlugin( + &staticResolver{promptID: "p1", versionNumber: 2, versionSpecified: true}, + map[string]*tables.TablePrompt{"p1": &prompt}, + map[string]map[int]*tables.TablePromptVersion{"p1": {1: &vLatest, 2: &vOld}}, + ) + + out, _, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("hello"))) + require.NoError(t, err) + require.Len(t, out.ChatRequest.Input, 2) + + // Must use vOld, not vLatest. + assert.Equal(t, schemas.ChatMessageRoleSystem, out.ChatRequest.Input[0].Role) + assert.Equal(t, "old system prompt", msgText(out.ChatRequest.Input[0])) +} + +func TestPreLLMHook_VersionNotFound(t *testing.T) { + v := makeVersion(1, "p1", true, versionMsg(schemas.ChatMessageRoleSystem, "hello")) + prompt := makePrompt("p1", &v) + log := NewMockLogger() + + p := newTestPluginWithLogger( + &staticResolver{promptID: "p1", versionNumber: 99, versionSpecified: true}, + map[string]*tables.TablePrompt{"p1": &prompt}, + map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}}, + log, + ) + + out, _, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("hi"))) + require.NoError(t, err) + assert.Len(t, out.ChatRequest.Input, 1, "input should be unchanged") + assert.True(t, log.Warned(), "expected warning for missing version") +} + +func TestPreLLMHook_VersionBelongsToDifferentPrompt(t *testing.T) { + v := makeVersion(1, "p2", true, versionMsg(schemas.ChatMessageRoleSystem, "wrong")) + prompt := makePrompt("p1", nil) + log := NewMockLogger() + + p := newTestPluginWithLogger( + &staticResolver{promptID: "p1", versionNumber: 1, versionSpecified: true}, + map[string]*tables.TablePrompt{"p1": &prompt}, + map[string]map[int]*tables.TablePromptVersion{"p2": {1: &v}}, + log, + ) + + out, _, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("hi"))) + require.NoError(t, err) + assert.Len(t, out.ChatRequest.Input, 1, "input should be unchanged") + assert.True(t, log.Warned(), "expected warning for version/prompt mismatch") +} + +func TestPreLLMHook_NoLatestVersion(t *testing.T) { + prompt := makePrompt("p1", nil) // LatestVersion is nil + log := NewMockLogger() + + p := newTestPluginWithLogger( + &staticResolver{promptID: "p1", versionSpecified: false}, + map[string]*tables.TablePrompt{"p1": &prompt}, + nil, + log, + ) + + out, _, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("hi"))) + require.NoError(t, err) + assert.Len(t, out.ChatRequest.Input, 1, "input should be unchanged") + assert.True(t, log.Warned(), "expected warning for missing latest version") +} + +func TestPreLLMHook_EmptyTemplate(t *testing.T) { + v := makeVersion(1, "p1", true) // no messages + prompt := makePrompt("p1", &v) + + p := newTestPlugin( + &staticResolver{promptID: "p1", versionSpecified: false}, + map[string]*tables.TablePrompt{"p1": &prompt}, + map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}}, + ) + + out, _, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("hi"))) + require.NoError(t, err) + assert.Len(t, out.ChatRequest.Input, 1) +} + +func TestPreLLMHook_MultipleTemplateMessages(t *testing.T) { + v := makeVersion(1, "p1", true, + versionMsg(schemas.ChatMessageRoleSystem, "sys prompt"), + versionMsg(schemas.ChatMessageRoleUser, "example input"), + versionMsg(schemas.ChatMessageRoleAssistant, "example output"), + ) + prompt := makePrompt("p1", &v) + + p := newTestPlugin( + &staticResolver{promptID: "p1", versionSpecified: false}, + map[string]*tables.TablePrompt{"p1": &prompt}, + map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}}, + ) + + out, _, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("actual question"))) + require.NoError(t, err) + require.Len(t, out.ChatRequest.Input, 4, "expected 3 template messages + 1 original") + + assert.Equal(t, schemas.ChatMessageRoleSystem, out.ChatRequest.Input[0].Role) + assert.Equal(t, "sys prompt", msgText(out.ChatRequest.Input[0])) + + assert.Equal(t, schemas.ChatMessageRoleUser, out.ChatRequest.Input[1].Role) + assert.Equal(t, "example input", msgText(out.ChatRequest.Input[1])) + + assert.Equal(t, schemas.ChatMessageRoleAssistant, out.ChatRequest.Input[2].Role) + assert.Equal(t, "example output", msgText(out.ChatRequest.Input[2])) + + // Original user message must be last, content preserved. + assert.Equal(t, schemas.ChatMessageRoleUser, out.ChatRequest.Input[3].Role) + assert.Equal(t, "actual question", msgText(out.ChatRequest.Input[3])) +} + +func TestPreLLMHook_ResolverError(t *testing.T) { + log := NewMockLogger() + p := newTestPluginWithLogger( + &staticResolver{err: errTest("resolver failed")}, + nil, nil, log, + ) + + out, _, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("hi"))) + require.NoError(t, err, "PreLLMHook must not propagate resolver errors") + assert.Len(t, out.ChatRequest.Input, 1, "input should be unchanged") + assert.True(t, log.Warned(), "expected warning for resolver error") +} + +func TestPreLLMHook_MessageJSON_FallbackPath(t *testing.T) { + // Messages where Message ([]byte) is nil but MessageJSON is set — the fallback + // branch in chatMessagesFromVersionMessages. This mirrors rows loaded from + // an older DB schema before AfterFind was established. + v := makeVersion(1, "p1", true, + versionMsgViaJSON(schemas.ChatMessageRoleSystem, "from json field"), + ) + prompt := makePrompt("p1", &v) + + p := newTestPlugin( + &staticResolver{promptID: "p1", versionSpecified: false}, + map[string]*tables.TablePrompt{"p1": &prompt}, + map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}}, + ) + + out, _, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("hi"))) + require.NoError(t, err) + require.Len(t, out.ChatRequest.Input, 2) + + assert.Equal(t, schemas.ChatMessageRoleSystem, out.ChatRequest.Input[0].Role) + assert.Equal(t, "from json field", msgText(out.ChatRequest.Input[0])) +} + +func TestPreLLMHook_ResponsesRequest(t *testing.T) { + v := makeVersion(1, "p1", true, + versionMsg(schemas.ChatMessageRoleSystem, "be concise"), + ) + prompt := makePrompt("p1", &v) + + p := newTestPlugin( + &staticResolver{promptID: "p1", versionSpecified: false}, + map[string]*tables.TablePrompt{"p1": &prompt}, + map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}}, + ) + + userRole := schemas.ResponsesMessageRoleType("user") + req := &schemas.BifrostRequest{ + ResponsesRequest: &schemas.BifrostResponsesRequest{ + Input: []schemas.ResponsesMessage{{Role: &userRole}}, + }, + } + + out, _, err := p.PreLLMHook(bfCtx(), req) + require.NoError(t, err) + // Template message(s) prepended before the original user input. + assert.Greater(t, len(out.ResponsesRequest.Input), 1, "expected template prepended before user message") + // Original user message must still be last. + last := out.ResponsesRequest.Input[len(out.ResponsesRequest.Input)-1] + assert.Equal(t, schemas.ResponsesMessageRoleType("user"), *last.Role) +} + +// TestPreLLMHook_PromptSystemMsg_PlusUserInputSystemMsg verifies that when the +// prompt template contains a system message and the incoming request also starts +// with a system message, both system messages are forwarded to the model — +// the plugin's only job is prepending, not de-duplicating or filtering roles. +func TestPreLLMHook_PromptSystemMsg_PlusUserInputSystemMsg(t *testing.T) { + v := makeVersion(1, "p1", true, + versionMsg(schemas.ChatMessageRoleSystem, "prompt system"), + ) + prompt := makePrompt("p1", &v) + + p := newTestPlugin( + &staticResolver{promptID: "p1", versionSpecified: false}, + map[string]*tables.TablePrompt{"p1": &prompt}, + map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}}, + ) + + // Incoming request already has its own system message before the user turn. + out, _, err := p.PreLLMHook(bfCtx(), chatRequest( + systemMsg("user-side system context"), + userMsg("actual question"), + )) + require.NoError(t, err) + require.Len(t, out.ChatRequest.Input, 3, "expected prompt system + user system + user message") + + assert.Equal(t, schemas.ChatMessageRoleSystem, out.ChatRequest.Input[0].Role) + assert.Equal(t, "prompt system", msgText(out.ChatRequest.Input[0])) + + assert.Equal(t, schemas.ChatMessageRoleSystem, out.ChatRequest.Input[1].Role) + assert.Equal(t, "user-side system context", msgText(out.ChatRequest.Input[1])) + + assert.Equal(t, schemas.ChatMessageRoleUser, out.ChatRequest.Input[2].Role) + assert.Equal(t, "actual question", msgText(out.ChatRequest.Input[2])) +} + +// TestPreLLMHook_PromptWithToolCallMessages_PlusUserMessage verifies that when +// the prompt template contains a full tool-call turn (system → assistant with +// tool_calls → tool result) and the user sends a new message, the entire +// template is prepended and all fields (ToolCalls, ToolCallID) are preserved. +func TestPreLLMHook_PromptWithToolCallMessages_PlusUserMessage(t *testing.T) { + const callID = "call_abc123" + v := makeVersion(1, "p1", true, + versionMsg(schemas.ChatMessageRoleSystem, "you are a weather bot"), + versionMsgWithToolCall(callID, "get_weather", `{"city":"Paris"}`), + versionMsgToolResult(callID, "Sunny, 22°C"), + ) + prompt := makePrompt("p1", &v) + + p := newTestPlugin( + &staticResolver{promptID: "p1", versionSpecified: false}, + map[string]*tables.TablePrompt{"p1": &prompt}, + map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}}, + ) + + out, _, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("what about tomorrow?"))) + require.NoError(t, err) + require.Len(t, out.ChatRequest.Input, 4, "expected system + assistant(tool_calls) + tool_result + user") + + // System message from prompt. + assert.Equal(t, schemas.ChatMessageRoleSystem, out.ChatRequest.Input[0].Role) + assert.Equal(t, "you are a weather bot", msgText(out.ChatRequest.Input[0])) + + // Assistant message with tool_calls must carry its ToolCalls slice. + assistantMsg := out.ChatRequest.Input[1] + assert.Equal(t, schemas.ChatMessageRoleAssistant, assistantMsg.Role) + require.NotNil(t, assistantMsg.ChatAssistantMessage, "ChatAssistantMessage must be present") + require.Len(t, assistantMsg.ChatAssistantMessage.ToolCalls, 1) + tc := assistantMsg.ChatAssistantMessage.ToolCalls[0] + require.NotNil(t, tc.ID) + assert.Equal(t, callID, *tc.ID) + require.NotNil(t, tc.Function.Name) + assert.Equal(t, "get_weather", *tc.Function.Name) + assert.Equal(t, `{"city":"Paris"}`, tc.Function.Arguments) + + // Tool result message must carry the ToolCallID. + toolResultMsg := out.ChatRequest.Input[2] + assert.Equal(t, schemas.ChatMessageRoleTool, toolResultMsg.Role) + assert.Equal(t, "Sunny, 22°C", msgText(toolResultMsg)) + require.NotNil(t, toolResultMsg.ChatToolMessage, "ChatToolMessage must be present") + require.NotNil(t, toolResultMsg.ChatToolMessage.ToolCallID) + assert.Equal(t, callID, *toolResultMsg.ChatToolMessage.ToolCallID) + + // Original user message is last. + assert.Equal(t, schemas.ChatMessageRoleUser, out.ChatRequest.Input[3].Role) + assert.Equal(t, "what about tomorrow?", msgText(out.ChatRequest.Input[3])) +} + +// TestPreLLMHook_MultipleSystemMessages_InPromptTemplate verifies that a prompt +// template may itself contain multiple system messages and all of them are +// prepended before the user's input in the original order. +func TestPreLLMHook_MultipleSystemMessages_InPromptTemplate(t *testing.T) { + v := makeVersion(1, "p1", true, + versionMsg(schemas.ChatMessageRoleSystem, "first system"), + versionMsg(schemas.ChatMessageRoleSystem, "second system"), + ) + prompt := makePrompt("p1", &v) + + p := newTestPlugin( + &staticResolver{promptID: "p1", versionSpecified: false}, + map[string]*tables.TablePrompt{"p1": &prompt}, + map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}}, + ) + + out, _, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("hello"))) + require.NoError(t, err) + require.Len(t, out.ChatRequest.Input, 3, "expected 2 system messages + user message") + + assert.Equal(t, schemas.ChatMessageRoleSystem, out.ChatRequest.Input[0].Role) + assert.Equal(t, "first system", msgText(out.ChatRequest.Input[0])) + + assert.Equal(t, schemas.ChatMessageRoleSystem, out.ChatRequest.Input[1].Role) + assert.Equal(t, "second system", msgText(out.ChatRequest.Input[1])) + + assert.Equal(t, schemas.ChatMessageRoleUser, out.ChatRequest.Input[2].Role) + assert.Equal(t, "hello", msgText(out.ChatRequest.Input[2])) +} + +// ============================================================ +// HTTPTransportPreHook +// ============================================================ + +func TestHTTPTransportPreHook_NilRequest(t *testing.T) { + p := newTestPlugin(nil, nil, nil) + resp, err := p.HTTPTransportPreHook(bfCtx(), nil) + assert.NoError(t, err) + assert.Nil(t, resp) +} + +func TestHTTPTransportPreHook_SetsPromptID(t *testing.T) { + p := newTestPlugin(nil, nil, nil) + ctx := bfCtx() + req := &schemas.HTTPRequest{ + Headers: map[string]string{PromptIDHeader: "my-prompt"}, + } + + _, _ = p.HTTPTransportPreHook(ctx, req) + + got, _ := ctx.Value(PromptIDKey).(string) + assert.Equal(t, "my-prompt", got) +} + +func TestHTTPTransportPreHook_SetsVersionID(t *testing.T) { + p := newTestPlugin(nil, nil, nil) + ctx := bfCtx() + req := &schemas.HTTPRequest{ + Headers: map[string]string{PromptVersionHeader: "42"}, + } + + _, _ = p.HTTPTransportPreHook(ctx, req) + + got, _ := ctx.Value(PromptVersionKey).(string) + assert.Equal(t, "42", got) +} + +func TestHTTPTransportPreHook_TrimsWhitespace(t *testing.T) { + p := newTestPlugin(nil, nil, nil) + ctx := bfCtx() + req := &schemas.HTTPRequest{ + Headers: map[string]string{PromptIDHeader: " padded "}, + } + + _, _ = p.HTTPTransportPreHook(ctx, req) + + got, _ := ctx.Value(PromptIDKey).(string) + assert.Equal(t, "padded", got) +} + +func TestHTTPTransportPreHook_WhitespaceOnlyNotSet(t *testing.T) { + p := newTestPlugin(nil, nil, nil) + ctx := bfCtx() + req := &schemas.HTTPRequest{ + Headers: map[string]string{PromptIDHeader: " "}, + } + + _, _ = p.HTTPTransportPreHook(ctx, req) + + assert.Nil(t, ctx.Value(PromptIDKey), "whitespace-only header must not be stored in context") +} + +func TestHTTPTransportPreHook_CaseInsensitiveHeaders(t *testing.T) { + p := newTestPlugin(nil, nil, nil) + ctx := bfCtx() + // "Bf-Prompt-Id" is a title-case variant of the canonical "bf-prompt-id". + req := &schemas.HTTPRequest{ + Headers: map[string]string{"Bf-Prompt-Id": "upper-case"}, + } + + _, _ = p.HTTPTransportPreHook(ctx, req) + + got, _ := ctx.Value(PromptIDKey).(string) + assert.Equal(t, "upper-case", got) +} + +// ============================================================ +// chatMessageFromStoredJSON +// ============================================================ + +func TestChatMessageFromStoredJSON(t *testing.T) { + systemText := "you are helpful" + directMsg := schemas.ChatMessage{ + Role: schemas.ChatMessageRoleSystem, + Content: &schemas.ChatMessageContent{ContentStr: &systemText}, + } + directJSON, _ := json.Marshal(directMsg) + envelopeJSON := []byte(`{"payload":` + string(directJSON) + `}`) + + tests := []struct { + name string + input []byte + wantErr bool + wantRole schemas.ChatMessageRole + wantText string + }{ + { + name: "direct format", + input: directJSON, + wantRole: schemas.ChatMessageRoleSystem, + wantText: systemText, + }, + { + name: "envelope format", + input: envelopeJSON, + wantRole: schemas.ChatMessageRoleSystem, + wantText: systemText, + }, + { + // UI format for assistant messages: originalType=completion_result, + // payload is a BifrostChatResponse; message lives at choices[0].message. + name: "completion_result envelope (UI assistant format)", + input: []byte(`{"originalType":"completion_result","payload":{"id":"r1","choices":[{"index":0,"message":{"role":"assistant","content":"hi there"},"finish_reason":"stop"}]}}`), + wantRole: schemas.ChatMessageRoleAssistant, + wantText: "hi there", + }, + { + // completion_result with no choices falls through to direct ChatMessage parse. + name: "completion_result envelope with empty choices", + input: []byte(`{"originalType":"completion_result","payload":{"id":"r1","choices":[]}}`), + wantErr: false, + wantRole: "", + wantText: "", + }, + { + name: "empty bytes", + input: []byte(""), + wantErr: true, + }, + { + name: "null bytes", + input: []byte("null"), + wantErr: true, + }, + { + name: "whitespace only", + input: []byte(" "), + wantErr: true, + }, + { + name: "malformed envelope payload", + input: []byte(`{"payload":"not-a-chat-message"}`), + wantErr: true, + }, + { + // {"payload":null} — envelope path is skipped (payload is "null"), + // falls through to direct decode of the outer object as ChatMessage. + // schemas.Unmarshal succeeds on an unknown-field object → empty ChatMessage, no error. + name: "envelope with null payload falls through to direct decode", + input: []byte(`{"payload":null}`), + wantErr: false, + wantRole: "", + wantText: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := convertVersionMessagesToChatMessages(tt.input) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.wantRole, got.Role) + assert.Equal(t, tt.wantText, msgText(got)) + }) + } +} + +func TestChatMessageFromStoredJSON_EnvelopeWithEmptyStringPayload(t *testing.T) { + // {"payload":""} — the payload field is a non-null, non-empty JSON string `""`. + // The envelope path attempts to unmarshal `""` (a JSON string literal) into + // schemas.ChatMessage (a struct), which fails. The error is returned directly; + // there is no further fallback. + input := []byte(`{"payload":""}`) + _, err := convertVersionMessagesToChatMessages(input) + require.Error(t, err) + assert.Contains(t, err.Error(), "decoding prompt message envelope payload") +} + +// ============================================================ +// parsePromptVersionNumber +// ============================================================ + +func TestParsePromptVersionNumber(t *testing.T) { + type want struct { + num int + specified bool + wantErr bool + } + + tests := []struct { + name string + value any // stored in context; nil means don't set + want want + }{ + {name: "nil — not specified", value: nil, want: want{0, false, false}}, + {name: "string valid", value: "99", want: want{99, true, false}}, + {name: "string empty", value: "", want: want{0, false, false}}, + {name: "string whitespace", value: " ", want: want{0, false, false}}, + {name: "string invalid", value: "abc", want: want{0, true, true}}, + {name: "unknown type", value: struct{}{}, want: want{0, false, false}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := bfCtx() + if tt.value != nil { + ctx.SetValue(PromptVersionKey, tt.value) + } + + num, specified, err := parsePromptVersionNumber(ctx) + + if tt.want.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.want.specified, specified) + assert.Equal(t, tt.want.num, num) + }) + } +} + +// ============================================================ +// mergeChatMessages +// ============================================================ + +func TestMergeChatMessages(t *testing.T) { + t.Run("prepends prefix before existing messages", func(t *testing.T) { + dest := []schemas.ChatMessage{userMsg("original")} + prefix := []schemas.ChatMessage{systemMsg("system")} + mergeChatMessages(&dest, prefix) + + require.Len(t, dest, 2) + assert.Equal(t, schemas.ChatMessageRoleSystem, dest[0].Role) + assert.Equal(t, "system", msgText(dest[0])) + assert.Equal(t, schemas.ChatMessageRoleUser, dest[1].Role) + assert.Equal(t, "original", msgText(dest[1])) + }) + + t.Run("nil dest is a no-op", func(t *testing.T) { + // Must not panic. + mergeChatMessages(nil, []schemas.ChatMessage{systemMsg("x")}) + }) + + t.Run("empty prefix is a no-op", func(t *testing.T) { + dest := []schemas.ChatMessage{userMsg("only")} + mergeChatMessages(&dest, nil) + assert.Len(t, dest, 1) + assert.Equal(t, "only", msgText(dest[0])) + }) +} + +// ============================================================ +// chatMessagesFromVersionMessages +// ============================================================ + +func TestChatMessagesFromVersionMessages_SingleMessage(t *testing.T) { + msg := versionMsg(schemas.ChatMessageRoleUser, "hello") + out, err := chatMessagesFromVersionMessages([]tables.TablePromptVersionMessage{msg}) + require.NoError(t, err) + require.Len(t, out, 1) + assert.Equal(t, schemas.ChatMessageRoleUser, out[0].Role) + assert.Equal(t, "hello", msgText(out[0])) +} + +func TestChatMessagesFromVersionMessages_MessageJSONFallback(t *testing.T) { + // Row has no Message bytes but has MessageJSON — exercises the fallback branch. + msg := versionMsgViaJSON(schemas.ChatMessageRoleAssistant, "assistant reply") + out, err := chatMessagesFromVersionMessages([]tables.TablePromptVersionMessage{msg}) + require.NoError(t, err) + require.Len(t, out, 1) + assert.Equal(t, schemas.ChatMessageRoleAssistant, out[0].Role) + assert.Equal(t, "assistant reply", msgText(out[0])) +} + +func TestChatMessagesFromVersionMessages_PreservesOrder(t *testing.T) { + msgs := []tables.TablePromptVersionMessage{ + versionMsg(schemas.ChatMessageRoleSystem, "first"), + versionMsg(schemas.ChatMessageRoleUser, "second"), + versionMsg(schemas.ChatMessageRoleAssistant, "third"), + } + out, err := chatMessagesFromVersionMessages(msgs) + require.NoError(t, err) + require.Len(t, out, 3) + assert.Equal(t, schemas.ChatMessageRoleSystem, out[0].Role) + assert.Equal(t, "first", msgText(out[0])) + assert.Equal(t, schemas.ChatMessageRoleUser, out[1].Role) + assert.Equal(t, "second", msgText(out[1])) + assert.Equal(t, schemas.ChatMessageRoleAssistant, out[2].Role) + assert.Equal(t, "third", msgText(out[2])) +} + +func TestChatMessagesFromVersionMessages_InvalidJSON(t *testing.T) { + bad := tables.TablePromptVersionMessage{Message: []byte(`not-json`)} + _, err := chatMessagesFromVersionMessages([]tables.TablePromptVersionMessage{bad}) + require.Error(t, err) +} + +// ============================================================ +// loadCache + PreLLMHook integration (store → cache → injection) +// ============================================================ + +// ============================================================ +// includesStreamInModelParams +// ============================================================ + +func TestIncludesStreamInModelParams(t *testing.T) { + tests := []struct { + name string + params tables.ModelParams + want bool + }{ + {"bool true", tables.ModelParams{"stream": true}, true}, + {"bool false", tables.ModelParams{"stream": false}, false}, + {"string true", tables.ModelParams{"stream": "true"}, true}, + {"string yes", tables.ModelParams{"stream": "yes"}, true}, + {"string 1", tables.ModelParams{"stream": "1"}, true}, + {"string false", tables.ModelParams{"stream": "false"}, false}, + {"string 0", tables.ModelParams{"stream": "0"}, false}, + {"absent key", tables.ModelParams{"temperature": 0.7}, false}, + {"empty params", tables.ModelParams{}, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, includesStreamInModelParams(tt.params)) + }) + } +} + +// ============================================================ +// HTTPTransportPreHook — stream routing via version ModelParams +// ============================================================ + +// TestHTTPTransportPreHook_StreamTrue_SetsStreamContext verifies that when the +// resolved version has stream:true in ModelParams, the hook marks the bifrost +// context so that the inference handler opens an SSE response. +func TestHTTPTransportPreHook_StreamTrue_SetsStreamContext(t *testing.T) { + v := makeVersionWithParams(1, "p1", true, tables.ModelParams{"stream": true}) + prompt := makePrompt("p1", &v) + + p := newTestPlugin( + nil, + map[string]*tables.TablePrompt{"p1": &prompt}, + map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}}, + ) + + ctx := bfCtx() + req := &schemas.HTTPRequest{Headers: map[string]string{PromptIDHeader: "p1"}} + + _, err := p.HTTPTransportPreHook(ctx, req) + require.NoError(t, err) + + streamVal, _ := ctx.Value(schemas.BifrostContextKeyPromptStreamRequest).(bool) + assert.True(t, streamVal, "expected BifrostContextKeyPromptStreamRequest=true when version has stream:true") +} + +// TestHTTPTransportPreHook_StreamFalse_NoStreamContext verifies that stream:false +// in ModelParams does NOT set the stream context key. +func TestHTTPTransportPreHook_StreamFalse_NoStreamContext(t *testing.T) { + v := makeVersionWithParams(1, "p1", true, tables.ModelParams{"stream": false}) + prompt := makePrompt("p1", &v) + + p := newTestPlugin( + nil, + map[string]*tables.TablePrompt{"p1": &prompt}, + map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}}, + ) + + ctx := bfCtx() + req := &schemas.HTTPRequest{Headers: map[string]string{PromptIDHeader: "p1"}} + + _, err := p.HTTPTransportPreHook(ctx, req) + require.NoError(t, err) + + assert.Nil(t, ctx.Value(schemas.BifrostContextKeyPromptStreamRequest), + "expected BifrostContextKeyPromptStreamRequest not set when version has stream:false") +} + +// TestHTTPTransportPreHook_NoStreamParam_NoStreamContext verifies that when no +// "stream" key is present in ModelParams, the stream context key is not set. +func TestHTTPTransportPreHook_NoStreamParam_NoStreamContext(t *testing.T) { + v := makeVersionWithParams(1, "p1", true, tables.ModelParams{"temperature": float64(0.7)}) + prompt := makePrompt("p1", &v) + + p := newTestPlugin( + nil, + map[string]*tables.TablePrompt{"p1": &prompt}, + map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}}, + ) + + ctx := bfCtx() + req := &schemas.HTTPRequest{Headers: map[string]string{PromptIDHeader: "p1"}} + + _, err := p.HTTPTransportPreHook(ctx, req) + require.NoError(t, err) + + assert.Nil(t, ctx.Value(schemas.BifrostContextKeyPromptStreamRequest), + "expected BifrostContextKeyPromptStreamRequest not set when no stream key in params") +} + +// TestHTTPTransportPreHook_SpecificVersion_StreamTrue_SetsStreamContext verifies +// that when a specific (non-latest) version is requested via header and that +// version has stream:true, the stream context key is set — even if the latest +// version has stream:false. +func TestHTTPTransportPreHook_SpecificVersion_StreamTrue_SetsStreamContext(t *testing.T) { + vLatest := makeVersionWithParams(1, "p1", true, tables.ModelParams{"stream": false}) + vOld := makeVersionWithParams(2, "p1", false, tables.ModelParams{"stream": true}) + prompt := makePrompt("p1", &vLatest) + + p := newTestPlugin( + nil, + map[string]*tables.TablePrompt{"p1": &prompt}, + map[string]map[int]*tables.TablePromptVersion{"p1": {1: &vLatest, 2: &vOld}}, + ) + + ctx := bfCtx() + req := &schemas.HTTPRequest{ + Headers: map[string]string{ + PromptIDHeader: "p1", + PromptVersionHeader: "2", + }, + } + + _, err := p.HTTPTransportPreHook(ctx, req) + require.NoError(t, err) + + streamVal, _ := ctx.Value(schemas.BifrostContextKeyPromptStreamRequest).(bool) + assert.True(t, streamVal, "expected stream=true from explicitly requested version with stream:true") +} + +// ============================================================ +// PreLLMHook — model params merge and override +// ============================================================ + +// TestPreLLMHook_VersionParamsApplied_WhenRequestHasNoParams verifies that when +// the request carries no Params at all, the version's ModelParams become the +// effective parameters on the outgoing request. +func TestPreLLMHook_VersionParamsApplied_WhenRequestHasNoParams(t *testing.T) { + v := makeVersionWithParams(1, "p1", true, + tables.ModelParams{"temperature": float64(0.7)}, + versionMsg(schemas.ChatMessageRoleSystem, "sys"), + ) + prompt := makePrompt("p1", &v) + + p := newTestPlugin( + &staticResolver{promptID: "p1", versionSpecified: false}, + map[string]*tables.TablePrompt{"p1": &prompt}, + map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}}, + ) + + out, _, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("hi"))) + require.NoError(t, err) + require.NotNil(t, out.ChatRequest.Params, "expected Params to be set from version ModelParams") + require.NotNil(t, out.ChatRequest.Params.Temperature) + assert.InDelta(t, 0.7, *out.ChatRequest.Params.Temperature, 0.001) +} + +// TestPreLLMHook_RequestParamsOverrideVersionParams verifies that when both the +// version and the request specify the same parameter, the request value wins. +func TestPreLLMHook_RequestParamsOverrideVersionParams(t *testing.T) { + reqTemp := 0.9 + v := makeVersionWithParams(1, "p1", true, + tables.ModelParams{"temperature": float64(0.3)}, + versionMsg(schemas.ChatMessageRoleSystem, "sys"), + ) + prompt := makePrompt("p1", &v) + + p := newTestPlugin( + &staticResolver{promptID: "p1", versionSpecified: false}, + map[string]*tables.TablePrompt{"p1": &prompt}, + map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}}, + ) + + req := chatRequestWithParams(&schemas.ChatParameters{Temperature: &reqTemp}, userMsg("hello")) + out, _, err := p.PreLLMHook(bfCtx(), req) + require.NoError(t, err) + require.NotNil(t, out.ChatRequest.Params) + require.NotNil(t, out.ChatRequest.Params.Temperature) + assert.InDelta(t, reqTemp, *out.ChatRequest.Params.Temperature, 0.001, + "request temperature must override version default temperature") +} + +// TestPreLLMHook_RequestParamsPartialOverride verifies the mixed case: version +// sets temperature and max_completion_tokens; request overrides only temperature. +// The version's max_completion_tokens must still be applied. +func TestPreLLMHook_RequestParamsPartialOverride(t *testing.T) { + reqTemp := 0.9 + maxTokens := 200 + v := makeVersionWithParams(1, "p1", true, + tables.ModelParams{ + "temperature": float64(0.3), + "max_completion_tokens": float64(maxTokens), + }, + versionMsg(schemas.ChatMessageRoleSystem, "sys"), + ) + prompt := makePrompt("p1", &v) + + p := newTestPlugin( + &staticResolver{promptID: "p1", versionSpecified: false}, + map[string]*tables.TablePrompt{"p1": &prompt}, + map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}}, + ) + + req := chatRequestWithParams(&schemas.ChatParameters{Temperature: &reqTemp}, userMsg("hello")) + out, _, err := p.PreLLMHook(bfCtx(), req) + require.NoError(t, err) + require.NotNil(t, out.ChatRequest.Params) + require.NotNil(t, out.ChatRequest.Params.Temperature) + assert.InDelta(t, reqTemp, *out.ChatRequest.Params.Temperature, 0.001, + "request temperature must override version temperature") + require.NotNil(t, out.ChatRequest.Params.MaxCompletionTokens, + "version max_completion_tokens must be applied when request does not override it") + assert.Equal(t, maxTokens, *out.ChatRequest.Params.MaxCompletionTokens) +} + +// ============================================================ +// PreLLMHook — model field preservation +// ============================================================ + +// TestPreLLMHook_ModelInVersionParams_DoesNotOverrideRequestModel verifies that +// a "model" key inside a version's ModelParams (which the UI may store alongside +// temperature etc.) does NOT replace the model field on the outgoing +// BifrostChatRequest. The model chosen by the caller must always win. +func TestPreLLMHook_ModelInVersionParams_DoesNotOverrideRequestModel(t *testing.T) { + v := makeVersionWithParams(1, "p1", true, + tables.ModelParams{ + "model": "openai/gpt-4o", + "temperature": float64(0.5), + }, + versionMsg(schemas.ChatMessageRoleSystem, "sys"), + ) + prompt := makePrompt("p1", &v) + + p := newTestPlugin( + &staticResolver{promptID: "p1", versionSpecified: false}, + map[string]*tables.TablePrompt{"p1": &prompt}, + map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}}, + ) + + req := chatRequestWithModel("openai/gpt-3.5-turbo", userMsg("hi")) + out, _, err := p.PreLLMHook(bfCtx(), req) + require.NoError(t, err) + assert.Equal(t, "openai/gpt-3.5-turbo", out.ChatRequest.Model, + "request model must not be overridden by model stored in version ModelParams") +} + +// ============================================================ +// loadCache + PreLLMHook integration (store → cache → injection) +// ============================================================ + +// TestLoadCacheAndPreLLMHook_EndToEnd verifies the full pipeline: +// mockStore returns TablePrompt/TablePromptVersion structs → loadCache populates +// the in-memory maps → PreLLMHook injects the template messages correctly. +// This catches any mismatch between how loadCache builds the maps and how +// PreLLMHook reads them (e.g. pointer aliasing, LatestVersion linking). +func TestLoadCacheAndPreLLMHook_EndToEnd(t *testing.T) { + v := makeVersion(1, "p1", true, + versionMsg(schemas.ChatMessageRoleSystem, "end-to-end system"), + ) + prompt := makePrompt("p1", &v) + + ms := &mockStore{ + prompts: []tables.TablePrompt{prompt}, + versions: []tables.TablePromptVersion{v}, + } + + p := newPluginWithStore(ms) + require.NoError(t, p.loadCache(context.Background())) + + p.resolver = &staticResolver{promptID: "p1", versionSpecified: false} + + out, _, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("user msg"))) + require.NoError(t, err) + require.Len(t, out.ChatRequest.Input, 2) + + assert.Equal(t, schemas.ChatMessageRoleSystem, out.ChatRequest.Input[0].Role) + assert.Equal(t, "end-to-end system", msgText(out.ChatRequest.Input[0])) + assert.Equal(t, schemas.ChatMessageRoleUser, out.ChatRequest.Input[1].Role) + assert.Equal(t, "user msg", msgText(out.ChatRequest.Input[1])) +} + +// TestLoadCacheAndPreLLMHook_SpecificVersion exercises the loadCache → PreLLMHook +// path for a version lookup by ID (not just the LatestVersion pointer). +func TestLoadCacheAndPreLLMHook_SpecificVersion(t *testing.T) { + vOld := makeVersion(2, "p1", false, + versionMsg(schemas.ChatMessageRoleSystem, "old via store"), + ) + vLatest := makeVersion(3, "p1", true, + versionMsg(schemas.ChatMessageRoleSystem, "latest via store"), + ) + prompt := makePrompt("p1", &vLatest) + + ms := &mockStore{ + prompts: []tables.TablePrompt{prompt}, + versions: []tables.TablePromptVersion{vOld, vLatest}, + } + + p := newPluginWithStore(ms) + require.NoError(t, p.loadCache(context.Background())) + + p.resolver = &staticResolver{promptID: "p1", versionNumber: 2, versionSpecified: true} + + out, _, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("question"))) + require.NoError(t, err) + require.Len(t, out.ChatRequest.Input, 2) + assert.Equal(t, "old via store", msgText(out.ChatRequest.Input[0])) +} + +// TestPreLLMHook_AssistantMessage_UIFormat verifies that assistant messages stored +// in the Bifrost UI's completion_result format (payload.choices[0].message) are +// correctly extracted and prepended to the request. +func TestPreLLMHook_AssistantMessage_UIFormat(t *testing.T) { + v := makeVersion(1, "p1", true, + versionMsg(schemas.ChatMessageRoleSystem, "be helpful"), + versionMsgAssistantUIFormat("sure, how can I help?"), + ) + prompt := makePrompt("p1", &v) + + p := newTestPlugin( + &staticResolver{promptID: "p1", versionSpecified: false}, + map[string]*tables.TablePrompt{"p1": &prompt}, + map[string]map[int]*tables.TablePromptVersion{"p1": {1: &v}}, + ) + + out, _, err := p.PreLLMHook(bfCtx(), chatRequest(userMsg("hello"))) + require.NoError(t, err) + require.Len(t, out.ChatRequest.Input, 3, "expected system + assistant + user") + + assert.Equal(t, schemas.ChatMessageRoleSystem, out.ChatRequest.Input[0].Role) + assert.Equal(t, "be helpful", msgText(out.ChatRequest.Input[0])) + + assert.Equal(t, schemas.ChatMessageRoleAssistant, out.ChatRequest.Input[1].Role) + assert.Equal(t, "sure, how can I help?", msgText(out.ChatRequest.Input[1])) + + assert.Equal(t, schemas.ChatMessageRoleUser, out.ChatRequest.Input[2].Role) + assert.Equal(t, "hello", msgText(out.ChatRequest.Input[2])) +} diff --git a/plugins/prompts/version b/plugins/prompts/version new file mode 100644 index 0000000000..afaf360d37 --- /dev/null +++ b/plugins/prompts/version @@ -0,0 +1 @@ +1.0.0 \ No newline at end of file diff --git a/transports/bifrost-http/handlers/inference.go b/transports/bifrost-http/handlers/inference.go index f5acfe8c73..8877d5675a 100644 --- a/transports/bifrost-http/handlers/inference.go +++ b/transports/bifrost-http/handlers/inference.go @@ -92,6 +92,8 @@ var chatParamsKnownFields = map[string]bool{ "presence_penalty": true, "prompt_cache_key": true, "reasoning": true, + "reasoning_effort": true, + "reasoning_max_tokens": true, "response_format": true, "safety_identifier": true, "service_tier": true, @@ -515,6 +517,16 @@ func parseFallbacks(fallbackStrings []string) ([]schemas.Fallback, error) { return fallbacks, nil } +func effectiveStream(bodyStream *bool, bifrostCtx *schemas.BifrostContext) bool { + if bodyStream != nil { + return *bodyStream + } + if v, ok := bifrostCtx.Value(schemas.BifrostContextKeyPromptStreamRequest).(bool); ok && v { + return true + } + return false +} + // extractExtraParams processes unknown fields from JSON data into ExtraParams func extractExtraParams(data []byte, knownFields map[string]bool) (map[string]any, error) { // Parse JSON to extract unknown fields @@ -932,7 +944,7 @@ func (h *CompletionHandler) chatCompletion(ctx *fasthttp.RequestCtx) { SendError(ctx, fasthttp.StatusBadRequest, "Failed to convert context") return } - if req.Stream != nil && *req.Stream { + if effectiveStream(req.Stream, bifrostCtx) { h.handleStreamingChatCompletion(ctx, bifrostChatReq, bifrostCtx, cancel) return } @@ -1027,7 +1039,7 @@ func (h *CompletionHandler) responses(ctx *fasthttp.RequestCtx) { return } - if req.Stream != nil && *req.Stream { + if effectiveStream(req.Stream, bifrostCtx) { h.handleStreamingResponses(ctx, bifrostResponsesReq, bifrostCtx, cancel) return } diff --git a/transports/bifrost-http/handlers/prompts.go b/transports/bifrost-http/handlers/prompts.go index 9f4ac48470..e5b96f0c38 100644 --- a/transports/bifrost-http/handlers/prompts.go +++ b/transports/bifrost-http/handlers/prompts.go @@ -1,8 +1,10 @@ package handlers import ( + "context" "encoding/json" "errors" + "fmt" "strconv" "github.com/fasthttp/router" @@ -14,15 +16,35 @@ import ( "github.com/valyala/fasthttp" ) +// PromptCacheReloader is implemented by the prompts plugin to allow the HTTP handler +// to trigger an in-memory cache refresh after any repository mutation. +type PromptCacheReloader interface { + Reload(ctx context.Context) error +} + // PromptsHandler handles prompt repository endpoints type PromptsHandler struct { - store configstore.ConfigStore + store configstore.ConfigStore + reloader PromptCacheReloader // optional; nil when the prompts plugin is not loaded +} + +// NewPromptsHandler creates a new PromptsHandler. +// reloader may be nil; when set, the in-memory prompt cache is refreshed after mutations. +func NewPromptsHandler(store configstore.ConfigStore, reloader PromptCacheReloader) *PromptsHandler { + if store == nil { + return nil + } + return &PromptsHandler{store: store, reloader: reloader} } -// NewPromptsHandler creates a new PromptsHandler -func NewPromptsHandler(store configstore.ConfigStore) *PromptsHandler { - return &PromptsHandler{ - store: store, +// reloadCache triggers a cache refresh if a reloader is configured. +// Errors are logged but do not fail the originating request. +func (h *PromptsHandler) reloadCache(ctx context.Context) { + if h.reloader == nil { + return + } + if err := h.reloader.Reload(ctx); err != nil { + logger.Error("failed to reload prompt cache: %v", err) } } @@ -143,7 +165,8 @@ type RenameSessionRequest struct { // CommitSessionRequest represents the request body for committing a session as a version type CommitSessionRequest struct { - CommitMessage string `json:"commit_message"` + CommitMessage string `json:"commit_message"` + MessageIndices *[]int `json:"message_indices,omitempty"` // optional: indices of messages to include (0-based). If nil/absent, all messages are included. } // ============================================================================ @@ -294,6 +317,7 @@ func (h *PromptsHandler) deleteFolder(ctx *fasthttp.RequestCtx) { return } + h.reloadCache(ctx) SendJSON(ctx, map[string]any{ "message": "folder deleted successfully", }) @@ -392,6 +416,7 @@ func (h *PromptsHandler) createPrompt(ctx *fasthttp.RequestCtx) { return } + h.reloadCache(ctx) SendJSON(ctx, map[string]any{ "prompt": prompt, }) @@ -465,6 +490,7 @@ func (h *PromptsHandler) updatePrompt(ctx *fasthttp.RequestCtx) { return } + h.reloadCache(ctx) SendJSON(ctx, map[string]any{ "prompt": prompt, }) @@ -493,6 +519,7 @@ func (h *PromptsHandler) deletePrompt(ctx *fasthttp.RequestCtx) { return } + h.reloadCache(ctx) SendJSON(ctx, map[string]any{ "message": "prompt deleted successfully", }) @@ -619,6 +646,7 @@ func (h *PromptsHandler) createVersion(ctx *fasthttp.RequestCtx) { return } + h.reloadCache(ctx) SendJSON(ctx, map[string]any{ "version": version, }) @@ -652,6 +680,7 @@ func (h *PromptsHandler) deleteVersion(ctx *fasthttp.RequestCtx) { return } + h.reloadCache(ctx) SendJSON(ctx, map[string]any{ "message": "version deleted successfully", }) @@ -1005,11 +1034,36 @@ func (h *PromptsHandler) commitSession(ctx *fasthttp.RequestCtx) { // Convert session messages to version messages var messages []tables.TablePromptVersionMessage - for _, msg := range session.Messages { - messages = append(messages, tables.TablePromptVersionMessage{ - PromptID: session.PromptID, - Message: msg.Message, - }) + if req.MessageIndices != nil { + // Only include messages at the specified indices, deduplicating + seen := make(map[int]struct{}) + for _, idx := range *req.MessageIndices { + if _, ok := seen[idx]; ok { + continue + } + seen[idx] = struct{}{} + if idx < 0 || idx >= len(session.Messages) { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("message index %d out of range (0-%d)", idx, len(session.Messages)-1)) + return + } + msg := session.Messages[idx] + messages = append(messages, tables.TablePromptVersionMessage{ + PromptID: session.PromptID, + Message: msg.Message, + }) + } + } else { + for _, msg := range session.Messages { + messages = append(messages, tables.TablePromptVersionMessage{ + PromptID: session.PromptID, + Message: msg.Message, + }) + } + } + + if len(messages) == 0 { + SendError(ctx, fasthttp.StatusBadRequest, "at least one message must be included in the version") + return } version := &tables.TablePromptVersion{ @@ -1027,6 +1081,7 @@ func (h *PromptsHandler) commitSession(ctx *fasthttp.RequestCtx) { return } + h.reloadCache(ctx) SendJSON(ctx, map[string]any{ "version": version, }) diff --git a/transports/bifrost-http/lib/config.go b/transports/bifrost-http/lib/config.go index 751687f0c8..9241752b68 100644 --- a/transports/bifrost-http/lib/config.go +++ b/transports/bifrost-http/lib/config.go @@ -41,6 +41,7 @@ import ( "github.com/maximhq/bifrost/plugins/logging" "github.com/maximhq/bifrost/plugins/maxim" "github.com/maximhq/bifrost/plugins/otel" + "github.com/maximhq/bifrost/plugins/prompts" "github.com/maximhq/bifrost/plugins/semanticcache" "github.com/maximhq/bifrost/plugins/telemetry" "gorm.io/gorm" @@ -103,6 +104,7 @@ func getWeight(w *float64) float64 { // IsBuiltinPlugin checks if a plugin is a built-in plugin func IsBuiltinPlugin(name string) bool { return name == telemetry.PluginName || + name == prompts.PluginName || name == logging.PluginName || name == governance.PluginName || name == litellmcompat.PluginName || diff --git a/transports/bifrost-http/lib/config_test.go b/transports/bifrost-http/lib/config_test.go index 25462ed8f5..e81456dc3b 100644 --- a/transports/bifrost-http/lib/config_test.go +++ b/transports/bifrost-http/lib/config_test.go @@ -1263,6 +1263,9 @@ func (m *MockConfigStore) DeletePrompt(ctx context.Context, id string) error { r func (m *MockConfigStore) GetPromptVersions(ctx context.Context, promptID string) ([]tables.TablePromptVersion, error) { return nil, nil } +func (m *MockConfigStore) GetAllPromptVersions(ctx context.Context) ([]tables.TablePromptVersion, error) { + return nil, nil +} func (m *MockConfigStore) GetPromptVersionByID(ctx context.Context, id uint) (*tables.TablePromptVersion, error) { return nil, nil } diff --git a/transports/bifrost-http/server/plugins.go b/transports/bifrost-http/server/plugins.go index b5fb80be3f..3cdf2f31fa 100644 --- a/transports/bifrost-http/server/plugins.go +++ b/transports/bifrost-http/server/plugins.go @@ -11,6 +11,7 @@ import ( "github.com/maximhq/bifrost/plugins/logging" "github.com/maximhq/bifrost/plugins/maxim" "github.com/maximhq/bifrost/plugins/otel" + "github.com/maximhq/bifrost/plugins/prompts" "github.com/maximhq/bifrost/plugins/semanticcache" "github.com/maximhq/bifrost/plugins/telemetry" "github.com/maximhq/bifrost/transports/bifrost-http/handlers" @@ -62,6 +63,9 @@ func loadBuiltinPlugin(ctx context.Context, name string, pluginConfig any, bifro } return telemetry.Init(telConfig, bifrostConfig.ModelCatalog, logger) + case prompts.PluginName: + return prompts.Init(ctx, bifrostConfig.ConfigStore, logger) + case logging.PluginName: loggingConfig, err := MarshalPluginConfig[logging.Config](pluginConfig) if err != nil { @@ -159,7 +163,15 @@ func (s *BifrostHTTPServer) loadBuiltinPlugins(ctx context.Context) error { } s.Config.SetPluginOrderInfo(telemetry.PluginName, builtinPlacement, schemas.Ptr(1)) - // 2. Logging (if enabled) + // 2. Prompts (requires config store for prompt repository) + if s.Config.ConfigStore != nil { + s.registerPluginWithStatus(ctx, prompts.PluginName, nil, nil, false) + } else { + s.markPluginDisabled(prompts.PluginName) + } + s.Config.SetPluginOrderInfo(prompts.PluginName, builtinPlacement, schemas.Ptr(2)) + + // 3. Logging (if enabled) if (s.Config.ClientConfig.EnableLogging == nil || *s.Config.ClientConfig.EnableLogging) && s.Config.LogsStore != nil { config := &logging.Config{ DisableContentLogging: &s.Config.ClientConfig.DisableContentLogging, @@ -169,9 +181,9 @@ func (s *BifrostHTTPServer) loadBuiltinPlugins(ctx context.Context) error { } else { s.markPluginDisabled(logging.PluginName) } - s.Config.SetPluginOrderInfo(logging.PluginName, builtinPlacement, schemas.Ptr(2)) + s.Config.SetPluginOrderInfo(logging.PluginName, builtinPlacement, schemas.Ptr(3)) - // 3. Governance (if enabled and not enterprise) + // 4. Governance (if enabled and not enterprise) if ctx.Value(schemas.BifrostContextKeyIsEnterprise) == nil { config := &governance.Config{ IsVkMandatory: &s.Config.ClientConfig.EnforceAuthOnInference, @@ -183,48 +195,47 @@ func (s *BifrostHTTPServer) loadBuiltinPlugins(ctx context.Context) error { } else { s.markPluginDisabled(governance.PluginName) } - s.Config.SetPluginOrderInfo(governance.PluginName, builtinPlacement, schemas.Ptr(3)) + s.Config.SetPluginOrderInfo(governance.PluginName, builtinPlacement, schemas.Ptr(4)) - // 4. OTEL (if configured in PluginConfigs) + // 5. OTEL (if configured in PluginConfigs) otelConfig := s.getPluginConfig(otel.PluginName) if otelConfig != nil && otelConfig.Enabled { s.registerPluginWithStatus(ctx, otel.PluginName, nil, otelConfig.Config, false) } else { s.markPluginDisabled(otel.PluginName) } - s.Config.SetPluginOrderInfo(otel.PluginName, builtinPlacement, schemas.Ptr(4)) + s.Config.SetPluginOrderInfo(otel.PluginName, builtinPlacement, schemas.Ptr(5)) - // 5. Semantic Cache (if configured in PluginConfigs) + // 6. Semantic Cache (if configured in PluginConfigs) semanticCacheConfig := s.getPluginConfig(semanticcache.PluginName) if semanticCacheConfig != nil && semanticCacheConfig.Enabled { s.registerPluginWithStatus(ctx, semanticcache.PluginName, nil, semanticCacheConfig.Config, false) } else { s.markPluginDisabled(semanticcache.PluginName) } - s.Config.SetPluginOrderInfo(semanticcache.PluginName, builtinPlacement, schemas.Ptr(5)) + s.Config.SetPluginOrderInfo(semanticcache.PluginName, builtinPlacement, schemas.Ptr(6)) - // 6. Litellmcompat (if configured in PluginConfigs) + // 7. Litellmcompat (if configured in PluginConfigs) litellmcompatConfig := s.getPluginConfig(litellmcompat.PluginName) if litellmcompatConfig != nil && litellmcompatConfig.Enabled { s.registerPluginWithStatus(ctx, litellmcompat.PluginName, nil, litellmcompatConfig.Config, false) } else { s.markPluginDisabled(litellmcompat.PluginName) } - s.Config.SetPluginOrderInfo(litellmcompat.PluginName, builtinPlacement, schemas.Ptr(6)) + s.Config.SetPluginOrderInfo(litellmcompat.PluginName, builtinPlacement, schemas.Ptr(7)) - // 7. Maxim (if configured in PluginConfigs) + // 8. Maxim (if configured in PluginConfigs) maximConfig := s.getPluginConfig(maxim.PluginName) if maximConfig != nil && maximConfig.Enabled { s.registerPluginWithStatus(ctx, maxim.PluginName, nil, maximConfig.Config, false) } else { s.markPluginDisabled(maxim.PluginName) } - s.Config.SetPluginOrderInfo(maxim.PluginName, builtinPlacement, schemas.Ptr(7)) + s.Config.SetPluginOrderInfo(maxim.PluginName, builtinPlacement, schemas.Ptr(8)) return nil } - // loadCustomPlugins loads plugins from PluginConfigs func (s *BifrostHTTPServer) loadCustomPlugins(ctx context.Context) error { for _, cfg := range s.Config.PluginConfigs { diff --git a/transports/bifrost-http/server/server.go b/transports/bifrost-http/server/server.go index b8e902a36c..faff3ae70d 100644 --- a/transports/bifrost-http/server/server.go +++ b/transports/bifrost-http/server/server.go @@ -26,6 +26,7 @@ import ( "github.com/maximhq/bifrost/framework/tracing" "github.com/maximhq/bifrost/plugins/governance" "github.com/maximhq/bifrost/plugins/logging" + "github.com/maximhq/bifrost/plugins/prompts" "github.com/maximhq/bifrost/plugins/semanticcache" "github.com/maximhq/bifrost/plugins/telemetry" "github.com/maximhq/bifrost/transports/bifrost-http/handlers" @@ -1047,6 +1048,10 @@ func (s *BifrostHTTPServer) RegisterAPIRoutes(ctx context.Context, callbacks Ser if semanticCachePlugin != nil { cacheHandler = handlers.NewCacheHandler(semanticCachePlugin) } + var promptsReloader handlers.PromptCacheReloader + if promptsPlugin, err := lib.FindPluginAs[*prompts.Plugin](s.Config, prompts.PluginName); err == nil && promptsPlugin != nil { + promptsReloader = promptsPlugin + } // Websocket handler needs to go below UI handler logger.Debug("initializing websocket server") if s.WebSocketHandler == nil { @@ -1073,7 +1078,7 @@ func (s *BifrostHTTPServer) RegisterAPIRoutes(ctx context.Context, callbacks Ser configHandler := handlers.NewConfigHandler(callbacks, s.Config) pluginsHandler := handlers.NewPluginsHandler(callbacks, s.Config.ConfigStore) sessionHandler := handlers.NewSessionHandler(s.Config.ConfigStore, s.WSTicketStore) - promptsHandler := handlers.NewPromptsHandler(s.Config.ConfigStore) + promptsHandler := handlers.NewPromptsHandler(s.Config.ConfigStore, promptsReloader) // Going ahead with API handlers healthHandler.RegisterRoutes(s.Router, middlewares...) providerHandler.RegisterRoutes(s.Router, middlewares...) diff --git a/ui/components/prompts/components/promptsViewHeader.tsx b/ui/components/prompts/components/promptsViewHeader.tsx index c3d424812d..fc6159e031 100644 --- a/ui/components/prompts/components/promptsViewHeader.tsx +++ b/ui/components/prompts/components/promptsViewHeader.tsx @@ -110,6 +110,13 @@ export default function PromptsViewHeader() { const handleCommitVersion = useCallback(async () => { if (!selectedPrompt) return; + if (!hasChanges) { + const selectedSession = sessions.find((s) => s.id === selectedSessionId); + if (selectedSession) { + onSessionSaved(selectedSession); + } + return; + } try { // Always create a new session with current state before committing const result = await createSession({ @@ -126,7 +133,7 @@ export default function PromptsViewHeader() { } catch (err) { toast.error("Failed to save session", { description: getErrorMessage(err) }); } - }, [selectedPrompt?.id, messages, buildSaveParams, provider, model, createSession, setUrlState, onSessionSaved]); + }, [selectedPrompt?.id, messages, buildSaveParams, provider, model, createSession, setUrlState, onSessionSaved, hasChanges]); const handleRenameSession = useCallback( async (sessionId: number, name: string) => { diff --git a/ui/components/prompts/sheets/commitVersionSheet.tsx b/ui/components/prompts/sheets/commitVersionSheet.tsx index dab3a5a724..3bd6839e03 100644 --- a/ui/components/prompts/sheets/commitVersionSheet.tsx +++ b/ui/components/prompts/sheets/commitVersionSheet.tsx @@ -1,13 +1,18 @@ import { Button } from '@/components/ui/button' +import { Checkbox } from '@/components/ui/checkbox' import { Input } from '@/components/ui/input' import { Label } from '@/components/ui/label' +import { ScrollArea } from '@/components/ui/scrollArea' import { Sheet, SheetContent, SheetDescription, SheetFooter, SheetHeader, SheetTitle } from '@/components/ui/sheet' +import { Message, MessageType } from '@/lib/message' +import { Markdown } from '@/components/ui/markdown' import { getErrorMessage } from '@/lib/store' import { useCommitSessionMutation } from '@/lib/store/apis/promptsApi' -import { PromptSession } from '@/lib/types/prompts' -import { useEffect } from 'react' +import { PromptSession, PromptSessionMessage } from '@/lib/types/prompts' +import { useCallback, useEffect, useMemo, useState } from 'react' import { useForm } from 'react-hook-form' import { toast } from 'sonner' +import { cn } from '@/lib/utils' interface CommitVersionFormData { commitMessage: string @@ -20,8 +25,49 @@ interface CommitVersionSheetProps { onCommitted: (versionId: number) => void } +function MessagePreview({ sessionMessage, selected, onToggle }: { + sessionMessage: PromptSessionMessage + selected: boolean + onToggle: () => void +}) { + const msg = useMemo(() => Message.deserialize(sessionMessage.message), [sessionMessage.message]) + const role = msg.role + const content = msg.content + const hasToolCalls = msg.type === MessageType.CompletionResult && msg.toolCalls && msg.toolCalls.length > 0 + + return ( + + ) +} + export function CommitVersionSheet({ open, onOpenChange, session, onCommitted }: CommitVersionSheetProps) { const [commitSession, { isLoading }] = useCommitSessionMutation() + const [selectedIndices, setSelectedIndices] = useState>(new Set()) const { register, @@ -32,18 +78,54 @@ export function CommitVersionSheet({ open, onOpenChange, session, onCommitted }: defaultValues: { commitMessage: '' }, }) + // Reset form and select only the first message when sheet opens useEffect(() => { if (open) { reset({ commitMessage: '' }) + setSelectedIndices(new Set(session.messages.length > 0 ? [0] : [])) } - }, [open, reset]) + }, [open, reset, session?.messages?.length]) + + const toggleMessage = useCallback((index: number) => { + setSelectedIndices(prev => { + const next = new Set(prev) + if (next.has(index)) { + next.delete(index) + } else { + next.add(index) + } + return next + }) + }, []) + + const allSelected = selectedIndices.size === session.messages.length + + const toggleAll = useCallback(() => { + if (allSelected) { + setSelectedIndices(new Set()) + } else { + setSelectedIndices(new Set(session.messages.map((_, i) => i))) + } + }, [allSelected, session.messages]) async function onSubmit(data: CommitVersionFormData) { + if (selectedIndices.size === 0) { + toast.error('Please select at least one message to commit') + return + } try { + const sortedIndices = Array.from(selectedIndices).sort((a, b) => a - b) + const commitData: { commit_message: string; message_indices?: number[] } = { + commit_message: data.commitMessage.trim(), + } + // Only send message_indices if not all messages are selected + if (!allSelected) { + commitData.message_indices = sortedIndices + } const result = await commitSession({ id: session.id, promptId: session.prompt_id, - data: { commit_message: data.commitMessage.trim() }, + data: commitData, }).unwrap() toast.success('Version committed') reset() @@ -58,16 +140,43 @@ export function CommitVersionSheet({ open, onOpenChange, session, onCommitted }: return ( - { e.preventDefault(); document.getElementById("commitMessage")?.focus(); }}> -
+ { e.preventDefault(); document.getElementById("commitMessage")?.focus(); }}> + Commit as Version - Create a new immutable version from the current session. Versions cannot be modified after creation. + Select the messages to include in this version. Uncheck any messages you want to exclude. -
+ {/* Messages selection - scrollable */} +
+
+ + +
+ +
+ {session.messages.map((sessionMsg, index) => ( + toggleMessage(index)} + /> + ))} +
+
+
+ + {/* Commit message + CTAs - always visible at bottom */} +
{errors.commitMessage.message}

) : (

- Describe what changed in this version (e.g., "Added error handling instructions") + Describe what changed in this version (e.g., "Added error handling instructions")

)}
-
- - - - + + + + +
diff --git a/ui/lib/types/prompts.ts b/ui/lib/types/prompts.ts index ce067dae29..9a90878f87 100644 --- a/ui/lib/types/prompts.ts +++ b/ui/lib/types/prompts.ts @@ -243,6 +243,7 @@ export interface DeleteSessionResponse { export interface CommitSessionRequest { commit_message: string + message_indices?: number[] } export interface CommitSessionResponse {