diff --git a/core/bifrost.go b/core/bifrost.go index 3cf8415556..395949a2b8 100644 --- a/core/bifrost.go +++ b/core/bifrost.go @@ -50,6 +50,70 @@ type Bifrost struct { backgroundCtx context.Context // Shared background context for nil context handling } +// PluginPipeline encapsulates the execution of plugin PreHooks and PostHooks, tracks which plugins ran, and manages short-circuiting and error aggregation. +type PluginPipeline struct { + plugins []schemas.Plugin + logger schemas.Logger + + // Indices of plugins whose PreHook ran (for reverse PostHook) + preHookRan []int + // Errors from PreHooks and PostHooks + preHookErrors []error + postHookErrors []error +} + +// NewPluginPipeline creates a new pipeline for a given plugin slice and logger. +func NewPluginPipeline(plugins []schemas.Plugin, logger schemas.Logger) *PluginPipeline { + return &PluginPipeline{ + plugins: plugins, + logger: logger, + } +} + +// RunPreHooks executes PreHooks in order, tracks which ran, and returns the final request, any short-circuit response, and error. +func (p *PluginPipeline) RunPreHooks(ctx *context.Context, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.BifrostResponse, int) { + var resp *schemas.BifrostResponse + var err error + for i, plugin := range p.plugins { + req, resp, err = plugin.PreHook(ctx, req) + if err != nil { + p.preHookErrors = append(p.preHookErrors, err) + p.logger.Warn(fmt.Sprintf("Error in PreHook for plugin %s: %v", plugin.GetName(), err)) + } + p.preHookRan = append(p.preHookRan, i) + if resp != nil { + return req, resp, i + 1 // short-circuit: only plugins up to and including i ran + } + } + return req, nil, len(p.plugins) +} + +// RunPostHooks executes PostHooks in reverse order for the plugins whose PreHook ran. +// Accepts the response and error, and allows plugins to transform either (e.g., recover from error, or invalidate a response). +// Returns the final response and error after all hooks. If both are set, error takes precedence unless a plugin clears it. +func (p *PluginPipeline) RunPostHooks(ctx *context.Context, resp *schemas.BifrostResponse, bifrostErr *schemas.BifrostError, count int) (*schemas.BifrostResponse, *schemas.BifrostError) { + var err error + for i := count - 1; i >= 0; i-- { + plugin := p.plugins[i] + resp, bifrostErr, err = plugin.PostHook(ctx, resp, bifrostErr) + if err != nil { + p.postHookErrors = append(p.postHookErrors, err) + p.logger.Warn(fmt.Sprintf("Error in PostHook for plugin %s: %v", plugin.GetName(), err)) + } + // If a plugin recovers from an error (sets bifrostErr to nil and sets resp), allow that + // If a plugin invalidates a response (sets resp to nil and sets bifrostErr), allow that + } + // Final logic: if both are set, error takes precedence, unless error is nil + if bifrostErr != nil { + if resp != nil && bifrostErr.Error.Message == "" && bifrostErr.Error.Error == nil { + // Defensive: treat as recovery if error is empty + return resp, nil + } + return resp, bifrostErr + } + return resp, nil +} + // createProviderFromProviderKey creates a new provider instance based on the provider key. // It returns an error if the provider is not supported. func (bifrost *Bifrost) createProviderFromProviderKey(providerKey schemas.ModelProvider, config *schemas.ProviderConfig) (schemas.Provider, error) { @@ -517,51 +581,35 @@ func (bifrost *Bifrost) tryTextCompletion(req *schemas.BifrostRequest, ctx conte return nil, newBifrostError(err) } - var resp *schemas.BifrostResponse - var processedPluginCount int - for i, plugin := range bifrost.plugins { - req, resp, err = plugin.PreHook(&ctx, req) - if err != nil { - return nil, newBifrostError(err) - } - processedPluginCount = i + 1 - if resp != nil { - // Run post-hooks in reverse order for plugins that had PreHook executed - for j := processedPluginCount - 1; j >= 0; j-- { - resp, err = bifrost.plugins[j].PostHook(&ctx, resp) - if err != nil { - return nil, newBifrostError(err) - } - } - return resp, nil + pipeline := NewPluginPipeline(bifrost.plugins, bifrost.logger) + preReq, preResp, preCount := pipeline.RunPreHooks(&ctx, req) + if preResp != nil { + resp, bifrostErr := pipeline.RunPostHooks(&ctx, preResp, nil, preCount) + // If PostHooks recovered from error, return resp; if not, return error + if bifrostErr != nil { + return nil, bifrostErr } + return resp, nil } - - if req == nil { + if preReq == nil { return nil, newBifrostErrorFromMsg("bifrost request after plugin hooks cannot be nil") } - // Get a ChannelMessage from the pool - msg := bifrost.getChannelMessage(*req, TextCompletionRequest) + msg := bifrost.getChannelMessage(*preReq, TextCompletionRequest) msg.Context = ctx - // Handle queue send with context and proper cleanup select { case queue <- *msg: // Message was sent successfully case <-ctx.Done(): - // Request was cancelled by caller bifrost.releaseChannelMessage(msg) return nil, newBifrostErrorFromMsg("request cancelled while waiting for queue space") default: if bifrost.dropExcessRequests { - // Drop request immediately if configured to do so bifrost.releaseChannelMessage(msg) bifrost.logger.Warn("Request dropped: queue is full, please increase the queue size or set dropExcessRequests to false") return nil, newBifrostErrorFromMsg("request dropped: queue is full") } - - // If not dropping excess requests, wait with context if ctx == nil { ctx = bifrost.backgroundCtx } @@ -574,26 +622,25 @@ func (bifrost *Bifrost) tryTextCompletion(req *schemas.BifrostRequest, ctx conte } } - // Handle response var result *schemas.BifrostResponse select { case result = <-msg.Response: - // Run plugins in reverse order - for i := len(bifrost.plugins) - 1; i >= 0; i-- { - result, err = bifrost.plugins[i].PostHook(&ctx, result) - if err != nil { - bifrost.releaseChannelMessage(msg) - return nil, newBifrostError(err) - } + resp, bifrostErr := pipeline.RunPostHooks(&ctx, result, nil, len(bifrost.plugins)) + if bifrostErr != nil { + bifrost.releaseChannelMessage(msg) + return nil, bifrostErr } - case err := <-msg.Err: bifrost.releaseChannelMessage(msg) - return nil, &err + return resp, nil + case bifrostErrVal := <-msg.Err: + bifrostErrPtr := &bifrostErrVal + resp, bifrostErrPtr := pipeline.RunPostHooks(&ctx, nil, bifrostErrPtr, len(bifrost.plugins)) + bifrost.releaseChannelMessage(msg) + if bifrostErrPtr != nil { + return nil, bifrostErrPtr + } + return resp, nil } - - // Return message to pool - bifrost.releaseChannelMessage(msg) - return result, nil } // ChatCompletionRequest sends a chat completion request to the specified provider. @@ -654,50 +701,34 @@ func (bifrost *Bifrost) tryChatCompletion(req *schemas.BifrostRequest, ctx conte return nil, newBifrostError(err) } - var resp *schemas.BifrostResponse - var processedPluginCount int - for i, plugin := range bifrost.plugins { - req, resp, err = plugin.PreHook(&ctx, req) - if err != nil { - return nil, newBifrostError(err) - } - processedPluginCount = i + 1 - if resp != nil { - // Run post-hooks in reverse order for plugins that had PreHook executed - for j := processedPluginCount - 1; j >= 0; j-- { - resp, err = bifrost.plugins[j].PostHook(&ctx, resp) - if err != nil { - return nil, newBifrostError(err) - } - } - return resp, nil + pipeline := NewPluginPipeline(bifrost.plugins, bifrost.logger) + preReq, preResp, preCount := pipeline.RunPreHooks(&ctx, req) + if preResp != nil { + resp, bifrostErr := pipeline.RunPostHooks(&ctx, preResp, nil, preCount) + if bifrostErr != nil { + return nil, bifrostErr } + return resp, nil } - - if req == nil { + if preReq == nil { return nil, newBifrostErrorFromMsg("bifrost request after plugin hooks cannot be nil") } - // Get a ChannelMessage from the pool - msg := bifrost.getChannelMessage(*req, ChatCompletionRequest) + msg := bifrost.getChannelMessage(*preReq, ChatCompletionRequest) msg.Context = ctx - // Handle queue send with context and proper cleanup select { case queue <- *msg: // Message was sent successfully case <-ctx.Done(): - // Request was cancelled by caller bifrost.releaseChannelMessage(msg) return nil, newBifrostErrorFromMsg("request cancelled while waiting for queue space") default: if bifrost.dropExcessRequests { - // Drop request immediately if configured to do so bifrost.releaseChannelMessage(msg) bifrost.logger.Warn("Request dropped: queue is full, please increase the queue size or set dropExcessRequests to false") return nil, newBifrostErrorFromMsg("request dropped: queue is full") } - // If not dropping excess requests, wait with context if ctx == nil { ctx = bifrost.backgroundCtx } @@ -710,26 +741,25 @@ func (bifrost *Bifrost) tryChatCompletion(req *schemas.BifrostRequest, ctx conte } } - // Handle response var result *schemas.BifrostResponse select { case result = <-msg.Response: - // Run plugins in reverse order - for i := len(bifrost.plugins) - 1; i >= 0; i-- { - result, err = bifrost.plugins[i].PostHook(&ctx, result) - if err != nil { - bifrost.releaseChannelMessage(msg) - return nil, newBifrostError(err) - } + resp, bifrostErr := pipeline.RunPostHooks(&ctx, result, nil, len(bifrost.plugins)) + if bifrostErr != nil { + bifrost.releaseChannelMessage(msg) + return nil, bifrostErr } - case err := <-msg.Err: bifrost.releaseChannelMessage(msg) - return nil, &err + return resp, nil + case bifrostErrVal := <-msg.Err: + bifrostErrPtr := &bifrostErrVal + resp, bifrostErrPtr := pipeline.RunPostHooks(&ctx, nil, bifrostErrPtr, len(bifrost.plugins)) + bifrost.releaseChannelMessage(msg) + if bifrostErrPtr != nil { + return nil, bifrostErrPtr + } + return resp, nil } - - // Return message to pool - bifrost.releaseChannelMessage(msg) - return result, nil } // Cleanup gracefully stops all workers when triggered. diff --git a/core/schemas/plugin.go b/core/schemas/plugin.go index e16d30cac5..755427d0e9 100644 --- a/core/schemas/plugin.go +++ b/core/schemas/plugin.go @@ -19,6 +19,9 @@ import "context" // No Plugin errors are returned to the caller, they are logged as warnings by the Bifrost instance. type Plugin interface { + // GetName returns the name of the plugin. + GetName() string + // PreHook is called before a request is processed by a provider. // It allows plugins to modify the request before it is sent to the provider. // The context parameter can be used to maintain state across plugin calls. @@ -27,9 +30,9 @@ type Plugin interface { PreHook(ctx *context.Context, req *BifrostRequest) (*BifrostRequest, *BifrostResponse, error) // PostHook is called after a response is received from a provider. - // It allows plugins to modify the response before it is returned to the caller. - // Returns the modified response and any error that occurred during processing. - PostHook(ctx *context.Context, result *BifrostResponse) (*BifrostResponse, error) + // It allows plugins to modify the response/error before it is returned to the caller. + // Returns the modified response, bifrost error and any error that occurred during processing. + PostHook(ctx *context.Context, result *BifrostResponse, err *BifrostError) (*BifrostResponse, *BifrostError, error) // Cleanup is called on bifrost shutdown. // It allows plugins to clean up any resources they have allocated.