diff --git a/core/mcp/toolmanager.go b/core/mcp/toolmanager.go index d786b79404..029d3fffd4 100644 --- a/core/mcp/toolmanager.go +++ b/core/mcp/toolmanager.go @@ -5,6 +5,7 @@ package mcp import ( "context" "encoding/json" + "errors" "fmt" "net/http" "strings" @@ -594,12 +595,22 @@ func (m *ToolsManager) executeToolInternal(ctx *schemas.BifrostContext, toolCall // Try identity-based token lookup first (works even without session token) accessToken, err := m.oauth2Provider.GetUserAccessTokenByIdentity(ctx, virtualKeyID, userID, sessionToken, client.ExecutionConfig.ID) - if err != nil && sessionToken != "" { - // Had session but token lookup failed — return error + if err != nil && !errors.Is(err, schemas.ErrOAuth2TokenNotFound) { + // Had session but token lookup failed with a real error (not just "not found") — return error return nil, "", "", fmt.Errorf("failed to get user access token for MCP server %s: %w", client.ExecutionConfig.Name, err) } if err != nil { - // No session and no token by identity — user hasn't authenticated yet. + // No token found — user hasn't authenticated with this MCP server yet. + // In LLM gateway mode with no identity, we can't track who this user is, + // so an OAuth flow would produce an orphaned token. Return a clear error instead. + isMCPGateway, _ := ctx.Value(schemas.BifrostContextKeyIsMCPGateway).(bool) + if !isMCPGateway && userID == "" && virtualKeyID == "" { + return nil, "", "", fmt.Errorf( + "per-user OAuth for %s requires a user identity: include X-Bf-User-Id or a Virtual Key in your request so the token can be linked to you", + client.ExecutionConfig.Name, + ) + } + // Initiate OAuth flow to get a proper authorize URL with session tracking. if client.ExecutionConfig.OauthConfigID == nil || *client.ExecutionConfig.OauthConfigID == "" { return nil, "", "", fmt.Errorf("per-user OAuth requires an OAuth config but MCP client %s has none", client.ExecutionConfig.Name) diff --git a/core/providers/bedrock/bedrock.go b/core/providers/bedrock/bedrock.go index f40ea760c1..eb5da01fde 100644 --- a/core/providers/bedrock/bedrock.go +++ b/core/providers/bedrock/bedrock.go @@ -221,7 +221,7 @@ func (provider *BedrockProvider) completeRequest(ctx *schemas.BifrostContext, js req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", key.Value.GetValue())) } else { // Sign the request using either explicit credentials or IAM role authentication - if err := signAWSRequest(ctx, req, config.AccessKey, config.SecretKey, config.SessionToken, config.RoleARN, config.ExternalID, config.RoleSessionName, region, bedrockSigningService, provider.GetProviderKey()); err != nil { + if err := signAWSRequest(ctx, req, config.AccessKey, config.SecretKey, config.SessionToken, config.RoleARN, config.ExternalID, config.RoleSessionName, region, bedrockSigningService); err != nil { return nil, 0, nil, err } } @@ -348,7 +348,7 @@ func (provider *BedrockProvider) completeAgentRuntimeRequest(ctx *schemas.Bifros if key.Value.GetValue() != "" { req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", key.Value.GetValue())) } else { - if err := signAWSRequest(ctx, req, config.AccessKey, config.SecretKey, config.SessionToken, config.RoleARN, config.ExternalID, config.RoleSessionName, region, bedrockSigningService, provider.GetProviderKey()); err != nil { + if err := signAWSRequest(ctx, req, config.AccessKey, config.SecretKey, config.SessionToken, config.RoleARN, config.ExternalID, config.RoleSessionName, region, bedrockSigningService); err != nil { return nil, 0, nil, err } } @@ -450,8 +450,8 @@ func (provider *BedrockProvider) makeStreamingRequest(ctx *schemas.BifrostContex } else { req.Header.Set("Accept", "application/vnd.amazon.eventstream") // Sign the request using either explicit credentials or IAM role authentication - if err := signAWSRequest(ctx, req, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, bedrockSigningService, providerName); err != nil { - return nil, deployment, err + if err := signAWSRequest(ctx, req, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, bedrockSigningService); err != nil { + return nil, err } } @@ -480,26 +480,20 @@ func (provider *BedrockProvider) makeStreamingRequest(ctx *schemas.BifrostContex var opErr *net.OpError var dnsErr *net.DNSError if errors.As(respErr, &opErr) || errors.As(respErr, &dnsErr) { - return nil, deployment, &schemas.BifrostError{ + return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: schemas.ErrProviderNetworkError, Error: respErr, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: providerName, - }, } } - return nil, deployment, &schemas.BifrostError{ + return nil, &schemas.BifrostError{ IsBifrostError: false, Error: &schemas.ErrorField{ Message: schemas.ErrProviderDoRequest, Error: respErr, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - Provider: providerName, - }, } } @@ -706,7 +700,7 @@ func (provider *BedrockProvider) listModelsByKey(ctx *schemas.BifrostContext, ke } else { // Sign the request using either explicit credentials or IAM role authentication - if err := signAWSRequest(ctx, req, config.AccessKey, config.SecretKey, config.SessionToken, config.RoleARN, config.ExternalID, config.RoleSessionName, region, bedrockSigningService, providerName); err != nil { + if err := signAWSRequest(ctx, req, config.AccessKey, config.SecretKey, config.SessionToken, config.RoleARN, config.ExternalID, config.RoleSessionName, region, bedrockSigningService); err != nil { return nil, err } } @@ -969,13 +963,12 @@ func (provider *BedrockProvider) TextCompletionStream(ctx *schemas.BifrostContex Error: err, }, ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.TextCompletionStreamRequest, - Provider: providerName, - ModelRequested: request.Model, + RequestType: schemas.TextCompletionStreamRequest, + Provider: providerName, }, }, responseChan, provider.logger) } else { - providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.TextCompletionStreamRequest, providerName, request.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) } return } @@ -1006,15 +999,10 @@ func (provider *BedrockProvider) TextCompletionStream(ctx *schemas.BifrostContex Error: &schemas.ErrorField{ Message: fmt.Sprintf("%s stream %s: %s", providerName, excType, errMsg), }, - ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.TextCompletionStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - }, }, responseChan, provider.logger) } else { err := fmt.Errorf("%s stream %s: %s", providerName, excType, errMsg) - providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.TextCompletionStreamRequest, providerName, request.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) } return } @@ -1130,7 +1118,6 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx *schemas.BifrostContex if err := providerUtils.CheckOperationAllowed(schemas.Bedrock, provider.customProviderConfig, schemas.ChatCompletionStreamRequest); err != nil { return nil, err } - jsonData, bifrostErr := providerUtils.CheckContextAndGetRequestBody( ctx, request, @@ -1152,7 +1139,7 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx *schemas.BifrostContex responseChan := make(chan *schemas.BifrostStreamChunk, schemas.DefaultStreamBufferSize) providerUtils.SetStreamIdleTimeoutIfEmpty(ctx, provider.networkConfig.StreamIdleTimeoutInSeconds) - + providerName := provider.GetProviderKey() // Start streaming in a goroutine go func() { defer func() { @@ -1225,14 +1212,9 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx *schemas.BifrostContex Message: schemas.ErrProviderNetworkError, Error: err, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - }, }, responseChan, provider.logger) } else { - providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.ChatCompletionStreamRequest, providerName, request.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) } return } @@ -1260,14 +1242,9 @@ func (provider *BedrockProvider) ChatCompletionStream(ctx *schemas.BifrostContex Error: &schemas.ErrorField{ Message: err.Error(), }, - ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ChatCompletionStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - }, }, responseChan, provider.logger) } else { - providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.ChatCompletionStreamRequest, providerName, request.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) } return } @@ -1556,7 +1533,7 @@ func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, po lastChunkTime := startTime decoder := eventstream.NewDecoder() payloadBuf := make([]byte, 0, 1024*1024) // 1MB payload buffer - + providerName := provider.GetProviderKey() for { // If context was cancelled/timed out, let defer handle it if ctx.Err() != nil { @@ -1608,14 +1585,9 @@ func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, po Message: schemas.ErrProviderNetworkError, Error: err, }, - ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - }, }, responseChan, provider.logger) } else { - providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.ResponsesStreamRequest, providerName, request.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) } return } @@ -1643,14 +1615,9 @@ func (provider *BedrockProvider) ResponsesStream(ctx *schemas.BifrostContext, po Error: &schemas.ErrorField{ Message: err.Error(), }, - ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - }, }, responseChan, provider.logger) } else { - providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, schemas.ResponsesStreamRequest, providerName, request.Model, provider.logger) + providerUtils.ProcessAndSendError(ctx, postHookRunner, err, responseChan, provider.logger) } return } @@ -1819,7 +1786,6 @@ func (provider *BedrockProvider) Embedding(ctx *schemas.BifrostContext, key sche if bifrostError != nil { return nil, providerUtils.EnrichError(ctx, bifrostError, jsonData, nil, provider.sendBackRawRequest, provider.sendBackRawResponse) } - // Parse response based on model type var bifrostResponse *schemas.BifrostEmbeddingResponse switch modelType { @@ -1838,7 +1804,7 @@ func (provider *BedrockProvider) Embedding(ctx *schemas.BifrostContext, key sche } converted, convErr := cohereResp.ToBifrostEmbeddingResponse() if convErr != nil { - return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error parsing Cohere embedding response", convErr, providerName), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) + return nil, providerUtils.EnrichError(ctx, providerUtils.NewBifrostOperationError("error parsing Cohere embedding response", convErr), jsonData, rawResponse, provider.sendBackRawRequest, provider.sendBackRawResponse) } bifrostResponse = converted bifrostResponse.Model = request.Model @@ -2878,7 +2844,7 @@ func (provider *BedrockProvider) BatchCreate(ctx *schemas.BifrostContext, key sc } // Sign request - if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, bedrockSigningService, providerName); err != nil { + if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, bedrockSigningService); err != nil { return nil, providerUtils.EnrichError(ctx, err, jsonData, nil, sendBackRawRequest, sendBackRawResponse) } @@ -3003,7 +2969,7 @@ func (provider *BedrockProvider) BatchList(ctx *schemas.BifrostContext, keys []s } // Sign request - if bifrostErr := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, bedrockSigningService, providerName); bifrostErr != nil { + if bifrostErr := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, bedrockSigningService); bifrostErr != nil { return nil, bifrostErr } @@ -3182,7 +3148,7 @@ func (provider *BedrockProvider) BatchRetrieve(ctx *schemas.BifrostContext, keys } // Sign request - if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, bedrockSigningService, providerName); err != nil { + if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, bedrockSigningService); err != nil { lastErr = err continue } @@ -3317,7 +3283,7 @@ func (provider *BedrockProvider) BatchCancel(ctx *schemas.BifrostContext, keys [ } // Sign request - if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, bedrockSigningService, providerName); err != nil { + if err := signAWSRequest(ctx, httpReq, key.BedrockKeyConfig.AccessKey, key.BedrockKeyConfig.SecretKey, key.BedrockKeyConfig.SessionToken, key.BedrockKeyConfig.RoleARN, key.BedrockKeyConfig.ExternalID, key.BedrockKeyConfig.RoleSessionName, region, bedrockSigningService); err != nil { lastErr = err continue } diff --git a/core/providers/bedrock/invoke.go b/core/providers/bedrock/invoke.go index a7886b477d..600a44d0e2 100644 --- a/core/providers/bedrock/invoke.go +++ b/core/providers/bedrock/invoke.go @@ -931,8 +931,12 @@ func ToBedrockInvokeImagesResponse(ctx *schemas.BifrostContext, resp *schemas.Bi } model := resp.Model - if resp.ExtraFields.ModelRequested != "" { - model = resp.ExtraFields.ModelRequested + if model == "" { + if resp.ExtraFields.ResolvedModelUsed != "" { + model = resp.ExtraFields.ResolvedModelUsed + } else if resp.ExtraFields.OriginalModelRequested != "" { + model = resp.ExtraFields.OriginalModelRequested + } } // Stability AI models use the same BedrockImageGenerationResponse format as Titan/Nova Canvas @@ -974,8 +978,12 @@ func ToBedrockEmbeddingInvokeResponse(resp *schemas.BifrostEmbeddingResponse) (i // Use model name to distinguish Cohere from Titan — not batch size. // A single-input Cohere request must still return the Cohere envelope format. model := resp.Model - if resp.ExtraFields.ModelRequested != "" { - model = resp.ExtraFields.ModelRequested + if model == "" { + if resp.ExtraFields.ResolvedModelUsed != "" { + model = resp.ExtraFields.ResolvedModelUsed + } else if resp.ExtraFields.OriginalModelRequested != "" { + model = resp.ExtraFields.OriginalModelRequested + } } if strings.Contains(strings.ToLower(model), "cohere") { diff --git a/core/providers/bedrock/transport_test.go b/core/providers/bedrock/transport_test.go index 6751527b5b..1e2a447e9d 100644 --- a/core/providers/bedrock/transport_test.go +++ b/core/providers/bedrock/transport_test.go @@ -138,7 +138,7 @@ func TestMakeStreamingRequest_StaleConnection_IsRetryable(t *testing.T) { ctx := testBedrockCtx() key := testBedrockKey() - _, _, bifrostErr := provider.makeStreamingRequest(ctx, []byte(`{}`), key, "anthropic.claude-sonnet-4-5", "converse-stream") + _, bifrostErr := provider.makeStreamingRequest(ctx, []byte(`{}`), key, "anthropic.claude-sonnet-4-5", "converse-stream") require.NotNil(t, bifrostErr, "expected error when server closes connection") assert.False(t, bifrostErr.IsBifrostError, diff --git a/core/providers/gemini/gemini.go b/core/providers/gemini/gemini.go index 5dd4940b72..32d3c96f93 100644 --- a/core/providers/gemini/gemini.go +++ b/core/providers/gemini/gemini.go @@ -281,7 +281,7 @@ func (provider *GeminiProvider) ChatCompletion(ctx *schemas.BifrostContext, key ctx, request, func() (providerUtils.RequestBodyWithExtraParams, error) { - return ToGeminiChatCompletionRequest(request), nil + return ToGeminiChatCompletionRequest(request) }) if err != nil { return nil, err diff --git a/core/providers/openai/chat_test.go b/core/providers/openai/chat_test.go index f391f821cb..f5e08c7f8e 100644 --- a/core/providers/openai/chat_test.go +++ b/core/providers/openai/chat_test.go @@ -277,7 +277,6 @@ func TestToOpenAIChatRequest_FireworksPreservesReasoningAndCacheIsolation(t *tes func() (providerUtils.RequestBodyWithExtraParams, error) { return ToOpenAIChatRequest(ctx, bifrostReq), nil }, - schemas.Fireworks, ) if bifrostErr != nil { t.Fatalf("failed to build request body: %v", bifrostErr.Error.Message) diff --git a/core/providers/openai/openai.go b/core/providers/openai/openai.go index 8930f3c472..9cb2b3db29 100644 --- a/core/providers/openai/openai.go +++ b/core/providers/openai/openai.go @@ -1763,11 +1763,6 @@ func HandleOpenAIResponsesStreaming( Type: schemas.Ptr(string(schemas.ResponsesStreamResponseTypeFailed)), IsBifrostError: false, Error: &schemas.ErrorField{}, - ExtraFields: schemas.BifrostErrorExtraFields{ - RequestType: schemas.ResponsesStreamRequest, - Provider: providerName, - ModelRequested: request.Model, - }, } if response.Response != nil && response.Response.Error != nil { bifrostErr.Error.Message = response.Response.Error.Message @@ -1777,12 +1772,7 @@ func HandleOpenAIResponsesStreaming( providerUtils.ProcessAndSendBifrostError(ctx, postHookRunner, providerUtils.EnrichError(ctx, bifrostErr, jsonBody, nil, sendBackRawRequest, sendBackRawResponse), responseChan, logger) return } - - response.ExtraFields.RequestType = schemas.ResponsesStreamRequest - response.ExtraFields.Provider = providerName - response.ExtraFields.ModelRequested = request.Model response.ExtraFields.ChunkIndex = response.SequenceNumber - if response.Type == schemas.ResponsesStreamResponseTypeCompleted || response.Type == schemas.ResponsesStreamResponseTypeIncomplete { // Set raw request if enabled if sendBackRawRequest { diff --git a/core/providers/openai/text_test.go b/core/providers/openai/text_test.go index b2dd53ee35..71c2f195a0 100644 --- a/core/providers/openai/text_test.go +++ b/core/providers/openai/text_test.go @@ -51,7 +51,6 @@ func TestToOpenAITextCompletionRequest_FireworksUsesCacheIsolation(t *testing.T) func() (providerUtils.RequestBodyWithExtraParams, error) { return ToOpenAITextCompletionRequest(bifrostReq), nil }, - schemas.Fireworks, ) if bifrostErr != nil { t.Fatalf("failed to build request body: %v", bifrostErr.Error.Message) diff --git a/core/providers/vertex/vertex.go b/core/providers/vertex/vertex.go index 1d0de4cb23..626cf00207 100644 --- a/core/providers/vertex/vertex.go +++ b/core/providers/vertex/vertex.go @@ -409,7 +409,7 @@ func (provider *VertexProvider) ChatCompletion(ctx *schemas.BifrostContext, key if err != nil { return nil, fmt.Errorf("failed to delete model field: %w", err) } - } else if schemas.IsGeminiModel(deployment) || schemas.IsAllDigitsASCII(deployment) { + } else if schemas.IsGeminiModel(request.Model) || schemas.IsAllDigitsASCII(request.Model) { reqBody, err := gemini.ToGeminiChatCompletionRequest(request) if err != nil { return nil, err @@ -702,7 +702,7 @@ func (provider *VertexProvider) ChatCompletionStream(ctx *schemas.BifrostContext return nil, fmt.Errorf("chat completion input is not provided") } extraParams = reqBody.GetExtraParams() - reqBody.Stream = schemas.Ptr(true) + reqBody.Stream = new(true) // Add provider-aware beta headers for Vertex anthropic.AddMissingBetaHeadersToContext(ctx, reqBody, schemas.Vertex) diff --git a/core/schemas/bifrost.go b/core/schemas/bifrost.go index d2129dfb51..d075239539 100644 --- a/core/schemas/bifrost.go +++ b/core/schemas/bifrost.go @@ -174,27 +174,27 @@ const ( MCPContextKeyIncludeClients BifrostContextKey = "mcp-include-clients" // Context key for whitelist client filtering MCPContextKeyIncludeTools BifrostContextKey = "mcp-include-tools" // Context key for whitelist tool filtering (Note: toolName should be in "clientName-toolName" format for individual tools, or "clientName-*" for wildcard) - BifrostContextKeySelectedKeyID BifrostContextKey = "bifrost-selected-key-id" // string (to store the selected key ID (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) - BifrostContextKeySelectedKeyName BifrostContextKey = "bifrost-selected-key-name" // string (to store the selected key name (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) - BifrostContextKeyGovernanceVirtualKeyID BifrostContextKey = "bifrost-governance-virtual-key-id" // string (to store the virtual key ID (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) - BifrostContextKeyGovernanceVirtualKeyName BifrostContextKey = "bifrost-governance-virtual-key-name" // string (to store the virtual key name (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) - BifrostContextKeyGovernanceTeamID BifrostContextKey = "bifrost-governance-team-id" // string (to store the team ID (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) - BifrostContextKeyGovernanceTeamName BifrostContextKey = "bifrost-governance-team-name" // string (to store the team name (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) - BifrostContextKeyGovernanceCustomerID BifrostContextKey = "bifrost-governance-customer-id" // string (to store the customer ID (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) - BifrostContextKeyGovernanceCustomerName BifrostContextKey = "bifrost-governance-customer-name" // string (to store the customer name (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) - BifrostContextKeyGovernanceUserID BifrostContextKey = "bifrost-governance-user-id" // string (to store the user ID (set by enterprise governance plugin - DO NOT SET THIS MANUALLY)) - BifrostContextKeyGovernanceBusinessUnitID BifrostContextKey = "bifrost-governance-business-unit-id" // string (to store the business unit ID (set by enterprise governance plugin - DO NOT SET THIS MANUALLY)) - BifrostContextKeyGovernanceBusinessUnitName BifrostContextKey = "bifrost-governance-business-unit-name" // string (to store the business unit name (set by enterprise governance plugin - DO NOT SET THIS MANUALLY)) - BifrostContextKeyGovernanceRoutingRuleID BifrostContextKey = "bifrost-governance-routing-rule-id" // string (to store the routing rule ID (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) - BifrostContextKeyGovernanceRoutingRuleName BifrostContextKey = "bifrost-governance-routing-rule-name" // string (to store the routing rule name (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) - BifrostContextKeyGovernanceIncludeOnlyKeys BifrostContextKey = "bf-governance-include-only-keys" // []string (to store the include-only key IDs for provider config routing (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) - BifrostContextKeyNumberOfRetries BifrostContextKey = "bifrost-number-of-retries" // int (to store the number of retries (set by bifrost - DO NOT SET THIS MANUALLY)) - BifrostContextKeyFallbackIndex BifrostContextKey = "bifrost-fallback-index" // int (to store the fallback index (set by bifrost - DO NOT SET THIS MANUALLY)) 0 for primary, 1 for first fallback, etc. - BifrostContextKeyStreamEndIndicator BifrostContextKey = "bifrost-stream-end-indicator" // bool (set by bifrost - DO NOT SET THIS MANUALLY)) - BifrostContextKeyStreamIdleTimeout BifrostContextKey = "bifrost-stream-idle-timeout" // time.Duration (per-chunk idle timeout for streaming) - BifrostContextKeySkipKeySelection BifrostContextKey = "bifrost-skip-key-selection" // bool (will pass an empty key to the provider) - BifrostContextKeyExtraHeaders BifrostContextKey = "bifrost-extra-headers" // map[string][]string - BifrostContextKeyURLPath BifrostContextKey = "bifrost-extra-url-path" // string + BifrostContextKeySelectedKeyID BifrostContextKey = "bifrost-selected-key-id" // string (to store the selected key ID (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) + BifrostContextKeySelectedKeyName BifrostContextKey = "bifrost-selected-key-name" // string (to store the selected key name (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) + BifrostContextKeyGovernanceVirtualKeyID BifrostContextKey = "bifrost-governance-virtual-key-id" // string (to store the virtual key ID (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) + BifrostContextKeyGovernanceVirtualKeyName BifrostContextKey = "bifrost-governance-virtual-key-name" // string (to store the virtual key name (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) + BifrostContextKeyGovernanceTeamID BifrostContextKey = "bifrost-governance-team-id" // string (to store the team ID (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) + BifrostContextKeyGovernanceTeamName BifrostContextKey = "bifrost-governance-team-name" // string (to store the team name (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) + BifrostContextKeyGovernanceCustomerID BifrostContextKey = "bifrost-governance-customer-id" // string (to store the customer ID (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) + BifrostContextKeyGovernanceCustomerName BifrostContextKey = "bifrost-governance-customer-name" // string (to store the customer name (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) + BifrostContextKeyGovernanceUserID BifrostContextKey = "bifrost-governance-user-id" // string (to store the user ID (set by enterprise governance plugin - DO NOT SET THIS MANUALLY)) + BifrostContextKeyGovernanceBusinessUnitID BifrostContextKey = "bifrost-governance-business-unit-id" // string (to store the business unit ID (set by enterprise governance plugin - DO NOT SET THIS MANUALLY)) + BifrostContextKeyGovernanceBusinessUnitName BifrostContextKey = "bifrost-governance-business-unit-name" // string (to store the business unit name (set by enterprise governance plugin - DO NOT SET THIS MANUALLY)) + BifrostContextKeyGovernanceRoutingRuleID BifrostContextKey = "bifrost-governance-routing-rule-id" // string (to store the routing rule ID (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) + BifrostContextKeyGovernanceRoutingRuleName BifrostContextKey = "bifrost-governance-routing-rule-name" // string (to store the routing rule name (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) + BifrostContextKeyGovernanceIncludeOnlyKeys BifrostContextKey = "bf-governance-include-only-keys" // []string (to store the include-only key IDs for provider config routing (set by bifrost governance plugin - DO NOT SET THIS MANUALLY)) + BifrostContextKeyNumberOfRetries BifrostContextKey = "bifrost-number-of-retries" // int (to store the number of retries (set by bifrost - DO NOT SET THIS MANUALLY)) + BifrostContextKeyFallbackIndex BifrostContextKey = "bifrost-fallback-index" // int (to store the fallback index (set by bifrost - DO NOT SET THIS MANUALLY)) 0 for primary, 1 for first fallback, etc. + BifrostContextKeyStreamEndIndicator BifrostContextKey = "bifrost-stream-end-indicator" // bool (set by bifrost - DO NOT SET THIS MANUALLY)) + BifrostContextKeyStreamIdleTimeout BifrostContextKey = "bifrost-stream-idle-timeout" // time.Duration (per-chunk idle timeout for streaming) + BifrostContextKeySkipKeySelection BifrostContextKey = "bifrost-skip-key-selection" // bool (will pass an empty key to the provider) + BifrostContextKeyExtraHeaders BifrostContextKey = "bifrost-extra-headers" // map[string][]string + BifrostContextKeyURLPath BifrostContextKey = "bifrost-extra-url-path" // string BifrostContextKeyUseRawRequestBody BifrostContextKey = "bifrost-use-raw-request-body" BifrostContextKeySendBackRawRequest BifrostContextKey = "bifrost-send-back-raw-request" // bool BifrostContextKeySendBackRawResponse BifrostContextKey = "bifrost-send-back-raw-response" // bool @@ -213,9 +213,10 @@ const ( BifrostContextKeyTraceCompleter BifrostContextKey = "bifrost-trace-completer" // func() (callback to complete trace after streaming - set by tracing middleware) BifrostContextKeyPostHookSpanFinalizer BifrostContextKey = "bifrost-posthook-span-finalizer" // func(context.Context) (callback to finalize post-hook spans after streaming - set by bifrost) BifrostContextKeyAccumulatorID BifrostContextKey = "bifrost-accumulator-id" // string (ID for streaming accumulator lookup - set by tracer for accumulator operations) - BifrostContextKeyMCPUserSession BifrostContextKey = "bifrost-mcp-user-session" // string (per-user OAuth session token from X-Bifrost-MCP-Session header) + BifrostContextKeyMCPUserSession BifrostContextKey = "bifrost-mcp-user-session" // string (per-user OAuth session token, automatically generated by bifrost) BifrostContextKeyMCPUserID BifrostContextKey = "bifrost-mcp-user-id" // string (per-user OAuth user identifier from X-Bf-User-Id header) BifrostContextKeyOAuthRedirectURI BifrostContextKey = "bifrost-oauth-redirect-uri" // string (OAuth callback URL, e.g. https://host/api/oauth/callback - set by HTTP middleware) + BifrostContextKeyIsMCPGateway BifrostContextKey = "bifrost-is-mcp-gateway" // bool (true when request is being handled via the MCP gateway path) BifrostContextKeyHasEmittedMessageDelta BifrostContextKey = "bifrost-has-emitted-message-delta" // bool (tracks whether message_delta was already emitted during streaming - avoids duplicates) BifrostContextKeySkipDBUpdate BifrostContextKey = "bifrost-skip-db-update" // bool (set by bifrost - DO NOT SET THIS MANUALLY)) BifrostContextKeyGovernancePluginName BifrostContextKey = "governance-plugin-name" // string (name of the governance plugin that processed the request - set by bifrost) @@ -1220,8 +1221,8 @@ type BifrostErrorExtraFields struct { OriginalModelRequested string `json:"original_model_requested,omitempty"` ResolvedModelUsed string `json:"resolved_model_used,omitempty"` RequestType RequestType `json:"request_type,omitempty"` - RawRequest interface{} `json:"raw_request,omitempty"` - RawResponse interface{} `json:"raw_response,omitempty"` + RawRequest any `json:"raw_request,omitempty"` + RawResponse any `json:"raw_response,omitempty"` LiteLLMCompat bool `json:"litellm_compat,omitempty"` KeyStatuses []KeyStatus `json:"key_statuses,omitempty"` MCPAuthRequired *MCPUserOAuthRequiredError `json:"mcp_auth_required,omitempty"` // Set when a per-user OAuth MCP tool requires authentication diff --git a/core/schemas/mcp.go b/core/schemas/mcp.go index ffde53271c..af87cdc743 100644 --- a/core/schemas/mcp.go +++ b/core/schemas/mcp.go @@ -21,6 +21,8 @@ var ( ErrOAuth2TokenInvalid = errors.New("oauth2 token invalid") ErrOAuth2RefreshFailed = errors.New("oauth2 token refresh failed") ErrOAuth2NotPerUserSession = errors.New("state does not match a per-user oauth session") + ErrOAuth2TokenNotFound = errors.New("per-user oauth token not found for this identity and mcp server") + ErrPerUserOAuthPendingFlowExpired = errors.New("per-user oauth pending flow has expired") ) // MCPUserOAuthRequiredError is returned when a per-user OAuth MCP server requires diff --git a/framework/configstore/migrations.go b/framework/configstore/migrations.go index 31a2019947..e64351eeaa 100644 --- a/framework/configstore/migrations.go +++ b/framework/configstore/migrations.go @@ -5949,20 +5949,36 @@ func migrationAddPerUserOAuthTables(ctx context.Context, db *gorm.DB) error { ID: "add_per_user_oauth_tables", Migrate: func(tx *gorm.DB) error { tx = tx.WithContext(ctx) - if err := tx.AutoMigrate(&tables.TablePerUserOAuthClient{}); err != nil { - return fmt.Errorf("failed to create oauth_per_user_clients table: %w", err) + mg := tx.Migrator() + if !mg.HasTable(&tables.TablePerUserOAuthClient{}) { + if err := mg.CreateTable(&tables.TablePerUserOAuthClient{}); err != nil { + return fmt.Errorf("failed to create oauth_per_user_clients table: %w", err) + } + } + if !mg.HasTable(&tables.TablePerUserOAuthSession{}) { + if err := mg.CreateTable(&tables.TablePerUserOAuthSession{}); err != nil { + return fmt.Errorf("failed to create oauth_per_user_sessions table: %w", err) + } } - if err := tx.AutoMigrate(&tables.TablePerUserOAuthSession{}); err != nil { - return fmt.Errorf("failed to create oauth_per_user_sessions table: %w", err) + if !mg.HasTable(&tables.TablePerUserOAuthCode{}) { + if err := mg.CreateTable(&tables.TablePerUserOAuthCode{}); err != nil { + return fmt.Errorf("failed to create oauth_per_user_codes table: %w", err) + } } - if err := tx.AutoMigrate(&tables.TablePerUserOAuthCode{}); err != nil { - return fmt.Errorf("failed to create oauth_per_user_codes table: %w", err) + if !mg.HasTable(&tables.TableOauthUserToken{}) { + if err := mg.CreateTable(&tables.TableOauthUserToken{}); err != nil { + return fmt.Errorf("failed to create oauth_user_tokens table: %w", err) + } } - if err := tx.AutoMigrate(&tables.TableOauthUserToken{}); err != nil { - return fmt.Errorf("failed to add identity columns to oauth_user_tokens: %w", err) + if !mg.HasTable(&tables.TableOauthUserSession{}) { + if err := mg.CreateTable(&tables.TableOauthUserSession{}); err != nil { + return fmt.Errorf("failed to create oauth_user_sessions table: %w", err) + } } - if err := tx.AutoMigrate(&tables.TableOauthUserSession{}); err != nil { - return fmt.Errorf("failed to create oauth_user_sessions table: %w", err) + if !mg.HasTable(&tables.TablePerUserOAuthPendingFlow{}) { + if err := mg.CreateTable(&tables.TablePerUserOAuthPendingFlow{}); err != nil { + return fmt.Errorf("failed to create oauth_per_user_pending_flows table: %w", err) + } } return nil @@ -5971,6 +5987,7 @@ func migrationAddPerUserOAuthTables(ctx context.Context, db *gorm.DB) error { tx = tx.WithContext(ctx) mg := tx.Migrator() for _, table := range []any{ + &tables.TablePerUserOAuthPendingFlow{}, &tables.TablePerUserOAuthCode{}, &tables.TablePerUserOAuthSession{}, &tables.TablePerUserOAuthClient{}, diff --git a/framework/configstore/rdb.go b/framework/configstore/rdb.go index abe0a8da99..624ab27300 100644 --- a/framework/configstore/rdb.go +++ b/framework/configstore/rdb.go @@ -1374,7 +1374,7 @@ func (s *RDBConfigStore) UpdateMCPClientDiscoveredTools(ctx context.Context, cli Updates(map[string]interface{}{ "discovered_tools_json": string(toolsJSON), "tool_name_mapping_json": string(mappingJSON), - "updated_at": time.Now(), + "updated_at": time.Now(), }).Error } @@ -2418,7 +2418,7 @@ func (s *RDBConfigStore) GetVirtualKeyMCPConfigs(ctx context.Context, virtualKey return nil, nil } var mcpConfigs []tables.TableVirtualKeyMCPConfig - if err := s.db.WithContext(ctx).Where("virtual_key_id = ?", virtualKey.ID).Find(&mcpConfigs).Error; err != nil { + if err := s.db.WithContext(ctx).Preload("MCPClient").Where("virtual_key_id = ?", virtualKey.ID).Find(&mcpConfigs).Error; err != nil { return nil, err } return mcpConfigs, nil @@ -4101,13 +4101,40 @@ func (s *RDBConfigStore) GetOauthUserTokenBySessionToken(ctx context.Context, se return &token, nil } -// CreateOauthUserToken creates a new per-user OAuth token +// CreateOauthUserToken creates or replaces a per-user OAuth token. +// When an identity (VirtualKeyID or UserID) is set, any existing token for the +// same identity + MCPClientID pair is replaced to keep resolution deterministic. func (s *RDBConfigStore) CreateOauthUserToken(ctx context.Context, token *tables.TableOauthUserToken) error { - result := s.db.WithContext(ctx).Create(token) - if result.Error != nil { - return fmt.Errorf("failed to create oauth user token: %w", result.Error) - } - return nil + // Wrap in a transaction so the SELECT + CREATE/UPDATE is atomic, preventing + // duplicate tokens when concurrent requests race on the same identity+client pair. + return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + if token.UserID != nil && *token.UserID != "" { + var existing tables.TableOauthUserToken + err := tx.Where("user_id = ? AND mcp_client_id = ?", *token.UserID, token.MCPClientID).First(&existing).Error + if err == nil { + token.ID = existing.ID // reuse the row + return tx.Save(token).Error + } + if !errors.Is(err, gorm.ErrRecordNotFound) { + return fmt.Errorf("failed to query oauth user token: %w", err) + } + } else if token.VirtualKeyID != nil && *token.VirtualKeyID != "" { + var existing tables.TableOauthUserToken + err := tx.Where("virtual_key_id = ? AND mcp_client_id = ?", *token.VirtualKeyID, token.MCPClientID).First(&existing).Error + if err == nil { + token.ID = existing.ID // reuse the row + return tx.Save(token).Error + } + if !errors.Is(err, gorm.ErrRecordNotFound) { + return fmt.Errorf("failed to query oauth user token: %w", err) + } + } + + if err := tx.Create(token).Error; err != nil { + return fmt.Errorf("failed to create oauth user token: %w", err) + } + return nil + }) } // UpdateOauthUserToken updates an existing per-user OAuth token @@ -4165,7 +4192,9 @@ func (s *RDBConfigStore) CreatePerUserOAuthClient(ctx context.Context, client *t func (s *RDBConfigStore) GetPerUserOAuthSessionByAccessToken(ctx context.Context, accessToken string) (*tables.TablePerUserOAuthSession, error) { var session tables.TablePerUserOAuthSession tokenHash := encrypt.HashSHA256(accessToken) - result := s.db.WithContext(ctx).Where("access_token_hash = ?", tokenHash).First(&session) + result := s.db.WithContext(ctx).Where("access_token_hash = ?", tokenHash).Preload("VirtualKey", func(db *gorm.DB) *gorm.DB { + return db.Select("id, name, value, encryption_status") + }).First(&session) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, nil @@ -4272,3 +4301,159 @@ func (s *RDBConfigStore) UpdatePerUserOAuthCode(ctx context.Context, code *table } return nil } + +// ---------- Per-User OAuth Pending Flow CRUD ---------- + +// GetPerUserOAuthPendingFlow retrieves a pending consent flow by its ID. +func (s *RDBConfigStore) GetPerUserOAuthPendingFlow(ctx context.Context, id string) (*tables.TablePerUserOAuthPendingFlow, error) { + var flow tables.TablePerUserOAuthPendingFlow + result := s.db.WithContext(ctx).Where("id = ?", id).First(&flow) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, nil + } + return nil, fmt.Errorf("failed to get per-user oauth pending flow: %w", result.Error) + } + return &flow, nil +} + +// CreatePerUserOAuthPendingFlow persists a new pending consent flow. +func (s *RDBConfigStore) CreatePerUserOAuthPendingFlow(ctx context.Context, flow *tables.TablePerUserOAuthPendingFlow) error { + result := s.db.WithContext(ctx).Create(flow) + if result.Error != nil { + return fmt.Errorf("failed to create per-user oauth pending flow: %w", result.Error) + } + return nil +} + +// UpdatePerUserOAuthPendingFlow updates an existing pending consent flow (e.g., after VK step). +func (s *RDBConfigStore) UpdatePerUserOAuthPendingFlow(ctx context.Context, flow *tables.TablePerUserOAuthPendingFlow) error { + result := s.db.WithContext(ctx).Save(flow) + if result.Error != nil { + return fmt.Errorf("failed to update per-user oauth pending flow: %w", result.Error) + } + return nil +} + +// DeletePerUserOAuthPendingFlow deletes a pending consent flow after it has been submitted. +func (s *RDBConfigStore) DeletePerUserOAuthPendingFlow(ctx context.Context, id string) error { + result := s.db.WithContext(ctx).Where("id = ?", id).Delete(&tables.TablePerUserOAuthPendingFlow{}) + if result.Error != nil { + return fmt.Errorf("failed to delete per-user oauth pending flow: %w", result.Error) + } + return nil +} + +func (s *RDBConfigStore) ConsumePerUserOAuthPendingFlow(ctx context.Context, id string) (int64, error) { + now := time.Now().UTC() + result := s.db.WithContext(ctx).Where("id = ? AND expires_at > ?", id, now).Delete(&tables.TablePerUserOAuthPendingFlow{}) + if result.Error != nil { + return 0, fmt.Errorf("failed to consume per-user oauth pending flow: %w", result.Error) + } + if result.RowsAffected == 0 { + // Distinguish between already-consumed (record gone) and expired (record exists but TTL elapsed). + var count int64 + if err := s.db.WithContext(ctx).Model(&tables.TablePerUserOAuthPendingFlow{}).Where("id = ?", id).Count(&count).Error; err != nil { + return 0, fmt.Errorf("failed to inspect per-user oauth pending flow: %w", err) + } + if count > 0 { + return 0, schemas.ErrPerUserOAuthPendingFlowExpired + } + } + return result.RowsAffected, nil +} + +// FinalizePerUserOAuthConsent atomically consumes a pending flow, creates the session, +// and creates the authorization code in a single transaction. +func (s *RDBConfigStore) FinalizePerUserOAuthConsent(ctx context.Context, flowID string, session *tables.TablePerUserOAuthSession, code *tables.TablePerUserOAuthCode) (int64, error) { + var rowsAffected int64 + err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + // 1. Consume the pending flow (atomic idempotency guard). + // Also enforce the TTL so an expired flow cannot be finalized even if callers miss the check. + now := time.Now().UTC() + result := tx.Where("id = ? AND expires_at > ?", flowID, now).Delete(&tables.TablePerUserOAuthPendingFlow{}) + if result.Error != nil { + return fmt.Errorf("failed to consume per-user oauth pending flow: %w", result.Error) + } + rowsAffected = result.RowsAffected + if rowsAffected == 0 { + // Distinguish between already-consumed (record gone) and expired (record exists but TTL elapsed). + var count int64 + if err := tx.Model(&tables.TablePerUserOAuthPendingFlow{}).Where("id = ?", flowID).Count(&count).Error; err != nil { + return fmt.Errorf("failed to inspect per-user oauth pending flow: %w", err) + } + if count > 0 { + return schemas.ErrPerUserOAuthPendingFlowExpired + } + // Record gone — consumed by a concurrent request; caller treats as conflict. + return nil + } + + // 2. Create the Bifrost session. + if err := tx.Create(session).Error; err != nil { + return fmt.Errorf("failed to create per-user oauth session: %w", err) + } + + // 3. Create the authorization code. + if err := tx.Create(code).Error; err != nil { + return fmt.Errorf("failed to create per-user oauth code: %w", err) + } + + return nil + }) + if err != nil { + return 0, err + } + return rowsAffected, nil +} + +// GetOauthUserTokensByGatewaySessionID returns all upstream tokens linked to a gateway session ID. +func (s *RDBConfigStore) GetOauthUserTokensByGatewaySessionID(ctx context.Context, gatewaySessionID string) ([]tables.TableOauthUserToken, error) { + if strings.TrimSpace(gatewaySessionID) == "" { + return nil, fmt.Errorf("gateway session id is required") + } + // Find all tokens whose session_token_hash matches any upstream session + // linked to this gateway session ID. This supports per-service proxy tokens + // (e.g. "flow::") where each MCP service gets its own hash. + var tokens []tables.TableOauthUserToken + subquery := s.db.Model(&tables.TableOauthUserSession{}).Select("session_token_hash").Where("gateway_session_id = ?", gatewaySessionID) + result := s.db.WithContext(ctx).Where("session_token_hash IN (?)", subquery).Find(&tokens) + if result.Error != nil { + return nil, fmt.Errorf("failed to get oauth user tokens by gateway session id: %w", result.Error) + } + return tokens, nil +} + +// TransferOauthUserTokensFromGatewaySession migrates upstream tokens from all flow proxy sessions +// (identified by gateway_session_id) to the real Bifrost session token, and sets VirtualKeyID/UserID. +func (s *RDBConfigStore) TransferOauthUserTokensFromGatewaySession(ctx context.Context, gatewaySessionID, realSessionToken, virtualKeyID, userID string) error { + if strings.TrimSpace(gatewaySessionID) == "" { + return fmt.Errorf("gateway session id is required") + } + if strings.TrimSpace(realSessionToken) == "" { + return fmt.Errorf("real session token is required") + } + realTokenHash := encrypt.HashSHA256(realSessionToken) + + // Always overwrite both identity columns from the finalized values so stale + // identities from a prior flow phase cannot persist and cause GetOauthUserTokenByIdentity + // to resolve this token under the wrong identity. + updates := map[string]interface{}{ + "session_token": realSessionToken, + "session_token_hash": realTokenHash, + "virtual_key_id": virtualKeyID, + "user_id": userID, + } + + // Update all tokens whose session_token_hash matches any upstream session + // linked to this gateway session ID. + subquery := s.db.Model(&tables.TableOauthUserSession{}).Select("session_token_hash").Where("gateway_session_id = ?", gatewaySessionID) + result := s.db.WithContext(ctx).Model(&tables.TableOauthUserToken{}). + Where("session_token_hash IN (?)", subquery). + Updates(updates) + if result.Error != nil { + return fmt.Errorf("failed to transfer oauth user tokens from gateway session: %w", result.Error) + } + s.logger.Debug("[rdb] TransferOauthUserTokensFromGatewaySession done: rows_affected=%d", result.RowsAffected) + return nil +} diff --git a/framework/configstore/store.go b/framework/configstore/store.go index 0f97fe47de..bb4f54966b 100644 --- a/framework/configstore/store.go +++ b/framework/configstore/store.go @@ -334,6 +334,26 @@ type ConfigStore interface { CreatePerUserOAuthCode(ctx context.Context, code *tables.TablePerUserOAuthCode) error UpdatePerUserOAuthCode(ctx context.Context, code *tables.TablePerUserOAuthCode) error + // Per-user OAuth consent flow (pending flows before code issuance) + GetPerUserOAuthPendingFlow(ctx context.Context, id string) (*tables.TablePerUserOAuthPendingFlow, error) + CreatePerUserOAuthPendingFlow(ctx context.Context, flow *tables.TablePerUserOAuthPendingFlow) error + UpdatePerUserOAuthPendingFlow(ctx context.Context, flow *tables.TablePerUserOAuthPendingFlow) error + DeletePerUserOAuthPendingFlow(ctx context.Context, id string) error + // ConsumePerUserOAuthPendingFlow atomically deletes a pending flow and returns the number of + // rows affected. Returns 0 if the flow was already consumed by a concurrent request. + ConsumePerUserOAuthPendingFlow(ctx context.Context, id string) (int64, error) + // FinalizePerUserOAuthConsent atomically consumes a pending flow, creates the session, + // and creates the authorization code in a single transaction. Returns (0, nil) if the + // flow was already consumed by a concurrent request. + FinalizePerUserOAuthConsent(ctx context.Context, flowID string, session *tables.TablePerUserOAuthSession, code *tables.TablePerUserOAuthCode) (int64, error) + // GetOauthUserTokensByGatewaySessionID returns all upstream tokens linked to a gateway session ID. + // Used during consent submit to discover which MCPs the user authenticated with. + // Queries tokens via upstream sessions matching the given gateway session ID. + GetOauthUserTokensByGatewaySessionID(ctx context.Context, gatewaySessionID string) ([]tables.TableOauthUserToken, error) + // TransferOauthUserTokensFromGatewaySession migrates upstream tokens from all flow proxy sessions + // (identified by gateway_session_id) to the real Bifrost session token, and sets VirtualKeyID/UserID on each record. + TransferOauthUserTokensFromGatewaySession(ctx context.Context, gatewaySessionID, realSessionToken, virtualKeyID, userID string) error + // Not found retry wrapper RetryOnNotFound(ctx context.Context, fn func(ctx context.Context) (any, error), maxRetries int, retryDelay time.Duration) (any, error) diff --git a/framework/configstore/tables/oauth.go b/framework/configstore/tables/oauth.go index 2b143a3fb5..9cb65bc4af 100644 --- a/framework/configstore/tables/oauth.go +++ b/framework/configstore/tables/oauth.go @@ -11,26 +11,26 @@ import ( // TableOauthConfig represents an OAuth configuration in the database // This stores the OAuth client configuration and flow state type TableOauthConfig struct { - ID string `gorm:"type:varchar(255);primaryKey" json:"id"` // UUID - ClientID string `gorm:"type:varchar(512)" json:"client_id"` // OAuth provider's client ID (optional for public clients) - ClientSecret string `gorm:"type:text" json:"-"` // Encrypted OAuth client secret (optional for public clients) - AuthorizeURL string `gorm:"type:text" json:"authorize_url"` // Provider's authorization endpoint (optional, can be discovered) - TokenURL string `gorm:"type:text" json:"token_url"` // Provider's token endpoint (optional, can be discovered) - RegistrationURL *string `gorm:"type:text" json:"registration_url,omitempty"` // Provider's dynamic registration endpoint (optional, can be discovered) - RedirectURI string `gorm:"type:text;not null" json:"redirect_uri"` // Callback URL - Scopes string `gorm:"type:text" json:"scopes"` // JSON array of scopes (optional, can be discovered) - State string `gorm:"type:varchar(255);uniqueIndex;not null" json:"-"` // CSRF state token - CodeVerifier string `gorm:"type:text" json:"-"` // PKCE code verifier (generated, kept secret) - CodeChallenge string `gorm:"type:varchar(255)" json:"code_challenge"` // PKCE code challenge (sent to provider) - Status string `gorm:"type:varchar(50);not null;index" json:"status"` // "pending", "authorized", "failed", "expired", "revoked" - TokenID *string `gorm:"type:varchar(255);index" json:"token_id"` // Foreign key to oauth_tokens.ID (set after callback) - ServerURL string `gorm:"type:text" json:"server_url"` // MCP server URL for OAuth discovery - UseDiscovery bool `gorm:"default:false" json:"use_discovery"` // Flag to enable OAuth discovery - MCPClientConfigJSON *string `gorm:"type:text" json:"-"` // JSON serialized MCPClientConfig for multi-instance support (pending MCP client waiting for OAuth completion) - EncryptionStatus string `gorm:"type:varchar(20);default:'plain_text'" json:"-"` - CreatedAt time.Time `gorm:"index;not null" json:"created_at"` - UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"` - ExpiresAt time.Time `gorm:"index;not null" json:"expires_at"` // State expiry (15 min) + ID string `gorm:"type:varchar(255);primaryKey" json:"id"` // UUID + ClientID string `gorm:"type:varchar(512)" json:"client_id"` // OAuth provider's client ID (optional for public clients) + ClientSecret string `gorm:"type:text" json:"-"` // Encrypted OAuth client secret (optional for public clients) + AuthorizeURL string `gorm:"type:text" json:"authorize_url"` // Provider's authorization endpoint (optional, can be discovered) + TokenURL string `gorm:"type:text" json:"token_url"` // Provider's token endpoint (optional, can be discovered) + RegistrationURL *string `gorm:"type:text" json:"registration_url,omitempty"` // Provider's dynamic registration endpoint (optional, can be discovered) + RedirectURI string `gorm:"type:text;not null" json:"redirect_uri"` // Callback URL + Scopes string `gorm:"type:text" json:"scopes"` // JSON array of scopes (optional, can be discovered) + State string `gorm:"type:varchar(255);uniqueIndex;not null" json:"-"` // CSRF state token + CodeVerifier string `gorm:"type:text" json:"-"` // PKCE code verifier (generated, kept secret) + CodeChallenge string `gorm:"type:varchar(255)" json:"code_challenge"` // PKCE code challenge (sent to provider) + Status string `gorm:"type:varchar(50);not null;index" json:"status"` // "pending", "authorized", "failed", "expired", "revoked" + TokenID *string `gorm:"type:varchar(255);index" json:"token_id"` // Foreign key to oauth_tokens.ID (set after callback) + ServerURL string `gorm:"type:text" json:"server_url"` // MCP server URL for OAuth discovery + UseDiscovery bool `gorm:"default:false" json:"use_discovery"` // Flag to enable OAuth discovery + MCPClientConfigJSON *string `gorm:"type:text" json:"-"` // JSON serialized MCPClientConfig for multi-instance support (pending MCP client waiting for OAuth completion) + EncryptionStatus string `gorm:"type:varchar(20);default:'plain_text'" json:"-"` + CreatedAt time.Time `gorm:"index;not null" json:"created_at"` + UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"` + ExpiresAt time.Time `gorm:"index;not null" json:"expires_at"` // State expiry (15 min) } // TableName sets the table name @@ -83,13 +83,13 @@ func (c *TableOauthConfig) AfterFind(tx *gorm.DB) error { // TableOauthToken represents an OAuth token in the database // This stores the actual access and refresh tokens type TableOauthToken struct { - ID string `gorm:"type:varchar(255);primaryKey" json:"id"` // UUID - AccessToken string `gorm:"type:text;not null" json:"-"` // Encrypted access token - RefreshToken string `gorm:"type:text" json:"-"` // Encrypted refresh token (optional) - TokenType string `gorm:"type:varchar(50);not null" json:"token_type"` // "Bearer" - ExpiresAt time.Time `gorm:"index;not null" json:"expires_at"` // Token expiration - Scopes string `gorm:"type:text" json:"scopes"` // JSON array of granted scopes - LastRefreshedAt *time.Time `gorm:"index" json:"last_refreshed_at,omitempty"` // Track when token was last refreshed + ID string `gorm:"type:varchar(255);primaryKey" json:"id"` // UUID + AccessToken string `gorm:"type:text;not null" json:"-"` // Encrypted access token + RefreshToken string `gorm:"type:text" json:"-"` // Encrypted refresh token (optional) + TokenType string `gorm:"type:varchar(50);not null" json:"token_type"` // "Bearer" + ExpiresAt time.Time `gorm:"index;not null" json:"expires_at"` // Token expiration + Scopes string `gorm:"type:text" json:"scopes"` // JSON array of granted scopes + LastRefreshedAt *time.Time `gorm:"index" json:"last_refreshed_at,omitempty"` // Track when token was last refreshed EncryptionStatus string `gorm:"type:varchar(20);default:'plain_text'" json:"-"` CreatedAt time.Time `gorm:"index;not null" json:"created_at"` UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"` @@ -139,20 +139,20 @@ func (t *TableOauthToken) AfterFind(tx *gorm.DB) error { // Each record maps an OAuth state token to a specific MCP client, allowing // the callback to associate the resulting tokens with the correct user session. type TableOauthUserSession struct { - ID string `gorm:"type:varchar(255);primaryKey" json:"id"` // Session UUID - MCPClientID string `gorm:"type:varchar(255);not null;index" json:"mcp_client_id"` // Which MCP server this auth is for - OauthConfigID string `gorm:"type:varchar(255);not null;index" json:"oauth_config_id"` // Template OAuth config (holds client_id, token_url, etc.) - State string `gorm:"type:varchar(255);uniqueIndex;not null" json:"-"` // CSRF state token sent to OAuth provider - RedirectURI string `gorm:"type:text" json:"-"` // Per-request redirect URI used in authorize step - CodeVerifier string `gorm:"type:text" json:"-"` // PKCE code verifier (kept secret) - SessionToken string `gorm:"type:varchar(255)" json:"-"` // Bifrost session ID (links to oauth_per_user_sessions) - SessionTokenHash string `gorm:"type:varchar(64);uniqueIndex" json:"-"` // SHA-256 hash of SessionToken for secure lookups - GatewaySessionID string `gorm:"type:varchar(255);index" json:"-"` // Bifrost MCP gateway session ID (separate from SessionToken) - VirtualKeyID string `gorm:"type:varchar(255);index" json:"virtual_key_id"` // VK identity (propagated to oauth_user_tokens) - UserID string `gorm:"type:varchar(255);index" json:"user_id"` // Enterprise user identity (propagated to oauth_user_tokens) - Status string `gorm:"type:varchar(50);not null;index" json:"status"` // "pending", "authorized", "failed", "expired" + ID string `gorm:"type:varchar(255);primaryKey" json:"id"` // Session UUID + MCPClientID string `gorm:"type:varchar(255);not null;index" json:"mcp_client_id"` // Which MCP server this auth is for + OauthConfigID string `gorm:"type:varchar(255);not null;index" json:"oauth_config_id"` // Template OAuth config (holds client_id, token_url, etc.) + State string `gorm:"type:varchar(255);uniqueIndex;not null" json:"-"` // CSRF state token sent to OAuth provider + RedirectURI string `gorm:"type:text" json:"-"` // Per-request redirect URI used in authorize step + CodeVerifier string `gorm:"type:text" json:"-"` // PKCE code verifier (kept secret) + SessionToken string `gorm:"type:varchar(255)" json:"-"` // Bifrost session ID (links to oauth_per_user_sessions) + SessionTokenHash string `gorm:"type:varchar(64);uniqueIndex" json:"-"` // SHA-256 hash of SessionToken for secure lookups + GatewaySessionID string `gorm:"type:varchar(255);index" json:"-"` // Bifrost MCP gateway session ID (separate from SessionToken) + VirtualKeyID *string `gorm:"type:varchar(255);index" json:"virtual_key_id"` // VK identity (propagated to oauth_user_tokens) + UserID *string `gorm:"type:varchar(255);index" json:"user_id"` // Enterprise user identity (propagated to oauth_user_tokens) + Status string `gorm:"type:varchar(50);not null;index" json:"status"` // "pending", "authorized", "failed", "expired" EncryptionStatus string `gorm:"type:varchar(20);default:'plain_text'" json:"-"` - ExpiresAt time.Time `gorm:"index;not null" json:"expires_at"` // Flow expiration (15 min) + ExpiresAt time.Time `gorm:"index;not null" json:"expires_at"` // Flow expiration (15 min) CreatedAt time.Time `gorm:"index;not null" json:"created_at"` UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"` } @@ -190,21 +190,21 @@ func (s *TableOauthUserSession) AfterFind(tx *gorm.DB) error { // TableOauthUserToken stores per-user OAuth credentials. // Each record holds the access/refresh tokens for a specific user session + MCP client pair. -// Lookup is by SessionToken (from the X-Bifrost-MCP-Session header). +// Lookup is by SessionToken. type TableOauthUserToken struct { - ID string `gorm:"type:varchar(255);primaryKey" json:"id"` // Token UUID - SessionToken string `gorm:"type:varchar(255)" json:"-"` // Maps to Bifrost session (fallback for anonymous users) - SessionTokenHash string `gorm:"type:varchar(64);index" json:"-"` // SHA-256 hash of SessionToken for secure lookups - VirtualKeyID string `gorm:"type:varchar(255);index:idx_vk_mcp" json:"virtual_key_id"` // VK identity (persistent across sessions) - UserID string `gorm:"type:varchar(255);index:idx_user_mcp" json:"user_id"` // Enterprise user identity (persistent across sessions) + ID string `gorm:"type:varchar(255);primaryKey" json:"id"` // Token UUID + SessionToken string `gorm:"type:varchar(255)" json:"-"` // Maps to Bifrost session (fallback for anonymous users) + SessionTokenHash string `gorm:"type:varchar(64);index" json:"-"` // SHA-256 hash of SessionToken for secure lookups + VirtualKeyID *string `gorm:"type:varchar(255);index:idx_vk_mcp" json:"virtual_key_id"` // VK identity (persistent across sessions) + UserID *string `gorm:"type:varchar(255);index:idx_user_mcp" json:"user_id"` // Enterprise user identity (persistent across sessions) MCPClientID string `gorm:"type:varchar(255);not null;index:idx_vk_mcp;index:idx_user_mcp" json:"mcp_client_id"` // Which MCP server - OauthConfigID string `gorm:"type:varchar(255);not null;index" json:"oauth_config_id"` // Template OAuth config - AccessToken string `gorm:"type:text;not null" json:"-"` // Encrypted user's OAuth access token - RefreshToken string `gorm:"type:text" json:"-"` // Encrypted user's OAuth refresh token - TokenType string `gorm:"type:varchar(50);not null" json:"token_type"` // "Bearer" - ExpiresAt time.Time `gorm:"index;not null" json:"expires_at"` // Token expiry - Scopes string `gorm:"type:text" json:"scopes"` // JSON array of granted scopes - LastRefreshedAt *time.Time `gorm:"index" json:"last_refreshed_at,omitempty"` // Last refresh time + OauthConfigID string `gorm:"type:varchar(255);not null;index" json:"oauth_config_id"` // Template OAuth config + AccessToken string `gorm:"type:text;not null" json:"-"` // Encrypted user's OAuth access token + RefreshToken string `gorm:"type:text" json:"-"` // Encrypted user's OAuth refresh token + TokenType string `gorm:"type:varchar(50);not null" json:"token_type"` // "Bearer" + ExpiresAt time.Time `gorm:"index;not null" json:"expires_at"` // Token expiry + Scopes string `gorm:"type:text" json:"scopes"` // JSON array of granted scopes + LastRefreshedAt *time.Time `gorm:"index" json:"last_refreshed_at,omitempty"` // Last refresh time EncryptionStatus string `gorm:"type:varchar(20);default:'plain_text'" json:"-"` CreatedAt time.Time `gorm:"index;not null" json:"created_at"` UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"` @@ -270,18 +270,19 @@ func (TablePerUserOAuthClient) TableName() string { // is created. The access token is included in all subsequent MCP requests. // Upstream provider tokens are linked via the oauth_user_tokens table. type TablePerUserOAuthSession struct { - ID string `gorm:"type:varchar(255);primaryKey" json:"id"` - AccessToken string `gorm:"type:text;not null" json:"-"` // Bifrost-issued access token (encrypted) - AccessTokenHash string `gorm:"type:varchar(64);uniqueIndex" json:"-"` // SHA-256 hash for secure lookups - RefreshToken string `gorm:"type:text" json:"-"` // Bifrost-issued refresh token (encrypted, optional) - RefreshTokenHash string `gorm:"type:varchar(64);index" json:"-"` // SHA-256 hash for secure lookups (not unique — refresh tokens are optional) - ClientID string `gorm:"type:varchar(255);not null;index" json:"client_id"` // Which OAuth client registered this session - VirtualKeyID string `gorm:"type:varchar(255);index" json:"virtual_key_id"` // Linked VK identity (set when VK is present during auth) - UserID string `gorm:"type:varchar(255);index" json:"user_id"` // Linked enterprise user identity (set when user ID is present) - ExpiresAt time.Time `gorm:"index;not null" json:"expires_at"` - EncryptionStatus string `gorm:"type:varchar(20);default:'plain_text'" json:"-"` - CreatedAt time.Time `gorm:"index;not null" json:"created_at"` - UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"` + ID string `gorm:"type:varchar(255);primaryKey" json:"id"` + AccessToken string `gorm:"type:text;not null" json:"-"` // Bifrost-issued access token (encrypted) + AccessTokenHash string `gorm:"type:varchar(64);uniqueIndex" json:"-"` // SHA-256 hash for secure lookups + RefreshToken string `gorm:"type:text" json:"-"` // Bifrost-issued refresh token (encrypted, optional) + RefreshTokenHash string `gorm:"type:varchar(64);index" json:"-"` // SHA-256 hash for secure lookups (not unique — refresh tokens are optional) + ClientID string `gorm:"type:varchar(255);not null;index" json:"client_id"` // Which OAuth client registered this session + VirtualKeyID *string `gorm:"type:varchar(255);index" json:"virtual_key_id"` // Linked VK identity (set when VK is present during auth) + VirtualKey *TableVirtualKey `gorm:"foreignKey:VirtualKeyID" json:"-"` // Linked VK identity (server-only, not serialized) + UserID *string `gorm:"type:varchar(255);index" json:"user_id"` // Linked enterprise user identity (set when user ID is present) + ExpiresAt time.Time `gorm:"index;not null" json:"expires_at"` + EncryptionStatus string `gorm:"type:varchar(20);default:'plain_text'" json:"-"` + CreatedAt time.Time `gorm:"index;not null" json:"created_at"` + UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"` } // TableName returns the table name for per-user OAuth sessions. @@ -330,12 +331,13 @@ func (s *TablePerUserOAuthSession) AfterFind(tx *gorm.DB) error { // Codes are short-lived (5 minutes) and single-use. type TablePerUserOAuthCode struct { ID string `gorm:"type:varchar(255);primaryKey" json:"id"` - Code string `gorm:"type:text;not null" json:"-"` // Authorization code - CodeHash string `gorm:"type:varchar(64);uniqueIndex" json:"-"` // SHA-256 hash for secure lookups + Code string `gorm:"type:text;not null" json:"-"` // Authorization code + CodeHash string `gorm:"type:varchar(64);uniqueIndex" json:"-"` // SHA-256 hash for secure lookups ClientID string `gorm:"type:varchar(255);not null;index" json:"client_id"` RedirectURI string `gorm:"type:text;not null" json:"redirect_uri"` - CodeChallenge string `gorm:"type:varchar(255);not null" json:"-"` // PKCE S256 challenge + CodeChallenge string `gorm:"type:varchar(255);not null" json:"-"` // PKCE S256 challenge Scopes string `gorm:"type:text" json:"scopes"` // JSON array of requested scopes + SessionID string `gorm:"type:varchar(255);index" json:"-"` // Links to the TablePerUserOAuthSession created during consent submit ExpiresAt time.Time `gorm:"index;not null" json:"expires_at"` // 5 min TTL Used bool `gorm:"default:false;not null" json:"used"` // Single-use flag CreatedAt time.Time `gorm:"index;not null" json:"created_at"` @@ -353,3 +355,25 @@ func (c *TablePerUserOAuthCode) BeforeSave(tx *gorm.DB) error { func (TablePerUserOAuthCode) TableName() string { return "oauth_per_user_codes" } + +// TablePerUserOAuthPendingFlow stores OAuth parameters between the authorize step +// and the final code issuance. It carries state through the multi-step consent +// screen (VK entry + per-MCP upstream auth) before a real authorization code is issued. +type TablePerUserOAuthPendingFlow struct { + ID string `gorm:"type:varchar(255);primaryKey" json:"id"` + ClientID string `gorm:"type:varchar(255);not null;index" json:"client_id"` // Registered OAuth client (from authorize request) + RedirectURI string `gorm:"type:text;not null" json:"redirect_uri"` // Client's callback URL + CodeChallenge string `gorm:"type:varchar(255);not null" json:"-"` // PKCE S256 challenge (echoed into the final code) + State string `gorm:"type:text;not null" json:"-"` // Original OAuth state (echoed back on final redirect) + VirtualKeyID *string `gorm:"type:varchar(255);index" json:"virtual_key_id"` // Set if user chose VK identity + UserID *string `gorm:"type:varchar(255);index" json:"user_id"` // Set if user chose User ID identity + BrowserSecretHash string `gorm:"type:varchar(255)" json:"-"` // SHA-256 hash of browser-binding cookie secret + ExpiresAt time.Time `gorm:"index;not null" json:"expires_at"` // 15-min TTL + CreatedAt time.Time `gorm:"index;not null" json:"created_at"` + UpdatedAt time.Time `gorm:"index;not null" json:"updated_at"` +} + +// TableName returns the table name for per-user OAuth pending flows. +func (TablePerUserOAuthPendingFlow) TableName() string { + return "oauth_per_user_pending_flows" +} diff --git a/framework/oauth2/main.go b/framework/oauth2/main.go index e45de911d1..44667932aa 100644 --- a/framework/oauth2/main.go +++ b/framework/oauth2/main.go @@ -722,12 +722,6 @@ func (p *OAuth2Provider) InitiateUserOAuthFlow(ctx context.Context, oauthConfigI json.Unmarshal([]byte(templateConfig.Scopes), &scopes) } - // Generate session token upfront to avoid empty unique-index collisions on pending sessions - sessionToken, err := generateSessionToken() - if err != nil { - return nil, "", fmt.Errorf("failed to generate session token: %w", err) - } - // Create per-user OAuth session sessionID := uuid.New().String() expiresAt := time.Now().Add(15 * time.Minute) @@ -740,6 +734,24 @@ func (p *OAuth2Provider) InitiateUserOAuthFlow(ctx context.Context, oauthConfigI userID = mcpUserID } + // If a Bifrost MCP session token is present in context, reuse it as the session token + // so the MCP server token is stored under the same key used for subsequent lookups. + // Otherwise generate a fresh token. + sessionToken, _ := ctx.Value(schemas.BifrostContextKeyMCPUserSession).(string) + if sessionToken == "" { + sessionToken, err = generateSessionToken() + if err != nil { + return nil, "", fmt.Errorf("failed to generate session token: %w", err) + } + } + var vkId *string + if virtualKeyID != "" { + vkId = &virtualKeyID + } + var uid *string + if userID != "" { + uid = &userID + } session := &tables.TableOauthUserSession{ ID: sessionID, MCPClientID: mcpClientID, @@ -748,8 +760,8 @@ func (p *OAuth2Provider) InitiateUserOAuthFlow(ctx context.Context, oauthConfigI RedirectURI: redirectURI, CodeVerifier: codeVerifier, SessionToken: sessionToken, - VirtualKeyID: virtualKeyID, - UserID: userID, + VirtualKeyID: vkId, + UserID: uid, Status: "pending", ExpiresAt: expiresAt, } @@ -805,7 +817,6 @@ func (p *OAuth2Provider) CompleteUserOAuthFlow(ctx context.Context, state string p.configStore.UpdateOauthUserSession(ctx, session) return "", fmt.Errorf("failed to load template oauth config: %w", err) } - // Exchange code for tokens with PKCE verifier // Use the redirect URI stored in the session (same one used in authorize step) // to satisfy OAuth spec requirement that redirect_uri must match @@ -828,7 +839,7 @@ func (p *OAuth2Provider) CompleteUserOAuthFlow(ctx context.Context, state string } // Use existing session token if set (e.g., Bifrost session ID from MCP spec OAuth flow), - // otherwise generate a new one (for standalone per-user OAuth via X-Bifrost-MCP-Session header). + // otherwise generate a new one (for standalone per-user OAuth). sessionToken := session.SessionToken if sessionToken == "" { sessionToken, err = generateSessionToken() @@ -860,7 +871,6 @@ func (p *OAuth2Provider) CompleteUserOAuthFlow(ctx context.Context, state string ExpiresAt: time.Now().Add(time.Duration(tokenResponse.ExpiresIn) * time.Second), Scopes: string(scopesJSON), } - if err := p.configStore.CreateOauthUserToken(ctx, tokenRecord); err != nil { return "", fmt.Errorf("failed to create per-user oauth token: %w", err) } @@ -917,7 +927,7 @@ func (p *OAuth2Provider) GetUserAccessTokenByIdentity(ctx context.Context, virtu return "", fmt.Errorf("failed to load per-user oauth token by identity: %w", err) } if token == nil { - return "", fmt.Errorf("per-user oauth token not found for this user and MCP server") + return "", schemas.ErrOAuth2TokenNotFound } // Check if token is expired — attempt refresh diff --git a/plugins/maxim/main.go b/plugins/maxim/main.go index a4f2951502..ed8f205dd0 100644 --- a/plugins/maxim/main.go +++ b/plugins/maxim/main.go @@ -556,6 +556,10 @@ func (plugin *Plugin) PostLLMHook(ctx *schemas.BifrostContext, result *schemas.B go func() { requestType, _, originalModel, resolvedModel := bifrost.GetResponseFields(result, bifrostErr) + modelTag := resolvedModel + if modelTag == "" { + modelTag = originalModel + } var streamResponse *streaming.ProcessedStreamResponse if bifrost.IsStreamRequestType(requestType) { @@ -660,11 +664,11 @@ func (plugin *Plugin) PostLLMHook(ctx *schemas.BifrostContext, result *schemas.B } } } - if hasGenerationID && generationID != "" { - logger.AddTagToGeneration(generationID, "model", string(model)) + if hasGenerationID && generationID != "" && modelTag != "" { + logger.AddTagToGeneration(generationID, "model", string(modelTag)) } - if hasTraceID && traceID != "" { - logger.AddTagToTrace(traceID, "model", string(model)) + if hasTraceID && traceID != "" && modelTag != "" { + logger.AddTagToTrace(traceID, "model", string(modelTag)) } // Flush only the effective logger that was used for this request logger.Flush() diff --git a/plugins/prompts/go.mod b/plugins/prompts/go.mod index 5c5acb4864..d293e006f8 100644 --- a/plugins/prompts/go.mod +++ b/plugins/prompts/go.mod @@ -71,7 +71,7 @@ require ( golang.org/x/arch v0.23.0 // indirect golang.org/x/crypto v0.49.0 // indirect golang.org/x/net v0.52.0 // indirect - golang.org/x/oauth2 v0.35.0 // indirect + golang.org/x/oauth2 v0.36.0 // indirect golang.org/x/sys v0.42.0 // indirect golang.org/x/text v0.35.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/plugins/prompts/go.sum b/plugins/prompts/go.sum index b51fcaafaa..6b203896be 100644 --- a/plugins/prompts/go.sum +++ b/plugins/prompts/go.sum @@ -166,8 +166,7 @@ golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4= golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA= golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0= golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw= -golang.org/x/oauth2 v0.35.0 h1:Mv2mzuHuZuY2+bkyWXIHMfhNdJAdwW3FuWeCPYN5GVQ= -golang.org/x/oauth2 v0.35.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= +golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs= golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/tests/integrations/python/pyproject.toml b/tests/integrations/python/pyproject.toml index 70e3167a4b..8d49b81278 100644 --- a/tests/integrations/python/pyproject.toml +++ b/tests/integrations/python/pyproject.toml @@ -24,12 +24,12 @@ dependencies = [ # AI/ML SDK dependencies "openai>=1.30.0", "anthropic>=0.25.0", - "litellm>=1.80.5", - "langchain-openai>=0.1.0", - "langchain-core>=0.3.0", - "langchain-anthropic>=0.1.0", + "litellm==1.80.5", + "langchain-openai==0.1.0", + "langchain-core==0.3.81", + "langchain-anthropic==0.1.0", "langchain-google-genai==4.1.1", - "langchain-mistralai>=0.1.0", + "langchain-mistralai==0.1.0", "langgraph>=0.1.0", "mistralai>=0.4.0", "google-genai>=1.50.0", @@ -123,4 +123,4 @@ exclude_lines = [ [tool.uv] -exclude-newer = "7 days" \ No newline at end of file +exclude-newer = "2026-04-08" \ No newline at end of file diff --git a/transports/bifrost-http/handlers/mcp.go b/transports/bifrost-http/handlers/mcp.go index b3e875d964..31999287fd 100644 --- a/transports/bifrost-http/handlers/mcp.go +++ b/transports/bifrost-http/handlers/mcp.go @@ -435,18 +435,18 @@ func (h *MCPHandler) addMCPClient(ctx *fasthttp.RequestCtx) { } pendingConfig := schemas.MCPClientConfig{ - ID: req.ClientID, - Name: req.Name, - IsCodeModeClient: req.IsCodeModeClient, - IsPingAvailable: &isPingAvailable, - ToolSyncInterval: toolSyncInterval, - ConnectionType: schemas.MCPConnectionType(req.ConnectionType), - ConnectionString: req.ConnectionString, - StdioConfig: req.StdioConfig, - AuthType: schemas.MCPAuthTypePerUserOauth, - OauthConfigID: &flowInitiation.OauthConfigID, - ToolsToExecute: req.ToolsToExecute, - ToolsToAutoExecute: req.ToolsToAutoExecute, + ID: req.ClientID, + Name: req.Name, + IsCodeModeClient: req.IsCodeModeClient, + IsPingAvailable: &isPingAvailable, + ToolSyncInterval: toolSyncInterval, + ConnectionType: schemas.MCPConnectionType(req.ConnectionType), + ConnectionString: req.ConnectionString, + StdioConfig: req.StdioConfig, + AuthType: schemas.MCPAuthTypePerUserOauth, + OauthConfigID: &flowInitiation.OauthConfigID, + ToolsToExecute: req.ToolsToExecute, + ToolsToAutoExecute: req.ToolsToAutoExecute, ToolPricing: req.ToolPricing, Headers: req.Headers, AllowedExtraHeaders: req.AllowedExtraHeaders, diff --git a/transports/bifrost-http/handlers/mcpserver.go b/transports/bifrost-http/handlers/mcpserver.go index 00d6d5dfb9..f3214e801c 100644 --- a/transports/bifrost-http/handlers/mcpserver.go +++ b/transports/bifrost-http/handlers/mcpserver.go @@ -31,12 +31,11 @@ type MCPToolManager interface { // MCPServerHandler manages HTTP requests for MCP server operations // It implements the MCP protocol over HTTP streaming (SSE) for MCP clients type MCPServerHandler struct { - toolManager MCPToolManager - globalMCPServer *server.MCPServer - vkMCPServers map[string]*server.MCPServer // Map of vk value -> mcp server - config *lib.Config - hasPerUserOAuthServers bool // Whether any per_user_oauth MCP servers are configured - mu sync.RWMutex + toolManager MCPToolManager + globalMCPServer *server.MCPServer + vkMCPServers map[string]*server.MCPServer // Map of vk value -> mcp server + config *lib.Config + mu sync.RWMutex } // NewMCPServerHandler creates a new MCP server handler instance @@ -83,8 +82,33 @@ func (h *MCPServerHandler) RegisterRoutes(r *router.Router, middlewares ...schem } // handleMCPServer handles POST requests for MCP JSON-RPC 2.0 messages +// injectMCPSessionIdentity sets the MCP gateway flag and, if a per-user OAuth +// session exists, injects the session token and identity (VK / User ID) directly +// into the BifrostContext. This avoids header-based identity propagation which +// would be vulnerable to spoofing by upstream callers. +// +// Governance context keys are set here intentionally (bypassing governance plugin) +// because in the MCP gateway path, identity is pre-authenticated via the OAuth session. +func injectMCPSessionIdentity(bifrostCtx *schemas.BifrostContext, session *tables.TablePerUserOAuthSession) { + bifrostCtx.SetValue(schemas.BifrostContextKeyIsMCPGateway, true) + if session != nil { + if session.AccessToken != "" { + bifrostCtx.SetValue(schemas.BifrostContextKeyMCPUserSession, session.AccessToken) + } + if session.VirtualKeyID != nil && *session.VirtualKeyID != "" { + bifrostCtx.SetValue(schemas.BifrostContextKeyGovernanceVirtualKeyID, *session.VirtualKeyID) + if session.VirtualKey != nil && session.VirtualKey.Name != "" { + bifrostCtx.SetValue(schemas.BifrostContextKeyGovernanceVirtualKeyName, session.VirtualKey.Name) + } + } + if session.UserID != nil && *session.UserID != "" { + bifrostCtx.SetValue(schemas.BifrostContextKeyGovernanceUserID, *session.UserID) + } + } +} + func (h *MCPServerHandler) handleMCPServer(ctx *fasthttp.RequestCtx) { - mcpServer, err := h.getMCPServerForRequest(ctx) + mcpServer, session, err := h.getMCPServerForRequest(ctx) if err != nil { SendError(ctx, fasthttp.StatusUnauthorized, err.Error()) return @@ -94,6 +118,8 @@ func (h *MCPServerHandler) handleMCPServer(ctx *fasthttp.RequestCtx) { bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, false, h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) defer cancel() + injectMCPSessionIdentity(bifrostCtx, session) + // Use mcp-go server to handle the request // HandleMessage processes JSON-RPC messages and returns appropriate responses response := mcpServer.HandleMessage(bifrostCtx, ctx.PostBody()) @@ -118,7 +144,7 @@ func (h *MCPServerHandler) handleMCPServer(ctx *fasthttp.RequestCtx) { // handleMCPServerSSE handles GET requests for MCP Server-Sent Events streaming func (h *MCPServerHandler) handleMCPServerSSE(ctx *fasthttp.RequestCtx) { - _, err := h.getMCPServerForRequest(ctx) + _, session, err := h.getMCPServerForRequest(ctx) if err != nil { SendError(ctx, fasthttp.StatusUnauthorized, err.Error()) return @@ -132,6 +158,8 @@ func (h *MCPServerHandler) handleMCPServerSSE(ctx *fasthttp.RequestCtx) { // Convert context bifrostCtx, cancel := lib.ConvertToBifrostContext(ctx, false, h.config.GetHeaderMatcher(), h.config.GetMCPHeaderCombinedAllowlist()) + injectMCPSessionIdentity(bifrostCtx, session) + // Use SSEStreamReader to bypass fasthttp's internal pipe batching reader := lib.NewSSEStreamReader() ctx.Response.SetBodyStream(reader, -1) @@ -260,6 +288,12 @@ func (h *MCPServerHandler) syncServer(server *server.MCPServer, availableTools [ // Execute the tool via tool executor toolMessage, err := h.toolManager.ExecuteChatMCPTool(ctx, &toolCall) if err != nil { + if err.ExtraFields.MCPAuthRequired != nil { + return mcp.NewToolResultError(fmt.Sprintf( + "Authentication required for %s. Open this URL to connect your account: %s", + err.ExtraFields.MCPAuthRequired.MCPClientName, err.ExtraFields.MCPAuthRequired.AuthorizeURL, + )), nil + } return mcp.NewToolResultError(fmt.Sprintf("Tool execution failed: %v", bifrost.GetErrorMessage(err))), nil } @@ -401,16 +435,7 @@ func (h *MCPServerHandler) makeIncludeClientsFilter() server.ToolFilterFunc { // Utility methods -// SetHasPerUserOAuthServers updates whether any per_user_oauth MCP servers are -// configured. When true, the /mcp endpoint returns 401 with WWW-Authenticate -// for unauthenticated requests to trigger the MCP spec OAuth flow. -func (h *MCPServerHandler) SetHasPerUserOAuthServers(value bool) { - h.mu.Lock() - defer h.mu.Unlock() - h.hasPerUserOAuthServers = value -} - -func (h *MCPServerHandler) getMCPServerForRequest(ctx *fasthttp.RequestCtx) (*server.MCPServer, error) { +func (h *MCPServerHandler) getMCPServerForRequest(ctx *fasthttp.RequestCtx) (*server.MCPServer, *tables.TablePerUserOAuthSession, error) { h.mu.RLock() defer h.mu.RUnlock() @@ -421,88 +446,91 @@ func (h *MCPServerHandler) getMCPServerForRequest(ctx *fasthttp.RequestCtx) (*se vk := getVKFromRequest(ctx) // Check for Bifrost per-user OAuth Bearer token (not a VK) - perUserSession := h.getPerUserOAuthSession(ctx) - - // If per_user_oauth servers are configured and no valid auth, return 401 with discovery - if h.hasPerUserOAuthServers && vk == "" && perUserSession == nil { - if !enforceVK { - // Even without enforced VK, per_user_oauth servers require auth - scheme := "http" - if ctx.IsTLS() || string(ctx.Request.Header.Peek("X-Forwarded-Proto")) == "https" { - scheme = "https" - } - host := string(ctx.Host()) - resourceMetadataURL := fmt.Sprintf("%s://%s/.well-known/oauth-protected-resource", scheme, host) - ctx.Response.Header.Set("WWW-Authenticate", - fmt.Sprintf(`Bearer resource_metadata="%s"`, resourceMetadataURL)) - return nil, fmt.Errorf("OAuth authentication required for MCP access") + userOauthSession, sessionErr := h.getPerUserOAuthSession(ctx) + if sessionErr != nil { + return nil, nil, fmt.Errorf("failed to look up OAuth session: %w", sessionErr) + } + + // If per_user_oauth MCP clients are configured and no valid auth, return 401 with discovery + if clients := h.config.GetPerUserOAuthMCPClients(); len(clients) > 0 && userOauthSession == nil && vk == "" { + scheme := "http" + if ctx.IsTLS() || string(ctx.Request.Header.Peek("X-Forwarded-Proto")) == "https" { + scheme = "https" } + host := string(ctx.Host()) + resourceMetadataURL := fmt.Sprintf("%s://%s/.well-known/oauth-protected-resource", scheme, host) + ctx.Response.Header.Set("WWW-Authenticate", + fmt.Sprintf(`Bearer resource_metadata="%s"`, resourceMetadataURL)) + return nil, nil, fmt.Errorf("oauth authentication required for mcp access") } - // If a per-user OAuth session is present, inject it into context and use global server. - // Also attach user identity to the session if not already set (first request after auth). - if perUserSession != nil { - // Propagate the access token so ConvertToBifrostContext picks it up - // and downstream tool execution can find the per-user OAuth session. - if perUserSession.AccessToken != "" { - ctx.Request.Header.Set("X-Bifrost-MCP-Session", perUserSession.AccessToken) + if userOauthSession != nil { + if !enforceVK && (userOauthSession.VirtualKeyID == nil || *userOauthSession.VirtualKeyID == "") { + return h.globalMCPServer, userOauthSession, nil } - ctx.SetUserValue(string(schemas.BifrostContextKeyMCPUserSession), perUserSession.AccessToken) - // Identity (VirtualKeyID, UserID) is resolved by governance middleware downstream. - // Do not store the raw VK secret here — governance resolves the actual VK ID. + if userOauthSession.VirtualKeyID == nil || *userOauthSession.VirtualKeyID == "" || userOauthSession.VirtualKey == nil { + return nil, nil, fmt.Errorf("virtual key required in oauth session to access mcp server, please re-authenticate with a virtual key") + } - if vk == "" { - return h.globalMCPServer, nil + vkServer, ok := h.vkMCPServers[userOauthSession.VirtualKey.Value] + if !ok { + return nil, nil, fmt.Errorf("virtual key not found") } + + return vkServer, userOauthSession, nil } // Return global MCP server if not enforcing virtual key header and no virtual key is provided if !enforceVK && vk == "" { - return h.globalMCPServer, nil + return h.globalMCPServer, nil, nil } - // Check if virtual key is provided if vk == "" { - return nil, fmt.Errorf("virtual key header is required to access MCP server.") + return nil, nil, fmt.Errorf("virtual key header required to access mcp server") } - // Check if vk exists in the map vkServer, ok := h.vkMCPServers[vk] if !ok { - return nil, fmt.Errorf("virtual key not found.") + return nil, nil, fmt.Errorf("virtual key not found") } - return vkServer, nil + return vkServer, nil, nil } // getPerUserOAuthSession extracts and validates a Bifrost-issued per-user OAuth // token from the Authorization header. Returns the session if valid, nil otherwise. -func (h *MCPServerHandler) getPerUserOAuthSession(ctx *fasthttp.RequestCtx) *tables.TablePerUserOAuthSession { +func (h *MCPServerHandler) getPerUserOAuthSession(ctx *fasthttp.RequestCtx) (*tables.TablePerUserOAuthSession, error) { authHeader := strings.TrimSpace(string(ctx.Request.Header.Peek("Authorization"))) if authHeader == "" || !strings.HasPrefix(strings.ToLower(authHeader), "bearer ") { - return nil + return nil, nil } token := strings.TrimSpace(authHeader[7:]) if token == "" || strings.HasPrefix(strings.ToLower(token), governance.VirtualKeyPrefix) { - return nil // It's a virtual key, not a per-user OAuth token + return nil, nil // It's a virtual key, not a per-user OAuth token } if h.config.ConfigStore == nil { - return nil + return nil, nil } session, err := h.config.ConfigStore.GetPerUserOAuthSessionByAccessToken(ctx, token) - if err != nil || session == nil { - return nil + if err != nil { + logger.Warn("[mcp/auth] GetPerUserOAuthSessionByAccessToken error: %v", err) + return nil, err + } + if session == nil { + logger.Debug("[mcp/auth] Session not found for token") + return nil, nil } // Check expiry if session.ExpiresAt.Before(time.Now()) { - return nil + logger.Debug("[mcp/auth] Session expired: session_id=%s expires_at=%v", session.ID, session.ExpiresAt) + return nil, nil } - return session + return session, nil } func getVKFromRequest(ctx *fasthttp.RequestCtx) string { diff --git a/transports/bifrost-http/handlers/oauth2.go b/transports/bifrost-http/handlers/oauth2.go index e094b76773..5240b32b91 100644 --- a/transports/bifrost-http/handlers/oauth2.go +++ b/transports/bifrost-http/handlers/oauth2.go @@ -8,6 +8,8 @@ import ( "errors" "fmt" "html" + "net/url" + "strings" "github.com/fasthttp/router" bifrost "github.com/maximhq/bifrost/core" @@ -69,31 +71,25 @@ func (h *OAuthHandler) handleOAuthCallback(ctx *fasthttp.RequestCtx) { return } if perUserErr == nil && sessionToken != "" { - // Per-user runtime OAuth flow completed — show session token + // Consent flow: session token is a flow proxy ("flow::"). + // Redirect back to the MCPs consent page so the user can continue. + if strings.HasPrefix(sessionToken, "flow:") { + rest := strings.TrimPrefix(sessionToken, "flow:") + flowID := strings.SplitN(rest, ":", 2)[0] + mcpsURL := fmt.Sprintf("/oauth/consent/mcps?flow_id=%s", url.QueryEscape(flowID)) + ctx.Redirect(mcpsURL, fasthttp.StatusFound) + return + } + + // Per-user runtime OAuth flow completed — show success page. ctx.SetStatusCode(fasthttp.StatusOK) ctx.SetContentType("text/html") - ctx.SetBodyString(` - - - - Authorization Successful - - - -
-
-

Authorization Successful

-

You can close this tab.

-
-
- - - `) + ctx.SetBodyString(oauthSuccessPage(` + if (window.opener) { + window.opener.postMessage({ type: 'oauth_success' }, window.location.origin); + window.close(); + } + `, "Authorization Successful", "You can close this tab.")) return } @@ -107,31 +103,12 @@ func (h *OAuthHandler) handleOAuthCallback(ctx *fasthttp.RequestCtx) { // Redirect to success page (or close popup) ctx.SetStatusCode(fasthttp.StatusOK) ctx.SetContentType("text/html") - ctx.SetBodyString(` - - - - OAuth Success - - - -
-
-

Authorization Successful

-

This window will close automatically...

-
-
- - - `) + ctx.SetBodyString(oauthSuccessPage(` + if (window.opener) { + window.opener.postMessage({ type: 'oauth_success' }, window.location.origin); + window.close(); + } + `, "Authorization Successful", "OAuth authorization successful! You can close this window.")) } // handleCallbackError handles OAuth callback errors @@ -156,30 +133,7 @@ func (h *OAuthHandler) handleCallbackError(ctx *fasthttp.RequestCtx, state, erro jsEscaped, _ := json.Marshal(errorMsg) // HTML-escape for safe embedding in HTML body (prevents HTML injection) htmlEscaped := html.EscapeString(errorMsg) - ctx.SetBodyString(fmt.Sprintf(` - - - - OAuth Failed - - - -
-
-

✗ Authorization Failed

-

%s

-

You can close this window.

-
-
- - - `, jsEscaped, htmlEscaped)) + ctx.SetBodyString(oauthErrorPage(string(jsEscaped), htmlEscaped)) } // getOAuthConfigStatus returns the current status of an OAuth config @@ -295,6 +249,64 @@ func (h *OAuthHandler) GetAccessToken(ctx context.Context, oauthConfigID string) return h.oauthProvider.GetAccessToken(ctx, oauthConfigID) } +// oauthSuccessPage renders a Bifrost-themed success HTML page. +// extraScript is injected verbatim into a + + +
+
+

%s

+

%s

+
+ +`, html.EscapeString(title), bifrostPageCSS, extraScript, html.EscapeString(title), html.EscapeString(message)) +} + +// oauthErrorPage renders a Bifrost-themed error HTML page. +// jsEscapedError must already be JSON-encoded (with quotes) for safe JS embedding. +// htmlError must already be HTML-escaped for safe body embedding. +func oauthErrorPage(jsEscapedError, htmlError string) string { + return fmt.Sprintf(` + + + + +Authorization Failed + + + + +
+
+

Authorization Failed

+

%s

+

You can close this window.

+
+ +`, bifrostPageCSS, jsEscapedError, htmlError) +} + // jsEscapeString returns a JSON-encoded string (with quotes) safe for embedding in JavaScript. func jsEscapeString(s string) string { b, _ := json.Marshal(s) diff --git a/transports/bifrost-http/handlers/oauth2_consent.go b/transports/bifrost-http/handlers/oauth2_consent.go new file mode 100644 index 0000000000..58b6d6e632 --- /dev/null +++ b/transports/bifrost-http/handlers/oauth2_consent.go @@ -0,0 +1,641 @@ +// Package handlers provides HTTP request handlers for the Bifrost HTTP transport. +// This file implements the per-user OAuth consent flow — the intermediate screens +// shown between the MCP client's authorize request and the final authorization code +// issuance. The flow is: +// +// 1. GET /oauth/consent?flow_id=xxx → VK input page (HTML) +// 2. POST /api/oauth/per-user/consent/vk → validate VK, update PendingFlow, redirect +// 3. GET /oauth/consent/mcps?flow_id=xxx → MCPs page (HTML, server-rendered) +// 4. POST /api/oauth/per-user/consent/submit → create session + code, redirect to client +package handlers + +import ( + "errors" + "fmt" + "html" + "net/url" + "sort" + "strings" + "time" + + "github.com/fasthttp/router" + "github.com/google/uuid" + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/framework/configstore/tables" + "github.com/maximhq/bifrost/transports/bifrost-http/lib" + "github.com/valyala/fasthttp" +) + +// ConsentHandler manages the per-user OAuth consent flow screens. +type ConsentHandler struct { + store *lib.Config +} + +// NewConsentHandler creates a new consent handler instance. +func NewConsentHandler(store *lib.Config) *ConsentHandler { + return &ConsentHandler{store: store} +} + +// RegisterRoutes registers the consent flow routes. +// All routes are public — no auth middleware — since they are part of the OAuth +// flow for unauthenticated users acquiring credentials. +func (h *ConsentHandler) RegisterRoutes(r *router.Router) { + // HTML pages (GET, served by Go) + r.GET("/oauth/consent", h.handleIdentityPage) + r.GET("/oauth/consent/mcps", h.handleMCPsPage) + + // API actions (POST) + // NOTE: All state-mutating endpoints use POST. CSRF protection relies on the + // SameSite=Lax browser-binding cookie (__bifrost_flow_secret) combined with + // the flow_id — SameSite=Lax blocks cross-site POST, and the cookie is + // HttpOnly+Secure. This is sufficient for the threat model here. + r.POST("/api/oauth/per-user/consent/vk", h.handleSubmitVK) + r.POST("/api/oauth/per-user/consent/user-id", h.handleSubmitUserID) + r.POST("/api/oauth/per-user/consent/skip", h.handleSkip) + r.POST("/api/oauth/per-user/consent/submit", h.handleSubmit) +} + +// ---------- HTML pages ---------- + +// handleIdentityPage renders the identity selection page with three options: +// User ID, Virtual Key, or skip (lazy auth when tools are called). +// GET /oauth/consent?flow_id=xxx[&error=xxx] +func (h *ConsentHandler) handleIdentityPage(ctx *fasthttp.RequestCtx) { + flowID := string(ctx.QueryArgs().Peek("flow_id")) + errorMsg := string(ctx.QueryArgs().Peek("error")) + + if flowID == "" { + ctx.SetStatusCode(fasthttp.StatusBadRequest) + ctx.SetBodyString("Missing flow_id") + return + } + + if h.store.ConfigStore == nil { + ctx.SetStatusCode(fasthttp.StatusServiceUnavailable) + ctx.SetBodyString("Config store unavailable") + return + } + + flow, err := h.store.ConfigStore.GetPerUserOAuthPendingFlow(ctx, flowID) + if err != nil { + ctx.SetStatusCode(fasthttp.StatusInternalServerError) + ctx.SetBodyString("Failed to load consent flow.") + return + } + if flow == nil || time.Now().After(flow.ExpiresAt) { + ctx.SetStatusCode(fasthttp.StatusBadRequest) + ctx.SetBodyString("Invalid or expired consent flow. Please restart the authentication process.") + return + } + if !validateFlowBrowserSecret(ctx, flow) { + ctx.SetStatusCode(fasthttp.StatusForbidden) + ctx.SetBodyString("Flow does not belong to this browser session. Please restart the authentication process.") + return + } + + h.store.Mu.RLock() + enforceVK := h.store.ClientConfig.EnforceAuthOnInference + h.store.Mu.RUnlock() + + safeFlowID := html.EscapeString(flowID) + safeError := html.EscapeString(errorMsg) + + errorBanner := "" + if safeError != "" { + errorBanner = fmt.Sprintf(`
%s
`, safeError) + } + + skipOption := "" + if !enforceVK { + skipOption = fmt.Sprintf(` +
+ Skip for now + Connect to services when a tool is called +
+ + +
+
`, safeFlowID) + } + + ctx.SetStatusCode(fasthttp.StatusOK) + ctx.SetContentType("text/html; charset=utf-8") + ctx.SetBodyString(fmt.Sprintf(` + + + + +Connect to Bifrost + + + +
+

Connect to Bifrost

+

Choose how to identify yourself for this session.

+

This setup page expires in 15 minutes.

+ %s +
+ User ID + Use a stable identifier — access all available services +
+ + + + +
+
+
+ Virtual Key + Use a VK — access services within your key's limits +
+ + + + +
+
+ %s +
+ +`, bifrostPageCSS, errorBanner, safeFlowID, safeFlowID, skipOption)) +} + +// handleMCPsPage renders the MCP authentication list page. +// GET /oauth/consent/mcps?flow_id=xxx +func (h *ConsentHandler) handleMCPsPage(ctx *fasthttp.RequestCtx) { + flowID := string(ctx.QueryArgs().Peek("flow_id")) + + if flowID == "" { + ctx.SetStatusCode(fasthttp.StatusBadRequest) + ctx.SetBodyString("Missing flow_id") + return + } + + if h.store.ConfigStore == nil { + ctx.SetStatusCode(fasthttp.StatusServiceUnavailable) + ctx.SetBodyString("Config store unavailable") + return + } + + flow, err := h.store.ConfigStore.GetPerUserOAuthPendingFlow(ctx, flowID) + if err != nil { + ctx.SetStatusCode(fasthttp.StatusInternalServerError) + ctx.SetBodyString("Failed to load consent flow.") + return + } + if flow == nil || time.Now().After(flow.ExpiresAt) { + ctx.SetStatusCode(fasthttp.StatusBadRequest) + ctx.SetBodyString("Invalid or expired consent flow. Please restart the authentication process.") + return + } + if !validateFlowBrowserSecret(ctx, flow) { + ctx.SetStatusCode(fasthttp.StatusForbidden) + ctx.SetBodyString("Flow does not belong to this browser session. Please restart the authentication process.") + return + } + + // Find which MCP clients the user has already authed. + // Check both: tokens stored under the flow proxy (connected during this flow) + // and tokens already stored under the VK/user identity (connected in a prior flow). + completedTokens, err := h.store.ConfigStore.GetOauthUserTokensByGatewaySessionID(ctx, flowID) + if err != nil { + completedTokens = nil // non-fatal; just show no checkmarks + } + completedMCPs := make(map[string]bool, len(completedTokens)) + for _, t := range completedTokens { + completedMCPs[t.MCPClientID] = true + } + + // Per_user_oauth MCP clients visible to this identity — sorted for deterministic rendering. + // When a VK is set on the flow, only show clients that VK is allowed to use. + perUserClients := h.store.GetPerUserOAuthMCPClientsForVirtualKey(ctx, strVal(flow.VirtualKeyID)) + clientIDs := make([]string, 0, len(perUserClients)) + for id := range perUserClients { + clientIDs = append(clientIDs, id) + } + sort.Strings(clientIDs) + + safeFlowID := html.EscapeString(flowID) + + // Determine if user skipped identity selection. + isSkipped := strVal(flow.VirtualKeyID) == "" && strVal(flow.UserID) == "" + + // Build MCP rows — only show connect buttons if user has an identity. + var mcpRows strings.Builder + if isSkipped { + mcpRows.WriteString(`

You skipped identity selection. Services will be connected when you first use their tools. Since no identity is attached, your connections will only persist as long as the service keeps the OAuth token active — they will not be remembered across sessions.

`) + } else { + for _, clientID := range clientIDs { + clientName := perUserClients[clientID] + safeName := html.EscapeString(clientName) + + // Also check if a token already exists under the user's identity (e.g. from a prior LLM gateway auth). + alreadyConnected := completedMCPs[clientID] + if !alreadyConnected && (strVal(flow.VirtualKeyID) != "" || strVal(flow.UserID) != "") { + existing, tokenErr := h.store.ConfigStore.GetOauthUserTokenByIdentity(ctx, strVal(flow.VirtualKeyID), strVal(flow.UserID), "", clientID) + if tokenErr != nil { + logger.Warn("[consent/mcps] failed to check existing token: mcp_client_id=%s err=%v", clientID, tokenErr) + } + alreadyConnected = existing != nil + } + + if alreadyConnected { + mcpRows.WriteString(fmt.Sprintf(` +
+
%s
+ ✓ Connected +
`, safeName)) + } else { + connectURL := fmt.Sprintf("/api/oauth/per-user/upstream/authorize?mcp_client_id=%s&flow_id=%s", + url.QueryEscape(clientID), url.QueryEscape(flowID)) + mcpRows.WriteString(fmt.Sprintf(` +
+
%s
+ Connect +
`, safeName, html.EscapeString(connectURL))) + } + } + if len(perUserClients) == 0 { + mcpRows.WriteString(`

No MCP services require authentication.

`) + } + } + + ctx.SetStatusCode(fasthttp.StatusOK) + ctx.SetContentType("text/html; charset=utf-8") + ctx.SetBodyString(fmt.Sprintf(` + + + + +Connect Your Apps — Bifrost + + + +
+

Connect Your Apps

+

Authenticate with the services below to enable their tools.

+

This setup page expires in 15 minutes.

+
%s
+
+ + +
+ +
+ +`, bifrostPageCSS, mcpRows.String(), safeFlowID, safeFlowID)) +} + +// ---------- API action handlers ---------- + +// handleSubmitVK validates the submitted Virtual Key, links it to the pending flow, +// and redirects to the MCPs page. +// POST /api/oauth/per-user/consent/vk (form: flow_id, vk) +func (h *ConsentHandler) handleSubmitVK(ctx *fasthttp.RequestCtx) { + if h.store.ConfigStore == nil { + SendError(ctx, fasthttp.StatusServiceUnavailable, "Config store unavailable") + return + } + + flowID := string(ctx.FormValue("flow_id")) + vkValue := strings.TrimSpace(string(ctx.FormValue("vk"))) + + if flowID == "" { + SendError(ctx, fasthttp.StatusBadRequest, "flow_id is required") + return + } + + flow, err := h.store.ConfigStore.GetPerUserOAuthPendingFlow(ctx, flowID) + if err != nil { + SendError(ctx, fasthttp.StatusInternalServerError, "Failed to load consent flow") + return + } + if flow == nil || time.Now().After(flow.ExpiresAt) { + SendError(ctx, fasthttp.StatusBadRequest, "Invalid or expired consent flow") + return + } + if !validateFlowBrowserSecret(ctx, flow) { + SendError(ctx, fasthttp.StatusForbidden, "Flow does not belong to this browser session") + return + } + + if vkValue == "" { + redirectToIdentityPage(ctx, flowID, "Please enter a Virtual Key.") + return + } + + vk, err := h.store.ConfigStore.GetVirtualKeyByValue(ctx, vkValue) + if err != nil { + redirectToIdentityPage(ctx, flowID, "Failed to validate Virtual Key. Please try again.") + return + } + if vk == nil || !vk.IsActive { + redirectToIdentityPage(ctx, flowID, "Virtual Key not found or inactive. Please check and try again.") + return + } + + flow.VirtualKeyID = &vk.ID + flow.UserID = nil // Clear other identity to keep selection exclusive + if err := h.store.ConfigStore.UpdatePerUserOAuthPendingFlow(ctx, flow); err != nil { + redirectToIdentityPage(ctx, flowID, "Failed to save Virtual Key. Please try again.") + return + } + + ctx.Redirect(fmt.Sprintf("/oauth/consent/mcps?flow_id=%s", url.QueryEscape(flowID)), fasthttp.StatusFound) +} + +// handleSubmitUserID links a user-supplied User ID to the pending flow and proceeds to MCPs page. +// SECURITY: The User ID is self-declared (typed in a form) with no server-side verification. +// This matches the trust model of X-Bf-User-Id in the LLM gateway path. Deployments requiring +// verified identity should use Virtual Keys or an auth layer in front of Bifrost. +// POST /api/oauth/per-user/consent/user-id (form: flow_id, user_id) +func (h *ConsentHandler) handleSubmitUserID(ctx *fasthttp.RequestCtx) { + if h.store.ConfigStore == nil { + SendError(ctx, fasthttp.StatusServiceUnavailable, "Config store unavailable") + return + } + + flowID := string(ctx.FormValue("flow_id")) + userID := strings.TrimSpace(string(ctx.FormValue("user_id"))) + + if flowID == "" { + SendError(ctx, fasthttp.StatusBadRequest, "flow_id is required") + return + } + + flow, err := h.store.ConfigStore.GetPerUserOAuthPendingFlow(ctx, flowID) + if err != nil { + SendError(ctx, fasthttp.StatusInternalServerError, "Failed to load consent flow") + return + } + if flow == nil || time.Now().After(flow.ExpiresAt) { + SendError(ctx, fasthttp.StatusBadRequest, "Invalid or expired consent flow") + return + } + if !validateFlowBrowserSecret(ctx, flow) { + SendError(ctx, fasthttp.StatusForbidden, "Flow does not belong to this browser session") + return + } + + if userID == "" { + redirectToIdentityPage(ctx, flowID, "Please enter a User ID.") + return + } + if len(userID) > 255 { + redirectToIdentityPage(ctx, flowID, "User ID is too long (max 255 characters).") + return + } + + if userID != "" { + flow.UserID = &userID + } + flow.VirtualKeyID = nil // Clear other identity to keep selection exclusive + if err := h.store.ConfigStore.UpdatePerUserOAuthPendingFlow(ctx, flow); err != nil { + redirectToIdentityPage(ctx, flowID, "Failed to save User ID. Please try again.") + return + } + + ctx.Redirect(fmt.Sprintf("/oauth/consent/mcps?flow_id=%s", url.QueryEscape(flowID)), fasthttp.StatusFound) +} + +// handleSkip skips identity selection and proceeds directly to the MCPs page. +// Upstream services will be connected lazily when tools are first called. +// POST /api/oauth/per-user/consent/skip (form: flow_id) +func (h *ConsentHandler) handleSkip(ctx *fasthttp.RequestCtx) { + if h.store.ConfigStore == nil { + SendError(ctx, fasthttp.StatusServiceUnavailable, "Config store unavailable") + return + } + + flowID := string(ctx.FormValue("flow_id")) + if flowID == "" { + SendError(ctx, fasthttp.StatusBadRequest, "flow_id is required") + return + } + + flow, err := h.store.ConfigStore.GetPerUserOAuthPendingFlow(ctx, flowID) + if err != nil { + SendError(ctx, fasthttp.StatusInternalServerError, "Failed to load consent flow") + return + } + if flow == nil || time.Now().After(flow.ExpiresAt) { + SendError(ctx, fasthttp.StatusBadRequest, "Invalid or expired consent flow") + return + } + if !validateFlowBrowserSecret(ctx, flow) { + SendError(ctx, fasthttp.StatusForbidden, "Flow does not belong to this browser session") + return + } + + h.store.Mu.RLock() + enforceVK := h.store.ClientConfig.EnforceAuthOnInference + h.store.Mu.RUnlock() + + if enforceVK { + redirectToIdentityPage(ctx, flowID, "An identity (Virtual Key or User ID) is required. Please choose one to continue.") + return + } + + // Clear any previously selected identity so skip truly resets the flow. + if strVal(flow.VirtualKeyID) != "" || strVal(flow.UserID) != "" { + flow.VirtualKeyID = nil + flow.UserID = nil + if err := h.store.ConfigStore.UpdatePerUserOAuthPendingFlow(ctx, flow); err != nil { + redirectToIdentityPage(ctx, flowID, "Failed to clear identity. Please try again.") + return + } + } + + // Skip goes straight to MCPs page; no identity means only lazy auth is available. + ctx.Redirect(fmt.Sprintf("/oauth/consent/mcps?flow_id=%s", url.QueryEscape(flowID)), fasthttp.StatusFound) +} + +// handleSubmit finalises the consent flow: +// 1. Creates a real Bifrost session (TablePerUserOAuthSession) +// 2. Migrates upstream tokens from the flow proxy to the real session +// 3. Issues a TablePerUserOAuthCode +// 4. Deletes the PendingFlow +// 5. Redirects to the original MCP client callback URL with code + state +// +// POST /api/oauth/per-user/consent/submit (form: flow_id) +func (h *ConsentHandler) handleSubmit(ctx *fasthttp.RequestCtx) { + if h.store.ConfigStore == nil { + SendError(ctx, fasthttp.StatusServiceUnavailable, "Config store unavailable") + return + } + + flowID := string(ctx.FormValue("flow_id")) + if flowID == "" { + SendError(ctx, fasthttp.StatusBadRequest, "flow_id is required") + return + } + flow, err := h.store.ConfigStore.GetPerUserOAuthPendingFlow(ctx, flowID) + if err != nil { + SendError(ctx, fasthttp.StatusInternalServerError, "Failed to load consent flow") + return + } + if flow == nil { + SendError(ctx, fasthttp.StatusBadRequest, "Invalid consent flow") + return + } + if time.Now().After(flow.ExpiresAt) { + SendError(ctx, fasthttp.StatusBadRequest, "Consent flow has expired. Please restart the authentication process.") + return + } + if !validateFlowBrowserSecret(ctx, flow) { + SendError(ctx, fasthttp.StatusForbidden, "Flow does not belong to this browser session") + return + } + + // Server-side enforcement: reject if identity is required but not provided. + h.store.Mu.RLock() + enforceAuth := h.store.ClientConfig.EnforceAuthOnInference + h.store.Mu.RUnlock() + if enforceAuth && strVal(flow.VirtualKeyID) == "" && strVal(flow.UserID) == "" { + redirectToIdentityPage(ctx, flowID, "An identity (Virtual Key or User ID) is required. Please choose one to continue.") + return + } + + // 1. Generate session credentials. + accessToken, err := generateOpaqueToken(32) + if err != nil { + SendError(ctx, fasthttp.StatusInternalServerError, "Failed to generate session token") + return + } + refreshToken, err := generateOpaqueToken(32) + if err != nil { + SendError(ctx, fasthttp.StatusInternalServerError, "Failed to generate refresh token") + return + } + + session := &tables.TablePerUserOAuthSession{ + ID: uuid.New().String(), + AccessToken: accessToken, + RefreshToken: refreshToken, + ClientID: flow.ClientID, + VirtualKeyID: flow.VirtualKeyID, + UserID: flow.UserID, + ExpiresAt: time.Now().Add(24 * time.Hour), + } + + // 2. Generate authorization code. + code, err := generateOpaqueToken(32) + if err != nil { + SendError(ctx, fasthttp.StatusInternalServerError, "Failed to generate authorization code") + return + } + codeRecord := &tables.TablePerUserOAuthCode{ + ID: uuid.New().String(), + Code: code, + ClientID: flow.ClientID, + RedirectURI: flow.RedirectURI, + CodeChallenge: flow.CodeChallenge, + SessionID: session.ID, // Links token endpoint to this session so it can return the same access token + // Scopes intentionally omitted: the consent flow has no scope selection step. + ExpiresAt: time.Now().Add(5 * time.Minute), + } + + // 3. Atomically consume the pending flow, create session, and create auth code. + // If another concurrent request already consumed the flow, rowsAffected will be 0. + rowsAffected, err := h.store.ConfigStore.FinalizePerUserOAuthConsent(ctx, flowID, session, codeRecord) + if err != nil { + if errors.Is(err, schemas.ErrPerUserOAuthPendingFlowExpired) { + SendError(ctx, fasthttp.StatusGone, "Consent flow has expired. Please restart the authentication process.") + return + } + SendError(ctx, fasthttp.StatusInternalServerError, "Failed to finalize consent flow") + return + } + if rowsAffected == 0 { + SendError(ctx, fasthttp.StatusConflict, "Consent flow has already been submitted") + return + } + logger.Debug("[consent/submit] session created: session_id=%s flow_id=%s", session.ID, flowID) + + // 4. Migrate upstream tokens from flow proxy sessions to real session (non-fatal). + if err := h.store.ConfigStore.TransferOauthUserTokensFromGatewaySession(ctx, flowID, accessToken, strVal(flow.VirtualKeyID), strVal(flow.UserID)); err != nil { + // Non-fatal: tokens can be re-acquired on first tool use. + logger.Warn("[consent/submit] failed to transfer upstream tokens: flow_id=%s err=%v", flowID, err) + } + + // 5. Redirect to MCP client callback with code + original state. + redirectURL, err := url.Parse(flow.RedirectURI) + if err != nil { + SendError(ctx, fasthttp.StatusInternalServerError, "Invalid redirect URI in pending flow") + return + } + q := redirectURL.Query() + q.Set("code", code) + if flow.State != "" { + q.Set("state", flow.State) + } + redirectURL.RawQuery = q.Encode() + + ctx.Redirect(redirectURL.String(), fasthttp.StatusFound) +} + +// ---------- helpers ---------- + +// bifrostPageCSS is the shared inline CSS for all Go-rendered consent/callback pages. +// It mirrors Bifrost's UI design tokens: teal primary, zinc palette, Geist font stack. +const bifrostPageCSS = ` + *,*::before,*::after{box-sizing:border-box;margin:0;padding:0} + body{font-family:"Geist",system-ui,-apple-system,sans-serif;font-size:0.95rem; + line-height:1.5;background:#f4f4f5;color:oklch(0.141 0.005 285.823); + display:flex;align-items:center;justify-content:center;min-height:100vh; + -webkit-font-smoothing:antialiased} + .card{background:#fff;border:1px solid oklch(0.92 0.004 286.32);border-radius:12px; + padding:40px;width:100%;max-width:480px} + h1{font-size:1.25rem;font-weight:600;color:oklch(0.141 0.005 285.823);margin-bottom:6px} + .subtitle{font-size:0.825rem;color:oklch(0.552 0.016 285.938);line-height:1.5;margin-bottom:24px} + label{display:block;font-size:0.825rem;font-weight:500;color:oklch(0.141 0.005 285.823);margin-bottom:5px} + input[type=text],input[type=password]{width:100%;padding:8px 12px;border:1px solid oklch(0.92 0.004 286.32); + border-radius:0.5rem;font-size:0.875rem;outline:none; + transition:border-color .15s,box-shadow .15s;margin-bottom:10px; + background:#fff;color:oklch(0.141 0.005 285.823)} + input[type=text]:focus,input[type=password]:focus{border-color:oklch(0.5081 0.1049 165.61); + box-shadow:0 0 0 3px oklch(0.5081 0.1049 165.61 / 0.15)} + .btn{display:block;width:100%;padding:9px 16px;border-radius:0.5rem;font-size:0.875rem; + font-weight:500;cursor:pointer;border:none;text-align:center;text-decoration:none; + transition:background .15s;font-family:inherit} + .btn-primary{background:oklch(0.5081 0.1049 165.61);color:oklch(0.985 0 0)} + .btn-primary:hover{background:oklch(0.43 0.1049 165.61)} + .btn-ghost{background:transparent;border:1px solid oklch(0.92 0.004 286.32); + color:oklch(0.552 0.016 285.938);display:inline-block;width:auto;padding:8px 16px} + .btn-ghost:hover{background:#f4f4f5} + .error-banner{background:oklch(0.97 0.02 27);border:1px solid oklch(0.88 0.06 27); + border-radius:0.5rem;padding:12px 14px;margin-bottom:18px; + color:oklch(0.50 0.18 27);font-size:0.825rem} +` + +// redirectToIdentityPage redirects to the identity selection page with an error message. +func redirectToIdentityPage(ctx *fasthttp.RequestCtx, flowID, errorMsg string) { + u := fmt.Sprintf("/oauth/consent?flow_id=%s&error=%s", + url.QueryEscape(flowID), url.QueryEscape(errorMsg)) + ctx.Redirect(u, fasthttp.StatusFound) +} + +// strVal safely dereferences a *string, returning "" for nil. +func strVal(s *string) string { + if s == nil { + return "" + } + return *s +} diff --git a/transports/bifrost-http/handlers/oauth2_metadata.go b/transports/bifrost-http/handlers/oauth2_metadata.go index 32e7eda21d..2a764291e4 100644 --- a/transports/bifrost-http/handlers/oauth2_metadata.go +++ b/transports/bifrost-http/handlers/oauth2_metadata.go @@ -43,6 +43,10 @@ func (h *OAuthMetadataHandler) RegisterRoutes(r *router.Router, middlewares ...s // // GET /.well-known/oauth-protected-resource func (h *OAuthMetadataHandler) handleProtectedResourceMetadata(ctx *fasthttp.RequestCtx) { + if clients := h.store.GetPerUserOAuthMCPClients(); len(clients) == 0 { + sendStringError(ctx, fasthttp.StatusNotFound, "Not Found") + return + } scheme := "http" if ctx.IsTLS() || string(ctx.Request.Header.Peek("X-Forwarded-Proto")) == "https" { scheme = "https" @@ -53,7 +57,7 @@ func (h *OAuthMetadataHandler) handleProtectedResourceMetadata(ctx *fasthttp.Req SendJSON(ctx, map[string]interface{}{ "resource": baseURL + "/mcp", "authorization_servers": []string{baseURL}, - "scopes_supported": []string{"mcp:read", "mcp:write"}, + "scopes_supported": []string{"mcp:read", "mcp:write"}, "bearer_methods_supported": []string{"header"}, }) } @@ -64,6 +68,10 @@ func (h *OAuthMetadataHandler) handleProtectedResourceMetadata(ctx *fasthttp.Req // // GET /.well-known/oauth-authorization-server func (h *OAuthMetadataHandler) handleAuthorizationServerMetadata(ctx *fasthttp.RequestCtx) { + if clients := h.store.GetPerUserOAuthMCPClients(); len(clients) == 0 { + sendStringError(ctx, fasthttp.StatusNotFound, "Not Found") + return + } scheme := "http" if ctx.IsTLS() || string(ctx.Request.Header.Peek("X-Forwarded-Proto")) == "https" { scheme = "https" @@ -78,7 +86,7 @@ func (h *OAuthMetadataHandler) handleAuthorizationServerMetadata(ctx *fasthttp.R "registration_endpoint": baseURL + "/api/oauth/per-user/register", "response_types_supported": []string{"code"}, "grant_types_supported": []string{"authorization_code"}, - "code_challenge_methods_supported": []string{"S256"}, + "code_challenge_methods_supported": []string{"S256"}, "token_endpoint_auth_methods_supported": []string{"none"}, "scopes_supported": []string{"mcp:read", "mcp:write"}, }) diff --git a/transports/bifrost-http/handlers/oauth2_per_user.go b/transports/bifrost-http/handlers/oauth2_per_user.go index e316a3df37..3c6f59041b 100644 --- a/transports/bifrost-http/handlers/oauth2_per_user.go +++ b/transports/bifrost-http/handlers/oauth2_per_user.go @@ -57,6 +57,11 @@ func (h *PerUserOAuthHandler) handleDynamicClientRegistration(ctx *fasthttp.Requ return } + if len(h.store.GetPerUserOAuthMCPClients()) == 0 { + sendStringError(ctx, fasthttp.StatusNotFound, "Not found") + return + } + var req struct { ClientName string `json:"client_name"` RedirectURIs []string `json:"redirect_uris"` @@ -104,24 +109,31 @@ func (h *PerUserOAuthHandler) handleDynamicClientRegistration(ctx *fasthttp.Requ ctx.SetStatusCode(fasthttp.StatusCreated) SendJSON(ctx, map[string]interface{}{ "client_id": clientID, - "client_name": req.ClientName, - "redirect_uris": req.RedirectURIs, - "grant_types": grantTypes, - "response_types": req.ResponseTypes, + "client_name": req.ClientName, + "redirect_uris": req.RedirectURIs, + "grant_types": grantTypes, + "response_types": req.ResponseTypes, "token_endpoint_auth_method": "none", }) } // handleAuthorize handles the OAuth 2.1 authorization endpoint. -// It validates the request, shows a consent page, and issues an authorization code. +// Instead of issuing a code immediately, it validates the request parameters, +// creates a PendingFlow record, and redirects the user to the consent screen. +// The code is only issued after the user completes the consent flow (VK + MCP auths). // -// GET /api/oauth/per-user/authorize?response_type=code&client_id=xxx&redirect_uri=xxx&state=xxx&code_challenge=xxx&code_challenge_method=S256 +// GET /api/oauth/per-user/authorize?response_type=code&client_id=xxx&redirect_uri=xxx&code_challenge=xxx&code_challenge_method=S256[&state=xxx] func (h *PerUserOAuthHandler) handleAuthorize(ctx *fasthttp.RequestCtx) { if h.store.ConfigStore == nil { SendError(ctx, fasthttp.StatusServiceUnavailable, "OAuth authorization unavailable: config store is disabled") return } + if len(h.store.GetPerUserOAuthMCPClients()) == 0 { + sendStringError(ctx, fasthttp.StatusNotFound, "Not found") + return + } + // Extract parameters responseType := string(ctx.QueryArgs().Peek("response_type")) clientID := string(ctx.QueryArgs().Peek("client_id")) @@ -129,15 +141,14 @@ func (h *PerUserOAuthHandler) handleAuthorize(ctx *fasthttp.RequestCtx) { state := string(ctx.QueryArgs().Peek("state")) codeChallenge := string(ctx.QueryArgs().Peek("code_challenge")) codeChallengeMethod := string(ctx.QueryArgs().Peek("code_challenge_method")) - scope := string(ctx.QueryArgs().Peek("scope")) // Validate required parameters if responseType != "code" { SendError(ctx, fasthttp.StatusBadRequest, "response_type must be 'code'") return } - if clientID == "" || redirectURI == "" || state == "" { - SendError(ctx, fasthttp.StatusBadRequest, "client_id, redirect_uri, and state are required") + if clientID == "" || redirectURI == "" { + SendError(ctx, fasthttp.StatusBadRequest, "client_id and redirect_uri are required") return } if codeChallenge == "" || codeChallengeMethod != "S256" { @@ -145,7 +156,7 @@ func (h *PerUserOAuthHandler) handleAuthorize(ctx *fasthttp.RequestCtx) { return } - // Validate client exists and redirect_uri is allowed + // Validate client exists and redirect_uri is registered client, err := h.store.ConfigStore.GetPerUserOAuthClientByClientID(ctx, clientID) if err != nil { SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to validate client: %v", err)) @@ -155,8 +166,6 @@ func (h *PerUserOAuthHandler) handleAuthorize(ctx *fasthttp.RequestCtx) { SendError(ctx, fasthttp.StatusBadRequest, "Unknown client_id") return } - - // Verify redirect_uri is registered var allowedURIs []string json.Unmarshal([]byte(client.RedirectURIs), &allowedURIs) uriAllowed := false @@ -171,41 +180,45 @@ func (h *PerUserOAuthHandler) handleAuthorize(ctx *fasthttp.RequestCtx) { return } - // Generate authorization code - code, err := generateOpaqueToken(32) - if err != nil { - SendError(ctx, fasthttp.StatusInternalServerError, "Failed to generate authorization code") - return - } - - // Store authorization code - codeRecord := &tables.TablePerUserOAuthCode{ - ID: uuid.New().String(), - Code: code, - ClientID: clientID, - RedirectURI: redirectURI, - CodeChallenge: codeChallenge, - Scopes: scope, - ExpiresAt: time.Now().Add(5 * time.Minute), - } - if err := h.store.ConfigStore.CreatePerUserOAuthCode(ctx, codeRecord); err != nil { - SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to store authorization code: %v", err)) - return - } - - // Auto-approve and redirect back with code (no consent page for MCP clients) - // Build redirect URL with code and state - redirectURL, err := url.Parse(redirectURI) + // Generate a browser-binding secret so only the initiating browser can resume this flow. + browserSecret, err := generateOpaqueToken(32) if err != nil { - SendError(ctx, fasthttp.StatusBadRequest, "Invalid redirect_uri") - return - } - q := redirectURL.Query() - q.Set("code", code) - q.Set("state", state) - redirectURL.RawQuery = q.Encode() - - ctx.Redirect(redirectURL.String(), fasthttp.StatusFound) + SendError(ctx, fasthttp.StatusInternalServerError, "Failed to generate browser secret") + return + } + browserSecretHash := fmt.Sprintf("%x", sha256.Sum256([]byte(browserSecret))) + + // Create a PendingFlow to carry OAuth params through the consent screen. + flow := &tables.TablePerUserOAuthPendingFlow{ + ID: uuid.New().String(), + ClientID: clientID, + RedirectURI: redirectURI, + CodeChallenge: codeChallenge, + State: state, + BrowserSecretHash: browserSecretHash, + ExpiresAt: time.Now().Add(15 * time.Minute), + } + if err := h.store.ConfigStore.CreatePerUserOAuthPendingFlow(ctx, flow); err != nil { + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to create pending flow: %v", err)) + return + } + logger.Debug("[oauth/authorize] PendingFlow created: flow_id=%s client_id=%s", flow.ID, clientID) + + // Set HttpOnly cookie binding this flow to the current browser. + var cookie fasthttp.Cookie + cookie.SetKey("__bifrost_flow_secret") + cookie.SetValue(browserSecret) + cookie.SetPath("/") + cookie.SetHTTPOnly(true) + cookie.SetSameSite(fasthttp.CookieSameSiteLaxMode) + isSecure := ctx.IsTLS() || string(ctx.Request.Header.Peek("X-Forwarded-Proto")) == "https" + cookie.SetSecure(isSecure) + cookie.SetMaxAge(15 * 60) // 15 minutes, matching flow TTL + ctx.Response.Header.SetCookie(&cookie) + + // Redirect to consent screen with flow_id (relative path — stays on current origin). + consentURL := fmt.Sprintf("/oauth/consent?flow_id=%s", url.QueryEscape(flow.ID)) + ctx.Redirect(consentURL, fasthttp.StatusFound) } // handleToken handles the OAuth 2.1 token endpoint. @@ -218,6 +231,11 @@ func (h *PerUserOAuthHandler) handleToken(ctx *fasthttp.RequestCtx) { return } + if len(h.store.GetPerUserOAuthMCPClients()) == 0 { + sendStringError(ctx, fasthttp.StatusNotFound, "Not found") + return + } + // Parse form-encoded body grantType := string(ctx.FormValue("grant_type")) code := string(ctx.FormValue("code")) @@ -230,8 +248,8 @@ func (h *PerUserOAuthHandler) handleToken(ctx *fasthttp.RequestCtx) { return } - if code == "" || clientID == "" || codeVerifier == "" { - sendOAuthError(ctx, fasthttp.StatusBadRequest, "invalid_request", "code, client_id, and code_verifier are required") + if code == "" || codeVerifier == "" { + sendOAuthError(ctx, fasthttp.StatusBadRequest, "invalid_request", "code and code_verifier are required") return } @@ -252,14 +270,20 @@ func (h *PerUserOAuthHandler) handleToken(ctx *fasthttp.RequestCtx) { return } - // Validate client_id matches - if codeRecord.ClientID != clientID { + // Validate client_id if provided — some public clients omit it (RFC 6749 §4.1.3 allows + // omitting client_id when the client is not authenticating with the server). + // The code record already binds the code to the correct client, so this is safe. + if clientID != "" && codeRecord.ClientID != clientID { + logger.Debug("[oauth/token] client_id mismatch: code_client=%s request_client=%s", codeRecord.ClientID, clientID) sendOAuthError(ctx, fasthttp.StatusBadRequest, "invalid_grant", "client_id mismatch") return } + // Use the client_id from the code record as the authoritative value. + clientID = codeRecord.ClientID // Validate redirect_uri matches if redirectURI != "" && codeRecord.RedirectURI != redirectURI { + logger.Debug("[oauth/token] redirect_uri mismatch: code=%s request=%s", codeRecord.RedirectURI, redirectURI) sendOAuthError(ctx, fasthttp.StatusBadRequest, "invalid_grant", "redirect_uri mismatch") return } @@ -268,36 +292,64 @@ func (h *PerUserOAuthHandler) handleToken(ctx *fasthttp.RequestCtx) { verifierHash := sha256.Sum256([]byte(codeVerifier)) computedChallenge := base64.RawURLEncoding.EncodeToString(verifierHash[:]) if computedChallenge != codeRecord.CodeChallenge { + logger.Debug("[oauth/token] PKCE verification failed") sendOAuthError(ctx, fasthttp.StatusBadRequest, "invalid_grant", "PKCE verification failed") return } - // Generate access token and refresh token - accessToken, err := generateOpaqueToken(32) - if err != nil { - sendOAuthError(ctx, fasthttp.StatusInternalServerError, "server_error", "Failed to generate access token") - return - } - refreshToken, err := generateOpaqueToken(32) - if err != nil { - sendOAuthError(ctx, fasthttp.StatusInternalServerError, "server_error", "Failed to generate refresh token") - return - } + // If the code was issued by the consent flow (handleSubmit), the session already exists + // with the upstream tokens transferred to it. Reuse that session's access token so the + // client receives the token that the upstream (Notion, GitHub, etc.) tokens are linked to. + var accessToken string + var expiresAt time.Time - // Store session - expiresAt := time.Now().Add(24 * time.Hour) // 24-hour access token - session := &tables.TablePerUserOAuthSession{ - ID: uuid.New().String(), - AccessToken: accessToken, - RefreshToken: refreshToken, - ClientID: clientID, - ExpiresAt: expiresAt, - } - if err := h.store.ConfigStore.CreatePerUserOAuthSession(ctx, session); err != nil { - sendOAuthError(ctx, fasthttp.StatusInternalServerError, "server_error", "Failed to create session") - return + if codeRecord.SessionID != "" { + existingSession, err := h.store.ConfigStore.GetPerUserOAuthSessionByID(ctx, codeRecord.SessionID) + if err != nil { + logger.Info("[oauth/token] Failed to load existing session: session_id=%s err=%v", codeRecord.SessionID, err) + sendOAuthError(ctx, fasthttp.StatusInternalServerError, "server_error", "Failed to load session") + return + } + if existingSession == nil { + logger.Info("[oauth/token] Existing session not found: session_id=%s", codeRecord.SessionID) + sendOAuthError(ctx, fasthttp.StatusInternalServerError, "server_error", "Session not found") + return + } + if !existingSession.ExpiresAt.After(time.Now()) { + sendOAuthError(ctx, fasthttp.StatusBadRequest, "invalid_grant", "Session expired") + return + } + accessToken = existingSession.AccessToken + expiresAt = existingSession.ExpiresAt + logger.Debug("[oauth/token] reusing consent session: session_id=%s", existingSession.ID) + } else { + // Fallback: no linked session (legacy path) — create a new one. + var newAccessToken, newRefreshToken string + newAccessToken, err = generateOpaqueToken(32) + if err != nil { + sendOAuthError(ctx, fasthttp.StatusInternalServerError, "server_error", "Failed to generate access token") + return + } + newRefreshToken, err = generateOpaqueToken(32) + if err != nil { + sendOAuthError(ctx, fasthttp.StatusInternalServerError, "server_error", "Failed to generate refresh token") + return + } + expiresAt = time.Now().Add(24 * time.Hour) + newSession := &tables.TablePerUserOAuthSession{ + ID: uuid.New().String(), + AccessToken: newAccessToken, + RefreshToken: newRefreshToken, + ClientID: clientID, + ExpiresAt: expiresAt, + } + if err := h.store.ConfigStore.CreatePerUserOAuthSession(ctx, newSession); err != nil { + sendOAuthError(ctx, fasthttp.StatusInternalServerError, "server_error", "Failed to create session") + return + } + accessToken = newAccessToken + logger.Debug("[oauth/token] created new session (legacy path): session_id=%s", newSession.ID) } - // Return OAuth token response ctx.SetContentType("application/json") ctx.SetStatusCode(fasthttp.StatusOK) @@ -320,7 +372,28 @@ func sendOAuthError(ctx *fasthttp.RequestCtx, statusCode int, errorCode, descrip ctx.SetBody(resp) } +func sendStringError(ctx *fasthttp.RequestCtx, statusCode int, message string) { + ctx.SetContentType("text/plain") + ctx.SetStatusCode(statusCode) + ctx.SetBodyString(message) +} + // generateOpaqueToken generates a cryptographically secure random token. +// validateFlowBrowserSecret checks that the request carries the __bifrost_flow_secret +// cookie matching the hash stored on the pending flow. Returns true if valid. +func validateFlowBrowserSecret(ctx *fasthttp.RequestCtx, flow *tables.TablePerUserOAuthPendingFlow) bool { + if flow.BrowserSecretHash == "" { + // Legacy flow without browser binding — allow for backwards compatibility. + return true + } + secret := ctx.Request.Header.Cookie("__bifrost_flow_secret") + if len(secret) == 0 { + return false + } + hash := fmt.Sprintf("%x", sha256.Sum256(secret)) + return hash == flow.BrowserSecretHash +} + func generateOpaqueToken(length int) (string, error) { bytes := make([]byte, length) if _, err := rand.Read(bytes); err != nil { @@ -333,9 +406,10 @@ func generateOpaqueToken(length int) (string, error) { // When a user needs to authenticate with an upstream MCP server (e.g., Notion), // this endpoint redirects them to the upstream provider's OAuth authorize URL. // After the user authenticates, the callback stores their upstream token linked -// to their Bifrost session. +// to either their Bifrost session (runtime flow) or a PendingFlow (consent flow). // -// GET /api/oauth/per-user/upstream/authorize?mcp_client_id=xxx&session=xxx +// Runtime flow: GET /api/oauth/per-user/upstream/authorize?mcp_client_id=xxx&session=xxx +// Consent flow: GET /api/oauth/per-user/upstream/authorize?mcp_client_id=xxx&flow_id=xxx func (h *PerUserOAuthHandler) handleUpstreamAuthorize(ctx *fasthttp.RequestCtx) { if h.store.ConfigStore == nil { SendError(ctx, fasthttp.StatusServiceUnavailable, "OAuth upstream authorization unavailable: config store is disabled") @@ -344,31 +418,64 @@ func (h *PerUserOAuthHandler) handleUpstreamAuthorize(ctx *fasthttp.RequestCtx) mcpClientID := string(ctx.QueryArgs().Peek("mcp_client_id")) sessionID := string(ctx.QueryArgs().Peek("session")) + flowID := string(ctx.QueryArgs().Peek("flow_id")) - if mcpClientID == "" || sessionID == "" { - SendError(ctx, fasthttp.StatusBadRequest, "mcp_client_id and session are required") + if mcpClientID == "" || (sessionID == "" && flowID == "") { + SendError(ctx, fasthttp.StatusBadRequest, "mcp_client_id and either session or flow_id are required") return } - // Validate the Bifrost session exists - session, err := h.store.ConfigStore.GetPerUserOAuthSessionByID(ctx, sessionID) - if err != nil || session == nil { - SendError(ctx, fasthttp.StatusUnauthorized, "Invalid or expired session") - return + // Resolve identity depending on whether this is a runtime session or a consent flow. + var virtualKeyID, userID, proxySessionToken, gatewaySessionID string + if flowID != "" { + // Consent flow: use the pending flow for identity and proxy token. + flow, err := h.store.ConfigStore.GetPerUserOAuthPendingFlow(ctx, flowID) + if err != nil || flow == nil || time.Now().After(flow.ExpiresAt) { + SendError(ctx, fasthttp.StatusUnauthorized, "Invalid or expired consent flow") + return + } + if !validateFlowBrowserSecret(ctx, flow) { + SendError(ctx, fasthttp.StatusForbidden, "Flow does not belong to this browser session") + return + } + if strVal(flow.VirtualKeyID) != "" { + virtualKeyID = *flow.VirtualKeyID + } + if strVal(flow.UserID) != "" { + userID = *flow.UserID + } + // Use a prefixed flow token so the callback can detect the consent path. + // Include mcpClientID to avoid unique constraint violations when multiple + // MCP services are connected in the same consent flow. + proxySessionToken = "flow:" + flowID + ":" + mcpClientID + gatewaySessionID = flowID + } else { + // Runtime flow: validate the existing Bifrost session. + bifrostSession, err := h.store.ConfigStore.GetPerUserOAuthSessionByID(ctx, sessionID) + if err != nil || bifrostSession == nil { + SendError(ctx, fasthttp.StatusUnauthorized, "Invalid or expired session") + return + } + if !bifrostSession.ExpiresAt.After(time.Now()) { + SendError(ctx, fasthttp.StatusUnauthorized, "Invalid or expired session") + return + } + virtualKeyID = strVal(bifrostSession.VirtualKeyID) + userID = strVal(bifrostSession.UserID) + proxySessionToken = "runtime:" + sessionID + ":" + mcpClientID + gatewaySessionID = sessionID } - // Look up the MCP client config to get the template OAuth config + // Look up the MCP client config to get the template OAuth config. mcpClient, err := h.store.ConfigStore.GetMCPClientByID(ctx, mcpClientID) if err != nil || mcpClient == nil { SendError(ctx, fasthttp.StatusNotFound, "MCP client not found") return } - if mcpClient.AuthType != string(schemas.MCPAuthTypePerUserOauth) { SendError(ctx, fasthttp.StatusBadRequest, "MCP client does not use per-user OAuth") return } - if mcpClient.OauthConfigID == nil || *mcpClient.OauthConfigID == "" { SendError(ctx, fasthttp.StatusBadRequest, "MCP client has no OAuth configuration") return @@ -381,7 +488,7 @@ func (h *PerUserOAuthHandler) handleUpstreamAuthorize(ctx *fasthttp.RequestCtx) return } - // Generate PKCE challenge for upstream + // Generate PKCE challenge for upstream. codeVerifier, err := generateOpaqueToken(32) if err != nil { SendError(ctx, fasthttp.StatusInternalServerError, "Failed to generate PKCE verifier") @@ -390,53 +497,55 @@ func (h *PerUserOAuthHandler) handleUpstreamAuthorize(ctx *fasthttp.RequestCtx) verifierHash := sha256.Sum256([]byte(codeVerifier)) codeChallenge := base64.RawURLEncoding.EncodeToString(verifierHash[:]) - // Generate state for upstream + // Generate state for upstream. state, err := generateOpaqueToken(32) if err != nil { SendError(ctx, fasthttp.StatusInternalServerError, "Failed to generate state token") return } - // Build redirect URI (Bifrost's callback endpoint) + // Build redirect URI (Bifrost's callback endpoint). scheme := "http" if ctx.IsTLS() || string(ctx.Request.Header.Peek("X-Forwarded-Proto")) == "https" { scheme = "https" } host := string(ctx.Host()) redirectURI := fmt.Sprintf("%s://%s/api/oauth/callback", scheme, host) - - // Look up Bifrost session to propagate identity to upstream OAuth flow - var virtualKeyID, userID string - if bifrostSession, err := h.store.ConfigStore.GetPerUserOAuthSessionByID(ctx, sessionID); err == nil && bifrostSession != nil { - virtualKeyID = bifrostSession.VirtualKeyID - userID = bifrostSession.UserID + var vkId *string + if virtualKeyID != "" { + vkId = &virtualKeyID } - - // Store upstream OAuth session (links state → session + mcp_client + identity) - upstreamSession := &tables.TableOauthUserSession{ - ID: uuid.New().String(), - MCPClientID: mcpClientID, - OauthConfigID: *mcpClient.OauthConfigID, - State: state, - CodeVerifier: codeVerifier, - GatewaySessionID: sessionID, // Link to Bifrost MCP gateway session - VirtualKeyID: virtualKeyID, - UserID: userID, - Status: "pending", - ExpiresAt: time.Now().Add(15 * time.Minute), + var uid *string + if userID != "" { + uid = &userID } + // Store upstream OAuth session linking state → MCP client + identity. + upstreamSession := &tables.TableOauthUserSession{ + ID: uuid.New().String(), + MCPClientID: mcpClientID, + OauthConfigID: *mcpClient.OauthConfigID, + State: state, + CodeVerifier: codeVerifier, + SessionToken: proxySessionToken, // "runtime:xxx" for runtime flow; "flow:xxx" for consent flow + GatewaySessionID: gatewaySessionID, + VirtualKeyID: vkId, + UserID: uid, + Status: "pending", + ExpiresAt: time.Now().Add(15 * time.Minute), + } + logger.Debug("[oauth/upstream-authorize] creating upstream session: mcp_client=%s flow=%s", mcpClientID, proxySessionToken) if err := h.store.ConfigStore.CreateOauthUserSession(ctx, upstreamSession); err != nil { SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Failed to create upstream OAuth session: %v", err)) return } - // Parse scopes from template config + // Parse scopes from template config. var scopes []string if templateConfig.Scopes != "" { json.Unmarshal([]byte(templateConfig.Scopes), &scopes) } - // Build upstream authorize URL with PKCE + // Build upstream authorize URL with PKCE. params := url.Values{} params.Set("response_type", "code") params.Set("client_id", templateConfig.ClientID) diff --git a/transports/bifrost-http/handlers/providers.go b/transports/bifrost-http/handlers/providers.go index 665e716b20..2c13fff8a0 100644 --- a/transports/bifrost-http/handlers/providers.go +++ b/transports/bifrost-http/handlers/providers.go @@ -18,6 +18,7 @@ import ( "github.com/maximhq/bifrost/core/schemas" "github.com/maximhq/bifrost/framework/configstore" "github.com/maximhq/bifrost/framework/configstore/tables" + "github.com/maximhq/bifrost/framework/modelcatalog" "github.com/maximhq/bifrost/transports/bifrost-http/lib" "github.com/valyala/fasthttp" ) @@ -938,7 +939,7 @@ func filterModelsByKeysWithAccessMap(config *configstore.ProviderConfig, provide for _, model := range models { grantedBy := make([]string, 0, len(matchedKeys)) for _, matched := range matchedKeys { - if keyAllowsModelForList(provider, model, matched.key, modelCatalog) { + if keyAllowsModelForList(matched.key, model) { grantedBy = append(grantedBy, matched.id) } } diff --git a/transports/bifrost-http/lib/config.go b/transports/bifrost-http/lib/config.go index a0409d7aec..6a721e5916 100644 --- a/transports/bifrost-http/lib/config.go +++ b/transports/bifrost-http/lib/config.go @@ -2430,6 +2430,73 @@ func (c *Config) GetAllowOnAllVirtualKeysClients() map[string]string { return result } +// GetPerUserOAuthMCPClients returns a map of clientID -> clientName for all MCP clients +// that have AuthType set to "per_user_oauth". The returned map is a copy, safe for concurrent use. +func (c *Config) GetPerUserOAuthMCPClients() map[string]string { + c.muMCP.RLock() + defer c.muMCP.RUnlock() + + if c.MCPConfig == nil { + return nil + } + result := make(map[string]string) + for _, client := range c.MCPConfig.ClientConfigs { + if client != nil && client.AuthType == schemas.MCPAuthTypePerUserOauth { + result[client.ID] = client.Name + } + } + return result +} + +// GetPerUserOAuthMCPClientsForVirtualKey returns a map of clientID -> clientName for +// per_user_oauth MCP clients that the given VK is allowed to use. A client is included if: +// - AllowOnAllVirtualKeys is true, OR +// - The VK has an explicit entry in governance_virtual_key_mcp_configs for that client. +// +// If virtualKeyID is empty, all per-user OAuth clients are returned. If the config store +// is unavailable or the VK lookup fails, only clients with AllowOnAllVirtualKeys=true are returned. +func (c *Config) GetPerUserOAuthMCPClientsForVirtualKey(ctx context.Context, virtualKeyID string) map[string]string { + all := c.GetPerUserOAuthMCPClients() + if virtualKeyID == "" { + return all + } + + // Build set of per-user OAuth clients that allow all virtual keys. + c.muMCP.RLock() + allowAll := make(map[string]string) + if c.MCPConfig != nil { + for _, client := range c.MCPConfig.ClientConfigs { + if client != nil && client.AuthType == schemas.MCPAuthTypePerUserOauth && client.AllowOnAllVirtualKeys { + allowAll[client.ID] = client.Name + } + } + } + c.muMCP.RUnlock() + + if c.ConfigStore == nil { + return allowAll + } + + // Get VK-specific MCP configs (with MCPClient preloaded so we have the string ClientID). + vkConfigs, err := c.ConfigStore.GetVirtualKeyMCPConfigs(ctx, virtualKeyID) + if err != nil { + // Fail closed: only return clients that are allowed on all virtual keys. + return allowAll + } + explicit := make(map[string]bool, len(vkConfigs)) + for _, cfg := range vkConfigs { + explicit[cfg.MCPClient.ClientID] = true + } + + result := make(map[string]string) + for clientID, clientName := range all { + if _, ok := allowAll[clientID]; ok || explicit[clientID] { + result[clientID] = clientName + } + } + return result +} + // GetPluginOrder returns the names of all base plugins in their sorted placement order. // This method is lock-free and safe for concurrent access from hot paths. // Do not modify the returned slice; it is a shared snapshot and must be treated read-only. diff --git a/transports/bifrost-http/lib/config_test.go b/transports/bifrost-http/lib/config_test.go index 6708fac37e..0723aeb2ff 100644 --- a/transports/bifrost-http/lib/config_test.go +++ b/transports/bifrost-http/lib/config_test.go @@ -1196,6 +1196,31 @@ func (m *MockConfigStore) UpdatePerUserOAuthCode(ctx context.Context, code *tabl return nil } +func (m *MockConfigStore) GetPerUserOAuthPendingFlow(ctx context.Context, id string) (*tables.TablePerUserOAuthPendingFlow, error) { + return nil, nil +} +func (m *MockConfigStore) CreatePerUserOAuthPendingFlow(ctx context.Context, flow *tables.TablePerUserOAuthPendingFlow) error { + return nil +} +func (m *MockConfigStore) UpdatePerUserOAuthPendingFlow(ctx context.Context, flow *tables.TablePerUserOAuthPendingFlow) error { + return nil +} +func (m *MockConfigStore) DeletePerUserOAuthPendingFlow(ctx context.Context, id string) error { + return nil +} +func (m *MockConfigStore) ConsumePerUserOAuthPendingFlow(ctx context.Context, id string) (int64, error) { + return 1, nil +} +func (m *MockConfigStore) GetOauthUserTokensByGatewaySessionID(ctx context.Context, gatewaySessionID string) ([]tables.TableOauthUserToken, error) { + return nil, nil +} +func (m *MockConfigStore) TransferOauthUserTokensFromGatewaySession(ctx context.Context, gatewaySessionID, realSessionToken, virtualKeyID, userID string) error { + return nil +} +func (m *MockConfigStore) FinalizePerUserOAuthConsent(ctx context.Context, flowID string, session *tables.TablePerUserOAuthSession, code *tables.TablePerUserOAuthCode) (int64, error) { + return 1, nil +} + // Routing rules func (m *MockConfigStore) GetRoutingRules(ctx context.Context) ([]tables.TableRoutingRule, error) { return nil, nil diff --git a/transports/bifrost-http/lib/ctx.go b/transports/bifrost-http/lib/ctx.go index 876e487249..36c56cc2e5 100644 --- a/transports/bifrost-http/lib/ctx.go +++ b/transports/bifrost-http/lib/ctx.go @@ -470,11 +470,6 @@ func ConvertToBifrostContext(ctx *fasthttp.RequestCtx, allowDirectKeys bool, mat }) bifrostCtx.SetValue(schemas.BifrostContextKeyRequestHeaders, allHeaders) - // Extract per-user MCP OAuth session token from X-Bifrost-MCP-Session header - if mcpSession := string(ctx.Request.Header.Peek("X-Bifrost-MCP-Session")); mcpSession != "" { - bifrostCtx.SetValue(schemas.BifrostContextKeyMCPUserSession, mcpSession) - } - // Extract per-user MCP OAuth user identifier from X-Bf-User-Id header if mcpUserID := string(ctx.Request.Header.Peek("X-Bf-User-Id")); mcpUserID != "" { bifrostCtx.SetValue(schemas.BifrostContextKeyMCPUserID, mcpUserID) diff --git a/transports/bifrost-http/server/server.go b/transports/bifrost-http/server/server.go index 2679f5d45d..46178d6680 100644 --- a/transports/bifrost-http/server/server.go +++ b/transports/bifrost-http/server/server.go @@ -247,9 +247,13 @@ func (s *BifrostHTTPServer) VerifyPerUserOAuthConnection(ctx context.Context, co return s.Client.VerifyPerUserOAuthConnection(ctx, config, accessToken) } -// SetClientTools delegates to the Bifrost client to update tool map for an existing MCP client. +// SetClientTools delegates to the Bifrost client to update tool map for an existing MCP client, +// then re-syncs the MCP server so the new tools are immediately visible via /mcp. func (s *BifrostHTTPServer) SetClientTools(clientID string, tools map[string]schemas.ChatTool, toolNameMapping map[string]string) { s.Client.SetClientTools(clientID, tools, toolNameMapping) + if err := s.MCPServerHandler.SyncAllMCPServers(context.Background()); err != nil { + logger.Warn("failed to sync MCP servers after setting client tools: %v", err) + } } // ExecuteChatMCPTool executes an MCP tool call and returns the result as a chat message. @@ -1100,6 +1104,8 @@ func (s *BifrostHTTPServer) RegisterAPIRoutes(ctx context.Context, callbacks Ser oauthMetadataHandler.RegisterRoutes(s.Router) perUserOAuthHandler := handlers.NewPerUserOAuthHandler(s.Config) perUserOAuthHandler.RegisterRoutes(s.Router) + consentHandler := handlers.NewConsentHandler(s.Config) + consentHandler.RegisterRoutes(s.Router) if pluginsHandler != nil { pluginsHandler.RegisterRoutes(s.Router, middlewares...) } diff --git a/ui/app/workspace/logs/page.tsx b/ui/app/workspace/logs/page.tsx index 08229ca99e..8479700c67 100644 --- a/ui/app/workspace/logs/page.tsx +++ b/ui/app/workspace/logs/page.tsx @@ -57,6 +57,17 @@ export default function LogsPage() { const [selectedSessionId, setSelectedSessionId] = useState(null); const [sessionHighlightedLogId, setSessionHighlightedLogId] = useState(null); + // Stable handler so SessionDetailsSheet's loadSessionPage useCallback doesn't + // recreate on every parent re-render. Without this, every live WebSocket log + // tick would re-render LogsPage, hand the sheet a fresh inline arrow, recreate + // loadSessionPage, and trip the reset effect — wiping sessionLogs and + // refetching from offset 0 while the sheet is open. + const handleSessionSheetOpenChange = useCallback((open: boolean) => { + if (!open) { + setSelectedSessionId(null); + setSessionHighlightedLogId(null); + } + }, []); const [isChartOpen, setIsChartOpen] = useState(true); const [triggerGetLogById] = useLazyGetLogByIdQuery(); const [fetchedLog, setFetchedLog] = useState(null); @@ -224,14 +235,25 @@ export default function LogsPage() { }), // Only re-derive filters when filter-related URL params change (not pagination) [ - urlState.providers, urlState.models, urlState.aliases, urlState.status, urlState.objects, - urlState.selected_key_ids, urlState.virtual_key_ids, urlState.routing_rule_ids, + urlState.providers, + urlState.models, + urlState.aliases, + urlState.status, + urlState.objects, + urlState.selected_key_ids, + urlState.virtual_key_ids, + urlState.routing_rule_ids, urlState.routing_engine_used, - urlState.user_ids, urlState.team_ids, urlState.customer_ids, urlState.business_unit_ids, + urlState.user_ids, + urlState.team_ids, + urlState.customer_ids, + urlState.business_unit_ids, urlState.content_search, urlState.parent_request_id, - urlState.start_time, urlState.end_time, - urlState.missing_cost_only, urlState.metadata_filters, + urlState.start_time, + urlState.end_time, + urlState.missing_cost_only, + urlState.metadata_filters, ], ); diff --git a/ui/app/workspace/logs/sheets/logDetailView.tsx b/ui/app/workspace/logs/sheets/logDetailView.tsx index 423109beb1..b0bf499732 100644 --- a/ui/app/workspace/logs/sheets/logDetailView.tsx +++ b/ui/app/workspace/logs/sheets/logDetailView.tsx @@ -342,7 +342,6 @@ export function LogDetailView({ } /> )} - {log.fallback_index > 0 && } {log.virtual_key && } {log.routing_engines_used && log.routing_engines_used.length > 0 && ( diff --git a/ui/app/workspace/virtual-keys/views/virtualKeysTable.tsx b/ui/app/workspace/virtual-keys/views/virtualKeysTable.tsx index ddc501405b..d8a023658d 100644 --- a/ui/app/workspace/virtual-keys/views/virtualKeysTable.tsx +++ b/ui/app/workspace/virtual-keys/views/virtualKeysTable.tsx @@ -18,17 +18,33 @@ import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from "@ import { Table, TableBody, TableCell, TableHead, TableHeader, TableRow } from "@/components/ui/table"; import { useCopyToClipboard } from "@/hooks/useCopyToClipboard"; import { resetDurationLabels } from "@/lib/constants/governance"; -import { getErrorMessage, useDeleteVirtualKeyMutation } from "@/lib/store"; +import { getErrorMessage, useDeleteVirtualKeyMutation, useLazyGetVirtualKeysQuery } from "@/lib/store"; import { Customer, Team, VirtualKey } from "@/lib/types/governance"; import { cn } from "@/lib/utils"; import { formatCurrency } from "@/lib/utils/governance"; import { RbacOperation, RbacResource, useRbac } from "@enterprise/lib"; -import { ChevronLeft, ChevronRight, Copy, Edit, Eye, EyeOff, Plus, Search, Trash2 } from "lucide-react"; +import { + ArrowUpDown, + ChevronLeft, + ChevronRight, + Copy, + Download, + Edit, + Eye, + EyeOff, + Loader2, + Plus, + Search, + ShieldCheck, + Trash2, +} from "lucide-react"; import { useMemo, useState } from "react"; import { toast } from "sonner"; import VirtualKeyDetailSheet from "./virtualKeyDetailsSheet"; import { VirtualKeysEmptyState } from "./virtualKeysEmptyState"; import VirtualKeySheet from "./virtualKeySheet"; +import { Dialog, DialogContent, DialogDescription, DialogFooter, DialogHeader, DialogTitle } from "@/components/ui/dialog"; +import { Label } from "@/components/ui/label"; const formatResetDuration = (duration: string) => resetDurationLabels[duration] || duration; @@ -38,7 +54,7 @@ function virtualKeysToCSV(vks: VirtualKey[]): string { const headers = ["Name", "Status", "Assigned To", "Budget Limit", "Budget Spent", "Budget Reset", "Description", "Created At"]; const rows = vks.map((vk) => { const isExhausted = - (vk.budget?.current_usage && vk.budget?.max_limit && vk.budget.current_usage >= vk.budget.max_limit) || + vk.budgets?.some((b) => b.current_usage >= b.max_limit) || (vk.rate_limit?.token_current_usage && vk.rate_limit?.token_max_limit && vk.rate_limit.token_current_usage >= vk.rate_limit.token_max_limit) || @@ -47,9 +63,9 @@ function virtualKeysToCSV(vks: VirtualKey[]): string { vk.rate_limit.request_current_usage >= vk.rate_limit.request_max_limit); const status = vk.is_active ? (isExhausted ? "Exhausted" : "Active") : "Inactive"; const assignedTo = vk.team ? `Team: ${vk.team.name}` : vk.customer ? `Customer: ${vk.customer.name}` : ""; - const budgetLimit = vk.budget ? formatCurrency(vk.budget.max_limit) : ""; - const budgetSpent = vk.budget ? formatCurrency(vk.budget.current_usage) : ""; - const budgetReset = vk.budget ? formatResetDuration(vk.budget.reset_duration) : ""; + const budgetLimit = vk.budgets?.length ? vk.budgets.map((b) => formatCurrency(b.max_limit)).join("; ") : ""; + const budgetSpent = vk.budgets?.length ? vk.budgets.map((b) => formatCurrency(b.current_usage)).join("; ") : ""; + const budgetReset = vk.budgets?.length ? vk.budgets.map((b) => formatResetDuration(b.reset_duration)).join("; ") : ""; return [vk.name, status, assignedTo, budgetLimit, budgetSpent, budgetReset, vk.description || "", vk.created_at]; }); return [headers, ...rows].map((row) => row.map((cell) => `"${String(cell).replace(/"/g, '""')}"`).join(",")).join("\n"); @@ -109,6 +125,10 @@ export default function VirtualKeysTable({ const [revealedKeys, setRevealedKeys] = useState>(new Set()); const [selectedVirtualKeyId, setSelectedVirtualKeyId] = useState(null); const [showDetailSheet, setShowDetailSheet] = useState(false); + const [showExportDialog, setShowExportDialog] = useState(false); + const [exportScope, setExportScope] = useState("current_page"); + const [exportMaxLimit, setExportMaxLimit] = useState(""); + const [fetchVirtualKeys, { isFetching: isExporting }] = useLazyGetVirtualKeysQuery(); // Derive objects from props so they stay in sync with RTK cache updates const editingVirtualKey = useMemo( diff --git a/ui/package-lock.json b/ui/package-lock.json index cc5e052abd..f8f08a138c 100644 --- a/ui/package-lock.json +++ b/ui/package-lock.json @@ -13740,6 +13740,24 @@ } } }, + "node_modules/vitest/node_modules/yaml": { + "version": "2.8.3", + "resolved": "https://registry.npmjs.org/yaml/-/yaml-2.8.3.tgz", + "integrity": "sha512-AvbaCLOO2Otw/lW5bmh9d/WEdcDFdQp2Z2ZUH3pX9U2ihyUY0nvLv7J6TrWowklRGPYbB/IuIMfYgxaCPg5Bpg==", + "dev": true, + "license": "ISC", + "optional": true, + "peer": true, + "bin": { + "yaml": "bin.mjs" + }, + "engines": { + "node": ">= 14.6" + }, + "funding": { + "url": "https://github.com/sponsors/eemeli" + } + }, "node_modules/watchpack": { "version": "2.5.1", "resolved": "https://registry.npmjs.org/watchpack/-/watchpack-2.5.1.tgz",