Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 21 additions & 11 deletions core/bifrost.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,13 @@ type Bifrost struct {
backgroundCtx context.Context // Shared background context for nil context handling
}

// PluginPipeline encapsulates the execution of plugin PreHooks and PostHooks, tracks which plugins ran, and manages short-circuiting and error aggregation.
// PluginPipeline encapsulates the execution of plugin PreHooks and PostHooks, tracks how many plugins ran, and manages short-circuiting and error aggregation.
type PluginPipeline struct {
plugins []schemas.Plugin
logger schemas.Logger

// Indices of plugins whose PreHook ran (for reverse PostHook)
preHookRan []int
// Number of PreHooks that were executed (used to determine which PostHooks to run in reverse order)
executedPreHooks int
// Errors from PreHooks and PostHooks
preHookErrors []error
postHookErrors []error
Expand All @@ -70,7 +70,7 @@ func NewPluginPipeline(plugins []schemas.Plugin, logger schemas.Logger) *PluginP
}
}

// RunPreHooks executes PreHooks in order, tracks which ran, and returns the final request, any short-circuit response, and error.
// RunPreHooks executes PreHooks in order, tracks how many ran, and returns the final request, any short-circuit response, and the count.
func (p *PluginPipeline) RunPreHooks(ctx *context.Context, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.BifrostResponse, int) {
var resp *schemas.BifrostResponse
var err error
Expand All @@ -80,18 +80,25 @@ func (p *PluginPipeline) RunPreHooks(ctx *context.Context, req *schemas.BifrostR
p.preHookErrors = append(p.preHookErrors, err)
p.logger.Warn(fmt.Sprintf("Error in PreHook for plugin %s: %v", plugin.GetName(), err))
}
p.preHookRan = append(p.preHookRan, i)
p.executedPreHooks = i + 1
if resp != nil {
return req, resp, i + 1 // short-circuit: only plugins up to and including i ran
return req, resp, p.executedPreHooks // short-circuit: only plugins up to and including i ran
}
}
return req, nil, len(p.plugins)
return req, nil, p.executedPreHooks
}

// RunPostHooks executes PostHooks in reverse order for the plugins whose PreHook ran.
// Accepts the response and error, and allows plugins to transform either (e.g., recover from error, or invalidate a response).
// Returns the final response and error after all hooks. If both are set, error takes precedence unless a plugin clears it.
// Returns the final response and error after all hooks. If both are set, error takes precedence unless error is nil.
func (p *PluginPipeline) RunPostHooks(ctx *context.Context, resp *schemas.BifrostResponse, bifrostErr *schemas.BifrostError, count int) (*schemas.BifrostResponse, *schemas.BifrostError) {
// Defensive: ensure count is within valid bounds
if count < 0 {
count = 0
}
if count > len(p.plugins) {
count = len(p.plugins)
}
var err error
for i := count - 1; i >= 0; i-- {
Comment thread
Pratham-Mishra04 marked this conversation as resolved.
plugin := p.plugins[i]
Expand All @@ -105,7 +112,8 @@ func (p *PluginPipeline) RunPostHooks(ctx *context.Context, resp *schemas.Bifros
}
// Final logic: if both are set, error takes precedence, unless error is nil
if bifrostErr != nil {
if resp != nil && bifrostErr.Error.Message == "" && bifrostErr.Error.Error == nil {
if resp != nil && bifrostErr.StatusCode == nil && bifrostErr.Error.Type == nil &&
bifrostErr.Error.Message == "" && bifrostErr.Error.Error == nil {
// Defensive: treat as recovery if error is empty
return resp, nil
}
Expand Down Expand Up @@ -623,6 +631,7 @@ func (bifrost *Bifrost) tryTextCompletion(req *schemas.BifrostRequest, ctx conte
}

var result *schemas.BifrostResponse
var resp *schemas.BifrostResponse
select {
case result = <-msg.Response:
resp, bifrostErr := pipeline.RunPostHooks(&ctx, result, nil, len(bifrost.plugins))
Expand All @@ -634,7 +643,7 @@ func (bifrost *Bifrost) tryTextCompletion(req *schemas.BifrostRequest, ctx conte
return resp, nil
case bifrostErrVal := <-msg.Err:
bifrostErrPtr := &bifrostErrVal
resp, bifrostErrPtr := pipeline.RunPostHooks(&ctx, nil, bifrostErrPtr, len(bifrost.plugins))
resp, bifrostErrPtr = pipeline.RunPostHooks(&ctx, nil, bifrostErrPtr, len(bifrost.plugins))
bifrost.releaseChannelMessage(msg)
if bifrostErrPtr != nil {
return nil, bifrostErrPtr
Expand Down Expand Up @@ -742,6 +751,7 @@ func (bifrost *Bifrost) tryChatCompletion(req *schemas.BifrostRequest, ctx conte
}

var result *schemas.BifrostResponse
var resp *schemas.BifrostResponse
select {
case result = <-msg.Response:
resp, bifrostErr := pipeline.RunPostHooks(&ctx, result, nil, len(bifrost.plugins))
Expand All @@ -753,7 +763,7 @@ func (bifrost *Bifrost) tryChatCompletion(req *schemas.BifrostRequest, ctx conte
return resp, nil
case bifrostErrVal := <-msg.Err:
bifrostErrPtr := &bifrostErrVal
resp, bifrostErrPtr := pipeline.RunPostHooks(&ctx, nil, bifrostErrPtr, len(bifrost.plugins))
resp, bifrostErrPtr = pipeline.RunPostHooks(&ctx, nil, bifrostErrPtr, len(bifrost.plugins))
bifrost.releaseChannelMessage(msg)
if bifrostErrPtr != nil {
return nil, bifrostErrPtr
Expand Down
23 changes: 16 additions & 7 deletions core/providers/anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,10 @@ type AnthropicImageContent struct {

// AnthropicProvider implements the Provider interface for Anthropic's Claude API.
type AnthropicProvider struct {
logger schemas.Logger // Logger for provider operations
client *fasthttp.Client // HTTP client for API requests
logger schemas.Logger // Logger for provider operations
client *fasthttp.Client // HTTP client for API requests
baseURL string // Base URL for the provider
apiVersion string // API version for the provider
}
Comment thread
Pratham-Mishra04 marked this conversation as resolved.

// anthropicChatResponsePool provides a pool for Anthropic chat response objects.
Expand Down Expand Up @@ -145,9 +147,16 @@ func NewAnthropicProvider(config *schemas.ProviderConfig, logger schemas.Logger)
// Configure proxy if provided
client = configureProxy(client, config.ProxyConfig, logger)

baseURL := strings.TrimRight(config.NetworkConfig.BaseURL, "/")
if baseURL == "" {
baseURL = "https://api.anthropic.com"
}

return &AnthropicProvider{
logger: logger,
client: client,
logger: logger,
client: client,
baseURL: baseURL,
apiVersion: "2023-06-01",
}
}

Expand Down Expand Up @@ -198,7 +207,7 @@ func (provider *AnthropicProvider) completeRequest(ctx context.Context, requestB
req.Header.SetMethod("POST")
req.Header.SetContentType("application/json")
req.Header.Set("x-api-key", key)
req.Header.Set("anthropic-version", "2023-06-01")
req.Header.Set("anthropic-version", provider.apiVersion)
req.SetBody(jsonData)

// Send the request
Expand Down Expand Up @@ -238,7 +247,7 @@ func (provider *AnthropicProvider) TextCompletion(ctx context.Context, model, ke
"prompt": fmt.Sprintf("\n\nHuman: %s\n\nAssistant:", text),
}, preparedParams)

responseBody, err := provider.completeRequest(ctx, requestBody, "https://api.anthropic.com/v1/complete", key)
responseBody, err := provider.completeRequest(ctx, requestBody, provider.baseURL+"/v1/complete", key)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -294,7 +303,7 @@ func (provider *AnthropicProvider) ChatCompletion(ctx context.Context, model, ke
"messages": formattedMessages,
}, preparedParams)

responseBody, err := provider.completeRequest(ctx, requestBody, "https://api.anthropic.com/v1/messages", key)
responseBody, err := provider.completeRequest(ctx, requestBody, provider.baseURL+"/v1/messages", key)
if err != nil {
return nil, err
}
Expand Down
17 changes: 12 additions & 5 deletions core/providers/cohere.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,9 @@ type CohereError struct {

// CohereProvider implements the Provider interface for Cohere.
type CohereProvider struct {
logger schemas.Logger // Logger for provider operations
client *fasthttp.Client // HTTP client for API requests
logger schemas.Logger // Logger for provider operations
client *fasthttp.Client // HTTP client for API requests
baseURL string // Base URL for the provider
}
Comment thread
Pratham-Mishra04 marked this conversation as resolved.

// NewCohereProvider creates a new Cohere provider instance.
Expand All @@ -117,9 +118,15 @@ func NewCohereProvider(config *schemas.ProviderConfig, logger schemas.Logger) *C
bifrostResponsePool.Put(&schemas.BifrostResponse{})
}

baseURL := strings.TrimRight(config.NetworkConfig.BaseURL, "/")
if baseURL == "" {
baseURL = "https://api.cohere.ai"
}
Comment thread
Pratham-Mishra04 marked this conversation as resolved.

return &CohereProvider{
logger: logger,
client: client,
logger: logger,
client: client,
baseURL: baseURL,
}
}

Expand Down Expand Up @@ -339,7 +346,7 @@ func (provider *CohereProvider) ChatCompletion(ctx context.Context, model, key s
defer fasthttp.ReleaseRequest(req)
defer fasthttp.ReleaseResponse(resp)

req.SetRequestURI("https://api.cohere.ai/v1/chat")
req.SetRequestURI(provider.baseURL + "/v1/chat")
req.Header.SetMethod("POST")
req.Header.SetContentType("application/json")
req.Header.Set("Authorization", "Bearer "+key)
Expand Down
18 changes: 13 additions & 5 deletions core/providers/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package providers
import (
"context"
"fmt"
"strings"
"sync"
"time"

Expand Down Expand Up @@ -64,8 +65,9 @@ func releaseOpenAIResponse(resp *OpenAIResponse) {

// OpenAIProvider implements the Provider interface for OpenAI's API.
type OpenAIProvider struct {
logger schemas.Logger // Logger for provider operations
client *fasthttp.Client // HTTP client for API requests
logger schemas.Logger // Logger for provider operations
client *fasthttp.Client // HTTP client for API requests
baseURL string // Base URL for the provider
}

// NewOpenAIProvider creates a new OpenAI provider instance.
Expand All @@ -89,9 +91,15 @@ func NewOpenAIProvider(config *schemas.ProviderConfig, logger schemas.Logger) *O
// Configure proxy if provided
client = configureProxy(client, config.ProxyConfig, logger)

baseURL := strings.TrimRight(config.NetworkConfig.BaseURL, "/")
if baseURL == "" {
baseURL = "https://api.openai.com"
}

Comment thread
Pratham-Mishra04 marked this conversation as resolved.
return &OpenAIProvider{
logger: logger,
client: client,
logger: logger,
client: client,
baseURL: baseURL,
}
}

Expand Down Expand Up @@ -139,7 +147,7 @@ func (provider *OpenAIProvider) ChatCompletion(ctx context.Context, model, key s
defer fasthttp.ReleaseRequest(req)
defer fasthttp.ReleaseResponse(resp)

req.SetRequestURI("https://api.openai.com/v1/chat/completions")
req.SetRequestURI(provider.baseURL + "/v1/chat/completions")
req.Header.SetMethod("POST")
req.Header.SetContentType("application/json")
req.Header.Set("Authorization", "Bearer "+key)
Comment thread
Pratham-Mishra04 marked this conversation as resolved.
Expand Down
21 changes: 15 additions & 6 deletions core/schemas/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,22 @@ import "context"
// User can provide multiple plugins in the BifrostConfig.
// PreHooks are executed in the order they are registered.
// PostHooks are executed in the reverse order of PreHooks.

//
// PreHooks and PostHooks can be used to implement custom logic, such as:
// - Rate limiting
// - Caching
// - Logging
// - Monitoring

// No Plugin errors are returned to the caller, they are logged as warnings by the Bifrost instance.
//
// Plugin error handling:
// - No Plugin errors are returned to the caller; they are logged as warnings by the Bifrost instance.
// - PreHook and PostHook can both modify the request/response and the error. Plugins can recover from errors (set error to nil and provide a response), or invalidate a response (set response to nil and provide an error).
// - PostHook is always called with both the current response and error, and should handle either being nil.
// - Only truly empty errors (no message, no error, no status code, no type) are treated as recoveries by the pipeline.
// - If a PreHook returns a response, the provider call is skipped and only the PostHook methods of plugins that had their PreHook executed are called in reverse order.
// - The plugin pipeline ensures symmetry: for every PreHook executed, the corresponding PostHook will be called in reverse order.
//
// Plugin authors should ensure their hooks are robust to both response and error being nil, and should not assume either is always present.

type Plugin interface {
// GetName returns the name of the plugin.
Expand All @@ -29,9 +37,10 @@ type Plugin interface {
// If a response is returned, the provider call is skipped and only the PostHook methods of plugins that had their PreHook executed are called in reverse order.
PreHook(ctx *context.Context, req *BifrostRequest) (*BifrostRequest, *BifrostResponse, error)

// PostHook is called after a response is received from a provider.
// It allows plugins to modify the response/error before it is returned to the caller.
// Returns the modified response, bifrost error and any error that occurred during processing.
// PostHook is called after a response is received from a provider or a PreHook short-circuit.
// It allows plugins to modify the response and/or error before it is returned to the caller.
// Plugins can recover from errors (set error to nil and provide a response), or invalidate a response (set response to nil and provide an error).
// Returns the modified response, bifrost error, and any error that occurred during processing.
PostHook(ctx *context.Context, result *BifrostResponse, err *BifrostError) (*BifrostResponse, *BifrostError, error)

// Cleanup is called on bifrost shutdown.
Expand Down
2 changes: 2 additions & 0 deletions core/schemas/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ const (

// NetworkConfig represents the network configuration for provider connections.
type NetworkConfig struct {
// BaseURL is only supported for OpenAI, Anthropic and Cohere providers
BaseURL string `json:"base_url,omitempty"` // Base URL for the provider (optional)
DefaultRequestTimeoutInSeconds int `json:"default_request_timeout_in_seconds"` // Default timeout for requests
MaxRetries int `json:"max_retries"` // Maximum number of retries
RetryBackoffInitial time.Duration `json:"retry_backoff_initial"` // Initial backoff duration
Expand Down
Loading