Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 109 additions & 79 deletions core/bifrost.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,70 @@ type Bifrost struct {
backgroundCtx context.Context // Shared background context for nil context handling
}

// PluginPipeline encapsulates the execution of plugin PreHooks and PostHooks, tracks which plugins ran, and manages short-circuiting and error aggregation.
type PluginPipeline struct {
plugins []schemas.Plugin
logger schemas.Logger

// Indices of plugins whose PreHook ran (for reverse PostHook)
preHookRan []int
// Errors from PreHooks and PostHooks
preHookErrors []error
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

usually we can wrap it under a separate struct for better abstraction. Also we can create pool of these objects to avoid any runtime allocations

type Errs struct {
    Errors []errors
}

func (e *Errs) Print() {
}

func (e *Errs) Add(e error) {
}

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes I've made changes to this, I pushed it to the next pr by mistake

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add sync pools, even I was thinking about it

postHookErrors []error
}

// NewPluginPipeline creates a new pipeline for a given plugin slice and logger.
func NewPluginPipeline(plugins []schemas.Plugin, logger schemas.Logger) *PluginPipeline {
return &PluginPipeline{
plugins: plugins,
logger: logger,
}
}

// RunPreHooks executes PreHooks in order, tracks which ran, and returns the final request, any short-circuit response, and error.
func (p *PluginPipeline) RunPreHooks(ctx *context.Context, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.BifrostResponse, int) {
var resp *schemas.BifrostResponse
var err error
for i, plugin := range p.plugins {
req, resp, err = plugin.PreHook(ctx, req)
Comment thread
Pratham-Mishra04 marked this conversation as resolved.
if err != nil {
p.preHookErrors = append(p.preHookErrors, err)
p.logger.Warn(fmt.Sprintf("Error in PreHook for plugin %s: %v", plugin.GetName(), err))
}
p.preHookRan = append(p.preHookRan, i)
if resp != nil {
return req, resp, i + 1 // short-circuit: only plugins up to and including i ran
}
}
return req, nil, len(p.plugins)
}
Comment thread
Pratham-Mishra04 marked this conversation as resolved.

// RunPostHooks executes PostHooks in reverse order for the plugins whose PreHook ran.
// Accepts the response and error, and allows plugins to transform either (e.g., recover from error, or invalidate a response).
// Returns the final response and error after all hooks. If both are set, error takes precedence unless a plugin clears it.
func (p *PluginPipeline) RunPostHooks(ctx *context.Context, resp *schemas.BifrostResponse, bifrostErr *schemas.BifrostError, count int) (*schemas.BifrostResponse, *schemas.BifrostError) {
var err error
for i := count - 1; i >= 0; i-- {
Comment thread
Pratham-Mishra04 marked this conversation as resolved.
plugin := p.plugins[i]
resp, bifrostErr, err = plugin.PostHook(ctx, resp, bifrostErr)
if err != nil {
p.postHookErrors = append(p.postHookErrors, err)
p.logger.Warn(fmt.Sprintf("Error in PostHook for plugin %s: %v", plugin.GetName(), err))
}
// If a plugin recovers from an error (sets bifrostErr to nil and sets resp), allow that
// If a plugin invalidates a response (sets resp to nil and sets bifrostErr), allow that
}
// Final logic: if both are set, error takes precedence, unless error is nil
if bifrostErr != nil {
if resp != nil && bifrostErr.Error.Message == "" && bifrostErr.Error.Error == nil {
// Defensive: treat as recovery if error is empty
return resp, nil
}
return resp, bifrostErr
}
return resp, nil
Comment thread
Pratham-Mishra04 marked this conversation as resolved.
}

// createProviderFromProviderKey creates a new provider instance based on the provider key.
// It returns an error if the provider is not supported.
func (bifrost *Bifrost) createProviderFromProviderKey(providerKey schemas.ModelProvider, config *schemas.ProviderConfig) (schemas.Provider, error) {
Expand Down Expand Up @@ -517,51 +581,35 @@ func (bifrost *Bifrost) tryTextCompletion(req *schemas.BifrostRequest, ctx conte
return nil, newBifrostError(err)
}

var resp *schemas.BifrostResponse
var processedPluginCount int
for i, plugin := range bifrost.plugins {
req, resp, err = plugin.PreHook(&ctx, req)
if err != nil {
return nil, newBifrostError(err)
}
processedPluginCount = i + 1
if resp != nil {
// Run post-hooks in reverse order for plugins that had PreHook executed
for j := processedPluginCount - 1; j >= 0; j-- {
resp, err = bifrost.plugins[j].PostHook(&ctx, resp)
if err != nil {
return nil, newBifrostError(err)
}
}
return resp, nil
pipeline := NewPluginPipeline(bifrost.plugins, bifrost.logger)
preReq, preResp, preCount := pipeline.RunPreHooks(&ctx, req)
if preResp != nil {
resp, bifrostErr := pipeline.RunPostHooks(&ctx, preResp, nil, preCount)
// If PostHooks recovered from error, return resp; if not, return error
if bifrostErr != nil {
return nil, bifrostErr
}
return resp, nil
}

if req == nil {
if preReq == nil {
return nil, newBifrostErrorFromMsg("bifrost request after plugin hooks cannot be nil")
}

// Get a ChannelMessage from the pool
msg := bifrost.getChannelMessage(*req, TextCompletionRequest)
msg := bifrost.getChannelMessage(*preReq, TextCompletionRequest)
msg.Context = ctx

// Handle queue send with context and proper cleanup
select {
case queue <- *msg:
// Message was sent successfully
case <-ctx.Done():
// Request was cancelled by caller
bifrost.releaseChannelMessage(msg)
return nil, newBifrostErrorFromMsg("request cancelled while waiting for queue space")
default:
if bifrost.dropExcessRequests {
// Drop request immediately if configured to do so
bifrost.releaseChannelMessage(msg)
bifrost.logger.Warn("Request dropped: queue is full, please increase the queue size or set dropExcessRequests to false")
return nil, newBifrostErrorFromMsg("request dropped: queue is full")
}

// If not dropping excess requests, wait with context
if ctx == nil {
ctx = bifrost.backgroundCtx
}
Expand All @@ -574,26 +622,25 @@ func (bifrost *Bifrost) tryTextCompletion(req *schemas.BifrostRequest, ctx conte
}
}

// Handle response
var result *schemas.BifrostResponse
select {
case result = <-msg.Response:
// Run plugins in reverse order
for i := len(bifrost.plugins) - 1; i >= 0; i-- {
result, err = bifrost.plugins[i].PostHook(&ctx, result)
if err != nil {
bifrost.releaseChannelMessage(msg)
return nil, newBifrostError(err)
}
resp, bifrostErr := pipeline.RunPostHooks(&ctx, result, nil, len(bifrost.plugins))
if bifrostErr != nil {
bifrost.releaseChannelMessage(msg)
return nil, bifrostErr
}
case err := <-msg.Err:
bifrost.releaseChannelMessage(msg)
return nil, &err
return resp, nil
case bifrostErrVal := <-msg.Err:
bifrostErrPtr := &bifrostErrVal
resp, bifrostErrPtr := pipeline.RunPostHooks(&ctx, nil, bifrostErrPtr, len(bifrost.plugins))
bifrost.releaseChannelMessage(msg)
if bifrostErrPtr != nil {
return nil, bifrostErrPtr
}
return resp, nil
}

// Return message to pool
bifrost.releaseChannelMessage(msg)
return result, nil
}

// ChatCompletionRequest sends a chat completion request to the specified provider.
Expand Down Expand Up @@ -654,50 +701,34 @@ func (bifrost *Bifrost) tryChatCompletion(req *schemas.BifrostRequest, ctx conte
return nil, newBifrostError(err)
}

var resp *schemas.BifrostResponse
var processedPluginCount int
for i, plugin := range bifrost.plugins {
req, resp, err = plugin.PreHook(&ctx, req)
if err != nil {
return nil, newBifrostError(err)
}
processedPluginCount = i + 1
if resp != nil {
// Run post-hooks in reverse order for plugins that had PreHook executed
for j := processedPluginCount - 1; j >= 0; j-- {
resp, err = bifrost.plugins[j].PostHook(&ctx, resp)
if err != nil {
return nil, newBifrostError(err)
}
}
return resp, nil
pipeline := NewPluginPipeline(bifrost.plugins, bifrost.logger)
preReq, preResp, preCount := pipeline.RunPreHooks(&ctx, req)
if preResp != nil {
resp, bifrostErr := pipeline.RunPostHooks(&ctx, preResp, nil, preCount)
if bifrostErr != nil {
return nil, bifrostErr
}
return resp, nil
}

if req == nil {
if preReq == nil {
return nil, newBifrostErrorFromMsg("bifrost request after plugin hooks cannot be nil")
}

// Get a ChannelMessage from the pool
msg := bifrost.getChannelMessage(*req, ChatCompletionRequest)
msg := bifrost.getChannelMessage(*preReq, ChatCompletionRequest)
msg.Context = ctx

// Handle queue send with context and proper cleanup
select {
case queue <- *msg:
// Message was sent successfully
case <-ctx.Done():
// Request was cancelled by caller
bifrost.releaseChannelMessage(msg)
return nil, newBifrostErrorFromMsg("request cancelled while waiting for queue space")
default:
if bifrost.dropExcessRequests {
// Drop request immediately if configured to do so
bifrost.releaseChannelMessage(msg)
bifrost.logger.Warn("Request dropped: queue is full, please increase the queue size or set dropExcessRequests to false")
return nil, newBifrostErrorFromMsg("request dropped: queue is full")
}
// If not dropping excess requests, wait with context
if ctx == nil {
ctx = bifrost.backgroundCtx
}
Expand All @@ -710,26 +741,25 @@ func (bifrost *Bifrost) tryChatCompletion(req *schemas.BifrostRequest, ctx conte
}
}

// Handle response
var result *schemas.BifrostResponse
select {
case result = <-msg.Response:
// Run plugins in reverse order
for i := len(bifrost.plugins) - 1; i >= 0; i-- {
result, err = bifrost.plugins[i].PostHook(&ctx, result)
if err != nil {
bifrost.releaseChannelMessage(msg)
return nil, newBifrostError(err)
}
resp, bifrostErr := pipeline.RunPostHooks(&ctx, result, nil, len(bifrost.plugins))
if bifrostErr != nil {
bifrost.releaseChannelMessage(msg)
return nil, bifrostErr
}
case err := <-msg.Err:
bifrost.releaseChannelMessage(msg)
return nil, &err
return resp, nil
case bifrostErrVal := <-msg.Err:
bifrostErrPtr := &bifrostErrVal
resp, bifrostErrPtr := pipeline.RunPostHooks(&ctx, nil, bifrostErrPtr, len(bifrost.plugins))
bifrost.releaseChannelMessage(msg)
if bifrostErrPtr != nil {
return nil, bifrostErrPtr
}
return resp, nil
Comment thread
Pratham-Mishra04 marked this conversation as resolved.
}

// Return message to pool
bifrost.releaseChannelMessage(msg)
return result, nil
}

// Cleanup gracefully stops all workers when triggered.
Expand Down
9 changes: 6 additions & 3 deletions core/schemas/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ import "context"
// No Plugin errors are returned to the caller, they are logged as warnings by the Bifrost instance.

type Plugin interface {
// GetName returns the name of the plugin.
GetName() string

// PreHook is called before a request is processed by a provider.
// It allows plugins to modify the request before it is sent to the provider.
// The context parameter can be used to maintain state across plugin calls.
Expand All @@ -27,9 +30,9 @@ type Plugin interface {
PreHook(ctx *context.Context, req *BifrostRequest) (*BifrostRequest, *BifrostResponse, error)

// PostHook is called after a response is received from a provider.
// It allows plugins to modify the response before it is returned to the caller.
// Returns the modified response and any error that occurred during processing.
PostHook(ctx *context.Context, result *BifrostResponse) (*BifrostResponse, error)
// It allows plugins to modify the response/error before it is returned to the caller.
// Returns the modified response, bifrost error and any error that occurred during processing.
PostHook(ctx *context.Context, result *BifrostResponse, err *BifrostError) (*BifrostResponse, *BifrostError, error)

// Cleanup is called on bifrost shutdown.
// It allows plugins to clean up any resources they have allocated.
Expand Down