Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions core/mcp/toolmanager.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package mcp
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"
Expand Down Expand Up @@ -594,12 +595,22 @@ func (m *ToolsManager) executeToolInternal(ctx *schemas.BifrostContext, toolCall

// Try identity-based token lookup first (works even without session token)
accessToken, err := m.oauth2Provider.GetUserAccessTokenByIdentity(ctx, virtualKeyID, userID, sessionToken, client.ExecutionConfig.ID)
if err != nil && sessionToken != "" {
// Had session but token lookup failed — return error
if err != nil && !errors.Is(err, schemas.ErrOAuth2TokenNotFound) {
// Had session but token lookup failed with a real error (not just "not found") — return error
return nil, "", "", fmt.Errorf("failed to get user access token for MCP server %s: %w", client.ExecutionConfig.Name, err)
}
if err != nil {
// No session and no token by identity — user hasn't authenticated yet.
// No token found — user hasn't authenticated with this MCP server yet.
Comment thread
coderabbitai[bot] marked this conversation as resolved.
// In LLM gateway mode with no identity, we can't track who this user is,
// so an OAuth flow would produce an orphaned token. Return a clear error instead.
isMCPGateway, _ := ctx.Value(schemas.BifrostContextKeyIsMCPGateway).(bool)
if !isMCPGateway && userID == "" && virtualKeyID == "" {
return nil, "", "", fmt.Errorf(
"per-user OAuth for %s requires a user identity: include X-Bf-User-Id or a Virtual Key in your request so the token can be linked to you",
client.ExecutionConfig.Name,
)
}
Comment thread
Pratham-Mishra04 marked this conversation as resolved.

// Initiate OAuth flow to get a proper authorize URL with session tracking.
if client.ExecutionConfig.OauthConfigID == nil || *client.ExecutionConfig.OauthConfigID == "" {
return nil, "", "", fmt.Errorf("per-user OAuth requires an OAuth config but MCP client %s has none", client.ExecutionConfig.Name)
Expand Down
78 changes: 22 additions & 56 deletions core/providers/bedrock/bedrock.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ func (provider *BedrockProvider) completeRequest(ctx *schemas.BifrostContext, js
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", key.Value.GetValue()))
} else {
// Sign the request using either explicit credentials or IAM role authentication
if err := signAWSRequest(ctx, req, config.AccessKey, config.SecretKey, config.SessionToken, config.RoleARN, config.ExternalID, config.RoleSessionName, region, bedrockSigningService, provider.GetProviderKey()); err != nil {
if err := signAWSRequest(ctx, req, config.AccessKey, config.SecretKey, config.SessionToken, config.RoleARN, config.ExternalID, config.RoleSessionName, region, bedrockSigningService); err != nil {
return nil, 0, nil, err
}
}
Expand Down Expand Up @@ -348,7 +348,7 @@ func (provider *BedrockProvider) completeAgentRuntimeRequest(ctx *schemas.Bifros
if key.Value.GetValue() != "" {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", key.Value.GetValue()))
} else {
if err := signAWSRequest(ctx, req, config.AccessKey, config.SecretKey, config.SessionToken, config.RoleARN, config.ExternalID, config.RoleSessionName, region, bedrockSigningService, provider.GetProviderKey()); err != nil {
if err := signAWSRequest(ctx, req, config.AccessKey, config.SecretKey, config.SessionToken, config.RoleARN, config.ExternalID, config.RoleSessionName, region, bedrockSigningService); err != nil {
return nil, 0, nil, err
}
}
Expand Down Expand Up @@ -450,8 +450,8 @@ func (provider *BedrockProvider) makeStreamingRequest(ctx *schemas.BifrostContex
} else {
req.Header.Set("Accept", "application/vnd.amazon.eventstream")
// Sign the request using either explicit credentials or IAM role authentication
if err := signAWSRequest(ctx, req, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, bedrockSigningService, providerName); err != nil {
return nil, deployment, err
if err := signAWSRequest(ctx, req, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, bedrockSigningService); err != nil {
return nil, err
}
}

Expand Down Expand Up @@ -480,26 +480,20 @@ func (provider *BedrockProvider) makeStreamingRequest(ctx *schemas.BifrostContex
var opErr *net.OpError
var dnsErr *net.DNSError
if errors.As(respErr, &opErr) || errors.As(respErr, &dnsErr) {
return nil, deployment, &schemas.BifrostError{
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: schemas.ErrProviderNetworkError,
Error: respErr,
},
ExtraFields: schemas.BifrostErrorExtraFields{
Provider: providerName,
},
}
}
return nil, deployment, &schemas.BifrostError{
return nil, &schemas.BifrostError{
IsBifrostError: false,
Error: &schemas.ErrorField{
Message: schemas.ErrProviderDoRequest,
Error: respErr,
},
ExtraFields: schemas.BifrostErrorExtraFields{
Provider: providerName,
},
}
}

Expand Down Expand Up @@ -706,7 +700,7 @@ func (provider *BedrockProvider) listModelsByKey(ctx *schemas.BifrostContext, ke
} else {
// Sign the request using either explicit credentials or IAM role authentication

if err := signAWSRequest(ctx, req, config.AccessKey, config.SecretKey, config.SessionToken, config.RoleARN, config.ExternalID, config.RoleSessionName, region, bedrockSigningService, providerName); err != nil {
if err := signAWSRequest(ctx, req, config.AccessKey, config.SecretKey, config.SessionToken, config.RoleARN, config.ExternalID, config.RoleSessionName, region, bedrockSigningService); err != nil {
return nil, err
}
}
Expand Down Expand Up @@ -969,13 +963,12 @@ func (provider *BedrockProvider) TextCompletionStream(ctx *schemas.BifrostContex
Error: err,
},
ExtraFields: schemas.BifrostErrorExtraFields{
Comment thread
akshaydeo marked this conversation as resolved.
RequestType: schemas.TextCompletionStreamRequest,
Provider: providerName,
ModelRequested: request.Model,
RequestType: schemas.TextCompletionStreamRequest,
Provider: providerName,
},
Comment thread
Pratham-Mishra04 marked this conversation as resolved.
}, responseChan, provider.logger)
} else {
providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.TextCompletionStreamRequest, providerName, request.Model, provider.logger)
providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger)
}
return
}
Expand Down Expand Up @@ -1006,15 +999,10 @@ func (provider *BedrockProvider) TextCompletionStream(ctx *schemas.BifrostContex
Error: &schemas.ErrorField{
Message: fmt.Sprintf("%s stream %s: %s", providerName, excType, errMsg),
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.TextCompletionStreamRequest,
Provider: providerName,
ModelRequested: request.Model,
},
}, responseChan, provider.logger)
} else {
err := fmt.Errorf("%s stream %s: %s", providerName, excType, errMsg)
providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.TextCompletionStreamRequest, providerName, request.Model, provider.logger)
providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger)
}
return
}
Expand Down Expand Up @@ -1130,7 +1118,6 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx *schemas.BifrostContex
if err := providerUtils.CheckOperationAllowed(schemas.Bedrock, provider.customProviderConfig, schemas.ChatCompletionStreamRequest); err != nil {
return nil, err
}

jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody(
ctx,
request,
Expand All @@ -1152,7 +1139,7 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx *schemas.BifrostContex
responseChan := make(chan *schemas.BifrostStreamChunk, schemas.DefaultStreamBufferSize)

providerUtils.SetStreamIdleTimeoutIfEmpty(ctx, provider.networkConfig.StreamIdleTimeoutInSeconds)

providerName := provider.GetProviderKey()
// Start streaming in a goroutine
go func() {
defer func() {
Expand Down Expand Up @@ -1225,14 +1212,9 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx *schemas.BifrostContex
Message: schemas.ErrProviderNetworkError,
Error: err,
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.ChatCompletionStreamRequest,
Provider: providerName,
ModelRequested: request.Model,
},
}, responseChan, provider.logger)
} else {
providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.ChatCompletionStreamRequest, providerName, request.Model, provider.logger)
providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger)
}
return
}
Expand Down Expand Up @@ -1260,14 +1242,9 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx *schemas.BifrostContex
Error: &schemas.ErrorField{
Message: err.Error(),
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.ChatCompletionStreamRequest,
Provider: providerName,
ModelRequested: request.Model,
},
}, responseChan, provider.logger)
} else {
providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.ChatCompletionStreamRequest, providerName, request.Model, provider.logger)
providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger)
}
return
}
Expand Down Expand Up @@ -1556,7 +1533,7 @@ func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, po
lastChunkTime := startTime
decoder := eventstream.NewDecoder()
payloadBuf := make([]byte, 0, 1024*1024) // 1MB payload buffer

providerName := provider.GetProviderKey()
for {
// If context was cancelled/timed out, let defer handle it
if ctx.Err() != nil {
Expand Down Expand Up @@ -1608,14 +1585,9 @@ func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, po
Message: schemas.ErrProviderNetworkError,
Error: err,
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.ResponsesStreamRequest,
Provider: providerName,
ModelRequested: request.Model,
},
}, responseChan, provider.logger)
} else {
providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.ResponsesStreamRequest, providerName, request.Model, provider.logger)
providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger)
}
return
}
Expand Down Expand Up @@ -1643,14 +1615,9 @@ func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, po
Error: &schemas.ErrorField{
Message: err.Error(),
},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.ResponsesStreamRequest,
Provider: providerName,
ModelRequested: request.Model,
},
}, responseChan, provider.logger)
} else {
providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.ResponsesStreamRequest, providerName, request.Model, provider.logger)
providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger)
}
return
}
Expand Down Expand Up @@ -1819,7 +1786,6 @@ func (provider *BedrockProvider) Embedding(ctx *schemas.BifrostContext, key sche
if bifrostError != nil {
return nil, providerUtils.EnrichError(ctx, bifrostError, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse)
}

// Parse response based on model type
var bifrostResponse *schemas.BifrostEmbeddingResponse
switch modelType {
Expand All @@ -1838,7 +1804,7 @@ func (provider *BedrockProvider) Embedding(ctx *schemas.BifrostContext, key sche
}
converted, convErr := cohereResp.ToBifrostEmbeddingResponse()
if convErr != nil {
return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error parsing Cohere embedding response", convErr, providerName), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse)
return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error parsing Cohere embedding response", convErr), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse)
}
bifrostResponse = converted
bifrostResponse.Model = request.Model
Expand Down Expand Up @@ -2878,7 +2844,7 @@ func (provider *BedrockProvider) BatchCreate(ctx *schemas.BifrostContext, key sc
}

// Sign request
if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, bedrockSigningService, providerName); err != nil {
if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, bedrockSigningService); err != nil {
return nil, providerUtils.EnrichError(ctx, err, jsonData, nil, sendBackRawRequest, sendBackRawResponse)
}

Expand Down Expand Up @@ -3003,7 +2969,7 @@ func (provider *BedrockProvider) BatchList(ctx *schemas.BifrostContext, keys []s
}

// Sign request
if bifrostErr := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, bedrockSigningService, providerName); bifrostErr != nil {
if bifrostErr := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, bedrockSigningService); bifrostErr != nil {
return nil, bifrostErr
}

Expand Down Expand Up @@ -3182,7 +3148,7 @@ func (provider *BedrockProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys
}

// Sign request
if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, bedrockSigningService, providerName); err != nil {
if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, bedrockSigningService); err != nil {
lastErr = err
continue
}
Expand Down Expand Up @@ -3317,7 +3283,7 @@ func (provider *BedrockProvider) BatchCancel(ctx *schemas.BifrostContext, keys [
}

// Sign request
if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, bedrockSigningService, providerName); err != nil {
if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, bedrockSigningService); err != nil {
lastErr = err
continue
}
Expand Down
16 changes: 12 additions & 4 deletions core/providers/bedrock/invoke.go
Original file line number Diff line number Diff line change
Expand Up @@ -931,8 +931,12 @@ func ToBedrockInvokeImagesResponse(ctx *schemas.BifrostContext, resp *schemas.Bi
}

model := resp.Model
Comment thread
akshaydeo marked this conversation as resolved.
if resp.ExtraFields.ModelRequested != "" {
model = resp.ExtraFields.ModelRequested
if model == "" {
if resp.ExtraFields.ResolvedModelUsed != "" {
model = resp.ExtraFields.ResolvedModelUsed
} else if resp.ExtraFields.OriginalModelRequested != "" {
model = resp.ExtraFields.OriginalModelRequested
}
}
Comment thread
akshaydeo marked this conversation as resolved.

// Stability AI models use the same BedrockImageGenerationResponse format as Titan/Nova Canvas
Expand Down Expand Up @@ -974,8 +978,12 @@ func ToBedrockEmbeddingInvokeResponse(resp *schemas.BifrostEmbeddingResponse) (i
// Use model name to distinguish Cohere from Titan — not batch size.
// A single-input Cohere request must still return the Cohere envelope format.
model := resp.Model
Comment thread
akshaydeo marked this conversation as resolved.
if resp.ExtraFields.ModelRequested != "" {
model = resp.ExtraFields.ModelRequested
if model == "" {
if resp.ExtraFields.ResolvedModelUsed != "" {
model = resp.ExtraFields.ResolvedModelUsed
} else if resp.ExtraFields.OriginalModelRequested != "" {
model = resp.ExtraFields.OriginalModelRequested
}
}

if strings.Contains(strings.ToLower(model), "cohere") {
Expand Down
2 changes: 1 addition & 1 deletion core/providers/bedrock/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ func TestMakeStreamingRequest_StaleConnection_IsRetryable(t *testing.T) {
ctx := testBedrockCtx()
key := testBedrockKey()

_, _, bifrostErr := provider.makeStreamingRequest(ctx, []byte(`{}`), key, "anthropic.claude-sonnet-4-5", "converse-stream")
_, bifrostErr := provider.makeStreamingRequest(ctx, []byte(`{}`), key, "anthropic.claude-sonnet-4-5", "converse-stream")

require.NotNil(t, bifrostErr, "expected error when server closes connection")
assert.False(t, bifrostErr.IsBifrostError,
Expand Down
2 changes: 1 addition & 1 deletion core/providers/gemini/gemini.go
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ func (provider *GeminiProvider) ChatCompletion(ctx *schemas.BifrostContext, key
ctx,
request,
func() (providerUtils.RequestBodyWithExtraParams, error) {
return ToGeminiChatCompletionRequest(request), nil
return ToGeminiChatCompletionRequest(request)
})
if err != nil {
return nil, err
Expand Down
1 change: 0 additions & 1 deletion core/providers/openai/chat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,6 @@ func TestToOpenAIChatRequest_FireworksPreservesReasoningAndCacheIsolation(t *tes
func() (providerUtils.RequestBodyWithExtraParams, error) {
return ToOpenAIChatRequest(ctx, bifrostReq), nil
},
schemas.Fireworks,
)
if bifrostErr != nil {
t.Fatalf("failed to build request body: %v", bifrostErr.Error.Message)
Expand Down
10 changes: 0 additions & 10 deletions core/providers/openai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -1763,11 +1763,6 @@ func HandleOpenAIResponsesStreaming(
Type: schemas.Ptr(string(schemas.ResponsesStreamResponseTypeFailed)),
IsBifrostError: false,
Error: &schemas.ErrorField{},
ExtraFields: schemas.BifrostErrorExtraFields{
RequestType: schemas.ResponsesStreamRequest,
Provider: providerName,
ModelRequested: request.Model,
},
}
if response.Response != nil && response.Response.Error != nil {
bifrostErr.Error.Message = response.Response.Error.Message
Expand All @@ -1777,12 +1772,7 @@ func HandleOpenAIResponsesStreaming(
providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger)
return
}

response.ExtraFields.RequestType = schemas.ResponsesStreamRequest
response.ExtraFields.Provider = providerName
response.ExtraFields.ModelRequested = request.Model
response.ExtraFields.ChunkIndex = response.SequenceNumber

if response.Type == schemas.ResponsesStreamResponseTypeCompleted || response.Type == schemas.ResponsesStreamResponseTypeIncomplete {
// Set raw request if enabled
if sendBackRawRequest {
Expand Down
1 change: 0 additions & 1 deletion core/providers/openai/text_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ func TestToOpenAITextCompletionRequest_FireworksUsesCacheIsolation(t *testing.T)
func() (providerUtils.RequestBodyWithExtraParams, error) {
return ToOpenAITextCompletionRequest(bifrostReq), nil
},
schemas.Fireworks,
)
if bifrostErr != nil {
t.Fatalf("failed to build request body: %v", bifrostErr.Error.Message)
Expand Down
Loading
Loading