Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion core/bifrost.go
Original file line number Diff line number Diff line change
Expand Up @@ -1432,7 +1432,7 @@ func (bifrost *Bifrost) selectKeyFromProviderForModel(ctx *context.Context, prov
// filter out keys which dont support the model, if the key has no models, it is supported for all models
var supportedKeys []schemas.Key
for _, key := range keys {
if (slices.Contains(key.Models, model) && (strings.TrimSpace(key.Value) != "" || providerKey == schemas.Vertex)) || len(key.Models) == 0 {
if (slices.Contains(key.Models, model) && (strings.TrimSpace(key.Value) != "" || canProviderKeyValueBeEmpty(providerKey))) || len(key.Models) == 0 {
supportedKeys = append(supportedKeys, key)
}
}
Expand Down
99 changes: 54 additions & 45 deletions core/providers/bedrock.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,6 @@ type BedrockStreamMetadataEvent struct {
type BedrockProvider struct {
logger schemas.Logger // Logger for provider operations
client *http.Client // HTTP client for API requests
meta schemas.MetaConfig // Bedrock-specific configuration
networkConfig schemas.NetworkConfig // Network configuration including extra headers
sendBackRawResponse bool // Whether to include raw response in BifrostResponse
}
Expand Down Expand Up @@ -227,10 +226,6 @@ func releaseBedrockChatResponse(resp *BedrockChatResponse) {
func NewBedrockProvider(config *schemas.ProviderConfig, logger schemas.Logger) (*BedrockProvider, error) {
config.CheckAndSetDefaults()

if config.MetaConfig == nil {
return nil, fmt.Errorf("meta config is not set")
}

client := &http.Client{Timeout: time.Second * time.Duration(config.NetworkConfig.DefaultRequestTimeoutInSeconds)}

// Pre-warm response pools
Expand All @@ -242,7 +237,6 @@ func NewBedrockProvider(config *schemas.ProviderConfig, logger schemas.Logger) (
return &BedrockProvider{
logger: logger,
client: client,
meta: config.MetaConfig,
networkConfig: config.NetworkConfig,
sendBackRawResponse: config.SendBackRawResponse,
}, nil
Expand All @@ -256,19 +250,10 @@ func (provider *BedrockProvider) GetProviderKey() schemas.ModelProvider {
// CompleteRequest sends a request to Bedrock's API and handles the response.
// It constructs the API URL, sets up AWS authentication, and processes the response.
// Returns the response body or an error if the request fails.
func (provider *BedrockProvider) completeRequest(ctx context.Context, requestBody map[string]interface{}, path string, accessKey string) ([]byte, *schemas.BifrostError) {
if provider.meta == nil {
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: schemas.ErrorField{
Message: "meta config for bedrock is not provided",
},
}
}

func (provider *BedrockProvider) completeRequest(ctx context.Context, requestBody map[string]interface{}, path string, config schemas.BedrockKeyConfig) ([]byte, *schemas.BifrostError) {
region := "us-east-1"
if provider.meta.GetRegion() != nil {
region = *provider.meta.GetRegion()
if config.Region != nil {
region = *config.Region
}

jsonBody, err := sonic.Marshal(requestBody)
Expand Down Expand Up @@ -307,8 +292,8 @@ func (provider *BedrockProvider) completeRequest(ctx context.Context, requestBod
// Set any extra headers from network config
setExtraHeadersHTTP(req, provider.networkConfig.ExtraHeaders, nil)

if provider.meta.GetSecretAccessKey() != nil {
if err := signAWSRequest(req, accessKey, *provider.meta.GetSecretAccessKey(), provider.meta.GetSessionToken(), region, "bedrock"); err != nil {
if config.SecretKey != "" {
if err := signAWSRequest(req, config.AccessKey, config.SecretKey, config.SessionToken, region, "bedrock"); err != nil {
return nil, err
}
} else {
Expand Down Expand Up @@ -822,13 +807,17 @@ func (provider *BedrockProvider) prepareTextCompletionParams(params map[string]i
// It formats the request, sends it to Bedrock, and processes the response.
// Returns a BifrostResponse containing the completion results or an error if the request fails.
func (provider *BedrockProvider) TextCompletion(ctx context.Context, model string, key schemas.Key, text string, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) {
if key.BedrockKeyConfig == nil {
return nil, newConfigurationError("bedrock key config is not provided", schemas.Bedrock)
}

preparedParams := provider.prepareTextCompletionParams(prepareParams(params), model)

requestBody := mergeConfig(map[string]interface{}{
"prompt": text,
}, preparedParams)

body, err := provider.completeRequest(ctx, requestBody, fmt.Sprintf("%s/invoke", model), key.Value)
body, err := provider.completeRequest(ctx, requestBody, fmt.Sprintf("%s/invoke", model), *key.BedrockKeyConfig)
Comment thread
Pratham-Mishra04 marked this conversation as resolved.
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -908,6 +897,10 @@ func (provider *BedrockProvider) extractToolsFromHistory(messages []schemas.Bifr
// It formats the request, sends it to Bedrock, and processes the response.
// Returns a BifrostResponse containing the completion results or an error if the request fails.
func (provider *BedrockProvider) ChatCompletion(ctx context.Context, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) {
if key.BedrockKeyConfig == nil {
return nil, newConfigurationError("bedrock key config is not provided", schemas.Bedrock)
}

messageBody, err := provider.prepareChatCompletionMessages(messages, model)
if err != nil {
return nil, err
Expand Down Expand Up @@ -939,17 +932,17 @@ func (provider *BedrockProvider) ChatCompletion(ctx context.Context, model strin
// Format the path with proper model identifier
path := fmt.Sprintf("%s/converse", model)

if provider.meta != nil && provider.meta.GetInferenceProfiles() != nil {
if inferenceProfileId, ok := provider.meta.GetInferenceProfiles()[model]; ok {
if provider.meta.GetARN() != nil {
encodedModelIdentifier := url.PathEscape(fmt.Sprintf("%s/%s", *provider.meta.GetARN(), inferenceProfileId))
if key.BedrockKeyConfig.Deployments != nil {
if inferenceProfileId, ok := key.BedrockKeyConfig.Deployments[model]; ok {
if key.BedrockKeyConfig.ARN != nil {
encodedModelIdentifier := url.PathEscape(fmt.Sprintf("%s/%s", *key.BedrockKeyConfig.ARN, inferenceProfileId))
path = fmt.Sprintf("%s/converse", encodedModelIdentifier)
}
}
}

// Create the signed request
responseBody, err := provider.completeRequest(ctx, requestBody, path, key.Value)
responseBody, err := provider.completeRequest(ctx, requestBody, path, *key.BedrockKeyConfig)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1088,7 +1081,7 @@ func signAWSRequest(req *http.Request, accessKey, secretKey string, sessionToken
AccessKeyID: accessKey,
SecretAccessKey: secretKey,
}
if sessionToken != nil {
if sessionToken != nil && *sessionToken != "" {
creds.SessionToken = *sessionToken
}
return creds, nil
Expand Down Expand Up @@ -1118,18 +1111,30 @@ func signAWSRequest(req *http.Request, accessKey, secretKey string, sessionToken
// Embedding generates embeddings for the given input text(s) using Amazon Bedrock.
// Supports Titan and Cohere embedding models. Returns a BifrostResponse containing the embedding(s) and any error that occurred.
func (provider *BedrockProvider) Embedding(ctx context.Context, model string, key schemas.Key, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) {
if key.BedrockKeyConfig == nil {
return nil, newConfigurationError("bedrock key config is not provided", schemas.Bedrock)
}

switch {
case strings.HasPrefix(model, "amazon.titan-embed-text"):
return provider.handleTitanEmbedding(ctx, model, key.Value, input, params)
return provider.handleTitanEmbedding(ctx, model, *key.BedrockKeyConfig, input, params)
case strings.HasPrefix(model, "cohere.embed"):
return provider.handleCohereEmbedding(ctx, model, key.Value, input, params)
return provider.handleCohereEmbedding(ctx, model, *key.BedrockKeyConfig, input, params)
Comment thread
Pratham-Mishra04 marked this conversation as resolved.
default:
return nil, newConfigurationError("embedding is not supported for this Bedrock model", schemas.Bedrock)
}
}

// handleTitanEmbedding handles embedding requests for Amazon Titan models.
func (provider *BedrockProvider) handleTitanEmbedding(ctx context.Context, model string, key string, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) {
func (provider *BedrockProvider) handleTitanEmbedding(ctx context.Context, model string, config schemas.BedrockKeyConfig, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) {
// Titan Text Embeddings V1/V2 - only supports single text input
if len(input.Texts) == 0 {
return nil, newConfigurationError("no input text provided for embedding", schemas.Bedrock)
}
if len(input.Texts) > 1 {
return nil, newConfigurationError("Amazon Titan embedding models support only single text input, received multiple texts", schemas.Bedrock)
}

requestBody := map[string]interface{}{
"inputText": input.Texts[0],
}
Expand All @@ -1148,7 +1153,7 @@ func (provider *BedrockProvider) handleTitanEmbedding(ctx context.Context, model

// Properly escape model name for URL path to ensure AWS SIGv4 signing works correctly
path := url.PathEscape(model) + "/invoke"
rawResponse, err := provider.completeRequest(ctx, requestBody, path, key)
rawResponse, err := provider.completeRequest(ctx, requestBody, path, config)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1192,7 +1197,11 @@ func (provider *BedrockProvider) handleTitanEmbedding(ctx context.Context, model
}

// handleCohereEmbedding handles embedding requests for Cohere models on Bedrock.
func (provider *BedrockProvider) handleCohereEmbedding(ctx context.Context, model string, key string, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) {
func (provider *BedrockProvider) handleCohereEmbedding(ctx context.Context, model string, config schemas.BedrockKeyConfig, input *schemas.EmbeddingInput, params *schemas.ModelParameters) (*schemas.BifrostResponse, *schemas.BifrostError) {
if len(input.Texts) == 0 {
return nil, newConfigurationError("no input text provided for embedding", schemas.Bedrock)
}

requestBody := map[string]interface{}{
"texts": input.Texts,
"input_type": "search_document",
Expand All @@ -1203,7 +1212,7 @@ func (provider *BedrockProvider) handleCohereEmbedding(ctx context.Context, mode

// Properly escape model name for URL path to ensure AWS SIGv4 signing works correctly
path := url.PathEscape(model) + "/invoke"
rawResponse, err := provider.completeRequest(ctx, requestBody, path, key)
rawResponse, err := provider.completeRequest(ctx, requestBody, path, config)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1255,6 +1264,10 @@ func (provider *BedrockProvider) handleCohereEmbedding(ctx context.Context, mode
// It formats the request, sends it to Bedrock, and processes the streaming response.
// Returns a channel for streaming BifrostResponse objects or an error if the request fails.
func (provider *BedrockProvider) ChatCompletionStream(ctx context.Context, postHookRunner schemas.PostHookRunner, model string, key schemas.Key, messages []schemas.BifrostMessage, params *schemas.ModelParameters) (chan *schemas.BifrostStream, *schemas.BifrostError) {
if key.BedrockKeyConfig == nil {
return nil, newConfigurationError("bedrock key config is not provided", schemas.Bedrock)
}

messageBody, err := provider.prepareChatCompletionMessages(messages, model)
if err != nil {
return nil, err
Expand Down Expand Up @@ -1286,22 +1299,18 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx context.Context, postH
// Format the path with proper model identifier for streaming
path := fmt.Sprintf("%s/converse-stream", model)

if provider.meta != nil && provider.meta.GetInferenceProfiles() != nil {
if inferenceProfileId, ok := provider.meta.GetInferenceProfiles()[model]; ok {
if provider.meta.GetARN() != nil {
encodedModelIdentifier := url.PathEscape(fmt.Sprintf("%s/%s", *provider.meta.GetARN(), inferenceProfileId))
if key.BedrockKeyConfig.Deployments != nil {
if inferenceProfileId, ok := key.BedrockKeyConfig.Deployments[model]; ok {
if key.BedrockKeyConfig.ARN != nil {
encodedModelIdentifier := url.PathEscape(fmt.Sprintf("%s/%s", *key.BedrockKeyConfig.ARN, inferenceProfileId))
path = fmt.Sprintf("%s/converse-stream", encodedModelIdentifier)
}
}
}

if provider.meta == nil {
return nil, newConfigurationError("meta config for bedrock is not provided", schemas.Bedrock)
}

region := "us-east-1"
if provider.meta.GetRegion() != nil {
region = *provider.meta.GetRegion()
if key.BedrockKeyConfig.Region != nil {
region = *key.BedrockKeyConfig.Region
}

// Create the streaming request
Expand All @@ -1320,8 +1329,8 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx context.Context, postH
setExtraHeadersHTTP(req, provider.networkConfig.ExtraHeaders, nil)

// Sign the request for AWS
if provider.meta.GetSecretAccessKey() != nil {
if signErr := signAWSRequest(req, key.Value, *provider.meta.GetSecretAccessKey(), provider.meta.GetSessionToken(), region, "bedrock"); signErr != nil {
if key.BedrockKeyConfig.SecretKey != "" {
if signErr := signAWSRequest(req, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, region, "bedrock"); signErr != nil {
return nil, signErr
}
} else {
Expand Down
4 changes: 2 additions & 2 deletions core/providers/vertex.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ func (provider *VertexProvider) ChatCompletion(ctx context.Context, model string

region := key.VertexKeyConfig.Region
if region == "" {
return nil, newConfigurationError("region is not set in meta config", schemas.Vertex)
return nil, newConfigurationError("region is not set in key config", schemas.Vertex)
}

url := fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions", region, projectID, region)
Expand Down Expand Up @@ -340,7 +340,7 @@ func (provider *VertexProvider) ChatCompletionStream(ctx context.Context, postHo

region := key.VertexKeyConfig.Region
if region == "" {
return nil, newConfigurationError("region is not set in meta config", schemas.Vertex)
return nil, newConfigurationError("region is not set in key config", schemas.Vertex)
}

client, err := getAuthClient(key)
Expand Down
24 changes: 18 additions & 6 deletions core/schemas/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@ import "context"
// Key represents an API key and its associated configuration for a provider.
// It contains the key value, supported models, and a weight for load balancing.
type Key struct {
ID string `json:"id"` // The unique identifier for the key (not used by bifrost, but can be used by users to identify the key)
Value string `json:"value"` // The actual API key value
Models []string `json:"models"` // List of models this key can access
Weight float64 `json:"weight"` // Weight for load balancing between multiple keys
AzureKeyConfig *AzureKeyConfig `json:"azure_key_config,omitempty"` // Azure-specific key configuration
VertexKeyConfig *VertexKeyConfig `json:"vertex_key_config,omitempty"` // Vertex-specific key configuration
ID string `json:"id"` // The unique identifier for the key (not used by bifrost, but can be used by users to identify the key)
Value string `json:"value"` // The actual API key value
Models []string `json:"models"` // List of models this key can access
Weight float64 `json:"weight"` // Weight for load balancing between multiple keys
AzureKeyConfig *AzureKeyConfig `json:"azure_key_config,omitempty"` // Azure-specific key configuration
VertexKeyConfig *VertexKeyConfig `json:"vertex_key_config,omitempty"` // Vertex-specific key configuration
BedrockKeyConfig *BedrockKeyConfig `json:"bedrock_key_config,omitempty"` // AWS Bedrock-specific key configuration
}

// AzureKeyConfig represents the Azure-specific configuration.
Expand All @@ -30,6 +31,17 @@ type VertexKeyConfig struct {
AuthCredentials string `json:"auth_credentials,omitempty"`
}

// BedrockKeyConfig represents the AWS Bedrock-specific configuration.
// It contains AWS-specific settings required for authentication and service access.
type BedrockKeyConfig struct {
AccessKey string `json:"access_key,omitempty"` // AWS access key for authentication
SecretKey string `json:"secret_key,omitempty"` // AWS secret access key for authentication
SessionToken *string `json:"session_token,omitempty"` // AWS session token for temporary credentials
Region *string `json:"region,omitempty"` // AWS region for service access
ARN *string `json:"arn,omitempty"` // Amazon Resource Name for resource identification
Deployments map[string]string `json:"deployments,omitempty"` // Mapping of model identifiers to inference profiles
}

// Account defines the interface for managing provider accounts and their configurations.
// It provides methods to access provider-specific settings, API keys, and configurations.
type Account interface {
Expand Down
44 changes: 0 additions & 44 deletions core/schemas/meta/bedrock.go

This file was deleted.

Loading