Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Bifrost is a high-performance AI gateway that connects you to 10+ providers (Ope
**What You Need**

- Any AI provider API key (OpenAI, Anthropic, Bedrock, etc.)
- Node.js 18+ installed (or use Docker instead via [Docker installation](#using-bifrost-http-transport))
- Node.js 18+ installed (or use Docker instead via [Docker installation](./docs/quickstart/http-transport.md))
- 20 seconds of your time ⏰
Comment thread
Pratham-Mishra04 marked this conversation as resolved.

### Using Bifrost HTTP Transport
Expand Down
22 changes: 3 additions & 19 deletions core/bifrost.go
Original file line number Diff line number Diff line change
Expand Up @@ -485,14 +485,6 @@ func (bifrost *Bifrost) UpdateProviderConcurrency(providerKey schemas.ModelProvi

oldQueue := oldQueueValue.(chan ChannelMessage)

// Check if the provider has any keys (skip keyless providers)
if providerRequiresKey(providerKey) {
keys, err := bifrost.account.GetKeysForProvider(providerKey)
if err != nil || len(keys) == 0 {
return fmt.Errorf("failed to get keys for provider %s: %v", providerKey, err)
}
}

bifrost.logger.Debug(fmt.Sprintf("Gracefully stopping existing workers for provider %s", providerKey))

// Step 1: Create new queue with updated buffer size
Expand Down Expand Up @@ -836,14 +828,6 @@ func (bifrost *Bifrost) prepareProvider(providerKey schemas.ModelProvider, confi
return fmt.Errorf("failed to get config for provider: %v", err)
}

// Check if the provider has any keys (skip keyless providers)
if providerRequiresKey(providerKey) {
keys, err := bifrost.account.GetKeysForProvider(providerKey)
if err != nil || len(keys) == 0 {
return fmt.Errorf("failed to get keys for provider: %v", err)
}
}

queue := make(chan ChannelMessage, providerConfig.ConcurrencyAndBufferSize.BufferSize) // Buffered channel per provider

bifrost.requestQueues.Store(providerKey, queue)
Expand Down Expand Up @@ -1094,7 +1078,7 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, queue chan Chan

key := schemas.Key{}
if providerRequiresKey(provider.GetProviderKey()) {
key, err = bifrost.selectKeyFromProviderForModel(provider.GetProviderKey(), req.Model)
key, err = bifrost.selectKeyFromProviderForModel(&req.Context, provider.GetProviderKey(), req.Model)
if err != nil {
bifrost.logger.Warn(fmt.Sprintf("Error selecting key for model %s: %v", req.Model, err))
req.Err <- schemas.BifrostError{
Expand Down Expand Up @@ -1384,8 +1368,8 @@ func (bifrost *Bifrost) releaseChannelMessage(msg *ChannelMessage) {

// selectKeyFromProviderForModel selects an appropriate API key for a given provider and model.
// It uses weighted random selection if multiple keys are available.
func (bifrost *Bifrost) selectKeyFromProviderForModel(providerKey schemas.ModelProvider, model string) (schemas.Key, error) {
keys, err := bifrost.account.GetKeysForProvider(providerKey)
func (bifrost *Bifrost) selectKeyFromProviderForModel(ctx *context.Context, providerKey schemas.ModelProvider, model string) (schemas.Key, error) {
keys, err := bifrost.account.GetKeysForProvider(ctx, providerKey)
if err != nil {
return schemas.Key{}, err
}
Expand Down
7 changes: 6 additions & 1 deletion core/schemas/account.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Package schemas defines the core schemas and types used by the Bifrost system.
package schemas

import "context"

// Key represents an API key and its associated configuration for a provider.
// It contains the key value, supported models, and a weight for load balancing.
type Key struct {
Expand Down Expand Up @@ -37,7 +39,10 @@ type Account interface {

// GetKeysForProvider returns the API keys configured for a specific provider.
// The keys include their values, supported models, and weights for load balancing.
GetKeysForProvider(providerKey ModelProvider) ([]Key, error)
// The context can carry data from any source that sets values before the Bifrost request,
// including but not limited to plugin pre-hooks, application logic, or any in app middleware sharing the context.
// This enables dynamic key selection based on any context values present during the request.
GetKeysForProvider(ctx *context.Context, providerKey ModelProvider) ([]Key, error)

// GetConfigForProvider returns the configuration for a specific provider.
// This includes network settings, authentication details, and other provider-specific
Expand Down
4 changes: 2 additions & 2 deletions docs/quickstart/go-package.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func (a *MyAccount) GetConfiguredProviders() ([]schemas.ModelProvider, error) {
return []schemas.ModelProvider{schemas.OpenAI}, nil
}

func (a *MyAccount) GetKeysForProvider(provider schemas.ModelProvider) ([]schemas.Key, error) {
func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) {
if provider == schemas.OpenAI {
return []schemas.Key{{
Value: os.Getenv("OPENAI_API_KEY"),
Expand Down Expand Up @@ -119,7 +119,7 @@ func (a *MyAccount) GetConfiguredProviders() ([]schemas.ModelProvider, error) {
}

// Update GetKeysForProvider to handle both providers
func (a *MyAccount) GetKeysForProvider(provider schemas.ModelProvider) ([]schemas.Key, error) {
func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) {
switch provider {
case schemas.OpenAI:
return []schemas.Key{{
Expand Down
107 changes: 100 additions & 7 deletions docs/usage/go-package/account.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ The Account interface is your configuration provider that tells Bifrost:
```go
type Account interface {
GetConfiguredProviders() ([]schemas.ModelProvider, error)
GetKeysForProvider(providerKey schemas.ModelProvider) ([]schemas.Key, error)
GetKeysForProvider(ctx *context.Context, providerKey schemas.ModelProvider) ([]schemas.Key, error)
GetConfigForProvider(providerKey schemas.ModelProvider) (*schemas.ProviderConfig, error)
}
```
Expand All @@ -34,6 +34,7 @@ Perfect for getting started or simple use cases:
package main

import (
"context"
"fmt"
"os"
"github.com/maximhq/bifrost/core/schemas"
Expand All @@ -45,7 +46,7 @@ func (a *SimpleAccount) GetConfiguredProviders() ([]schemas.ModelProvider, error
return []schemas.ModelProvider{schemas.OpenAI}, nil
}

func (a *SimpleAccount) GetKeysForProvider(provider schemas.ModelProvider) ([]schemas.Key, error) {
func (a *SimpleAccount) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) {
if provider == schemas.OpenAI {
apiKey := os.Getenv("OPENAI_API_KEY")
if apiKey == "" {
Expand Down Expand Up @@ -110,7 +111,7 @@ func (a *MultiProviderAccount) GetConfiguredProviders() ([]schemas.ModelProvider
return providers, nil
}

func (a *MultiProviderAccount) GetKeysForProvider(provider schemas.ModelProvider) ([]schemas.Key, error) {
func (a *MultiProviderAccount) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) {
switch provider {
case schemas.OpenAI:
return []schemas.Key{{
Expand Down Expand Up @@ -217,7 +218,7 @@ func (a *MultiProviderAccount) GetConfigForProvider(provider schemas.ModelProvid
Distribute requests across multiple API keys for higher rate limits:

```go
func (a *AdvancedAccount) GetKeysForProvider(provider schemas.ModelProvider) ([]schemas.Key, error) {
func (a *AdvancedAccount) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) {
if provider == schemas.OpenAI {
return []schemas.Key{
{
Expand All @@ -236,6 +237,98 @@ func (a *AdvancedAccount) GetKeysForProvider(provider schemas.ModelProvider) ([]
}
```

### **Plugin Context Usage**

Leverage plugin pre-hook data for dynamic key selection:

```go
type ContextAwareAccount struct {
standardKeys map[schemas.ModelProvider][]schemas.Key
premiumKeys map[schemas.ModelProvider][]schemas.Key
regionKeys map[string][]schemas.Key
}

func (a *ContextAwareAccount) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) {
// Early validation
standardKeys, ok := a.standardKeys[provider]
if !ok {
return nil, fmt.Errorf("provider %s not configured", provider)
}

// No context means use standard keys
if ctx == nil {
return standardKeys, nil
}

// Example: Access control based on user role
if userRole, ok := (*ctx).Value("user_role").(string); ok {
switch userRole {
case "premium":
if premiumKeys, ok := a.premiumKeys[provider]; ok {
return premiumKeys, nil
}
}
}

// Example: Geographic routing
if region, ok := (*ctx).Value("geo_region").(string); ok {
if regionKeys, ok := a.regionKeys[region]; ok {
return regionKeys, nil
}
}

// Example: Custom routing based on request type
if reqType, ok := (*ctx).Value("request_type").(string); ok {
switch reqType {
case "streaming":
return []schemas.Key{{
Value: os.Getenv("DEDICATED_STREAMING_KEY"),
Models: []string{"gpt-4o-mini"},
Weight: 1.0,
}}, nil
case "batch":
return []schemas.Key{{
Value: os.Getenv("BATCH_PROCESSING_KEY"),
Models: []string{"gpt-4o"},
Weight: 1.0,
}}, nil
}
}

// Example: Rate limit management
if quota, ok := (*ctx).Value("remaining_quota").(int); ok {
if quota < 100 {
// Switch to backup keys when quota is low
return []schemas.Key{{
Value: os.Getenv("BACKUP_API_KEY"),
Models: []string{"gpt-4o-mini"},
Weight: 1.0,
}}, nil
}
}

return standardKeys, nil
}
```
Comment thread
Pratham-Mishra04 marked this conversation as resolved.

This implementation shows how to:
- Use plugin-set context data for dynamic key selection
- Implement role-based access control
- Handle geographic routing requirements
- Support request type-specific key allocation
- Manage rate limits and quotas

Common context values set by plugins:
- `user_role`: User permission level
- `geo_region`: Geographic location
- `request_type`: Type of request (streaming, batch, etc.)
- `remaining_quota`: Rate limit tracking
- `request_priority`: Priority level
- `client_id`: Client identifier
- `custom_routing`: Custom routing rules

> **💡 Tip:** Plugins can set any context values during their pre-hook phase. Document the expected context keys and their format to help plugin developers integrate with your key selection logic.

### **Custom Network Settings**

Optimize timeouts and retries for different providers:
Expand Down Expand Up @@ -354,7 +447,7 @@ type DatabaseAccount struct {
db *sql.DB
}

func (a *DatabaseAccount) GetKeysForProvider(provider schemas.ModelProvider) ([]schemas.Key, error) {
func (a *DatabaseAccount) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) {
rows, err := a.db.Query(`
SELECT api_key, models, weight
FROM provider_keys
Expand Down Expand Up @@ -403,7 +496,7 @@ apiKey := "sk-..." // Never do this!
### **Error Handling**

```go
func (a *Account) GetKeysForProvider(provider schemas.ModelProvider) ([]schemas.Key, error) {
func (a *Account) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) {
apiKey := os.Getenv("OPENAI_API_KEY")
if apiKey == "" {
return nil, fmt.Errorf("OPENAI_API_KEY not configured")
Expand Down Expand Up @@ -442,7 +535,7 @@ func TestAccount(t *testing.T) {
assert.Contains(t, providers, schemas.OpenAI)

// Test key retrieval
keys, err := account.GetKeysForProvider(schemas.OpenAI)
keys, err := account.GetKeysForProvider(context.Background(), schemas.OpenAI)
assert.NoError(t, err)
assert.Len(t, keys, 1)
assert.Equal(t, "sk-test-key", keys[0].Value)
Expand Down
31 changes: 28 additions & 3 deletions docs/usage/go-package/schemas.md
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,11 @@ Provider configuration and key management:
```go
type Account interface {
GetConfiguredProviders() ([]ModelProvider, error)
GetKeysForProvider(ModelProvider) ([]Key, error)
// GetKeysForProvider receives a context that can contain data from any source that sets
// values before the Bifrost request. This includes application code, middleware, plugin
// pre-hooks, or any other source. Implementations can use this context data to make
// dynamic key selection decisions based on any values present during the request.
GetKeysForProvider(ctx *context.Context, providerKey ModelProvider) ([]Key, error)
GetConfigForProvider(ModelProvider) (*ProviderConfig, error)
Comment thread
Pratham-Mishra04 marked this conversation as resolved.
}

Expand All @@ -402,15 +406,36 @@ func (a *MyAccount) GetConfiguredProviders() ([]schemas.ModelProvider, error) {
}, nil
}

func (a *MyAccount) GetKeysForProvider(provider schemas.ModelProvider) ([]schemas.Key, error) {
// Example of context-aware key selection
func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.ModelProvider) ([]schemas.Key, error) {
switch provider {
case schemas.OpenAI:
// Check for any context values
if ctx != nil {
// Example: Value set by application code
if userRole, ok := (*ctx).Value("user_role").(string); ok && userRole == "premium" {
return []schemas.Key{{
Value: os.Getenv("OPENAI_PREMIUM_KEY"),
Models: []string{"gpt-4o"},
Weight: 1.0,
}}, nil
}

// Example: Value set by middleware
if region, ok := (*ctx).Value("geo_region").(string); ok && region == "eu" {
return []schemas.Key{{
Value: os.Getenv("OPENAI_EU_KEY"),
Models: []string{"gpt-4o"},
Weight: 1.0,
}}, nil
}
Comment thread
Pratham-Mishra04 marked this conversation as resolved.
}
// Default key if no special context
return []schemas.Key{{
Value: os.Getenv("OPENAI_API_KEY"),
Models: []string{"gpt-4o-mini", "gpt-4o"},
Weight: 1.0,
}}, nil
// ... other providers
}
return nil, fmt.Errorf("provider not supported")
}
Expand Down
Loading