Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion core/providers/anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -1198,7 +1198,7 @@ func handleAnthropicStreaming(
logger.Warn(fmt.Sprintf("Error reading %s stream: %v", providerType, err))
processAndSendError(ctx, postHookRunner, err, responseChan, logger)
} else {
response := createBifrostChatCompletionChunkResponse(usage, finishReason, chunkIndex, params, providerType)
response := createBifrostChatCompletionChunkResponse(messageID, usage, finishReason, chunkIndex, params, providerType)
handleStreamEndWithSuccess(ctx, response, postHookRunner, responseChan, logger)
}
}()
Expand Down
87 changes: 49 additions & 38 deletions core/providers/bedrock.go
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,9 @@ func (provider *BedrockProvider) GetProviderKey() schemas.ModelProvider {
// CompleteRequest sends a request to Bedrock's API and handles the response.
// It constructs the API URL, sets up AWS authentication, and processes the response.
// Returns the response body or an error if the request fails.
func (provider *BedrockProvider) completeRequest(ctx context.Context, requestBody map[string]interface{}, path string, config schemas.BedrockKeyConfig) ([]byte, *schemas.BifrostError) {
func (provider *BedrockProvider) completeRequest(ctx context.Context, requestBody map[string]interface{}, path string, key schemas.Key) ([]byte, *schemas.BifrostError) {
config := key.BedrockKeyConfig

region := "us-east-1"
if config.Region != nil {
region = *config.Region
Expand Down Expand Up @@ -301,9 +303,14 @@ func (provider *BedrockProvider) completeRequest(ctx context.Context, requestBod
// Set any extra headers from network config
setExtraHeadersHTTP(req, provider.networkConfig.ExtraHeaders, nil)

// Sign the request using either explicit credentials or IAM role authentication
if err := signAWSRequest(ctx, req, config.AccessKey, config.SecretKey, config.SessionToken, region, "bedrock", provider.GetProviderKey()); err != nil {
return nil, err
// If Value is set, use API Key authentication - else use IAM role authentication
if key.Value != "" {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", key.Value))
} else {
// Sign the request using either explicit credentials or IAM role authentication
if err := signAWSRequest(ctx, req, config.AccessKey, config.SecretKey, config.SessionToken, region, "bedrock", provider.GetProviderKey()); err != nil {
return nil, err
}
}
Comment thread
Pratham-Mishra04 marked this conversation as resolved.
Comment thread
Pratham-Mishra04 marked this conversation as resolved.

// Execute the request
Expand Down Expand Up @@ -834,7 +841,8 @@ func (provider *BedrockProvider) TextCompletion(ctx context.Context, model strin
"prompt": text,
}, preparedParams)

body, err := provider.completeRequest(ctx, requestBody, fmt.Sprintf("%s/invoke", model), *key.BedrockKeyConfig)
path := provider.getModelPath("invoke", model, key)
body, err := provider.completeRequest(ctx, requestBody, path, key)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1018,19 +1026,10 @@ func (provider *BedrockProvider) ChatCompletion(ctx context.Context, model strin
requestBody := mergeConfig(messageBody, preparedParams)

// Format the path with proper model identifier
path := fmt.Sprintf("%s/converse", model)

if key.BedrockKeyConfig.Deployments != nil {
if inferenceProfileId, ok := key.BedrockKeyConfig.Deployments[model]; ok {
if key.BedrockKeyConfig.ARN != nil {
encodedModelIdentifier := url.QueryEscape(fmt.Sprintf("%s/%s", *key.BedrockKeyConfig.ARN, inferenceProfileId))
path = fmt.Sprintf("%s/converse", encodedModelIdentifier)
}
}
}
path := provider.getModelPath("converse", model, key)

// Create the signed request
responseBody, err := provider.completeRequest(ctx, requestBody, path, *key.BedrockKeyConfig)
responseBody, err := provider.completeRequest(ctx, requestBody, path, key)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1223,16 +1222,16 @@ func (provider *BedrockProvider) Embedding(ctx context.Context, model string, ke

switch {
case strings.Contains(model, "amazon.titan-embed-text"):
return provider.handleTitanEmbedding(ctx, model, *key.BedrockKeyConfig, input, params, providerName)
return provider.handleTitanEmbedding(ctx, model, key, input, params, providerName)
case strings.Contains(model, "cohere.embed"):
return provider.handleCohereEmbedding(ctx, model, *key.BedrockKeyConfig, input, params, providerName)
return provider.handleCohereEmbedding(ctx, model, key, input, params, providerName)
default:
return nil, newConfigurationError("embedding is not supported for this Bedrock model", providerName)
}
}

// handleTitanEmbedding handles embedding requests for Amazon Titan models.
func (provider *BedrockProvider) handleTitanEmbedding(ctx context.Context, model string, config schemas.BedrockKeyConfig, input *schemas.EmbeddingInput, params *schemas.ModelParameters, providerName schemas.ModelProvider) (*schemas.BifrostResponse, *schemas.BifrostError) {
func (provider *BedrockProvider) handleTitanEmbedding(ctx context.Context, model string, key schemas.Key, input *schemas.EmbeddingInput, params *schemas.ModelParameters, providerName schemas.ModelProvider) (*schemas.BifrostResponse, *schemas.BifrostError) {
// Titan Text Embeddings V1/V2 - only supports single text input
if len(input.Texts) == 0 {
return nil, newConfigurationError("no input text provided for embedding", providerName)
Expand All @@ -1258,8 +1257,8 @@ func (provider *BedrockProvider) handleTitanEmbedding(ctx context.Context, model
}

// Properly escape model name for URL path to ensure AWS SIGv4 signing works correctly
path := url.PathEscape(model) + "/invoke"
rawResponse, err := provider.completeRequest(ctx, requestBody, path, config)
path := provider.getModelPath("invoke", model, key)
rawResponse, err := provider.completeRequest(ctx, requestBody, path, key)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1306,7 +1305,7 @@ func (provider *BedrockProvider) handleTitanEmbedding(ctx context.Context, model
}

// handleCohereEmbedding handles embedding requests for Cohere models on Bedrock.
func (provider *BedrockProvider) handleCohereEmbedding(ctx context.Context, model string, config schemas.BedrockKeyConfig, input *schemas.EmbeddingInput, params *schemas.ModelParameters, providerName schemas.ModelProvider) (*schemas.BifrostResponse, *schemas.BifrostError) {
func (provider *BedrockProvider) handleCohereEmbedding(ctx context.Context, model string, key schemas.Key, input *schemas.EmbeddingInput, params *schemas.ModelParameters, providerName schemas.ModelProvider) (*schemas.BifrostResponse, *schemas.BifrostError) {
if len(input.Texts) == 0 {
return nil, newConfigurationError("no input text provided for embedding", providerName)
}
Expand All @@ -1320,8 +1319,8 @@ func (provider *BedrockProvider) handleCohereEmbedding(ctx context.Context, mode
}

// Properly escape model name for URL path to ensure AWS SIGv4 signing works correctly
path := url.PathEscape(model) + "/invoke"
rawResponse, err := provider.completeRequest(ctx, requestBody, path, config)
path := provider.getModelPath("invoke", model, key)
rawResponse, err := provider.completeRequest(ctx, requestBody, path, key)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1433,16 +1432,7 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx context.Context, postH
requestBody := mergeConfig(messageBody, preparedParams)

// Format the path with proper model identifier for streaming
path := fmt.Sprintf("%s/converse-stream", model)

if key.BedrockKeyConfig.Deployments != nil {
if inferenceProfileId, ok := key.BedrockKeyConfig.Deployments[model]; ok {
if key.BedrockKeyConfig.ARN != nil {
encodedModelIdentifier := url.PathEscape(fmt.Sprintf("%s/%s", *key.BedrockKeyConfig.ARN, inferenceProfileId))
path = fmt.Sprintf("%s/converse-stream", encodedModelIdentifier)
}
}
}
path := provider.getModelPath("converse-stream", model, key)

region := "us-east-1"
if key.BedrockKeyConfig.Region != nil {
Expand All @@ -1464,9 +1454,14 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx context.Context, postH
// Set any extra headers from network config
setExtraHeadersHTTP(req, provider.networkConfig.ExtraHeaders, nil)

// Sign the request using either explicit credentials or IAM role authentication
if signErr := signAWSRequest(ctx, req, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, region, "bedrock", providerName); signErr != nil {
return nil, signErr
// If Value is set, use API Key authentication - else use IAM role authentication
if key.Value != "" {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", key.Value))
} else {
// Sign the request using either explicit credentials or IAM role authentication
if err := signAWSRequest(ctx, req, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, region, "bedrock", providerName); err != nil {
return nil, err
}
}

// Make the request
Expand Down Expand Up @@ -1531,7 +1526,7 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx context.Context, postH
}

// Send final response
response := createBifrostChatCompletionChunkResponse(usage, finishReason, chunkIndex, params, providerName)
response := createBifrostChatCompletionChunkResponse(messageID, usage, finishReason, chunkIndex, params, providerName)
handleStreamEndWithSuccess(ctx, response, postHookRunner, responseChan, provider.logger)
}()

Expand Down Expand Up @@ -1827,3 +1822,19 @@ func (provider *BedrockProvider) Transcription(ctx context.Context, model string
func (provider *BedrockProvider) TranscriptionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model string, key schemas.Key, input *schemas.TranscriptionInput, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) {
return nil, newUnsupportedOperationError("transcription stream", "bedrock")
}

func (provider *BedrockProvider) getModelPath(basePath string, model string, key schemas.Key) string {
// Format the path with proper model identifier for streaming
path := fmt.Sprintf("%s/%s", model, basePath)

if key.BedrockKeyConfig.Deployments != nil {
if inferenceProfileId, ok := key.BedrockKeyConfig.Deployments[model]; ok {
if key.BedrockKeyConfig.ARN != nil {
encodedModelIdentifier := url.PathEscape(fmt.Sprintf("%s/%s", *key.BedrockKeyConfig.ARN, inferenceProfileId))
path = fmt.Sprintf("%s/%s", encodedModelIdentifier, basePath)
}
}
}

return path
}
7 changes: 6 additions & 1 deletion core/providers/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,7 @@ func handleOpenAIStreaming(
usage := &schemas.LLMUsage{}

var finishReason *string
var id string

for scanner.Scan() {
line := scanner.Text()
Expand Down Expand Up @@ -513,6 +514,10 @@ func handleOpenAIStreaming(
response.Choices[0].FinishReason = nil
}

if response.ID != "" && id == "" {
id = response.ID
}

// Handle regular content chunks
if choice.BifrostStreamResponseChoice != nil && (choice.BifrostStreamResponseChoice.Delta.Content != nil || len(choice.BifrostStreamResponseChoice.Delta.ToolCalls) > 0) {
chunkIndex++
Expand All @@ -529,7 +534,7 @@ func handleOpenAIStreaming(
logger.Warn(fmt.Sprintf("Error reading stream: %v", err))
processAndSendError(ctx, postHookRunner, err, responseChan, logger)
} else {
response := createBifrostChatCompletionChunkResponse(usage, finishReason, chunkIndex, params, providerName)
response := createBifrostChatCompletionChunkResponse(id, usage, finishReason, chunkIndex, params, providerName)
handleStreamEndWithSuccess(ctx, response, postHookRunner, responseChan, logger)
}
}()
Expand Down
2 changes: 2 additions & 0 deletions core/providers/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -804,13 +804,15 @@ func processAndSendError(
}

func createBifrostChatCompletionChunkResponse(
id string,
usage *schemas.LLMUsage,
finishReason *string,
currentChunkIndex int,
params *schemas.ModelParameters,
providerName schemas.ModelProvider,
) *schemas.BifrostResponse {
response := &schemas.BifrostResponse{
ID: id,
Object: "chat.completion.chunk",
Usage: usage,
Choices: []schemas.BifrostResponseChoice{
Expand Down
5 changes: 5 additions & 0 deletions core/schemas/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ type VertexKeyConfig struct {
AuthCredentials string `json:"auth_credentials,omitempty"`
}

// NOTE: To use Vertex IAM role authentication, set AuthCredentials to empty string.

// BedrockKeyConfig represents the AWS Bedrock-specific configuration.
// It contains AWS-specific settings required for authentication and service access.
type BedrockKeyConfig struct {
Expand All @@ -42,6 +44,9 @@ type BedrockKeyConfig struct {
Deployments map[string]string `json:"deployments,omitempty"` // Mapping of model identifiers to inference profiles
}

// NOTE: To use Bedrock IAM role authentication, set both AccessKey and SecretKey to empty strings.
// To use Bedrock API Key authentication, set Value in Key struct instead.

// Account defines the interface for managing provider accounts and their configurations.
// It provides methods to access provider-specific settings, API keys, and configurations.
type Account interface {
Expand Down
25 changes: 0 additions & 25 deletions core/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package bifrost

import (
"context"
"encoding/json"
"math/rand"
"time"

Expand All @@ -14,30 +13,6 @@ func Ptr[T any](v T) *T {
return &v
}

// MarshalToString marshals the given value to a JSON string.
func MarshalToString(v any) (string, error) {
if v == nil {
return "", nil
}
data, err := json.Marshal(v)
if err != nil {
return "", err
}
return string(data), nil
}

// MarshalToStringPtr marshals the given value to a JSON string and returns a pointer to the string.
func MarshalToStringPtr(v any) (*string, error) {
if v == nil {
return nil, nil
}
data, err := MarshalToString(v)
if err != nil {
return nil, err
}
return &data, nil
}

func attachContextKeys(ctx context.Context, req *schemas.BifrostRequest, requestType schemas.RequestType) context.Context {
ctx = context.WithValue(ctx, schemas.BifrostContextKeyRequestType, requestType)
ctx = context.WithValue(ctx, schemas.BifrostContextKeyRequestProvider, req.Provider)
Expand Down
20 changes: 11 additions & 9 deletions docs/quickstart/gateway/provider-configuration.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -764,12 +764,13 @@ AWS Bedrock supports both explicit credentials and IAM role authentication:
![AWS Bedrock Configuration Interface](../../media/ui-bedrock-config.png)

1. Navigate to **"Providers"** → **"AWS Bedrock"**
2. Set **Access Key**: AWS Access Key ID (or leave empty for IAM)
3. Set **Secret Key**: AWS Secret Access Key (or leave empty for IAM)
4. Set **Region**: e.g., `us-east-1`
5. Configure **Deployments**: Map model names to inference profiles
6. Set **ARN**: Required for deployments mapping
7. Save configuration
2. Set **API Key**: AWS API Key (or leave empty if using IAM role authentication)
3. Set **Access Key**: AWS Access Key ID (or leave empty to use IAM in environment)
4. Set **Secret Key**: AWS Secret Access Key (or leave empty to use IAM in environment)
5. Set **Region**: e.g., `us-east-1`
6. Configure **Deployments**: Map model names to inference profiles
7. Set **ARN**: Required for deployments mapping
8. Save configuration

</Tab>

Expand Down Expand Up @@ -833,9 +834,10 @@ curl --location 'http://localhost:8080/api/providers' \
</Tabs>

**Notes:**
- If both `access_key` and `secret_key` are empty, Bifrost uses IAM role authentication from environment
- `arn` is required for URL formation - `deployments` mapping is ignored without it
- When using `arn` + `deployments`, Bifrost uses model profiles; otherwise forms path with incoming model name directly
- If using API Key authentication, set `value` field to the API key, else leave it empty for IAM role authentication.
- In IAM role authentication, if both `access_key` and `secret_key` are empty, Bifrost uses IAM role authentication from the environment.
- `arn` is required for URL formation - `deployments` mapping is ignored without it.
- When using `arn` + `deployments`, Bifrost uses model profiles; otherwise forms path with incoming model name directly.

### Google Vertex

Expand Down
12 changes: 7 additions & 5 deletions docs/quickstart/go-sdk/provider-configuration.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -307,9 +307,10 @@ func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.Mo
{
Models: []string{"anthropic.claude-3-sonnet-20240229-v1:0", "anthropic.claude-v2:1"},
Weight: 1.0,
Value: os.Getenv("AWS_API_KEY"), // Leave empty for IAM role authentication
BedrockKeyConfig: &schemas.BedrockKeyConfig{
AccessKey: os.Getenv("AWS_ACCESS_KEY_ID"), // Leave empty for IAM role
SecretKey: os.Getenv("AWS_SECRET_ACCESS_KEY"), // Leave empty for IAM role
AccessKey: os.Getenv("AWS_ACCESS_KEY_ID"), // Leave empty for API Key authentication or system's IAM pickup
SecretKey: os.Getenv("AWS_SECRET_ACCESS_KEY"), // Leave empty for API Key authentication or system's IAM pickup
SessionToken: bifrost.Ptr(os.Getenv("AWS_SESSION_TOKEN")), // Optional
Region: bifrost.Ptr("us-east-1"),
// For model profiles (inference profiles)
Expand All @@ -327,9 +328,10 @@ func (a *MyAccount) GetKeysForProvider(ctx *context.Context, provider schemas.Mo
```

**Notes:**
- If both `AccessKey` and `SecretKey` are empty, Bifrost uses IAM role authentication from environment
- `ARN` is required for URL formation - `Deployments` mapping is ignored without it
- When using `ARN` + `Deployments`, Bifrost uses model profiles; otherwise forms path with incoming model name directly
- If using API Key authentication, set `Value` field to the API key, else leave it empty for IAM role authentication.
- In IAM role authentication, if both `AccessKey` and `SecretKey` are empty, Bifrost uses IAM from the environment.
- `ARN` is required for URL formation - `Deployments` mapping is ignored without it.
- When using `ARN` + `Deployments`, Bifrost uses model profiles; otherwise forms path with incoming model name directly.

</Tab>

Expand Down
4 changes: 1 addition & 3 deletions framework/configstore/migrations.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@ import (

// Migrate performs the necessary database migrations.
func triggerMigrations(db *gorm.DB) error {
var err error
err = migrationInit(db)
if err != nil {
if err := migrationInit(db); err != nil {
return err
}
return nil
Expand Down
5 changes: 2 additions & 3 deletions framework/configstore/sqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"os"
"strings"

bifrost "github.com/maximhq/bifrost/core"
"github.com/maximhq/bifrost/core/schemas"
"github.com/maximhq/bifrost/framework/logstore"
"github.com/maximhq/bifrost/framework/vectorstore"
Expand Down Expand Up @@ -603,7 +602,7 @@ func (s *SQLiteConfigStore) UpdateVectorStoreConfig(config *vectorstore.Config)
if err := tx.Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(&TableVectorStoreConfig{}).Error; err != nil {
return err
}
jsonConfig, err := bifrost.MarshalToStringPtr(config.Config)
jsonConfig, err := marshalToStringPtr(config.Config)
if err != nil {
return err
}
Expand Down Expand Up @@ -642,7 +641,7 @@ func (s *SQLiteConfigStore) UpdateLogsStoreConfig(config *logstore.Config) error
if err := tx.Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(&TableLogStoreConfig{}).Error; err != nil {
return err
}
jsonConfig, err := bifrost.MarshalToStringPtr(config)
jsonConfig, err := marshalToStringPtr(config)
if err != nil {
return err
}
Expand Down
Loading