Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
549 changes: 531 additions & 18 deletions .github/workflows/scripts/run-migration-tests.sh

Large diffs are not rendered by default.

9 changes: 9 additions & 0 deletions core/bifrost.go
Original file line number Diff line number Diff line change
Expand Up @@ -6846,6 +6846,9 @@ func (bifrost *Bifrost) getAllSupportedKeys(ctx *schemas.BifrostContext, provide
if ctx != nil {
key, ok := ctx.Value(schemas.BifrostContextKeyDirectKey).(schemas.Key)
if ok {
if err := validateKey(baseProviderType, &key); err != nil {
return nil, fmt.Errorf("invalid direct key for provider %v: %w", baseProviderType, err)
}
// If a direct key is specified, return it as a single-element slice
return []schemas.Key{key}, nil
}
Expand Down Expand Up @@ -6893,6 +6896,9 @@ func (bifrost *Bifrost) getKeysForBatchAndFileOps(ctx *schemas.BifrostContext, p
if ctx != nil {
key, ok := ctx.Value(schemas.BifrostContextKeyDirectKey).(schemas.Key)
if ok {
if err := validateKey(baseProviderType, &key); err != nil {
return nil, fmt.Errorf("invalid direct key for provider %v: %w", baseProviderType, err)
}
// If a direct key is specified, return it as a single-element slice
return []schemas.Key{key}, nil
}
Expand Down Expand Up @@ -6981,6 +6987,9 @@ func (bifrost *Bifrost) selectKeyFromProviderForModelWithPool(ctx *schemas.Bifro
// DirectKey: caller supplied a key directly — no pool, no rotation.
if ctx != nil {
if key, ok := ctx.Value(schemas.BifrostContextKeyDirectKey).(schemas.Key); ok {
if err := validateKey(baseProviderType, &key); err != nil {
return nil, false, fmt.Errorf("invalid direct key for provider %v: %w", baseProviderType, err)
}
return []schemas.Key{key}, false, nil
}
}
Expand Down
2 changes: 2 additions & 0 deletions core/changelog.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
- fix: usage of per-user OAuth servers in codemode
- fix: adds validation on direct api keys
3 changes: 3 additions & 0 deletions core/mcp/codemode.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ type CodeModeDependencies struct {

// LogMutex protects concurrent access to logs during code execution
LogMutex *sync.Mutex

// OAuth2Provider handles per-user OAuth token lookup and flow initiation
OAuth2Provider schemas.OAuth2Provider
}

// DefaultCodeModeConfig returns the default configuration for CodeMode.
Expand Down
25 changes: 24 additions & 1 deletion core/mcp/codemode/starlark/executecode.go
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,30 @@ func (s *StarlarkCodeMode) callMCPTool(ctx *schemas.BifrostContext, clientName,
toolCtx, cancel := context.WithTimeout(nestedCtx, toolExecutionTimeout)
defer cancel()

toolResponse, callErr := client.Conn.CallTool(toolCtx, callRequest)
var toolResponse *mcp.CallToolResult
var callErr error

if client.ExecutionConfig.AuthType == schemas.MCPAuthTypePerUserOauth {
accessToken, err := utils.ResolvePerUserOAuthToken(nestedCtx, client, s.oauth2Provider)
if err != nil {
return nil, err
}

if client.Conn == nil {
// Per-user OAuth with no persistent connection — use a temporary connection.
// Assign to outer toolResponse/callErr so the shared logging + post-hooks path runs.
toolResponse, callErr = codemcp.ExecuteToolWithUserToken(toolCtx, client.ExecutionConfig, toolNameToCall, args, accessToken, s.logger)
if callErr != nil && toolCtx.Err() == context.DeadlineExceeded {
callErr = fmt.Errorf("MCP tool call timed out after %v: %s", toolExecutionTimeout, toolName)
}
} else {
callRequest.Header = utils.BuildPerUserOAuthHeaders(callRequest.Header, accessToken)
toolResponse, callErr = client.Conn.CallTool(toolCtx, callRequest)
}
} else {
toolResponse, callErr = client.Conn.CallTool(toolCtx, callRequest)
}

latency := time.Since(startTime).Milliseconds()

var mcpResp *schemas.BifrostMCPResponse
Expand Down
2 changes: 2 additions & 0 deletions core/mcp/codemode/starlark/starlark.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ type StarlarkCodeMode struct {
pluginPipelineProvider func() mcp.PluginPipeline
releasePluginPipeline func(pipeline mcp.PluginPipeline)
fetchNewRequestIDFunc func(ctx *schemas.BifrostContext) string
oauth2Provider schemas.OAuth2Provider

// Logger for this instance
logger schemas.Logger
Expand Down Expand Up @@ -86,6 +87,7 @@ func (s *StarlarkCodeMode) SetDependencies(deps *mcp.CodeModeDependencies) {
s.pluginPipelineProvider = deps.PluginPipelineProvider
s.releasePluginPipeline = deps.ReleasePluginPipeline
s.fetchNewRequestIDFunc = deps.FetchNewRequestIDFunc
s.oauth2Provider = deps.OAuth2Provider
}
}

Expand Down
75 changes: 6 additions & 69 deletions core/mcp/toolmanager.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ package mcp
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"
Expand Down Expand Up @@ -184,6 +183,7 @@ func (m *ToolsManager) GetCodeModeDependencies() *CodeModeDependencies {
PluginPipelineProvider: m.pluginPipelineProvider,
ReleasePluginPipeline: m.releasePluginPipeline,
FetchNewRequestIDFunc: m.fetchNewRequestIDFunc,
OAuth2Provider: m.oauth2Provider,
}
}

Expand Down Expand Up @@ -686,55 +686,9 @@ func (m *ToolsManager) executeToolInternal(ctx *schemas.BifrostContext, toolCall

// Handle per-user OAuth: inject user-specific Authorization header
if client.ExecutionConfig.AuthType == schemas.MCPAuthTypePerUserOauth {
if m.oauth2Provider == nil {
return nil, "", "", fmt.Errorf("per-user OAuth requires an OAuth2Provider but none is configured")
}
virtualKeyID, _ := ctx.Value(schemas.BifrostContextKeyGovernanceVirtualKeyID).(string)
userID, _ := ctx.Value(schemas.BifrostContextKeyUserID).(string)
sessionToken, _ := ctx.Value(schemas.BifrostContextKeyMCPUserSession).(string)

// Optional X-Bf-User-Id header overrides user identity; if absent, falls back to virtual key
if mcpUserID, _ := ctx.Value(schemas.BifrostContextKeyMCPUserID).(string); mcpUserID != "" {
userID = mcpUserID
}

// Try identity-based token lookup first (works even without session token)
accessToken, err := m.oauth2Provider.GetUserAccessTokenByIdentity(ctx, virtualKeyID, userID, sessionToken, client.ExecutionConfig.ID)
if err != nil && !errors.Is(err, schemas.ErrOAuth2TokenNotFound) {
// Had session but token lookup failed with a real error (not just "not found") — return error
return nil, "", "", fmt.Errorf("failed to get user access token for MCP server %s: %w", client.ExecutionConfig.Name, err)
}
accessToken, err := utils.ResolvePerUserOAuthToken(ctx, client, m.oauth2Provider)
if err != nil {
// No token found — user hasn't authenticated with this MCP server yet.
// In LLM gateway mode with no identity, we can't track who this user is,
// so an OAuth flow would produce an orphaned token. Return a clear error instead.
isMCPGateway, _ := ctx.Value(schemas.BifrostContextKeyIsMCPGateway).(bool)
if !isMCPGateway && userID == "" && virtualKeyID == "" {
return nil, "", "", fmt.Errorf(
"per-user OAuth for %s requires a user identity: include X-Bf-User-Id or a Virtual Key in your request so the token can be linked to you",
client.ExecutionConfig.Name,
)
}

// Initiate OAuth flow to get a proper authorize URL with session tracking.
if client.ExecutionConfig.OauthConfigID == nil || *client.ExecutionConfig.OauthConfigID == "" {
return nil, "", "", fmt.Errorf("per-user OAuth requires an OAuth config but MCP client %s has none", client.ExecutionConfig.Name)
}
redirectURI := buildRedirectURIFromContext(ctx)
if redirectURI == "" {
return nil, "", "", fmt.Errorf("per-user OAuth requires a redirect URI but none is available in context")
}
flowInitiation, sessionID, flowErr := m.oauth2Provider.InitiateUserOAuthFlow(ctx, *client.ExecutionConfig.OauthConfigID, client.ExecutionConfig.ID, redirectURI)
if flowErr != nil {
return nil, "", "", fmt.Errorf("failed to initiate per-user OAuth flow for %s: %w", client.ExecutionConfig.Name, flowErr)
}
return nil, "", "", &schemas.MCPUserOAuthRequiredError{
MCPClientID: client.ExecutionConfig.ID,
MCPClientName: client.ExecutionConfig.Name,
AuthorizeURL: flowInitiation.AuthorizeURL,
SessionID: sessionID,
Message: fmt.Sprintf("Authentication required for %s. Please visit the authorize URL to connect your account.", client.ExecutionConfig.Name),
}
return nil, "", "", err
}

if client.Conn == nil {
Expand All @@ -743,7 +697,7 @@ func (m *ToolsManager) executeToolInternal(ctx *schemas.BifrostContext, toolCall
toolCtx, cancel := context.WithTimeout(ctx, toolExecutionTimeout)
defer cancel()

toolResponse, callErr := executeToolWithUserToken(toolCtx, client.ExecutionConfig, originalMCPToolName, arguments, accessToken, m.logger)
toolResponse, callErr := ExecuteToolWithUserToken(toolCtx, client.ExecutionConfig, originalMCPToolName, arguments, accessToken, m.logger)
if callErr != nil {
if toolCtx.Err() == context.DeadlineExceeded {
return nil, "", "", fmt.Errorf("MCP tool call timed out after %v: %s", toolExecutionTimeout, toolName)
Expand All @@ -755,15 +709,7 @@ func (m *ToolsManager) executeToolInternal(ctx *schemas.BifrostContext, toolCall
return createToolResponseMessage(*toolCall, responseText), client.ExecutionConfig.Name, sanitizedToolName, nil
}

// Persistent connection exists — use per-call headers
headers := make(http.Header)
if client.ExecutionConfig.Headers != nil {
for key, value := range client.ExecutionConfig.Headers {
headers.Add(key, value.GetValue())
}
}
headers.Set("Authorization", "Bearer "+accessToken)
callRequest.Header = headers
callRequest.Header = utils.BuildPerUserOAuthHeaders(callRequest.Header, accessToken)
} else if client.ExecutionConfig.Headers != nil {
headers := make(http.Header)
for key, value := range client.ExecutionConfig.Headers {
Expand Down Expand Up @@ -911,7 +857,7 @@ func (m *ToolsManager) UpdateConfig(config *schemas.MCPToolManagerConfig) {
// Returns:
// - *mcp.CallToolResult: tool execution result
// - error: any error during connection or execution
func executeToolWithUserToken(ctx context.Context, config *schemas.MCPClientConfig, toolName string, arguments map[string]interface{}, accessToken string, logger schemas.Logger) (*mcp.CallToolResult, error) {
func ExecuteToolWithUserToken(ctx context.Context, config *schemas.MCPClientConfig, toolName string, arguments map[string]interface{}, accessToken string, logger schemas.Logger) (*mcp.CallToolResult, error) {
if config.ConnectionString == nil || config.ConnectionString.GetValue() == "" {
return nil, fmt.Errorf("connection URL is required for per-user OAuth tool execution")
}
Expand Down Expand Up @@ -964,15 +910,6 @@ func executeToolWithUserToken(ctx context.Context, config *schemas.MCPClientConf
return tempClient.CallTool(ctx, callRequest)
}

// buildRedirectURIFromContext extracts the OAuth redirect URI from context.
// The URI is set by the HTTP middleware from the request's host.
func buildRedirectURIFromContext(ctx *schemas.BifrostContext) string {
if uri, ok := ctx.Value(schemas.BifrostContextKeyOAuthRedirectURI).(string); ok && uri != "" {
return uri
}
// Fallback — should not happen if middleware is configured correctly
return ""
}

// GetCodeModeBindingLevel returns the current code mode binding level.
// This method is safe to call concurrently from multiple goroutines.
Expand Down
71 changes: 71 additions & 0 deletions core/mcp/utils/utils.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,82 @@
package utils

import (
"errors"
"fmt"
"net/http"

"github.com/maximhq/bifrost/core/schemas"
)

// ResolvePerUserOAuthToken looks up the per-user OAuth access token for the given client.
// If no token exists yet, it initiates an OAuth flow and returns an MCPUserOAuthRequiredError.
func ResolvePerUserOAuthToken(ctx *schemas.BifrostContext, client *schemas.MCPClientState, oauth2Provider schemas.OAuth2Provider) (string, error) {
if oauth2Provider == nil {
return "", fmt.Errorf("per-user OAuth requires an OAuth2Provider but none is configured")
}

virtualKeyID, _ := ctx.Value(schemas.BifrostContextKeyGovernanceVirtualKeyID).(string)
userID, _ := ctx.Value(schemas.BifrostContextKeyUserID).(string)
sessionToken, _ := ctx.Value(schemas.BifrostContextKeyMCPUserSession).(string)

// Optional X-Bf-User-Id header overrides user identity; if absent, falls back to virtual key
if mcpUserID, _ := ctx.Value(schemas.BifrostContextKeyMCPUserID).(string); mcpUserID != "" {
userID = mcpUserID
}

accessToken, err := oauth2Provider.GetUserAccessTokenByIdentity(ctx, virtualKeyID, userID, sessionToken, client.ExecutionConfig.ID)
if err != nil && !errors.Is(err, schemas.ErrOAuth2TokenNotFound) {
return "", fmt.Errorf("failed to get user access token for MCP server %s: %w", client.ExecutionConfig.Name, err)
}
if err != nil {
// In LLM gateway mode with no identity, an OAuth flow would produce an orphaned token.
isMCPGateway, _ := ctx.Value(schemas.BifrostContextKeyIsMCPGateway).(bool)
if !isMCPGateway && userID == "" && virtualKeyID == "" {
return "", fmt.Errorf(
"per-user OAuth for %s requires a user identity: include X-Bf-User-Id or a Virtual Key in your request so the token can be linked to you",
client.ExecutionConfig.Name,
)
}

if client.ExecutionConfig.OauthConfigID == nil || *client.ExecutionConfig.OauthConfigID == "" {
return "", fmt.Errorf("per-user OAuth requires an OAuth config but MCP client %s has none", client.ExecutionConfig.Name)
}
redirectURI := BuildRedirectURIFromContext(ctx)
if redirectURI == "" {
return "", fmt.Errorf("per-user OAuth requires a redirect URI but none is available in context")
}
flowInitiation, sessionID, flowErr := oauth2Provider.InitiateUserOAuthFlow(ctx, *client.ExecutionConfig.OauthConfigID, client.ExecutionConfig.ID, redirectURI)
if flowErr != nil {
return "", fmt.Errorf("failed to initiate per-user OAuth flow for %s: %w", client.ExecutionConfig.Name, flowErr)
}
return "", &schemas.MCPUserOAuthRequiredError{
MCPClientID: client.ExecutionConfig.ID,
MCPClientName: client.ExecutionConfig.Name,
AuthorizeURL: flowInitiation.AuthorizeURL,
SessionID: sessionID,
Message: fmt.Sprintf("Authentication required for %s. Please visit the authorize URL to connect your account.", client.ExecutionConfig.Name),
}
}

return accessToken, nil
}

// BuildPerUserOAuthHeaders clones the provided headers and adds the Bearer token,
// preserving any request-scoped extra headers already present.
func BuildPerUserOAuthHeaders(headers http.Header, accessToken string) http.Header {
h := headers.Clone()
h.Set("Authorization", "Bearer "+accessToken)
return h
}

// BuildRedirectURIFromContext extracts the OAuth redirect URI from context.
func BuildRedirectURIFromContext(ctx *schemas.BifrostContext) string {
if uri, ok := ctx.Value(schemas.BifrostContextKeyOAuthRedirectURI).(string); ok && uri != "" {
return uri
}
return ""
}

// GetHeadersForToolExecution sets additional headers for tool execution.
// It returns the headers for the tool execution.
func GetHeadersForToolExecution(ctx *schemas.BifrostContext, client *schemas.MCPClientState) http.Header {
Expand Down
16 changes: 0 additions & 16 deletions core/providers/bedrock/bedrock.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,22 +138,6 @@ func (provider *BedrockProvider) GetProviderKey() schemas.ModelProvider {
return providerUtils.GetProviderName(schemas.Bedrock, provider.customProviderConfig)
}

// ensureBedrockKeyConfig ensures key.BedrockKeyConfig is non-nil. When the key
// uses API key authentication (key.Value is set) but has no Bedrock-specific
// config, a minimal default is created so the request URL can be constructed
// (region defaults to us-east-1). Returns false only when there is truly no
// way to authenticate (no API key AND no bedrock config).
func ensureBedrockKeyConfig(key *schemas.Key) bool {
if key.BedrockKeyConfig != nil {
return true
}
if key.Value.GetValue() != "" {
key.BedrockKeyConfig = &schemas.BedrockKeyConfig{}
return true
}
return false
}

// isStreamTransportError reports whether err is a transport-level connection
// failure that occurred while reading the EventStream body — as opposed to a
// semantic error (JSON parse failure, AWS exception event, etc.).
Expand Down
9 changes: 7 additions & 2 deletions core/schemas/bifrost.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,8 +250,8 @@ const (
BifrostContextKeyRequestHeaders BifrostContextKey = "bifrost-request-headers" // map[string]string (all request headers with lowercased keys)
BifrostContextKeySkipListModelsGovernanceFiltering BifrostContextKey = "bifrost-skip-list-models-governance-filtering" // bool (set by bifrost - DO NOT SET THIS MANUALLY))
BifrostContextKeySCIMClaims BifrostContextKey = "scim_claims"
BifrostContextKeyUserID BifrostContextKey = "bifrost-user-id" // string (to store the user ID (set by enterprise auth middleware - DO NOT SET THIS MANUALLY))
BifrostContextKeyUserName BifrostContextKey = "bifrost-user-name" // string (to store the user name (set by enterprise auth middleware - DO NOT SET THIS MANUALLY))
BifrostContextKeyUserID BifrostContextKey = "bifrost-user-id" // string (to store the user ID (set by enterprise auth middleware - DO NOT SET THIS MANUALLY))
BifrostContextKeyUserName BifrostContextKey = "bifrost-user-name" // string (to store the user name (set by enterprise auth middleware - DO NOT SET THIS MANUALLY))
BifrostContextKeyTargetUserID BifrostContextKey = "target_user_id"
BifrostContextKeyIsAzureUserAgent BifrostContextKey = "bifrost-is-azure-user-agent" // bool (set by bifrost - DO NOT SET THIS MANUALLY)) - whether the request is an Azure user agent (only used in gateway)
BifrostContextKeyVideoOutputRequested BifrostContextKey = "bifrost-video-output-requested"
Expand Down Expand Up @@ -1056,6 +1056,11 @@ func (r *BifrostResponse) PopulateExtraFields(requestType RequestType, provider
r.ContainerFileDeleteResponse.ExtraFields.Provider = provider
r.ContainerFileDeleteResponse.ExtraFields.OriginalModelRequested = originalModelRequested
r.ContainerFileDeleteResponse.ExtraFields.ResolvedModelUsed = resolvedModel
case r.OCRResponse != nil:
r.OCRResponse.ExtraFields.RequestType = requestType
r.OCRResponse.ExtraFields.Provider = provider
r.OCRResponse.ExtraFields.OriginalModelRequested = originalModelRequested
r.OCRResponse.ExtraFields.ResolvedModelUsed = resolvedModel
case r.PassthroughResponse != nil:
r.PassthroughResponse.ExtraFields.RequestType = requestType
r.PassthroughResponse.ExtraFields.Provider = provider
Expand Down
7 changes: 1 addition & 6 deletions core/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,13 +144,8 @@ func validateKey(providerKey schemas.ModelProvider, key *schemas.Key) error {
return fmt.Errorf("azure_key_config.endpoint is required")
}
case schemas.Bedrock:
// Key is valid if either:
// 1. BedrockKeyConfig is provided
// 2. Value is provided and is not empty
// BedrockKeyConfig is optional — an empty config is valid for IRSA / ambient credential auth.
if key.BedrockKeyConfig == nil {
if key.Value.GetValue() == "" {
return fmt.Errorf("either value in key or bedrock_key_config is required")
}
key.BedrockKeyConfig = &schemas.BedrockKeyConfig{}
}
case schemas.Vertex:
Expand Down
1 change: 1 addition & 0 deletions docs/features/async-inference.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ Streaming is not supported on async endpoints.
| Image generations | `/v1/async/images/generations` | `/v1/async/images/generations/{job_id}` |
| Image edits | `/v1/async/images/edits` | `/v1/async/images/edits/{job_id}` |
| Image variations | `/v1/async/images/variations` | `/v1/async/images/variations/{job_id}` |
| OCR | `/v1/async/ocr` | `/v1/async/ocr/{job_id}` |
| Rerank | `/v1/async/rerank` | `/v1/async/rerank/{job_id}` |

## Submitting a Request
Expand Down
1 change: 1 addition & 0 deletions framework/changelog.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
- fix: adds support for OCR request pricing
Loading
Loading