diff --git a/core/bifrost.go b/core/bifrost.go index f77e6cf07e..3cf8415556 100644 --- a/core/bifrost.go +++ b/core/bifrost.go @@ -514,33 +514,31 @@ func (bifrost *Bifrost) TextCompletionRequest(ctx context.Context, req *schemas. func (bifrost *Bifrost) tryTextCompletion(req *schemas.BifrostRequest, ctx context.Context) (*schemas.BifrostResponse, *schemas.BifrostError) { queue, err := bifrost.getProviderQueue(req.Provider) if err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: err.Error(), - }, - } + return nil, newBifrostError(err) } - for _, plugin := range bifrost.plugins { - req, err = plugin.PreHook(&ctx, req) + var resp *schemas.BifrostResponse + var processedPluginCount int + for i, plugin := range bifrost.plugins { + req, resp, err = plugin.PreHook(&ctx, req) if err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: err.Error(), - }, + return nil, newBifrostError(err) + } + processedPluginCount = i + 1 + if resp != nil { + // Run post-hooks in reverse order for plugins that had PreHook executed + for j := processedPluginCount - 1; j >= 0; j-- { + resp, err = bifrost.plugins[j].PostHook(&ctx, resp) + if err != nil { + return nil, newBifrostError(err) + } } + return resp, nil } } if req == nil { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "bifrost request after plugin hooks cannot be nil", - }, - } + return nil, newBifrostErrorFromMsg("bifrost request after plugin hooks cannot be nil") } // Get a ChannelMessage from the pool @@ -554,23 +552,13 @@ func (bifrost *Bifrost) tryTextCompletion(req *schemas.BifrostRequest, ctx conte case <-ctx.Done(): // Request was cancelled by caller bifrost.releaseChannelMessage(msg) - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "request cancelled while waiting for queue space", - }, - } + return nil, newBifrostErrorFromMsg("request cancelled while waiting for queue space") default: if bifrost.dropExcessRequests { // Drop request immediately if configured to do so bifrost.releaseChannelMessage(msg) bifrost.logger.Warn("Request dropped: queue is full, please increase the queue size or set dropExcessRequests to false") - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "request dropped: queue is full", - }, - } + return nil, newBifrostErrorFromMsg("request dropped: queue is full") } // If not dropping excess requests, wait with context @@ -582,12 +570,7 @@ func (bifrost *Bifrost) tryTextCompletion(req *schemas.BifrostRequest, ctx conte // Message was sent successfully case <-ctx.Done(): bifrost.releaseChannelMessage(msg) - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "request cancelled while waiting for queue space", - }, - } + return nil, newBifrostErrorFromMsg("request cancelled while waiting for queue space") } } @@ -600,12 +583,7 @@ func (bifrost *Bifrost) tryTextCompletion(req *schemas.BifrostRequest, ctx conte result, err = bifrost.plugins[i].PostHook(&ctx, result) if err != nil { bifrost.releaseChannelMessage(msg) - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: err.Error(), - }, - } + return nil, newBifrostError(err) } } case err := <-msg.Err: @@ -623,30 +601,15 @@ func (bifrost *Bifrost) tryTextCompletion(req *schemas.BifrostRequest, ctx conte // If the primary provider fails, it will try each fallback provider in order until one succeeds. func (bifrost *Bifrost) ChatCompletionRequest(ctx context.Context, req *schemas.BifrostRequest) (*schemas.BifrostResponse, *schemas.BifrostError) { if req == nil { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "bifrost request cannot be nil", - }, - } + return nil, newBifrostErrorFromMsg("bifrost request cannot be nil") } if req.Provider == "" { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "provider is required", - }, - } + return nil, newBifrostErrorFromMsg("provider is required") } if req.Model == "" { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "model is required", - }, - } + return nil, newBifrostErrorFromMsg("model is required") } // Try the primary provider first @@ -688,33 +651,31 @@ func (bifrost *Bifrost) ChatCompletionRequest(ctx context.Context, req *schemas. func (bifrost *Bifrost) tryChatCompletion(req *schemas.BifrostRequest, ctx context.Context) (*schemas.BifrostResponse, *schemas.BifrostError) { queue, err := bifrost.getProviderQueue(req.Provider) if err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: err.Error(), - }, - } + return nil, newBifrostError(err) } - for _, plugin := range bifrost.plugins { - req, err = plugin.PreHook(&ctx, req) + var resp *schemas.BifrostResponse + var processedPluginCount int + for i, plugin := range bifrost.plugins { + req, resp, err = plugin.PreHook(&ctx, req) if err != nil { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: err.Error(), - }, + return nil, newBifrostError(err) + } + processedPluginCount = i + 1 + if resp != nil { + // Run post-hooks in reverse order for plugins that had PreHook executed + for j := processedPluginCount - 1; j >= 0; j-- { + resp, err = bifrost.plugins[j].PostHook(&ctx, resp) + if err != nil { + return nil, newBifrostError(err) + } } + return resp, nil } } if req == nil { - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "bifrost request after plugin hooks cannot be nil", - }, - } + return nil, newBifrostErrorFromMsg("bifrost request after plugin hooks cannot be nil") } // Get a ChannelMessage from the pool @@ -728,23 +689,13 @@ func (bifrost *Bifrost) tryChatCompletion(req *schemas.BifrostRequest, ctx conte case <-ctx.Done(): // Request was cancelled by caller bifrost.releaseChannelMessage(msg) - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "request cancelled while waiting for queue space", - }, - } + return nil, newBifrostErrorFromMsg("request cancelled while waiting for queue space") default: if bifrost.dropExcessRequests { // Drop request immediately if configured to do so bifrost.releaseChannelMessage(msg) bifrost.logger.Warn("Request dropped: queue is full, please increase the queue size or set dropExcessRequests to false") - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "request dropped: queue is full", - }, - } + return nil, newBifrostErrorFromMsg("request dropped: queue is full") } // If not dropping excess requests, wait with context if ctx == nil { @@ -755,12 +706,7 @@ func (bifrost *Bifrost) tryChatCompletion(req *schemas.BifrostRequest, ctx conte // Message was sent successfully case <-ctx.Done(): bifrost.releaseChannelMessage(msg) - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: "request cancelled while waiting for queue space", - }, - } + return nil, newBifrostErrorFromMsg("request cancelled while waiting for queue space") } } @@ -773,12 +719,7 @@ func (bifrost *Bifrost) tryChatCompletion(req *schemas.BifrostRequest, ctx conte result, err = bifrost.plugins[i].PostHook(&ctx, result) if err != nil { bifrost.releaseChannelMessage(msg) - return nil, &schemas.BifrostError{ - IsBifrostError: false, - Error: schemas.ErrorField{ - Message: err.Error(), - }, - } + return nil, newBifrostError(err) } } case err := <-msg.Err: @@ -794,7 +735,7 @@ func (bifrost *Bifrost) tryChatCompletion(req *schemas.BifrostRequest, ctx conte // Cleanup gracefully stops all workers when triggered. // It closes all request channels and waits for workers to exit. func (bifrost *Bifrost) Cleanup() { - bifrost.logger.Info("[BIFROST] Graceful Cleanup Initiated - Closing all request channels...") + bifrost.logger.Info("Graceful Cleanup Initiated - Closing all request channels...") // Close all provider queues to signal workers to stop for _, queue := range bifrost.requestQueues { @@ -805,4 +746,14 @@ func (bifrost *Bifrost) Cleanup() { for _, waitGroup := range bifrost.waitGroups { waitGroup.Wait() } + + // Cleanup plugins + for _, plugin := range bifrost.plugins { + err := plugin.Cleanup() + if err != nil { + bifrost.logger.Warn(fmt.Sprintf("Error cleaning up plugin: %s", err.Error())) + } + } + + bifrost.logger.Info("Graceful Cleanup Completed") } diff --git a/core/schemas/plugin.go b/core/schemas/plugin.go index c10adebf3e..e16d30cac5 100644 --- a/core/schemas/plugin.go +++ b/core/schemas/plugin.go @@ -16,15 +16,23 @@ import "context" // - Logging // - Monitoring +// No Plugin errors are returned to the caller, they are logged as warnings by the Bifrost instance. + type Plugin interface { // PreHook is called before a request is processed by a provider. // It allows plugins to modify the request before it is sent to the provider. // The context parameter can be used to maintain state across plugin calls. - // Returns the modified request and any error that occurred during processing. - PreHook(ctx *context.Context, req *BifrostRequest) (*BifrostRequest, error) + // Returns the modified request, an optional response (if the plugin wants to short-circuit the provider call), and any error that occurred during processing. + // If a response is returned, the provider call is skipped and only the PostHook methods of plugins that had their PreHook executed are called in reverse order. + PreHook(ctx *context.Context, req *BifrostRequest) (*BifrostRequest, *BifrostResponse, error) // PostHook is called after a response is received from a provider. // It allows plugins to modify the response before it is returned to the caller. // Returns the modified response and any error that occurred during processing. PostHook(ctx *context.Context, result *BifrostResponse) (*BifrostResponse, error) + + // Cleanup is called on bifrost shutdown. + // It allows plugins to clean up any resources they have allocated. + // Returns any error that occurred during cleanup, which will be logged as a warning by the Bifrost instance. + Cleanup() error } diff --git a/core/utils.go b/core/utils.go index be284fbddf..f9b336cf89 100644 --- a/core/utils.go +++ b/core/utils.go @@ -1,5 +1,30 @@ package bifrost +import schemas "github.com/maximhq/bifrost/core/schemas" + func Ptr[T any](v T) *T { return &v } + +// newBifrostError wraps a standard error into a BifrostError with IsBifrostError set to false. +// This helper function reduces code duplication when handling non-Bifrost errors. +func newBifrostError(err error) *schemas.BifrostError { + return &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Message: err.Error(), + Error: err, + }, + } +} + +// newBifrostErrorFromMsg creates a BifrostError with a custom message. +// This helper function is used for static error messages. +func newBifrostErrorFromMsg(message string) *schemas.BifrostError { + return &schemas.BifrostError{ + IsBifrostError: false, + Error: schemas.ErrorField{ + Message: message, + }, + } +} diff --git a/docs/plugins.md b/docs/plugins.md index 0058484b5e..3e3bf285ac 100644 --- a/docs/plugins.md +++ b/docs/plugins.md @@ -19,15 +19,28 @@ Plugins in Bifrost follow a simple but powerful interface that allows them to in - Can add monitoring or logging - Executed in reverse order of PreHooks +> **Note**: PostHooks maintain symmetry with PreHooks. If a plugin returns a response in its PreHook (short-circuiting the provider call), only the PostHook methods of plugins that had their PreHook executed are called, in reverse order. This ensures proper request/response pairing for each plugin. + ## 2. Plugin Interface ```golang type Plugin interface { - // PreHook is called before a request is processed by a provider - PreHook(ctx *context.Context, req *BifrostRequest) (*BifrostRequest, error) - - // PostHook is called after a response is received from a provider + // PreHook is called before a request is processed by a provider. + // It allows plugins to modify the request before it is sent to the provider. + // The context parameter can be used to maintain state across plugin calls. + // Returns the modified request, an optional response (if the plugin wants to short-circuit the provider call), and any error that occurred during processing. + // If a response is returned, the provider call is skipped and only the PostHook methods of plugins that had their PreHook executed are called in reverse order. + PreHook(ctx *context.Context, req *BifrostRequest) (*BifrostRequest, *BifrostResponse, error) + + // PostHook is called after a response is received from a provider. + // It allows plugins to modify the response before it is returned to the caller. + // Returns the modified response and any error that occurred during processing. PostHook(ctx *context.Context, result *BifrostResponse) (*BifrostResponse, error) + + // Cleanup is called on bifrost shutdown. + // It allows plugins to clean up any resources they have allocated. + // Returns any error that occurred during cleanup, which will be logged as a warning by the Bifrost instance. + Cleanup() error } ``` @@ -40,15 +53,21 @@ type CustomPlugin struct { // Your plugin fields } -func (p *CustomPlugin) PreHook(ctx *context.Context, req *BifrostRequest) (*BifrostRequest, error) { +func (p *CustomPlugin) PreHook(ctx *context.Context, req *BifrostRequest) (*BifrostRequest, *BifrostResponse, error) { // Modify request or add custom logic - return req, nil + // Return nil for response to continue with provider call + return req, nil, nil } func (p *CustomPlugin) PostHook(ctx *context.Context, result *BifrostResponse) (*BifrostResponse, error) { // Modify response or add custom logic return result, nil } + +func (p *CustomPlugin) Cleanup() error { + // Clean up any resources + return nil +} ``` ### Example: Rate Limiting Plugin @@ -64,16 +83,21 @@ func NewRateLimitPlugin(rps float64) *RateLimitPlugin { } } -func (p *RateLimitPlugin) PreHook(ctx *context.Context, req *BifrostRequest) (*BifrostRequest, error) { +func (p *RateLimitPlugin) PreHook(ctx *context.Context, req *BifrostRequest) (*BifrostRequest, *BifrostResponse, error) { if err := p.limiter.Wait(*ctx); err != nil { - return nil, err + return nil, nil, err } - return req, nil + return req, nil, nil } func (p *RateLimitPlugin) PostHook(ctx *context.Context, result *BifrostResponse) (*BifrostResponse, error) { return result, nil } + +func (p *RateLimitPlugin) Cleanup() error { + // Rate limiter doesn't need cleanup + return nil +} ``` ### Example: Logging Plugin @@ -87,15 +111,20 @@ func NewLoggingPlugin(logger schemas.Logger) *LoggingPlugin { return &LoggingPlugin{logger: logger} } -func (p *LoggingPlugin) PreHook(ctx *context.Context, req *BifrostRequest) (*BifrostRequest, error) { +func (p *LoggingPlugin) PreHook(ctx *context.Context, req *BifrostRequest) (*BifrostRequest, *BifrostResponse, error) { p.logger.Info(fmt.Sprintf("Request to %s with model %s", req.Provider, req.Model)) - return req, nil + return req, nil, nil } func (p *LoggingPlugin) PostHook(ctx *context.Context, result *BifrostResponse) (*BifrostResponse, error) { p.logger.Info(fmt.Sprintf("Response from %s with %d tokens", result.Model, result.Usage.TotalTokens)) return result, nil } + +func (p *LoggingPlugin) Cleanup() error { + // Logger doesn't need cleanup + return nil +} ``` ## 4. Using Plugins @@ -231,15 +260,20 @@ client, err := bifrost.Init(schemas.BifrostConfig{ } } - func (p *YourPlugin) PreHook(ctx *context.Context, req *schemas.BifrostRequest) (*schemas.BifrostRequest, error) { + func (p *YourPlugin) PreHook(ctx *context.Context, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.BifrostResponse, error) { // Implementation - return req, nil + return req, nil, nil } func (p *YourPlugin) PostHook(ctx *context.Context, result *schemas.BifrostResponse) (*schemas.BifrostResponse, error) { // Implementation return result, nil } + + func (p *YourPlugin) Cleanup() error { + // Clean up any resources + return nil + } ``` Example `README.md`: