diff --git a/README.md b/README.md index b1cba0bcae..a00936be34 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ Bifrost is a high-performance AI gateway that connects you to 10+ providers (Ope **What You Need** - Any AI provider API key (OpenAI, Anthropic, Bedrock, etc.) -- Node.js 18+ installed (or use Docker instead via [Docker installation](#using-bifrost-http-transport)) +- Node.js 18+ installed (or use Docker instead via [Docker installation](./docs/quickstart/http-transport.md)) - 20 seconds of your time ⏰ ### Using Bifrost HTTP Transport diff --git a/core/bifrost.go b/core/bifrost.go index c0de1d3889..d3d576e39b 100644 --- a/core/bifrost.go +++ b/core/bifrost.go @@ -485,14 +485,6 @@ func (bifrost *Bifrost) UpdateProviderConcurrency(providerKey schemas.ModelProvi oldQueue := oldQueueValue.(chan ChannelMessage) - // Check if the provider has any keys (skip keyless providers) - if providerRequiresKey(providerKey) { - keys, err := bifrost.account.GetKeysForProvider(providerKey) - if err != nil || len(keys) == 0 { - return fmt.Errorf("failed to get keys for provider %s: %v", providerKey, err) - } - } - bifrost.logger.Debug(fmt.Sprintf("Gracefully stopping existing workers for provider %s", providerKey)) // Step 1: Create new queue with updated buffer size @@ -836,14 +828,6 @@ func (bifrost *Bifrost) prepareProvider(providerKey schemas.ModelProvider, confi return fmt.Errorf("failed to get config for provider: %v", err) } - // Check if the provider has any keys (skip keyless providers) - if providerRequiresKey(providerKey) { - keys, err := bifrost.account.GetKeysForProvider(providerKey) - if err != nil || len(keys) == 0 { - return fmt.Errorf("failed to get keys for provider: %v", err) - } - } - queue := make(chan ChannelMessage, providerConfig.ConcurrencyAndBufferSize.BufferSize) // Buffered channel per provider bifrost.requestQueues.Store(providerKey, queue) @@ -1094,7 +1078,7 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, queue chan Chan key := schemas.Key{} if providerRequiresKey(provider.GetProviderKey()) { - key, err = bifrost.selectKeyFromProviderForModel(provider.GetProviderKey(), req.Model) + key, err = bifrost.selectKeyFromProviderForModel(&req.Context, provider.GetProviderKey(), req.Model) if err != nil { bifrost.logger.Warn(fmt.Sprintf("Error selecting key for model %s: %v", req.Model, err)) req.Err <- schemas.BifrostError{ @@ -1384,8 +1368,8 @@ func (bifrost *Bifrost) releaseChannelMessage(msg *ChannelMessage) { // selectKeyFromProviderForModel selects an appropriate API key for a given provider and model. // It uses weighted random selection if multiple keys are available. -func (bifrost *Bifrost) selectKeyFromProviderForModel(providerKey schemas.ModelProvider, model string) (schemas.Key, error) { - keys, err := bifrost.account.GetKeysForProvider(providerKey) +func (bifrost *Bifrost) selectKeyFromProviderForModel(ctx *context.Context, providerKey schemas.ModelProvider, model string) (schemas.Key, error) { + keys, err := bifrost.account.GetKeysForProvider(ctx, providerKey) if err != nil { return schemas.Key{}, err } diff --git a/core/schemas/account.go b/core/schemas/account.go index 9d4121d09a..1572624e0c 100644 --- a/core/schemas/account.go +++ b/core/schemas/account.go @@ -1,6 +1,8 @@ // Package schemas defines the core schemas and types used by the Bifrost system. package schemas +import "context" + // Key represents an API key and its associated configuration for a provider. // It contains the key value, supported models, and a weight for load balancing. type Key struct { @@ -37,7 +39,10 @@ type Account interface { // GetKeysForProvider returns the API keys configured for a specific provider. // The keys include their values, supported models, and weights for load balancing. - GetKeysForProvider(providerKey ModelProvider) ([]Key, error) + // The context can carry data from any source that sets values before the Bifrost request, + // including but not limited to plugin pre-hooks, application logic, or any in app middleware sharing the context. + // This enables dynamic key selection based on any context values present during the request. + GetKeysForProvider(ctx *context.Context, providerKey ModelProvider) ([]Key, error) // GetConfigForProvider returns the configuration for a specific provider. // This includes network settings, authentication details, and other provider-specific diff --git a/docs/quickstart/go-package.md b/docs/quickstart/go-package.md index 177c4a0bcd..5c7da476a7 100644 --- a/docs/quickstart/go-package.md +++ b/docs/quickstart/go-package.md @@ -39,7 +39,7 @@ func (a *MyAccount) GetConfiguredProviders() ([]schemas.ModelProvider, error) { return []schemas.ModelProvider{schemas.OpenAI}, nil } -func (a *MyAccount) GetKeysForProvider(provider schemas.ModelProvider) ([]schemas.Key, error) { +func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) { if provider == schemas.OpenAI { return []schemas.Key{{ Value: os.Getenv("OPENAI_API_KEY"), @@ -119,7 +119,7 @@ func (a *MyAccount) GetConfiguredProviders() ([]schemas.ModelProvider, error) { } // Update GetKeysForProvider to handle both providers -func (a *MyAccount) GetKeysForProvider(provider schemas.ModelProvider) ([]schemas.Key, error) { +func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) { switch provider { case schemas.OpenAI: return []schemas.Key{{ diff --git a/docs/usage/go-package/account.md b/docs/usage/go-package/account.md index 535b334bcc..06cc627818 100644 --- a/docs/usage/go-package/account.md +++ b/docs/usage/go-package/account.md @@ -17,7 +17,7 @@ The Account interface is your configuration provider that tells Bifrost: ```go type Account interface { GetConfiguredProviders() ([]schemas.ModelProvider, error) - GetKeysForProvider(providerKey schemas.ModelProvider) ([]schemas.Key, error) + GetKeysForProvider(ctx *context.Context, providerKey schemas.ModelProvider) ([]schemas.Key, error) GetConfigForProvider(providerKey schemas.ModelProvider) (*schemas.ProviderConfig, error) } ``` @@ -34,6 +34,7 @@ Perfect for getting started or simple use cases: package main import ( + "context" "fmt" "os" "github.com/maximhq/bifrost/core/schemas" @@ -45,7 +46,7 @@ func (a *SimpleAccount) GetConfiguredProviders() ([]schemas.ModelProvider, error return []schemas.ModelProvider{schemas.OpenAI}, nil } -func (a *SimpleAccount) GetKeysForProvider(provider schemas.ModelProvider) ([]schemas.Key, error) { +func (a *SimpleAccount) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) { if provider == schemas.OpenAI { apiKey := os.Getenv("OPENAI_API_KEY") if apiKey == "" { @@ -110,7 +111,7 @@ func (a *MultiProviderAccount) GetConfiguredProviders() ([]schemas.ModelProvider return providers, nil } -func (a *MultiProviderAccount) GetKeysForProvider(provider schemas.ModelProvider) ([]schemas.Key, error) { +func (a *MultiProviderAccount) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) { switch provider { case schemas.OpenAI: return []schemas.Key{{ @@ -217,7 +218,7 @@ func (a *MultiProviderAccount) GetConfigForProvider(provider schemas.ModelProvid Distribute requests across multiple API keys for higher rate limits: ```go -func (a *AdvancedAccount) GetKeysForProvider(provider schemas.ModelProvider) ([]schemas.Key, error) { +func (a *AdvancedAccount) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) { if provider == schemas.OpenAI { return []schemas.Key{ { @@ -236,6 +237,98 @@ func (a *AdvancedAccount) GetKeysForProvider(provider schemas.ModelProvider) ([] } ``` +### **Plugin Context Usage** + +Leverage plugin pre-hook data for dynamic key selection: + +```go +type ContextAwareAccount struct { + standardKeys map[schemas.ModelProvider][]schemas.Key + premiumKeys map[schemas.ModelProvider][]schemas.Key + regionKeys map[string][]schemas.Key +} + +func (a *ContextAwareAccount) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) { + // Early validation + standardKeys, ok := a.standardKeys[provider] + if !ok { + return nil, fmt.Errorf("provider %s not configured", provider) + } + + // No context means use standard keys + if ctx == nil { + return standardKeys, nil + } + + // Example: Access control based on user role + if userRole, ok := (*ctx).Value("user_role").(string); ok { + switch userRole { + case "premium": + if premiumKeys, ok := a.premiumKeys[provider]; ok { + return premiumKeys, nil + } + } + } + + // Example: Geographic routing + if region, ok := (*ctx).Value("geo_region").(string); ok { + if regionKeys, ok := a.regionKeys[region]; ok { + return regionKeys, nil + } + } + + // Example: Custom routing based on request type + if reqType, ok := (*ctx).Value("request_type").(string); ok { + switch reqType { + case "streaming": + return []schemas.Key{{ + Value: os.Getenv("DEDICATED_STREAMING_KEY"), + Models: []string{"gpt-4o-mini"}, + Weight: 1.0, + }}, nil + case "batch": + return []schemas.Key{{ + Value: os.Getenv("BATCH_PROCESSING_KEY"), + Models: []string{"gpt-4o"}, + Weight: 1.0, + }}, nil + } + } + + // Example: Rate limit management + if quota, ok := (*ctx).Value("remaining_quota").(int); ok { + if quota < 100 { + // Switch to backup keys when quota is low + return []schemas.Key{{ + Value: os.Getenv("BACKUP_API_KEY"), + Models: []string{"gpt-4o-mini"}, + Weight: 1.0, + }}, nil + } + } + + return standardKeys, nil +} +``` + +This implementation shows how to: +- Use plugin-set context data for dynamic key selection +- Implement role-based access control +- Handle geographic routing requirements +- Support request type-specific key allocation +- Manage rate limits and quotas + +Common context values set by plugins: +- `user_role`: User permission level +- `geo_region`: Geographic location +- `request_type`: Type of request (streaming, batch, etc.) +- `remaining_quota`: Rate limit tracking +- `request_priority`: Priority level +- `client_id`: Client identifier +- `custom_routing`: Custom routing rules + +> **💡 Tip:** Plugins can set any context values during their pre-hook phase. Document the expected context keys and their format to help plugin developers integrate with your key selection logic. + ### **Custom Network Settings** Optimize timeouts and retries for different providers: @@ -354,7 +447,7 @@ type DatabaseAccount struct { db *sql.DB } -func (a *DatabaseAccount) GetKeysForProvider(provider schemas.ModelProvider) ([]schemas.Key, error) { +func (a *DatabaseAccount) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) { rows, err := a.db.Query(` SELECT api_key, models, weight FROM provider_keys @@ -403,7 +496,7 @@ apiKey := "sk-..." // Never do this! ### **Error Handling** ```go -func (a *Account) GetKeysForProvider(provider schemas.ModelProvider) ([]schemas.Key, error) { +func (a *Account) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) { apiKey := os.Getenv("OPENAI_API_KEY") if apiKey == "" { return nil, fmt.Errorf("OPENAI_API_KEY not configured") @@ -442,7 +535,7 @@ func TestAccount(t *testing.T) { assert.Contains(t, providers, schemas.OpenAI) // Test key retrieval - keys, err := account.GetKeysForProvider(schemas.OpenAI) + keys, err := account.GetKeysForProvider(context.Background(), schemas.OpenAI) assert.NoError(t, err) assert.Len(t, keys, 1) assert.Equal(t, "sk-test-key", keys[0].Value) diff --git a/docs/usage/go-package/schemas.md b/docs/usage/go-package/schemas.md index e613aa804a..616dd0b0e6 100644 --- a/docs/usage/go-package/schemas.md +++ b/docs/usage/go-package/schemas.md @@ -385,7 +385,11 @@ Provider configuration and key management: ```go type Account interface { GetConfiguredProviders() ([]ModelProvider, error) - GetKeysForProvider(ModelProvider) ([]Key, error) + // GetKeysForProvider receives a context that can contain data from any source that sets + // values before the Bifrost request. This includes application code, middleware, plugin + // pre-hooks, or any other source. Implementations can use this context data to make + // dynamic key selection decisions based on any values present during the request. + GetKeysForProvider(ctx *context.Context, providerKey ModelProvider) ([]Key, error) GetConfigForProvider(ModelProvider) (*ProviderConfig, error) } @@ -402,15 +406,36 @@ func (a *MyAccount) GetConfiguredProviders() ([]schemas.ModelProvider, error) { }, nil } -func (a *MyAccount) GetKeysForProvider(provider schemas.ModelProvider) ([]schemas.Key, error) { +// Example of context-aware key selection +func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) { switch provider { case schemas.OpenAI: + // Check for any context values + if ctx != nil { + // Example: Value set by application code + if userRole, ok := (*ctx).Value("user_role").(string); ok && userRole == "premium" { + return []schemas.Key{{ + Value: os.Getenv("OPENAI_PREMIUM_KEY"), + Models: []string{"gpt-4o"}, + Weight: 1.0, + }}, nil + } + + // Example: Value set by middleware + if region, ok := (*ctx).Value("geo_region").(string); ok && region == "eu" { + return []schemas.Key{{ + Value: os.Getenv("OPENAI_EU_KEY"), + Models: []string{"gpt-4o"}, + Weight: 1.0, + }}, nil + } + } + // Default key if no special context return []schemas.Key{{ Value: os.Getenv("OPENAI_API_KEY"), Models: []string{"gpt-4o-mini", "gpt-4o"}, Weight: 1.0, }}, nil - // ... other providers } return nil, fmt.Errorf("provider not supported") } diff --git a/docs/usage/key-management.md b/docs/usage/key-management.md index 857b3187b6..73fd3977a9 100644 --- a/docs/usage/key-management.md +++ b/docs/usage/key-management.md @@ -29,8 +29,20 @@ Advanced API key management with weighted distribution, automatic rotation, and
🔧 Go Package Usage +The `GetKeysForProvider` method allows you to implement custom key selection logic for each provider. The method receives a context parameter that carries data set by plugin pre-hooks, enabling dynamic key selection based on plugin-defined criteria. + +For example, plugins can set request metadata, user preferences, or routing rules in the context during their pre-hook phase. Your key management implementation can then access this data to make informed decisions about which keys to return. This is particularly useful for scenarios like: + +- Route requests to specific API keys based on user roles or permissions +- Implement key rotation based on request patterns +- Apply custom rate limiting or quota management +- Select keys based on geographical routing rules +- Use different keys for different types of requests or model configurations + +Here's a basic example implementation: + ```go -func (a *MyAccount) GetKeysForProvider(provider schemas.ModelProvider) ([]schemas.Key, error) { +func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) { switch provider { case schemas.OpenAI: return []schemas.Key{ @@ -96,6 +108,118 @@ export ANTHROPIC_API_KEY="sk-ant-..." --- +### Context-Aware Key Selection + +
+🔧 Go Package - Context Usage + +The `GetKeysForProvider` method receives a context that can contain data from any source that sets values before the Bifrost request. This includes plugin pre-hooks, application logic, middleware, or direct context manipulation. Here's an example that demonstrates various context-based key selection strategies: + +```go +type ContextAwareAccount struct { + standardKeys []schemas.Key + premiumKeys []schemas.Key +} + +func (a *ContextAwareAccount) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) { + if provider != schemas.OpenAI { + return nil, fmt.Errorf("provider not supported") + } + + // Access context values from any source + if ctx != nil { + // Example: Application-set user role + if userRole, ok := (*ctx).Value("user_role").(string); ok { + switch userRole { + case "premium": + return a.premiumKeys, nil + case "standard": + return a.standardKeys, nil + } + } + + // Example: Middleware-set geographic region + if region, ok := (*ctx).Value("geo_region").(string); ok { + // Return region-specific keys + switch region { + case "eu": + return []schemas.Key{{ + Value: os.Getenv("OPENAI_EU_KEY"), + Models: []string{"gpt-4o-mini", "gpt-4o"}, + Weight: 1.0, + }}, nil + case "us": + return []schemas.Key{{ + Value: os.Getenv("OPENAI_US_KEY"), + Models: []string{"gpt-4o-mini", "gpt-4o"}, + Weight: 1.0, + }}, nil + } + } + + // Example: Plugin-set request priority + if priority, ok := (*ctx).Value("request_priority").(string); ok { + switch priority { + case "high": + return []schemas.Key{{ + Value: os.Getenv("OPENAI_DEDICATED_KEY"), + Models: []string{"gpt-4o"}, + Weight: 1.0, + }}, nil + } + } + + // Example: Direct context value from application code + if customKey, ok := (*ctx).Value("custom_api_key").(string); ok { + return []schemas.Key{{ + Value: customKey, + Models: []string{"gpt-4o-mini", "gpt-4o"}, + Weight: 1.0, + }}, nil + } + } + + // Default to standard keys if no context or matching criteria + return a.standardKeys, nil +} +``` + +This implementation demonstrates: +- Reading context values set by various sources +- Application-level user role based selection +- Geographic routing from middleware +- Priority-based selection from plugins +- Custom key injection through direct context manipulation + +You can set context values in several ways: + +```go +// Direct in your application code +ctx := context.WithValue(context.Background(), "user_role", "premium") + +// In middleware +func MyMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := context.WithValue(r.Context(), "geo_region", "eu") + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +// In a plugin's PreHook +func (p *MyPlugin) PreHook(ctx *context.Context, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.PluginShortCircuit, error) { + *ctx = context.WithValue(*ctx, "request_priority", "high") + return req, nil, nil +} + +// When making a Bifrost request +ctx := context.WithValue(context.Background(), "custom_api_key", "sk-...") +response, err := client.ChatCompletionRequest(ctx, request) +``` + +
+ +--- + ## 🔄 Key Distribution Strategies ### Load Balancing Strategy @@ -106,7 +230,7 @@ Distribute requests evenly across multiple keys for maximum throughput: 🔧 Go Package - Equal Distribution ```go -func (a *MyAccount) GetKeysForProvider(provider schemas.ModelProvider) ([]schemas.Key, error) { +func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) { if provider == schemas.OpenAI { return []schemas.Key{ { @@ -190,7 +314,7 @@ Use premium keys for expensive models, standard keys for cheaper models: 🔧 Go Package - Tiered Strategy ```go -func (a *MyAccount) GetKeysForProvider(provider schemas.ModelProvider) ([]schemas.Key, error) { +func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) { if provider == schemas.OpenAI { return []schemas.Key{ // Standard keys for cheap models @@ -269,7 +393,7 @@ Route traffic based on key priority and reliability: 🔧 Go Package - Priority Strategy ```go -func (a *MyAccount) GetKeysForProvider(provider schemas.ModelProvider) ([]schemas.Key, error) { +func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) { if provider == schemas.OpenAI { return []schemas.Key{ // Primary key (highest priority) @@ -338,7 +462,7 @@ func (a *MyAccount) GetKeysForProvider(provider schemas.ModelProvider) ([]schema 🔧 Go Package - Cross-Provider Keys ```go -func (a *MyAccount) GetKeysForProvider(provider schemas.ModelProvider) ([]schemas.Key, error) { +func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) { switch provider { case schemas.OpenAI: return []schemas.Key{ @@ -447,7 +571,7 @@ type DynamicAccount struct { keys map[schemas.ModelProvider][]schemas.Key } -func (a *DynamicAccount) GetKeysForProvider(provider schemas.ModelProvider) ([]schemas.Key, error) { +func (a *DynamicAccount) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) { // Rotate keys every hour if time.Since(a.lastRotation) > a.keyRotationInterval { a.rotateKeys() diff --git a/docs/usage/providers.md b/docs/usage/providers.md index 3cb5d79d5b..71ecc08609 100644 --- a/docs/usage/providers.md +++ b/docs/usage/providers.md @@ -44,7 +44,7 @@ func (a *MyAccount) GetConfiguredProviders() ([]schemas.ModelProvider, error) { return []schemas.ModelProvider{schemas.OpenAI}, nil } -func (a *MyAccount) GetKeysForProvider(provider schemas.ModelProvider) ([]schemas.Key, error) { +func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) { switch provider { case schemas.OpenAI: return []schemas.Key{ @@ -153,7 +153,7 @@ Configure multiple providers for fallbacks and load distribution. 🔧 Go Package - Multi-Provider ```go -func (a *MyAccount) GetKeysForProvider(provider schemas.ModelProvider) ([]schemas.Key, error) { +func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) { switch provider { case schemas.OpenAI: return []schemas.Key{ @@ -314,7 +314,7 @@ echo "$response" **Go Package:** ```go -func (a *MyAccount) GetKeysForProvider(provider schemas.ModelProvider) ([]schemas.Key, error) { +func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) { if provider == schemas.Azure { return []schemas.Key{ { @@ -368,7 +368,7 @@ func (a *MyAccount) GetKeysForProvider(provider schemas.ModelProvider) ([]schema **Go Package:** ```go -func (a *MyAccount) GetKeysForProvider(provider schemas.ModelProvider) ([]schemas.Key, error) { +func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) { if provider == schemas.Vertex { return []schemas.Key{ { @@ -427,7 +427,7 @@ func (a *MyAccount) GetConfigForProvider(provider schemas.ModelProvider) (*schem return &schemas.ProviderConfig{}, nil } -func (a *MyAccount) GetKeysForProvider(provider schemas.ModelProvider) ([]schemas.Key, error) { +func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) { if provider == schemas.Ollama { return []schemas.Key{ {