diff --git a/core/bifrost.go b/core/bifrost.go index dd1b7cf84c..6d06ee7af2 100644 --- a/core/bifrost.go +++ b/core/bifrost.go @@ -3995,6 +3995,85 @@ func (bifrost *Bifrost) SelectKeyForProviderRequestType(ctx *schemas.BifrostCont return bifrost.keySelector(ctx, supportedKeys, providerKey, model) } +func setProviderContextMetadata(ctx *schemas.BifrostContext, config *schemas.ProviderConfig) { + isCustomProvider := config != nil && config.CustomProviderConfig != nil + ctx.SetValue(schemas.BifrostContextKeyIsCustomProvider, isCustomProvider) + if isCustomProvider { + ctx.SetValue(schemas.BifrostContextKeyCustomProviderMetadata, &schemas.CustomProviderContextMetadata{ + ProviderKey: schemas.ModelProvider(config.CustomProviderConfig.CustomProviderKey), + BaseProviderType: config.CustomProviderConfig.BaseProviderType, + SupportsResponsesAPI: config.CustomProviderConfig.SupportsResponsesAPI, + }) + return + } + ctx.SetValue(schemas.BifrostContextKeyCustomProviderMetadata, nil) +} + +func (bifrost *Bifrost) setProviderContextMetadataForKey(ctx *schemas.BifrostContext, providerKey schemas.ModelProvider) error { + config, err := bifrost.account.GetConfigForProvider(providerKey) + if err != nil { + return fmt.Errorf("failed to get config for provider %s: %w", providerKey, err) + } + if config == nil { + return fmt.Errorf("config is nil for provider %s", providerKey) + } + setProviderContextMetadata(ctx, config) + return nil +} + +func logResponsesToChatCompletionRetry(logger schemas.Logger, providerKey schemas.ModelProvider, model string, bifrostErr *schemas.BifrostError, warnings []string, isStreaming bool) { + if logger == nil { + return + } + + if isStreaming { + logger.Warn("custom provider %s does not appear to support the OpenAI responses API cleanly for streaming; retrying via chat completions fallback: %s", providerKey, schemas.ResponsesToChatCompletionFallbackErrorMessage(bifrostErr)) + } else { + logger.Warn("custom provider %s does not appear to support the OpenAI responses API cleanly; retrying via chat completions fallback: %s", providerKey, schemas.ResponsesToChatCompletionFallbackErrorMessage(bifrostErr)) + } + + if len(warnings) > 0 { + logger.Warn("responses->chat completion fallback for provider %s model %s is compatibility-only: %s", providerKey, model, strings.Join(warnings, "; ")) + } +} + +func runPostLLMHooksWithResponsesCompatMetadata(ctx *schemas.BifrostContext, pipeline *PluginPipeline, result *schemas.BifrostResponse, bifrostErr *schemas.BifrostError, runFrom int) (*schemas.BifrostResponse, *schemas.BifrostError) { + resp, processedErr := pipeline.RunPostLLMHooks(ctx, result, bifrostErr, runFrom) + schemas.ApplyResponsesToChatCompletionFallbackMetadata(ctx, resp, processedErr) + return resp, processedErr +} + +// tryResponsesToChatCompletionFallback checks whether a failed Responses API call is eligible +// for an operation-level fallback to Chat Completions (runtime auto-detection path, +// SupportsResponsesAPI == nil). This is NOT a second retry system — it reuses the same +// executeRequestWithRetries engine with the transformed request. +// +// Streaming semantics: fallback only triggers when executeRequestWithRetries returned an error, +// meaning no partial stream was ever emitted to the client. The fallback starts a fresh +// stream from scratch — no chunk duplication or data loss. +// +// Key reuse: the same keyProvider closure from the original call is passed through, so key +// selection and rotation follow the exact same rules. For keyless providers (keyProvider == nil), +// executeRequestWithRetries handles this natively with a zero-value key. +func (bifrost *Bifrost) tryResponsesToChatCompletionFallback( + ctx *schemas.BifrostContext, + provider schemas.Provider, + bifrostErr *schemas.BifrostError, +) bool { + if bifrostErr == nil { + return false + } + state, ok := schemas.GetResponsesToChatCompletionCompatState(ctx) + if !ok || state == nil || state.FallbackRequest == nil || !state.ShouldRetry(bifrostErr) { + return false + } + + logResponsesToChatCompletionRetry(bifrost.logger, provider.GetProviderKey(), state.OriginalModel, bifrostErr, state.Warnings, state.IsStreaming) + + _, activated := schemas.ActivateResponsesToChatCompletionCompatState(ctx, schemas.ResponsesToChatCompletionFallbackReasonRuntimeUnsupported) + return activated +} + // WSStreamHooks holds the post-hook runner and cleanup function returned by RunStreamPreHooks. // Call PostHookRunner for each streaming chunk, setting StreamEndIndicator on the final chunk. // Call Cleanup when done to release the pipeline back to the pool. @@ -4021,6 +4100,19 @@ func (bifrost *Bifrost) RunStreamPreHooks(ctx *schemas.BifrostContext, req *sche ctx = bifrost.ctx } + provider, model, _ := req.GetRequestFields() + if err := bifrost.setProviderContextMetadataForKey(ctx, provider); err != nil { + bifrostErr := newBifrostError(err) + bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ + RequestType: req.RequestType, + Provider: provider, + OriginalModelRequested: model, + } + return nil, bifrostErr + } + schemas.ClearResponsesToChatCompletionFallback(ctx) + schemas.ClearResponsesToChatCompletionCompatState(ctx) + if _, ok := ctx.Value(schemas.BifrostContextKeyRequestID).(string); !ok { ctx.SetValue(schemas.BifrostContextKeyRequestID, uuid.New().String()) } @@ -4073,7 +4165,7 @@ func (bifrost *Bifrost) RunStreamPreHooks(ctx *schemas.BifrostContext, req *sche if shortCircuit != nil { if shortCircuit.Error != nil { shortCircuit.Error.PopulateExtraFields(req.RequestType, reqProvider, reqModel, reqModel) - _, bifrostErr := pipeline.RunPostLLMHooks(ctx, nil, shortCircuit.Error, preCount) + _, bifrostErr := runPostLLMHooksWithResponsesCompatMetadata(ctx, pipeline, nil, shortCircuit.Error, preCount) if bifrostErr != nil { bifrostErr.PopulateExtraFields(req.RequestType, reqProvider, reqModel, reqModel) } @@ -4089,7 +4181,7 @@ func (bifrost *Bifrost) RunStreamPreHooks(ctx *schemas.BifrostContext, req *sche } if shortCircuit.Response != nil { shortCircuit.Response.PopulateExtraFields(req.RequestType, reqProvider, reqModel, reqModel) - resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, shortCircuit.Response, nil, preCount) + resp, bifrostErr := runPostLLMHooksWithResponsesCompatMetadata(ctx, pipeline, shortCircuit.Response, nil, preCount) if bifrostErr != nil { bifrostErr.PopulateExtraFields(req.RequestType, reqProvider, reqModel, reqModel) } else if resp != nil { @@ -4120,7 +4212,7 @@ func (bifrost *Bifrost) RunStreamPreHooks(ctx *schemas.BifrostContext, req *sche if err != nil { err.PopulateExtraFields(req.RequestType, wsProvider, wsModel, wsModel) } - resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, result, err, preCount) + resp, bifrostErr := runPostLLMHooksWithResponsesCompatMetadata(ctx, pipeline, result, err, preCount) if IsFinalChunk(ctx) { drainAndAttachPluginLogs(ctx) if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && strings.TrimSpace(traceID) != "" { @@ -4620,6 +4712,18 @@ func (bifrost *Bifrost) tryRequest(ctx *schemas.BifrostContext, req *schemas.Bif return nil, bifrostErr } + if err := bifrost.setProviderContextMetadataForKey(ctx, provider); err != nil { + bifrostErr := newBifrostError(err) + bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ + RequestType: req.RequestType, + Provider: provider, + OriginalModelRequested: model, + } + return nil, bifrostErr + } + schemas.ClearResponsesToChatCompletionFallback(ctx) + schemas.ClearResponsesToChatCompletionCompatState(ctx) + // Add MCP tools to request if MCP is configured and requested if bifrost.MCPManager != nil { req = bifrost.MCPManager.AddToolsToRequest(ctx, req) @@ -4648,8 +4752,7 @@ func (bifrost *Bifrost) tryRequest(ctx *schemas.BifrostContext, req *schemas.Bif if shortCircuit != nil { // Handle short-circuit with response (success case) if shortCircuit.Response != nil { - shortCircuit.Response.PopulateExtraFields(req.RequestType, provider, model, model) - resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, shortCircuit.Response, nil, preCount) + resp, bifrostErr := runPostLLMHooksWithResponsesCompatMetadata(ctx, pipeline, shortCircuit.Response, nil, preCount) if bifrostErr != nil { bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) } else if resp != nil { @@ -4663,8 +4766,7 @@ func (bifrost *Bifrost) tryRequest(ctx *schemas.BifrostContext, req *schemas.Bif } // Handle short-circuit with error if shortCircuit.Error != nil { - shortCircuit.Error.PopulateExtraFields(req.RequestType, provider, model, model) - resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, nil, shortCircuit.Error, preCount) + resp, bifrostErr := runPostLLMHooksWithResponsesCompatMetadata(ctx, pipeline, nil, shortCircuit.Error, preCount) if bifrostErr != nil { bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) } else if resp != nil { @@ -4766,7 +4868,7 @@ func (bifrost *Bifrost) tryRequest(ctx *schemas.BifrostContext, req *schemas.Bif pluginCount := len(*bifrost.llmPlugins.Load()) select { case result = <-msg.Response: - resp, bifrostErr := pipeline.RunPostLLMHooks(msg.Context, result, nil, pluginCount) + resp, bifrostErr := runPostLLMHooksWithResponsesCompatMetadata(msg.Context, pipeline, result, nil, pluginCount) if bifrostErr != nil { bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) } else if resp != nil { @@ -4795,7 +4897,7 @@ func (bifrost *Bifrost) tryRequest(ctx *schemas.BifrostContext, req *schemas.Bif return resp, nil case bifrostErrVal := <-msg.Err: bifrostErrPtr := &bifrostErrVal - resp, bifrostErrPtr = pipeline.RunPostLLMHooks(msg.Context, nil, bifrostErrPtr, pluginCount) + resp, bifrostErrPtr = runPostLLMHooksWithResponsesCompatMetadata(msg.Context, pipeline, nil, bifrostErrPtr, pluginCount) if bifrostErrPtr != nil { bifrostErrPtr.PopulateExtraFields(req.RequestType, provider, model, model) } else if resp != nil { @@ -4855,6 +4957,18 @@ func (bifrost *Bifrost) tryStreamRequest(ctx *schemas.BifrostContext, req *schem return nil, bifrostErr } + if err := bifrost.setProviderContextMetadataForKey(ctx, provider); err != nil { + bifrostErr := newBifrostError(err) + bifrostErr.ExtraFields = schemas.BifrostErrorExtraFields{ + RequestType: req.RequestType, + Provider: provider, + OriginalModelRequested: model, + } + return nil, bifrostErr + } + schemas.ClearResponsesToChatCompletionFallback(ctx) + schemas.ClearResponsesToChatCompletionCompatState(ctx) + // Add MCP tools to request if MCP is configured and requested if req.RequestType != schemas.SpeechStreamRequest && req.RequestType != schemas.TranscriptionStreamRequest && bifrost.MCPManager != nil { req = bifrost.MCPManager.AddToolsToRequest(ctx, req) @@ -4899,8 +5013,7 @@ func (bifrost *Bifrost) tryStreamRequest(ctx *schemas.BifrostContext, req *schem if shortCircuit != nil { // Handle short-circuit with response (success case) if shortCircuit.Response != nil { - shortCircuit.Response.PopulateExtraFields(req.RequestType, provider, model, model) - resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, shortCircuit.Response, nil, preCount) + resp, bifrostErr := runPostLLMHooksWithResponsesCompatMetadata(ctx, pipeline, shortCircuit.Response, nil, preCount) if bifrostErr != nil { bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) } else if resp != nil { @@ -4925,7 +5038,7 @@ func (bifrost *Bifrost) tryStreamRequest(ctx *schemas.BifrostContext, req *schem if err != nil { err.PopulateExtraFields(req.RequestType, provider, model, model) } - resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, result, err, preCount) + resp, bifrostErr := runPostLLMHooksWithResponsesCompatMetadata(ctx, pipeline, result, err, preCount) if IsFinalChunk(ctx) { drainAndAttachPluginLogs(ctx) } @@ -4974,20 +5087,19 @@ func (bifrost *Bifrost) tryStreamRequest(ctx *schemas.BifrostContext, req *schem // Run post hooks on the stream message processedResponse, processedError := pipelinePostHookRunner(ctx, bifrostResponse, streamMsg.BifrostError) - // Build the client-facing chunk via the shared helper, which strips raw - // request/response fields when in logging-only mode without mutating the - // shared processedResponse or processedError objects. streamResponse := providerUtils.BuildClientStreamChunk(ctx, processedResponse, processedError) - // Guarded send: if the consumer abandons outputStream (client - // disconnect, ctx cancel), drain the upstream shortCircuit.Stream - // so its producer can exit cleanly instead of blocking on its send. - select { - case outputStream <- streamResponse: - case <-ctx.Done(): - for range shortCircuit.Stream { + if streamResponse != nil { + // Guarded send: if the consumer abandons outputStream (client + // disconnect, ctx cancel), drain the upstream shortCircuit.Stream + // so its producer can exit cleanly instead of blocking on its send. + select { + case outputStream <- streamResponse: + case <-ctx.Done(): + for range shortCircuit.Stream { + } + return } - return } // TODO: Release the processed response immediately after use @@ -4998,8 +5110,7 @@ func (bifrost *Bifrost) tryStreamRequest(ctx *schemas.BifrostContext, req *schem } // Handle short-circuit with error if shortCircuit.Error != nil { - shortCircuit.Error.PopulateExtraFields(req.RequestType, provider, model, model) - resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, nil, shortCircuit.Error, preCount) + resp, bifrostErr := runPostLLMHooksWithResponsesCompatMetadata(ctx, pipeline, nil, shortCircuit.Error, preCount) if bifrostErr != nil { bifrostErr.PopulateExtraFields(req.RequestType, provider, model, model) } else if resp != nil { @@ -5108,7 +5219,7 @@ func (bifrost *Bifrost) tryStreamRequest(ctx *schemas.BifrostContext, req *schem // Marking final chunk ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) // On error we will complete post-hooks - recoveredResp, recoveredErr := pipeline.RunPostLLMHooks(ctx, nil, &bifrostErrVal, len(*bifrost.llmPlugins.Load())) + recoveredResp, recoveredErr := runPostLLMHooksWithResponsesCompatMetadata(ctx, pipeline, nil, &bifrostErrVal, len(*bifrost.llmPlugins.Load())) if recoveredErr != nil { recoveredErr.PopulateExtraFields(req.RequestType, provider, model, model) } else if recoveredResp != nil { @@ -5496,7 +5607,7 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas if cfg := config.CustomProviderConfig; cfg != nil && cfg.BaseProviderType != "" { baseProvider = cfg.BaseProviderType } - req.Context.SetValue(schemas.BifrostContextKeyIsCustomProvider, !IsStandardProvider(baseProvider)) + setProviderContextMetadata(req.Context, config) // Determine whether this provider attempt should capture raw payloads. // @@ -5766,6 +5877,90 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas } } + // Responses → Chat Completions operation-level fallback (runtime auto-detection). + // Only triggers when: + // 1. The Responses API returned a retryable error (HTTP 404/405/410/501) + // 2. The compat plugin marked the request as retry-eligible via context state + // Reuses the same keyProvider and executeRequestWithRetries engine — no separate + // retry system. The fallback creates its own per-attempt pipeline/postHookRunner + // inside the closure, matching the primary path's per-attempt allocation model. + // Streaming fallback only triggers if no stream was emitted (bifrostError != nil + // guarantees the original stream failed before any chunks were sent to the caller). + if bifrostError != nil { + if state, ok := schemas.GetResponsesToChatCompletionCompatState(req.Context); ok && state != nil && state.FallbackRequest != nil { + if bifrost.tryResponsesToChatCompletionFallback(req.Context, provider, bifrostError) { + fallbackMsg := &ChannelMessage{ + BifrostRequest: *state.FallbackRequest, + Context: req.Context, + } + fallbackModel := state.OriginalModel + if IsStreamRequestType(req.RequestType) { + stream, bifrostError = executeRequestWithRetries(req.Context, config, func(k schemas.Key) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { + // Resolve aliases for the key selected on this attempt + attemptResolvedModel := k.Aliases.Resolve(fallbackModel) + fallbackMsg.SetModel(attemptResolvedModel) + resolvedModel = attemptResolvedModel + + // Per-attempt pipeline/postHookRunner — mirrors the primary streaming path. + // ExtraFields use state.OriginalRequestType (the caller-facing type, e.g. + // ResponsesStreamRequest) rather than the internal FallbackRequest type + // (ChatCompletionStreamRequest), so observability reflects the original intent. + pipeline := bifrost.getPluginPipeline() + postHookRunner := func(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) { + if result != nil { + result.PopulateExtraFields(state.OriginalRequestType, provider.GetProviderKey(), originalModelRequested, attemptResolvedModel) + } + if err != nil { + err.PopulateExtraFields(state.OriginalRequestType, provider.GetProviderKey(), originalModelRequested, attemptResolvedModel) + } + resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, result, err, len(*bifrost.llmPlugins.Load())) + // Tag each streaming chunk with fallback metadata (the compat plugin's + // PostLLMHook only handles non-streaming response conversion). + schemas.ApplyResponsesToChatCompletionFallbackMetadata(ctx, resp, bifrostErr) + if IsFinalChunk(ctx) { + drainAndAttachPluginLogs(ctx) + } + if bifrostErr != nil { + bifrostErr.PopulateExtraFields(state.OriginalRequestType, provider.GetProviderKey(), originalModelRequested, attemptResolvedModel) + return nil, bifrostErr + } else if resp != nil { + resp.PopulateExtraFields(state.OriginalRequestType, provider.GetProviderKey(), originalModelRequested, attemptResolvedModel) + } + return resp, nil + } + var finalizerOnce sync.Once + postHookSpanFinalizer := func(ctx context.Context) { + finalizerOnce.Do(func() { + pipeline.FinalizeStreamingPostHookSpans(ctx) + bifrost.releasePluginPipeline(pipeline) + }) + } + lastAttemptFinalizer = postHookSpanFinalizer + streamCh, streamErr := bifrost.handleProviderStreamRequest(provider, fallbackMsg, k, postHookRunner, postHookSpanFinalizer) + if streamErr != nil && streamCh == nil { + finalizerOnce.Do(func() { + bifrost.releasePluginPipeline(pipeline) + }) + } + return streamCh, streamErr + }, keyProvider, state.FallbackRequest.RequestType, provider.GetProviderKey(), fallbackModel, state.FallbackRequest, bifrost.logger) + + // Clean up fallback stream's last attempt finalizer on error + if bifrostError != nil && lastAttemptFinalizer != nil { + lastAttemptFinalizer(req.Context) + } + } else { + result, bifrostError = executeRequestWithRetries(req.Context, config, func(k schemas.Key) (*schemas.BifrostResponse, *schemas.BifrostError) { + attemptResolvedModel := k.Aliases.Resolve(fallbackModel) + fallbackMsg.SetModel(attemptResolvedModel) + resolvedModel = attemptResolvedModel + return bifrost.handleProviderRequest(provider, config, fallbackMsg, k, keys) + }, keyProvider, state.FallbackRequest.RequestType, provider.GetProviderKey(), fallbackModel, state.FallbackRequest, bifrost.logger) + } + } + } + } + if bifrostError != nil { bifrostError.PopulateExtraFields(req.RequestType, provider.GetProviderKey(), originalModelRequested, resolvedModel) diff --git a/core/custom_provider_context_test.go b/core/custom_provider_context_test.go new file mode 100644 index 0000000000..8ce9d105b9 --- /dev/null +++ b/core/custom_provider_context_test.go @@ -0,0 +1,55 @@ +package bifrost + +import ( + "context" + "testing" + + schemas "github.com/maximhq/bifrost/core/schemas" +) + +func TestSetProviderContextMetadata_MarksCustomProvider(t *testing.T) { + ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + config := &schemas.ProviderConfig{ + CustomProviderConfig: &schemas.CustomProviderConfig{ + CustomProviderKey: "lmstudio", + BaseProviderType: schemas.OpenAI, + SupportsResponsesAPI: schemas.Ptr(false), + }, + } + + setProviderContextMetadata(ctx, config) + + isCustomProvider, ok := ctx.Value(schemas.BifrostContextKeyIsCustomProvider).(bool) + if !ok || !isCustomProvider { + t.Fatalf("expected custom provider flag to be true, got %v", ctx.Value(schemas.BifrostContextKeyIsCustomProvider)) + } + metadata, ok := schemas.GetCustomProviderContextMetadata(ctx) + if !ok || metadata == nil { + t.Fatal("expected custom provider metadata to be stored in context") + } + if metadata.ProviderKey != "lmstudio" { + t.Fatalf("expected custom provider key lmstudio, got %s", metadata.ProviderKey) + } + if metadata.BaseProviderType != schemas.OpenAI { + t.Fatalf("expected base provider type openai, got %s", metadata.BaseProviderType) + } + if metadata.SupportsResponsesAPI == nil || *metadata.SupportsResponsesAPI { + t.Fatalf("expected supports_responses_api=false metadata, got %+v", metadata.SupportsResponsesAPI) + } +} + +func TestSetProviderContextMetadata_ClearsCustomProviderValues(t *testing.T) { + ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + ctx.SetValue(schemas.BifrostContextKeyIsCustomProvider, true) + ctx.SetValue(schemas.BifrostContextKeyCustomProviderMetadata, &schemas.CustomProviderContextMetadata{ProviderKey: "lmstudio"}) + + setProviderContextMetadata(ctx, &schemas.ProviderConfig{}) + + isCustomProvider, ok := ctx.Value(schemas.BifrostContextKeyIsCustomProvider).(bool) + if !ok || isCustomProvider { + t.Fatalf("expected custom provider flag to be false, got %v", ctx.Value(schemas.BifrostContextKeyIsCustomProvider)) + } + if metadata := ctx.Value(schemas.BifrostContextKeyCustomProviderMetadata); metadata != nil { + t.Fatalf("expected custom provider metadata to be cleared, got %+v", metadata) + } +} diff --git a/core/providers/openai/openai.go b/core/providers/openai/openai.go index 2197de92e2..e76eecad5e 100644 --- a/core/providers/openai/openai.go +++ b/core/providers/openai/openai.go @@ -93,6 +93,8 @@ func (provider *OpenAIProvider) buildRequestURL(ctx *schemas.BifrostContext, def return provider.networkConfig.BaseURL + path } + + func (provider *OpenAIProvider) ListModels(ctx *schemas.BifrostContext, keys []schemas.Key, request *schemas.BifrostListModelsRequest) (*schemas.BifrostListModelsResponse, *schemas.BifrostError) { if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.ListModelsRequest); err != nil { return nil, err @@ -730,8 +732,10 @@ func HandleOpenAITextCompletionStreaming( // Returns a BifrostResponse containing the completion results or an error if the request fails. func (provider *OpenAIProvider) ChatCompletion(ctx *schemas.BifrostContext, key schemas.Key, request *schemas.BifrostChatRequest) (*schemas.BifrostChatResponse, *schemas.BifrostError) { // Check if chat completion is allowed for this provider - if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.ChatCompletionRequest); err != nil { - return nil, err + if !schemas.ShouldSkipOperationCheck(ctx) { + if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.ChatCompletionRequest); err != nil { + return nil, err + } } if provider.disableStore { @@ -895,8 +899,10 @@ func HandleOpenAIChatCompletionRequest( // Returns a channel for streaming responses and any error that occurred. func (provider *OpenAIProvider) ChatCompletionStream(ctx *schemas.BifrostContext, postHookRunner schemas.PostHookRunner, postHookSpanFinalizer func(context.Context), key schemas.Key, request *schemas.BifrostChatRequest) (chan *schemas.BifrostStreamChunk, *schemas.BifrostError) { // Check if chat completion stream is allowed for this provider - if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.ChatCompletionStreamRequest); err != nil { - return nil, err + if !schemas.ShouldSkipOperationCheck(ctx) { + if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.ChatCompletionStreamRequest); err != nil { + return nil, err + } } var authHeader map[string]string if key.Value.GetValue() != "" { @@ -1352,7 +1358,7 @@ func (provider *OpenAIProvider) Responses(ctx *schemas.BifrostContext, key schem request.Params.Store = schemas.Ptr(false) } - return HandleOpenAIResponsesRequest( + response, err := HandleOpenAIResponsesRequest( ctx, provider.client, provider.buildRequestURL(ctx, "/v1/responses", schemas.ResponsesRequest), @@ -1366,6 +1372,8 @@ func (provider *OpenAIProvider) Responses(ctx *schemas.BifrostContext, key schem nil, provider.logger, ) + + return response, err } // HandleOpenAIResponsesRequest handles a responses request to OpenAI's API. @@ -1507,6 +1515,7 @@ func (provider *OpenAIProvider) ResponsesStream(ctx *schemas.BifrostContext, pos if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.ResponsesStreamRequest); err != nil { return nil, err } + var authHeader map[string]string if key.Value.GetValue() != "" { authHeader = map[string]string{"Authorization": "Bearer " + key.Value.GetValue()} @@ -1519,7 +1528,7 @@ func (provider *OpenAIProvider) ResponsesStream(ctx *schemas.BifrostContext, pos } // Use shared streaming logic - return HandleOpenAIResponsesStreaming( + streamChan, err := HandleOpenAIResponsesStreaming( ctx, provider.streamingClient, provider.buildRequestURL(ctx, "/v1/responses", schemas.ResponsesStreamRequest), @@ -1537,6 +1546,8 @@ func (provider *OpenAIProvider) ResponsesStream(ctx *schemas.BifrostContext, pos provider.logger, postHookSpanFinalizer, ) + + return streamChan, err } // HandleOpenAIResponsesStreaming handles streaming for OpenAI-compatible APIs. diff --git a/core/providers/openai/responses_fallback_test.go b/core/providers/openai/responses_fallback_test.go new file mode 100644 index 0000000000..9a56a83e64 --- /dev/null +++ b/core/providers/openai/responses_fallback_test.go @@ -0,0 +1,383 @@ +package openai + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" + + "github.com/bytedance/sonic" + + schemas "github.com/maximhq/bifrost/core/schemas" +) + +type testOpenAILogger struct{} + +func (testOpenAILogger) Debug(string, ...any) {} +func (testOpenAILogger) Info(string, ...any) {} +func (testOpenAILogger) Warn(string, ...any) {} +func (testOpenAILogger) Error(string, ...any) {} +func (testOpenAILogger) Fatal(string, ...any) {} +func (testOpenAILogger) SetLevel(schemas.LogLevel) {} +func (testOpenAILogger) SetOutputType(schemas.LoggerOutputType) {} +func (testOpenAILogger) LogHTTPRequest(schemas.LogLevel, string) schemas.LogEventBuilder { + return schemas.NoopLogEvent +} + +func testOpenAIResponsesCtx() *schemas.BifrostContext { + return schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) +} + +func testOpenAIResponsesRequest() *schemas.BifrostResponsesRequest { + content := "hello" + return &schemas.BifrostResponsesRequest{ + Model: "test-model", + Input: []schemas.ResponsesMessage{{ + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser), + Content: &schemas.ResponsesMessageContent{ + ContentStr: &content, + }, + }}, + Params: &schemas.ResponsesParameters{ + MaxOutputTokens: schemas.Ptr(7), + }, + } +} + +func noopOpenAIPostHookRunner(_ *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) { + return result, err +} + +func newTestOpenAIProvider(baseURL string, customConfig *schemas.CustomProviderConfig) *OpenAIProvider { + return NewOpenAIProvider(&schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + BaseURL: baseURL, + DefaultRequestTimeoutInSeconds: 5, + }, + CustomProviderConfig: customConfig, + }, testOpenAILogger{}) +} + +func testChatCompletionBody(text string) []byte { + finishReason := string(schemas.BifrostFinishReasonStop) + response := &schemas.BifrostChatResponse{ + ID: "chatcmpl-test", + Object: "chat.completion", + Created: 1, + Model: "test-model", + Choices: []schemas.BifrostResponseChoice{{ + Index: 0, + FinishReason: &finishReason, + ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{ + Message: &schemas.ChatMessage{ + Role: schemas.ChatMessageRoleAssistant, + Content: &schemas.ChatMessageContent{ + ContentStr: &text, + }, + }, + }, + }}, + Usage: &schemas.BifrostLLMUsage{PromptTokens: 1, CompletionTokens: 1, TotalTokens: 2}, + } + body, _ := sonic.Marshal(response) + return body +} + +func testResponsesBody(text string) []byte { + messageType := schemas.ResponsesMessageTypeMessage + role := schemas.ResponsesInputMessageRoleAssistant + status := "completed" + textType := schemas.ResponsesOutputMessageContentTypeText + response := &schemas.BifrostResponsesResponse{ + ID: schemas.Ptr("resp-test"), + Object: "response", + CreatedAt: 1, + Model: "test-model", + Status: &status, + Output: []schemas.ResponsesMessage{{ + Type: &messageType, + Role: &role, + Content: &schemas.ResponsesMessageContent{ + ContentBlocks: []schemas.ResponsesMessageContentBlock{{ + Type: textType, + Text: &text, + }}, + }, + }}, + } + body, _ := sonic.Marshal(response) + return body +} + +func TestResponses_CustomProviderConfiguredUnsupported_DoesNotFallbackInsideProvider(t *testing.T) { + t.Parallel() + + var chatHits atomic.Int32 + var responsesHits atomic.Int32 + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/v1/chat/completions": + chatHits.Add(1) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write(testChatCompletionBody("fallback response")) + case "/v1/responses": + responsesHits.Add(1) + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"error":{"message":"unexpected responses endpoint"}}`)) + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer server.Close() + + provider := newTestOpenAIProvider(server.URL, &schemas.CustomProviderConfig{ + CustomProviderKey: "lmstudio", + BaseProviderType: schemas.OpenAI, + IsKeyLess: true, + SupportsResponsesAPI: schemas.Ptr(false), + }) + + ctx := testOpenAIResponsesCtx() + response, bifrostErr := provider.Responses(ctx, schemas.Key{}, testOpenAIResponsesRequest()) + if bifrostErr == nil { + t.Fatal("expected provider-level responses call to fail without compat fallback") + } + if response != nil { + t.Fatalf("expected nil response on provider-level failure, got %+v", response) + } + if chatHits.Load() != 0 { + t.Fatalf("expected zero chat completion requests, got %d", chatHits.Load()) + } + if responsesHits.Load() != 1 { + t.Fatalf("expected one responses endpoint request, got %d", responsesHits.Load()) + } + if ctx.Value(schemas.BifrostContextKeyIsResponsesToChatCompletionFallback) != nil { + t.Fatalf("expected no fallback marker on provider-only path, got %v", ctx.Value(schemas.BifrostContextKeyIsResponsesToChatCompletionFallback)) + } +} + +func TestResponses_CustomProviderRuntimeUnsupported_DoesNotFallbackInsideProvider(t *testing.T) { + t.Parallel() + + var chatHits atomic.Int32 + var responsesHits atomic.Int32 + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/v1/responses": + responsesHits.Add(1) + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"error":{"message":"responses endpoint unsupported"}}`)) + case "/v1/chat/completions": + chatHits.Add(1) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write(testChatCompletionBody("auto fallback response")) + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer server.Close() + + provider := newTestOpenAIProvider(server.URL, &schemas.CustomProviderConfig{ + CustomProviderKey: "lmstudio", + BaseProviderType: schemas.OpenAI, + IsKeyLess: true, + }) + + ctx := testOpenAIResponsesCtx() + response, bifrostErr := provider.Responses(ctx, schemas.Key{}, testOpenAIResponsesRequest()) + if bifrostErr == nil { + t.Fatal("expected provider-level responses call to fail without runtime compat retry") + } + if response != nil { + t.Fatalf("expected nil response on provider-level failure, got %+v", response) + } + if responsesHits.Load() != 1 { + t.Fatalf("expected one native responses attempt, got %d", responsesHits.Load()) + } + if chatHits.Load() != 0 { + t.Fatalf("expected zero chat completion fallback requests, got %d", chatHits.Load()) + } + if ctx.Value(schemas.BifrostContextKeyIsResponsesToChatCompletionFallback) != nil { + t.Fatalf("expected no fallback marker on provider-only path, got %v", ctx.Value(schemas.BifrostContextKeyIsResponsesToChatCompletionFallback)) + } +} + +func TestResponses_CustomProviderUsesNativeResponsesWhenEnabled(t *testing.T) { + t.Parallel() + + var chatHits atomic.Int32 + var responsesHits atomic.Int32 + var requestBody atomic.Value + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + bodyBytes, err := io.ReadAll(r.Body) + if err != nil { + t.Fatalf("failed to read request body: %v", err) + } + + switch r.URL.Path { + case "/v1/responses": + responsesHits.Add(1) + requestBody.Store(string(bodyBytes)) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write(testResponsesBody("native custom response")) + case "/v1/chat/completions": + chatHits.Add(1) + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`{"error":{"message":"unexpected chat endpoint"}}`)) + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer server.Close() + + provider := newTestOpenAIProvider(server.URL, &schemas.CustomProviderConfig{ + CustomProviderKey: "lmstudio", + BaseProviderType: schemas.OpenAI, + IsKeyLess: true, + SupportsResponsesAPI: schemas.Ptr(true), + }) + + response, bifrostErr := provider.Responses(testOpenAIResponsesCtx(), schemas.Key{}, testOpenAIResponsesRequest()) + if bifrostErr != nil { + t.Fatalf("Responses returned error: %v", bifrostErr.Error) + } + if response == nil { + t.Fatal("Responses returned nil response") + } + if responsesHits.Load() != 1 { + t.Fatalf("expected one responses request, got %d", responsesHits.Load()) + } + if chatHits.Load() != 0 { + t.Fatalf("expected zero chat completion requests, got %d", chatHits.Load()) + } + + body, _ := requestBody.Load().(string) + if !strings.Contains(body, `"input"`) { + t.Fatalf("expected native responses payload to contain input, got %s", body) + } + if strings.Contains(body, `"messages"`) { + t.Fatalf("expected native responses payload to omit messages, got %s", body) + } + if !strings.Contains(body, `"max_output_tokens":`) { + t.Fatalf("expected native responses payload to contain max_output_tokens, got %s", body) + } + + // ExtraFields.Provider is populated by the core layer (PopulateExtraFields), + // not the provider itself, so we don't assert it in provider-level tests. + if len(response.Output) == 0 || response.Output[0].Content == nil || len(response.Output[0].Content.ContentBlocks) == 0 { + t.Fatalf("expected native responses output, got %+v", response.Output) + } + if got := response.Output[0].Content.ContentBlocks[0].Text; got == nil || *got != "native custom response" { + t.Fatalf("expected native output text native custom response, got %+v", got) + } +} + +func TestResponsesStream_CustomProviderRuntimeUnsupported_DoesNotFallbackInsideProvider(t *testing.T) { + t.Parallel() + + var chatHits atomic.Int32 + var responsesHits atomic.Int32 + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/v1/chat/completions": + chatHits.Add(1) + w.Header().Set("Content-Type", "text/event-stream") + _, _ = w.Write([]byte("unused")) + case "/v1/responses": + responsesHits.Add(1) + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"error":{"message":"responses endpoint unsupported"}}`)) + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer server.Close() + + provider := newTestOpenAIProvider(server.URL, &schemas.CustomProviderConfig{ + CustomProviderKey: "lmstudio", + BaseProviderType: schemas.OpenAI, + IsKeyLess: true, + }) + + ctx := testOpenAIResponsesCtx() + streamChan, bifrostErr := provider.ResponsesStream(ctx, noopOpenAIPostHookRunner, func(context.Context) {}, schemas.Key{}, testOpenAIResponsesRequest()) + if bifrostErr == nil { + t.Fatal("expected provider-level responses stream call to fail without runtime compat retry") + } + if streamChan != nil { + t.Fatal("expected nil stream when provider-level stream setup fails") + } + if chatHits.Load() != 0 { + t.Fatalf("expected zero chat completion stream requests, got %d", chatHits.Load()) + } + if responsesHits.Load() != 1 { + t.Fatalf("expected one native responses stream attempt, got %d", responsesHits.Load()) + } + if ctx.Value(schemas.BifrostContextKeyIsResponsesToChatCompletionFallback) != nil { + t.Fatalf("expected no fallback marker on provider-only path, got %v", ctx.Value(schemas.BifrostContextKeyIsResponsesToChatCompletionFallback)) + } +} + +func TestResponses_NativeOpenAIStillUsesResponsesEndpoint(t *testing.T) { + t.Parallel() + + var chatHits atomic.Int32 + var responsesHits atomic.Int32 + var requestBody atomic.Value + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + bodyBytes, err := io.ReadAll(r.Body) + if err != nil { + t.Fatalf("failed to read request body: %v", err) + } + + switch r.URL.Path { + case "/v1/responses": + responsesHits.Add(1) + requestBody.Store(string(bodyBytes)) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write(testResponsesBody("native response")) + case "/v1/chat/completions": + chatHits.Add(1) + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`{"error":{"message":"unexpected chat endpoint"}}`)) + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer server.Close() + + provider := newTestOpenAIProvider(server.URL, nil) + + response, bifrostErr := provider.Responses(testOpenAIResponsesCtx(), schemas.Key{}, testOpenAIResponsesRequest()) + if bifrostErr != nil { + t.Fatalf("Responses returned error: %v", bifrostErr.Error) + } + if response == nil { + t.Fatal("Responses returned nil response") + } + if responsesHits.Load() != 1 { + t.Fatalf("expected one responses request, got %d", responsesHits.Load()) + } + if chatHits.Load() != 0 { + t.Fatalf("expected zero chat completion requests, got %d", chatHits.Load()) + } + + body, _ := requestBody.Load().(string) + if !strings.Contains(body, `"input"`) { + t.Fatalf("expected native responses payload to contain input, got %s", body) + } + if strings.Contains(body, `"messages"`) { + t.Fatalf("expected native responses payload to omit messages, got %s", body) + } + if got := response.Output[0].Content.ContentBlocks[0].Text; got == nil || *got != "native response" { + t.Fatalf("expected native output text native response, got %+v", got) + } +} diff --git a/core/providers/utils/utils.go b/core/providers/utils/utils.go index fc66309189..b91debf900 100644 --- a/core/providers/utils/utils.go +++ b/core/providers/utils/utils.go @@ -1909,11 +1909,12 @@ func ProcessAndSendResponse( } streamResponse := BuildClientStreamChunk(ctx, processedResponse, processedError) - - select { - case responseChan <- streamResponse: - case <-ctx.Done(): - return + if streamResponse != nil { + select { + case responseChan <- streamResponse: + case <-ctx.Done(): + return + } } // Check if this is the final chunk and complete deferred span with post-processed data diff --git a/core/schemas/bifrost.go b/core/schemas/bifrost.go index e43ba0ccbf..4e07b03d1b 100644 --- a/core/schemas/bifrost.go +++ b/core/schemas/bifrost.go @@ -2,6 +2,7 @@ package schemas import ( + "context" "encoding/json" "errors" "fmt" @@ -203,7 +204,11 @@ const ( BifrostContextKeySendBackRawRequest BifrostContextKey = "bifrost-send-back-raw-request" // bool (per-request override — read by bifrost.go, never overwritten) BifrostContextKeySendBackRawResponse BifrostContextKey = "bifrost-send-back-raw-response" // bool (per-request override — read by bifrost.go, never overwritten) BifrostContextKeyIntegrationType BifrostContextKey = "bifrost-integration-type" // integration used in gateway (e.g. openai, anthropic, bedrock, etc.) - BifrostContextKeyIsResponsesToChatCompletionFallback BifrostContextKey = "bifrost-is-responses-to-chat-completion-fallback" // bool (set by bifrost - DO NOT SET THIS MANUALLY)) + BifrostContextKeySkipOperationCheck BifrostContextKey = "bifrost-skip-operation-check" // bool (set by compat plugin / core fallback — skip AllowedRequests gate for internal request-type conversions) + BifrostContextKeyIsResponsesToChatCompletionFallback BifrostContextKey = "bifrost-is-responses-to-chat-completion-fallback" // bool (set by bifrost - DO NOT SET THIS MANUALLY)) + BifrostContextKeyResponsesToChatCompletionFallbackReason BifrostContextKey = "bifrost-responses-to-chat-completion-fallback-reason" // ResponsesToChatCompletionFallbackReason (set by bifrost - DO NOT SET THIS MANUALLY)) + BifrostContextKeyResponsesToChatCompletionCompatState BifrostContextKey = "bifrost-responses-to-chat-completion-compat-state" // *ResponsesToChatCompletionCompatState (set by bifrost/plugins - DO NOT SET THIS MANUALLY) + BifrostContextKeyCustomProviderMetadata BifrostContextKey = "bifrost-custom-provider-metadata" // *CustomProviderContextMetadata (set by bifrost - DO NOT SET THIS MANUALLY) BifrostMCPAgentOriginalRequestID BifrostContextKey = "bifrost-mcp-agent-original-request-id" // string (to store the original request ID for MCP agent mode) BifrostContextKeyParentMCPRequestID BifrostContextKey = "bf-parent-mcp-request-id" // string (parent request ID for nested tool calls from executeCode) BifrostContextKeyStructuredOutputToolName BifrostContextKey = "bifrost-structured-output-tool-name" // string (to store the name of the structured output tool (set by bifrost)) @@ -770,6 +775,174 @@ type BifrostResponse struct { PassthroughResponse *BifrostPassthroughResponse } +type ResponsesToChatCompletionFallbackReason string + +const ( + ResponsesToChatCompletionFallbackReasonConfiguredUnsupported ResponsesToChatCompletionFallbackReason = "configured_unsupported" + ResponsesToChatCompletionFallbackReasonRuntimeUnsupported ResponsesToChatCompletionFallbackReason = "runtime_unsupported" +) + +type CustomProviderContextMetadata struct { + ProviderKey ModelProvider + BaseProviderType ModelProvider + SupportsResponsesAPI *bool +} + +type ResponsesToChatCompletionRetryPolicy struct { + UnsupportedStatusCodes []int + UnsupportedErrorSubstrings []string +} + +type ResponsesToChatCompletionCompatState struct { + OriginalRequestType RequestType + OriginalModel string + IsStreaming bool + RetryEligible bool + RetryPolicy *ResponsesToChatCompletionRetryPolicy + Active bool + FallbackReason ResponsesToChatCompletionFallbackReason + FallbackRequest *BifrostRequest + Warnings []string +} + +func GetCustomProviderContextMetadata(ctx context.Context) (*CustomProviderContextMetadata, bool) { + if ctx == nil { + return nil, false + } + + metadata, ok := ctx.Value(BifrostContextKeyCustomProviderMetadata).(*CustomProviderContextMetadata) + if !ok || metadata == nil { + return nil, false + } + + return metadata, true +} + +func SetResponsesToChatCompletionCompatState(ctx *BifrostContext, state *ResponsesToChatCompletionCompatState) { + if ctx == nil { + return + } + + if state == nil { + ctx.ClearValue(BifrostContextKeyResponsesToChatCompletionCompatState) + return + } + + ctx.SetValue(BifrostContextKeyResponsesToChatCompletionCompatState, state) +} + +func ClearResponsesToChatCompletionCompatState(ctx *BifrostContext) { + if ctx == nil { + return + } + + ctx.ClearValue(BifrostContextKeyResponsesToChatCompletionCompatState) +} + +func GetResponsesToChatCompletionCompatState(ctx context.Context) (*ResponsesToChatCompletionCompatState, bool) { + if ctx == nil { + return nil, false + } + + state, ok := ctx.Value(BifrostContextKeyResponsesToChatCompletionCompatState).(*ResponsesToChatCompletionCompatState) + if !ok || state == nil { + return nil, false + } + + return state, true +} + +// ShouldSkipOperationCheck returns true when the compat plugin or core fallback +// has set the skip-operation-check flag on the context. Providers use this to +// bypass the AllowedRequests gate for internally converted requests (e.g. +// Responses→Chat fallback). The provider has no knowledge of WHY the check is +// skipped — that decision lives in the compat plugin and core orchestration. +func ShouldSkipOperationCheck(ctx *BifrostContext) bool { + if ctx == nil { + return false + } + skip, ok := ctx.Value(BifrostContextKeySkipOperationCheck).(bool) + return ok && skip +} + +func ActivateResponsesToChatCompletionCompatState(ctx *BifrostContext, reason ResponsesToChatCompletionFallbackReason) (*ResponsesToChatCompletionCompatState, bool) { + state, ok := GetResponsesToChatCompletionCompatState(ctx) + if !ok || state == nil { + return nil, false + } + + state.Active = true + state.RetryEligible = false + state.FallbackReason = reason + SetResponsesToChatCompletionFallback(ctx, reason) + // Allow the converted Chat request to bypass the provider's AllowedRequests gate. + // The original Responses request was already authorised; this is an internal conversion. + ctx.SetValue(BifrostContextKeySkipOperationCheck, true) + + return state, true +} + +func SetResponsesToChatCompletionFallback(ctx *BifrostContext, reason ResponsesToChatCompletionFallbackReason) { + if ctx == nil { + return + } + + ctx.SetValue(BifrostContextKeyIsResponsesToChatCompletionFallback, true) + if reason == "" { + ctx.ClearValue(BifrostContextKeyResponsesToChatCompletionFallbackReason) + return + } + ctx.SetValue(BifrostContextKeyResponsesToChatCompletionFallbackReason, reason) +} + +func ClearResponsesToChatCompletionFallback(ctx *BifrostContext) { + if ctx == nil { + return + } + + ctx.ClearValue(BifrostContextKeyIsResponsesToChatCompletionFallback) + ctx.ClearValue(BifrostContextKeyResponsesToChatCompletionFallbackReason) + ctx.ClearValue(BifrostContextKeySkipOperationCheck) +} + +func GetResponsesToChatCompletionFallback(ctx context.Context) (ResponsesToChatCompletionFallbackReason, bool) { + if ctx == nil { + return "", false + } + + isFallback, ok := ctx.Value(BifrostContextKeyIsResponsesToChatCompletionFallback).(bool) + if !ok || !isFallback { + return "", false + } + + switch reason := ctx.Value(BifrostContextKeyResponsesToChatCompletionFallbackReason).(type) { + case ResponsesToChatCompletionFallbackReason: + return reason, true + case string: + return ResponsesToChatCompletionFallbackReason(reason), true + default: + return "", true + } +} + +func ApplyResponsesToChatCompletionFallbackMetadata(ctx context.Context, response *BifrostResponse, bifrostErr *BifrostError) { + reason, ok := GetResponsesToChatCompletionFallback(ctx) + if !ok { + return + } + + if response != nil { + extraFields := response.GetExtraFields() + extraFields.ResponsesToChatCompletionFallback = true + extraFields.ResponsesToChatCompletionFallbackReason = string(reason) + } + + if bifrostErr != nil { + bifrostErr.ExtraFields.ResponsesToChatCompletionFallback = true + bifrostErr.ExtraFields.ResponsesToChatCompletionFallbackReason = string(reason) + } +} + func (r *BifrostResponse) GetExtraFields() *BifrostResponseExtraFields { switch { case r.ListModelsResponse != nil: @@ -1096,6 +1269,8 @@ type BifrostResponseExtraFields struct { ParseErrors []BatchError `json:"parse_errors,omitempty"` // errors encountered while parsing JSONL batch results ConvertedRequestType RequestType `json:"converted_request_type,omitempty"` DroppedCompatPluginParams []string `json:"dropped_compat_plugin_params,omitempty"` // params dropped by the compat plugin based on model catalog + ResponsesToChatCompletionFallback bool `json:"responses_to_chat_completion_fallback,omitempty"` + ResponsesToChatCompletionFallbackReason string `json:"responses_to_chat_completion_fallback_reason,omitempty"` ProviderResponseHeaders map[string]string `json:"provider_response_headers,omitempty"` // HTTP response headers from the provider (filtered to exclude transport-level headers) } @@ -1282,6 +1457,8 @@ type BifrostErrorExtraFields struct { RawResponse interface{} `json:"raw_response,omitempty"` ConvertedRequestType RequestType `json:"converted_request_type,omitempty"` DroppedCompatPluginParams []string `json:"dropped_compat_plugin_params,omitempty"` + ResponsesToChatCompletionFallback bool `json:"responses_to_chat_completion_fallback,omitempty"` + ResponsesToChatCompletionFallbackReason string `json:"responses_to_chat_completion_fallback_reason,omitempty"` KeyStatuses []KeyStatus `json:"key_statuses,omitempty"` MCPAuthRequired *MCPUserOAuthRequiredError `json:"mcp_auth_required,omitempty"` // Set when a per-user OAuth MCP tool requires authentication } diff --git a/core/schemas/mux.go b/core/schemas/mux.go index 24943d3fbd..0d22f2786f 100644 --- a/core/schemas/mux.go +++ b/core/schemas/mux.go @@ -2,6 +2,7 @@ package schemas import ( "fmt" + "net/http" "sort" "strings" "sync" @@ -1025,6 +1026,7 @@ func (brr *BifrostResponsesRequest) ToChatRequest() *BifrostChatRequest { Temperature: brr.Params.Temperature, TopLogProbs: brr.Params.TopLogProbs, TopP: brr.Params.TopP, + User: brr.Params.User, ExtraParams: brr.Params.ExtraParams, // Map specific fields @@ -1032,6 +1034,12 @@ func (brr *BifrostResponsesRequest) ToChatRequest() *BifrostChatRequest { Metadata: brr.Params.Metadata, } + if brr.Params.Text != nil && brr.Params.Text.Format != nil && brr.Params.Text.Format.Type != "" && brr.Params.Text.Format.Type != "text" { + if responseFormat, err := ConvertViaJSON[interface{}](brr.Params.Text.Format); err == nil { + bcr.Params.ResponseFormat = &responseFormat + } + } + // Convert StreamOptions if brr.Params.StreamOptions != nil { bcr.Params.StreamOptions = &ChatStreamOptions{ @@ -1068,6 +1076,135 @@ func (brr *BifrostResponsesRequest) ToChatRequest() *BifrostChatRequest { return bcr } +// ToChatFallbackRequest is the compatibility seam for Responses -> Chat fallback. +// Keep fallback-specific request shaping behind this method so schema evolution has a +// single place to update. This bridge is best-effort rather than fully lossless. +func (brr *BifrostResponsesRequest) ToChatFallbackRequest() *BifrostChatRequest { + return brr.ToChatRequest() +} + +// ChatFallbackWarnings returns known caveats for Responses -> Chat fallback. +// It intentionally reports only request features that are dropped or normalized today. +func (brr *BifrostResponsesRequest) ChatFallbackWarnings() []string { + if brr == nil { + return nil + } + + warnings := make([]string, 0, 5) + + for _, input := range brr.Input { + if input.Type != nil && *input.Type == ResponsesMessageTypeReasoning { + warnings = append(warnings, "reasoning input items are dropped during Responses to Chat fallback") + break + } + } + + for _, input := range brr.Input { + if input.Role != nil && *input.Role == ResponsesInputMessageRoleDeveloper { + warnings = append(warnings, "developer role is normalized to system during Responses to Chat fallback") + break + } + } + + if brr.Params == nil { + return warnings + } + + for _, tool := range brr.Params.Tools { + if tool.Type != ResponsesToolTypeFunction { + warnings = append(warnings, "non-function tools are dropped during Responses to Chat fallback") + break + } + } + + if brr.Params.ToolChoice != nil && brr.Params.ToolChoice.ResponsesToolChoiceStruct != nil { + switch brr.Params.ToolChoice.ResponsesToolChoiceStruct.Type { + case ResponsesToolChoiceTypeAllowedTools, ResponsesToolChoiceTypeCustom: + warnings = append(warnings, fmt.Sprintf("tool_choice type %q is not preserved during Responses to Chat fallback", brr.Params.ToolChoice.ResponsesToolChoiceStruct.Type)) + } + } + + if brr.Params.Truncation != nil { + warnings = append(warnings, "responses truncation is not forwarded during Responses to Chat fallback") + } + + return warnings +} + +// DefaultResponsesToChatCompletionRetryPolicy returns the retry policy for runtime +// auto-detection of unsupported Responses endpoints. Only triggers on clear signals +// that the endpoint does not exist: +// - HTTP 404/405/410/501 — unambiguous "not found" / "not implemented" status codes +// - HTML response body — provider returned a web page instead of JSON (common for +// servers that don't recognise the /v1/responses path) +// +// Intentionally excludes unmarshal failures and empty responses to avoid masking +// transient errors or parsing bugs as unsupported-endpoint signals. +func DefaultResponsesToChatCompletionRetryPolicy() *ResponsesToChatCompletionRetryPolicy { + return &ResponsesToChatCompletionRetryPolicy{ + UnsupportedStatusCodes: []int{ + http.StatusNotFound, + http.StatusMethodNotAllowed, + http.StatusGone, + http.StatusNotImplemented, + }, + UnsupportedErrorSubstrings: []string{ + strings.ToLower(ErrProviderResponseHTML), + }, + } +} + +func (policy *ResponsesToChatCompletionRetryPolicy) ShouldRetry(err *BifrostError) bool { + if policy == nil || err == nil { + return false + } + + if err.StatusCode != nil { + for _, statusCode := range policy.UnsupportedStatusCodes { + if *err.StatusCode == statusCode { + return true + } + } + } + + if err.Error == nil { + return false + } + + message := strings.ToLower(strings.TrimSpace(err.Error.Message)) + for _, unsupportedMarker := range policy.UnsupportedErrorSubstrings { + if strings.Contains(message, unsupportedMarker) { + return true + } + } + + return false +} + +func (state *ResponsesToChatCompletionCompatState) ShouldRetry(err *BifrostError) bool { + if state == nil || !state.RetryEligible || state.RetryPolicy == nil { + return false + } + + return state.RetryPolicy.ShouldRetry(err) +} + +func ResponsesToChatCompletionFallbackErrorMessage(err *BifrostError) string { + if err == nil { + return "unknown responses API error" + } + + if err.Error == nil { + return "unknown responses API error" + } + + if err.Error.Message == "" { + return "unknown responses API error" + } + + return err.Error.Message +} + func sanitizeResponsesToolsForChatFallback(tools []ResponsesTool) []ChatTool { if len(tools) == 0 { return nil diff --git a/core/schemas/mux_test.go b/core/schemas/mux_test.go index a8e8edd96e..657f900d44 100644 --- a/core/schemas/mux_test.go +++ b/core/schemas/mux_test.go @@ -1,6 +1,10 @@ package schemas -import "testing" +import ( + "net/http" + "strings" + "testing" +) func TestToChatMessages_PreservesDeveloperRole(t *testing.T) { messages := []ResponsesMessage{ @@ -254,6 +258,109 @@ func TestToChatRequest_PreservesStringToolChoiceAutoAndNone(t *testing.T) { } } +func TestToChatFallbackRequest_PreservesUserAndStructuredOutputFormat(t *testing.T) { + user := "user-123" + verbosity := "high" + formatType := "json_schema" + schemaType := "object" + req := &BifrostResponsesRequest{ + Params: &ResponsesParameters{ + User: &user, + Text: &ResponsesTextConfig{ + Verbosity: &verbosity, + Format: &ResponsesTextConfigFormat{ + Type: formatType, + JSONSchema: &ResponsesTextConfigFormatJSONSchema{ + Type: &schemaType, + }, + }, + }, + }, + } + + chatReq := req.ToChatFallbackRequest() + if chatReq == nil || chatReq.Params == nil { + t.Fatal("expected non-nil chat fallback request params") + } + if chatReq.Params.User == nil || *chatReq.Params.User != user { + t.Fatalf("expected user %q to be preserved, got %+v", user, chatReq.Params.User) + } + if chatReq.Params.ResponseFormat == nil { + t.Fatal("expected structured output format to be forwarded") + } + if chatReq.Params.Verbosity == nil || *chatReq.Params.Verbosity != verbosity { + t.Fatalf("expected verbosity %q to be preserved, got %+v", verbosity, chatReq.Params.Verbosity) + } + responseFormat, ok := (*chatReq.Params.ResponseFormat).(map[string]interface{}) + if !ok { + t.Fatalf("expected response_format to be a map, got %T", *chatReq.Params.ResponseFormat) + } + if got, _ := responseFormat["type"].(string); got != formatType { + t.Fatalf("expected response_format type %q, got %q", formatType, got) + } +} + +func TestChatFallbackWarnings_ReportsDroppedCompatibilityFeatures(t *testing.T) { + truncation := "auto" + developer := ResponsesInputMessageRoleDeveloper + choiceType := ResponsesToolChoiceTypeAllowedTools + req := &BifrostResponsesRequest{ + Input: []ResponsesMessage{{Role: &developer}}, + Params: &ResponsesParameters{ + Truncation: &truncation, + Tools: []ResponsesTool{{Type: ResponsesToolTypeWebSearch, Name: Ptr("search")}}, + ToolChoice: &ResponsesToolChoice{ + ResponsesToolChoiceStruct: &ResponsesToolChoiceStruct{Type: choiceType}, + }, + }, + } + + warnings := req.ChatFallbackWarnings() + if len(warnings) < 3 { + t.Fatalf("expected multiple compatibility warnings, got %#v", warnings) + } + + joined := strings.Join(warnings, "\n") + if !containsWarning(joined, "developer role") { + t.Fatalf("expected developer-role warning, got %#v", warnings) + } + if !containsWarning(joined, "non-function tools") { + t.Fatalf("expected tool warning, got %#v", warnings) + } + if !containsWarning(joined, "tool_choice type") { + t.Fatalf("expected tool_choice warning, got %#v", warnings) + } + if !containsWarning(joined, "truncation") { + t.Fatalf("expected truncation warning, got %#v", warnings) + } +} + +func containsWarning(joined string, fragment string) bool { + return strings.Contains(joined, fragment) +} + +func TestDefaultResponsesToChatCompletionRetryPolicy_UsesUnsupportedEndpointSignals(t *testing.T) { + policy := DefaultResponsesToChatCompletionRetryPolicy() + if policy == nil { + t.Fatal("expected default retry policy") + } + + notFound := http.StatusNotFound + if !policy.ShouldRetry(&BifrostError{StatusCode: ¬Found}) { + t.Fatal("expected 404 to trigger runtime fallback retry") + } + + unsupportedResponseErr := &BifrostError{Error: &ErrorField{Message: ErrProviderResponseHTML}} + if !policy.ShouldRetry(unsupportedResponseErr) { + t.Fatal("expected provider HTML response marker to trigger runtime fallback retry") + } + + unauthorized := http.StatusUnauthorized + if policy.ShouldRetry(&BifrostError{StatusCode: &unauthorized}) { + t.Fatal("expected 401 to remain a hard failure") + } +} + func TestToBifrostResponsesStreamResponse_PopulatesFinalDoneTextAndCompletedOutput(t *testing.T) { state := AcquireChatToResponsesStreamState() defer ReleaseChatToResponsesStreamState(state) diff --git a/core/schemas/provider.go b/core/schemas/provider.go index e087747c07..ce12b324f8 100644 --- a/core/schemas/provider.go +++ b/core/schemas/provider.go @@ -417,6 +417,7 @@ type CustomProviderConfig struct { CustomProviderKey string `json:"-"` // Custom provider key, internally set by Bifrost IsKeyLess bool `json:"is_key_less"` // Whether the custom provider requires a key (not allowed for Bedrock) BaseProviderType ModelProvider `json:"base_provider_type"` // Base provider type + SupportsResponsesAPI *bool `json:"supports_responses_api,omitempty"` // Whether the upstream provider supports the native OpenAI Responses API (true = force native, false = force chat fallback, nil = try native first and auto-fallback on obvious unsupported-endpoint failures) AllowedRequests *AllowedRequests `json:"allowed_requests,omitempty"` // Allowed requests for the custom provider RequestPathOverrides map[RequestType]string `json:"request_path_overrides,omitempty"` // Mapping of request type to its custom path which will override the default path of the provider (not allowed for Bedrock) } diff --git a/plugins/compat/main.go b/plugins/compat/main.go index 6c536d6bd6..3c4865325b 100644 --- a/plugins/compat/main.go +++ b/plugins/compat/main.go @@ -117,6 +117,11 @@ func (p *CompatPlugin) PreLLMHook(ctx *schemas.BifrostContext, req *schemas.Bifr applyParameterConversion(modifiedReq) } + // Responses → Chat fallback: sets up compat state for configured-unsupported + // (SupportsResponsesAPI==false → immediate transform) and runtime auto-detection + // (SupportsResponsesAPI==nil → prepare retry-eligible state, keep original request). + modifiedReq = transformResponsesToChatRequest(ctx, modifiedReq, p.logger) + return modifiedReq, nil, nil } @@ -144,6 +149,11 @@ func (p *CompatPlugin) PostLLMHook(ctx *schemas.BifrostContext, result *schemas. } } + // Responses → Chat fallback: convert chat response back to responses format + // when compat state is active (configured or runtime fallback path). + result = transformResponsesToChatResponse(ctx, result, p.logger) + bifrostErr = transformResponsesToChatError(ctx, bifrostErr) + return result, bifrostErr, nil } diff --git a/plugins/compat/responsestochat.go b/plugins/compat/responsestochat.go new file mode 100644 index 0000000000..17aaabccd0 --- /dev/null +++ b/plugins/compat/responsestochat.go @@ -0,0 +1,114 @@ +package compat + +import ( + "strings" + + "github.com/maximhq/bifrost/core/schemas" +) + +// transformResponsesToChatRequest applies the Responses -> Chat compatibility bridge. +// This path is intentionally best-effort rather than fully lossless; keep request shaping +// behind schemas.ToChatFallbackRequest so schema evolution has one compatibility seam. +func transformResponsesToChatRequest(ctx *schemas.BifrostContext, req *schemas.BifrostRequest, logger schemas.Logger) *schemas.BifrostRequest { + if req.RequestType != schemas.ResponsesRequest && req.RequestType != schemas.ResponsesStreamRequest { + return req + } + + if req.ResponsesRequest == nil { + return req + } + + metadata, ok := schemas.GetCustomProviderContextMetadata(ctx) + if !ok || metadata == nil || metadata.BaseProviderType != schemas.OpenAI { + return req + } + + if metadata.SupportsResponsesAPI != nil && *metadata.SupportsResponsesAPI { + return req + } + + chatRequest := req.ResponsesRequest.ToChatFallbackRequest() + if chatRequest == nil { + return req + } + + fallbackRequestType := schemas.ChatCompletionRequest + if req.RequestType == schemas.ResponsesStreamRequest { + fallbackRequestType = schemas.ChatCompletionStreamRequest + } + + state := &schemas.ResponsesToChatCompletionCompatState{ + OriginalRequestType: req.RequestType, + OriginalModel: req.ResponsesRequest.Model, + IsStreaming: req.RequestType == schemas.ResponsesStreamRequest, + FallbackRequest: &schemas.BifrostRequest{ + RequestType: fallbackRequestType, + ChatRequest: chatRequest, + }, + Warnings: req.ResponsesRequest.ChatFallbackWarnings(), + } + schemas.SetResponsesToChatCompletionCompatState(ctx, state) + + if metadata.SupportsResponsesAPI == nil { + state.RetryEligible = true + state.RetryPolicy = schemas.DefaultResponsesToChatCompletionRetryPolicy() + return req + } + + if _, activated := schemas.ActivateResponsesToChatCompletionCompatState(ctx, schemas.ResponsesToChatCompletionFallbackReasonConfiguredUnsupported); !activated { + return req + } + logResponsesToChatFallback(logger, state.OriginalModel, state.FallbackReason, state.Warnings) + + return state.FallbackRequest +} + +func transformResponsesToChatResponse(ctx *schemas.BifrostContext, resp *schemas.BifrostResponse, logger schemas.Logger) *schemas.BifrostResponse { + state, ok := schemas.GetResponsesToChatCompletionCompatState(ctx) + if !ok || state == nil || !state.Active || resp == nil || state.IsStreaming || resp.ChatResponse == nil { + return resp + } + + responsesResponse := resp.ChatResponse.ToBifrostResponsesResponse() + if responsesResponse == nil { + return resp + } + + responsesResponse.ExtraFields.RequestType = state.OriginalRequestType + responsesResponse.ExtraFields.OriginalModelRequested = state.OriginalModel + responsesResponse.ExtraFields.ResponsesToChatCompletionFallback = true + + if logger != nil { + logger.Debug("compat: converted chat response back to responses for model %s (reason=%s)", state.OriginalModel, state.FallbackReason) + } + + return &schemas.BifrostResponse{ + ResponsesResponse: responsesResponse, + } +} + +func transformResponsesToChatError(ctx *schemas.BifrostContext, err *schemas.BifrostError) *schemas.BifrostError { + state, ok := schemas.GetResponsesToChatCompletionCompatState(ctx) + if !ok || state == nil || err == nil || !state.Active { + return err + } + + err.ExtraFields.RequestType = state.OriginalRequestType + err.ExtraFields.OriginalModelRequested = state.OriginalModel + err.ExtraFields.ResponsesToChatCompletionFallback = true + + return err +} + +func logResponsesToChatFallback(logger schemas.Logger, model string, reason schemas.ResponsesToChatCompletionFallbackReason, warnings []string) { + if logger == nil { + return + } + + logger.Info("compat: applied responses->chat completion fallback for model %s (reason=%s)", model, reason) + if len(warnings) == 0 { + return + } + + logger.Warn("compat: responses->chat completion fallback for model %s is compatibility-only: %s", model, strings.Join(warnings, "; ")) +} diff --git a/plugins/compat/responsestochat_test.go b/plugins/compat/responsestochat_test.go new file mode 100644 index 0000000000..953415c173 --- /dev/null +++ b/plugins/compat/responsestochat_test.go @@ -0,0 +1,199 @@ +package compat + +import ( + "context" + "net/http" + "testing" + + "github.com/maximhq/bifrost/core/schemas" +) + +func testResponsesCompatRequest(requestType schemas.RequestType) *schemas.BifrostRequest { + content := "hello" + return &schemas.BifrostRequest{ + RequestType: requestType, + ResponsesRequest: &schemas.BifrostResponsesRequest{ + Provider: schemas.ModelProvider("lmstudio"), + Model: "test-model", + Input: []schemas.ResponsesMessage{{ + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser), + Content: &schemas.ResponsesMessageContent{ + ContentStr: &content, + }, + }}, + Params: &schemas.ResponsesParameters{ + MaxOutputTokens: schemas.Ptr(7), + }, + }, + } +} + +func testResponsesCompatContext() *schemas.BifrostContext { + ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + ctx.SetValue(schemas.BifrostContextKeyCustomProviderMetadata, &schemas.CustomProviderContextMetadata{ + ProviderKey: "lmstudio", + BaseProviderType: schemas.OpenAI, + SupportsResponsesAPI: schemas.Ptr(false), + }) + return ctx +} + +func TestPreLLMHook_TransformsForcedResponsesRequestToChatCompletion(t *testing.T) { + plugin := &CompatPlugin{config: Config{}} + ctx := testResponsesCompatContext() + + transformedReq, shortCircuit, err := plugin.PreLLMHook(ctx, testResponsesCompatRequest(schemas.ResponsesRequest)) + if err != nil { + t.Fatalf("PreLLMHook returned error: %v", err) + } + if shortCircuit != nil { + t.Fatal("expected no short circuit") + } + if transformedReq == nil || transformedReq.ChatRequest == nil { + t.Fatal("expected chat request after transform") + } + if transformedReq.RequestType != schemas.ChatCompletionRequest { + t.Fatalf("expected request type %s, got %s", schemas.ChatCompletionRequest, transformedReq.RequestType) + } + if transformedReq.ChatRequest.Model != "test-model" { + t.Fatalf("expected model test-model, got %s", transformedReq.ChatRequest.Model) + } + if len(transformedReq.ChatRequest.Input) == 0 { + t.Fatal("expected transformed chat input") + } + if ctx.Value(schemas.BifrostContextKeyIsResponsesToChatCompletionFallback) != true { + t.Fatalf("expected fallback context marker to be set, got %v", ctx.Value(schemas.BifrostContextKeyIsResponsesToChatCompletionFallback)) + } + if reason, _ := ctx.Value(schemas.BifrostContextKeyResponsesToChatCompletionFallbackReason).(schemas.ResponsesToChatCompletionFallbackReason); reason != schemas.ResponsesToChatCompletionFallbackReasonConfiguredUnsupported { + t.Fatalf("expected fallback reason %q, got %q", schemas.ResponsesToChatCompletionFallbackReasonConfiguredUnsupported, reason) + } + state, ok := schemas.GetResponsesToChatCompletionCompatState(ctx) + if !ok || state == nil { + t.Fatal("expected responses compat state to be stored on context") + } + if !state.Active || state.RetryEligible { + t.Fatalf("expected active forced fallback state, got %+v", state) + } + + chatText := "fallback response" + finishReason := string(schemas.BifrostFinishReasonStop) + result, bifrostErr, err := plugin.PostLLMHook(ctx, &schemas.BifrostResponse{ + ChatResponse: &schemas.BifrostChatResponse{ + ID: "chatcmpl-test", + Created: 1, + Model: "test-model", + Choices: []schemas.BifrostResponseChoice{{ + Index: 0, + FinishReason: &finishReason, + ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{ + Message: &schemas.ChatMessage{ + Role: schemas.ChatMessageRoleAssistant, + Content: &schemas.ChatMessageContent{ + ContentStr: &chatText, + }, + }, + }, + }}, + ExtraFields: schemas.BifrostResponseExtraFields{ + Provider: schemas.ModelProvider("lmstudio"), + RequestType: schemas.ChatCompletionRequest, + OriginalModelRequested: "test-model", + }, + }, + }, nil) + if err != nil { + t.Fatalf("PostLLMHook returned error: %v", err) + } + if bifrostErr != nil { + t.Fatalf("expected nil error, got %+v", bifrostErr) + } + if result == nil || result.ResponsesResponse == nil { + t.Fatal("expected responses response after transform") + } + if result.ResponsesResponse.ExtraFields.RequestType != schemas.ResponsesRequest { + t.Fatalf("expected request type %s, got %s", schemas.ResponsesRequest, result.ResponsesResponse.ExtraFields.RequestType) + } + if !result.ResponsesResponse.ExtraFields.ResponsesToChatCompletionFallback { + t.Fatal("expected responses-to-chat fallback marker to be set") + } + if len(result.ResponsesResponse.Output) == 0 || result.ResponsesResponse.Output[0].Content == nil || len(result.ResponsesResponse.Output[0].Content.ContentBlocks) == 0 { + t.Fatal("expected transformed responses output") + } + if got := result.ResponsesResponse.Output[0].Content.ContentBlocks[0].Text; got == nil || *got != "fallback response" { + t.Fatalf("expected fallback response text, got %+v", got) + } +} + +func TestPreLLMHook_PreparesRuntimeResponsesFallback(t *testing.T) { + plugin := &CompatPlugin{config: Config{}} + ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + ctx.SetValue(schemas.BifrostContextKeyCustomProviderMetadata, &schemas.CustomProviderContextMetadata{ + ProviderKey: "lmstudio", + BaseProviderType: schemas.OpenAI, + }) + + originalReq := testResponsesCompatRequest(schemas.ResponsesRequest) + transformedReq, shortCircuit, err := plugin.PreLLMHook(ctx, originalReq) + if err != nil { + t.Fatalf("PreLLMHook returned error: %v", err) + } + if shortCircuit != nil { + t.Fatal("expected no short circuit") + } + if transformedReq != originalReq { + t.Fatal("expected native responses request to be kept for runtime probing") + } + if ctx.Value(schemas.BifrostContextKeyIsResponsesToChatCompletionFallback) != nil { + t.Fatalf("expected fallback context marker to stay unset before runtime retry, got %v", ctx.Value(schemas.BifrostContextKeyIsResponsesToChatCompletionFallback)) + } + state, ok := schemas.GetResponsesToChatCompletionCompatState(ctx) + if !ok || state == nil { + t.Fatal("expected compat state to be stored for runtime retry") + } + if state.Active || !state.RetryEligible { + t.Fatalf("expected runtime retry state to be eligible but inactive, got %+v", state) + } + if state.RetryPolicy == nil { + t.Fatal("expected runtime retry policy to be prepared by compat plugin") + } + if state.FallbackRequest == nil || state.FallbackRequest.ChatRequest == nil { + t.Fatal("expected fallback chat request to be prepared") + } + if state.FallbackRequest.RequestType != schemas.ChatCompletionRequest { + t.Fatalf("expected fallback request type %s, got %s", schemas.ChatCompletionRequest, state.FallbackRequest.RequestType) + } + if state.OriginalRequestType != schemas.ResponsesRequest { + t.Fatalf("expected original request type %s, got %s", schemas.ResponsesRequest, state.OriginalRequestType) + } + statusCode := http.StatusNotFound + if !state.ShouldRetry(&schemas.BifrostError{StatusCode: &statusCode}) { + t.Fatal("expected compat state retry policy to treat 404 as runtime fallback trigger") + } +} + +func TestPreLLMHook_TransformsForcedResponsesStreamRequestToChatCompletionStream(t *testing.T) { + plugin := &CompatPlugin{config: Config{}} + ctx := testResponsesCompatContext() + + transformedReq, shortCircuit, err := plugin.PreLLMHook(ctx, testResponsesCompatRequest(schemas.ResponsesStreamRequest)) + if err != nil { + t.Fatalf("PreLLMHook returned error: %v", err) + } + if shortCircuit != nil { + t.Fatal("expected no short circuit") + } + if transformedReq == nil || transformedReq.ChatRequest == nil { + t.Fatal("expected chat stream request after transform") + } + if transformedReq.RequestType != schemas.ChatCompletionStreamRequest { + t.Fatalf("expected request type %s, got %s", schemas.ChatCompletionStreamRequest, transformedReq.RequestType) + } + + state, ok := schemas.GetResponsesToChatCompletionCompatState(ctx) + if !ok || state == nil { + t.Fatal("expected compat state to be stored for forced streaming fallback") + } + if !state.Active || !state.IsStreaming || state.FallbackRequest == nil { + t.Fatalf("expected active streaming fallback state, got %+v", state) + } +} diff --git a/plugins/compat/runtime_fallback_integration_test.go b/plugins/compat/runtime_fallback_integration_test.go new file mode 100644 index 0000000000..c657d18931 --- /dev/null +++ b/plugins/compat/runtime_fallback_integration_test.go @@ -0,0 +1,388 @@ +package compat + +import ( + "context" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/bytedance/sonic" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" +) + +type integrationAccount struct { + mu sync.RWMutex + configs map[schemas.ModelProvider]*schemas.ProviderConfig +} + +func newIntegrationAccount(baseURL string, supportsResponsesAPI *bool, allowedRequests *schemas.AllowedRequests) *integrationAccount { + providerKey := schemas.ModelProvider("lmstudio") + return &integrationAccount{ + configs: map[schemas.ModelProvider]*schemas.ProviderConfig{ + providerKey: &schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + BaseURL: baseURL, + DefaultRequestTimeoutInSeconds: 5, + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 1, + BufferSize: 8, + }, + CustomProviderConfig: &schemas.CustomProviderConfig{ + CustomProviderKey: string(providerKey), + BaseProviderType: schemas.OpenAI, + SupportsResponsesAPI: supportsResponsesAPI, + AllowedRequests: allowedRequests, + IsKeyLess: true, + }, + }, + }, + } +} + +func (a *integrationAccount) GetConfiguredProviders() ([]schemas.ModelProvider, error) { + a.mu.RLock() + defer a.mu.RUnlock() + + providers := make([]schemas.ModelProvider, 0, len(a.configs)) + for providerKey := range a.configs { + providers = append(providers, providerKey) + } + + return providers, nil +} + +func (a *integrationAccount) GetKeysForProvider(context.Context, schemas.ModelProvider) ([]schemas.Key, error) { + return nil, nil +} + +func (a *integrationAccount) GetConfigForProvider(providerKey schemas.ModelProvider) (*schemas.ProviderConfig, error) { + a.mu.RLock() + defer a.mu.RUnlock() + + config, ok := a.configs[providerKey] + if !ok { + return nil, fmt.Errorf("provider %s not configured", providerKey) + } + + configCopy := *config + if config.CustomProviderConfig != nil { + customProviderConfigCopy := *config.CustomProviderConfig + configCopy.CustomProviderConfig = &customProviderConfigCopy + } + + return &configCopy, nil +} + +func newIntegrationBifrost(t *testing.T, serverURL string, supportsResponsesAPI *bool, allowedRequests *schemas.AllowedRequests) *bifrost.Bifrost { + t.Helper() + + logger := bifrost.NewDefaultLogger(schemas.LogLevelError) + plugin, err := Init(Config{}, logger, nil) + if err != nil { + t.Fatalf("InitWithModelCatalog returned error: %v", err) + } + + instance, err := bifrost.Init(context.Background(), schemas.BifrostConfig{ + Account: newIntegrationAccount(serverURL, supportsResponsesAPI, allowedRequests), + Logger: logger, + LLMPlugins: []schemas.LLMPlugin{plugin}, + InitialPoolSize: 1, + }) + if err != nil { + t.Fatalf("bifrost.Init returned error: %v", err) + } + + return instance +} + +func integrationResponsesRequest() *schemas.BifrostResponsesRequest { + content := "hello" + return &schemas.BifrostResponsesRequest{ + Provider: schemas.ModelProvider("lmstudio"), + Model: "test-model", + Input: []schemas.ResponsesMessage{{ + Role: schemas.Ptr(schemas.ResponsesInputMessageRoleUser), + Content: &schemas.ResponsesMessageContent{ + ContentStr: &content, + }, + }}, + Params: &schemas.ResponsesParameters{ + MaxOutputTokens: schemas.Ptr(7), + }, + } +} + +func integrationChatCompletionBody(text string) []byte { + finishReason := string(schemas.BifrostFinishReasonStop) + response := &schemas.BifrostChatResponse{ + ID: "chatcmpl-test", + Object: "chat.completion", + Created: 1, + Model: "test-model", + Choices: []schemas.BifrostResponseChoice{{ + Index: 0, + FinishReason: &finishReason, + ChatNonStreamResponseChoice: &schemas.ChatNonStreamResponseChoice{ + Message: &schemas.ChatMessage{ + Role: schemas.ChatMessageRoleAssistant, + Content: &schemas.ChatMessageContent{ + ContentStr: &text, + }, + }, + }, + }}, + Usage: &schemas.BifrostLLMUsage{PromptTokens: 1, CompletionTokens: 1, TotalTokens: 2}, + } + body, _ := sonic.Marshal(response) + return body +} + +func collectRuntimeFallbackStream(t *testing.T, stream chan *schemas.BifrostStreamChunk) []*schemas.BifrostResponsesStreamResponse { + t.Helper() + + responses := make([]*schemas.BifrostResponsesStreamResponse, 0) + timeout := time.After(3 * time.Second) + + for { + select { + case chunk, ok := <-stream: + if !ok { + return responses + } + if chunk == nil { + continue + } + if chunk.BifrostError != nil { + t.Fatalf("unexpected stream error: %+v", chunk.BifrostError) + } + if chunk.BifrostResponsesStreamResponse != nil { + responses = append(responses, chunk.BifrostResponsesStreamResponse) + } + case <-timeout: + t.Fatal("timed out waiting for stream to complete") + } + } +} + +func TestResponsesRequest_RuntimeFallbackRunsThroughBifrostCompat(t *testing.T) { + t.Parallel() + + var chatHits atomic.Int32 + var responsesHits atomic.Int32 + var chatRequestBody atomic.Value + var responsesRequestBody atomic.Value + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + bodyBytes, err := io.ReadAll(r.Body) + if err != nil { + t.Fatalf("failed to read request body: %v", err) + } + + switch r.URL.Path { + case "/v1/responses": + responsesHits.Add(1) + responsesRequestBody.Store(string(bodyBytes)) + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"error":{"message":"responses endpoint unsupported"}}`)) + case "/v1/chat/completions": + chatHits.Add(1) + chatRequestBody.Store(string(bodyBytes)) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write(integrationChatCompletionBody("runtime fallback response")) + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer server.Close() + + instance := newIntegrationBifrost(t, server.URL, nil, &schemas.AllowedRequests{Responses: true}) + defer instance.Shutdown() + + ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + response, bifrostErr := instance.ResponsesRequest(ctx, integrationResponsesRequest()) + if bifrostErr != nil { + t.Fatalf("ResponsesRequest returned error: %v", bifrostErr.Error) + } + if response == nil { + t.Fatal("ResponsesRequest returned nil response") + } + if responsesHits.Load() != 1 { + t.Fatalf("expected one native responses attempt, got %d", responsesHits.Load()) + } + if chatHits.Load() != 1 { + t.Fatalf("expected one fallback chat attempt, got %d", chatHits.Load()) + } + + responsesBody, _ := responsesRequestBody.Load().(string) + if !strings.Contains(responsesBody, `"input"`) { + t.Fatalf("expected initial responses payload to contain input, got %s", responsesBody) + } + chatBody, _ := chatRequestBody.Load().(string) + if !strings.Contains(chatBody, `"messages"`) { + t.Fatalf("expected fallback chat payload to contain messages, got %s", chatBody) + } + if strings.Contains(chatBody, `"input"`) { + t.Fatalf("expected fallback chat payload to omit responses input, got %s", chatBody) + } + + if !response.ExtraFields.ResponsesToChatCompletionFallback { + t.Fatal("expected responses-to-chat fallback marker on fallback response") + } + if response.ExtraFields.ResponsesToChatCompletionFallbackReason != string(schemas.ResponsesToChatCompletionFallbackReasonRuntimeUnsupported) { + t.Fatalf("expected fallback reason %q, got %q", schemas.ResponsesToChatCompletionFallbackReasonRuntimeUnsupported, response.ExtraFields.ResponsesToChatCompletionFallbackReason) + } + if got := response.Output[0].Content.ContentBlocks[0].Text; got == nil || *got != "runtime fallback response" { + t.Fatalf("expected converted output text runtime fallback response, got %+v", got) + } + if reason, ok := schemas.GetResponsesToChatCompletionFallback(ctx); !ok || reason != schemas.ResponsesToChatCompletionFallbackReasonRuntimeUnsupported { + t.Fatalf("expected runtime fallback context reason, got %q (ok=%v)", reason, ok) + } +} + +func TestResponsesRequest_ConfiguredFallbackBypassesChatAllowedRequestsGate(t *testing.T) { + t.Parallel() + + var chatHits atomic.Int32 + var responsesHits atomic.Int32 + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, err := io.ReadAll(r.Body) + if err != nil { + t.Fatalf("failed to read request body: %v", err) + } + + switch r.URL.Path { + case "/v1/responses": + responsesHits.Add(1) + w.WriteHeader(http.StatusNotFound) + case "/v1/chat/completions": + chatHits.Add(1) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write(integrationChatCompletionBody("configured fallback response")) + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer server.Close() + + instance := newIntegrationBifrost(t, server.URL, schemas.Ptr(false), &schemas.AllowedRequests{Responses: true}) + defer instance.Shutdown() + + ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + response, bifrostErr := instance.ResponsesRequest(ctx, integrationResponsesRequest()) + if bifrostErr != nil { + t.Fatalf("ResponsesRequest returned error: %v", bifrostErr.Error) + } + if response == nil { + t.Fatal("ResponsesRequest returned nil response") + } + if responsesHits.Load() != 0 { + t.Fatalf("expected zero native responses attempts during configured fallback, got %d", responsesHits.Load()) + } + if chatHits.Load() != 1 { + t.Fatalf("expected one fallback chat attempt, got %d", chatHits.Load()) + } + if got := response.Output[0].Content.ContentBlocks[0].Text; got == nil || *got != "configured fallback response" { + t.Fatalf("expected converted output text configured fallback response, got %+v", got) + } + if reason, ok := schemas.GetResponsesToChatCompletionFallback(ctx); !ok || reason != schemas.ResponsesToChatCompletionFallbackReasonConfiguredUnsupported { + t.Fatalf("expected configured fallback context reason, got %q (ok=%v)", reason, ok) + } +} + +func TestResponsesStreamRequest_RuntimeFallbackRunsThroughBifrostCompat(t *testing.T) { + t.Parallel() + + var chatHits atomic.Int32 + var responsesHits atomic.Int32 + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, err := io.ReadAll(r.Body) + if err != nil { + t.Fatalf("failed to read request body: %v", err) + } + + switch r.URL.Path { + case "/v1/responses": + responsesHits.Add(1) + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"error":{"message":"responses endpoint unsupported"}}`)) + case "/v1/chat/completions": + chatHits.Add(1) + w.Header().Set("Content-Type", "text/event-stream") + flusher, ok := w.(http.Flusher) + if !ok { + t.Fatal("response writer does not implement http.Flusher") + } + + chunks := []string{ + `{"id":"chatcmpl-test","object":"chat.completion.chunk","created":1,"model":"test-model","choices":[{"index":0,"delta":{"role":"assistant"}}]}`, + `{"id":"chatcmpl-test","object":"chat.completion.chunk","created":1,"model":"test-model","choices":[{"index":0,"delta":{"content":"runtime fallback stream"}}]}`, + `{"id":"chatcmpl-test","object":"chat.completion.chunk","created":1,"model":"test-model","choices":[{"index":0,"delta":{},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}}`, + } + + for _, chunk := range chunks { + _, _ = fmt.Fprintf(w, "data: %s\n\n", chunk) + flusher.Flush() + } + _, _ = fmt.Fprint(w, "data: [DONE]\n\n") + flusher.Flush() + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer server.Close() + + instance := newIntegrationBifrost(t, server.URL, nil, &schemas.AllowedRequests{ResponsesStream: true}) + defer instance.Shutdown() + + ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + streamChan, bifrostErr := instance.ResponsesStreamRequest(ctx, integrationResponsesRequest()) + if bifrostErr != nil { + t.Fatalf("ResponsesStreamRequest returned error: %v", bifrostErr.Error) + } + + responses := collectRuntimeFallbackStream(t, streamChan) + if responsesHits.Load() != 1 { + t.Fatalf("expected one native responses stream attempt, got %d", responsesHits.Load()) + } + if chatHits.Load() != 1 { + t.Fatalf("expected one fallback chat stream attempt, got %d", chatHits.Load()) + } + + seenTypes := map[schemas.ResponsesStreamResponseType]bool{} + for _, response := range responses { + seenTypes[response.Type] = true + if response.ExtraFields.RequestType != schemas.ResponsesStreamRequest { + t.Fatalf("expected request type %s, got %s", schemas.ResponsesStreamRequest, response.ExtraFields.RequestType) + } + if !response.ExtraFields.ResponsesToChatCompletionFallback { + t.Fatal("expected responses-to-chat fallback marker on streamed fallback response") + } + if response.ExtraFields.ResponsesToChatCompletionFallbackReason != string(schemas.ResponsesToChatCompletionFallbackReasonRuntimeUnsupported) { + t.Fatalf("expected fallback reason %q, got %q", schemas.ResponsesToChatCompletionFallbackReasonRuntimeUnsupported, response.ExtraFields.ResponsesToChatCompletionFallbackReason) + } + } + + if !seenTypes[schemas.ResponsesStreamResponseTypeCreated] { + t.Fatalf("expected response.created event, got %#v", seenTypes) + } + if !seenTypes[schemas.ResponsesStreamResponseTypeOutputTextDelta] { + t.Fatalf("expected response.output_text.delta event, got %#v", seenTypes) + } + if !seenTypes[schemas.ResponsesStreamResponseTypeCompleted] { + t.Fatalf("expected response.completed event, got %#v", seenTypes) + } + if reason, ok := schemas.GetResponsesToChatCompletionFallback(ctx); !ok || reason != schemas.ResponsesToChatCompletionFallbackReasonRuntimeUnsupported { + t.Fatalf("expected runtime fallback context reason, got %q (ok=%v)", reason, ok) + } +} diff --git a/transports/config.schema.json b/transports/config.schema.json index 3a65af7b49..13e07a4e16 100644 --- a/transports/config.schema.json +++ b/transports/config.schema.json @@ -4061,6 +4061,10 @@ ], "description": "Base provider type to extend" }, + "supports_responses_api": { + "type": "boolean", + "description": "Controls native OpenAI Responses API usage for custom OpenAI-compatible providers. true forces /v1/responses, false forces an internal Chat Completions fallback, and omitting it keeps native behavior first with an automatic downgrade on obvious unsupported-endpoint failures." + }, "request_path_overrides": { "type": "object", "description": "Mapping of request type to custom path overriding the default provider path",