Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions core/bifrost.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ const (
)

// executor is a function type that handles specific request types.
type executor func(provider schemas.Provider, req *ChannelMessage, key string) (*schemas.BifrostResponse, *schemas.BifrostError)
type executor func(provider schemas.Provider, req *ChannelMessage, key schemas.Key) (*schemas.BifrostResponse, *schemas.BifrostError)

// messageExecutors is a factory map for handling different request types.
var messageExecutors = map[RequestType]executor{
Expand Down Expand Up @@ -1092,7 +1092,7 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, queue chan Chan
var bifrostError *schemas.BifrostError
var err error

key := ""
key := schemas.Key{}
if providerRequiresKey(provider.GetProviderKey()) {
key, err = bifrost.selectKeyFromProviderForModel(provider.GetProviderKey(), req.Model)
if err != nil {
Expand Down Expand Up @@ -1207,7 +1207,7 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, queue chan Chan
}

// handleTextCompletion executes a text completion request
func handleTextCompletion(provider schemas.Provider, req *ChannelMessage, key string) (*schemas.BifrostResponse, *schemas.BifrostError) {
func handleTextCompletion(provider schemas.Provider, req *ChannelMessage, key schemas.Key) (*schemas.BifrostResponse, *schemas.BifrostError) {
if req.Input.TextCompletionInput == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Expand All @@ -1220,7 +1220,7 @@ func handleTextCompletion(provider schemas.Provider, req *ChannelMessage, key st
}

// handleChatCompletion executes a chat completion request
func handleChatCompletion(provider schemas.Provider, req *ChannelMessage, key string) (*schemas.BifrostResponse, *schemas.BifrostError) {
func handleChatCompletion(provider schemas.Provider, req *ChannelMessage, key schemas.Key) (*schemas.BifrostResponse, *schemas.BifrostError) {
if req.Input.ChatCompletionInput == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Expand All @@ -1233,7 +1233,7 @@ func handleChatCompletion(provider schemas.Provider, req *ChannelMessage, key st
}

// handleEmbedding executes an embedding request
func handleEmbedding(provider schemas.Provider, req *ChannelMessage, key string) (*schemas.BifrostResponse, *schemas.BifrostError) {
func handleEmbedding(provider schemas.Provider, req *ChannelMessage, key schemas.Key) (*schemas.BifrostResponse, *schemas.BifrostError) {
if req.Input.EmbeddingInput == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Expand All @@ -1246,7 +1246,7 @@ func handleEmbedding(provider schemas.Provider, req *ChannelMessage, key string)
}

// handleChatCompletionStream executes a chat completion stream request
func handleChatCompletionStream(provider schemas.Provider, req *ChannelMessage, key string, postHookRunner schemas.PostHookRunner) (chan *schemas.BifrostStream, *schemas.BifrostError) {
func handleChatCompletionStream(provider schemas.Provider, req *ChannelMessage, key schemas.Key, postHookRunner schemas.PostHookRunner) (chan *schemas.BifrostStream, *schemas.BifrostError) {
if req.Input.ChatCompletionInput == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Expand Down Expand Up @@ -1384,30 +1384,30 @@ func (bifrost *Bifrost) releaseChannelMessage(msg *ChannelMessage) {

// selectKeyFromProviderForModel selects an appropriate API key for a given provider and model.
// It uses weighted random selection if multiple keys are available.
func (bifrost *Bifrost) selectKeyFromProviderForModel(providerKey schemas.ModelProvider, model string) (string, error) {
func (bifrost *Bifrost) selectKeyFromProviderForModel(providerKey schemas.ModelProvider, model string) (schemas.Key, error) {
keys, err := bifrost.account.GetKeysForProvider(providerKey)
if err != nil {
return "", err
return schemas.Key{}, err
}

if len(keys) == 0 {
return "", fmt.Errorf("no keys found for provider: %v", providerKey)
return schemas.Key{}, fmt.Errorf("no keys found for provider: %v", providerKey)
}

// filter out keys which dont support the model, if the key has no models, it is supported for all models
var supportedKeys []schemas.Key
for _, key := range keys {
if (slices.Contains(key.Models, model) && strings.TrimSpace(key.Value) != "") || len(key.Models) == 0 {
if (slices.Contains(key.Models, model) && (strings.TrimSpace(key.Value) != "" || providerKey == schemas.Vertex)) || len(key.Models) == 0 {
supportedKeys = append(supportedKeys, key)
}
}

if len(supportedKeys) == 0 {
return "", fmt.Errorf("no keys found that support model: %s", model)
return schemas.Key{}, fmt.Errorf("no keys found that support model: %s", model)
}

if len(supportedKeys) == 1 {
return supportedKeys[0].Value, nil
return supportedKeys[0], nil
}

// Use a weighted random selection based on key weights
Expand All @@ -1425,12 +1425,12 @@ func (bifrost *Bifrost) selectKeyFromProviderForModel(providerKey schemas.ModelP
for _, key := range supportedKeys {
currentWeight += int(key.Weight * 100)
if randomValue < currentWeight {
return key.Value, nil
return key, nil
}
}

// Fallback to first key if something goes wrong
return supportedKeys[0].Value, nil
return supportedKeys[0], nil
}

// CLEANUP
Expand Down
14 changes: 7 additions & 7 deletions core/providers/anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ func (provider *AnthropicProvider) completeRequest(ctx context.Context, requestB
// TextCompletion performs a text completion request to Anthropic's API.
// It formats the request, sends it to Anthropic, and processes the response.
// Returns a BifrostResponse containing the completion results or an error if the request fails.
func (provider *AnthropicProvider) TextCompletion(ctx context.Context, model, key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) {
func (provider *AnthropicProvider) TextCompletion(ctx context.Context, model string, key schemas.Key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) {
preparedParams := provider.prepareTextCompletionParams(prepareParams(params))

// Merge additional parameters
Expand All @@ -327,7 +327,7 @@ func (provider *AnthropicProvider) TextCompletion(ctx context.Context, model, ke
"prompt": fmt.Sprintf("\n\nHuman: %s\n\nAssistant:", text),
}, preparedParams)

responseBody, err := provider.completeRequest(ctx, requestBody, provider.networkConfig.BaseURL+"/v1/complete", key)
responseBody, err := provider.completeRequest(ctx, requestBody, provider.networkConfig.BaseURL+"/v1/complete", key.Value)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -379,7 +379,7 @@ func (provider *AnthropicProvider) TextCompletion(ctx context.Context, model, ke
// ChatCompletion performs a chat completion request to Anthropic's API.
// It formats the request, sends it to Anthropic, and processes the response.
// Returns a BifrostResponse containing the completion results or an error if the request fails.
func (provider *AnthropicProvider) ChatCompletion(ctx context.Context, model, key string, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) {
func (provider *AnthropicProvider) ChatCompletion(ctx context.Context, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) {
formattedMessages, preparedParams := prepareAnthropicChatRequest(messages, params)

// Merge additional parameters
Expand All @@ -388,7 +388,7 @@ func (provider *AnthropicProvider) ChatCompletion(ctx context.Context, model, ke
"messages": formattedMessages,
}, preparedParams)

responseBody, err := provider.completeRequest(ctx, requestBody, provider.networkConfig.BaseURL+"/v1/messages", key)
responseBody, err := provider.completeRequest(ctx, requestBody, provider.networkConfig.BaseURL+"/v1/messages", key.Value)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -775,14 +775,14 @@ func parseAnthropicResponse(response *AnthropicChatResponse, bifrostResponse *sc
}

// Embedding is not supported by the Anthropic provider.
func (provider *AnthropicProvider) Embedding(ctx context.Context, model, key string, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) {
func (provider *AnthropicProvider) Embedding(ctx context.Context, model string, key schemas.Key, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) {
return nil, newUnsupportedOperationError("embedding", "anthropic")
}

// ChatCompletionStream performs a streaming chat completion request to the Anthropic API.
// It supports real-time streaming of responses using Server-Sent Events (SSE).
// Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails.
func (provider *AnthropicProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model, key string, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) {
func (provider *AnthropicProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) {
formattedMessages, preparedParams := prepareAnthropicChatRequest(messages, params)

// Merge additional parameters and set stream to true
Expand All @@ -795,7 +795,7 @@ func (provider *AnthropicProvider) ChatCompletionStream(ctx context.Context, pos
// Prepare Anthropic headers
headers := map[string]string{
"Content-Type": "application/json",
"x-api-key": key,
"x-api-key": key.Value,
"anthropic-version": provider.apiVersion,
"Accept": "text/event-stream",
"Cache-Control": "no-cache",
Expand Down
58 changes: 35 additions & 23 deletions core/providers/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ type AzureProvider struct {
logger schemas.Logger // Logger for provider operations
client *fasthttp.Client // HTTP client for API requests
streamClient *http.Client // HTTP client for streaming requests
meta schemas.MetaConfig // Azure-specific configuration
networkConfig schemas.NetworkConfig // Network configuration including extra headers
}

Expand All @@ -127,10 +126,6 @@ type AzureProvider struct {
func NewAzureProvider(config *schemas.ProviderConfig, logger schemas.Logger) (*AzureProvider, error) {
config.CheckAndSetDefaults()

if config.MetaConfig == nil {
return nil, fmt.Errorf("meta config is not set")
}

client := &fasthttp.Client{
ReadTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds),
WriteTimeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds),
Expand All @@ -156,7 +151,6 @@ func NewAzureProvider(config *schemas.ProviderConfig, logger schemas.Logger) (*A
logger: logger,
client: client,
streamClient: streamClient,
meta: config.MetaConfig,
networkConfig: config.NetworkConfig,
}, nil
}
Expand All @@ -169,7 +163,16 @@ func (provider *AzureProvider) GetProviderKey() schemas.ModelProvider {
// completeRequest sends a request to Azure's API and handles the response.
// It constructs the API URL, sets up authentication, and processes the response.
// Returns the response body or an error if the request fails.
func (provider *AzureProvider) completeRequest(ctx context.Context, requestBody map[string]interface{}, path string, key string, model string) ([]byte, *schemas.BifrostError) {
func (provider *AzureProvider) completeRequest(ctx context.Context, requestBody map[string]interface{}, path string, key schemas.Key, model string) ([]byte, *schemas.BifrostError) {
if key.AzureKeyConfig == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: schemas.ErrorField{
Message: "azure key config not set",
},
}
}

// Marshal the request body
jsonData, err := json.Marshal(requestBody)
if err != nil {
Expand All @@ -182,7 +185,7 @@ func (provider *AzureProvider) completeRequest(ctx context.Context, requestBody
}
}

if provider.meta.GetEndpoint() == nil {
if key.AzureKeyConfig.Endpoint == "" {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: schemas.ErrorField{
Expand All @@ -191,10 +194,10 @@ func (provider *AzureProvider) completeRequest(ctx context.Context, requestBody
}
}

url := *provider.meta.GetEndpoint()
url := key.AzureKeyConfig.Endpoint

if provider.meta.GetDeployments() != nil {
deployment := provider.meta.GetDeployments()[model]
if key.AzureKeyConfig.Deployments != nil {
deployment := key.AzureKeyConfig.Deployments[model]
if deployment == "" {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Expand All @@ -204,7 +207,7 @@ func (provider *AzureProvider) completeRequest(ctx context.Context, requestBody
}
}

apiVersion := provider.meta.GetAPIVersion()
apiVersion := key.AzureKeyConfig.APIVersion
if apiVersion == nil {
apiVersion = StrPtr("2024-02-01")
}
Expand Down Expand Up @@ -236,7 +239,7 @@ func (provider *AzureProvider) completeRequest(ctx context.Context, requestBody
// Ensure api-key is not accidentally present (from extra headers, etc.)
req.Header.Del("api-key")
} else {
req.Header.Set("api-key", key)
req.Header.Set("api-key", key.Value)
}

req.SetBody(jsonData)
Expand Down Expand Up @@ -269,7 +272,7 @@ func (provider *AzureProvider) completeRequest(ctx context.Context, requestBody
// TextCompletion performs a text completion request to Azure's API.
// It formats the request, sends it to Azure, and processes the response.
// Returns a BifrostResponse containing the completion results or an error if the request fails.
func (provider *AzureProvider) TextCompletion(ctx context.Context, model, key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) {
func (provider *AzureProvider) TextCompletion(ctx context.Context, model string, key schemas.Key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) {
preparedParams := prepareParams(params)

// Merge additional parameters
Expand Down Expand Up @@ -337,7 +340,7 @@ func (provider *AzureProvider) TextCompletion(ctx context.Context, model, key, t
// ChatCompletion performs a chat completion request to Azure's API.
// It formats the request, sends it to Azure, and processes the response.
// Returns a BifrostResponse containing the completion results or an error if the request fails.
func (provider *AzureProvider) ChatCompletion(ctx context.Context, model, key string, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) {
func (provider *AzureProvider) ChatCompletion(ctx context.Context, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) {
formattedMessages, preparedParams := prepareOpenAIChatRequest(messages, params)

// Merge additional parameters
Expand Down Expand Up @@ -384,7 +387,7 @@ func (provider *AzureProvider) ChatCompletion(ctx context.Context, model, key st
// Embedding generates embeddings for the given input text(s) using Azure OpenAI.
// The input can be either a single string or a slice of strings for batch embedding.
// Returns a BifrostResponse containing the embedding(s) and any error that occurred.
func (provider *AzureProvider) Embedding(ctx context.Context, model string, key string, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) {
func (provider *AzureProvider) Embedding(ctx context.Context, model string, key schemas.Key, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) {
if len(input.Texts) == 0 {
return nil, &schemas.BifrostError{
IsBifrostError: true,
Expand Down Expand Up @@ -493,9 +496,18 @@ func (provider *AzureProvider) Embedding(ctx context.Context, model string, key
// It supports real-time streaming of responses using Server-Sent Events (SSE).
// Uses Azure-specific URL construction with deployments and supports both api-key and Bearer token authentication.
// Returns a channel containing BifrostResponse objects representing the stream or an error if the request fails.
func (provider *AzureProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model, key string, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) {
func (provider *AzureProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) {
formattedMessages, preparedParams := prepareOpenAIChatRequest(messages, params)

if key.AzureKeyConfig == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: schemas.ErrorField{
Message: "azure key config not set",
},
}
}

// Merge additional parameters and set stream to true
requestBody := mergeConfig(map[string]interface{}{
"model": model,
Expand All @@ -504,7 +516,7 @@ func (provider *AzureProvider) ChatCompletionStream(ctx context.Context, postHoo
}, preparedParams)

// Construct Azure-specific URL with deployment
if provider.meta.GetEndpoint() == nil {
if key.AzureKeyConfig.Endpoint == "" {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: schemas.ErrorField{
Expand All @@ -513,11 +525,11 @@ func (provider *AzureProvider) ChatCompletionStream(ctx context.Context, postHoo
}
}

baseURL := *provider.meta.GetEndpoint()
baseURL := key.AzureKeyConfig.Endpoint
var fullURL string

if provider.meta.GetDeployments() != nil {
deployment := provider.meta.GetDeployments()[model]
if key.AzureKeyConfig.Deployments != nil {
deployment := key.AzureKeyConfig.Deployments[model]
if deployment == "" {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Expand All @@ -527,7 +539,7 @@ func (provider *AzureProvider) ChatCompletionStream(ctx context.Context, postHoo
}
}

apiVersion := provider.meta.GetAPIVersion()
apiVersion := key.AzureKeyConfig.APIVersion
if apiVersion == nil {
apiVersion = StrPtr("2024-02-01")
}
Expand All @@ -552,7 +564,7 @@ func (provider *AzureProvider) ChatCompletionStream(ctx context.Context, postHoo
if authToken, ok := ctx.Value(AzureAuthorizationTokenKey).(string); ok {
headers["Authorization"] = fmt.Sprintf("Bearer %s", authToken)
} else {
headers["api-key"] = key
headers["api-key"] = key.Value
}

// Use shared streaming logic from OpenAI
Expand Down
Loading