diff --git a/core/bifrost.go b/core/bifrost.go index db562495b0..bea3ae846c 100644 --- a/core/bifrost.go +++ b/core/bifrost.go @@ -3821,9 +3821,10 @@ func (bifrost *Bifrost) GetProviderByKey(providerKey schemas.ModelProvider) sche return bifrost.getProviderByKey(providerKey) } -// SelectKeyForProvider selects an API key for the given provider and model. -// Used by WebSocket handlers that need a key for upstream connections. -func (bifrost *Bifrost) SelectKeyForProvider(ctx *schemas.BifrostContext, providerKey schemas.ModelProvider, model string) (schemas.Key, error) { +// SelectKeyForProviderRequestType selects an API key for the given provider, request type, and model. +// Used by WebSocket handlers that need a key for upstream connections while honoring request-specific +// AllowedRequests gates such as realtime-only support. +func (bifrost *Bifrost) SelectKeyForProviderRequestType(ctx *schemas.BifrostContext, requestType schemas.RequestType, providerKey schemas.ModelProvider, model string) (schemas.Key, error) { if ctx == nil { ctx = bifrost.ctx } @@ -3832,7 +3833,7 @@ func (bifrost *Bifrost) SelectKeyForProvider(ctx *schemas.BifrostContext, provid config.CustomProviderConfig != nil && config.CustomProviderConfig.BaseProviderType != "" { baseProvider = config.CustomProviderConfig.BaseProviderType } - return bifrost.selectKeyFromProviderForModel(ctx, schemas.WebSocketResponsesRequest, providerKey, model, baseProvider) + return bifrost.selectKeyFromProviderForModel(ctx, requestType, providerKey, model, baseProvider) } // WSStreamHooks holds the post-hook runner and cleanup function returned by RunStreamPreHooks. @@ -3846,6 +3847,13 @@ type WSStreamHooks struct { ShortCircuitResponse *schemas.BifrostResponse } +// RealtimeTurnHooks mirrors RunStreamPreHooks but is explicitly scoped to a +// single realtime turn rather than one long-lived transport connection. +type RealtimeTurnHooks struct { + PostHookRunner schemas.PostHookRunner + Cleanup func() +} + // RunStreamPreHooks acquires a plugin pipeline, sets up tracing context, runs PreLLMHooks, // and returns a PostHookRunner for per-chunk post-processing. // Used by WebSocket handlers that bypass the normal inference path but still need plugin hooks. @@ -3884,13 +3892,22 @@ func (bifrost *Bifrost) RunStreamPreHooks(ctx *schemas.BifrostContext, req *sche preReq, shortCircuit, preCount := pipeline.RunLLMPreHooks(ctx, req) if preReq == nil && shortCircuit == nil { + bifrostErr := newBifrostErrorFromMsg("bifrost request after plugin hooks cannot be nil") + _, bifrostErr = pipeline.RunPostLLMHooks(ctx, nil, bifrostErr, preCount) + drainAndAttachPluginLogs(ctx) + if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && strings.TrimSpace(traceID) != "" { + tracer.CompleteAndFlushTrace(strings.TrimSpace(traceID)) + } cleanup() - return nil, newBifrostErrorFromMsg("bifrost request after plugin hooks cannot be nil") + return nil, bifrostErr } if shortCircuit != nil { if shortCircuit.Error != nil { _, bifrostErr := pipeline.RunPostLLMHooks(ctx, nil, shortCircuit.Error, preCount) drainAndAttachPluginLogs(ctx) + if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && strings.TrimSpace(traceID) != "" { + tracer.CompleteAndFlushTrace(strings.TrimSpace(traceID)) + } cleanup() if bifrostErr != nil { return nil, bifrostErr @@ -3900,6 +3917,9 @@ func (bifrost *Bifrost) RunStreamPreHooks(ctx *schemas.BifrostContext, req *sche if shortCircuit.Response != nil { resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, shortCircuit.Response, nil, preCount) drainAndAttachPluginLogs(ctx) + if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && strings.TrimSpace(traceID) != "" { + tracer.CompleteAndFlushTrace(strings.TrimSpace(traceID)) + } cleanup() if bifrostErr != nil { return nil, bifrostErr @@ -3934,6 +3954,94 @@ func (bifrost *Bifrost) RunStreamPreHooks(ctx *schemas.BifrostContext, req *sche }, nil } +// RunRealtimeTurnPreHooks acquires a plugin pipeline and runs LLM pre-hooks for +// a single realtime turn. Unlike generic stream hooks, realtime turns do not +// support short-circuit responses in v1 because the transports cannot yet emit a +// fully synthetic assistant turn without an upstream generation. +func (bifrost *Bifrost) RunRealtimeTurnPreHooks(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*RealtimeTurnHooks, *schemas.BifrostError) { + if req == nil { + bifrostErr := newBifrostErrorFromMsg("realtime turn request is nil") + bifrostErr.ExtraFields.RequestType = schemas.RealtimeRequest + return nil, bifrostErr + } + if ctx == nil { + ctx = bifrost.ctx + } + + if _, ok := ctx.Value(schemas.BifrostContextKeyRequestID).(string); !ok { + ctx.SetValue(schemas.BifrostContextKeyRequestID, uuid.New().String()) + } + + tracer := bifrost.getTracer() + ctx.SetValue(schemas.BifrostContextKeyTracer, tracer) + + if _, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); !ok { + traceID := tracer.CreateTrace("") + if traceID != "" { + ctx.SetValue(schemas.BifrostContextKeyTraceID, traceID) + } + } + + pipeline := bifrost.getPluginPipeline() + cleanup := func() { + if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && traceID != "" { + tracer.CleanupStreamAccumulator(traceID) + } + bifrost.releasePluginPipeline(pipeline) + } + provider, model, _ := req.GetRequestFields() + + preReq, shortCircuit, preCount := pipeline.RunLLMPreHooks(ctx, req) + if preReq == nil && shortCircuit == nil { + bifrostErr := newBifrostErrorFromMsg("bifrost request after plugin hooks cannot be nil") + bifrostErr.PopulateExtraFields(schemas.RealtimeRequest, provider, model, model) + _, bifrostErr = pipeline.RunPostLLMHooks(ctx, nil, bifrostErr, preCount) + drainAndAttachPluginLogs(ctx) + if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && strings.TrimSpace(traceID) != "" { + tracer.CompleteAndFlushTrace(strings.TrimSpace(traceID)) + } + cleanup() + return nil, bifrostErr + } + if shortCircuit != nil { + if shortCircuit.Error != nil { + shortCircuit.Error.PopulateExtraFields(schemas.RealtimeRequest, provider, model, model) + _, bifrostErr := pipeline.RunPostLLMHooks(ctx, nil, shortCircuit.Error, preCount) + drainAndAttachPluginLogs(ctx) + if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && strings.TrimSpace(traceID) != "" { + tracer.CompleteAndFlushTrace(strings.TrimSpace(traceID)) + } + cleanup() + if bifrostErr != nil { + return nil, bifrostErr + } + return nil, shortCircuit.Error + } + if shortCircuit.Response != nil { + // Short-circuit responses are not supported for realtime turns (v1). + // Treat this like an error turn so plugins can close pending state cleanly. + bifrostErr := newBifrostErrorFromMsg("realtime turn short-circuit responses are not supported") + bifrostErr.PopulateExtraFields(schemas.RealtimeRequest, provider, model, model) + _, bifrostErr = pipeline.RunPostLLMHooks(ctx, nil, bifrostErr, preCount) + drainAndAttachPluginLogs(ctx) + if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && strings.TrimSpace(traceID) != "" { + tracer.CompleteAndFlushTrace(strings.TrimSpace(traceID)) + } + cleanup() + return nil, bifrostErr + } + } + + return &RealtimeTurnHooks{ + PostHookRunner: func(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) { + resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, result, err, preCount) + drainAndAttachPluginLogs(ctx) + return resp, bifrostErr + }, + Cleanup: cleanup, + }, nil +} + // getProviderByKey retrieves a provider instance from the providers array by its provider key. // Returns the provider if found, or nil if no provider with the given key exists. func (bifrost *Bifrost) getProviderByKey(providerKey schemas.ModelProvider) schemas.Provider { @@ -5664,8 +5772,10 @@ func (p *PluginPipeline) RunPostLLMHooks(ctx *schemas.BifrostContext, resp *sche if runFrom > len(p.llmPlugins) { runFrom = len(p.llmPlugins) } - // Detect streaming mode - if StreamStartTime is set, we're in a streaming context - isStreaming := ctx.Value(schemas.BifrostContextKeyStreamStartTime) != nil + requestType, _, _, _ := GetResponseFields(resp, bifrostErr) + // Realtime turns carry StreamStartTime for plugin latency/final-chunk context, + // but they are finalized as one completed turn, not chunk-by-chunk stream output. + isStreaming := ctx.Value(schemas.BifrostContextKeyStreamStartTime) != nil && requestType != schemas.RealtimeRequest ctx.BlockRestrictedWrites() defer ctx.UnblockRestrictedWrites() var err error diff --git a/core/internal/llmtests/realtime.go b/core/internal/llmtests/realtime.go index 821aeba9eb..400f5f9cda 100644 --- a/core/internal/llmtests/realtime.go +++ b/core/internal/llmtests/realtime.go @@ -43,7 +43,7 @@ func RunRealtimeTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) defer bfCtx.Cancel() - key, err := client.SelectKeyForProvider(bfCtx, testConfig.Provider, testConfig.RealtimeModel) + key, err := client.SelectKeyForProviderRequestType(bfCtx, schemas.RealtimeRequest, testConfig.Provider, testConfig.RealtimeModel) if err != nil { t.Fatalf("failed to select key for provider %s: %v", testConfig.Provider, err) } diff --git a/core/internal/llmtests/websocket_responses.go b/core/internal/llmtests/websocket_responses.go index 420a049fb7..966463dade 100644 --- a/core/internal/llmtests/websocket_responses.go +++ b/core/internal/llmtests/websocket_responses.go @@ -38,7 +38,7 @@ func RunWebSocketResponsesTest(t *testing.T, client *bifrost.Bifrost, ctx contex bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) defer bfCtx.Cancel() - key, err := client.SelectKeyForProvider(bfCtx, testConfig.Provider, testConfig.ChatModel) + key, err := client.SelectKeyForProviderRequestType(bfCtx, schemas.WebSocketResponsesRequest, testConfig.Provider, testConfig.ChatModel) if err != nil { t.Fatalf("failed to select key for provider %s: %v", testConfig.Provider, err) } diff --git a/core/providers/elevenlabs/realtime.go b/core/providers/elevenlabs/realtime.go index f124f58339..a18e1cd514 100644 --- a/core/providers/elevenlabs/realtime.go +++ b/core/providers/elevenlabs/realtime.go @@ -39,6 +39,44 @@ func (provider *ElevenlabsProvider) RealtimeHeaders(key schemas.Key) map[string] return headers } +// SupportsRealtimeWebRTC returns false — ElevenLabs WebRTC SDP exchange is not yet implemented. +func (provider *ElevenlabsProvider) SupportsRealtimeWebRTC() bool { + return false +} + +// ExchangeRealtimeWebRTCSDP is not yet implemented for ElevenLabs. +func (provider *ElevenlabsProvider) ExchangeRealtimeWebRTCSDP(_ *schemas.BifrostContext, _ schemas.Key, _ string, _ string, _ json.RawMessage) (string, *schemas.BifrostError) { + return "", &schemas.BifrostError{ + IsBifrostError: true, + StatusCode: schemas.Ptr(400), + Error: &schemas.ErrorField{Type: schemas.Ptr("invalid_request_error"), Message: "WebRTC SDP exchange is not yet implemented for ElevenLabs"}, + } +} + +func (provider *ElevenlabsProvider) ShouldStartRealtimeTurn(event *schemas.BifrostRealtimeEvent) bool { + return false +} + +func (provider *ElevenlabsProvider) RealtimeTurnFinalEvent() schemas.RealtimeEventType { + return schemas.RTEventResponseDone +} + +func (provider *ElevenlabsProvider) RealtimeWebRTCDataChannelLabel() string { + return "" +} + +func (provider *ElevenlabsProvider) RealtimeWebSocketSubprotocol() string { + return "" +} + +func (provider *ElevenlabsProvider) ShouldForwardRealtimeEvent(event *schemas.BifrostRealtimeEvent) bool { + return true +} + +func (provider *ElevenlabsProvider) ShouldAccumulateRealtimeOutput(eventType schemas.RealtimeEventType) bool { + return eventType == schemas.RTEventResponseDone +} + // ElevenLabs Conversational AI WebSocket event types const ( elConversationInitMetadata = "conversation_initiation_metadata" @@ -50,8 +88,8 @@ const ( elInterruption = "interruption" elClientToolCall = "client_tool_call" - elUserAudioChunk = "user_audio_chunk" - elPong = "pong" + elUserAudioChunk = "user_audio_chunk" + elPong = "pong" elClientToolResult = "client_tool_result" elContextualUpdate = "contextual_update" ) @@ -134,7 +172,7 @@ func (provider *ElevenlabsProvider) ToBifrostRealtimeEvent(providerEvent json.Ra } case elAgentResponse: - event.Type = schemas.RTEventResponseTextDone + event.Type = schemas.RTEventResponseDone if raw.AgentResponse != nil { var agentResp elevenlabsTranscriptEvent if err := json.Unmarshal(raw.AgentResponse, &agentResp); err == nil { @@ -194,10 +232,6 @@ func (provider *ElevenlabsProvider) ToBifrostRealtimeEvent(providerEvent json.Ra // ToProviderRealtimeEvent converts a unified Bifrost Realtime event to ElevenLabs' native JSON. func (provider *ElevenlabsProvider) ToProviderRealtimeEvent(bifrostEvent *schemas.BifrostRealtimeEvent) (json.RawMessage, error) { - if bifrostEvent.RawData != nil { - return bifrostEvent.RawData, nil - } - switch bifrostEvent.Type { case schemas.RTEventInputAudioAppend: if bifrostEvent.Delta == nil { diff --git a/core/providers/openai/realtime.go b/core/providers/openai/realtime.go index b73db4ea24..8c88382297 100644 --- a/core/providers/openai/realtime.go +++ b/core/providers/openai/realtime.go @@ -1,13 +1,17 @@ package openai import ( + "bytes" "encoding/json" "fmt" + "mime/multipart" + "net/http" "net/url" "strings" providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" + "github.com/valyala/fasthttp" ) // SupportsRealtimeAPI returns true since OpenAI natively supports the Realtime API. @@ -28,7 +32,6 @@ func (provider *OpenAIProvider) RealtimeWebSocketURL(key schemas.Key, model stri func (provider *OpenAIProvider) RealtimeHeaders(key schemas.Key) map[string]string { headers := map[string]string{ "Authorization": "Bearer " + key.Value.GetValue(), - "OpenAI-Beta": "realtime=v1", } for k, v := range provider.networkConfig.ExtraHeaders { headers[k] = v @@ -36,6 +39,380 @@ func (provider *OpenAIProvider) RealtimeHeaders(key schemas.Key) map[string]stri return headers } +// SupportsRealtimeWebRTC reports that OpenAI supports WebRTC SDP exchange. +func (provider *OpenAIProvider) SupportsRealtimeWebRTC() bool { + return true +} + +// ExchangeRealtimeWebRTCSDP performs the GA SDP exchange via multipart POST to /v1/realtime/calls. +func (provider *OpenAIProvider) ExchangeRealtimeWebRTCSDP( + ctx *schemas.BifrostContext, + key schemas.Key, + model string, + sdp string, + session json.RawMessage, +) (string, *schemas.BifrostError) { + path := "/v1/realtime/calls" + if session == nil && strings.TrimSpace(model) != "" { + path += "?model=" + url.QueryEscape(model) + } + return provider.exchangeWebRTCSDP(ctx, key, path, sdp, session) +} + +// ExchangeLegacyRealtimeWebRTCSDP performs the beta SDP exchange via multipart POST to /v1/realtime. +// Same multipart format but targets the legacy endpoint with model in the URL. +func (provider *OpenAIProvider) ExchangeLegacyRealtimeWebRTCSDP( + ctx *schemas.BifrostContext, + key schemas.Key, + sdp string, + session json.RawMessage, + model string, +) (string, *schemas.BifrostError) { + return provider.exchangeWebRTCSDP(ctx, key, "/v1/realtime?model="+url.QueryEscape(model), sdp, session) +} + +// exchangeWebRTCSDP is the shared multipart SDP exchange implementation. +// Builds a multipart body with sdp + optional session, POSTs to the given path. +func (provider *OpenAIProvider) exchangeWebRTCSDP( + ctx *schemas.BifrostContext, + key schemas.Key, + path string, + sdp string, + session json.RawMessage, +) (string, *schemas.BifrostError) { + bodyBuf := &bytes.Buffer{} + writer := multipart.NewWriter(bodyBuf) + if err := writer.WriteField("sdp", sdp); err != nil { + return "", newRealtimeWebRTCSDPError(fasthttp.StatusInternalServerError, "server_error", "failed to encode upstream SDP body", err) + } + if session != nil { + if err := writer.WriteField("session", string(session)); err != nil { + return "", newRealtimeWebRTCSDPError(fasthttp.StatusInternalServerError, "server_error", "failed to encode upstream session body", err) + } + } + if err := writer.Close(); err != nil { + return "", newRealtimeWebRTCSDPError(fasthttp.StatusInternalServerError, "server_error", "failed to finalize upstream SDP body", err) + } + + req := fasthttp.AcquireRequest() + resp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseRequest(req) + defer fasthttp.ReleaseResponse(resp) + + req.SetRequestURI(provider.buildRequestURL(ctx, path, schemas.RealtimeRequest)) + req.Header.SetMethod(http.MethodPost) + req.Header.SetContentType(writer.FormDataContentType()) + req.Header.Set("Authorization", "Bearer "+key.Value.GetValue()) + for k, v := range provider.networkConfig.ExtraHeaders { + req.Header.Set(k, v) + } + if headers, _ := ctx.Value(schemas.BifrostContextKeyRequestHeaders).(map[string]string); headers != nil { + if agentsSDK := headers["x-openai-agents-sdk"]; agentsSDK != "" { + req.Header.Set("X-OpenAI-Agents-SDK", agentsSDK) + } + } + req.SetBody(bodyBuf.Bytes()) + + _, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp) + defer wait() + if bifrostErr != nil { + return "", bifrostErr + } + + answerBody := resp.Body() + if resp.StatusCode() < fasthttp.StatusOK || resp.StatusCode() >= fasthttp.StatusMultipleChoices { + return "", provider.realtimeWebRTCUpstreamError(ctx, resp.StatusCode(), answerBody) + } + + return string(answerBody), nil +} + +func (provider *OpenAIProvider) realtimeWebRTCUpstreamError(ctx *schemas.BifrostContext, statusCode int, body []byte) *schemas.BifrostError { + bifrostErr := &schemas.BifrostError{ + IsBifrostError: false, + StatusCode: schemas.Ptr(fasthttp.StatusBadGateway), + Error: &schemas.ErrorField{ + Type: schemas.Ptr("upstream_connection_error"), + Message: fmt.Sprintf("upstream realtime WebRTC handshake failed for %s", provider.GetProviderKey()), + }, + ExtraFields: schemas.BifrostErrorExtraFields{ + RequestType: schemas.RealtimeRequest, + Provider: provider.GetProviderKey(), + }, + } + if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { + bifrostErr.ExtraFields.RawResponse = map[string]any{ + "status": statusCode, + "body": string(body), + } + } + return bifrostErr +} + +func newRealtimeWebRTCSDPError(status int, errorType, message string, err error) *schemas.BifrostError { + bifrostErr := &schemas.BifrostError{ + IsBifrostError: true, + StatusCode: schemas.Ptr(status), + Error: &schemas.ErrorField{ + Type: schemas.Ptr(errorType), + Message: message, + }, + } + if err != nil { + bifrostErr.Error.Error = err + } + return bifrostErr +} + +func (provider *OpenAIProvider) ShouldStartRealtimeTurn(event *schemas.BifrostRealtimeEvent) bool { + if event == nil { + return false + } + switch event.Type { + case schemas.RTEventResponseCreate, schemas.RTEventInputAudioBufferCommitted: + return true + default: + return false + } +} + +func (provider *OpenAIProvider) RealtimeTurnFinalEvent() schemas.RealtimeEventType { + return schemas.RTEventResponseDone +} + +func (provider *OpenAIProvider) RealtimeWebRTCDataChannelLabel() string { + return "oai-events" +} + +func (provider *OpenAIProvider) RealtimeWebSocketSubprotocol() string { + return "realtime" +} + +func (provider *OpenAIProvider) ShouldForwardRealtimeEvent(event *schemas.BifrostRealtimeEvent) bool { + return true +} + +func (provider *OpenAIProvider) ShouldAccumulateRealtimeOutput(eventType schemas.RealtimeEventType) bool { + switch eventType { + case schemas.RTEventResponseTextDelta, + schemas.RTEventResponseAudioTransDelta, + schemas.RealtimeEventType("response.output_text.delta"), + schemas.RealtimeEventType("response.output_audio_transcript.delta"): + return true + default: + return false + } +} + +// CreateRealtimeClientSecret mints an OpenAI Realtime client secret and returns +// the native OpenAI response body unchanged. +func (provider *OpenAIProvider) CreateRealtimeClientSecret( + ctx *schemas.BifrostContext, + key schemas.Key, + endpointType schemas.RealtimeSessionEndpointType, + rawRequest json.RawMessage, +) (*schemas.BifrostPassthroughResponse, *schemas.BifrostError) { + if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.RealtimeRequest); err != nil { + return nil, err + } + + normalizedBody, requestedModel, bifrostErr := normalizeRealtimeClientSecretRequest(rawRequest, provider.GetProviderKey(), endpointType) + if bifrostErr != nil { + return nil, bifrostErr + } + req := fasthttp.AcquireRequest() + resp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseRequest(req) + defer fasthttp.ReleaseResponse(resp) + + req.SetRequestURI(provider.buildRequestURL(ctx, realtimeSessionUpstreamPath(endpointType), schemas.RealtimeRequest)) + req.Header.SetMethod(http.MethodPost) + req.Header.SetContentType("application/json") + for k, v := range provider.realtimeSessionHeaders(key, endpointType) { + req.Header.Set(k, v) + } + req.SetBody(normalizedBody) + + latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp) + defer wait() + if bifrostErr != nil { + return nil, bifrostErr + } + + headers := providerUtils.ExtractProviderResponseHeaders(resp) + ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, headers) + + if resp.StatusCode() < fasthttp.StatusOK || resp.StatusCode() >= fasthttp.StatusMultipleChoices { + return nil, ParseOpenAIError(resp) + } + + body, err := providerUtils.CheckAndDecodeBody(resp) + if err != nil { + return nil, providerUtils.NewBifrostOperationError("failed to decode response body", err) + } + for k := range headers { + if strings.EqualFold(k, "Content-Encoding") || strings.EqualFold(k, "Content-Length") { + delete(headers, k) + } + } + + out := &schemas.BifrostPassthroughResponse{ + StatusCode: resp.StatusCode(), + Headers: headers, + Body: body, + } + out.ExtraFields.Provider = provider.GetProviderKey() + out.ExtraFields.OriginalModelRequested = requestedModel + out.ExtraFields.RequestType = schemas.RealtimeRequest + out.ExtraFields.Latency = latency.Milliseconds() + if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { + providerUtils.ParseAndSetRawRequestIfJSON(req, &out.ExtraFields) + } + + return out, nil +} + +func normalizeRealtimeClientSecretRequest( + rawRequest json.RawMessage, + defaultProvider schemas.ModelProvider, + endpointType schemas.RealtimeSessionEndpointType, +) ([]byte, string, *schemas.BifrostError) { + root, bifrostErr := schemas.ParseRealtimeClientSecretBody(rawRequest) + if bifrostErr != nil { + return nil, "", bifrostErr + } + + modelValue, bifrostErr := schemas.ExtractRealtimeClientSecretModel(root) + if bifrostErr != nil { + return nil, "", bifrostErr + } + providerKey, normalizedModel := schemas.ParseModelString(modelValue, defaultProvider) + if normalizedModel == "" { + return nil, "", newRealtimeClientSecretError(fasthttp.StatusBadRequest, "invalid_request_error", "session.model is required", nil) + } + if providerKey == "" { + providerKey = defaultProvider + } + if providerKey == "" { + return nil, "", newRealtimeClientSecretError(fasthttp.StatusBadRequest, "invalid_request_error", "unable to determine provider from model", nil) + } + + if endpointType == schemas.RealtimeSessionEndpointSessions { + return normalizeRealtimeSessionsRequest(root, normalizedModel) + } + + return normalizeRealtimeClientSecretsRequest(root, normalizedModel) +} + +func normalizeRealtimeClientSecretsRequest( + root map[string]json.RawMessage, + normalizedModel string, +) ([]byte, string, *schemas.BifrostError) { + session := map[string]json.RawMessage{} + if existingSession, ok := root["session"]; ok && len(existingSession) > 0 && !bytes.Equal(existingSession, []byte("null")) { + if err := json.Unmarshal(existingSession, &session); err != nil { + return nil, "", newRealtimeClientSecretError(fasthttp.StatusBadRequest, "invalid_request_error", "session must be an object", err) + } + } + + modelJSON, marshalErr := json.Marshal(normalizedModel) + if marshalErr != nil { + return nil, "", newRealtimeClientSecretError(fasthttp.StatusInternalServerError, "server_error", "failed to encode normalized model", marshalErr) + } + session["model"] = modelJSON + if _, ok := session["type"]; !ok { + typeJSON, marshalErr := json.Marshal("realtime") + if marshalErr != nil { + return nil, "", newRealtimeClientSecretError(fasthttp.StatusInternalServerError, "server_error", "failed to encode realtime session type", marshalErr) + } + session["type"] = typeJSON + } + delete(root, "model") + + sessionJSON, marshalErr := json.Marshal(session) + if marshalErr != nil { + return nil, "", newRealtimeClientSecretError(fasthttp.StatusInternalServerError, "server_error", "failed to encode realtime session", marshalErr) + } + root["session"] = sessionJSON + + normalizedBody, marshalErr := json.Marshal(root) + if marshalErr != nil { + return nil, "", newRealtimeClientSecretError(fasthttp.StatusInternalServerError, "server_error", "failed to encode realtime request", marshalErr) + } + + return normalizedBody, normalizedModel, nil +} + +func normalizeRealtimeSessionsRequest( + root map[string]json.RawMessage, + normalizedModel string, +) ([]byte, string, *schemas.BifrostError) { + if existingSession, ok := root["session"]; ok && len(existingSession) > 0 && !bytes.Equal(existingSession, []byte("null")) { + session := map[string]json.RawMessage{} + if err := json.Unmarshal(existingSession, &session); err != nil { + return nil, "", newRealtimeClientSecretError(fasthttp.StatusBadRequest, "invalid_request_error", "session must be an object", err) + } + for key, value := range session { + if _, exists := root[key]; !exists { + root[key] = value + } + } + } + + modelJSON, marshalErr := json.Marshal(normalizedModel) + if marshalErr != nil { + return nil, "", newRealtimeClientSecretError(fasthttp.StatusInternalServerError, "server_error", "failed to encode normalized model", marshalErr) + } + root["model"] = modelJSON + delete(root, "session") + + normalizedBody, marshalErr := json.Marshal(root) + if marshalErr != nil { + return nil, "", newRealtimeClientSecretError(fasthttp.StatusInternalServerError, "server_error", "failed to encode realtime request", marshalErr) + } + + return normalizedBody, normalizedModel, nil +} + +func (provider *OpenAIProvider) realtimeSessionHeaders( + key schemas.Key, + endpointType schemas.RealtimeSessionEndpointType, +) map[string]string { + headers := map[string]string{ + "Authorization": "Bearer " + key.Value.GetValue(), + } + if endpointType == schemas.RealtimeSessionEndpointSessions { + headers["OpenAI-Beta"] = "realtime=v1" + } + for k, v := range provider.networkConfig.ExtraHeaders { + headers[k] = v + } + return headers +} + +func realtimeSessionUpstreamPath(endpointType schemas.RealtimeSessionEndpointType) string { + if endpointType == schemas.RealtimeSessionEndpointSessions { + return "/v1/realtime/sessions" + } + return "/v1/realtime/client_secrets" +} + +func newRealtimeClientSecretError(status int, errorType, message string, err error) *schemas.BifrostError { + return &schemas.BifrostError{ + IsBifrostError: false, + StatusCode: schemas.Ptr(status), + Error: &schemas.ErrorField{ + Type: schemas.Ptr(errorType), + Message: message, + Error: err, + }, + ExtraFields: schemas.BifrostErrorExtraFields{ + RequestType: schemas.RealtimeRequest, + Provider: schemas.OpenAI, + }, + } +} + // openAIRealtimeEvent is the raw shape of an OpenAI Realtime protocol event. type openAIRealtimeEvent struct { Type string `json:"type"` @@ -44,15 +421,17 @@ type openAIRealtimeEvent struct { Conversation json.RawMessage `json:"conversation,omitempty"` Item json.RawMessage `json:"item,omitempty"` Response json.RawMessage `json:"response,omitempty"` + Part json.RawMessage `json:"part,omitempty"` Delta string `json:"delta,omitempty"` Audio string `json:"audio,omitempty"` Transcript string `json:"transcript,omitempty"` Text string `json:"text,omitempty"` Error json.RawMessage `json:"error,omitempty"` ItemID string `json:"item_id,omitempty"` - OutputIndex int `json:"output_index,omitempty"` - ContentIndex int `json:"content_index,omitempty"` + OutputIndex *int `json:"output_index,omitempty"` + ContentIndex *int `json:"content_index,omitempty"` ResponseID string `json:"response_id,omitempty"` + AudioEndMS *int `json:"audio_end_ms,omitempty"` PreviousItemID string `json:"previous_item_id,omitempty"` } @@ -105,6 +484,17 @@ func (provider *OpenAIProvider) ToBifrostRealtimeEvent(providerEvent json.RawMes EventID: raw.EventID, RawData: providerEvent, } + setRealtimeExtraParam(event, "item_id", raw.ItemID) + setRealtimeExtraParam(event, "previous_item_id", raw.PreviousItemID) + setRealtimeExtraParam(event, "output_index", raw.OutputIndex) + setRealtimeExtraParam(event, "content_index", raw.ContentIndex) + setRealtimeExtraParam(event, "response_id", raw.ResponseID) + setRealtimeExtraParam(event, "audio_end_ms", raw.AudioEndMS) + setRealtimeExtraParam(event, "transcript", raw.Transcript) + setRealtimeExtraParam(event, "text", raw.Text) + setRealtimeExtraParam(event, "conversation", raw.Conversation) + setRealtimeExtraParam(event, "response", raw.Response) + setRealtimeExtraParam(event, "part", raw.Part) switch { case raw.Session != nil: @@ -123,8 +513,10 @@ func (provider *OpenAIProvider) ToBifrostRealtimeEvent(providerEvent json.RawMes OutputAudioType: sess.OutputAudioType, Tools: sess.Tools, } + if extra := extractRealtimeNestedParams(raw.Session, "id", "model", "modalities", "instructions", "voice", "temperature", "max_output_tokens", "turn_detection", "input_audio_format", "output_audio_type", "tools"); len(extra) > 0 { + event.Session.ExtraParams = extra + } } - case raw.Item != nil: var item openAIRealtimeItem if err := json.Unmarshal(raw.Item, &item); err == nil { @@ -139,6 +531,9 @@ func (provider *OpenAIProvider) ToBifrostRealtimeEvent(providerEvent json.RawMes Arguments: item.Arguments, Output: item.Output, } + if extra := extractRealtimeNestedParams(raw.Item, "id", "type", "role", "status", "content", "name", "call_id", "arguments", "output"); len(extra) > 0 { + event.Item.ExtraParams = extra + } } case raw.Error != nil: @@ -150,6 +545,9 @@ func (provider *OpenAIProvider) ToBifrostRealtimeEvent(providerEvent json.RawMes Message: rtErr.Message, Param: rtErr.Param, } + if extra := extractRealtimeNestedParams(raw.Error, "type", "code", "message", "param"); len(extra) > 0 { + event.Error.ExtraParams = extra + } } } @@ -159,8 +557,8 @@ func (provider *OpenAIProvider) ToBifrostRealtimeEvent(providerEvent json.RawMes Audio: raw.Audio, Transcript: raw.Transcript, ItemID: raw.ItemID, - OutputIdx: &raw.OutputIndex, - ContentIdx: &raw.ContentIndex, + OutputIdx: raw.OutputIndex, + ContentIdx: raw.ContentIndex, ResponseID: raw.ResponseID, } if raw.Delta != "" { @@ -175,19 +573,19 @@ func (provider *OpenAIProvider) ToBifrostRealtimeEvent(providerEvent json.RawMes // ToProviderRealtimeEvent converts a unified Bifrost Realtime event back to OpenAI's native JSON. func (provider *OpenAIProvider) ToProviderRealtimeEvent(bifrostEvent *schemas.BifrostRealtimeEvent) (json.RawMessage, error) { - if bifrostEvent.RawData != nil { - return bifrostEvent.RawData, nil - } - out := map[string]interface{}{ "type": string(bifrostEvent.Type), } if bifrostEvent.EventID != "" { out["event_id"] = bifrostEvent.EventID } + mergeRealtimeExtraParams(out, bifrostEvent.ExtraParams) if bifrostEvent.Session != nil { sess := map[string]interface{}{} + if bifrostEvent.Session.ID != "" && bifrostEvent.Type != schemas.RTEventSessionUpdate { + sess["id"] = bifrostEvent.Session.ID + } if bifrostEvent.Session.Model != "" { sess["model"] = bifrostEvent.Session.Model } @@ -218,6 +616,7 @@ func (provider *OpenAIProvider) ToProviderRealtimeEvent(bifrostEvent *schemas.Bi if bifrostEvent.Session.Tools != nil { sess["tools"] = bifrostEvent.Session.Tools } + mergeRealtimeSessionExtraParams(sess, bifrostEvent.Session.ExtraParams, bifrostEvent.Type) out["session"] = sess } @@ -231,6 +630,9 @@ func (provider *OpenAIProvider) ToProviderRealtimeEvent(bifrostEvent *schemas.Bi if bifrostEvent.Item.Role != "" { item["role"] = bifrostEvent.Item.Role } + if bifrostEvent.Item.Status != "" { + item["status"] = bifrostEvent.Item.Status + } if bifrostEvent.Item.Content != nil { item["content"] = bifrostEvent.Item.Content } @@ -246,9 +648,28 @@ func (provider *OpenAIProvider) ToProviderRealtimeEvent(bifrostEvent *schemas.Bi if bifrostEvent.Item.Output != "" { item["output"] = bifrostEvent.Item.Output } + mergeRealtimeExtraParams(item, bifrostEvent.Item.ExtraParams) out["item"] = item } + if bifrostEvent.Error != nil { + rtErr := map[string]interface{}{} + if bifrostEvent.Error.Type != "" { + rtErr["type"] = bifrostEvent.Error.Type + } + if bifrostEvent.Error.Code != "" { + rtErr["code"] = bifrostEvent.Error.Code + } + if bifrostEvent.Error.Message != "" { + rtErr["message"] = bifrostEvent.Error.Message + } + if bifrostEvent.Error.Param != "" { + rtErr["param"] = bifrostEvent.Error.Param + } + mergeRealtimeExtraParams(rtErr, bifrostEvent.Error.ExtraParams) + out["error"] = rtErr + } + if bifrostEvent.Delta != nil { if bifrostEvent.Delta.Text != "" { out["delta"] = bifrostEvent.Delta.Text @@ -259,16 +680,16 @@ func (provider *OpenAIProvider) ToProviderRealtimeEvent(bifrostEvent *schemas.Bi if bifrostEvent.Delta.Transcript != "" { out["transcript"] = bifrostEvent.Delta.Transcript } - if bifrostEvent.Delta.ItemID != "" { + if bifrostEvent.Delta.ItemID != "" && !hasRealtimeExtraParam(bifrostEvent.ExtraParams, "item_id") { out["item_id"] = bifrostEvent.Delta.ItemID } - if bifrostEvent.Delta.OutputIdx != nil { + if bifrostEvent.Delta.OutputIdx != nil && !hasRealtimeExtraParam(bifrostEvent.ExtraParams, "output_index") { out["output_index"] = *bifrostEvent.Delta.OutputIdx } - if bifrostEvent.Delta.ContentIdx != nil { + if bifrostEvent.Delta.ContentIdx != nil && !hasRealtimeExtraParam(bifrostEvent.ExtraParams, "content_index") { out["content_index"] = *bifrostEvent.Delta.ContentIdx } - if bifrostEvent.Delta.ResponseID != "" { + if bifrostEvent.Delta.ResponseID != "" && !hasRealtimeExtraParam(bifrostEvent.ExtraParams, "response_id") { out["response_id"] = bifrostEvent.Delta.ResponseID } } @@ -276,11 +697,269 @@ func (provider *OpenAIProvider) ToProviderRealtimeEvent(bifrostEvent *schemas.Bi return providerUtils.MarshalSorted(out) } +func mergeRealtimeSessionExtraParams(out map[string]interface{}, params map[string]json.RawMessage, eventType schemas.RealtimeEventType) { + filtered := params + if eventType == schemas.RTEventSessionUpdate && len(params) > 0 { + filtered = make(map[string]json.RawMessage, len(params)) + for key, value := range params { + switch key { + case "id", "object", "expires_at", "client_secret": + continue + default: + filtered[key] = value + } + } + } + mergeRealtimeExtraParams(out, filtered) +} + +func (provider *OpenAIProvider) ExtractRealtimeTurnUsage(terminalEventRaw []byte) *schemas.BifrostLLMUsage { + if len(terminalEventRaw) == 0 { + return nil + } + + var parsed openAIRealtimeResponseDoneEnvelope + if err := json.Unmarshal(terminalEventRaw, &parsed); err != nil || parsed.Response.Usage == nil { + return nil + } + + usage := &schemas.BifrostLLMUsage{ + PromptTokens: parsed.Response.Usage.InputTokens, + CompletionTokens: parsed.Response.Usage.OutputTokens, + TotalTokens: parsed.Response.Usage.TotalTokens, + } + + if parsed.Response.Usage.InputTokenDetails != nil { + usage.PromptTokensDetails = &schemas.ChatPromptTokensDetails{ + TextTokens: parsed.Response.Usage.InputTokenDetails.TextTokens, + AudioTokens: parsed.Response.Usage.InputTokenDetails.AudioTokens, + ImageTokens: parsed.Response.Usage.InputTokenDetails.ImageTokens, + CachedReadTokens: parsed.Response.Usage.InputTokenDetails.CachedTokens, + } + } + + if parsed.Response.Usage.OutputTokenDetails != nil { + usage.CompletionTokensDetails = &schemas.ChatCompletionTokensDetails{ + TextTokens: parsed.Response.Usage.OutputTokenDetails.TextTokens, + AudioTokens: parsed.Response.Usage.OutputTokenDetails.AudioTokens, + ReasoningTokens: parsed.Response.Usage.OutputTokenDetails.ReasoningTokens, + ImageTokens: parsed.Response.Usage.OutputTokenDetails.ImageTokens, + CitationTokens: parsed.Response.Usage.OutputTokenDetails.CitationTokens, + NumSearchQueries: parsed.Response.Usage.OutputTokenDetails.NumSearchQueries, + AcceptedPredictionTokens: parsed.Response.Usage.OutputTokenDetails.AcceptedPredictionTokens, + RejectedPredictionTokens: parsed.Response.Usage.OutputTokenDetails.RejectedPredictionTokens, + } + } + + return usage +} + +func (provider *OpenAIProvider) ExtractRealtimeTurnOutput(terminalEventRaw []byte) *schemas.ChatMessage { + if len(terminalEventRaw) == 0 { + return nil + } + + var parsed openAIRealtimeResponseDoneEnvelope + if err := json.Unmarshal(terminalEventRaw, &parsed); err != nil { + return nil + } + + content := extractOpenAIRealtimeResponseDoneAssistantText(parsed.Response.Output) + toolCalls := extractOpenAIRealtimeResponseDoneToolCalls(parsed.Response.Output) + if content == "" && len(toolCalls) == 0 { + return nil + } + + message := &schemas.ChatMessage{Role: schemas.ChatMessageRoleAssistant} + if content != "" { + message.Content = &schemas.ChatMessageContent{ContentStr: schemas.Ptr(content)} + } + if len(toolCalls) > 0 { + message.ChatAssistantMessage = &schemas.ChatAssistantMessage{ToolCalls: toolCalls} + } + + return message +} + +type openAIRealtimeResponseDoneEnvelope struct { + Response struct { + Output []openAIRealtimeResponseDoneOutput `json:"output"` + Usage *openAIRealtimeResponseDoneUsage `json:"usage"` + } `json:"response"` +} + +type openAIRealtimeResponseDoneOutput struct { + ID string `json:"id"` + Type string `json:"type"` + Name string `json:"name"` + CallID string `json:"call_id"` + Arguments string `json:"arguments"` + Content []openAIRealtimeResponseDoneBlock `json:"content"` +} + +type openAIRealtimeResponseDoneBlock struct { + Text string `json:"text"` + Transcript string `json:"transcript"` + Refusal string `json:"refusal"` +} + +type openAIRealtimeResponseDoneUsage struct { + TotalTokens int `json:"total_tokens"` + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + InputTokenDetails *openAIRealtimeResponseDoneInputTokenUsage `json:"input_token_details"` + OutputTokenDetails *openAIRealtimeResponseDoneOutputTokenUsage `json:"output_token_details"` +} + +type openAIRealtimeResponseDoneInputTokenUsage struct { + TextTokens int `json:"text_tokens"` + AudioTokens int `json:"audio_tokens"` + ImageTokens int `json:"image_tokens"` + CachedTokens int `json:"cached_tokens"` +} + +type openAIRealtimeResponseDoneOutputTokenUsage struct { + TextTokens int `json:"text_tokens"` + AudioTokens int `json:"audio_tokens"` + ReasoningTokens int `json:"reasoning_tokens"` + ImageTokens *int `json:"image_tokens"` + CitationTokens *int `json:"citation_tokens"` + NumSearchQueries *int `json:"num_search_queries"` + AcceptedPredictionTokens int `json:"accepted_prediction_tokens"` + RejectedPredictionTokens int `json:"rejected_prediction_tokens"` +} + +func extractOpenAIRealtimeResponseDoneAssistantText(outputs []openAIRealtimeResponseDoneOutput) string { + var sb strings.Builder + for _, output := range outputs { + if output.Type != "message" { + continue + } + for _, block := range output.Content { + switch { + case strings.TrimSpace(block.Text) != "": + sb.WriteString(block.Text) + case strings.TrimSpace(block.Transcript) != "": + sb.WriteString(block.Transcript) + case strings.TrimSpace(block.Refusal) != "": + sb.WriteString(block.Refusal) + } + } + } + return strings.TrimSpace(sb.String()) +} + +func extractOpenAIRealtimeResponseDoneToolCalls(outputs []openAIRealtimeResponseDoneOutput) []schemas.ChatAssistantMessageToolCall { + toolCalls := make([]schemas.ChatAssistantMessageToolCall, 0) + for _, output := range outputs { + if output.Type != "function_call" { + continue + } + + name := strings.TrimSpace(output.Name) + if name == "" { + continue + } + + toolType := "function" + id := strings.TrimSpace(output.CallID) + if id == "" { + id = strings.TrimSpace(output.ID) + } + + toolCall := schemas.ChatAssistantMessageToolCall{ + Index: uint16(len(toolCalls)), + Type: &toolType, + Function: schemas.ChatAssistantMessageToolCallFunction{ + Name: schemas.Ptr(name), + Arguments: output.Arguments, + }, + } + if id != "" { + toolCall.ID = schemas.Ptr(id) + } + + toolCalls = append(toolCalls, toolCall) + } + return toolCalls +} + +func setRealtimeExtraParam(event *schemas.BifrostRealtimeEvent, key string, value any) { + if event == nil || key == "" || value == nil { + return + } + + switch v := value.(type) { + case string: + if v == "" { + return + } + case *int: + if v == nil { + return + } + case json.RawMessage: + if len(v) == 0 || string(v) == "null" { + return + } + } + + raw, err := json.Marshal(value) + if err != nil { + return + } + if event.ExtraParams == nil { + event.ExtraParams = make(map[string]json.RawMessage) + } + event.ExtraParams[key] = raw +} + +func mergeRealtimeExtraParams(out map[string]interface{}, params map[string]json.RawMessage) { + for key, raw := range params { + if len(raw) == 0 { + continue + } + var value any + if err := json.Unmarshal(raw, &value); err != nil { + continue + } + out[key] = value + } +} + +func hasRealtimeExtraParam(params map[string]json.RawMessage, key string) bool { + if params == nil { + return false + } + raw, ok := params[key] + return ok && len(raw) > 0 +} + +func extractRealtimeNestedParams(raw json.RawMessage, knownKeys ...string) map[string]json.RawMessage { + if len(raw) == 0 { + return nil + } + root := map[string]json.RawMessage{} + if err := json.Unmarshal(raw, &root); err != nil { + return nil + } + for _, key := range knownKeys { + delete(root, key) + } + if len(root) == 0 { + return nil + } + return root +} + func isRealtimeDeltaEvent(eventType string) bool { switch eventType { case "response.text.delta", + "response.output_text.delta", "response.audio.delta", + "response.output_audio.delta", "response.audio_transcript.delta", + "response.output_audio_transcript.delta", "conversation.item.input_audio_transcription.delta": return true } diff --git a/core/providers/openai/realtime_test.go b/core/providers/openai/realtime_test.go new file mode 100644 index 0000000000..6b7f76f98f --- /dev/null +++ b/core/providers/openai/realtime_test.go @@ -0,0 +1,561 @@ +package openai + +import ( + "encoding/json" + "testing" + + "github.com/maximhq/bifrost/core/schemas" +) + +func TestNormalizeRealtimeClientSecretRequest(t *testing.T) { + t.Parallel() + + body, model, bifrostErr := normalizeRealtimeClientSecretRequest( + json.RawMessage(`{"model":"openai/gpt-4o-realtime-preview","voice":"alloy"}`), + schemas.OpenAI, + schemas.RealtimeSessionEndpointClientSecrets, + ) + if bifrostErr != nil { + t.Fatalf("normalizeRealtimeClientSecretRequest() error = %v", bifrostErr) + } + if model != "gpt-4o-realtime-preview" { + t.Fatalf("model = %q, want %q", model, "gpt-4o-realtime-preview") + } + + var payload map[string]json.RawMessage + if err := json.Unmarshal(body, &payload); err != nil { + t.Fatalf("failed to unmarshal normalized body: %v", err) + } + if _, ok := payload["model"]; ok { + t.Fatal("top-level model should be removed after normalization") + } + + var session map[string]any + if err := json.Unmarshal(payload["session"], &session); err != nil { + t.Fatalf("failed to unmarshal session: %v", err) + } + if session["model"] != "gpt-4o-realtime-preview" { + t.Fatalf("session.model = %v, want %q", session["model"], "gpt-4o-realtime-preview") + } + if session["type"] != "realtime" { + t.Fatalf("session.type = %v, want %q", session["type"], "realtime") + } +} + +func TestNormalizeRealtimeClientSecretRequestUsesDefaultProvider(t *testing.T) { + t.Parallel() + + body, model, bifrostErr := normalizeRealtimeClientSecretRequest( + json.RawMessage(`{"session":{"model":"gpt-4o-realtime-preview"}}`), + schemas.OpenAI, + schemas.RealtimeSessionEndpointClientSecrets, + ) + if bifrostErr != nil { + t.Fatalf("normalizeRealtimeClientSecretRequest() error = %v", bifrostErr) + } + if model != "gpt-4o-realtime-preview" { + t.Fatalf("model = %q, want %q", model, "gpt-4o-realtime-preview") + } + + var payload map[string]json.RawMessage + if err := json.Unmarshal(body, &payload); err != nil { + t.Fatalf("failed to unmarshal normalized body: %v", err) + } + + var session map[string]any + if err := json.Unmarshal(payload["session"], &session); err != nil { + t.Fatalf("failed to unmarshal session: %v", err) + } + if session["model"] != "gpt-4o-realtime-preview" { + t.Fatalf("session.model = %v, want %q", session["model"], "gpt-4o-realtime-preview") + } + if session["type"] != "realtime" { + t.Fatalf("session.type = %v, want %q", session["type"], "realtime") + } +} + +func TestNormalizeRealtimeSessionsRequest(t *testing.T) { + t.Parallel() + + body, model, bifrostErr := normalizeRealtimeClientSecretRequest( + json.RawMessage(`{"session":{"model":"openai/gpt-4o-realtime-preview","voice":"alloy"}}`), + schemas.OpenAI, + schemas.RealtimeSessionEndpointSessions, + ) + if bifrostErr != nil { + t.Fatalf("normalizeRealtimeClientSecretRequest() error = %v", bifrostErr) + } + if model != "gpt-4o-realtime-preview" { + t.Fatalf("model = %q, want %q", model, "gpt-4o-realtime-preview") + } + + var payload map[string]any + if err := json.Unmarshal(body, &payload); err != nil { + t.Fatalf("failed to unmarshal normalized body: %v", err) + } + if _, ok := payload["session"]; ok { + t.Fatal("legacy sessions endpoint should not forward nested session object") + } + if payload["model"] != "gpt-4o-realtime-preview" { + t.Fatalf("model = %v, want %q", payload["model"], "gpt-4o-realtime-preview") + } + if payload["voice"] != "alloy" { + t.Fatalf("voice = %v, want %q", payload["voice"], "alloy") + } +} + +func TestToProviderRealtimeEventSerializesTopLevelClientFields(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + contentIndex, err := json.Marshal(0) + if err != nil { + t.Fatalf("json.Marshal() error = %v", err) + } + audioEndMS, err := json.Marshal(640) + if err != nil { + t.Fatalf("json.Marshal() error = %v", err) + } + out, err := provider.ToProviderRealtimeEvent(&schemas.BifrostRealtimeEvent{ + Type: schemas.RealtimeEventType("conversation.item.truncate"), + ExtraParams: map[string]json.RawMessage{ + "item_id": json.RawMessage(`"item_123"`), + "content_index": contentIndex, + "audio_end_ms": audioEndMS, + }, + }) + if err != nil { + t.Fatalf("ToProviderRealtimeEvent() error = %v", err) + } + + var payload map[string]any + if err := json.Unmarshal(out, &payload); err != nil { + t.Fatalf("json.Unmarshal() error = %v", err) + } + if payload["type"] != "conversation.item.truncate" { + t.Fatalf("type = %v, want %q", payload["type"], "conversation.item.truncate") + } + if payload["item_id"] != "item_123" { + t.Fatalf("item_id = %v, want %q", payload["item_id"], "item_123") + } + if payload["content_index"] != float64(0) { + t.Fatalf("content_index = %v, want 0", payload["content_index"]) + } + if payload["audio_end_ms"] != float64(640) { + t.Fatalf("audio_end_ms = %v, want 640", payload["audio_end_ms"]) + } +} + +func TestToBifrostRealtimeEventParsesTopLevelClientFields(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + event, err := provider.ToBifrostRealtimeEvent(json.RawMessage(`{"type":"conversation.item.truncate","item_id":"item_123","content_index":0,"audio_end_ms":640}`)) + if err != nil { + t.Fatalf("ToBifrostRealtimeEvent() error = %v", err) + } + var itemID string + if err := json.Unmarshal(event.ExtraParams["item_id"], &itemID); err != nil { + t.Fatalf("json.Unmarshal(item_id) error = %v", err) + } + if itemID != "item_123" { + t.Fatalf("item_id = %q, want %q", itemID, "item_123") + } + var contentIndex int + if err := json.Unmarshal(event.ExtraParams["content_index"], &contentIndex); err != nil { + t.Fatalf("json.Unmarshal(content_index) error = %v", err) + } + if contentIndex != 0 { + t.Fatalf("content_index = %d, want 0", contentIndex) + } + var audioEndMS int + if err := json.Unmarshal(event.ExtraParams["audio_end_ms"], &audioEndMS); err != nil { + t.Fatalf("json.Unmarshal(audio_end_ms) error = %v", err) + } + if audioEndMS != 640 { + t.Fatalf("audio_end_ms = %d, want 640", audioEndMS) + } +} + +func TestToBifrostRealtimeEventParsesCompletedInputAudioTranscript(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + event, err := provider.ToBifrostRealtimeEvent(json.RawMessage(`{"type":"conversation.item.input_audio_transcription.completed","event_id":"evt_123","item_id":"item_123","content_index":0,"transcript":"Who are you?"}`)) + if err != nil { + t.Fatalf("ToBifrostRealtimeEvent() error = %v", err) + } + + var transcript string + if err := json.Unmarshal(event.ExtraParams["transcript"], &transcript); err != nil { + t.Fatalf("json.Unmarshal(transcript) error = %v", err) + } + if transcript != "Who are you?" { + t.Fatalf("transcript = %q, want %q", transcript, "Who are you?") + } +} + +func TestToBifrostRealtimeEventParsesModernOutputTextDelta(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + event, err := provider.ToBifrostRealtimeEvent(json.RawMessage(`{ + "type":"response.output_text.delta", + "event_id":"evt_123", + "item_id":"item_123", + "output_index":0, + "content_index":0, + "response_id":"resp_123", + "delta":"hello" + }`)) + if err != nil { + t.Fatalf("ToBifrostRealtimeEvent() error = %v", err) + } + if event.Delta == nil || event.Delta.Text != "hello" { + t.Fatalf("Delta = %+v, want text=hello", event.Delta) + } +} + +func TestShouldStartRealtimeTurn(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + tests := []struct { + name string + event *schemas.BifrostRealtimeEvent + want bool + }{ + { + name: "response create starts turn", + event: &schemas.BifrostRealtimeEvent{Type: schemas.RTEventResponseCreate}, + want: true, + }, + { + name: "audio buffer committed starts turn", + event: &schemas.BifrostRealtimeEvent{Type: schemas.RTEventInputAudioBufferCommitted}, + want: true, + }, + { + name: "response done does not start turn", + event: &schemas.BifrostRealtimeEvent{Type: schemas.RTEventResponseDone}, + want: false, + }, + { + name: "nil event does not start turn", + event: nil, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if got := provider.ShouldStartRealtimeTurn(tt.event); got != tt.want { + t.Fatalf("ShouldStartRealtimeTurn() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestToProviderRealtimeEventSerializesModernOutputTextDelta(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + outputIndex := 0 + contentIndex := 0 + out, err := provider.ToProviderRealtimeEvent(&schemas.BifrostRealtimeEvent{ + Type: schemas.RealtimeEventType("response.output_text.delta"), + Delta: &schemas.RealtimeDelta{ + Text: "hello", + ItemID: "item_123", + OutputIdx: &outputIndex, + ContentIdx: &contentIndex, + ResponseID: "resp_123", + }, + }) + if err != nil { + t.Fatalf("ToProviderRealtimeEvent() error = %v", err) + } + + var payload map[string]any + if err := json.Unmarshal(out, &payload); err != nil { + t.Fatalf("json.Unmarshal() error = %v", err) + } + if payload["type"] != "response.output_text.delta" { + t.Fatalf("type = %v, want response.output_text.delta", payload["type"]) + } + if payload["delta"] != "hello" { + t.Fatalf("delta = %v, want hello", payload["delta"]) + } +} + +func TestToProviderRealtimeEventSerializesSessionID(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + out, err := provider.ToProviderRealtimeEvent(&schemas.BifrostRealtimeEvent{ + Type: schemas.RTEventSessionCreated, + Session: &schemas.RealtimeSession{ + ID: "sess_123", + Model: "gpt-realtime", + }, + }) + if err != nil { + t.Fatalf("ToProviderRealtimeEvent() error = %v", err) + } + + var payload map[string]any + if err := json.Unmarshal(out, &payload); err != nil { + t.Fatalf("json.Unmarshal() error = %v", err) + } + session, ok := payload["session"].(map[string]any) + if !ok { + t.Fatalf("session = %T, want object", payload["session"]) + } + if session["id"] != "sess_123" { + t.Fatalf("session.id = %v, want sess_123", session["id"]) + } +} + +func TestToProviderRealtimeEventSerializesMessageItemStatus(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + content := json.RawMessage(`[{"type":"input_audio","transcript":"hello"}]`) + out, err := provider.ToProviderRealtimeEvent(&schemas.BifrostRealtimeEvent{ + Type: schemas.RealtimeEventType("conversation.item.retrieved"), + Item: &schemas.RealtimeItem{ + ID: "item_123", + Type: "message", + Role: "user", + Status: "completed", + Content: content, + }, + }) + if err != nil { + t.Fatalf("ToProviderRealtimeEvent() error = %v", err) + } + + var payload map[string]any + if err := json.Unmarshal(out, &payload); err != nil { + t.Fatalf("json.Unmarshal() error = %v", err) + } + item, ok := payload["item"].(map[string]any) + if !ok { + t.Fatalf("item = %T, want object", payload["item"]) + } + if item["status"] != "completed" { + t.Fatalf("item.status = %v, want completed", item["status"]) + } +} + +func TestToBifrostRealtimeEventPreservesTopLevelResponsePayload(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + event, err := provider.ToBifrostRealtimeEvent(json.RawMessage(`{ + "type":"response.done", + "event_id":"evt_123", + "response":{ + "id":"resp_123", + "output":[{"type":"message","content":[{"type":"output_text","text":"hello"}]}] + } + }`)) + if err != nil { + t.Fatalf("ToBifrostRealtimeEvent() error = %v", err) + } + + var response map[string]any + if err := json.Unmarshal(event.ExtraParams["response"], &response); err != nil { + t.Fatalf("json.Unmarshal(response) error = %v", err) + } + if response["id"] != "resp_123" { + t.Fatalf("response.id = %v, want resp_123", response["id"]) + } +} + +func TestToProviderRealtimeEventSerializesTopLevelResponsePayload(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + out, err := provider.ToProviderRealtimeEvent(&schemas.BifrostRealtimeEvent{ + Type: schemas.RTEventResponseDone, + ExtraParams: map[string]json.RawMessage{ + "response": json.RawMessage(`{"id":"resp_123","output":[{"type":"message","content":[{"type":"output_text","text":"hello"}]}]}`), + }, + }) + if err != nil { + t.Fatalf("ToProviderRealtimeEvent() error = %v", err) + } + + var payload map[string]any + if err := json.Unmarshal(out, &payload); err != nil { + t.Fatalf("json.Unmarshal() error = %v", err) + } + response, ok := payload["response"].(map[string]any) + if !ok { + t.Fatalf("response = %T, want object", payload["response"]) + } + if response["id"] != "resp_123" { + t.Fatalf("response.id = %v, want resp_123", response["id"]) + } +} + +func TestToBifrostRealtimeEventPreservesTopLevelPartPayload(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + event, err := provider.ToBifrostRealtimeEvent(json.RawMessage(`{ + "type":"response.content_part.added", + "event_id":"evt_123", + "item_id":"item_123", + "output_index":0, + "content_index":0, + "part":{ + "type":"text", + "text":"hello" + } + }`)) + if err != nil { + t.Fatalf("ToBifrostRealtimeEvent() error = %v", err) + } + + var part map[string]any + if err := json.Unmarshal(event.ExtraParams["part"], &part); err != nil { + t.Fatalf("json.Unmarshal(part) error = %v", err) + } + if part["type"] != "text" { + t.Fatalf("part.type = %v, want text", part["type"]) + } +} + +func TestToProviderRealtimeEventSerializesTopLevelPartPayload(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + out, err := provider.ToProviderRealtimeEvent(&schemas.BifrostRealtimeEvent{ + Type: schemas.RTEventResponseContentPartAdded, + ExtraParams: map[string]json.RawMessage{ + "part": json.RawMessage(`{"type":"text","text":"hello"}`), + }, + }) + if err != nil { + t.Fatalf("ToProviderRealtimeEvent() error = %v", err) + } + + var payload map[string]any + if err := json.Unmarshal(out, &payload); err != nil { + t.Fatalf("json.Unmarshal() error = %v", err) + } + part, ok := payload["part"].(map[string]any) + if !ok { + t.Fatalf("part = %T, want object", payload["part"]) + } + if part["type"] != "text" { + t.Fatalf("part.type = %v, want text", part["type"]) + } +} + +func TestParseRealtimeEventPreservesNestedSessionExtraParams(t *testing.T) { + t.Parallel() + + event, err := schemas.ParseRealtimeEvent([]byte(`{ + "type":"session.update", + "session":{ + "type":"realtime", + "model":"gpt-4o-realtime-preview", + "output_modalities":["text"] + } + }`)) + if err != nil { + t.Fatalf("ParseRealtimeEvent() error = %v", err) + } + if event.Session == nil { + t.Fatal("expected session to be parsed") + } + var outputModalities []string + if err := json.Unmarshal(event.Session.ExtraParams["output_modalities"], &outputModalities); err != nil { + t.Fatalf("json.Unmarshal(output_modalities) error = %v", err) + } + if len(outputModalities) != 1 || outputModalities[0] != "text" { + t.Fatalf("output_modalities = %v, want [text]", outputModalities) + } +} + +func TestToProviderRealtimeEventSerializesNestedSessionExtraParams(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + out, err := provider.ToProviderRealtimeEvent(&schemas.BifrostRealtimeEvent{ + Type: schemas.RTEventSessionUpdate, + Session: &schemas.RealtimeSession{ + Model: "gpt-4o-realtime-preview", + ExtraParams: map[string]json.RawMessage{ + "type": json.RawMessage(`"realtime"`), + "output_modalities": json.RawMessage(`["text"]`), + }, + }, + }) + if err != nil { + t.Fatalf("ToProviderRealtimeEvent() error = %v", err) + } + + var payload struct { + Type string `json:"type"` + Session map[string]any `json:"session"` + } + if err := json.Unmarshal(out, &payload); err != nil { + t.Fatalf("json.Unmarshal() error = %v", err) + } + if payload.Type != "session.update" { + t.Fatalf("type = %q, want %q", payload.Type, "session.update") + } + if payload.Session["type"] != "realtime" { + t.Fatalf("session.type = %v, want realtime", payload.Session["type"]) + } + outputModalities, ok := payload.Session["output_modalities"].([]any) + if !ok || len(outputModalities) != 1 || outputModalities[0] != "text" { + t.Fatalf("session.output_modalities = %v, want [text]", payload.Session["output_modalities"]) + } +} + +func TestToProviderRealtimeEventOmitsReadOnlySessionFieldsOnSessionUpdate(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + out, err := provider.ToProviderRealtimeEvent(&schemas.BifrostRealtimeEvent{ + Type: schemas.RTEventSessionUpdate, + Session: &schemas.RealtimeSession{ + ID: "sess_123", + Model: "gpt-realtime", + ExtraParams: map[string]json.RawMessage{ + "type": json.RawMessage(`"realtime"`), + "object": json.RawMessage(`"realtime.session"`), + "expires_at": json.RawMessage(`1774614381`), + "client_secret": json.RawMessage(`{"value":"secret"}`), + "modalities": json.RawMessage(`["text","audio"]`), + }, + }, + }) + if err != nil { + t.Fatalf("ToProviderRealtimeEvent() error = %v", err) + } + + var payload struct { + Session map[string]any `json:"session"` + } + if err := json.Unmarshal(out, &payload); err != nil { + t.Fatalf("json.Unmarshal() error = %v", err) + } + for _, key := range []string{"id", "object", "expires_at", "client_secret"} { + if _, ok := payload.Session[key]; ok { + t.Fatalf("session.%s unexpectedly present in session.update payload", key) + } + } + if payload.Session["type"] != "realtime" { + t.Fatalf("session.type = %v, want realtime", payload.Session["type"]) + } + if payload.Session["model"] != "gpt-realtime" { + t.Fatalf("session.model = %v, want gpt-realtime", payload.Session["model"]) + } +} diff --git a/core/schemas/bifrost.go b/core/schemas/bifrost.go index 34f120432c..f50eed4210 100644 --- a/core/schemas/bifrost.go +++ b/core/schemas/bifrost.go @@ -228,6 +228,11 @@ const ( BifrostContextKeyTransportPluginLogs BifrostContextKey = "bifrost-transport-plugin-logs" // []PluginLogEntry (transport-layer plugin logs accumulated during HTTP transport hooks) BifrostContextKeyTransportPostHookCompleter BifrostContextKey = "bifrost-transport-posthook-completer" // func() (callback to run HTTPTransportPostHook after streaming - set by transport interceptor middleware) BifrostContextKeySkipPluginPipeline BifrostContextKey = "bifrost-skip-plugin-pipeline" // bool - skip plugin pipeline for the request + BifrostContextKeyParentRequestID BifrostContextKey = "bifrost-parent-request-id" // string (parent linkage for grouped request logs like realtime turns) + BifrostContextKeyRealtimeSessionID BifrostContextKey = "bifrost-realtime-session-id" // string + BifrostContextKeyRealtimeProviderSessionID BifrostContextKey = "bifrost-realtime-provider-session-id" // string + BifrostContextKeyRealtimeSource BifrostContextKey = "bifrost-realtime-source" // string ("ei" or "lm") + BifrostContextKeyRealtimeEventType BifrostContextKey = "bifrost-realtime-event-type" // string BifrostIsAsyncRequest BifrostContextKey = "bifrost-is-async-request" // bool (set by bifrost - DO NOT SET THIS MANUALLY)) - whether the request is an async request (only used in gateway) BifrostContextKeyRequestHeaders BifrostContextKey = "bifrost-request-headers" // map[string]string (all request headers with lowercased keys) BifrostContextKeySkipListModelsGovernanceFiltering BifrostContextKey = "bifrost-skip-list-models-governance-filtering" // bool (set by bifrost - DO NOT SET THIS MANUALLY)) diff --git a/core/schemas/plugin.go b/core/schemas/plugin.go index f9ea18a4b3..5e0d068718 100644 --- a/core/schemas/plugin.go +++ b/core/schemas/plugin.go @@ -313,9 +313,15 @@ type ObservabilityPlugin interface { // // Implementations should: // - Convert the trace to their backend's format - // - Send the trace to the backend (can be async) + // - Send the trace to the backend (can be async, but see retention note below) // - Handle errors gracefully (log and continue) // // The context passed is a fresh background context, not the request context. + // + // Retention: implementations MUST NOT retain the *Trace pointer after Inject + // returns. The caller releases the trace back to a sync.Pool immediately after + // Inject completes, so any background goroutine that still references it will + // race with pool reuse. If a plugin needs to forward the trace asynchronously, + // it must copy the data it needs before returning. Inject(ctx context.Context, trace *Trace) error } diff --git a/core/schemas/realtime.go b/core/schemas/realtime.go index e1e20d7bf4..ec4fd6789d 100644 --- a/core/schemas/realtime.go +++ b/core/schemas/realtime.go @@ -19,33 +19,75 @@ const ( // Server-to-client event types (received from the provider, forwarded to client) const ( - RTEventSessionCreated RealtimeEventType = "session.created" - RTEventSessionUpdated RealtimeEventType = "session.updated" - RTEventConversationCreated RealtimeEventType = "conversation.created" - RTEventConversationItemCreated RealtimeEventType = "conversation.item.created" - RTEventConversationItemDone RealtimeEventType = "conversation.item.done" - RTEventResponseCreated RealtimeEventType = "response.created" - RTEventResponseDone RealtimeEventType = "response.done" - RTEventResponseTextDelta RealtimeEventType = "response.text.delta" - RTEventResponseTextDone RealtimeEventType = "response.text.done" - RTEventResponseAudioDelta RealtimeEventType = "response.audio.delta" - RTEventResponseAudioDone RealtimeEventType = "response.audio.done" - RTEventResponseAudioTransDelta RealtimeEventType = "response.audio_transcript.delta" - RTEventResponseAudioTransDone RealtimeEventType = "response.audio_transcript.done" - RTEventResponseOutputItemAdded RealtimeEventType = "response.output_item.added" - RTEventResponseOutputItemDone RealtimeEventType = "response.output_item.done" - RTEventResponseContentPartAdded RealtimeEventType = "response.content_part.added" - RTEventResponseContentPartDone RealtimeEventType = "response.content_part.done" - RTEventInputAudioTransCompleted RealtimeEventType = "conversation.item.input_audio_transcription.completed" - RTEventInputAudioTransDelta RealtimeEventType = "conversation.item.input_audio_transcription.delta" - RTEventInputAudioTransFailed RealtimeEventType = "conversation.item.input_audio_transcription.failed" - RTEventInputAudioBufferCommitted RealtimeEventType = "input_audio_buffer.committed" - RTEventInputAudioBufferCleared RealtimeEventType = "input_audio_buffer.cleared" - RTEventInputAudioSpeechStarted RealtimeEventType = "input_audio_buffer.speech_started" - RTEventInputAudioSpeechStopped RealtimeEventType = "input_audio_buffer.speech_stopped" - RTEventError RealtimeEventType = "error" + RTEventSessionCreated RealtimeEventType = "session.created" + RTEventSessionUpdated RealtimeEventType = "session.updated" + RTEventConversationCreated RealtimeEventType = "conversation.created" + RTEventConversationItemAdded RealtimeEventType = "conversation.item.added" + RTEventConversationItemCreated RealtimeEventType = "conversation.item.created" + RTEventConversationItemRetrieved RealtimeEventType = "conversation.item.retrieved" + RTEventConversationItemDone RealtimeEventType = "conversation.item.done" + RTEventResponseCreated RealtimeEventType = "response.created" + RTEventResponseDone RealtimeEventType = "response.done" + RTEventResponseTextDelta RealtimeEventType = "response.text.delta" + RTEventResponseTextDone RealtimeEventType = "response.text.done" + RTEventResponseAudioDelta RealtimeEventType = "response.audio.delta" + RTEventResponseAudioDone RealtimeEventType = "response.audio.done" + RTEventResponseAudioTransDelta RealtimeEventType = "response.audio_transcript.delta" + RTEventResponseAudioTransDone RealtimeEventType = "response.audio_transcript.done" + RTEventResponseOutputItemAdded RealtimeEventType = "response.output_item.added" + RTEventResponseOutputItemDone RealtimeEventType = "response.output_item.done" + RTEventResponseContentPartAdded RealtimeEventType = "response.content_part.added" + RTEventResponseContentPartDone RealtimeEventType = "response.content_part.done" + RTEventRateLimitsUpdated RealtimeEventType = "rate_limits.updated" + RTEventInputAudioTransCompleted RealtimeEventType = "conversation.item.input_audio_transcription.completed" + RTEventInputAudioTransDelta RealtimeEventType = "conversation.item.input_audio_transcription.delta" + RTEventInputAudioTransFailed RealtimeEventType = "conversation.item.input_audio_transcription.failed" + RTEventInputAudioBufferCommitted RealtimeEventType = "input_audio_buffer.committed" + RTEventInputAudioBufferCleared RealtimeEventType = "input_audio_buffer.cleared" + RTEventInputAudioSpeechStarted RealtimeEventType = "input_audio_buffer.speech_started" + RTEventInputAudioSpeechStopped RealtimeEventType = "input_audio_buffer.speech_stopped" + RTEventError RealtimeEventType = "error" ) +// IsRealtimeConversationItemEventType reports whether the event carries a +// canonical conversation item payload after provider translation. +func IsRealtimeConversationItemEventType(eventType RealtimeEventType) bool { + switch eventType { + case RTEventConversationItemCreate, + RTEventConversationItemAdded, + RTEventConversationItemCreated, + RTEventConversationItemRetrieved, + RTEventConversationItemDone: + return true + default: + return false + } +} + +// IsRealtimeUserInputEvent reports whether the event represents a finalized +// user input item in the canonical Bifrost realtime schema. +func IsRealtimeUserInputEvent(event *BifrostRealtimeEvent) bool { + return event != nil && + event.Item != nil && + event.Item.Role == "user" && + IsRealtimeConversationItemEventType(event.Type) +} + +// IsRealtimeToolOutputEvent reports whether the event represents a finalized +// tool output item in the canonical Bifrost realtime schema. +func IsRealtimeToolOutputEvent(event *BifrostRealtimeEvent) bool { + return event != nil && + event.Item != nil && + event.Item.Type == "function_call_output" && + IsRealtimeConversationItemEventType(event.Type) +} + +// IsRealtimeInputTranscriptEvent reports whether the event carries a finalized +// input-audio transcript in the canonical Bifrost realtime schema. +func IsRealtimeInputTranscriptEvent(event *BifrostRealtimeEvent) bool { + return event != nil && event.Type == RTEventInputAudioTransCompleted +} + // BifrostRealtimeEvent is the unified Bifrost envelope for all Realtime events. // Provider converters translate between this format and the provider-native protocol. type BifrostRealtimeEvent struct { @@ -58,36 +100,42 @@ type BifrostRealtimeEvent struct { Audio []byte `json:"audio,omitempty"` Error *RealtimeError `json:"error,omitempty"` + // ExtraParams preserves provider-specific top-level event fields that are not + // promoted into the common Bifrost schema. + ExtraParams map[string]json.RawMessage `json:"extra_params,omitempty"` + // RawData preserves the original provider event for pass-through or debugging. RawData json.RawMessage `json:"raw_data,omitempty"` } // RealtimeSession describes session configuration for the Realtime connection. type RealtimeSession struct { - ID string `json:"id,omitempty"` - Model string `json:"model,omitempty"` - Modalities []string `json:"modalities,omitempty"` - Instructions string `json:"instructions,omitempty"` - Voice string `json:"voice,omitempty"` - Temperature *float64 `json:"temperature,omitempty"` - MaxOutputTokens json.RawMessage `json:"max_output_tokens,omitempty"` - TurnDetection json.RawMessage `json:"turn_detection,omitempty"` - InputAudioFormat string `json:"input_audio_format,omitempty"` - OutputAudioType string `json:"output_audio_type,omitempty"` - Tools json.RawMessage `json:"tools,omitempty"` + ID string `json:"id,omitempty"` + Model string `json:"model,omitempty"` + Modalities []string `json:"modalities,omitempty"` + Instructions string `json:"instructions,omitempty"` + Voice string `json:"voice,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + MaxOutputTokens json.RawMessage `json:"max_output_tokens,omitempty"` + TurnDetection json.RawMessage `json:"turn_detection,omitempty"` + InputAudioFormat string `json:"input_audio_format,omitempty"` + OutputAudioType string `json:"output_audio_type,omitempty"` + Tools json.RawMessage `json:"tools,omitempty"` + ExtraParams map[string]json.RawMessage `json:"extra_params,omitempty"` } // RealtimeItem represents a conversation item in the Realtime protocol. type RealtimeItem struct { - ID string `json:"id,omitempty"` - Type string `json:"type,omitempty"` - Role string `json:"role,omitempty"` - Status string `json:"status,omitempty"` - Content json.RawMessage `json:"content,omitempty"` - Name string `json:"name,omitempty"` - CallID string `json:"call_id,omitempty"` - Arguments string `json:"arguments,omitempty"` - Output string `json:"output,omitempty"` + ID string `json:"id,omitempty"` + Type string `json:"type,omitempty"` + Role string `json:"role,omitempty"` + Status string `json:"status,omitempty"` + Content json.RawMessage `json:"content,omitempty"` + Name string `json:"name,omitempty"` + CallID string `json:"call_id,omitempty"` + Arguments string `json:"arguments,omitempty"` + Output string `json:"output,omitempty"` + ExtraParams map[string]json.RawMessage `json:"extra_params,omitempty"` } // RealtimeDelta carries incremental content for streaming events. @@ -103,10 +151,28 @@ type RealtimeDelta struct { // RealtimeError describes an error from the Realtime API. type RealtimeError struct { - Type string `json:"type,omitempty"` - Code string `json:"code,omitempty"` - Message string `json:"message,omitempty"` - Param string `json:"param,omitempty"` + Type string `json:"type,omitempty"` + Code string `json:"code,omitempty"` + Message string `json:"message,omitempty"` + Param string `json:"param,omitempty"` + ExtraParams map[string]json.RawMessage `json:"extra_params,omitempty"` +} + +// RealtimeSessionEndpointType identifies the public ephemeral-token endpoint +// shape the client called so providers can preserve versioned behavior. +type RealtimeSessionEndpointType string + +const ( + RealtimeSessionEndpointClientSecrets RealtimeSessionEndpointType = "client_secrets" + RealtimeSessionEndpointSessions RealtimeSessionEndpointType = "sessions" +) + +// RealtimeSessionRoute describes a provider-registered public route for +// ephemeral-token creation. +type RealtimeSessionRoute struct { + Path string + EndpointType RealtimeSessionEndpointType + DefaultProvider ModelProvider } // RealtimeProvider is an optional interface that providers can implement to @@ -116,6 +182,129 @@ type RealtimeProvider interface { SupportsRealtimeAPI() bool RealtimeWebSocketURL(key Key, model string) string RealtimeHeaders(key Key) map[string]string + // SupportsRealtimeWebRTC reports whether the provider supports WebRTC SDP exchange. + SupportsRealtimeWebRTC() bool + // ExchangeRealtimeWebRTCSDP performs the provider-specific SDP signaling exchange. + // The provider owns the HTTP specifics (URL, headers, body format). + // session may be nil if the signaling format doesn't include session config. + ExchangeRealtimeWebRTCSDP(ctx *BifrostContext, key Key, model string, sdp string, session json.RawMessage) (string, *BifrostError) ToBifrostRealtimeEvent(providerEvent json.RawMessage) (*BifrostRealtimeEvent, error) ToProviderRealtimeEvent(bifrostEvent *BifrostRealtimeEvent) (json.RawMessage, error) + // ShouldStartRealtimeTurn reports whether the canonical client-side event + // should start pre-hooks. Providers without an explicit turn-start signal + // return false and rely on finalize-time fallback hooks. + ShouldStartRealtimeTurn(event *BifrostRealtimeEvent) bool + // RealtimeTurnFinalEvent returns the canonical provider event that completes + // a turn and should trigger post-hooks. + RealtimeTurnFinalEvent() RealtimeEventType + RealtimeWebRTCDataChannelLabel() string + RealtimeWebSocketSubprotocol() string + ShouldForwardRealtimeEvent(event *BifrostRealtimeEvent) bool + ShouldAccumulateRealtimeOutput(eventType RealtimeEventType) bool +} + +// RealtimeLegacyWebRTCProvider is an optional interface for providers that +// support the beta WebRTC handshake (e.g., OpenAI's /v1/realtime). +// Only checked for legacy integration routes via type assertion. +// Takes SDP offer + optional session JSON, same as ExchangeRealtimeWebRTCSDP +// but targets the provider's legacy/beta endpoint. +type RealtimeLegacyWebRTCProvider interface { + ExchangeLegacyRealtimeWebRTCSDP(ctx *BifrostContext, key Key, sdp string, session json.RawMessage, model string) (string, *BifrostError) +} + +// RealtimeUsageExtractor lets providers parse terminal-turn usage/output from +// their native wire payloads without coupling handlers to a specific protocol. +type RealtimeUsageExtractor interface { + ExtractRealtimeTurnUsage(terminalEventRaw []byte) *BifrostLLMUsage + ExtractRealtimeTurnOutput(terminalEventRaw []byte) *ChatMessage +} + +// RealtimeSessionProvider is an optional interface for providers that can mint +// short-lived client secrets for browser/client-side Realtime connections. +// Checked via type assertion: provider.(RealtimeSessionProvider). +type RealtimeSessionProvider interface { + CreateRealtimeClientSecret(ctx *BifrostContext, key Key, endpointType RealtimeSessionEndpointType, rawRequest json.RawMessage) (*BifrostPassthroughResponse, *BifrostError) +} + +// ParseRealtimeEvent decodes a client/provider realtime event while preserving +// unknown top-level fields in ExtraParams for provider-specific round-tripping. +func ParseRealtimeEvent(raw []byte) (*BifrostRealtimeEvent, error) { + type realtimeEventAlias struct { + Type RealtimeEventType `json:"type"` + EventID string `json:"event_id,omitempty"` + Session *RealtimeSession `json:"session,omitempty"` + Item *RealtimeItem `json:"item,omitempty"` + Delta *RealtimeDelta `json:"delta,omitempty"` + Audio []byte `json:"audio,omitempty"` + Error *RealtimeError `json:"error,omitempty"` + } + + var alias realtimeEventAlias + if err := Unmarshal(raw, &alias); err != nil { + return nil, err + } + + event := &BifrostRealtimeEvent{ + Type: alias.Type, + EventID: alias.EventID, + Session: alias.Session, + Item: alias.Item, + Delta: alias.Delta, + Audio: alias.Audio, + Error: alias.Error, + } + + var root map[string]json.RawMessage + if err := Unmarshal(raw, &root); err != nil { + return nil, err + } + savedSession := root["session"] + savedItem := root["item"] + savedError := root["error"] + for _, key := range []string{"type", "event_id", "session", "item", "delta", "audio", "error", "raw_data"} { + delete(root, key) + } + if len(root) > 0 { + event.ExtraParams = root + } + if event.Session != nil { + var sessionRoot map[string]json.RawMessage + if len(savedSession) > 0 && Unmarshal(savedSession, &sessionRoot) == nil { + for _, key := range []string{ + "id", "model", "modalities", "instructions", "voice", "temperature", + "max_output_tokens", "turn_detection", "input_audio_format", "output_audio_type", "tools", + } { + delete(sessionRoot, key) + } + if len(sessionRoot) > 0 { + event.Session.ExtraParams = sessionRoot + } + } + } + if event.Item != nil { + var itemRoot map[string]json.RawMessage + if len(savedItem) > 0 && Unmarshal(savedItem, &itemRoot) == nil { + for _, key := range []string{ + "id", "type", "role", "status", "content", "name", "call_id", "arguments", "output", + } { + delete(itemRoot, key) + } + if len(itemRoot) > 0 { + event.Item.ExtraParams = itemRoot + } + } + } + if event.Error != nil { + var errorRoot map[string]json.RawMessage + if len(savedError) > 0 && Unmarshal(savedError, &errorRoot) == nil { + for _, key := range []string{"type", "code", "message", "param"} { + delete(errorRoot, key) + } + if len(errorRoot) > 0 { + event.Error.ExtraParams = errorRoot + } + } + } + + return event, nil } diff --git a/core/schemas/realtime_client_secrets.go b/core/schemas/realtime_client_secrets.go new file mode 100644 index 0000000000..ae97b573a1 --- /dev/null +++ b/core/schemas/realtime_client_secrets.go @@ -0,0 +1,66 @@ +package schemas + +import ( + "bytes" + "encoding/json" + "strings" +) + +// ParseRealtimeClientSecretBody parses a realtime client-secret request body +// into a mutable raw JSON map while preserving unknown fields. +func ParseRealtimeClientSecretBody(raw json.RawMessage) (map[string]json.RawMessage, *BifrostError) { + var root map[string]json.RawMessage + if err := Unmarshal(raw, &root); err != nil { + return nil, NewRealtimeClientSecretBodyError(400, "invalid_request_error", "invalid JSON body", err) + } + return root, nil +} + +// ExtractRealtimeClientSecretModel extracts the model from either session.model +// or the legacy top-level model field. +func ExtractRealtimeClientSecretModel(root map[string]json.RawMessage) (string, *BifrostError) { + if sessionJSON, ok := root["session"]; ok && len(sessionJSON) > 0 && !bytes.Equal(sessionJSON, []byte("null")) { + var session map[string]json.RawMessage + if err := Unmarshal(sessionJSON, &session); err != nil { + return "", NewRealtimeClientSecretBodyError(400, "invalid_request_error", "session must be an object", err) + } + if modelJSON, ok := session["model"]; ok { + var sessionModel string + if err := Unmarshal(modelJSON, &sessionModel); err != nil { + return "", NewRealtimeClientSecretBodyError(400, "invalid_request_error", "session.model must be a string", err) + } + if strings.TrimSpace(sessionModel) != "" { + return strings.TrimSpace(sessionModel), nil + } + } + } + + if modelJSON, ok := root["model"]; ok { + var model string + if err := Unmarshal(modelJSON, &model); err != nil { + return "", NewRealtimeClientSecretBodyError(400, "invalid_request_error", "model must be a string", err) + } + if strings.TrimSpace(model) != "" { + return strings.TrimSpace(model), nil + } + } + + return "", NewRealtimeClientSecretBodyError(400, "invalid_request_error", "session.model or model is required", nil) +} + +// NewRealtimeClientSecretBodyError builds a standard invalid-request style error +// for HTTP realtime client-secret request parsing/validation. +func NewRealtimeClientSecretBodyError(status int, errorType, message string, err error) *BifrostError { + return &BifrostError{ + IsBifrostError: false, + StatusCode: Ptr(status), + Error: &ErrorField{ + Type: Ptr(errorType), + Message: message, + Error: err, + }, + ExtraFields: BifrostErrorExtraFields{ + RequestType: RealtimeRequest, + }, + } +} diff --git a/core/schemas/realtime_client_secrets_test.go b/core/schemas/realtime_client_secrets_test.go new file mode 100644 index 0000000000..dfd8f8b1d3 --- /dev/null +++ b/core/schemas/realtime_client_secrets_test.go @@ -0,0 +1,40 @@ +package schemas + +import ( + "encoding/json" + "testing" +) + +func TestExtractRealtimeClientSecretModel(t *testing.T) { + t.Parallel() + + root, err := ParseRealtimeClientSecretBody(json.RawMessage(`{"session":{"model":"openai/gpt-4o-realtime-preview"}}`)) + if err != nil { + t.Fatalf("ParseRealtimeClientSecretBody() error = %v", err) + } + + model, err := ExtractRealtimeClientSecretModel(root) + if err != nil { + t.Fatalf("ExtractRealtimeClientSecretModel() error = %v", err) + } + if model != "openai/gpt-4o-realtime-preview" { + t.Fatalf("model = %q, want %q", model, "openai/gpt-4o-realtime-preview") + } +} + +func TestExtractRealtimeClientSecretModelFallbackTopLevel(t *testing.T) { + t.Parallel() + + root, err := ParseRealtimeClientSecretBody(json.RawMessage(`{"model":"gpt-4o-realtime-preview"}`)) + if err != nil { + t.Fatalf("ParseRealtimeClientSecretBody() error = %v", err) + } + + model, err := ExtractRealtimeClientSecretModel(root) + if err != nil { + t.Fatalf("ExtractRealtimeClientSecretModel() error = %v", err) + } + if model != "gpt-4o-realtime-preview" { + t.Fatalf("model = %q, want %q", model, "gpt-4o-realtime-preview") + } +} diff --git a/core/schemas/realtime_test.go b/core/schemas/realtime_test.go new file mode 100644 index 0000000000..69e9e403c8 --- /dev/null +++ b/core/schemas/realtime_test.go @@ -0,0 +1,68 @@ +package schemas + +import "testing" + +func TestIsRealtimeConversationItemEventType(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + eventType RealtimeEventType + want bool + }{ + {name: "create", eventType: RTEventConversationItemCreate, want: true}, + {name: "added", eventType: RTEventConversationItemAdded, want: true}, + {name: "created", eventType: RTEventConversationItemCreated, want: true}, + {name: "retrieved", eventType: RTEventConversationItemRetrieved, want: true}, + {name: "done", eventType: RTEventConversationItemDone, want: true}, + {name: "response done", eventType: RTEventResponseDone, want: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if got := IsRealtimeConversationItemEventType(tt.eventType); got != tt.want { + t.Fatalf("IsRealtimeConversationItemEventType(%q) = %v, want %v", tt.eventType, got, tt.want) + } + }) + } +} + +func TestRealtimeCanonicalEventClassifiers(t *testing.T) { + t.Parallel() + + userEvent := &BifrostRealtimeEvent{ + Type: RTEventConversationItemAdded, + Item: &RealtimeItem{ + Role: "user", + Type: "message", + }, + } + if !IsRealtimeUserInputEvent(userEvent) { + t.Fatal("expected conversation.item.added user event to be classified as realtime user input") + } + if IsRealtimeToolOutputEvent(userEvent) { + t.Fatal("did not expect conversation.item.added user event to be classified as realtime tool output") + } + + toolEvent := &BifrostRealtimeEvent{ + Type: RTEventConversationItemRetrieved, + Item: &RealtimeItem{ + Type: "function_call_output", + }, + } + if !IsRealtimeToolOutputEvent(toolEvent) { + t.Fatal("expected function_call_output item to be classified as realtime tool output") + } + if IsRealtimeUserInputEvent(toolEvent) { + t.Fatal("did not expect function_call_output item to be classified as realtime user input") + } + + transcriptEvent := &BifrostRealtimeEvent{Type: RTEventInputAudioTransCompleted} + if !IsRealtimeInputTranscriptEvent(transcriptEvent) { + t.Fatal("expected input audio transcription completion to be classified as transcript event") + } + if IsRealtimeInputTranscriptEvent(&BifrostRealtimeEvent{Type: RTEventInputAudioTransDelta}) { + t.Fatal("did not expect input audio transcription delta to be classified as transcript event") + } +} diff --git a/core/schemas/tracer.go b/core/schemas/tracer.go index 3e5bcbf46f..23c5d4cc4c 100644 --- a/core/schemas/tracer.go +++ b/core/schemas/tracer.go @@ -117,6 +117,10 @@ type Tracer interface { // Thread-safe. Should be called after plugin hooks complete, before trace completion. AttachPluginLogs(traceID string, logs []PluginLogEntry) + // CompleteAndFlushTrace ends a trace, exports it to observability plugins, and + // releases the trace resources. Used by transports that bypass normal HTTP trace completion. + CompleteAndFlushTrace(traceID string) + // Stop releases resources associated with the tracer. // Should be called during shutdown to stop background goroutines. Stop() @@ -185,6 +189,9 @@ func (n *NoOpTracer) ProcessStreamingChunk(_ string, _ bool, _ *BifrostResponse, // AttachPluginLogs does nothing. func (n *NoOpTracer) AttachPluginLogs(_ string, _ []PluginLogEntry) {} +// CompleteAndFlushTrace does nothing. +func (n *NoOpTracer) CompleteAndFlushTrace(_ string) {} + // Stop does nothing. func (n *NoOpTracer) Stop() {} diff --git a/framework/logstore/matviews.go b/framework/logstore/matviews.go index bf51748693..0ccc4a8f28 100644 --- a/framework/logstore/matviews.go +++ b/framework/logstore/matviews.go @@ -164,7 +164,8 @@ func startMatViewRefresher(ctx context.Context, db *gorm.DB, interval time.Durat // mv_logs_hourly. Per-row filters (content search, metadata, numeric ranges) // require the raw logs table. func canUseMatView(f SearchFilters) bool { - return f.ContentSearch == "" && + return f.ParentRequestID == "" && + f.ContentSearch == "" && len(f.MetadataFilters) == 0 && len(f.RoutingEngineUsed) == 0 && f.MinLatency == nil && f.MaxLatency == nil && diff --git a/framework/logstore/migrations.go b/framework/logstore/migrations.go index 65ea54e9ac..244f844f9a 100644 --- a/framework/logstore/migrations.go +++ b/framework/logstore/migrations.go @@ -2038,6 +2038,11 @@ var performanceIndexes = []performanceIndexDef{ name: "idx_logs_alias", sql: "CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_logs_alias ON logs(alias)", }, + { + table: "logs", + name: "idx_logs_parent_request_id", + sql: "CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_logs_parent_request_id ON logs(parent_request_id) WHERE parent_request_id IS NOT NULL", + }, } // ensurePerformanceIndexes checks whether each performance GIN index exists and is diff --git a/framework/logstore/rdb.go b/framework/logstore/rdb.go index 39bea8720c..a14aa61580 100644 --- a/framework/logstore/rdb.go +++ b/framework/logstore/rdb.go @@ -29,6 +29,7 @@ func isValidMetadataKey(key string) bool { } const bulkUpdateCostChunkSize = 500 +const sessionLogPageLimit = 50 const ( // defaultMaxQueryLimit is a safety cap for unbounded queries (FindAll, FindAllDistinct). @@ -88,6 +89,9 @@ func (s *RDBLogStore) applyFilters(baseQuery *gorm.DB, filters SearchFilters) *g if len(filters.Objects) > 0 { baseQuery = baseQuery.Where("object_type IN ?", filters.Objects) } + if filters.ParentRequestID != "" { + baseQuery = baseQuery.Where("parent_request_id = ?", filters.ParentRequestID) + } if len(filters.SelectedKeyIDs) > 0 { baseQuery = baseQuery.Where("selected_key_id IN ?", filters.SelectedKeyIDs) } @@ -444,9 +448,167 @@ func (s *RDBLogStore) SearchLogs(ctx context.Context, filters SearchFilters, pag }, nil } +// GetSessionLogs returns paginated logs for a single parent_request_id session. +func (s *RDBLogStore) GetSessionLogs(ctx context.Context, sessionID string, pagination PaginationOptions) (*SessionDetailResult, error) { + if strings.TrimSpace(sessionID) == "" { + return nil, fmt.Errorf("sessionID cannot be empty") + } + + limit := pagination.Limit + if limit <= 0 || limit > sessionLogPageLimit { + limit = sessionLogPageLimit + } + pagination.Limit = limit + if pagination.Offset < 0 { + pagination.Offset = 0 + } + + pagination.SortBy = "timestamp" + orderDir := "ASC" + if pagination.Order == "desc" { + orderDir = "DESC" + } + orderClause := "timestamp " + orderDir + ", id " + orderDir + + baseQuery := s.db.WithContext(ctx).Model(&Log{}).Where("parent_request_id = ?", sessionID) + + var ( + totalCount int64 + logs []Log + ) + + g, gCtx := errgroup.WithContext(ctx) + + g.Go(func() error { + return s.db.WithContext(gCtx).Model(&Log{}).Where("parent_request_id = ?", sessionID).Count(&totalCount).Error + }) + + g.Go(func() error { + dataQuery := baseQuery.Session(&gorm.Session{}). + WithContext(gCtx). + Order(orderClause). + Select(s.listSelectColumns()). + Limit(limit) + if pagination.Offset > 0 { + dataQuery = dataQuery.Offset(pagination.Offset) + } + err := dataQuery.Find(&logs).Error + if err != nil && errors.Is(err, gorm.ErrRecordNotFound) { + return nil + } + return err + }) + + if err := g.Wait(); err != nil { + return nil, err + } + + pagination.TotalCount = totalCount + returnedCount := len(logs) + return &SessionDetailResult{ + SessionID: sessionID, + Logs: logs, + Pagination: pagination, + Count: totalCount, + ReturnedCount: returnedCount, + HasMore: int64(pagination.Offset+returnedCount) < totalCount, + }, nil +} + +// GetSessionSummary returns aggregate totals for a single parent_request_id session. +func (s *RDBLogStore) GetSessionSummary(ctx context.Context, sessionID string) (*SessionSummaryResult, error) { + if strings.TrimSpace(sessionID) == "" { + return nil, fmt.Errorf("sessionID cannot be empty") + } + + var ( + count int64 + totalCost float64 + totalTokens int64 + startedAt string + latestAt string + startedRaw any + latestRaw any + ) + + // Single aggregate select keeps Count/SUM/MIN/MAX consistent against the same row snapshot + // and halves the round trips compared to running Count and the aggregate row in parallel. + row := s.db.WithContext(ctx). + Model(&Log{}). + Where("parent_request_id = ?", sessionID). + Select("COUNT(*) AS count, COALESCE(SUM(cost), 0) AS total_cost, COALESCE(SUM(total_tokens), 0) AS total_tokens, MIN(timestamp) AS started_at, MAX(timestamp) AS latest_at"). + Row() + + if err := row.Scan(&count, &totalCost, &totalTokens, &startedRaw, &latestRaw); err != nil { + return nil, err + } + + startedAt = normalizeAggregateTimestamp(startedRaw) + latestAt = normalizeAggregateTimestamp(latestRaw) + + durationMs := int64(0) + if startedAt != "" && latestAt != "" { + if startedTime, err := time.Parse(time.RFC3339Nano, startedAt); err == nil { + if latestTime, err := time.Parse(time.RFC3339Nano, latestAt); err == nil { + durationMs = latestTime.Sub(startedTime).Milliseconds() + if durationMs < 0 { + durationMs = 0 + } + } + } + } + + return &SessionSummaryResult{ + SessionID: sessionID, + Count: count, + TotalCost: totalCost, + TotalTokens: totalTokens, + StartedAt: startedAt, + LatestAt: latestAt, + DurationMs: durationMs, + }, nil +} + +func normalizeAggregateTimestamp(value any) string { + switch v := value.(type) { + case nil: + return "" + case time.Time: + return v.UTC().Format(time.RFC3339Nano) + case []byte: + return normalizeAggregateTimestamp(string(v)) + case string: + raw := strings.TrimSpace(v) + if raw == "" { + return "" + } + layouts := []string{ + time.RFC3339Nano, + time.RFC3339, + "2006-01-02 15:04:05.999999999-07:00", + "2006-01-02 15:04:05.999999999Z07:00", + "2006-01-02 15:04:05.999999999", + "2006-01-02 15:04:05", + "2006-01-02T15:04:05.999999999", + "2006-01-02T15:04:05", + } + for _, layout := range layouts { + if parsed, err := time.Parse(layout, raw); err == nil { + return parsed.UTC().Format(time.RFC3339Nano) + } + } + return raw + default: + return fmt.Sprint(v) + } +} + // listSelectColumns returns a SELECT clause for list queries that omits large // output/detail TEXT columns and uses SQL JSON functions to extract only the // last element from input_history and responses_input_history arrays. +// +// Realtime turn rows are kept intact because the logs table renders them as a +// combined Tool/User/Assistant summary and needs the full turn context. func (s *RDBLogStore) listSelectColumns() string { baseCols := strings.Join([]string{ "id", "parent_request_id", "timestamp", "object_type", "provider", "model", "alias", @@ -462,25 +624,35 @@ func (s *RDBLogStore) listSelectColumns() string { "created_at", }, ", ") - var inputHistoryExpr, responsesInputExpr string + var inputHistoryExpr, responsesInputExpr, outputMessageExpr string switch s.db.Dialector.Name() { case "postgres": - inputHistoryExpr = `CASE WHEN input_history IS NOT NULL AND input_history != '' AND input_history != '[]' + inputHistoryExpr = `CASE + WHEN object_type = 'realtime.turn' THEN input_history + WHEN input_history IS NOT NULL AND input_history != '' AND input_history != '[]' THEN jsonb_build_array(input_history::jsonb->-1)::text ELSE input_history END AS input_history` - responsesInputExpr = `CASE WHEN responses_input_history IS NOT NULL AND responses_input_history != '' AND responses_input_history != '[]' + responsesInputExpr = `CASE + WHEN object_type = 'realtime.turn' THEN responses_input_history + WHEN responses_input_history IS NOT NULL AND responses_input_history != '' AND responses_input_history != '[]' THEN jsonb_build_array(responses_input_history::jsonb->-1)::text ELSE responses_input_history END AS responses_input_history` + outputMessageExpr = `CASE WHEN object_type = 'realtime.turn' THEN output_message ELSE NULL END AS output_message` default: // sqlite - inputHistoryExpr = `CASE WHEN input_history IS NOT NULL AND input_history != '' AND input_history != '[]' + inputHistoryExpr = `CASE + WHEN object_type = 'realtime.turn' THEN input_history + WHEN input_history IS NOT NULL AND input_history != '' AND input_history != '[]' THEN json_array(json_extract(input_history, '$[' || (json_array_length(input_history) - 1) || ']')) ELSE input_history END AS input_history` - responsesInputExpr = `CASE WHEN responses_input_history IS NOT NULL AND responses_input_history != '' AND responses_input_history != '[]' + responsesInputExpr = `CASE + WHEN object_type = 'realtime.turn' THEN responses_input_history + WHEN responses_input_history IS NOT NULL AND responses_input_history != '' AND responses_input_history != '[]' THEN json_array(json_extract(responses_input_history, '$[' || (json_array_length(responses_input_history) - 1) || ']')) ELSE responses_input_history END AS responses_input_history` + outputMessageExpr = `CASE WHEN object_type = 'realtime.turn' THEN output_message ELSE NULL END AS output_message` } - return baseCols + ", " + inputHistoryExpr + ", " + responsesInputExpr + return baseCols + ", " + inputHistoryExpr + ", " + responsesInputExpr + ", " + outputMessageExpr } // GetStats calculates statistics for logs matching the given filters. diff --git a/framework/logstore/store.go b/framework/logstore/store.go index 27a2d12f7d..3d6aedf711 100644 --- a/framework/logstore/store.go +++ b/framework/logstore/store.go @@ -30,6 +30,8 @@ type LogStore interface { FindAllDistinct(ctx context.Context, query any, fields ...string) ([]*Log, error) HasLogs(ctx context.Context) (bool, error) SearchLogs(ctx context.Context, filters SearchFilters, pagination PaginationOptions) (*SearchResult, error) + GetSessionLogs(ctx context.Context, sessionID string, pagination PaginationOptions) (*SessionDetailResult, error) + GetSessionSummary(ctx context.Context, sessionID string) (*SessionSummaryResult, error) GetStats(ctx context.Context, filters SearchFilters) (*SearchStats, error) GetHistogram(ctx context.Context, filters SearchFilters, bucketSizeSeconds int64) (*HistogramResult, error) GetTokenHistogram(ctx context.Context, filters SearchFilters, bucketSizeSeconds int64) (*TokenHistogramResult, error) diff --git a/framework/logstore/tables.go b/framework/logstore/tables.go index edf0f879d0..c37d51d6f2 100644 --- a/framework/logstore/tables.go +++ b/framework/logstore/tables.go @@ -34,6 +34,7 @@ type SearchFilters struct { Aliases []string `json:"aliases,omitempty"` Status []string `json:"status,omitempty"` Objects []string `json:"objects,omitempty"` // For filtering by request type (chat.completion, text.completion, embedding) + ParentRequestID string `json:"parent_request_id,omitempty"` SelectedKeyIDs []string `json:"selected_key_ids,omitempty"` VirtualKeyIDs []string `json:"virtual_key_ids,omitempty"` RoutingRuleIDs []string `json:"routing_rule_ids,omitempty"` @@ -68,6 +69,25 @@ type SearchResult struct { HasLogs bool `json:"has_logs"` } +type SessionDetailResult struct { + SessionID string `json:"session_id"` + Logs []Log `json:"logs"` + Pagination PaginationOptions `json:"pagination"` + Count int64 `json:"count"` + ReturnedCount int `json:"returned_count"` + HasMore bool `json:"has_more"` +} + +type SessionSummaryResult struct { + SessionID string `json:"session_id"` + Count int64 `json:"count"` + TotalCost float64 `json:"total_cost"` + TotalTokens int64 `json:"total_tokens"` + StartedAt string `json:"started_at,omitempty"` + LatestAt string `json:"latest_at,omitempty"` + DurationMs int64 `json:"duration_ms"` +} + type SearchStats struct { TotalRequests int64 `json:"total_requests"` SuccessRate float64 `json:"success_rate"` // Percentage of successful requests @@ -80,7 +100,7 @@ type SearchStats struct { // This is the GORM model with appropriate tags type Log struct { ID string `gorm:"primaryKey;type:varchar(255)" json:"id"` - ParentRequestID *string `gorm:"type:varchar(255)" json:"parent_request_id"` + ParentRequestID *string `gorm:"type:varchar(255);index" json:"parent_request_id"` Timestamp time.Time `gorm:"index;index:idx_logs_ts_provider_status,priority:1;not null" json:"timestamp"` Object string `gorm:"type:varchar(255);index;not null;column:object_type" json:"object"` // text.completion, chat.completion, or embedding Provider string `gorm:"type:varchar(255);index;index:idx_logs_ts_provider_status,priority:2;not null" json:"provider"` diff --git a/framework/modelcatalog/pricing.go b/framework/modelcatalog/pricing.go index 535aed226d..6942a3ad05 100644 --- a/framework/modelcatalog/pricing.go +++ b/framework/modelcatalog/pricing.go @@ -115,7 +115,7 @@ func (mc *ModelCatalog) calculateBaseCost(result *schemas.BifrostResponse, scope // Route to the appropriate compute function switch requestType { - case schemas.ChatCompletionRequest, schemas.TextCompletionRequest, schemas.ResponsesRequest: + case schemas.ChatCompletionRequest, schemas.TextCompletionRequest, schemas.ResponsesRequest, schemas.RealtimeRequest: return computeTextCost(pricing, input.usage) case schemas.EmbeddingRequest: return computeEmbeddingCost(pricing, input.usage) @@ -833,7 +833,7 @@ func (mc *ModelCatalog) getBasePricing(model, provider string, requestType schem } // Lookup in chat if responses not found - if requestType == schemas.ResponsesRequest || requestType == schemas.ResponsesStreamRequest { + if requestType == schemas.ResponsesRequest || requestType == schemas.ResponsesStreamRequest || requestType == schemas.RealtimeRequest { mc.logger.Debug("secondary lookup failed, trying vertex provider for the same model in chat completion") pricing, ok = mc.pricingData[makeKey(model, "vertex", normalizeRequestType(schemas.ChatCompletionRequest))] if ok { @@ -853,7 +853,7 @@ func (mc *ModelCatalog) getBasePricing(model, provider string, requestType schem } // Lookup in chat if responses not found - if requestType == schemas.ResponsesRequest || requestType == schemas.ResponsesStreamRequest { + if requestType == schemas.ResponsesRequest || requestType == schemas.ResponsesStreamRequest || requestType == schemas.RealtimeRequest { mc.logger.Debug("secondary lookup failed, trying vertex provider for the same model in chat completion") pricing, ok = mc.pricingData[makeKey(modelWithoutProvider, "vertex", normalizeRequestType(schemas.ChatCompletionRequest))] if ok { @@ -873,7 +873,7 @@ func (mc *ModelCatalog) getBasePricing(model, provider string, requestType schem } // Lookup in chat if responses not found - if requestType == schemas.ResponsesRequest || requestType == schemas.ResponsesStreamRequest { + if requestType == schemas.ResponsesRequest || requestType == schemas.ResponsesStreamRequest || requestType == schemas.RealtimeRequest { mc.logger.Debug("secondary lookup failed, trying chat provider for the same model in chat completion") pricing, ok = mc.pricingData[makeKey("anthropic."+model, provider, normalizeRequestType(schemas.ChatCompletionRequest))] if ok { @@ -884,7 +884,7 @@ func (mc *ModelCatalog) getBasePricing(model, provider string, requestType schem } // Lookup in chat if responses not found - if requestType == schemas.ResponsesRequest || requestType == schemas.ResponsesStreamRequest { + if requestType == schemas.ResponsesRequest || requestType == schemas.ResponsesStreamRequest || requestType == schemas.RealtimeRequest { mc.logger.Debug("primary lookup failed, trying chat provider for the same model in chat completion") pricing, ok = mc.pricingData[makeKey(model, provider, normalizeRequestType(schemas.ChatCompletionRequest))] if ok { diff --git a/framework/modelcatalog/pricing_test.go b/framework/modelcatalog/pricing_test.go index d273f32cfb..4fc185e7b4 100644 --- a/framework/modelcatalog/pricing_test.go +++ b/framework/modelcatalog/pricing_test.go @@ -1192,6 +1192,14 @@ func TestGetPricing_ResponsesStreamFallsBackToChat(t *testing.T) { assert.Equal(t, 0.000005, derefF(p.InputCostPerToken)) } +func TestGetPricing_RealtimeFallsBackToChat(t *testing.T) { + mc := testCatalogWithPricing(map[string]configstoreTables.TableModelPricing{ + makeKey("gpt-4o", "openai", "chat"): chatPricing(0.000005, 0.000015), + }) + p := mc.resolvePricing("openai", "gpt-4o", "", schemas.RealtimeRequest, PricingLookupScopes{Provider: "openai"}) + assert.Equal(t, 0.000005, derefF(p.InputCostPerToken)) +} + func TestGetPricing_GeminiResponsesFallsBackToVertexChat(t *testing.T) { mc := testCatalogWithPricing(map[string]configstoreTables.TableModelPricing{ makeKey("gemini-2.0-flash", "vertex", "chat"): chatPricing(0.0000001, 0.0000004), @@ -1256,6 +1264,7 @@ func TestNormalizeStreamRequestType(t *testing.T) { {schemas.TranscriptionStreamRequest, schemas.TranscriptionRequest}, {schemas.ImageGenerationStreamRequest, schemas.ImageGenerationRequest}, {schemas.ImageEditStreamRequest, schemas.ImageEditRequest}, + {schemas.RealtimeRequest, schemas.RealtimeRequest}, // realtime is its own base type {schemas.ChatCompletionRequest, schemas.ChatCompletionRequest}, // non-stream unchanged {schemas.EmbeddingRequest, schemas.EmbeddingRequest}, // non-stream unchanged } diff --git a/framework/modelcatalog/utils.go b/framework/modelcatalog/utils.go index dd6d270f60..6fae8251f9 100644 --- a/framework/modelcatalog/utils.go +++ b/framework/modelcatalog/utils.go @@ -35,7 +35,7 @@ func normalizeRequestType(reqType schemas.RequestType) string { baseType = "completion" case schemas.ChatCompletionRequest, schemas.ChatCompletionStreamRequest: baseType = "chat" - case schemas.ResponsesRequest, schemas.ResponsesStreamRequest: + case schemas.ResponsesRequest, schemas.ResponsesStreamRequest, schemas.RealtimeRequest: baseType = "responses" case schemas.EmbeddingRequest: baseType = "embedding" @@ -66,6 +66,8 @@ func normalizeStreamRequestType(rt schemas.RequestType) schemas.RequestType { return schemas.ChatCompletionRequest case schemas.ResponsesStreamRequest: return schemas.ResponsesRequest + case schemas.RealtimeRequest: + return schemas.RealtimeRequest case schemas.SpeechStreamRequest: return schemas.SpeechRequest case schemas.TranscriptionStreamRequest: diff --git a/framework/tracing/tracer.go b/framework/tracing/tracer.go index c44f513428..55fd8873d3 100644 --- a/framework/tracing/tracer.go +++ b/framework/tracing/tracer.go @@ -3,6 +3,9 @@ package tracing import ( "context" + "strings" + "sync" + "sync/atomic" "time" "github.com/maximhq/bifrost/core/schemas" @@ -18,6 +21,9 @@ type Tracer struct { store *TraceStore accumulator *streaming.Accumulator pricingManager *modelcatalog.ModelCatalog + logger schemas.Logger + obsPlugins atomic.Pointer[[]schemas.ObservabilityPlugin] + flushWG sync.WaitGroup } // NewTracer creates a new Tracer wrapping the given TraceStore. @@ -28,9 +34,19 @@ func NewTracer(store *TraceStore, pricingManager *modelcatalog.ModelCatalog, log store: store, accumulator: streaming.NewAccumulator(pricingManager, logger), pricingManager: pricingManager, + logger: logger, + obsPlugins: atomic.Pointer[[]schemas.ObservabilityPlugin]{}, } } +// SetObservabilityPlugins updates the plugins that receive completed traces. +func (t *Tracer) SetObservabilityPlugins(obsPlugins []schemas.ObservabilityPlugin) { + if t == nil { + return + } + t.obsPlugins.Store(&obsPlugins) +} + // CreateTrace creates a new trace with optional parent ID and returns the trace ID. func (t *Tracer) CreateTrace(parentID string, requestID ...string) string { return t.store.CreateTrace(parentID, requestID...) @@ -360,6 +376,7 @@ func (t *Tracer) AttachPluginLogs(traceID string, logs []schemas.PluginLogEntry) // Stop stops the tracer and releases its resources. // This stops the internal TraceStore's cleanup goroutine. func (t *Tracer) Stop() { + t.flushWG.Wait() if t.store != nil { t.store.Stop() } @@ -368,5 +385,56 @@ func (t *Tracer) Stop() { } } +// CompleteAndFlushTrace ends a trace and forwards it to any observability +// plugins asynchronously. Realtime transports need this explicit flush because +// they bypass the HTTP tracing middleware that normally injects completed traces. +func (t *Tracer) CompleteAndFlushTrace(traceID string) { + if t == nil { + return + } + if strings.TrimSpace(traceID) == "" { + return + } + t.flushWG.Go(func() { + completedTrace := t.EndTrace(strings.TrimSpace(traceID)) + if completedTrace == nil { + return + } + // Defer release so the pooled trace is returned even if a plugin panics; + // otherwise an unrecovered panic in this detached goroutine leaks the + // trace object and takes down the whole process. + defer t.ReleaseTrace(completedTrace) + + var obsPlugins []schemas.ObservabilityPlugin + if loaded := t.obsPlugins.Load(); loaded != nil { + obsPlugins = *loaded + } + seen := make(map[string]struct{}, len(obsPlugins)) + for _, plugin := range obsPlugins { + if plugin == nil { + continue + } + // Isolate each plugin callback — one bad observability backend should + // not crash the server or prevent other plugins from receiving the trace. + func(plugin schemas.ObservabilityPlugin) { + name := "" + defer func() { + if r := recover(); r != nil && t.logger != nil { + t.logger.Error("observability plugin %s panicked during trace injection: %v", name, r) + } + }() + name = plugin.GetName() + if _, exists := seen[name]; exists { + return + } + seen[name] = struct{}{} + if err := plugin.Inject(context.Background(), completedTrace); err != nil && t.logger != nil { + t.logger.Warn("observability plugin %s failed to inject trace: %v", name, err) + } + }(plugin) + } + }) +} + // Ensure Tracer implements schemas.Tracer at compile time var _ schemas.Tracer = (*Tracer)(nil) diff --git a/framework/tracing/tracer_test.go b/framework/tracing/tracer_test.go index 372e075829..33134c67d2 100644 --- a/framework/tracing/tracer_test.go +++ b/framework/tracing/tracer_test.go @@ -8,6 +8,57 @@ import ( "github.com/maximhq/bifrost/core/schemas" ) +type testRealtimeObservabilityPlugin struct { + injected chan *schemas.Trace +} + +func (p *testRealtimeObservabilityPlugin) GetName() string { return "test-observability" } +func (p *testRealtimeObservabilityPlugin) Cleanup() error { return nil } +func (p *testRealtimeObservabilityPlugin) PreLLMHook(_ *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.LLMPluginShortCircuit, error) { + return req, nil, nil +} +func (p *testRealtimeObservabilityPlugin) PostLLMHook(_ *schemas.BifrostContext, resp *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) { + return resp, bifrostErr, nil +} +func (p *testRealtimeObservabilityPlugin) Inject(_ context.Context, trace *schemas.Trace) error { + if trace == nil { + p.injected <- nil + return nil + } + traceCopy := *trace + p.injected <- &traceCopy + return nil +} + +func TestTracer_CompleteAndFlushTraceInjectsObservabilityPlugins(t *testing.T) { + store := NewTraceStore(5*time.Minute, nil) + defer store.Stop() + + tracer := NewTracer(store, nil, nil) + defer tracer.Stop() + + traceID := tracer.CreateTrace("") + plugin := &testRealtimeObservabilityPlugin{ + injected: make(chan *schemas.Trace, 1), + } + + tracer.SetObservabilityPlugins([]schemas.ObservabilityPlugin{plugin}) + tracer.CompleteAndFlushTrace(traceID) + + select { + case trace := <-plugin.injected: + if trace == nil || trace.TraceID != traceID { + t.Fatalf("injected trace = %+v, want trace %q", trace, traceID) + } + case <-time.After(time.Second): + t.Fatal("timed out waiting for observability inject") + } + + if got := tracer.store.GetTrace(traceID); got != nil { + t.Fatalf("trace %q was not released after flush", traceID) + } +} + func TestTracer_StartSpan_RootSpanWithW3CParent(t *testing.T) { // This is the key test: verifies that when an incoming request has a W3C traceparent header, // the root span in Bifrost correctly links to the upstream service's span. diff --git a/plugins/governance/main.go b/plugins/governance/main.go index 4bdf362d95..d4821d282c 100644 --- a/plugins/governance/main.go +++ b/plugins/governance/main.go @@ -49,6 +49,7 @@ type InMemoryStore interface { type BaseGovernancePlugin interface { GetName() string + EvaluateGovernanceRequest(ctx *schemas.BifrostContext, evaluationRequest *EvaluationRequest, requestType schemas.RequestType) (*EvaluationResult, *schemas.BifrostError) HTTPTransportPreHook(ctx *schemas.BifrostContext, req *schemas.HTTPRequest) (*schemas.HTTPResponse, error) HTTPTransportPostHook(ctx *schemas.BifrostContext, req *schemas.HTTPRequest, resp *schemas.HTTPResponse) error PreLLMHook(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.LLMPluginShortCircuit, error) @@ -366,7 +367,19 @@ func (p *GovernancePlugin) HTTPTransportPreHook(ctx *schemas.BifrostContext, req var needsMarshal bool contentType := req.CaseInsensitiveHeaderLookup("Content-Type") - isMultipart := strings.HasPrefix(strings.ToLower(contentType), "multipart/form-data") + lowerCT := strings.ToLower(contentType) + // Strip parameters (e.g., "; charset=utf-8") for clean media type comparison + mediaType := lowerCT + if idx := strings.IndexByte(mediaType, ';'); idx >= 0 { + mediaType = strings.TrimSpace(mediaType[:idx]) + } + isMultipart := strings.HasPrefix(mediaType, "multipart/form-data") + isJSON := mediaType == "" || mediaType == "application/json" || strings.HasSuffix(mediaType, "+json") + + if !isMultipart && !isJSON { + // Non-parseable body (e.g., application/sdp for WebRTC signaling) — skip governance + return nil, nil + } var err error if isMultipart { @@ -994,7 +1007,7 @@ func (p *GovernancePlugin) validateRequiredHeaders(ctx *schemas.BifrostContext) return nil } -// evaluateGovernanceRequest is a common function that handles virtual key validation +// EvaluateGovernanceRequest is a common function that handles virtual key validation // and governance evaluation logic. It returns the evaluation result and a BifrostError // if the request should be rejected, or nil if allowed. // @@ -1005,7 +1018,7 @@ func (p *GovernancePlugin) validateRequiredHeaders(ctx *schemas.BifrostContext) // Returns: // - *EvaluationResult: The governance evaluation result // - *schemas.BifrostError: The error to return if request is not allowed, nil if allowed -func (p *GovernancePlugin) evaluateGovernanceRequest(ctx *schemas.BifrostContext, evaluationRequest *EvaluationRequest, requestType schemas.RequestType) (*EvaluationResult, *schemas.BifrostError) { +func (p *GovernancePlugin) EvaluateGovernanceRequest(ctx *schemas.BifrostContext, evaluationRequest *EvaluationRequest, requestType schemas.RequestType) (*EvaluationResult, *schemas.BifrostError) { // Check if authentication is mandatory (either VK or user auth) // Checking if the virtual key is valid or not isVirtualKeyValid := false @@ -1218,7 +1231,7 @@ func (p *GovernancePlugin) PreLLMHook(ctx *schemas.BifrostContext, req *schemas. UserID: userID, } // Evaluate governance using common function - _, bifrostError := p.evaluateGovernanceRequest(ctx, evaluationRequest, req.RequestType) + _, bifrostError := p.EvaluateGovernanceRequest(ctx, evaluationRequest, req.RequestType) // Convert BifrostError to LLMPluginShortCircuit if needed if bifrostError != nil { return req, &schemas.LLMPluginShortCircuit{ @@ -1317,7 +1330,7 @@ func (p *GovernancePlugin) PreMCPHook(ctx *schemas.BifrostContext, req *schemas. } // Evaluate governance using common function - _, bifrostError := p.evaluateGovernanceRequest(ctx, evaluationRequest, schemas.MCPToolExecutionRequest) + _, bifrostError := p.EvaluateGovernanceRequest(ctx, evaluationRequest, schemas.MCPToolExecutionRequest) // Convert BifrostError to MCPPluginShortCircuit if needed if bifrostError != nil { @@ -1327,7 +1340,7 @@ func (p *GovernancePlugin) PreMCPHook(ctx *schemas.BifrostContext, req *schemas. } // Blind single-tool check: validate the specific tool being executed against VK MCPConfigs. - // This runs independently of evaluateGovernanceRequest to enforce execution-time allow-list. + // This runs independently of EvaluateGovernanceRequest to enforce execution-time allow-list. if virtualKeyValue != "" { vk, ok := p.store.GetVirtualKey(virtualKeyValue) if !ok || vk == nil || !vk.IsActive { diff --git a/plugins/logging/main.go b/plugins/logging/main.go index 886cf3083e..83bc98f54c 100644 --- a/plugins/logging/main.go +++ b/plugins/logging/main.go @@ -449,6 +449,9 @@ func (p *LoggerPlugin) PreLLMHook(ctx *schemas.BifrostContext, req *schemas.Bifr Model: model, Object: string(req.RequestType), } + if req.RequestType == schemas.RealtimeRequest { + initialData.Object = "realtime.turn" + } if p.disableContentLogging == nil || !*p.disableContentLogging { inputHistory, responsesInputHistory := p.extractInputHistory(req) @@ -469,6 +472,10 @@ func (p *LoggerPlugin) PreLLMHook(ctx *schemas.BifrostContext, req *schemas.Bifr tools = append(tools, *tool.ToChatTool()) } initialData.Tools = tools + case schemas.RealtimeRequest: + if req.ResponsesRequest != nil { + initialData.Params = req.ResponsesRequest.Params + } case schemas.EmbeddingRequest: initialData.Params = req.EmbeddingRequest.Params case schemas.RerankRequest: @@ -574,7 +581,7 @@ func (p *LoggerPlugin) PreLLMHook(ctx *schemas.BifrostContext, req *schemas.Bifr } // Capture configured logging headers and x-bf-lh-* headers into metadata first - initialData.Metadata = p.captureLoggingHeaders(ctx) + initialData.Metadata = mergeRealtimeMetadata(p.captureLoggingHeaders(ctx), ctx) // System entries are set after so they take precedence over dynamic header values if isAsync, ok := ctx.Value(schemas.BifrostIsAsyncRequest).(bool); ok && isAsync { @@ -592,10 +599,15 @@ func (p *LoggerPlugin) PreLLMHook(ctx *schemas.BifrostContext, req *schemas.Bifr // Determine effective request ID (fallback override) effectiveRequestID := requestID var parentRequestID string + if directParentRequestID, ok := ctx.Value(schemas.BifrostContextKeyParentRequestID).(string); ok && directParentRequestID != "" { + parentRequestID = directParentRequestID + } fallbackRequestID, ok := ctx.Value(schemas.BifrostContextKeyFallbackRequestID).(string) if ok && fallbackRequestID != "" { effectiveRequestID = fallbackRequestID - parentRequestID = requestID + if parentRequestID == "" { + parentRequestID = requestID + } } fallbackIndex := bifrost.GetIntFromContext(ctx, schemas.BifrostContextKeyFallbackIndex) @@ -672,7 +684,7 @@ func (p *LoggerPlugin) PostLLMHook(ctx *schemas.BifrostContext, result *schemas. var tracer schemas.Tracer var traceID string - if bifrost.IsStreamRequestType(requestType) && requestType != schemas.PassthroughStreamRequest { + if bifrost.IsStreamRequestType(requestType) && requestType != schemas.PassthroughStreamRequest && requestType != schemas.RealtimeRequest { var err error tracer, traceID, err = bifrost.GetTracerFromContext(ctx) if err != nil { @@ -686,7 +698,7 @@ func (p *LoggerPlugin) PostLLMHook(ctx *schemas.BifrostContext, result *schemas. // and skip the write queue entirely. The accumulator work (ProcessStreamingChunk) // is fast (mutex + append). Only final chunks, errors, and non-streaming // responses need a DB write. - if bifrost.IsStreamRequestType(requestType) && requestType != schemas.PassthroughStreamRequest && !isFinalChunk && result != nil && bifrostErr == nil { + if bifrost.IsStreamRequestType(requestType) && requestType != schemas.PassthroughStreamRequest && requestType != schemas.RealtimeRequest && !isFinalChunk && result != nil && bifrostErr == nil { if tracer != nil && traceID != "" { tracer.ProcessStreamingChunk(traceID, false, result, bifrostErr) } @@ -725,6 +737,11 @@ func (p *LoggerPlugin) PostLLMHook(ctx *schemas.BifrostContext, result *schemas. } pending := pendingVal.(*PendingLogData) + if requestType == schemas.RealtimeRequest { + if resolvedRealtimeSessionID := bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeyRealtimeSessionID); resolvedRealtimeSessionID != "" { + pending.ParentRequestID = resolvedRealtimeSessionID + } + } // Build the complete log entry with input (from PreLLMHook) + output (from PostLLMHook) entry := buildCompleteLogEntryFromPending(pending) @@ -735,6 +752,7 @@ func (p *LoggerPlugin) PostLLMHook(ctx *schemas.BifrostContext, result *schemas. } applyOutputFieldsToEntry(entry, selectedKeyID, selectedKeyName, virtualKeyID, virtualKeyName, routingRuleID, routingRuleName, numberOfRetries, latency) entry.MetadataParsed = pending.InitialData.Metadata + entry.MetadataParsed = mergeRealtimeMetadata(entry.MetadataParsed, ctx) entry.RoutingEngineLogs = routingEngineLogs // Branch based on response type to populate output-specific fields @@ -775,7 +793,7 @@ func (p *LoggerPlugin) PostLLMHook(ctx *schemas.BifrostContext, result *schemas. } // Path B: Streaming final chunk - if bifrost.IsStreamRequestType(requestType) { + if bifrost.IsStreamRequestType(requestType) && requestType != schemas.RealtimeRequest { var streamResponse *streaming.ProcessedStreamResponse if requestType != schemas.PassthroughStreamRequest && tracer != nil && traceID != "" { accResult := tracer.ProcessStreamingChunk(traceID, isFinalChunk, result, bifrostErr) @@ -834,11 +852,21 @@ func (p *LoggerPlugin) PostLLMHook(ctx *schemas.BifrostContext, result *schemas. entry.ErrorDetails = string(data) } entry.ErrorDetailsParsed = bifrostErr + // Realtime turns that fail mid-stream still need their input transcript + // surfaced — backfill from bifrostErr.ExtraFields.RawRequest if present. + if requestType == schemas.RealtimeRequest { + contentLoggingEnabled := p.disableContentLogging == nil || !*p.disableContentLogging + applyRealtimeRawRequestBackfill(entry, bifrostErr.ExtraFields.RawRequest, contentLoggingEnabled) + } } else if result != nil { entry.Status = "success" extraFields := result.GetExtraFields() applyModelAlias(entry, extraFields.OriginalModelRequested, extraFields.ResolvedModelUsed) - p.applyNonStreamingOutputToEntry(entry, result) + if requestType == schemas.RealtimeRequest { + p.applyRealtimeOutputToEntry(entry, result) + } else { + p.applyNonStreamingOutputToEntry(entry, result) + } // Flip status for passthrough error responses (4xx/5xx from provider) if isPassthroughErrorResponse(result) { entry.Status = "error" diff --git a/plugins/logging/operations.go b/plugins/logging/operations.go index 1ec78b951a..1383cba265 100644 --- a/plugins/logging/operations.go +++ b/plugins/logging/operations.go @@ -4,6 +4,7 @@ package logging import ( "context" "fmt" + "strings" "time" "github.com/bytedance/sonic" @@ -13,6 +14,8 @@ import ( "github.com/maximhq/bifrost/framework/streaming" ) +const realtimeMissingTranscriptText = "[Audio transcription unavailable]" + // insertInitialLogEntry creates a new log entry in the database using GORM func (p *LoggerPlugin) insertInitialLogEntry( ctx context.Context, @@ -722,6 +725,350 @@ func (p *LoggerPlugin) applyNonStreamingOutputToEntry(entry *logstore.Log, resul } } +func (p *LoggerPlugin) applyRealtimeOutputToEntry(entry *logstore.Log, result *schemas.BifrostResponse) { + if result == nil || result.ResponsesResponse == nil { + return + } + + if usage := result.ResponsesResponse.Usage; usage != nil { + bifrostUsage := usage.ToBifrostLLMUsage() + entry.TokenUsageParsed = bifrostUsage + entry.PromptTokens = bifrostUsage.PromptTokens + entry.CompletionTokens = bifrostUsage.CompletionTokens + entry.TotalTokens = bifrostUsage.TotalTokens + } + + contentLoggingEnabled := p.disableContentLogging == nil || !*p.disableContentLogging + + if contentLoggingEnabled { + if outputMessage := extractRealtimeOutputMessage(result.ResponsesResponse.Output); outputMessage != nil { + entry.OutputMessageParsed = outputMessage + } + } + + extraFields := result.GetExtraFields() + applyRealtimeRawRequestBackfill(entry, extraFields.RawRequest, contentLoggingEnabled) + if contentLoggingEnabled && extraFields.RawResponse != nil { + switch raw := extraFields.RawResponse.(type) { + case string: + entry.RawResponse = strings.TrimSpace(raw) + default: + if rawResponseBytes, err := sonic.Marshal(extraFields.RawResponse); err == nil { + entry.RawResponse = string(rawResponseBytes) + } + } + } +} + +// applyRealtimeRawRequestBackfill writes RawRequest onto entry from an +// ExtraFields.RawRequest value (string or marshalable) and rebuilds +// InputHistoryParsed from any embedded realtime user/transcript events. +// Used by both success and error paths so realtime turns that fail mid-stream +// still surface their input transcript in logs. +func applyRealtimeRawRequestBackfill(entry *logstore.Log, rawRequest any, contentLoggingEnabled bool) { + if !contentLoggingEnabled || rawRequest == nil { + return + } + switch raw := rawRequest.(type) { + case string: + entry.RawRequest = strings.TrimSpace(raw) + default: + if rawRequestBytes, err := sonic.Marshal(rawRequest); err == nil { + entry.RawRequest = string(rawRequestBytes) + } + } + if strings.TrimSpace(entry.RawRequest) == "" { + return + } + if inputHistory := extractRealtimeInputHistoryFromRawRequest(entry.RawRequest); len(inputHistory) > 0 { + entry.InputHistoryParsed = mergeRealtimeInputHistory(entry.InputHistoryParsed, inputHistory) + } +} + +func extractRealtimeInputHistoryFromRawRequest(rawRequest string) []schemas.ChatMessage { + rawRequest = strings.TrimSpace(rawRequest) + if rawRequest == "" { + return nil + } + + parts := strings.Split(rawRequest, "\n\n") + messages := make([]schemas.ChatMessage, 0, len(parts)) + for _, part := range parts { + event, err := schemas.ParseRealtimeEvent([]byte(strings.TrimSpace(part))) + if err != nil || event == nil { + continue + } + + switch { + case schemas.IsRealtimeInputTranscriptEvent(event): + if transcript := extractRealtimeTranscript(event); transcript != "" { + messages = append(messages, schemas.ChatMessage{ + Role: schemas.ChatMessageRoleUser, + Content: &schemas.ChatMessageContent{ + ContentStr: schemas.Ptr(transcript), + }, + }) + } + case schemas.IsRealtimeUserInputEvent(event): + if content := extractRealtimeRawItemContent(event.Item); content != "" { + messages = append(messages, schemas.ChatMessage{ + Role: schemas.ChatMessageRoleUser, + Content: &schemas.ChatMessageContent{ + ContentStr: schemas.Ptr(content), + }, + }) + } + case schemas.IsRealtimeToolOutputEvent(event): + if content := extractRealtimeRawItemContent(event.Item); content != "" { + messages = append(messages, schemas.ChatMessage{ + Role: schemas.ChatMessageRoleTool, + Content: &schemas.ChatMessageContent{ + ContentStr: schemas.Ptr(content), + }, + ChatToolMessage: &schemas.ChatToolMessage{ + ToolCallID: schemas.Ptr(event.Item.CallID), + }, + }) + } + } + } + + if len(messages) == 0 { + return nil + } + return messages +} + +func mergeRealtimeInputHistory(existing, backfill []schemas.ChatMessage) []schemas.ChatMessage { + if len(backfill) == 0 { + return existing + } + + // Run dedupe even when existing is empty so duplicate events inside the + // same raw-event blob (same turn captured twice) collapse instead of + // getting written out verbatim. + merged := append([]schemas.ChatMessage(nil), existing...) + for _, candidate := range backfill { + if realtimeInputHistoryContainsEquivalent(merged, candidate) { + continue + } + if candidate.Role == schemas.ChatMessageRoleUser { + inserted := false + for idx, msg := range merged { + if msg.Role == schemas.ChatMessageRoleTool { + merged = append(merged[:idx], append([]schemas.ChatMessage{candidate}, merged[idx:]...)...) + inserted = true + break + } + } + if inserted { + continue + } + } + merged = append(merged, candidate) + } + return merged +} + +func realtimeInputHistoryContainsEquivalent(history []schemas.ChatMessage, candidate schemas.ChatMessage) bool { + candidateContent := strings.TrimSpace(realtimeInputHistoryMessageContent(candidate)) + candidateToolCallID := strings.TrimSpace(realtimeInputHistoryToolCallID(candidate)) + + for _, existing := range history { + if existing.Role != candidate.Role { + continue + } + if strings.TrimSpace(realtimeInputHistoryMessageContent(existing)) != candidateContent { + continue + } + if strings.TrimSpace(realtimeInputHistoryToolCallID(existing)) != candidateToolCallID { + continue + } + return true + } + + return false +} + +func realtimeInputHistoryMessageContent(message schemas.ChatMessage) string { + if message.Content == nil || message.Content.ContentStr == nil { + return "" + } + return *message.Content.ContentStr +} + +func realtimeInputHistoryToolCallID(message schemas.ChatMessage) string { + if message.ChatToolMessage == nil || message.ChatToolMessage.ToolCallID == nil { + return "" + } + return *message.ChatToolMessage.ToolCallID +} + +func extractRealtimeTranscript(event *schemas.BifrostRealtimeEvent) string { + if event == nil || event.ExtraParams == nil { + return realtimeMissingTranscriptText + } + raw, ok := event.ExtraParams["transcript"] + if !ok || len(raw) == 0 { + return realtimeMissingTranscriptText + } + var transcript string + if err := schemas.Unmarshal(raw, &transcript); err != nil { + return realtimeMissingTranscriptText + } + transcript = strings.TrimSpace(transcript) + if transcript == "" { + return realtimeMissingTranscriptText + } + return transcript +} + +func extractRealtimeRawItemContent(item *schemas.RealtimeItem) string { + if item == nil { + return "" + } + if content := extractRealtimeRawContent(item.Content); content != "" { + return content + } + if item.Role == "user" && realtimeItemHasMissingAudioTranscript(item) { + return realtimeMissingTranscriptText + } + switch { + case strings.TrimSpace(item.Output) != "": + return strings.TrimSpace(item.Output) + case strings.TrimSpace(item.Arguments) != "": + return strings.TrimSpace(item.Arguments) + default: + return "" + } +} + +func realtimeItemHasMissingAudioTranscript(item *schemas.RealtimeItem) bool { + if item == nil || len(item.Content) == 0 { + return false + } + + var decoded []map[string]any + if err := sonic.Unmarshal(item.Content, &decoded); err != nil { + return false + } + + for _, part := range decoded { + partType, _ := part["type"].(string) + if partType != "input_audio" { + continue + } + transcript, exists := part["transcript"] + if !exists || transcript == nil { + return true + } + if text, ok := transcript.(string); ok && strings.TrimSpace(text) == "" { + return true + } + } + + return false +} + +func extractRealtimeRawContent(raw []byte) string { + if len(raw) == 0 { + return "" + } + + var decoded any + if err := sonic.Unmarshal(raw, &decoded); err != nil { + return strings.TrimSpace(string(raw)) + } + + var parts []string + collectRealtimeRawTextFragments(decoded, &parts) + return strings.TrimSpace(strings.Join(parts, " ")) +} + +func collectRealtimeRawTextFragments(value any, parts *[]string) { + switch v := value.(type) { + case map[string]any: + for key, field := range v { + switch key { + case "text", "transcript", "input_text", "output_text", "output", "arguments": + if text, ok := field.(string); ok { + text = strings.TrimSpace(text) + if text != "" { + *parts = append(*parts, text) + } + continue + } + } + collectRealtimeRawTextFragments(field, parts) + } + case []any: + for _, item := range v { + collectRealtimeRawTextFragments(item, parts) + } + } +} + +func extractRealtimeOutputMessage(output []schemas.ResponsesMessage) *schemas.ChatMessage { + var contentParts []string + toolCalls := make([]schemas.ChatAssistantMessageToolCall, 0) + for _, item := range output { + if item.Type == nil { + continue + } + switch *item.Type { + case schemas.ResponsesMessageTypeMessage: + if item.Role == nil || *item.Role != schemas.ResponsesInputMessageRoleAssistant { + continue + } + if text := extractRealtimeResponsesContent(item.Content); text != "" { + contentParts = append(contentParts, text) + } + case schemas.ResponsesMessageTypeFunctionCall: + if item.ResponsesToolMessage == nil || item.ResponsesToolMessage.Name == nil { + continue + } + toolType := "function" + toolCall := schemas.ChatAssistantMessageToolCall{ + Index: uint16(len(toolCalls)), + Type: &toolType, + Function: schemas.ChatAssistantMessageToolCallFunction{ + Name: item.ResponsesToolMessage.Name, + Arguments: derefString(item.ResponsesToolMessage.Arguments), + }, + } + if item.CallID != nil && strings.TrimSpace(*item.CallID) != "" { + toolCall.ID = schemas.Ptr(strings.TrimSpace(*item.CallID)) + } else if item.ID != nil && strings.TrimSpace(*item.ID) != "" { + toolCall.ID = schemas.Ptr(strings.TrimSpace(*item.ID)) + } + toolCalls = append(toolCalls, toolCall) + } + } + + if len(contentParts) == 0 && len(toolCalls) == 0 { + return nil + } + + message := &schemas.ChatMessage{Role: schemas.ChatMessageRoleAssistant} + if len(contentParts) > 0 { + content := strings.Join(contentParts, "\n") + message.Content = &schemas.ChatMessageContent{ContentStr: &content} + } + if len(toolCalls) > 0 { + message.ChatAssistantMessage = &schemas.ChatAssistantMessage{ + ToolCalls: toolCalls, + } + } + return message +} + +func derefString(value *string) string { + if value == nil { + return "" + } + return *value +} + // SearchLogs searches logs with filters and pagination using GORM func (p *LoggerPlugin) SearchLogs(ctx context.Context, filters logstore.SearchFilters, pagination logstore.PaginationOptions) (*logstore.SearchResult, error) { // Set default pagination if not provided @@ -738,6 +1085,25 @@ func (p *LoggerPlugin) SearchLogs(ctx context.Context, filters logstore.SearchFi return p.store.SearchLogs(ctx, filters, pagination) } +// GetSessionLogs returns paginated logs for a single parent_request_id session. +func (p *LoggerPlugin) GetSessionLogs(ctx context.Context, sessionID string, pagination logstore.PaginationOptions) (*logstore.SessionDetailResult, error) { + if pagination.Limit == 0 { + pagination.Limit = 50 + } + if pagination.SortBy == "" { + pagination.SortBy = "timestamp" + } + if pagination.Order == "" { + pagination.Order = "asc" + } + return p.store.GetSessionLogs(ctx, sessionID, pagination) +} + +// GetSessionSummary returns aggregate totals for a single parent_request_id session. +func (p *LoggerPlugin) GetSessionSummary(ctx context.Context, sessionID string) (*logstore.SessionSummaryResult, error) { + return p.store.GetSessionSummary(ctx, sessionID) +} + // GetLog retrieves a single log entry by ID including all fields (raw_request, raw_response). func (p *LoggerPlugin) GetLog(ctx context.Context, id string) (*logstore.Log, error) { return p.store.FindByID(ctx, id) diff --git a/plugins/logging/operations_test.go b/plugins/logging/operations_test.go index 27d4436818..7b12b33eed 100644 --- a/plugins/logging/operations_test.go +++ b/plugins/logging/operations_test.go @@ -311,3 +311,339 @@ func TestStoreOrEnqueueRetryPreservesAllEntries(t *testing.T) { t.Fatal("expected pendingLogsToInject to be cleaned up after Inject") } } + +func TestApplyRealtimeOutputToEntryBackfillsUserTranscriptFromRawRequest(t *testing.T) { + plugin := &LoggerPlugin{} + entry := &logstore.Log{} + + assistantText := "Hello!" + messageType := schemas.ResponsesMessageTypeMessage + assistantRole := schemas.ResponsesInputMessageRoleAssistant + result := &schemas.BifrostResponse{ + ResponsesResponse: &schemas.BifrostResponsesResponse{ + Output: []schemas.ResponsesMessage{{ + Type: &messageType, + Role: &assistantRole, + Content: &schemas.ResponsesMessageContent{ + ContentStr: &assistantText, + }, + }}, + ExtraFields: schemas.BifrostResponseExtraFields{ + RequestType: schemas.RealtimeRequest, + RawRequest: `{"type":"conversation.item.input_audio_transcription.completed","transcript":"Hello."}`, + RawResponse: `{"type":"response.done"}`, + }, + }, + } + + plugin.applyRealtimeOutputToEntry(entry, result) + if err := entry.SerializeFields(); err != nil { + t.Fatalf("SerializeFields() error = %v", err) + } + + if len(entry.InputHistoryParsed) != 1 { + t.Fatalf("len(InputHistoryParsed) = %d, want 1", len(entry.InputHistoryParsed)) + } + if entry.InputHistoryParsed[0].Role != schemas.ChatMessageRoleUser { + t.Fatalf("InputHistoryParsed[0].Role = %q, want user", entry.InputHistoryParsed[0].Role) + } + if entry.InputHistoryParsed[0].Content == nil || entry.InputHistoryParsed[0].Content.ContentStr == nil || *entry.InputHistoryParsed[0].Content.ContentStr != "Hello." { + t.Fatalf("InputHistoryParsed[0] = %+v, want transcript", entry.InputHistoryParsed[0]) + } + if entry.OutputMessageParsed == nil || entry.OutputMessageParsed.Content == nil || entry.OutputMessageParsed.Content.ContentStr == nil || *entry.OutputMessageParsed.Content.ContentStr != assistantText { + t.Fatalf("OutputMessageParsed = %+v, want assistant text", entry.OutputMessageParsed) + } + if !strings.Contains(entry.ContentSummary, "Hello.") { + t.Fatalf("ContentSummary = %q, want user transcript", entry.ContentSummary) + } + if !strings.Contains(entry.ContentSummary, "Hello!") { + t.Fatalf("ContentSummary = %q, want assistant text", entry.ContentSummary) + } +} + +func TestApplyRealtimeOutputToEntryBackfillsMissingTranscriptPlaceholder(t *testing.T) { + plugin := &LoggerPlugin{} + entry := &logstore.Log{} + + assistantText := "Hi there!" + messageType := schemas.ResponsesMessageTypeMessage + assistantRole := schemas.ResponsesInputMessageRoleAssistant + result := &schemas.BifrostResponse{ + ResponsesResponse: &schemas.BifrostResponsesResponse{ + Output: []schemas.ResponsesMessage{{ + Type: &messageType, + Role: &assistantRole, + Content: &schemas.ResponsesMessageContent{ + ContentStr: &assistantText, + }, + }}, + ExtraFields: schemas.BifrostResponseExtraFields{ + RequestType: schemas.RealtimeRequest, + RawRequest: `{"type":"conversation.item.input_audio_transcription.completed","transcript":""}`, + RawResponse: `{"type":"response.done"}`, + }, + }, + } + + plugin.applyRealtimeOutputToEntry(entry, result) + if err := entry.SerializeFields(); err != nil { + t.Fatalf("SerializeFields() error = %v", err) + } + + if len(entry.InputHistoryParsed) != 1 { + t.Fatalf("len(InputHistoryParsed) = %d, want 1", len(entry.InputHistoryParsed)) + } + if entry.InputHistoryParsed[0].Content == nil || entry.InputHistoryParsed[0].Content.ContentStr == nil || *entry.InputHistoryParsed[0].Content.ContentStr != realtimeMissingTranscriptText { + t.Fatalf("InputHistoryParsed[0] = %+v, want missing transcript placeholder", entry.InputHistoryParsed[0]) + } + if !strings.Contains(entry.ContentSummary, realtimeMissingTranscriptText) { + t.Fatalf("ContentSummary = %q, want missing transcript placeholder", entry.ContentSummary) + } +} + +func TestApplyRealtimeOutputToEntryBackfillsDoneMissingTranscriptPlaceholder(t *testing.T) { + plugin := &LoggerPlugin{} + entry := &logstore.Log{} + + assistantText := "Hi there!" + messageType := schemas.ResponsesMessageTypeMessage + assistantRole := schemas.ResponsesInputMessageRoleAssistant + result := &schemas.BifrostResponse{ + ResponsesResponse: &schemas.BifrostResponsesResponse{ + Output: []schemas.ResponsesMessage{{ + Type: &messageType, + Role: &assistantRole, + Content: &schemas.ResponsesMessageContent{ + ContentStr: &assistantText, + }, + }}, + ExtraFields: schemas.BifrostResponseExtraFields{ + RequestType: schemas.RealtimeRequest, + RawRequest: `{"type":"conversation.item.done","item":{"id":"item_user","type":"message","role":"user","status":"completed","content":[{"type":"input_audio","transcript":null}]}}`, + RawResponse: `{"type":"response.done"}`, + }, + }, + } + + plugin.applyRealtimeOutputToEntry(entry, result) + if err := entry.SerializeFields(); err != nil { + t.Fatalf("SerializeFields() error = %v", err) + } + + if len(entry.InputHistoryParsed) != 1 { + t.Fatalf("len(InputHistoryParsed) = %d, want 1", len(entry.InputHistoryParsed)) + } + if entry.InputHistoryParsed[0].Content == nil || entry.InputHistoryParsed[0].Content.ContentStr == nil || *entry.InputHistoryParsed[0].Content.ContentStr != realtimeMissingTranscriptText { + t.Fatalf("InputHistoryParsed[0] = %+v, want missing transcript placeholder", entry.InputHistoryParsed[0]) + } +} + +func TestApplyRealtimeOutputToEntryBackfillsRetrievedUserAndToolHistory(t *testing.T) { + plugin := &LoggerPlugin{} + entry := &logstore.Log{} + + assistantText := "I checked that for you." + messageType := schemas.ResponsesMessageTypeMessage + assistantRole := schemas.ResponsesInputMessageRoleAssistant + result := &schemas.BifrostResponse{ + ResponsesResponse: &schemas.BifrostResponsesResponse{ + Output: []schemas.ResponsesMessage{{ + Type: &messageType, + Role: &assistantRole, + Content: &schemas.ResponsesMessageContent{ + ContentStr: &assistantText, + }, + }}, + ExtraFields: schemas.BifrostResponseExtraFields{ + RequestType: schemas.RealtimeRequest, + RawRequest: strings.Join([]string{ + `{"type":"conversation.item.retrieved","item":{"id":"item_user","type":"message","role":"user","status":"completed","content":[{"type":"input_text","text":"Where is my order?"}]}}`, + `{"type":"conversation.item.retrieved","item":{"id":"item_tool","type":"function_call_output","call_id":"call_123","status":"completed","output":"{\"status\":\"delivered\"}"}}`, + }, "\n\n"), + RawResponse: `{"type":"response.done"}`, + }, + }, + } + + plugin.applyRealtimeOutputToEntry(entry, result) + if err := entry.SerializeFields(); err != nil { + t.Fatalf("SerializeFields() error = %v", err) + } + + if len(entry.InputHistoryParsed) != 2 { + t.Fatalf("len(InputHistoryParsed) = %d, want 2", len(entry.InputHistoryParsed)) + } + if entry.InputHistoryParsed[0].Role != schemas.ChatMessageRoleUser { + t.Fatalf("InputHistoryParsed[0].Role = %q, want user", entry.InputHistoryParsed[0].Role) + } + if entry.InputHistoryParsed[0].Content == nil || entry.InputHistoryParsed[0].Content.ContentStr == nil || *entry.InputHistoryParsed[0].Content.ContentStr != "Where is my order?" { + t.Fatalf("InputHistoryParsed[0] = %+v, want user content", entry.InputHistoryParsed[0]) + } + if entry.InputHistoryParsed[1].Role != schemas.ChatMessageRoleTool { + t.Fatalf("InputHistoryParsed[1].Role = %q, want tool", entry.InputHistoryParsed[1].Role) + } + if entry.InputHistoryParsed[1].Content == nil || entry.InputHistoryParsed[1].Content.ContentStr == nil || *entry.InputHistoryParsed[1].Content.ContentStr != `{"status":"delivered"}` { + t.Fatalf("InputHistoryParsed[1] = %+v, want tool content", entry.InputHistoryParsed[1]) + } + if entry.InputHistoryParsed[1].ChatToolMessage == nil || entry.InputHistoryParsed[1].ChatToolMessage.ToolCallID == nil || *entry.InputHistoryParsed[1].ChatToolMessage.ToolCallID != "call_123" { + t.Fatalf("InputHistoryParsed[1].ChatToolMessage = %+v, want tool call id", entry.InputHistoryParsed[1].ChatToolMessage) + } +} + +func TestApplyRealtimeOutputToEntryBackfillsCreatedUserAndToolHistory(t *testing.T) { + t.Parallel() + + plugin := &LoggerPlugin{} + entry := &logstore.Log{} + result := &schemas.BifrostResponse{ + ResponsesResponse: &schemas.BifrostResponsesResponse{ + ExtraFields: schemas.BifrostResponseExtraFields{ + RawRequest: strings.Join([]string{ + `{"type":"conversation.item.created","item":{"id":"item_user","type":"message","role":"user","status":"completed","content":[{"type":"input_text","text":"I need help"}]}}`, + `{"type":"conversation.item.created","item":{"id":"item_tool","type":"function_call_output","call_id":"call_456","status":"completed","output":"{\"status\":\"ok\"}"}}`, + }, "\n\n"), + }, + }, + } + + plugin.applyRealtimeOutputToEntry(entry, result) + + if len(entry.InputHistoryParsed) != 2 { + t.Fatalf("len(InputHistoryParsed) = %d, want 2", len(entry.InputHistoryParsed)) + } + if entry.InputHistoryParsed[0].Role != schemas.ChatMessageRoleUser { + t.Fatalf("InputHistoryParsed[0].Role = %q, want user", entry.InputHistoryParsed[0].Role) + } + if entry.InputHistoryParsed[0].Content == nil || entry.InputHistoryParsed[0].Content.ContentStr == nil || *entry.InputHistoryParsed[0].Content.ContentStr != "I need help" { + t.Fatalf("InputHistoryParsed[0] = %+v, want user content", entry.InputHistoryParsed[0]) + } + if entry.InputHistoryParsed[1].Role != schemas.ChatMessageRoleTool { + t.Fatalf("InputHistoryParsed[1].Role = %q, want tool", entry.InputHistoryParsed[1].Role) + } + if entry.InputHistoryParsed[1].Content == nil || entry.InputHistoryParsed[1].Content.ContentStr == nil || *entry.InputHistoryParsed[1].Content.ContentStr != `{"status":"ok"}` { + t.Fatalf("InputHistoryParsed[1] = %+v, want tool content", entry.InputHistoryParsed[1]) + } + if entry.InputHistoryParsed[1].ChatToolMessage == nil || entry.InputHistoryParsed[1].ChatToolMessage.ToolCallID == nil || *entry.InputHistoryParsed[1].ChatToolMessage.ToolCallID != "call_456" { + t.Fatalf("InputHistoryParsed[1].ChatToolMessage = %+v, want tool call id", entry.InputHistoryParsed[1].ChatToolMessage) + } +} + +func TestApplyRealtimeOutputToEntryBackfillsAddedUserAndToolHistory(t *testing.T) { + t.Parallel() + + plugin := &LoggerPlugin{} + entry := &logstore.Log{} + + assistantText := "Done." + messageType := schemas.ResponsesMessageTypeMessage + assistantRole := schemas.ResponsesInputMessageRoleAssistant + result := &schemas.BifrostResponse{ + ResponsesResponse: &schemas.BifrostResponsesResponse{ + Output: []schemas.ResponsesMessage{{ + Type: &messageType, + Role: &assistantRole, + Content: &schemas.ResponsesMessageContent{ + ContentStr: &assistantText, + }, + }}, + ExtraFields: schemas.BifrostResponseExtraFields{ + RequestType: schemas.RealtimeRequest, + RawRequest: strings.Join([]string{ + `{"type":"conversation.item.added","item":{"id":"item_user","type":"message","role":"user","status":"completed","content":[{"type":"input_text","text":"hello from added item"}]}}`, + `{"type":"conversation.item.added","item":{"id":"item_tool","type":"function_call_output","call_id":"call_added","status":"completed","output":"{\"status\":\"ok\"}"}}`, + }, "\n\n"), + RawResponse: `{"type":"response.done"}`, + }, + }, + } + + plugin.applyRealtimeOutputToEntry(entry, result) + if err := entry.SerializeFields(); err != nil { + t.Fatalf("SerializeFields() error = %v", err) + } + + if len(entry.InputHistoryParsed) != 2 { + t.Fatalf("len(InputHistoryParsed) = %d, want 2", len(entry.InputHistoryParsed)) + } + if entry.InputHistoryParsed[0].Content == nil || entry.InputHistoryParsed[0].Content.ContentStr == nil || *entry.InputHistoryParsed[0].Content.ContentStr != "hello from added item" { + t.Fatalf("InputHistoryParsed[0] = %+v, want added user content", entry.InputHistoryParsed[0]) + } + if entry.InputHistoryParsed[1].ChatToolMessage == nil || entry.InputHistoryParsed[1].ChatToolMessage.ToolCallID == nil || *entry.InputHistoryParsed[1].ChatToolMessage.ToolCallID != "call_added" { + t.Fatalf("InputHistoryParsed[1].ChatToolMessage = %+v, want added tool call id", entry.InputHistoryParsed[1].ChatToolMessage) + } +} + +func TestApplyRealtimeOutputToEntryMergesRawTranscriptIntoStructuredRealtimeHistory(t *testing.T) { + t.Parallel() + + plugin := &LoggerPlugin{} + entry := &logstore.Log{ + InputHistoryParsed: []schemas.ChatMessage{ + { + Role: schemas.ChatMessageRoleUser, + Content: &schemas.ChatMessageContent{ + ContentStr: schemas.Ptr("Can you help with my ticket?"), + }, + }, + { + Role: schemas.ChatMessageRoleTool, + Content: &schemas.ChatMessageContent{ + ContentStr: schemas.Ptr(`{"status":"open"}`), + }, + ChatToolMessage: &schemas.ChatToolMessage{ + ToolCallID: schemas.Ptr("call_789"), + }, + }, + }, + } + + assistantText := "Let me check." + messageType := schemas.ResponsesMessageTypeMessage + assistantRole := schemas.ResponsesInputMessageRoleAssistant + result := &schemas.BifrostResponse{ + ResponsesResponse: &schemas.BifrostResponsesResponse{ + Output: []schemas.ResponsesMessage{{ + Type: &messageType, + Role: &assistantRole, + Content: &schemas.ResponsesMessageContent{ + ContentStr: &assistantText, + }, + }}, + ExtraFields: schemas.BifrostResponseExtraFields{ + RequestType: schemas.RealtimeRequest, + RawRequest: strings.Join([]string{ + `{"type":"conversation.item.input_audio_transcription.completed","transcript":"Hello."}`, + `{"type":"conversation.item.retrieved","item":{"id":"item_tool","type":"function_call_output","call_id":"call_789","status":"completed","output":"{\"status\":\"open\"}"}}`, + }, "\n\n"), + RawResponse: `{"type":"response.done"}`, + }, + }, + } + + plugin.applyRealtimeOutputToEntry(entry, result) + if err := entry.SerializeFields(); err != nil { + t.Fatalf("SerializeFields() error = %v", err) + } + + if len(entry.InputHistoryParsed) != 3 { + t.Fatalf("len(InputHistoryParsed) = %d, want 3", len(entry.InputHistoryParsed)) + } + if entry.InputHistoryParsed[0].Content == nil || entry.InputHistoryParsed[0].Content.ContentStr == nil || *entry.InputHistoryParsed[0].Content.ContentStr != "Can you help with my ticket?" { + t.Fatalf("InputHistoryParsed[0] = %+v, want structured user content", entry.InputHistoryParsed[0]) + } + if entry.InputHistoryParsed[1].Role != schemas.ChatMessageRoleUser { + t.Fatalf("InputHistoryParsed[1].Role = %q, want user", entry.InputHistoryParsed[1].Role) + } + if entry.InputHistoryParsed[1].Content == nil || entry.InputHistoryParsed[1].Content.ContentStr == nil || *entry.InputHistoryParsed[1].Content.ContentStr != "Hello." { + t.Fatalf("InputHistoryParsed[1] = %+v, want raw transcript merge", entry.InputHistoryParsed[1]) + } + if entry.InputHistoryParsed[2].Role != schemas.ChatMessageRoleTool { + t.Fatalf("InputHistoryParsed[2].Role = %q, want tool", entry.InputHistoryParsed[2].Role) + } + if entry.InputHistoryParsed[2].ChatToolMessage == nil || entry.InputHistoryParsed[2].ChatToolMessage.ToolCallID == nil || *entry.InputHistoryParsed[2].ChatToolMessage.ToolCallID != "call_789" { + t.Fatalf("InputHistoryParsed[2].ChatToolMessage = %+v, want original tool call id", entry.InputHistoryParsed[2].ChatToolMessage) + } + if strings.Count(entry.ContentSummary, "Hello.") != 1 { + t.Fatalf("ContentSummary = %q, want one merged transcript", entry.ContentSummary) + } +} diff --git a/plugins/logging/utils.go b/plugins/logging/utils.go index 7b64d944bb..81ea1b0f3f 100644 --- a/plugins/logging/utils.go +++ b/plugins/logging/utils.go @@ -8,6 +8,7 @@ import ( "strings" "time" + bifrost "github.com/maximhq/bifrost/core" "github.com/maximhq/bifrost/core/schemas" "github.com/maximhq/bifrost/framework/logstore" "github.com/maximhq/bifrost/framework/streaming" @@ -27,6 +28,12 @@ type LogManager interface { // Search searches for log entries based on filters and pagination Search(ctx context.Context, filters *logstore.SearchFilters, pagination *logstore.PaginationOptions) (*logstore.SearchResult, error) + // GetSessionLogs returns paginated logs for a single parent_request_id session. + GetSessionLogs(ctx context.Context, sessionID string, pagination *logstore.PaginationOptions) (*logstore.SessionDetailResult, error) + + // GetSessionSummary returns aggregate totals for a single parent_request_id session. + GetSessionSummary(ctx context.Context, sessionID string) (*logstore.SessionSummaryResult, error) + // GetStats calculates statistics for logs matching the given filters GetStats(ctx context.Context, filters *logstore.SearchFilters) (*logstore.SearchStats, error) @@ -135,6 +142,23 @@ func (p *PluginLogManager) Search(ctx context.Context, filters *logstore.SearchF return p.plugin.SearchLogs(ctx, *filters, *pagination) } +func (p *PluginLogManager) GetSessionLogs(ctx context.Context, sessionID string, pagination *logstore.PaginationOptions) (*logstore.SessionDetailResult, error) { + if pagination == nil { + return nil, fmt.Errorf("pagination cannot be nil") + } + if strings.TrimSpace(sessionID) == "" { + return nil, fmt.Errorf("sessionID cannot be empty") + } + return p.plugin.GetSessionLogs(ctx, sessionID, *pagination) +} + +func (p *PluginLogManager) GetSessionSummary(ctx context.Context, sessionID string) (*logstore.SessionSummaryResult, error) { + if strings.TrimSpace(sessionID) == "" { + return nil, fmt.Errorf("sessionID cannot be empty") + } + return p.plugin.GetSessionSummary(ctx, sessionID) +} + func (p *PluginLogManager) GetStats(ctx context.Context, filters *logstore.SearchFilters) (*logstore.SearchStats, error) { if filters == nil { return nil, fmt.Errorf("filters cannot be nil") @@ -386,6 +410,9 @@ func (p *LoggerPlugin) extractInputHistory(request *schemas.BifrostRequest) ([]s if request.ChatRequest != nil { return request.ChatRequest.Input, []schemas.ResponsesMessage{} } + if request.RequestType == schemas.RealtimeRequest && request.ResponsesRequest != nil { + return extractRealtimeInputHistory(request.ResponsesRequest.Input), []schemas.ResponsesMessage{} + } if request.ResponsesRequest != nil && len(request.ResponsesRequest.Input) > 0 { return []schemas.ChatMessage{}, request.ResponsesRequest.Input } @@ -459,6 +486,96 @@ func (p *LoggerPlugin) extractInputHistory(request *schemas.BifrostRequest) ([]s return []schemas.ChatMessage{}, []schemas.ResponsesMessage{} } +func extractRealtimeInputHistory(input []schemas.ResponsesMessage) []schemas.ChatMessage { + messages := make([]schemas.ChatMessage, 0, len(input)) + for _, item := range input { + if item.Type == nil { + continue + } + switch *item.Type { + case schemas.ResponsesMessageTypeMessage: + if item.Role == nil || item.Content == nil { + continue + } + content := extractRealtimeResponsesContent(item.Content) + if content == "" { + continue + } + messages = append(messages, schemas.ChatMessage{ + Role: mapRealtimeResponsesRole(*item.Role), + Content: &schemas.ChatMessageContent{ + ContentStr: schemas.Ptr(content), + }, + }) + case schemas.ResponsesMessageTypeFunctionCallOutput, + schemas.ResponsesMessageTypeCustomToolCallOutput, + schemas.ResponsesMessageTypeLocalShellCallOutput, + schemas.ResponsesMessageTypeComputerCallOutput: + content := extractRealtimeToolOutputContent(item.ResponsesToolMessage) + if content == "" { + continue + } + messages = append(messages, schemas.ChatMessage{ + Role: schemas.ChatMessageRoleTool, + Content: &schemas.ChatMessageContent{ + ContentStr: schemas.Ptr(content), + }, + ChatToolMessage: &schemas.ChatToolMessage{ + ToolCallID: item.ResponsesToolMessage.CallID, + }, + }) + } + } + return messages +} + +func mapRealtimeResponsesRole(role schemas.ResponsesMessageRoleType) schemas.ChatMessageRole { + switch role { + case schemas.ResponsesInputMessageRoleAssistant: + return schemas.ChatMessageRoleAssistant + case schemas.ResponsesInputMessageRoleSystem: + return schemas.ChatMessageRoleSystem + case schemas.ResponsesInputMessageRoleDeveloper: + return schemas.ChatMessageRoleDeveloper + default: + return schemas.ChatMessageRoleUser + } +} + +func extractRealtimeResponsesContent(content *schemas.ResponsesMessageContent) string { + if content == nil { + return "" + } + if content.ContentStr != nil { + return strings.TrimSpace(*content.ContentStr) + } + parts := make([]string, 0, len(content.ContentBlocks)) + for _, block := range content.ContentBlocks { + switch { + case block.Text != nil && strings.TrimSpace(*block.Text) != "": + parts = append(parts, strings.TrimSpace(*block.Text)) + case block.ResponsesOutputMessageContentRefusal != nil && strings.TrimSpace(block.Refusal) != "": + parts = append(parts, strings.TrimSpace(block.Refusal)) + } + } + return strings.TrimSpace(strings.Join(parts, "\n")) +} + +func extractRealtimeToolOutputContent(toolMessage *schemas.ResponsesToolMessage) string { + if toolMessage == nil || toolMessage.Output == nil { + return "" + } + switch { + case toolMessage.Output.ResponsesToolCallOutputStr != nil: + return strings.TrimSpace(*toolMessage.Output.ResponsesToolCallOutputStr) + case len(toolMessage.Output.ResponsesFunctionToolCallOutputBlocks) > 0: + content := &schemas.ResponsesMessageContent{ContentBlocks: toolMessage.Output.ResponsesFunctionToolCallOutputBlocks} + return extractRealtimeResponsesContent(content) + default: + return "" + } +} + // convertToProcessedStreamResponse converts a StreamAccumulatorResult to ProcessedStreamResponse // for use with the logging plugin's streaming log update functionality. func convertToProcessedStreamResponse(result *schemas.StreamAccumulatorResult, requestType schemas.RequestType) *streaming.ProcessedStreamResponse { @@ -527,6 +644,32 @@ func convertToProcessedStreamResponse(result *schemas.StreamAccumulatorResult, r return resp } +func mergeRealtimeMetadata(metadata map[string]interface{}, ctx *schemas.BifrostContext) map[string]interface{} { + if ctx == nil { + return metadata + } + set := func(key string, ctxKey schemas.BifrostContextKey) { + if value := bifrost.GetStringFromContext(ctx, ctxKey); value != "" { + if metadata == nil { + metadata = make(map[string]interface{}) + } + metadata[key] = value + } + } + + set("realtime_session_id", schemas.BifrostContextKeyRealtimeSessionID) + set("provider_session_id", schemas.BifrostContextKeyRealtimeProviderSessionID) + set("realtime_source", schemas.BifrostContextKeyRealtimeSource) + set("realtime_event_type", schemas.BifrostContextKeyRealtimeEventType) + if bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeyRealtimeSessionID) != "" { + if metadata == nil { + metadata = make(map[string]interface{}) + } + metadata["realtime"] = true + } + return metadata +} + // formatRoutingEngineLogs formats routing engine logs into a human-readable string. // Format: [timestamp] [engine] - message // Parameters: diff --git a/plugins/maxim/main.go b/plugins/maxim/main.go index 0e85d79bc4..f9d9f7dbfd 100644 --- a/plugins/maxim/main.go +++ b/plugins/maxim/main.go @@ -245,6 +245,10 @@ func (plugin *Plugin) getOrCreateLogger(logRepoID string) (*logging.Logger, erro // - *schemas.BifrostRequest: The original request, unmodified // - error: Any error that occurred during trace/generation creation func (plugin *Plugin) PreLLMHook(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.LLMPluginShortCircuit, error) { + if req != nil && req.RequestType == schemas.RealtimeRequest { + return req, nil, nil + } + var traceID string var traceName string var sessionID string @@ -491,6 +495,11 @@ func (plugin *Plugin) PreLLMHook(ctx *schemas.BifrostContext, req *schemas.Bifro // - *schemas.BifrostError: The original error, unmodified // - error: Never returns an error as it handles missing IDs gracefully func (plugin *Plugin) PostLLMHook(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) { + requestType, _, _, _ := bifrost.GetResponseFields(result, bifrostErr) + if requestType == schemas.RealtimeRequest { + return result, bifrostErr, nil + } + // Get effective log repo ID for this request effectiveLogRepoID := plugin.getEffectiveLogRepoID(ctx) if effectiveLogRepoID == "" { diff --git a/transports/bifrost-http/handlers/devpprof.go b/transports/bifrost-http/handlers/devpprof.go index fca8712d9b..19ef208237 100644 --- a/transports/bifrost-http/handlers/devpprof.go +++ b/transports/bifrost-http/handlers/devpprof.go @@ -531,7 +531,7 @@ func categorizeGoroutine(g *GoroutineGroup) { "PostMCPHook", "HTTPTransportPreHook", "HTTPTransportPostHook", - "completeAndFlushTrace", + "CompleteAndFlushTrace", "ProcessAndSend", "handleProvider", "Inject", // Observability plugin inject diff --git a/transports/bifrost-http/handlers/integrations.go b/transports/bifrost-http/handlers/integrations.go index 2e5c7d7aec..da9290f117 100644 --- a/transports/bifrost-http/handlers/integrations.go +++ b/transports/bifrost-http/handlers/integrations.go @@ -12,13 +12,16 @@ import ( // IntegrationHandler manages HTTP requests for AI provider integrations type IntegrationHandler struct { - extensions []integrations.ExtensionRouter - wsResponses *WSResponsesHandler + extensions []integrations.ExtensionRouter + wsResponses *WSResponsesHandler + wsRealtime *WSRealtimeHandler + webrtcRealtime *WebRTCRealtimeHandler + realtimeClientSecrets *RealtimeClientSecretsHandler } // NewIntegrationHandler creates a new integration handler instance. -// wsResponses may be nil if WebSocket support is not configured. -func NewIntegrationHandler(client *bifrost.Bifrost, handlerStore lib.HandlerStore, wsResponses *WSResponsesHandler) *IntegrationHandler { +// WebSocket handlers may be nil if WebSocket support is not configured. +func NewIntegrationHandler(client *bifrost.Bifrost, handlerStore lib.HandlerStore, wsResponses *WSResponsesHandler, wsRealtime *WSRealtimeHandler, webrtcRealtime *WebRTCRealtimeHandler, realtimeClientSecrets *RealtimeClientSecretsHandler) *IntegrationHandler { // Initialize all available integration routers extensions := []integrations.ExtensionRouter{ integrations.NewOpenAIRouter(client, handlerStore, logger), @@ -37,8 +40,11 @@ func NewIntegrationHandler(client *bifrost.Bifrost, handlerStore lib.HandlerStor } return &IntegrationHandler{ - extensions: extensions, - wsResponses: wsResponses, + extensions: extensions, + wsResponses: wsResponses, + wsRealtime: wsRealtime, + webrtcRealtime: webrtcRealtime, + realtimeClientSecrets: realtimeClientSecrets, } } @@ -52,6 +58,30 @@ func (h *IntegrationHandler) RegisterRoutes(r *router.Router, middlewares ...sch if h.wsResponses != nil { h.wsResponses.RegisterRoutes(r, middlewares...) } + if h.wsRealtime != nil { + h.wsRealtime.RegisterRoutes(r, middlewares...) + } + if h.webrtcRealtime != nil { + h.webrtcRealtime.RegisterRoutes(r, middlewares...) + } + if h.realtimeClientSecrets != nil { + h.realtimeClientSecrets.RegisterRoutes(r, middlewares...) + } +} + +func (h *IntegrationHandler) Close() { + if h == nil { + return + } + if h.wsResponses != nil { + h.wsResponses.Close() + } + if h.wsRealtime != nil { + h.wsRealtime.Close() + } + if h.webrtcRealtime != nil { + h.webrtcRealtime.Close() + } } // SetLargePayloadHook sets the large payload detection hook on all integration routers diff --git a/transports/bifrost-http/handlers/logging.go b/transports/bifrost-http/handlers/logging.go index 852a0695b0..3e57c66329 100644 --- a/transports/bifrost-http/handlers/logging.go +++ b/transports/bifrost-http/handlers/logging.go @@ -29,6 +29,19 @@ type LoggingHandler struct { config *lib.Config } +// Keep session log page size in one place so the session sheet limit is easy to tune later. +const sessionLogPageLimit = 500 + +func parseParentRequestIDFilter(ctx *fasthttp.RequestCtx) string { + if parentRequestID := string(ctx.QueryArgs().Peek("parent_request_id")); strings.TrimSpace(parentRequestID) != "" { + return parentRequestID + } + if sessionID := string(ctx.QueryArgs().Peek("session_id")); strings.TrimSpace(sessionID) != "" { + return sessionID + } + return "" +} + type RedactedKeysManager interface { GetAllRedactedKeys(ctx context.Context, ids []string) []schemas.Key GetAllRedactedVirtualKeys(ctx context.Context, ids []string) []tables.TableVirtualKey @@ -55,6 +68,8 @@ func (h *LoggingHandler) shouldHideDeletedVirtualKeysInFilters() bool { func (h *LoggingHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) { // LLM Log retrieval with filtering, search, and pagination r.GET("/api/logs", lib.ChainMiddlewares(h.getLogs, middlewares...)) + r.GET("/api/logs/sessions/{session_id}/summary", lib.ChainMiddlewares(h.getLogSessionSummaryByID, middlewares...)) + r.GET("/api/logs/sessions/{session_id}", lib.ChainMiddlewares(h.getLogSessionByID, middlewares...)) r.GET("/api/logs/{id}", lib.ChainMiddlewares(h.getLogByID, middlewares...)) r.GET("/api/logs/stats", lib.ChainMiddlewares(h.getLogsStats, middlewares...)) r.GET("/api/logs/histogram", lib.ChainMiddlewares(h.getLogsHistogram, middlewares...)) @@ -81,6 +96,126 @@ func (h *LoggingHandler) RegisterRoutes(r *router.Router, middlewares ...schemas r.DELETE("/api/mcp-logs", lib.ChainMiddlewares(h.deleteMCPLogs, middlewares...)) } +// getLogSessionByID handles GET /api/logs/sessions/{session_id} - Get logs in a single session. +func (h *LoggingHandler) getLogSessionByID(ctx *fasthttp.RequestCtx) { + rawSessionID, ok := ctx.UserValue("session_id").(string) + if !ok || strings.TrimSpace(rawSessionID) == "" { + SendError(ctx, fasthttp.StatusBadRequest, "session_id is required") + return + } + + pagination := &logstore.PaginationOptions{ + Limit: sessionLogPageLimit, + Offset: 0, + SortBy: "timestamp", + Order: "asc", + } + if limit := string(ctx.QueryArgs().Peek("limit")); limit != "" { + i, err := strconv.Atoi(limit) + if err != nil { + SendError(ctx, fasthttp.StatusBadRequest, "invalid limit") + return + } + if i <= 0 { + SendError(ctx, fasthttp.StatusBadRequest, "limit must be greater than 0") + return + } + if i > sessionLogPageLimit { + SendError(ctx, fasthttp.StatusBadRequest, fmt.Sprintf("limit cannot exceed %d", sessionLogPageLimit)) + return + } + pagination.Limit = i + } + if offset := string(ctx.QueryArgs().Peek("offset")); offset != "" { + i, err := strconv.Atoi(offset) + if err != nil { + SendError(ctx, fasthttp.StatusBadRequest, "invalid offset") + return + } + if i < 0 { + SendError(ctx, fasthttp.StatusBadRequest, "offset cannot be negative") + return + } + pagination.Offset = i + } + if order := string(ctx.QueryArgs().Peek("order")); order != "" { + if order != "asc" && order != "desc" { + SendError(ctx, fasthttp.StatusBadRequest, "order must be 'asc' or 'desc'") + return + } + pagination.Order = order + } + + result, err := h.logManager.GetSessionLogs(ctx, rawSessionID, pagination) + if err != nil { + logger.Error("failed to fetch session logs: %v", err) + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Session fetch failed: %v", err)) + return + } + + selectedKeyIDs := make(map[string]struct{}) + virtualKeyIDs := make(map[string]struct{}) + routingRuleIDs := make(map[string]struct{}) + for _, log := range result.Logs { + if log.SelectedKeyID != "" { + selectedKeyIDs[log.SelectedKeyID] = struct{}{} + } + if log.VirtualKeyID != nil && *log.VirtualKeyID != "" { + virtualKeyIDs[*log.VirtualKeyID] = struct{}{} + } + if log.RoutingRuleID != nil && *log.RoutingRuleID != "" { + routingRuleIDs[*log.RoutingRuleID] = struct{}{} + } + } + + toSlice := func(m map[string]struct{}) []string { + if len(m) == 0 { + return nil + } + out := make([]string, 0, len(m)) + for id := range m { + out = append(out, id) + } + return out + } + + redactedKeys := h.redactedKeysManager.GetAllRedactedKeys(ctx, toSlice(selectedKeyIDs)) + redactedVirtualKeys := h.redactedKeysManager.GetAllRedactedVirtualKeys(ctx, toSlice(virtualKeyIDs)) + redactedRoutingRules := h.redactedKeysManager.GetAllRedactedRoutingRules(ctx, toSlice(routingRuleIDs)) + + for i, log := range result.Logs { + if log.SelectedKeyID != "" && log.SelectedKeyName != "" { + result.Logs[i].SelectedKey = findRedactedKey(redactedKeys, log.SelectedKeyID, log.SelectedKeyName) + } + if log.VirtualKeyID != nil && log.VirtualKeyName != nil && *log.VirtualKeyID != "" && *log.VirtualKeyName != "" { + result.Logs[i].VirtualKey = findRedactedVirtualKey(redactedVirtualKeys, *log.VirtualKeyID, *log.VirtualKeyName) + } + if log.RoutingRuleID != nil && log.RoutingRuleName != nil && *log.RoutingRuleID != "" && *log.RoutingRuleName != "" { + result.Logs[i].RoutingRule = findRedactedRoutingRule(redactedRoutingRules, *log.RoutingRuleID, *log.RoutingRuleName) + } + } + + SendJSON(ctx, result) +} + +// getLogSessionSummaryByID handles GET /api/logs/sessions/{session_id}/summary - Get aggregate totals for a single session. +func (h *LoggingHandler) getLogSessionSummaryByID(ctx *fasthttp.RequestCtx) { + rawSessionID, ok := ctx.UserValue("session_id").(string) + if !ok || strings.TrimSpace(rawSessionID) == "" { + SendError(ctx, fasthttp.StatusBadRequest, "session_id is required") + return + } + + result, err := h.logManager.GetSessionSummary(ctx, rawSessionID) + if err != nil { + logger.Error("failed to fetch session summary: %v", err) + SendError(ctx, fasthttp.StatusInternalServerError, fmt.Sprintf("Session summary fetch failed: %v", err)) + return + } + + SendJSON(ctx, result) +} + // getLogs handles GET /api/logs - Get logs with filtering, search, and pagination via query parameters func (h *LoggingHandler) getLogs(ctx *fasthttp.RequestCtx) { // Parse query parameters into filters @@ -103,6 +238,9 @@ func (h *LoggingHandler) getLogs(ctx *fasthttp.RequestCtx) { if objects := string(ctx.QueryArgs().Peek("objects")); objects != "" { filters.Objects = parseCommaSeparated(objects) } + if parentRequestID := parseParentRequestIDFilter(ctx); parentRequestID != "" { + filters.ParentRequestID = parentRequestID + } if selectedKeyIDs := string(ctx.QueryArgs().Peek("selected_key_ids")); selectedKeyIDs != "" { filters.SelectedKeyIDs = parseCommaSeparated(selectedKeyIDs) } @@ -317,6 +455,9 @@ func (h *LoggingHandler) getLogsStats(ctx *fasthttp.RequestCtx) { if objects := string(ctx.QueryArgs().Peek("objects")); objects != "" { filters.Objects = parseCommaSeparated(objects) } + if parentRequestID := parseParentRequestIDFilter(ctx); parentRequestID != "" { + filters.ParentRequestID = parentRequestID + } if selectedKeyIDs := string(ctx.QueryArgs().Peek("selected_key_ids")); selectedKeyIDs != "" { filters.SelectedKeyIDs = parseCommaSeparated(selectedKeyIDs) } @@ -449,6 +590,9 @@ func parseHistogramFilters(ctx *fasthttp.RequestCtx) *logstore.SearchFilters { if objects := string(ctx.QueryArgs().Peek("objects")); objects != "" { filters.Objects = parseCommaSeparated(objects) } + if parentRequestID := parseParentRequestIDFilter(ctx); parentRequestID != "" { + filters.ParentRequestID = parentRequestID + } if selectedKeyIDs := string(ctx.QueryArgs().Peek("selected_key_ids")); selectedKeyIDs != "" { filters.SelectedKeyIDs = parseCommaSeparated(selectedKeyIDs) } diff --git a/transports/bifrost-http/handlers/middlewares.go b/transports/bifrost-http/handlers/middlewares.go index 6bf5668e33..4c4c7d2856 100644 --- a/transports/bifrost-http/handlers/middlewares.go +++ b/transports/bifrost-http/handlers/middlewares.go @@ -18,11 +18,13 @@ import ( "github.com/maximhq/bifrost/framework/configstore" "github.com/maximhq/bifrost/framework/encrypt" "github.com/maximhq/bifrost/framework/tracing" + "github.com/maximhq/bifrost/transports/bifrost-http/integrations" "github.com/maximhq/bifrost/transports/bifrost-http/lib" "github.com/valyala/fasthttp" ) var loggingSkipPaths = []string{"/health", "/_next", "/api/dev"} +var realtimeTransportPaths = buildRealtimeTransportPathSet() // SecurityHeadersMiddleware sets security-related HTTP headers on every response. // This should wrap the outermost handler so all responses (API, UI, errors) include these headers. @@ -83,7 +85,7 @@ func CorsMiddleware(config *lib.Config) schemas.BifrostHTTPMiddleware { isLocalhostOrigin(origin) || slices.Contains(config.ClientConfig.AllowedOrigins, origin) - allowedHeaders := []string{"Content-Type", "Authorization", "X-Requested-With", "X-Stainless-Timeout", "X-Api-Key"} + allowedHeaders := []string{"Content-Type", "Authorization", "X-Requested-With", "X-Stainless-Timeout", "X-Api-Key", "X-OpenAI-Agents-SDK"} if slices.Contains(config.ClientConfig.AllowedHeaders, "*") { if credentialed { // Per the Fetch spec, Access-Control-Allow-Headers: * is NOT treated as a @@ -554,7 +556,41 @@ func validateSession(_ *fasthttp.RequestCtx, store configstore.ConfigStore, toke // isInferenceWSEndpoint returns true for WebSocket endpoints that should use // standard inference auth (Bearer/Basic/VK) rather than dashboard session tokens. func isInferenceWSEndpoint(path string) bool { - return path == "/v1/responses" || path == "/v1/realtime" + for strings.HasPrefix(path, "/openai/") { + path = strings.TrimPrefix(path, "/openai") + } + + switch path { + case "/v1/responses", + "/responses", + "/v1/realtime", + "/realtime": + return true + default: + return false + } +} + +func buildRealtimeTransportPathSet() map[string]struct{} { + paths := map[string]struct{}{} + for _, path := range integrations.OpenAIRealtimePaths("") { + paths[path] = struct{}{} + } + for _, path := range integrations.OpenAIRealtimePaths("/openai") { + paths[path] = struct{}{} + } + for _, path := range integrations.OpenAIRealtimeWebRTCCallsPaths("") { + paths[path] = struct{}{} + } + for _, path := range integrations.OpenAIRealtimeWebRTCCallsPaths("/openai") { + paths[path] = struct{}{} + } + return paths +} + +func isRealtimeTransportEndpoint(path string) bool { + _, ok := realtimeTransportPaths[path] + return ok } // AuthMiddleware is a middleware that handles authentication for the API. @@ -639,6 +675,10 @@ func (m *AuthMiddleware) middleware(shouldSkip func(*configstore.AuthConfig, str next(ctx) return } + if isRealtimeTransportEndpoint(string(ctx.Path())) { + next(ctx) + return + } // If inference is disabled, we skip authorization // Get the authorization header authorization := string(ctx.Request.Header.Peek("Authorization")) @@ -798,24 +838,23 @@ func (m *AuthMiddleware) middleware(shouldSkip func(*configstore.AuthConfig, str // // This middleware should be placed early in the middleware chain to capture the full request lifecycle. type TracingMiddleware struct { - tracer atomic.Pointer[tracing.Tracer] - obsPlugins atomic.Pointer[[]schemas.ObservabilityPlugin] + tracer atomic.Pointer[tracing.Tracer] } // NewTracingMiddleware creates a new tracing middleware -func NewTracingMiddleware(tracer *tracing.Tracer, obsPlugins []schemas.ObservabilityPlugin) *TracingMiddleware { +func NewTracingMiddleware(tracer *tracing.Tracer) *TracingMiddleware { tm := &TracingMiddleware{ - tracer: atomic.Pointer[tracing.Tracer]{}, - obsPlugins: atomic.Pointer[[]schemas.ObservabilityPlugin]{}, + tracer: atomic.Pointer[tracing.Tracer]{}, } tm.tracer.Store(tracer) - tm.obsPlugins.Store(&obsPlugins) return tm } // SetObservabilityPlugins sets the observability plugins for the tracing middleware func (m *TracingMiddleware) SetObservabilityPlugins(obsPlugins []schemas.ObservabilityPlugin) { - m.obsPlugins.Store(&obsPlugins) + if tracer := m.tracer.Load(); tracer != nil { + tracer.SetObservabilityPlugins(obsPlugins) + } } // SetTracer sets the tracer for the tracing middleware @@ -827,8 +866,10 @@ func (m *TracingMiddleware) SetTracer(tracer *tracing.Tracer) { func (m *TracingMiddleware) Middleware() schemas.BifrostHTTPMiddleware { return func(next fasthttp.RequestHandler) fasthttp.RequestHandler { return func(ctx *fasthttp.RequestCtx) { - // Skip if store is nil - if m.tracer.Load() == nil { + // Pin the tracer for the lifetime of this request so that a concurrent + // SetTracer() swap cannot split a trace across two instances. + tracer := m.tracer.Load() + if tracer == nil { next(ctx) return } @@ -842,7 +883,7 @@ func (m *TracingMiddleware) Middleware() schemas.BifrostHTTPMiddleware { // This is the 32-char trace ID that links all spans in a distributed trace inheritedTraceID := tracing.ExtractParentID(&ctx.Request.Header) // Create trace in store - only ID returned (trace data stays in store) - traceID := m.tracer.Load().CreateTrace(inheritedTraceID, requestID) + traceID := tracer.CreateTrace(inheritedTraceID, requestID) // Only trace ID goes into context (lightweight, no bloat) ctx.SetUserValue(schemas.BifrostContextKeyTraceID, traceID) // Extract parent span ID from W3C traceparent header (if present) @@ -861,16 +902,16 @@ func (m *TracingMiddleware) Middleware() schemas.BifrostHTTPMiddleware { } // Attach transport plugin logs before completing the trace (streaming path) if transportLogs, ok := ctx.UserValue(schemas.BifrostContextKeyTransportPluginLogs).([]schemas.PluginLogEntry); ok && len(transportLogs) > 0 { - m.tracer.Load().AttachPluginLogs(traceID, transportLogs) + tracer.AttachPluginLogs(traceID, transportLogs) } - m.completeAndFlushTrace(traceID) + tracer.CompleteAndFlushTrace(traceID) }) // Create root span for the HTTP request - spanCtx, rootSpan := m.tracer.Load().StartSpan(ctx, string(ctx.RequestURI()), schemas.SpanKindHTTPRequest) + spanCtx, rootSpan := tracer.StartSpan(ctx, string(ctx.RequestURI()), schemas.SpanKindHTTPRequest) if rootSpan != nil { - m.tracer.Load().SetAttribute(rootSpan, "http.method", string(ctx.Method())) - m.tracer.Load().SetAttribute(rootSpan, "http.url", string(ctx.RequestURI())) - m.tracer.Load().SetAttribute(rootSpan, "http.user_agent", string(ctx.Request.Header.UserAgent())) + tracer.SetAttribute(rootSpan, "http.method", string(ctx.Method())) + tracer.SetAttribute(rootSpan, "http.url", string(ctx.RequestURI())) + tracer.SetAttribute(rootSpan, "http.user_agent", string(ctx.Request.Header.UserAgent())) // Set root span ID in context for child span creation if spanID, ok := spanCtx.Value(schemas.BifrostContextKeySpanID).(string); ok { ctx.SetUserValue(schemas.BifrostContextKeySpanID, spanID) @@ -879,11 +920,11 @@ func (m *TracingMiddleware) Middleware() schemas.BifrostHTTPMiddleware { defer func() { // Record response status on the root span if rootSpan != nil { - m.tracer.Load().SetAttribute(rootSpan, "http.status_code", ctx.Response.StatusCode()) + tracer.SetAttribute(rootSpan, "http.status_code", ctx.Response.StatusCode()) if ctx.Response.StatusCode() >= 400 { - m.tracer.Load().EndSpan(rootSpan, schemas.SpanStatusError, fmt.Sprintf("HTTP %d", ctx.Response.StatusCode())) + tracer.EndSpan(rootSpan, schemas.SpanStatusError, fmt.Sprintf("HTTP %d", ctx.Response.StatusCode())) } else { - m.tracer.Load().EndSpan(rootSpan, schemas.SpanStatusOk, "") + tracer.EndSpan(rootSpan, schemas.SpanStatusOk, "") } } // Check if trace completion is deferred (for streaming requests) @@ -893,10 +934,10 @@ func (m *TracingMiddleware) Middleware() schemas.BifrostHTTPMiddleware { } // Attach transport plugin logs to trace before completion if transportLogs, ok := ctx.UserValue(schemas.BifrostContextKeyTransportPluginLogs).([]schemas.PluginLogEntry); ok && len(transportLogs) > 0 { - m.tracer.Load().AttachPluginLogs(traceID, transportLogs) + tracer.AttachPluginLogs(traceID, transportLogs) } // After response written - async flush - m.completeAndFlushTrace(traceID) + tracer.CompleteAndFlushTrace(traceID) }() next(ctx) @@ -904,32 +945,6 @@ func (m *TracingMiddleware) Middleware() schemas.BifrostHTTPMiddleware { } } -// completeAndFlushTrace completes the trace and forwards it to observability plugins. -// This is called either by the middleware defer (for non-streaming) or by streaming handlers. -func (m *TracingMiddleware) completeAndFlushTrace(traceID string) { - go func() { - // Clean up the stream accumulator for this trace - - // Get completed trace from store - completedTrace := m.tracer.Load().EndTrace(traceID) - if completedTrace == nil { - return - } - // Forward to all observability plugins - for _, plugin := range *m.obsPlugins.Load() { - if plugin == nil { - continue - } - // Call inject with a background context (request context is done) - if err := plugin.Inject(context.Background(), completedTrace); err != nil { - logger.Warn("observability plugin %s failed to inject trace: %v", plugin.GetName(), err) - } - } - // Return trace to pool for reuse - m.tracer.Load().ReleaseTrace(completedTrace) - }() -} - // GetTracer returns the tracer instance for use by streaming handlers func (m *TracingMiddleware) GetTracer() *tracing.Tracer { return m.tracer.Load() diff --git a/transports/bifrost-http/handlers/middlewares_test.go b/transports/bifrost-http/handlers/middlewares_test.go index e98edee615..97675d239c 100644 --- a/transports/bifrost-http/handlers/middlewares_test.go +++ b/transports/bifrost-http/handlers/middlewares_test.go @@ -2,8 +2,8 @@ package handlers import ( "bytes" - "compress/zlib" "compress/gzip" + "compress/zlib" cryptoRand "crypto/rand" "encoding/json" "io" @@ -71,7 +71,7 @@ func TestCorsMiddleware_LocalhostOrigins(t *testing.T) { if string(ctx.Response.Header.Peek("Access-Control-Allow-Methods")) != "GET, POST, PUT, DELETE, PATCH, OPTIONS, HEAD" { t.Errorf("Access-Control-Allow-Methods header not set correctly") } - if string(ctx.Response.Header.Peek("Access-Control-Allow-Headers")) != "Content-Type, Authorization, X-Requested-With, X-Stainless-Timeout, X-Api-Key" { + if string(ctx.Response.Header.Peek("Access-Control-Allow-Headers")) != "Content-Type, Authorization, X-Requested-With, X-Stainless-Timeout, X-Api-Key, X-OpenAI-Agents-SDK" { t.Errorf("Access-Control-Allow-Headers header not set correctly") } if string(ctx.Response.Header.Peek("Access-Control-Allow-Credentials")) != "true" { @@ -410,6 +410,69 @@ func TestChainMiddlewares_MiddlewareCanModifyContext(t *testing.T) { chained(ctx) } +func TestIsInferenceWSEndpoint(t *testing.T) { + paths := []string{ + "/v1/responses", + "/v1/realtime", + "/responses", + "/realtime", + "/openai/v1/responses", + "/openai/responses", + "/openai/openai/responses", + "/openai/v1/realtime", + "/openai/realtime", + "/openai/openai/realtime", + } + + for _, path := range paths { + if !isInferenceWSEndpoint(path) { + t.Fatalf("expected inference websocket path %s to be recognized", path) + } + } + + if isInferenceWSEndpoint("/api/ws") { + t.Fatal("dashboard websocket path should not be treated as inference websocket") + } + if isInferenceWSEndpoint("/openai/chat/completions") { + t.Fatal("non-websocket OpenAI path should not be treated as inference websocket") + } +} + +func TestIsRealtimeTransportEndpoint(t *testing.T) { + paths := []string{ + "/v1/realtime", + "/realtime", + "/openai/realtime", + "/openai/v1/realtime", + "/openai/openai/realtime", + "/v1/realtime/calls", + "/realtime/calls", + "/openai/realtime/calls", + "/openai/v1/realtime/calls", + "/openai/openai/realtime/calls", + } + + for _, path := range paths { + if !isRealtimeTransportEndpoint(path) { + t.Fatalf("expected realtime transport path %s to be recognized", path) + } + } + + nonTransportPaths := []string{ + "/v1/realtime/client_secrets", + "/v1/realtime/sessions", + "/openai/v1/realtime/client_secrets", + "/openai/v1/realtime/sessions", + "/v1/chat/completions", + } + + for _, path := range nonTransportPaths { + if isRealtimeTransportEndpoint(path) { + t.Fatalf("did not expect non-transport path %s to be recognized", path) + } + } +} + // Testlib.ChainMiddlewares_ShortCircuit tests that when a middleware writes a response // and does not call next, subsequent middlewares and handler do not execute. func TestChainMiddlewares_ShortCircuit(t *testing.T) { @@ -663,6 +726,83 @@ func TestAuthMiddleware_WhitelistedRoutes(t *testing.T) { } } +func TestAuthMiddleware_InferenceMiddleware_RealtimeTransportBypassesAuth(t *testing.T) { + SetLogger(&mockLogger{}) + + am := &AuthMiddleware{} + am.UpdateAuthConfig(&configstore.AuthConfig{ + AdminUserName: schemas.NewEnvVar("admin"), + AdminPassword: schemas.NewEnvVar("hashedpassword"), + IsEnabled: true, + }) + + routes := []string{ + "/v1/realtime", + "/openai/v1/realtime", + "/v1/realtime/calls?model=gpt-realtime", + "/openai/v1/realtime/calls?model=gpt-realtime", + } + + for _, route := range routes { + t.Run(route, func(t *testing.T) { + ctx := &fasthttp.RequestCtx{} + ctx.Request.SetRequestURI(route) + + nextCalled := false + next := func(ctx *fasthttp.RequestCtx) { + nextCalled = true + } + + handler := am.InferenceMiddleware()(next) + handler(ctx) + + if !nextCalled { + t.Fatalf("expected realtime transport route %s to bypass auth", route) + } + }) + } +} + +func TestAuthMiddleware_InferenceMiddleware_RealtimeMintingStillRequiresAuth(t *testing.T) { + SetLogger(&mockLogger{}) + + am := &AuthMiddleware{} + am.UpdateAuthConfig(&configstore.AuthConfig{ + AdminUserName: schemas.NewEnvVar("admin"), + AdminPassword: schemas.NewEnvVar("hashedpassword"), + IsEnabled: true, + }) + + routes := []string{ + "/v1/realtime/client_secrets", + "/v1/realtime/sessions", + "/openai/v1/realtime/client_secrets", + "/openai/v1/realtime/sessions", + } + + for _, route := range routes { + t.Run(route, func(t *testing.T) { + ctx := &fasthttp.RequestCtx{} + ctx.Request.SetRequestURI(route) + + nextCalled := false + next := func(ctx *fasthttp.RequestCtx) { + nextCalled = true + } + + handler := am.InferenceMiddleware()(next) + handler(ctx) + + if nextCalled { + t.Fatalf("expected realtime minting route %s to still require auth", route) + } + if ctx.Response.StatusCode() != fasthttp.StatusUnauthorized { + t.Fatalf("expected %d for route %s, got %d", fasthttp.StatusUnauthorized, route, ctx.Response.StatusCode()) + } + }) + } +} + // TestAuthMiddleware_UpdateAuthConfig_NilToEnabled tests updating auth config from nil to enabled func TestAuthMiddleware_UpdateAuthConfig_NilToEnabled(t *testing.T) { SetLogger(&mockLogger{}) @@ -864,7 +1004,7 @@ func TestCorsMiddleware_DefaultHeaders(t *testing.T) { handler(ctx) // Check default headers are set - expectedHeaders := "Content-Type, Authorization, X-Requested-With, X-Stainless-Timeout, X-Api-Key" + expectedHeaders := "Content-Type, Authorization, X-Requested-With, X-Stainless-Timeout, X-Api-Key, X-OpenAI-Agents-SDK" actualHeaders := string(ctx.Response.Header.Peek("Access-Control-Allow-Headers")) if actualHeaders != expectedHeaders { t.Errorf("Expected Access-Control-Allow-Headers to be %s, got %s", expectedHeaders, actualHeaders) diff --git a/transports/bifrost-http/handlers/realtime_client_secrets.go b/transports/bifrost-http/handlers/realtime_client_secrets.go new file mode 100644 index 0000000000..9c761d0692 --- /dev/null +++ b/transports/bifrost-http/handlers/realtime_client_secrets.go @@ -0,0 +1,416 @@ +package handlers + +import ( + "encoding/json" + "fmt" + "mime" + "strings" + "time" + + "github.com/fasthttp/router" + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/plugins/governance" + "github.com/maximhq/bifrost/transports/bifrost-http/integrations" + "github.com/maximhq/bifrost/transports/bifrost-http/lib" + "github.com/valyala/fasthttp" +) + +// RealtimeClientSecretsHandler exposes OpenAI-compatible HTTP routes for +// minting short-lived Realtime client secrets. +type RealtimeClientSecretsHandler struct { + client *bifrost.Bifrost + config *lib.Config + handlerStore lib.HandlerStore + routeSpecs map[string]schemas.RealtimeSessionRoute +} + +func NewRealtimeClientSecretsHandler(client *bifrost.Bifrost, config *lib.Config) *RealtimeClientSecretsHandler { + return &RealtimeClientSecretsHandler{ + client: client, + config: config, + handlerStore: config, + routeSpecs: make(map[string]schemas.RealtimeSessionRoute), + } +} + +func (h *RealtimeClientSecretsHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) { + handler := lib.ChainMiddlewares(h.handleRequest, middlewares...) + for _, route := range h.realtimeSessionRoutes() { + h.routeSpecs[route.Path] = route + r.POST(route.Path, handler) + } +} + +func (h *RealtimeClientSecretsHandler) findGovernancePlugin() governance.BaseGovernancePlugin { + basePlugins := h.config.BasePlugins.Load() + if basePlugins == nil { + return nil + } + + for _, plugin := range *basePlugins { + if governancePlugin, ok := plugin.(governance.BaseGovernancePlugin); ok { + return governancePlugin + } + } + + return nil +} + +func (h *RealtimeClientSecretsHandler) handleRequest(ctx *fasthttp.RequestCtx) { + if !isJSONContentType(string(ctx.Request.Header.ContentType())) { + SendBifrostError(ctx, newRealtimeClientSecretHandlerError( + fasthttp.StatusBadRequest, + "invalid_request_error", + "Content-Type must be application/json", + nil, + )) + return + } + + body := append([]byte(nil), ctx.Request.Body()...) + route, ok := h.routeSpecs[string(ctx.Path())] + if !ok { + SendBifrostError(ctx, newRealtimeClientSecretHandlerError( + fasthttp.StatusNotFound, + "invalid_request_error", + "unsupported realtime client secret route", + nil, + )) + return + } + + providerKey, model, normalizedBody, err := resolveRealtimeClientSecretTarget(route, body) + if err != nil { + SendBifrostError(ctx, err) + return + } + + bifrostCtx, cancel := lib.ConvertToBifrostContext( + ctx, + h.handlerStore.ShouldAllowDirectKeys(), + h.config.GetHeaderMatcher(), + h.config.GetMCPHeaderCombinedAllowlist(), + ) + defer cancel() + bifrostCtx.SetValue(schemas.BifrostContextKeyHTTPRequestType, schemas.RealtimeRequest) + if route.DefaultProvider == schemas.OpenAI { + bifrostCtx.SetValue(schemas.BifrostContextKeyIntegrationType, "openai") + } + if governanceUserID, ok := ctx.UserValue(schemas.BifrostContextKeyGovernanceUserID).(string); ok && governanceUserID != "" { + bifrostCtx.SetValue(schemas.BifrostContextKeyGovernanceUserID, governanceUserID) + } + if bifrostErr := h.evaluateMintingGovernance(bifrostCtx, providerKey, model); bifrostErr != nil { + SendBifrostError(ctx, bifrostErr) + return + } + + provider := h.client.GetProviderByKey(providerKey) + if provider == nil { + SendBifrostError(ctx, newRealtimeClientSecretHandlerError( + fasthttp.StatusBadRequest, + "invalid_request_error", + "provider not found: "+string(providerKey), + nil, + )) + return + } + + key, keyErr := h.client.SelectKeyForProviderRequestType(bifrostCtx, schemas.RealtimeRequest, providerKey, model) + if keyErr != nil { + SendBifrostError(ctx, newRealtimeClientSecretHandlerError( + fasthttp.StatusBadRequest, + "invalid_request_error", + keyErr.Error(), + keyErr, + )) + return + } + + // Resolve model aliases now that the key is selected so the forwarded body + // carries the provider's canonical model, matching wsrealtime/webrtc flows. + if resolved := key.Aliases.Resolve(model); resolved != "" && resolved != model { + model = resolved + reparsed, parseErr := schemas.ParseRealtimeClientSecretBody(normalizedBody) + if parseErr != nil { + SendBifrostError(ctx, parseErr) + return + } + rewritten, normalizeErr := normalizeRealtimeClientSecretBody(reparsed, model) + if normalizeErr != nil { + SendBifrostError(ctx, normalizeErr) + return + } + normalizedBody = rewritten + } + + sessionProvider, ok := provider.(schemas.RealtimeSessionProvider) + if !ok { + SendBifrostError(ctx, realtimeSessionNotSupportedError(providerKey, provider)) + return + } + + resp, bifrostErr := sessionProvider.CreateRealtimeClientSecret(bifrostCtx, key, route.EndpointType, normalizedBody) + if bifrostErr != nil { + SendBifrostError(ctx, bifrostErr) + return + } + cacheRealtimeEphemeralKeyMapping( + h.handlerStore.GetKVStore(), + resp.Body, + key.ID, + bifrost.GetStringFromContext(bifrostCtx, schemas.BifrostContextKeyVirtualKey), + ) + + writeRealtimeClientSecretResponse(ctx, resp) +} + +func (h *RealtimeClientSecretsHandler) evaluateMintingGovernance( + bifrostCtx *schemas.BifrostContext, + providerKey schemas.ModelProvider, + model string, +) *schemas.BifrostError { + governancePlugin := h.findGovernancePlugin() + if governancePlugin == nil { + return nil + } + + _, bifrostErr := governancePlugin.EvaluateGovernanceRequest(bifrostCtx, &governance.EvaluationRequest{ + VirtualKey: bifrost.GetStringFromContext(bifrostCtx, schemas.BifrostContextKeyVirtualKey), + Provider: providerKey, + Model: model, + UserID: bifrost.GetStringFromContext(bifrostCtx, schemas.BifrostContextKeyGovernanceUserID), + }, schemas.RealtimeRequest) + return bifrostErr +} + +func (h *RealtimeClientSecretsHandler) realtimeSessionRoutes() []schemas.RealtimeSessionRoute { + routes := []schemas.RealtimeSessionRoute{ + { + Path: "/v1/realtime/client_secrets", + EndpointType: schemas.RealtimeSessionEndpointClientSecrets, + }, + { + Path: "/v1/realtime/sessions", + EndpointType: schemas.RealtimeSessionEndpointSessions, + }, + } + + for _, path := range integrations.OpenAIRealtimeClientSecretPaths("/openai") { + endpointType := schemas.RealtimeSessionEndpointClientSecrets + if strings.HasSuffix(path, "/realtime/sessions") { + endpointType = schemas.RealtimeSessionEndpointSessions + } + routes = append(routes, schemas.RealtimeSessionRoute{ + Path: path, + EndpointType: endpointType, + DefaultProvider: schemas.OpenAI, + }) + } + return routes +} + +func resolveRealtimeClientSecretTarget(route schemas.RealtimeSessionRoute, body []byte) (schemas.ModelProvider, string, []byte, *schemas.BifrostError) { + root, err := schemas.ParseRealtimeClientSecretBody(body) + if err != nil { + return "", "", nil, err + } + + rawModel, err := schemas.ExtractRealtimeClientSecretModel(root) + if err != nil { + return "", "", nil, err + } + + defaultProvider := route.DefaultProvider + providerKey, model := schemas.ParseModelString(rawModel, defaultProvider) + if defaultProvider == "" && providerKey == "" { + return "", "", nil, newRealtimeClientSecretHandlerError( + fasthttp.StatusBadRequest, + "invalid_request_error", + "session.model must use provider/model on /v1 realtime client secret routes", + nil, + ) + } + if providerKey == "" || model == "" { + return "", "", nil, newRealtimeClientSecretHandlerError( + fasthttp.StatusBadRequest, + "invalid_request_error", + "session.model is required", + nil, + ) + } + + // Normalize the forwarded body so the upstream provider sees the bare model + // (strip provider prefix). Mirrors resolveRealtimeSDPTarget normalization. + normalizedBody, normalizeErr := normalizeRealtimeClientSecretBody(root, model) + if normalizeErr != nil { + return "", "", nil, normalizeErr + } + + return providerKey, model, normalizedBody, nil +} + +func normalizeRealtimeClientSecretBody(root map[string]json.RawMessage, bareModel string) ([]byte, *schemas.BifrostError) { + normalizedModel, marshalErr := json.Marshal(bareModel) + if marshalErr != nil { + return nil, newRealtimeClientSecretHandlerError(fasthttp.StatusInternalServerError, "server_error", "failed to encode normalized model", marshalErr) + } + + // Normalize session.model if present + if sessionJSON, ok := root["session"]; ok && len(sessionJSON) > 0 { + var session map[string]json.RawMessage + if err := json.Unmarshal(sessionJSON, &session); err == nil { + if _, hasModel := session["model"]; hasModel { + session["model"] = normalizedModel + rewritten, err := json.Marshal(session) + if err != nil { + return nil, newRealtimeClientSecretHandlerError(fasthttp.StatusInternalServerError, "server_error", "failed to re-encode session", err) + } + root["session"] = rewritten + } + } + } + // Normalize top-level model if present + if _, ok := root["model"]; ok { + root["model"] = normalizedModel + } + + normalized, marshalErr := json.Marshal(root) + if marshalErr != nil { + return nil, newRealtimeClientSecretHandlerError(fasthttp.StatusInternalServerError, "server_error", "failed to re-encode body", marshalErr) + } + return normalized, nil +} + +const realtimeEphemeralKeyMappingPrefix = "realtime:ephemeral-key:" + +type realtimeEphemeralKeyMapping struct { + KeyID string `json:"key_id,omitempty"` + VirtualKey string `json:"virtual_key,omitempty"` +} + +func cacheRealtimeEphemeralKeyMapping(kv schemas.KVStore, body []byte, keyID string, virtualKey string) { + if kv == nil || len(body) == 0 || strings.TrimSpace(keyID) == "" { + return + } + + token, ttl, ok := parseRealtimeEphemeralKeyMapping(body) + if !ok || strings.TrimSpace(token) == "" || ttl <= 0 { + return + } + + payload, err := json.Marshal(realtimeEphemeralKeyMapping{ + KeyID: strings.TrimSpace(keyID), + VirtualKey: strings.TrimSpace(virtualKey), + }) + if err != nil { + logger.Warn("failed to encode realtime ephemeral key mapping for key_id=%s: %v", keyID, err) + return + } + + if err := kv.SetWithTTL(buildRealtimeEphemeralKeyMappingKey(token), payload, ttl); err != nil { + logger.Warn("failed to cache realtime ephemeral key mapping for key_id=%s: %v", keyID, err) + } +} + +func parseRealtimeEphemeralKeyMapping(body []byte) (string, time.Duration, bool) { + var root map[string]json.RawMessage + if err := json.Unmarshal(body, &root); err != nil { + return "", 0, false + } + + var clientSecret struct { + Value string `json:"value"` + ExpiresAt int64 `json:"expires_at"` + } + + // OpenAI client_secrets responses expose the ephemeral token at the top level. + // Keep accepting the nested shape too so the mapping logic stays compatible + // with any provider/session endpoint variants that wrap the secret object. + if err := json.Unmarshal(body, &clientSecret); err != nil || strings.TrimSpace(clientSecret.Value) == "" || clientSecret.ExpiresAt <= 0 { + clientSecretRaw, ok := root["client_secret"] + if !ok || len(clientSecretRaw) == 0 || string(clientSecretRaw) == "null" { + return "", 0, false + } + if err := json.Unmarshal(clientSecretRaw, &clientSecret); err != nil { + return "", 0, false + } + } + if strings.TrimSpace(clientSecret.Value) == "" || clientSecret.ExpiresAt <= 0 { + return "", 0, false + } + + ttl := time.Until(time.Unix(clientSecret.ExpiresAt, 0)) + if ttl <= 0 { + return "", 0, false + } + + return clientSecret.Value, ttl, true +} + +func buildRealtimeEphemeralKeyMappingKey(token string) string { + return realtimeEphemeralKeyMappingPrefix + strings.TrimSpace(token) +} + +func realtimeSessionNotSupportedError(providerKey schemas.ModelProvider, provider schemas.Provider) *schemas.BifrostError { + if rtProvider, ok := provider.(schemas.RealtimeProvider); ok && rtProvider.SupportsRealtimeAPI() { + return newRealtimeClientSecretHandlerError( + fasthttp.StatusBadRequest, + "invalid_request_error", + fmt.Sprintf("provider %s supports realtime websocket connections but not realtime client secret creation", providerKey), + nil, + ) + } + + return newRealtimeClientSecretHandlerError( + fasthttp.StatusBadRequest, + "invalid_request_error", + fmt.Sprintf("provider %s does not support realtime client secret creation", providerKey), + nil, + ) +} + +func newRealtimeClientSecretHandlerError(status int, errorType, message string, err error) *schemas.BifrostError { + return &schemas.BifrostError{ + IsBifrostError: false, + StatusCode: schemas.Ptr(status), + Error: &schemas.ErrorField{ + Type: schemas.Ptr(errorType), + Message: message, + Error: err, + }, + ExtraFields: schemas.BifrostErrorExtraFields{ + RequestType: schemas.RealtimeRequest, + }, + } +} + +func writeRealtimeClientSecretResponse(ctx *fasthttp.RequestCtx, resp *schemas.BifrostPassthroughResponse) { + if resp == nil { + SendBifrostError(ctx, newRealtimeClientSecretHandlerError( + fasthttp.StatusInternalServerError, + "server_error", + "provider returned an empty realtime client secret response", + nil, + )) + return + } + + for key, value := range resp.Headers { + ctx.Response.Header.Set(key, value) + } + if len(ctx.Response.Header.ContentType()) == 0 { + ctx.SetContentType("application/json") + } + ctx.SetStatusCode(resp.StatusCode) + ctx.SetBody(resp.Body) +} + +func isJSONContentType(contentType string) bool { + mediaType, _, err := mime.ParseMediaType(contentType) + if err != nil { + return false + } + mediaType = strings.ToLower(mediaType) + return mediaType == "application/json" || strings.HasSuffix(mediaType, "+json") +} diff --git a/transports/bifrost-http/handlers/realtime_client_secrets_test.go b/transports/bifrost-http/handlers/realtime_client_secrets_test.go new file mode 100644 index 0000000000..4a23782406 --- /dev/null +++ b/transports/bifrost-http/handlers/realtime_client_secrets_test.go @@ -0,0 +1,414 @@ +package handlers + +import ( + "context" + "encoding/json" + "fmt" + "testing" + "time" + + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/framework/kvstore" + "github.com/maximhq/bifrost/plugins/governance" + "github.com/maximhq/bifrost/transports/bifrost-http/lib" + "github.com/valyala/fasthttp" +) + +func TestResolveRealtimeClientSecretTarget(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + route schemas.RealtimeSessionRoute + body []byte + wantProvider schemas.ModelProvider + wantModel string + wantErr bool + }{ + { + name: "base route with session model", + route: schemas.RealtimeSessionRoute{Path: "/v1/realtime/client_secrets", EndpointType: schemas.RealtimeSessionEndpointClientSecrets}, + body: []byte(`{"session":{"model":"openai/gpt-4o-realtime-preview"}}`), + wantProvider: schemas.OpenAI, + wantModel: "gpt-4o-realtime-preview", + }, + { + name: "base route with top level model", + route: schemas.RealtimeSessionRoute{Path: "/v1/realtime/sessions", EndpointType: schemas.RealtimeSessionEndpointSessions}, + body: []byte(`{"model":"openai/gpt-4o-realtime-preview"}`), + wantProvider: schemas.OpenAI, + wantModel: "gpt-4o-realtime-preview", + }, + { + name: "openai alias uses bare model", + route: schemas.RealtimeSessionRoute{Path: "/openai/v1/realtime/client_secrets", EndpointType: schemas.RealtimeSessionEndpointClientSecrets, DefaultProvider: schemas.OpenAI}, + body: []byte(`{"session":{"model":"gpt-4o-realtime-preview"}}`), + wantProvider: schemas.OpenAI, + wantModel: "gpt-4o-realtime-preview", + }, + { + name: "base route rejects bare model", + route: schemas.RealtimeSessionRoute{Path: "/v1/realtime/client_secrets", EndpointType: schemas.RealtimeSessionEndpointClientSecrets}, + body: []byte(`{"session":{"model":"gpt-4o-realtime-preview"}}`), + wantErr: true, + }, + { + name: "missing model", + route: schemas.RealtimeSessionRoute{Path: "/openai/v1/realtime/client_secrets", EndpointType: schemas.RealtimeSessionEndpointClientSecrets, DefaultProvider: schemas.OpenAI}, + body: []byte(`{"session":{}}`), + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + gotProvider, gotModel, _, err := resolveRealtimeClientSecretTarget(tt.route, tt.body) + if tt.wantErr { + if err == nil { + t.Fatal("expected error, got nil") + } + return + } + if err != nil { + t.Fatalf("resolveRealtimeClientSecretTarget() error = %v", err) + } + if gotProvider != tt.wantProvider { + t.Fatalf("provider = %q, want %q", gotProvider, tt.wantProvider) + } + if gotModel != tt.wantModel { + t.Fatalf("model = %q, want %q", gotModel, tt.wantModel) + } + }) + } +} + +func TestResolveRealtimeClientSecretTarget_NormalizesModel(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + route schemas.RealtimeSessionRoute + body string + wantModel string // bare model expected in normalized body + }{ + { + name: "session.model provider prefix stripped", + route: schemas.RealtimeSessionRoute{Path: "/v1/realtime/client_secrets", EndpointType: schemas.RealtimeSessionEndpointClientSecrets}, + body: `{"session":{"model":"openai/gpt-4o-realtime-preview","voice":"alloy"}}`, + wantModel: "gpt-4o-realtime-preview", + }, + { + name: "top-level model provider prefix stripped", + route: schemas.RealtimeSessionRoute{Path: "/v1/realtime/sessions", EndpointType: schemas.RealtimeSessionEndpointSessions}, + body: `{"model":"openai/gpt-4o-realtime-preview"}`, + wantModel: "gpt-4o-realtime-preview", + }, + { + name: "bare model unchanged on alias route", + route: schemas.RealtimeSessionRoute{Path: "/openai/v1/realtime/client_secrets", EndpointType: schemas.RealtimeSessionEndpointClientSecrets, DefaultProvider: schemas.OpenAI}, + body: `{"session":{"model":"gpt-4o-realtime-preview"}}`, + wantModel: "gpt-4o-realtime-preview", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + _, _, normalizedBody, err := resolveRealtimeClientSecretTarget(tt.route, []byte(tt.body)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var root map[string]json.RawMessage + if unmarshalErr := json.Unmarshal(normalizedBody, &root); unmarshalErr != nil { + t.Fatalf("failed to unmarshal normalized body: %v", unmarshalErr) + } + + // Check session.model if present + if sessionJSON, ok := root["session"]; ok { + var session map[string]json.RawMessage + if unmarshalErr := json.Unmarshal(sessionJSON, &session); unmarshalErr != nil { + t.Fatalf("failed to unmarshal session: %v", unmarshalErr) + } + if modelJSON, ok := session["model"]; ok { + var model string + if unmarshalErr := json.Unmarshal(modelJSON, &model); unmarshalErr != nil { + t.Fatalf("failed to unmarshal session.model: %v", unmarshalErr) + } + if model != tt.wantModel { + t.Fatalf("session.model = %q, want %q", model, tt.wantModel) + } + } + } + + // Check top-level model if present + if modelJSON, ok := root["model"]; ok { + var model string + if unmarshalErr := json.Unmarshal(modelJSON, &model); unmarshalErr != nil { + t.Fatalf("failed to unmarshal model: %v", unmarshalErr) + } + if model != tt.wantModel { + t.Fatalf("model = %q, want %q", model, tt.wantModel) + } + } + }) + } +} + +func TestParseRealtimeEphemeralKeyMapping(t *testing.T) { + t.Parallel() + + token, ttl, ok := parseRealtimeEphemeralKeyMapping([]byte(`{ + "value": "ek_test_123", + "expires_at": 4102444800 + }`)) + if !ok { + t.Fatal("expected ephemeral mapping to be parsed") + } + if token != "ek_test_123" { + t.Fatalf("token = %q, want %q", token, "ek_test_123") + } + if ttl <= 0 { + t.Fatalf("ttl = %v, want > 0", ttl) + } +} + +func TestParseRealtimeEphemeralKeyMapping_NestedFallback(t *testing.T) { + t.Parallel() + + token, ttl, ok := parseRealtimeEphemeralKeyMapping([]byte(`{ + "client_secret": { + "value": "ek_test_nested", + "expires_at": 4102444800 + } + }`)) + if !ok { + t.Fatal("expected nested ephemeral mapping to be parsed") + } + if token != "ek_test_nested" { + t.Fatalf("token = %q, want %q", token, "ek_test_nested") + } + if ttl <= 0 { + t.Fatalf("ttl = %v, want > 0", ttl) + } +} + +func TestCacheRealtimeEphemeralKeyMappingStoresKeyID(t *testing.T) { + t.Parallel() + + store, err := kvstore.New(kvstore.Config{}) + if err != nil { + t.Fatalf("kvstore.New() error = %v", err) + } + defer store.Close() + + body := []byte(`{ + "value": "ek_test_456", + "expires_at": ` + "4102444800" + ` + }`) + cacheRealtimeEphemeralKeyMapping(store, body, "key_123", "sk-bf-test") + + raw, err := store.Get(buildRealtimeEphemeralKeyMappingKey("ek_test_456")) + if err != nil { + t.Fatalf("store.Get() error = %v", err) + } + value, ok := raw.([]byte) + if !ok { + t.Fatalf("cached value type = %T, want []byte", raw) + } + var mapping realtimeEphemeralKeyMapping + if err := json.Unmarshal(value, &mapping); err != nil { + t.Fatalf("json.Unmarshal() error = %v", err) + } + if mapping.KeyID != "key_123" { + t.Fatalf("mapping.KeyID = %q, want %q", mapping.KeyID, "key_123") + } + if mapping.VirtualKey != "sk-bf-test" { + t.Fatalf("mapping.VirtualKey = %q, want %q", mapping.VirtualKey, "sk-bf-test") + } +} + +func TestCacheRealtimeEphemeralKeyMappingSkipsExpiredSecrets(t *testing.T) { + t.Parallel() + + store, err := kvstore.New(kvstore.Config{}) + if err != nil { + t.Fatalf("kvstore.New() error = %v", err) + } + defer store.Close() + + expired := time.Now().Add(-time.Minute).Unix() + body := fmt.Appendf(nil, `{ + "value": "ek_expired", + "expires_at": %d + }`, expired) + cacheRealtimeEphemeralKeyMapping(store, body, "key_123", "") + + if _, err := store.Get(buildRealtimeEphemeralKeyMappingKey("ek_expired")); err == nil { + t.Fatal("expected no cached mapping for expired token") + } +} + +func TestIsJSONContentType(t *testing.T) { + t.Parallel() + + if !isJSONContentType("application/json; charset=utf-8") { + t.Fatal("expected application/json content type to pass") + } + if !isJSONContentType("application/vnd.openai+json") { + t.Fatal("expected +json content type to pass") + } + if isJSONContentType("text/plain") { + t.Fatal("expected text/plain content type to fail") + } +} + +type mockRealtimeMintingGovernancePlugin struct { + err *schemas.BifrostError + seenUserID string + seenVirtualKey string + seenProvider schemas.ModelProvider + seenModel string + evaluateCalls int +} + +func (m *mockRealtimeMintingGovernancePlugin) GetName() string { + return governance.PluginName +} + +func (m *mockRealtimeMintingGovernancePlugin) EvaluateGovernanceRequest(ctx *schemas.BifrostContext, evaluationRequest *governance.EvaluationRequest, _ schemas.RequestType) (*governance.EvaluationResult, *schemas.BifrostError) { + m.evaluateCalls++ + m.seenUserID = "" + m.seenVirtualKey = "" + m.seenProvider = "" + m.seenModel = "" + if evaluationRequest != nil { + m.seenUserID = evaluationRequest.UserID + m.seenVirtualKey = evaluationRequest.VirtualKey + m.seenProvider = evaluationRequest.Provider + m.seenModel = evaluationRequest.Model + } + if ctx != nil && m.seenVirtualKey == "" { + m.seenVirtualKey = bifrost.GetStringFromContext(ctx, schemas.BifrostContextKeyVirtualKey) + } + if m.err != nil { + return nil, m.err + } + return &governance.EvaluationResult{Decision: governance.DecisionAllow}, nil +} + +func (m *mockRealtimeMintingGovernancePlugin) HTTPTransportPreHook(_ *schemas.BifrostContext, _ *schemas.HTTPRequest) (*schemas.HTTPResponse, error) { + return nil, nil +} + +func (m *mockRealtimeMintingGovernancePlugin) HTTPTransportPostHook(_ *schemas.BifrostContext, _ *schemas.HTTPRequest, _ *schemas.HTTPResponse) error { + return nil +} + +func (m *mockRealtimeMintingGovernancePlugin) PreLLMHook(_ *schemas.BifrostContext, req *schemas.BifrostRequest) (*schemas.BifrostRequest, *schemas.LLMPluginShortCircuit, error) { + return req, nil, nil +} + +func (m *mockRealtimeMintingGovernancePlugin) PostLLMHook(_ *schemas.BifrostContext, result *schemas.BifrostResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError, error) { + return result, bifrostErr, nil +} + +func (m *mockRealtimeMintingGovernancePlugin) PreMCPHook(_ *schemas.BifrostContext, req *schemas.BifrostMCPRequest) (*schemas.BifrostMCPRequest, *schemas.MCPPluginShortCircuit, error) { + return req, nil, nil +} + +func (m *mockRealtimeMintingGovernancePlugin) PostMCPHook(_ *schemas.BifrostContext, resp *schemas.BifrostMCPResponse, bifrostErr *schemas.BifrostError) (*schemas.BifrostMCPResponse, *schemas.BifrostError, error) { + return resp, bifrostErr, nil +} + +func (m *mockRealtimeMintingGovernancePlugin) Cleanup() error { + return nil +} + +func (m *mockRealtimeMintingGovernancePlugin) GetGovernanceStore() governance.GovernanceStore { + return nil +} + +func TestRealtimeClientSecretsEvaluateMintingGovernance_RequiresAccess(t *testing.T) { + t.Parallel() + + config := &lib.Config{} + plugin := &mockRealtimeMintingGovernancePlugin{ + err: &schemas.BifrostError{ + Type: schemas.Ptr("virtual_key_required"), + StatusCode: schemas.Ptr(401), + Error: &schemas.ErrorField{ + Message: "virtual key is required. Provide a virtual key via the x-bf-vk header.", + }, + }, + } + plugins := []schemas.BasePlugin{plugin} + config.BasePlugins.Store(&plugins) + + handler := NewRealtimeClientSecretsHandler(nil, config) + bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + defer bifrostCtx.Done() + + err := handler.evaluateMintingGovernance(bifrostCtx, schemas.OpenAI, "gpt-realtime") + if err == nil { + t.Fatal("expected governance error") + } + if err.StatusCode == nil { + t.Fatal("expected status code") + } + if got, want := *err.StatusCode, fasthttp.StatusUnauthorized; got != want { + t.Fatalf("status = %d, want %d", got, want) + } +} + +func TestRealtimeClientSecretsEvaluateMintingGovernance_PassesContext(t *testing.T) { + t.Parallel() + + config := &lib.Config{} + plugin := &mockRealtimeMintingGovernancePlugin{} + plugins := []schemas.BasePlugin{ + plugin, + } + config.BasePlugins.Store(&plugins) + + handler := NewRealtimeClientSecretsHandler(nil, config) + bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + defer bifrostCtx.Done() + bifrostCtx.SetValue(schemas.BifrostContextKeyGovernanceUserID, "user_123") + bifrostCtx.SetValue(schemas.BifrostContextKeyVirtualKey, "sk-bf-123") + + if err := handler.evaluateMintingGovernance(bifrostCtx, schemas.OpenAI, "gpt-realtime"); err != nil { + t.Fatalf("unexpected governance error: %v", err) + } + if plugin.evaluateCalls != 1 { + t.Fatalf("evaluate calls = %d, want 1", plugin.evaluateCalls) + } + if plugin.seenUserID != "user_123" { + t.Fatalf("governance user id = %q, want %q", plugin.seenUserID, "user_123") + } + if plugin.seenVirtualKey != "sk-bf-123" { + t.Fatalf("virtual key = %q, want %q", plugin.seenVirtualKey, "sk-bf-123") + } + if plugin.seenProvider != schemas.OpenAI { + t.Fatalf("provider = %q, want %q", plugin.seenProvider, schemas.OpenAI) + } + if plugin.seenModel != "gpt-realtime" { + t.Fatalf("model = %q, want %q", plugin.seenModel, "gpt-realtime") + } +} + +func TestRealtimeClientSecretsEvaluateMintingGovernance_ContinuesWithoutGovernance(t *testing.T) { + t.Parallel() + + handler := NewRealtimeClientSecretsHandler(nil, &lib.Config{}) + bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + defer bifrostCtx.Done() + + if err := handler.evaluateMintingGovernance(bifrostCtx, schemas.OpenAI, "gpt-realtime"); err != nil { + t.Fatalf("unexpected governance error without plugin: %v", err) + } +} diff --git a/transports/bifrost-http/handlers/realtime_logging.go b/transports/bifrost-http/handlers/realtime_logging.go new file mode 100644 index 0000000000..3b05b1633d --- /dev/null +++ b/transports/bifrost-http/handlers/realtime_logging.go @@ -0,0 +1,441 @@ +package handlers + +import ( + "encoding/json" + "strings" + + "github.com/bytedance/sonic" + "github.com/maximhq/bifrost/core/schemas" + bfws "github.com/maximhq/bifrost/transports/bifrost-http/websocket" +) + +type realtimeTurnSource string + +const ( + realtimeTurnSourceEI realtimeTurnSource = "ei" + realtimeTurnSourceLM realtimeTurnSource = "lm" +) + +const ( + realtimeMissingTranscriptText = "[Audio transcription unavailable]" +) + +func extractRealtimeTurnSummary(event *schemas.BifrostRealtimeEvent, contentOverride string) string { + if strings.TrimSpace(contentOverride) != "" { + return strings.TrimSpace(contentOverride) + } + if event == nil { + return "" + } + if event.Error != nil && strings.TrimSpace(event.Error.Message) != "" { + return strings.TrimSpace(event.Error.Message) + } + if event.Delta != nil { + if text := strings.TrimSpace(event.Delta.Text); text != "" { + return text + } + if transcript := strings.TrimSpace(event.Delta.Transcript); transcript != "" { + return transcript + } + } + if event.Item != nil { + if summary := extractRealtimeItemSummary(event.Item); summary != "" { + return summary + } + } + if event.Session != nil && strings.TrimSpace(event.Session.Instructions) != "" { + return strings.TrimSpace(event.Session.Instructions) + } + if len(event.RawData) > 0 { + return strings.TrimSpace(string(event.RawData)) + } + return "" +} + +func extractRealtimeItemSummary(item *schemas.RealtimeItem) string { + if item == nil { + return "" + } + if summary := extractRealtimeContentSummary(item.Content); summary != "" { + return summary + } + switch { + case strings.TrimSpace(item.Output) != "": + return strings.TrimSpace(item.Output) + case strings.TrimSpace(item.Arguments) != "": + return strings.TrimSpace(item.Arguments) + case strings.TrimSpace(item.Name) != "": + return strings.TrimSpace(item.Name) + default: + return "" + } +} + +func extractRealtimeContentSummary(raw []byte) string { + if len(raw) == 0 { + return "" + } + + var decoded any + if err := sonic.Unmarshal(raw, &decoded); err != nil { + return strings.TrimSpace(string(raw)) + } + + var parts []string + collectRealtimeTextFragments(decoded, &parts) + return strings.Join(parts, " ") +} + +func collectRealtimeTextFragments(value any, parts *[]string) { + switch v := value.(type) { + case map[string]any: + for key, field := range v { + switch key { + case "text", "transcript", "input_text", "output_text", "output", "arguments": + if text, ok := field.(string); ok { + text = strings.TrimSpace(text) + if text != "" { + *parts = append(*parts, text) + } + continue + } + } + collectRealtimeTextFragments(field, parts) + } + case []any: + for _, item := range v { + collectRealtimeTextFragments(item, parts) + } + } +} + +func finalizedRealtimeInputSummary(event *schemas.BifrostRealtimeEvent) string { + if event == nil { + return "" + } + + switch event.Type { + case schemas.RTEventInputAudioTransCompleted: + if transcript := extractRealtimeExtraParamString(event, "transcript"); transcript != "" { + return transcript + } + return realtimeMissingTranscriptText + default: + if event != nil && event.Type == schemas.RTEventConversationItemDone && schemas.IsRealtimeUserInputEvent(event) { + if summary := extractRealtimeItemSummary(event.Item); summary != "" { + return summary + } + if realtimeItemHasMissingAudioTranscript(event.Item) { + return realtimeMissingTranscriptText + } + } + if schemas.IsRealtimeUserInputEvent(event) { + return extractRealtimeItemSummary(event.Item) + } + } + + return "" +} + +func pendingRealtimeInputUpdate(event *schemas.BifrostRealtimeEvent) (string, string) { + if event == nil { + return "", "" + } + + switch event.Type { + case schemas.RTEventConversationItemRetrieved: + return "", "" + case schemas.RTEventInputAudioTransCompleted: + return realtimeEventItemID(event), finalizedRealtimeInputSummary(event) + default: + if schemas.IsRealtimeUserInputEvent(event) { + return realtimeEventItemID(event), finalizedRealtimeInputSummary(event) + } + } + + return "", "" +} + +func realtimeItemHasMissingAudioTranscript(item *schemas.RealtimeItem) bool { + if item == nil || len(item.Content) == 0 { + return false + } + + var decoded []map[string]any + if err := sonic.Unmarshal(item.Content, &decoded); err != nil { + return false + } + + for _, part := range decoded { + partType, _ := part["type"].(string) + if partType != "input_audio" { + continue + } + transcript, exists := part["transcript"] + if !exists || transcript == nil { + return true + } + if text, ok := transcript.(string); ok && strings.TrimSpace(text) == "" { + return true + } + } + + return false +} + +func finalizedRealtimeToolOutputSummary(event *schemas.BifrostRealtimeEvent) string { + if !schemas.IsRealtimeToolOutputEvent(event) { + return "" + } + return extractRealtimeItemSummary(event.Item) +} + +func pendingRealtimeToolOutputUpdate(event *schemas.BifrostRealtimeEvent) (string, string) { + if event == nil || event.Type == schemas.RTEventConversationItemRetrieved || !schemas.IsRealtimeToolOutputEvent(event) { + return "", "" + } + return realtimeEventItemID(event), finalizedRealtimeToolOutputSummary(event) +} + +func extractRealtimeExtraParamString(event *schemas.BifrostRealtimeEvent, key string) string { + if event == nil || event.ExtraParams == nil { + return "" + } + raw, ok := event.ExtraParams[key] + if !ok || len(raw) == 0 { + return "" + } + var value string + if err := json.Unmarshal(raw, &value); err != nil { + return "" + } + return strings.TrimSpace(value) +} + +func realtimeEventItemID(event *schemas.BifrostRealtimeEvent) string { + if event == nil { + return "" + } + if event.Item != nil && strings.TrimSpace(event.Item.ID) != "" { + return strings.TrimSpace(event.Item.ID) + } + if event.Delta != nil && strings.TrimSpace(event.Delta.ItemID) != "" { + return strings.TrimSpace(event.Delta.ItemID) + } + return extractRealtimeExtraParamString(event, "item_id") +} + +func combineRealtimeInputRaw(turnInputs []bfws.RealtimeTurnInput) string { + var parts []string + for _, turnInput := range turnInputs { + if trimmed := strings.TrimSpace(turnInput.Raw); trimmed != "" { + parts = append(parts, trimmed) + } + } + return strings.Join(parts, "\n\n") +} + +type realtimeResponseDoneEnvelope struct { + Response struct { + Output []realtimeResponseDoneOutput `json:"output"` + Usage *realtimeResponseDoneUsage `json:"usage"` + } `json:"response"` +} + +type realtimeResponseDoneOutput struct { + ID string `json:"id"` + Type string `json:"type"` + Name string `json:"name"` + CallID string `json:"call_id"` + Arguments string `json:"arguments"` + Content []realtimeResponseDoneContent `json:"content"` +} + +type realtimeResponseDoneContent struct { + Type string `json:"type"` + Text string `json:"text"` + Transcript string `json:"transcript"` + Refusal string `json:"refusal"` +} + +type realtimeResponseDoneUsage struct { + TotalTokens int `json:"total_tokens"` + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + InputTokenDetails *realtimeResponseDoneInputTokenUsage `json:"input_token_details"` + OutputTokenDetails *realtimeResponseDoneOutputTokenUsage `json:"output_token_details"` +} + +type realtimeResponseDoneInputTokenUsage struct { + TextTokens int `json:"text_tokens"` + AudioTokens int `json:"audio_tokens"` + ImageTokens int `json:"image_tokens"` + CachedTokens int `json:"cached_tokens"` +} + +type realtimeResponseDoneOutputTokenUsage struct { + TextTokens int `json:"text_tokens"` + AudioTokens int `json:"audio_tokens"` + ReasoningTokens int `json:"reasoning_tokens"` + ImageTokens *int `json:"image_tokens"` + CitationTokens *int `json:"citation_tokens"` + NumSearchQueries *int `json:"num_search_queries"` + AcceptedPredictionTokens int `json:"accepted_prediction_tokens"` + RejectedPredictionTokens int `json:"rejected_prediction_tokens"` +} + +func extractRealtimeTurnUsage(provider schemas.RealtimeProvider, rawMessage []byte) *schemas.BifrostLLMUsage { + if extractor, ok := provider.(schemas.RealtimeUsageExtractor); ok { + if usage := extractor.ExtractRealtimeTurnUsage(rawMessage); usage != nil { + return usage + } + } + return extractRealtimeResponseDoneUsage(rawMessage) +} + +func extractRealtimeTurnOutputMessage(provider schemas.RealtimeProvider, rawMessage []byte, contentSummary string) *schemas.ChatMessage { + if extractor, ok := provider.(schemas.RealtimeUsageExtractor); ok { + if message := extractor.ExtractRealtimeTurnOutput(rawMessage); message != nil { + if strings.TrimSpace(contentSummary) != "" && (message.Content == nil || message.Content.ContentStr == nil || strings.TrimSpace(*message.Content.ContentStr) == "") { + message.Content = &schemas.ChatMessageContent{ContentStr: schemas.Ptr(strings.TrimSpace(contentSummary))} + } + return message + } + } + return buildRealtimeAssistantLogMessage(rawMessage, contentSummary) +} + +func buildRealtimeAssistantLogMessage(rawMessage []byte, contentSummary string) *schemas.ChatMessage { + contentSummary = strings.TrimSpace(contentSummary) + var parsed realtimeResponseDoneEnvelope + if len(rawMessage) > 0 && sonic.Unmarshal(rawMessage, &parsed) == nil { + message := &schemas.ChatMessage{Role: schemas.ChatMessageRoleAssistant} + if contentSummary == "" { + contentSummary = extractRealtimeResponseDoneAssistantText(parsed.Response.Output) + } + if contentSummary != "" { + message.Content = &schemas.ChatMessageContent{ContentStr: schemas.Ptr(contentSummary)} + } + + toolCalls := extractRealtimeResponseDoneToolCalls(parsed.Response.Output) + if len(toolCalls) > 0 { + message.ChatAssistantMessage = &schemas.ChatAssistantMessage{ + ToolCalls: toolCalls, + } + } + + if message.Content != nil || message.ChatAssistantMessage != nil { + return message + } + } + + if contentSummary == "" { + return nil + } + + return &schemas.ChatMessage{ + Role: schemas.ChatMessageRoleAssistant, + Content: &schemas.ChatMessageContent{ContentStr: schemas.Ptr(contentSummary)}, + } +} + +func extractRealtimeResponseDoneAssistantText(outputs []realtimeResponseDoneOutput) string { + var parts []string + for _, output := range outputs { + if output.Type != "message" { + continue + } + for _, block := range output.Content { + switch { + case strings.TrimSpace(block.Text) != "": + parts = append(parts, strings.TrimSpace(block.Text)) + case strings.TrimSpace(block.Transcript) != "": + parts = append(parts, strings.TrimSpace(block.Transcript)) + case strings.TrimSpace(block.Refusal) != "": + parts = append(parts, strings.TrimSpace(block.Refusal)) + } + } + } + return strings.Join(parts, " ") +} + +func extractRealtimeResponseDoneToolCalls(outputs []realtimeResponseDoneOutput) []schemas.ChatAssistantMessageToolCall { + toolCalls := make([]schemas.ChatAssistantMessageToolCall, 0) + for _, output := range outputs { + if output.Type != "function_call" { + continue + } + + name := strings.TrimSpace(output.Name) + if name == "" { + continue + } + + toolType := "function" + id := strings.TrimSpace(output.CallID) + if id == "" { + id = strings.TrimSpace(output.ID) + } + + toolCall := schemas.ChatAssistantMessageToolCall{ + Index: uint16(len(toolCalls)), + Type: &toolType, + Function: schemas.ChatAssistantMessageToolCallFunction{ + Name: schemas.Ptr(name), + Arguments: output.Arguments, + }, + } + if id != "" { + toolCall.ID = schemas.Ptr(id) + } + + toolCalls = append(toolCalls, toolCall) + } + return toolCalls +} + +func extractRealtimeResponseDoneUsage(rawMessage []byte) *schemas.BifrostLLMUsage { + if len(rawMessage) == 0 { + return nil + } + + var parsed realtimeResponseDoneEnvelope + if err := sonic.Unmarshal(rawMessage, &parsed); err != nil || parsed.Response.Usage == nil { + return nil + } + + totalTokens := parsed.Response.Usage.TotalTokens + if totalTokens == 0 && (parsed.Response.Usage.InputTokens > 0 || parsed.Response.Usage.OutputTokens > 0) { + totalTokens = parsed.Response.Usage.InputTokens + parsed.Response.Usage.OutputTokens + } + + usage := &schemas.BifrostLLMUsage{ + PromptTokens: parsed.Response.Usage.InputTokens, + CompletionTokens: parsed.Response.Usage.OutputTokens, + TotalTokens: totalTokens, + } + + if parsed.Response.Usage.InputTokenDetails != nil { + usage.PromptTokensDetails = &schemas.ChatPromptTokensDetails{ + TextTokens: parsed.Response.Usage.InputTokenDetails.TextTokens, + AudioTokens: parsed.Response.Usage.InputTokenDetails.AudioTokens, + ImageTokens: parsed.Response.Usage.InputTokenDetails.ImageTokens, + CachedReadTokens: parsed.Response.Usage.InputTokenDetails.CachedTokens, + } + } + + if parsed.Response.Usage.OutputTokenDetails != nil { + usage.CompletionTokensDetails = &schemas.ChatCompletionTokensDetails{ + TextTokens: parsed.Response.Usage.OutputTokenDetails.TextTokens, + AudioTokens: parsed.Response.Usage.OutputTokenDetails.AudioTokens, + ReasoningTokens: parsed.Response.Usage.OutputTokenDetails.ReasoningTokens, + ImageTokens: parsed.Response.Usage.OutputTokenDetails.ImageTokens, + CitationTokens: parsed.Response.Usage.OutputTokenDetails.CitationTokens, + NumSearchQueries: parsed.Response.Usage.OutputTokenDetails.NumSearchQueries, + AcceptedPredictionTokens: parsed.Response.Usage.OutputTokenDetails.AcceptedPredictionTokens, + RejectedPredictionTokens: parsed.Response.Usage.OutputTokenDetails.RejectedPredictionTokens, + } + } + + return usage +} diff --git a/transports/bifrost-http/handlers/realtime_logging_test.go b/transports/bifrost-http/handlers/realtime_logging_test.go new file mode 100644 index 0000000000..054f2ea0e9 --- /dev/null +++ b/transports/bifrost-http/handlers/realtime_logging_test.go @@ -0,0 +1,435 @@ +package handlers + +import ( + "encoding/json" + "testing" + "time" + + "github.com/maximhq/bifrost/core/providers/openai" + "github.com/maximhq/bifrost/core/schemas" + bfws "github.com/maximhq/bifrost/transports/bifrost-http/websocket" +) + +func TestShouldAccumulateRealtimeOutput(t *testing.T) { + provider := &openai.OpenAIProvider{} + if !provider.ShouldAccumulateRealtimeOutput(schemas.RTEventResponseTextDelta) { + t.Fatal("expected response.text.delta to accumulate output text") + } + if !provider.ShouldAccumulateRealtimeOutput(schemas.RTEventResponseAudioTransDelta) { + t.Fatal("expected response.audio_transcript.delta to accumulate output transcript") + } + if provider.ShouldAccumulateRealtimeOutput(schemas.RTEventInputAudioTransDelta) { + t.Fatal("did not expect input audio transcription delta to accumulate assistant output") + } +} + +func TestExtractRealtimeTurnSummary(t *testing.T) { + event := &schemas.BifrostRealtimeEvent{ + Type: schemas.RTEventConversationItemCreate, + Item: &schemas.RealtimeItem{ + Content: []byte(`[{"type":"input_text","text":"hello from realtime"}]`), + }, + } + + got := extractRealtimeTurnSummary(event, "") + if got != "hello from realtime" { + t.Fatalf("extractRealtimeTurnSummary() = %q, want %q", got, "hello from realtime") + } +} + +func TestFinalizedRealtimeInputSummary(t *testing.T) { + userCreate := &schemas.BifrostRealtimeEvent{ + Type: schemas.RTEventConversationItemCreate, + Item: &schemas.RealtimeItem{ + Role: "user", + Content: []byte(`[{"type":"input_text","text":"hello from browser"}]`), + }, + } + if got := finalizedRealtimeInputSummary(userCreate); got != "hello from browser" { + t.Fatalf("finalizedRealtimeInputSummary(user create) = %q, want %q", got, "hello from browser") + } + + userRetrieved := &schemas.BifrostRealtimeEvent{ + Type: schemas.RTEventConversationItemRetrieved, + Item: &schemas.RealtimeItem{ + Role: "user", + Content: []byte(`[{"type":"input_text","text":"hello from retrieved item"}]`), + }, + } + if got := finalizedRealtimeInputSummary(userRetrieved); got != "hello from retrieved item" { + t.Fatalf("finalizedRealtimeInputSummary(user retrieved) = %q, want %q", got, "hello from retrieved item") + } + + userCreated := &schemas.BifrostRealtimeEvent{ + Type: schemas.RTEventConversationItemCreated, + Item: &schemas.RealtimeItem{ + Role: "user", + Content: []byte(`[{"type":"input_text","text":"hello from provider created item"}]`), + }, + } + if got := finalizedRealtimeInputSummary(userCreated); got != "hello from provider created item" { + t.Fatalf("finalizedRealtimeInputSummary(user created) = %q, want %q", got, "hello from provider created item") + } + + userAdded := &schemas.BifrostRealtimeEvent{ + Type: schemas.RTEventConversationItemAdded, + Item: &schemas.RealtimeItem{ + Role: "user", + Content: []byte(`[{"type":"input_text","text":"hello from provider added item"}]`), + }, + } + if got := finalizedRealtimeInputSummary(userAdded); got != "hello from provider added item" { + t.Fatalf("finalizedRealtimeInputSummary(user added) = %q, want %q", got, "hello from provider added item") + } + + userCreatedWithoutTranscript := &schemas.BifrostRealtimeEvent{ + Type: schemas.RTEventConversationItemCreated, + Item: &schemas.RealtimeItem{ + Role: "user", + Type: "message", + Content: []byte(`[{"type":"input_audio","audio":null,"transcript":null}]`), + }, + RawData: []byte(`{"type":"conversation.item.created","item":{"type":"message","role":"user","content":[{"type":"input_audio","audio":null,"transcript":null}]}}`), + } + if got := finalizedRealtimeInputSummary(userCreatedWithoutTranscript); got != "" { + t.Fatalf("finalizedRealtimeInputSummary(user created without transcript) = %q, want empty", got) + } + + userDoneWithoutTranscript := &schemas.BifrostRealtimeEvent{ + Type: schemas.RTEventConversationItemDone, + Item: &schemas.RealtimeItem{ + Role: "user", + Type: "message", + Status: "completed", + Content: []byte(`[{"type":"input_audio","audio":null,"transcript":null}]`), + }, + RawData: []byte(`{"type":"conversation.item.done","item":{"type":"message","role":"user","status":"completed","content":[{"type":"input_audio","audio":null,"transcript":null}]}}`), + } + if got := finalizedRealtimeInputSummary(userDoneWithoutTranscript); got != realtimeMissingTranscriptText { + t.Fatalf("finalizedRealtimeInputSummary(user done without transcript) = %q, want %q", got, realtimeMissingTranscriptText) + } + + inputTranscript := &schemas.BifrostRealtimeEvent{ + Type: schemas.RTEventInputAudioTransCompleted, + ExtraParams: map[string]json.RawMessage{ + "transcript": json.RawMessage(`"spoken user turn"`), + }, + } + if got := finalizedRealtimeInputSummary(inputTranscript); got != "spoken user turn" { + t.Fatalf("finalizedRealtimeInputSummary(input transcript) = %q, want %q", got, "spoken user turn") + } + + emptyInputTranscript := &schemas.BifrostRealtimeEvent{ + Type: schemas.RTEventInputAudioTransCompleted, + ExtraParams: map[string]json.RawMessage{ + "transcript": json.RawMessage(`""`), + }, + RawData: []byte(`{"type":"conversation.item.input_audio_transcription.completed","transcript":"","usage":{"total_tokens":11}}`), + } + if got := finalizedRealtimeInputSummary(emptyInputTranscript); got != realtimeMissingTranscriptText { + t.Fatalf("finalizedRealtimeInputSummary(empty input transcript) = %q, want %q", got, realtimeMissingTranscriptText) + } + + missingInputTranscript := &schemas.BifrostRealtimeEvent{ + Type: schemas.RTEventInputAudioTransCompleted, + RawData: []byte(`{"type":"conversation.item.input_audio_transcription.completed","usage":{"total_tokens":11}}`), + } + if got := finalizedRealtimeInputSummary(missingInputTranscript); got != realtimeMissingTranscriptText { + t.Fatalf("finalizedRealtimeInputSummary(missing input transcript) = %q, want %q", got, realtimeMissingTranscriptText) + } + + assistantCreate := &schemas.BifrostRealtimeEvent{ + Type: schemas.RTEventConversationItemCreate, + Item: &schemas.RealtimeItem{ + Role: "assistant", + Content: []byte(`[{"type":"text","text":"assistant text"}]`), + }, + } + if got := finalizedRealtimeInputSummary(assistantCreate); got != "" { + t.Fatalf("finalizedRealtimeInputSummary(assistant create) = %q, want empty", got) + } +} + +func TestFinalizedRealtimeToolOutputSummary(t *testing.T) { + event := &schemas.BifrostRealtimeEvent{ + Type: schemas.RTEventConversationItemCreate, + Item: &schemas.RealtimeItem{ + Type: "function_call_output", + Output: `{"nextResponse":"tool result"}`, + }, + } + if got := finalizedRealtimeToolOutputSummary(event); got != `{"nextResponse":"tool result"}` { + t.Fatalf("finalizedRealtimeToolOutputSummary() = %q, want %q", got, `{"nextResponse":"tool result"}`) + } + + retrieved := &schemas.BifrostRealtimeEvent{ + Type: schemas.RTEventConversationItemRetrieved, + Item: &schemas.RealtimeItem{ + Type: "function_call_output", + Output: `{"nextResponse":"tool result from retrieved"}`, + }, + } + if got := finalizedRealtimeToolOutputSummary(retrieved); got != `{"nextResponse":"tool result from retrieved"}` { + t.Fatalf("finalizedRealtimeToolOutputSummary(retrieved) = %q, want %q", got, `{"nextResponse":"tool result from retrieved"}`) + } + + created := &schemas.BifrostRealtimeEvent{ + Type: schemas.RTEventConversationItemCreated, + Item: &schemas.RealtimeItem{ + Type: "function_call_output", + Output: `{"nextResponse":"tool result from created"}`, + }, + } + if got := finalizedRealtimeToolOutputSummary(created); got != `{"nextResponse":"tool result from created"}` { + t.Fatalf("finalizedRealtimeToolOutputSummary(created) = %q, want %q", got, `{"nextResponse":"tool result from created"}`) + } + + added := &schemas.BifrostRealtimeEvent{ + Type: schemas.RTEventConversationItemAdded, + Item: &schemas.RealtimeItem{ + Type: "function_call_output", + Output: `{"nextResponse":"tool result from added"}`, + }, + } + if got := finalizedRealtimeToolOutputSummary(added); got != `{"nextResponse":"tool result from added"}` { + t.Fatalf("finalizedRealtimeToolOutputSummary(added) = %q, want %q", got, `{"nextResponse":"tool result from added"}`) + } +} + +func TestPendingRealtimeInputUpdate(t *testing.T) { + t.Parallel() + + transcriptEvent := &schemas.BifrostRealtimeEvent{ + Type: schemas.RTEventInputAudioTransCompleted, + ExtraParams: map[string]json.RawMessage{ + "item_id": json.RawMessage(`"item_123"`), + "transcript": json.RawMessage(`"Hello."`), + }, + } + itemID, summary := pendingRealtimeInputUpdate(transcriptEvent) + if itemID != "item_123" || summary != "Hello." { + t.Fatalf("pendingRealtimeInputUpdate(transcript) = (%q, %q), want (%q, %q)", itemID, summary, "item_123", "Hello.") + } + + retrievedEvent := &schemas.BifrostRealtimeEvent{ + Type: schemas.RTEventConversationItemRetrieved, + Item: &schemas.RealtimeItem{ + ID: "item_123", + Role: "user", + Content: []byte(`[{"type":"input_text","text":"historical hello"}]`), + }, + } + itemID, summary = pendingRealtimeInputUpdate(retrievedEvent) + if itemID != "" || summary != "" { + t.Fatalf("pendingRealtimeInputUpdate(retrieved) = (%q, %q), want empty", itemID, summary) + } +} + +func TestPendingRealtimeToolOutputUpdate(t *testing.T) { + t.Parallel() + + toolOutputEvent := &schemas.BifrostRealtimeEvent{ + Type: schemas.RTEventConversationItemDone, + Item: &schemas.RealtimeItem{ + ID: "item_tool_123", + Type: "function_call_output", + Output: `{"nextResponse":"tool result"}`, + }, + } + itemID, summary := pendingRealtimeToolOutputUpdate(toolOutputEvent) + if itemID != "item_tool_123" || summary != `{"nextResponse":"tool result"}` { + t.Fatalf("pendingRealtimeToolOutputUpdate(done) = (%q, %q), want (%q, %q)", itemID, summary, "item_tool_123", `{"nextResponse":"tool result"}`) + } + + retrievedToolOutputEvent := &schemas.BifrostRealtimeEvent{ + Type: schemas.RTEventConversationItemRetrieved, + Item: &schemas.RealtimeItem{ + ID: "item_tool_123", + Type: "function_call_output", + Output: `{"nextResponse":"historical tool result"}`, + }, + } + itemID, summary = pendingRealtimeToolOutputUpdate(retrievedToolOutputEvent) + if itemID != "" || summary != "" { + t.Fatalf("pendingRealtimeToolOutputUpdate(retrieved) = (%q, %q), want empty", itemID, summary) + } +} + +func TestBuildRealtimeTurnPostResponseUsesFullResponseDonePayload(t *testing.T) { + rawRequest := `{"type":"conversation.item.input_audio_transcription.completed","transcript":""}` + rawResponse := []byte(`{ + "type":"response.done", + "response":{ + "output":[ + { + "id":"item_message_123", + "type":"message", + "content":[ + { + "type":"audio", + "transcript":"assistant turn text" + } + ] + } + ], + "usage":{ + "total_tokens":26, + "input_tokens":17, + "output_tokens":9, + "input_token_details":{ + "text_tokens":12, + "audio_tokens":5, + "image_tokens":0, + "cached_tokens":4 + }, + "output_token_details":{ + "text_tokens":7, + "audio_tokens":2 + } + } + } + }`) + + resp := buildRealtimeTurnPostResponse(&openai.OpenAIProvider{}, schemas.OpenAI, "gpt-4o-realtime-preview-2025-06-03", rawRequest, rawResponse, "", 4321) + if resp == nil || resp.ResponsesResponse == nil { + t.Fatal("expected realtime post response to be built") + } + if resp.ResponsesResponse.ExtraFields.Latency != 4321 { + t.Fatalf("Latency = %d, want %d", resp.ResponsesResponse.ExtraFields.Latency, 4321) + } + if resp.ResponsesResponse.Usage == nil || resp.ResponsesResponse.Usage.InputTokens != 17 || resp.ResponsesResponse.Usage.OutputTokens != 9 || resp.ResponsesResponse.Usage.TotalTokens != 26 { + t.Fatalf("Usage = %+v, want input=17 output=9 total=26", resp.ResponsesResponse.Usage) + } + if len(resp.ResponsesResponse.Output) != 1 { + t.Fatalf("len(Output) = %d, want 1", len(resp.ResponsesResponse.Output)) + } + if resp.ResponsesResponse.Output[0].Content == nil || resp.ResponsesResponse.Output[0].Content.ContentStr == nil || *resp.ResponsesResponse.Output[0].Content.ContentStr != "assistant turn text" { + t.Fatalf("Output[0].Content = %+v, want assistant turn text", resp.ResponsesResponse.Output[0].Content) + } + if got, ok := resp.ResponsesResponse.ExtraFields.RawRequest.(string); !ok || got != rawRequest { + t.Fatalf("RawRequest = %#v, want %q", resp.ResponsesResponse.ExtraFields.RawRequest, rawRequest) + } + if got, ok := resp.ResponsesResponse.ExtraFields.RawResponse.(string); !ok || got == "" { + t.Fatalf("RawResponse = %#v, want raw response string", resp.ResponsesResponse.ExtraFields.RawResponse) + } +} + +func TestFinalizeRealtimeTurnHooksWithErrorCompletesActiveHooks(t *testing.T) { + t.Parallel() + + session := bfws.NewSession(nil) + session.SetProviderSessionID("sess_provider_123") + session.AddRealtimeInput("hello from user", `{"type":"conversation.item.added"}`) + session.AppendRealtimeOutputText("partial assistant output") + + var ( + capturedResp *schemas.BifrostResponse + capturedErr *schemas.BifrostError + cleanedUp bool + ) + session.SetRealtimeTurnHooks(&bfws.RealtimeTurnPluginState{ + RequestID: "req_realtime_123", + StartedAt: time.Now().Add(-time.Second), + PreHookValues: map[any]any{ + schemas.BifrostContextKeyGovernanceVirtualKeyID: "vk_123", + }, + PostHookRunner: func(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) { + capturedResp = result + capturedErr = err + return result, nil + }, + Cleanup: func() { + cleanedUp = true + }, + }) + + rawResponse := []byte(`{"type":"error","error":{"type":"server_error","message":"Virtual key is required."}}`) + postErr := finalizeRealtimeTurnHooksWithError( + nil, + nil, + session, + schemas.OpenAI, + "gpt-realtime", + nil, + schemas.RTEventError, + rawResponse, + newRealtimeWireBifrostError(401, "server_error", "Virtual key is required."), + ) + if postErr != nil { + t.Fatalf("finalizeRealtimeTurnHooksWithError() post error = %v, want nil", postErr) + } + if capturedResp != nil { + t.Fatalf("captured response = %#v, want nil", capturedResp) + } + if capturedErr == nil { + t.Fatal("expected captured error") + } + if capturedErr.ExtraFields.RequestType != schemas.RealtimeRequest { + t.Fatalf("request type = %q, want %q", capturedErr.ExtraFields.RequestType, schemas.RealtimeRequest) + } + if capturedErr.ExtraFields.Provider != schemas.OpenAI { + t.Fatalf("provider = %q, want %q", capturedErr.ExtraFields.Provider, schemas.OpenAI) + } + if capturedErr.ExtraFields.OriginalModelRequested != "gpt-realtime" { + t.Fatalf("model requested = %q, want %q", capturedErr.ExtraFields.OriginalModelRequested, "gpt-realtime") + } + rawRequest, ok := capturedErr.ExtraFields.RawRequest.(string) + if !ok || rawRequest == "" { + t.Fatalf("raw request = %#v, want non-empty string", capturedErr.ExtraFields.RawRequest) + } + rawResp, ok := capturedErr.ExtraFields.RawResponse.(json.RawMessage) + if !ok || string(rawResp) != string(rawResponse) { + t.Fatalf("raw response = %#v, want %s", capturedErr.ExtraFields.RawResponse, string(rawResponse)) + } + if session.PeekRealtimeTurnHooks() != nil { + t.Fatal("expected active hooks to be cleared") + } + if got := session.ConsumeRealtimeTurnInputs(); len(got) != 0 { + t.Fatalf("remaining turn inputs = %d, want 0", len(got)) + } + if got := session.ConsumeRealtimeOutputText(); got != "" { + t.Fatalf("remaining output text = %q, want empty", got) + } + if !cleanedUp { + t.Fatal("expected realtime hook cleanup to run") + } +} + +func TestNewBifrostErrorFromRealtimeErrorCarriesRealtimeMetadata(t *testing.T) { + t.Parallel() + + rawResponse := []byte(`{"type":"error","error":{"type":"invalid_request_error","code":"invalid_request_error","message":"bad request","param":"session.type"}}`) + bifrostErr := newBifrostErrorFromRealtimeError( + schemas.OpenAI, + "gpt-realtime", + rawResponse, + &schemas.RealtimeError{ + Type: "invalid_request_error", + Code: "invalid_request_error", + Message: "bad request", + Param: "session.type", + }, + ) + if bifrostErr == nil { + t.Fatal("expected bifrost error") + } + if bifrostErr.StatusCode == nil || *bifrostErr.StatusCode != 400 { + t.Fatalf("status code = %#v, want 400", bifrostErr.StatusCode) + } + if bifrostErr.ExtraFields.RequestType != schemas.RealtimeRequest { + t.Fatalf("request type = %q, want %q", bifrostErr.ExtraFields.RequestType, schemas.RealtimeRequest) + } + if bifrostErr.ExtraFields.Provider != schemas.OpenAI { + t.Fatalf("provider = %q, want %q", bifrostErr.ExtraFields.Provider, schemas.OpenAI) + } + if bifrostErr.ExtraFields.OriginalModelRequested != "gpt-realtime" { + t.Fatalf("model requested = %q, want %q", bifrostErr.ExtraFields.OriginalModelRequested, "gpt-realtime") + } + rawResp, ok := bifrostErr.ExtraFields.RawResponse.(json.RawMessage) + if !ok || string(rawResp) != string(rawResponse) { + t.Fatalf("raw response = %#v, want %s", bifrostErr.ExtraFields.RawResponse, string(rawResponse)) + } + if bifrostErr.Error == nil || bifrostErr.Error.Param != "session.type" { + t.Fatalf("error param = %#v, want session.type", bifrostErr.Error) + } +} diff --git a/transports/bifrost-http/handlers/realtime_turn_pipeline.go b/transports/bifrost-http/handlers/realtime_turn_pipeline.go new file mode 100644 index 0000000000..91095e5843 --- /dev/null +++ b/transports/bifrost-http/handlers/realtime_turn_pipeline.go @@ -0,0 +1,798 @@ +package handlers + +import ( + "context" + "encoding/json" + "fmt" + "strings" + "time" + + "github.com/google/uuid" + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + bfws "github.com/maximhq/bifrost/transports/bifrost-http/websocket" +) + +func newRealtimeTurnContext( + baseCtx *schemas.BifrostContext, + requestID string, + sessionID string, + providerSessionID string, + source realtimeTurnSource, + eventType schemas.RealtimeEventType, + key *schemas.Key, +) *schemas.BifrostContext { + ctx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + if baseCtx != nil { + // Realtime post-hook contexts must preserve plugin-private values written in + // pre-hooks (for example telemetry start timestamps), not just public keys. + for ctxKey, value := range baseCtx.GetUserValues() { + if value != nil { + ctx.SetValue(ctxKey, value) + } + } + } + + ctx.SetValue(schemas.BifrostContextKeyHTTPRequestType, schemas.RealtimeRequest) + if requestID == "" { + requestID = uuid.NewString() + } + ctx.SetValue(schemas.BifrostContextKeyRequestID, requestID) + resolvedSessionID := strings.TrimSpace(providerSessionID) + if resolvedSessionID == "" { + resolvedSessionID = strings.TrimSpace(sessionID) + } + if baseCtx != nil { + if externalSessionID, ok := baseCtx.Value(schemas.BifrostContextKeyParentRequestID).(string); ok && strings.TrimSpace(externalSessionID) != "" { + resolvedSessionID = strings.TrimSpace(externalSessionID) + } + } + if resolvedSessionID != "" { + ctx.SetValue(schemas.BifrostContextKeyParentRequestID, resolvedSessionID) + } + if strings.TrimSpace(providerSessionID) != "" { + ctx.SetValue(schemas.BifrostContextKeyRealtimeSessionID, providerSessionID) + ctx.SetValue(schemas.BifrostContextKeyRealtimeProviderSessionID, providerSessionID) + } + if source != "" { + ctx.SetValue(schemas.BifrostContextKeyRealtimeSource, string(source)) + } + if eventType != "" { + ctx.SetValue(schemas.BifrostContextKeyRealtimeEventType, string(eventType)) + } + if key != nil { + if strings.TrimSpace(key.ID) != "" { + ctx.SetValue(schemas.BifrostContextKeySelectedKeyID, key.ID) + } + if strings.TrimSpace(key.Name) != "" { + ctx.SetValue(schemas.BifrostContextKeySelectedKeyName, key.Name) + } + } + return ctx +} + +func applyRealtimeTurnContextValues(ctx *schemas.BifrostContext, values map[any]any) { + if ctx == nil || len(values) == 0 { + return + } + for ctxKey, value := range values { + switch ctxKey { + case schemas.BifrostContextKeyRequestID, + schemas.BifrostContextKeyParentRequestID, + schemas.BifrostContextKeyRealtimeSessionID, + schemas.BifrostContextKeyRealtimeProviderSessionID, + schemas.BifrostContextKeyRealtimeSource, + schemas.BifrostContextKeyRealtimeEventType, + schemas.BifrostContextKeyStreamStartTime, + schemas.BifrostContextKeyStreamEndIndicator: + continue + } + if value != nil { + ctx.SetValue(ctxKey, value) + } + } +} + +func setRealtimeTurnStreamContext(ctx *schemas.BifrostContext, startedAt time.Time, isFinal bool) { + if ctx == nil { + return + } + if startedAt.IsZero() { + startedAt = time.Now() + } + ctx.SetValue(schemas.BifrostContextKeyStreamStartTime, startedAt) + if isFinal { + ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true) + } +} + +func buildRealtimeTurnPreRequest(provider schemas.ModelProvider, model string, turnInputs []bfws.RealtimeTurnInput) *schemas.BifrostRequest { + input := make([]schemas.ResponsesMessage, 0, len(turnInputs)) + for _, turnInput := range turnInputs { + summary := strings.TrimSpace(turnInput.Summary) + if summary == "" { + continue + } + switch turnInput.Role { + case string(schemas.ChatMessageRoleTool): + itemType := schemas.ResponsesMessageTypeFunctionCallOutput + output := &schemas.ResponsesToolMessageOutputStruct{ + ResponsesToolCallOutputStr: schemas.Ptr(summary), + } + input = append(input, schemas.ResponsesMessage{ + Type: &itemType, + ResponsesToolMessage: &schemas.ResponsesToolMessage{Output: output}, + }) + case string(schemas.ChatMessageRoleUser): + itemType := schemas.ResponsesMessageTypeMessage + role := schemas.ResponsesInputMessageRoleUser + input = append(input, schemas.ResponsesMessage{ + Type: &itemType, + Role: &role, + Content: &schemas.ResponsesMessageContent{ContentStr: schemas.Ptr(summary)}, + }) + } + } + + return &schemas.BifrostRequest{ + RequestType: schemas.RealtimeRequest, + ResponsesRequest: &schemas.BifrostResponsesRequest{ + Provider: provider, + Model: model, + Input: input, + }, + } +} + +func buildRealtimeTurnPostResponse( + rtProvider schemas.RealtimeProvider, + provider schemas.ModelProvider, + model string, + rawRequest string, + rawResponse []byte, + contentOverride string, + latency int64, +) *schemas.BifrostResponse { + output := buildRealtimeTurnOutputMessages(rtProvider, rawResponse, contentOverride) + resp := &schemas.BifrostResponsesResponse{ + Object: "response", + Model: model, + Output: output, + ExtraFields: schemas.BifrostResponseExtraFields{ + RequestType: schemas.RealtimeRequest, + Provider: provider, + OriginalModelRequested: model, + Latency: latency, + }, + } + if usage := extractRealtimeTurnUsage(rtProvider, rawResponse); usage != nil { + resp.Usage = buildRealtimeResponsesUsage(usage) + } + if strings.TrimSpace(rawRequest) != "" { + resp.ExtraFields.RawRequest = rawRequest + } + if len(rawResponse) > 0 { + resp.ExtraFields.RawResponse = string(rawResponse) + } + + return &schemas.BifrostResponse{ResponsesResponse: resp} +} + +func buildRealtimeTurnOutputMessages(rtProvider schemas.RealtimeProvider, rawResponse []byte, contentOverride string) []schemas.ResponsesMessage { + outputs := make([]schemas.ResponsesMessage, 0) + if outputMessage := extractRealtimeTurnOutputMessage(rtProvider, rawResponse, contentOverride); outputMessage != nil { + outputs = append(outputs, buildRealtimeResponsesMessagesFromChat(outputMessage, contentOverride)...) + } + + if len(outputs) > 0 { + return outputs + } + + var parsed realtimeResponseDoneEnvelope + if len(rawResponse) > 0 && schemas.Unmarshal(rawResponse, &parsed) == nil { + for _, item := range parsed.Response.Output { + switch item.Type { + case "message": + content := strings.TrimSpace(contentOverride) + if content == "" { + content = extractRealtimeResponseDoneContentText(item.Content) + } + itemType := schemas.ResponsesMessageTypeMessage + role := schemas.ResponsesInputMessageRoleAssistant + msg := schemas.ResponsesMessage{ + Type: &itemType, + Role: &role, + Status: schemas.Ptr("completed"), + } + if strings.TrimSpace(item.ID) != "" { + msg.ID = schemas.Ptr(strings.TrimSpace(item.ID)) + } + if content != "" { + msg.Content = &schemas.ResponsesMessageContent{ContentStr: schemas.Ptr(content)} + } + outputs = append(outputs, msg) + case "function_call": + itemType := schemas.ResponsesMessageTypeFunctionCall + msg := schemas.ResponsesMessage{ + Type: &itemType, + Status: schemas.Ptr("completed"), + ResponsesToolMessage: &schemas.ResponsesToolMessage{ + Name: schemas.Ptr(strings.TrimSpace(item.Name)), + Arguments: schemas.Ptr(item.Arguments), + }, + } + if strings.TrimSpace(item.ID) != "" { + msg.ID = schemas.Ptr(strings.TrimSpace(item.ID)) + } + if strings.TrimSpace(item.CallID) != "" { + msg.CallID = schemas.Ptr(strings.TrimSpace(item.CallID)) + } + outputs = append(outputs, msg) + } + } + } + + if len(outputs) == 0 && strings.TrimSpace(contentOverride) != "" { + itemType := schemas.ResponsesMessageTypeMessage + role := schemas.ResponsesInputMessageRoleAssistant + outputs = append(outputs, schemas.ResponsesMessage{ + Type: &itemType, + Role: &role, + Status: schemas.Ptr("completed"), + Content: &schemas.ResponsesMessageContent{ContentStr: schemas.Ptr(strings.TrimSpace(contentOverride))}, + }) + } + + return outputs +} + +func buildRealtimeResponsesMessagesFromChat(message *schemas.ChatMessage, contentOverride string) []schemas.ResponsesMessage { + if message == nil { + return nil + } + + outputs := make([]schemas.ResponsesMessage, 0, 1) + content := strings.TrimSpace(contentOverride) + if content == "" && message.Content != nil && message.Content.ContentStr != nil { + content = strings.TrimSpace(*message.Content.ContentStr) + } + if content != "" { + itemType := schemas.ResponsesMessageTypeMessage + role := schemas.ResponsesInputMessageRoleAssistant + outputs = append(outputs, schemas.ResponsesMessage{ + Type: &itemType, + Role: &role, + Status: schemas.Ptr("completed"), + Content: &schemas.ResponsesMessageContent{ContentStr: schemas.Ptr(content)}, + }) + } + + if message.ChatAssistantMessage == nil { + return outputs + } + + for _, toolCall := range message.ChatAssistantMessage.ToolCalls { + itemType := schemas.ResponsesMessageTypeFunctionCall + msg := schemas.ResponsesMessage{ + Type: &itemType, + Status: schemas.Ptr("completed"), + ResponsesToolMessage: &schemas.ResponsesToolMessage{ + Arguments: schemas.Ptr(toolCall.Function.Arguments), + }, + } + if toolCall.Function.Name != nil { + msg.ResponsesToolMessage.Name = schemas.Ptr(strings.TrimSpace(*toolCall.Function.Name)) + } + if toolCall.ID != nil { + msg.CallID = schemas.Ptr(strings.TrimSpace(*toolCall.ID)) + msg.ID = schemas.Ptr(strings.TrimSpace(*toolCall.ID)) + } + outputs = append(outputs, msg) + } + + return outputs +} + +func extractRealtimeResponseDoneContentText(content []realtimeResponseDoneContent) string { + for _, block := range content { + switch { + case strings.TrimSpace(block.Text) != "": + return strings.TrimSpace(block.Text) + case strings.TrimSpace(block.Transcript) != "": + return strings.TrimSpace(block.Transcript) + case strings.TrimSpace(block.Refusal) != "": + return strings.TrimSpace(block.Refusal) + } + } + return "" +} + +func buildRealtimeResponsesUsage(usage *schemas.BifrostLLMUsage) *schemas.ResponsesResponseUsage { + if usage == nil { + return nil + } + result := &schemas.ResponsesResponseUsage{ + InputTokens: usage.PromptTokens, + OutputTokens: usage.CompletionTokens, + TotalTokens: usage.TotalTokens, + } + if usage.PromptTokensDetails != nil { + result.InputTokensDetails = &schemas.ResponsesResponseInputTokens{ + TextTokens: usage.PromptTokensDetails.TextTokens, + AudioTokens: usage.PromptTokensDetails.AudioTokens, + ImageTokens: usage.PromptTokensDetails.ImageTokens, + CachedReadTokens: usage.PromptTokensDetails.CachedReadTokens, + CachedWriteTokens: usage.PromptTokensDetails.CachedWriteTokens, + } + } + if usage.CompletionTokensDetails != nil { + result.OutputTokensDetails = &schemas.ResponsesResponseOutputTokens{ + TextTokens: usage.CompletionTokensDetails.TextTokens, + AcceptedPredictionTokens: usage.CompletionTokensDetails.AcceptedPredictionTokens, + AudioTokens: usage.CompletionTokensDetails.AudioTokens, + ImageTokens: usage.CompletionTokensDetails.ImageTokens, + ReasoningTokens: usage.CompletionTokensDetails.ReasoningTokens, + RejectedPredictionTokens: usage.CompletionTokensDetails.RejectedPredictionTokens, + CitationTokens: usage.CompletionTokensDetails.CitationTokens, + NumSearchQueries: usage.CompletionTokensDetails.NumSearchQueries, + } + } + return result +} + +func newRealtimeTurnErrorEventPayload(bifrostErr *schemas.BifrostError) []byte { + if bifrostErr == nil { + return []byte(`{"type":"error","error":{"type":"server_error","message":"internal server error"}}`) + } + + errorType, errorCode, errorMessage, errorParam := mapRealtimeWireErrorFields(bifrostErr) + payload := schemas.BifrostRealtimeEvent{ + Type: schemas.RTEventError, + Error: &schemas.RealtimeError{ + Type: errorType, + Code: errorCode, + Message: errorMessage, + Param: errorParam, + }, + } + if data, err := schemas.Marshal(payload); err == nil { + return data + } + return []byte(`{"type":"error","error":{"type":"server_error","message":"internal server error"}}`) +} + +// isBudgetOrBillingError returns true if the lowercased value indicates a budget or billing exhaustion error. +// Quota/rate-limit patterns (quota_exceeded, quota exceeded, etc.) are already covered by bifrost.IsRateLimitErrorMessage. +func isBudgetOrBillingError(lower string) bool { + return strings.Contains(lower, "budget_exceeded") || + strings.Contains(lower, "budget exceeded") || + strings.Contains(lower, "insufficient_quota") || + strings.Contains(lower, "hard limit reached") || + strings.Contains(lower, "billing hard limit") +} + +func mapRealtimeWireErrorFields(bifrostErr *schemas.BifrostError) (string, string, string, string) { + errorType := "server_error" + errorCode := "server_error" + errorMessage := "internal server error" + errorParam := "" + + if bifrostErr == nil { + return errorType, errorCode, errorMessage, errorParam + } + + var values []string + if bifrostErr.Type != nil { + values = append(values, strings.TrimSpace(*bifrostErr.Type)) + } + if bifrostErr.Error != nil { + if bifrostErr.Error.Type != nil { + values = append(values, strings.TrimSpace(*bifrostErr.Error.Type)) + } + if bifrostErr.Error.Code != nil { + values = append(values, strings.TrimSpace(*bifrostErr.Error.Code)) + } + if strings.TrimSpace(bifrostErr.Error.Message) != "" { + errorMessage = strings.TrimSpace(bifrostErr.Error.Message) + values = append(values, errorMessage) + } + if bifrostErr.Error.Param != nil { + errorParam = strings.TrimSpace(fmt.Sprint(bifrostErr.Error.Param)) + } + } + + for _, value := range values { + lower := strings.ToLower(value) + switch { + case lower == "": + continue + case strings.Contains(lower, "invalid_request_error"): + return "invalid_request_error", "invalid_request_error", errorMessage, errorParam + case isBudgetOrBillingError(lower): + return "insufficient_quota", "insufficient_quota", errorMessage, errorParam + case bifrost.IsRateLimitErrorMessage(lower): + return "rate_limit_exceeded", "rate_limit_exceeded", errorMessage, errorParam + } + } + + return errorType, errorCode, errorMessage, errorParam +} + +func shouldGracefullyDisconnectRealtime(bifrostErr *schemas.BifrostError) bool { + if bifrostErr == nil { + return false + } + + var values []string + if bifrostErr.Type != nil { + values = append(values, strings.TrimSpace(*bifrostErr.Type)) + } + if bifrostErr.Error != nil { + if bifrostErr.Error.Type != nil { + values = append(values, strings.TrimSpace(*bifrostErr.Error.Type)) + } + if bifrostErr.Error.Code != nil { + values = append(values, strings.TrimSpace(*bifrostErr.Error.Code)) + } + values = append(values, strings.TrimSpace(bifrostErr.Error.Message)) + } + + for _, value := range values { + lower := strings.ToLower(value) + if lower == "" { + continue + } + if isBudgetOrBillingError(lower) || bifrost.IsRateLimitErrorMessage(lower) { + return true + } + } + + return false +} + +func startRealtimeTurnHooks( + client *bifrost.Bifrost, + baseCtx *schemas.BifrostContext, + session *bfws.Session, + rtProvider schemas.RealtimeProvider, + provider schemas.ModelProvider, + model string, + key *schemas.Key, + startEventType schemas.RealtimeEventType, +) *schemas.BifrostError { + if client == nil || session == nil { + return &schemas.BifrostError{ + Type: schemas.Ptr("server_error"), + StatusCode: schemas.Ptr(500), + Error: &schemas.ErrorField{ + Type: schemas.Ptr("server_error"), + Message: "realtime turn pipeline is unavailable", + }, + } + } + if !session.TryBeginRealtimeTurnHooks() { + return &schemas.BifrostError{ + Type: schemas.Ptr("invalid_request_error"), + StatusCode: schemas.Ptr(400), + Error: &schemas.ErrorField{ + Type: schemas.Ptr("invalid_request_error"), + Message: "Conversation already has an active response in progress.", + }, + } + } + committed := false + defer func() { + if !committed { + session.AbortRealtimeTurnHooks() + } + }() + + startedAt := time.Now() + turnCtx := newRealtimeTurnContext(baseCtx, "", session.ID(), session.ProviderSessionID(), realtimeTurnSourceEI, startEventType, key) + setRealtimeTurnStreamContext(turnCtx, startedAt, false) + req := buildRealtimeTurnPreRequest(provider, model, session.PeekRealtimeTurnInputs()) + hooks, bifrostErr := client.RunRealtimeTurnPreHooks(turnCtx, req) + if bifrostErr != nil { + // RunRealtimeTurnPreHooks already executed post-hooks and flushed the trace + // for this turn-start failure. Clear buffered turn state so transport-close + // fallback finalization does not emit the same error a second time. + session.ConsumeRealtimeTurnInputs() + session.ConsumeRealtimeOutputText() + return bifrostErr + } + + requestID, _ := turnCtx.Value(schemas.BifrostContextKeyRequestID).(string) + session.SetRealtimeTurnHooks(&bfws.RealtimeTurnPluginState{ + PostHookRunner: hooks.PostHookRunner, + Cleanup: hooks.Cleanup, + RequestID: requestID, + StartedAt: startedAt, + PreHookValues: turnCtx.GetUserValues(), + }) + committed = true + return nil +} + +func finalizeRealtimeTurnHooks( + client *bifrost.Bifrost, + baseCtx *schemas.BifrostContext, + session *bfws.Session, + rtProvider schemas.RealtimeProvider, + provider schemas.ModelProvider, + model string, + key *schemas.Key, + rawResponse []byte, + contentOverride string, +) *schemas.BifrostError { + if client == nil || session == nil { + return nil + } + + turnInputs := session.ConsumeRealtimeTurnInputs() + rawRequest := combineRealtimeInputRaw(turnInputs) + + if activeHooks := session.ConsumeRealtimeTurnHooks(); activeHooks != nil { + defer func() { + if activeHooks.Cleanup != nil { + activeHooks.Cleanup() + } + }() + postResponse := buildRealtimeTurnPostResponse( + rtProvider, + provider, + model, + rawRequest, + rawResponse, + contentOverride, + time.Since(activeHooks.StartedAt).Milliseconds(), + ) + postCtx := newRealtimeTurnContext(baseCtx, activeHooks.RequestID, session.ID(), session.ProviderSessionID(), realtimeTurnSourceLM, rtProvider.RealtimeTurnFinalEvent(), key) + applyRealtimeTurnContextValues(postCtx, activeHooks.PreHookValues) + setRealtimeTurnStreamContext(postCtx, activeHooks.StartedAt, true) + _, bifrostErr := activeHooks.PostHookRunner(postCtx, postResponse, nil) + completeRealtimeTurnTrace(postCtx) + return bifrostErr + } + + startedAt := time.Now() + preCtx := newRealtimeTurnContext(baseCtx, "", session.ID(), session.ProviderSessionID(), realtimeTurnSourceEI, "", key) + setRealtimeTurnStreamContext(preCtx, startedAt, false) + preReq := buildRealtimeTurnPreRequest(provider, model, turnInputs) + hooks, bifrostErr := client.RunRealtimeTurnPreHooks(preCtx, preReq) + if bifrostErr != nil { + return bifrostErr + } + if hooks.Cleanup != nil { + defer hooks.Cleanup() + } + + requestID, _ := preCtx.Value(schemas.BifrostContextKeyRequestID).(string) + postResponse := buildRealtimeTurnPostResponse( + rtProvider, + provider, + model, + rawRequest, + rawResponse, + contentOverride, + time.Since(startedAt).Milliseconds(), + ) + postCtx := newRealtimeTurnContext(baseCtx, requestID, session.ID(), session.ProviderSessionID(), realtimeTurnSourceLM, rtProvider.RealtimeTurnFinalEvent(), key) + applyRealtimeTurnContextValues(postCtx, preCtx.GetUserValues()) + setRealtimeTurnStreamContext(postCtx, startedAt, true) + _, bifrostErr = hooks.PostHookRunner(postCtx, postResponse, nil) + completeRealtimeTurnTrace(postCtx) + return bifrostErr +} + +func finalizeRealtimeTurnHooksWithError( + client *bifrost.Bifrost, + baseCtx *schemas.BifrostContext, + session *bfws.Session, + provider schemas.ModelProvider, + model string, + key *schemas.Key, + eventType schemas.RealtimeEventType, + rawResponse []byte, + bifrostErr *schemas.BifrostError, +) *schemas.BifrostError { + if session == nil || bifrostErr == nil { + return nil + } + + turnInputs := session.ConsumeRealtimeTurnInputs() + rawRequest := combineRealtimeInputRaw(turnInputs) + session.ConsumeRealtimeOutputText() + + if activeHooks := session.ConsumeRealtimeTurnHooks(); activeHooks != nil { + defer func() { + if activeHooks.Cleanup != nil { + activeHooks.Cleanup() + } + }() + postErr := buildRealtimeTurnPostError( + provider, + model, + rawRequest, + rawResponse, + bifrostErr, + ) + postCtx := newRealtimeTurnContext(baseCtx, activeHooks.RequestID, session.ID(), session.ProviderSessionID(), realtimeTurnSourceLM, eventType, key) + applyRealtimeTurnContextValues(postCtx, activeHooks.PreHookValues) + setRealtimeTurnStreamContext(postCtx, activeHooks.StartedAt, true) + _, hookErr := activeHooks.PostHookRunner(postCtx, nil, postErr) + completeRealtimeTurnTrace(postCtx) + return hookErr + } + + if len(turnInputs) == 0 { + return nil + } + + if client == nil { + return nil + } + + startedAt := time.Now() + preCtx := newRealtimeTurnContext(baseCtx, "", session.ID(), session.ProviderSessionID(), realtimeTurnSourceEI, "", key) + setRealtimeTurnStreamContext(preCtx, startedAt, false) + preReq := buildRealtimeTurnPreRequest(provider, model, turnInputs) + hooks, hookPreErr := client.RunRealtimeTurnPreHooks(preCtx, preReq) + if hookPreErr != nil { + return hookPreErr + } + if hooks.Cleanup != nil { + defer hooks.Cleanup() + } + + requestID, _ := preCtx.Value(schemas.BifrostContextKeyRequestID).(string) + postErr := buildRealtimeTurnPostError( + provider, + model, + rawRequest, + rawResponse, + bifrostErr, + ) + postCtx := newRealtimeTurnContext(baseCtx, requestID, session.ID(), session.ProviderSessionID(), realtimeTurnSourceLM, eventType, key) + applyRealtimeTurnContextValues(postCtx, preCtx.GetUserValues()) + setRealtimeTurnStreamContext(postCtx, startedAt, true) + _, hookErr := hooks.PostHookRunner(postCtx, nil, postErr) + completeRealtimeTurnTrace(postCtx) + return hookErr +} + +func buildRealtimeTurnPostError( + provider schemas.ModelProvider, + model string, + rawRequest string, + rawResponse []byte, + bifrostErr *schemas.BifrostError, +) *schemas.BifrostError { + if bifrostErr == nil { + return nil + } + + copied := *bifrostErr + copied.ExtraFields = bifrostErr.ExtraFields + if bifrostErr.Error != nil { + errorCopy := *bifrostErr.Error + copied.Error = &errorCopy + } + copied.ExtraFields.RequestType = schemas.RealtimeRequest + if copied.ExtraFields.Provider == "" { + copied.ExtraFields.Provider = provider + } + if strings.TrimSpace(copied.ExtraFields.OriginalModelRequested) == "" { + copied.ExtraFields.OriginalModelRequested = model + } + if strings.TrimSpace(rawRequest) != "" && copied.ExtraFields.RawRequest == nil { + copied.ExtraFields.RawRequest = rawRequest + } + if len(rawResponse) > 0 && copied.ExtraFields.RawResponse == nil { + copied.ExtraFields.RawResponse = json.RawMessage(append([]byte(nil), rawResponse...)) + } + return &copied +} + +func newBifrostErrorFromRealtimeError( + provider schemas.ModelProvider, + model string, + rawResponse []byte, + realtimeErr *schemas.RealtimeError, +) *schemas.BifrostError { + if realtimeErr == nil { + return nil + } + + statusCode := 500 + values := []string{ + strings.TrimSpace(realtimeErr.Type), + strings.TrimSpace(realtimeErr.Code), + strings.TrimSpace(realtimeErr.Message), + } + for _, value := range values { + lower := strings.ToLower(value) + switch { + case lower == "": + continue + case strings.Contains(lower, "invalid_request_error"): + statusCode = 400 + case isBudgetOrBillingError(lower), bifrost.IsRateLimitErrorMessage(lower): + statusCode = 429 + } + } + + errType := strings.TrimSpace(realtimeErr.Type) + if errType == "" { + errType = "server_error" + } + errCode := strings.TrimSpace(realtimeErr.Code) + if errCode == "" { + errCode = errType + } + message := strings.TrimSpace(realtimeErr.Message) + if message == "" { + message = "realtime turn failed" + } + + bifrostErr := &schemas.BifrostError{ + IsBifrostError: true, + StatusCode: schemas.Ptr(statusCode), + Type: schemas.Ptr(errType), + Error: &schemas.ErrorField{ + Type: schemas.Ptr(errType), + Code: schemas.Ptr(errCode), + Message: message, + }, + ExtraFields: schemas.BifrostErrorExtraFields{ + Provider: provider, + OriginalModelRequested: model, + RequestType: schemas.RealtimeRequest, + }, + } + if strings.TrimSpace(realtimeErr.Param) != "" { + bifrostErr.Error.Param = realtimeErr.Param + } + if len(rawResponse) > 0 { + bifrostErr.ExtraFields.RawResponse = json.RawMessage(append([]byte(nil), rawResponse...)) + } + return bifrostErr +} + +func completeRealtimeTurnTrace(ctx *schemas.BifrostContext) { + if ctx == nil { + return + } + traceID, _ := ctx.Value(schemas.BifrostContextKeyTraceID).(string) + if strings.TrimSpace(traceID) == "" { + return + } + tracer, _ := ctx.Value(schemas.BifrostContextKeyTracer).(schemas.Tracer) + if tracer == nil { + return + } + tracer.CompleteAndFlushTrace(strings.TrimSpace(traceID)) +} + +func finalizeRealtimeTurnHooksOnTransportError( + client *bifrost.Bifrost, + baseCtx *schemas.BifrostContext, + session *bfws.Session, + provider schemas.ModelProvider, + model string, + key *schemas.Key, + status int, + code string, + message string, +) *schemas.BifrostError { + return finalizeRealtimeTurnHooksWithError( + client, + baseCtx, + session, + provider, + model, + key, + schemas.RTEventError, + nil, + newRealtimeWireBifrostError(status, code, message), + ) +} diff --git a/transports/bifrost-http/handlers/webrtc_realtime.go b/transports/bifrost-http/handlers/webrtc_realtime.go new file mode 100644 index 0000000000..644dbc593f --- /dev/null +++ b/transports/bifrost-http/handlers/webrtc_realtime.go @@ -0,0 +1,1215 @@ +package handlers + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "regexp" + "strconv" + "strings" + "sync" + "time" + + "github.com/fasthttp/router" + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/transports/bifrost-http/integrations" + "github.com/maximhq/bifrost/transports/bifrost-http/lib" + bfws "github.com/maximhq/bifrost/transports/bifrost-http/websocket" + "github.com/pion/rtcp" + "github.com/pion/webrtc/v4" + "github.com/valyala/fasthttp" +) + +const ( + webrtcRealtimeHandshakeTimeout = 10 * time.Second + webrtcRealtimeICEGatherTimeout = 3 * time.Second + webrtcRealtimeMaxPendingMessages = 1000 +) + +var defaultAudioCodec = webrtc.RTPCodecCapability{ + MimeType: webrtc.MimeTypeOpus, + ClockRate: 48000, + Channels: 2, + SDPFmtpLine: "minptime=10;useinbandfec=1", +} + +var realtimeSDPMaxMessageSizePattern = regexp.MustCompile(`(?m)^a=max-message-size:(\d+)\s*$`) + +type WebRTCRealtimeHandler struct { + client *bifrost.Bifrost + config *lib.Config + handlerStore lib.HandlerStore + mu sync.Mutex + relays map[string]*webrtcRealtimeRelay + legacyRoutes map[string]schemas.ModelProvider // path → default provider (legacy raw-SDP routes) +} + +func NewWebRTCRealtimeHandler(client *bifrost.Bifrost, config *lib.Config) *WebRTCRealtimeHandler { + return &WebRTCRealtimeHandler{ + client: client, + config: config, + handlerStore: config, + relays: make(map[string]*webrtcRealtimeRelay), + legacyRoutes: make(map[string]schemas.ModelProvider), + } +} + +func (h *WebRTCRealtimeHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) { + handler := lib.ChainMiddlewares(h.handleRequest, middlewares...) + + // Base bifrost route — GA /calls format (multipart sdp + session) + r.POST("/v1/realtime/calls", handler) + + // OpenAI integration routes — /calls variants (GA format) + for _, path := range integrations.OpenAIRealtimeWebRTCCallsPaths("/openai") { + r.POST(path, handler) + } + + // OpenAI integration routes — legacy variants (raw SDP, beta format) + for _, path := range integrations.OpenAIRealtimePaths("/openai") { + h.legacyRoutes[path] = schemas.OpenAI + r.POST(path, handler) + } +} + +func (h *WebRTCRealtimeHandler) Close() { + if h == nil { + return + } + + h.mu.Lock() + relays := make([]*webrtcRealtimeRelay, 0, len(h.relays)) + for _, relay := range h.relays { + relays = append(relays, relay) + } + h.mu.Unlock() + + for _, relay := range relays { + relay.closeWithShutdownSignal() + } +} + +func (h *WebRTCRealtimeHandler) handleRequest(ctx *fasthttp.RequestCtx) { + if defaultProvider, isLegacy := h.legacyRoutes[string(ctx.Path())]; isLegacy { + h.handleLegacyRequest(ctx, defaultProvider) + } else { + h.handleCallsRequest(ctx) + } +} + +// handleCallsRequest handles the GA /realtime/calls format. +// Multipart bodies strictly require both "sdp" and "session" form fields — +// the model is read from session.model, not from a ?model= query param. +// Raw SDP bodies (application/sdp) fall back to ?model= for the legacy +// raw-SDP path only; the multipart contract has no ?model= fallback. +func (h *WebRTCRealtimeHandler) handleCallsRequest(ctx *fasthttp.RequestCtx) { + sdpOffer, providerKey, model, normalizedSession, bifrostErr := parseCallsWebRTCRequest(ctx) + if bifrostErr != nil { + SendBifrostError(ctx, bifrostErr) + return + } + + rtProvider, bifrostErr := h.resolveWebRTCProvider(providerKey) + if bifrostErr != nil { + SendBifrostError(ctx, bifrostErr) + return + } + + exchangeSDP := func(rCtx *schemas.BifrostContext, key schemas.Key, upstreamOffer string) (string, *schemas.BifrostError) { + return rtProvider.ExchangeRealtimeWebRTCSDP(rCtx, key, model, upstreamOffer, normalizedSession) + } + + h.runWebRTCRelay(ctx, rtProvider, providerKey, model, sdpOffer, exchangeSDP) +} + +func parseCallsWebRTCRequest(ctx *fasthttp.RequestCtx) (string, schemas.ModelProvider, string, []byte, *schemas.BifrostError) { + contentType := strings.ToLower(string(ctx.Request.Header.ContentType())) + path := string(ctx.Path()) + if strings.HasPrefix(contentType, "multipart/form-data") { + form, err := ctx.MultipartForm() + if err != nil { + return "", "", "", nil, newRealtimeWebRTCError(fasthttp.StatusBadRequest, "invalid_request_error", "failed to parse multipart form", err) + } + + sdpOffer := firstMultipartValue(form.Value, "sdp") + if strings.TrimSpace(sdpOffer) == "" { + return "", "", "", nil, newRealtimeWebRTCError(fasthttp.StatusBadRequest, "invalid_request_error", "sdp form field is required", nil) + } + + sessionField := firstMultipartValue(form.Value, "session") + if strings.TrimSpace(sessionField) == "" { + return "", "", "", nil, newRealtimeWebRTCError(fasthttp.StatusBadRequest, "invalid_request_error", "session form field is required", nil) + } + providerKey, model, normalizedSession, bifrostErr := resolveRealtimeSDPTarget(path, []byte(sessionField)) + if bifrostErr != nil { + return "", "", "", nil, bifrostErr + } + return sdpOffer, providerKey, model, normalizedSession, nil + } + + sdpOffer := string(ctx.Request.Body()) + if strings.TrimSpace(sdpOffer) == "" { + return "", "", "", nil, newRealtimeWebRTCError(fasthttp.StatusBadRequest, "invalid_request_error", "SDP is required", nil) + } + + rawModel := strings.TrimSpace(string(ctx.QueryArgs().Peek("model"))) + if rawModel == "" { + return "", "", "", nil, newRealtimeWebRTCError(fasthttp.StatusBadRequest, "invalid_request_error", "model query param is required", nil) + } + + providerKey, model := schemas.ParseModelString(rawModel, realtimeDefaultProviderForPath(path)) + if providerKey == "" || strings.TrimSpace(model) == "" { + if realtimeDefaultProviderForPath(path) == "" { + return "", "", "", nil, newRealtimeWebRTCError(fasthttp.StatusBadRequest, "invalid_request_error", "model must use provider/model on /v1 realtime routes", nil) + } + return "", "", "", nil, newRealtimeWebRTCError(fasthttp.StatusBadRequest, "invalid_request_error", "invalid model: "+rawModel, nil) + } + + return sdpOffer, providerKey, model, nil, nil +} + +// handleLegacyRequest handles the beta /realtime endpoint. +// Accepts both multipart (sdp + session) and raw SDP (application/sdp) from clients. +func (h *WebRTCRealtimeHandler) handleLegacyRequest(ctx *fasthttp.RequestCtx, defaultProvider schemas.ModelProvider) { + sdpOffer, rawModel, sessionJSON, bifrostErr := parseLegacyWebRTCRequest(ctx, defaultProvider) + if bifrostErr != nil { + SendBifrostError(ctx, bifrostErr) + return + } + + providerKey, model := schemas.ParseModelString(rawModel, defaultProvider) + if providerKey == "" || model == "" { + SendBifrostError(ctx, newRealtimeWebRTCError(fasthttp.StatusBadRequest, "invalid_request_error", "invalid model: "+rawModel, nil)) + return + } + + rtProvider, bifrostErr := h.resolveWebRTCProvider(providerKey) + if bifrostErr != nil { + SendBifrostError(ctx, bifrostErr) + return + } + + legacyProvider, ok := rtProvider.(schemas.RealtimeLegacyWebRTCProvider) + if !ok { + SendBifrostError(ctx, newRealtimeWebRTCError(fasthttp.StatusBadRequest, "invalid_request_error", "provider does not support legacy realtime WebRTC: "+string(providerKey), nil)) + return + } + + exchangeSDP := func(rCtx *schemas.BifrostContext, key schemas.Key, upstreamOffer string) (string, *schemas.BifrostError) { + return legacyProvider.ExchangeLegacyRealtimeWebRTCSDP(rCtx, key, upstreamOffer, sessionJSON, model) + } + + h.runWebRTCRelay(ctx, rtProvider, providerKey, model, sdpOffer, exchangeSDP) +} + +// parseLegacyWebRTCRequest extracts SDP, model, and optional session from a legacy request. +// Handles both multipart (sdp + session fields) and raw SDP (body + ?model= query param). +func parseLegacyWebRTCRequest(ctx *fasthttp.RequestCtx, defaultProvider schemas.ModelProvider) (sdpOffer, rawModel string, sessionJSON json.RawMessage, err *schemas.BifrostError) { + if strings.HasPrefix(strings.ToLower(string(ctx.Request.Header.ContentType())), "multipart/form-data") { + form, formErr := ctx.MultipartForm() + if formErr != nil { + return "", "", nil, newRealtimeWebRTCError(fasthttp.StatusBadRequest, "invalid_request_error", "failed to parse multipart form", formErr) + } + sdpOffer = firstMultipartValue(form.Value, "sdp") + if sessionField := firstMultipartValue(form.Value, "session"); sessionField != "" { + sessionJSON = json.RawMessage(sessionField) + if root, parseErr := schemas.ParseRealtimeClientSecretBody(sessionJSON); parseErr == nil { + if modelJSON, ok := root["model"]; ok { + var m string + if json.Unmarshal(modelJSON, &m) == nil { + rawModel = m + } + } + } + } + } else { + sdpOffer = string(ctx.Request.Body()) + } + + if strings.TrimSpace(sdpOffer) == "" { + return "", "", nil, newRealtimeWebRTCError(fasthttp.StatusBadRequest, "invalid_request_error", "SDP is required", nil) + } + + // Query param model takes precedence + if queryModel := strings.TrimSpace(string(ctx.QueryArgs().Peek("model"))); queryModel != "" { + rawModel = queryModel + } + if rawModel == "" { + return "", "", nil, newRealtimeWebRTCError(fasthttp.StatusBadRequest, "invalid_request_error", "model is required (query param or session field)", nil) + } + + return sdpOffer, rawModel, sessionJSON, nil +} + +// runWebRTCRelay is the shared relay setup: creates bifrost context, selects key, establishes relay. +func (h *WebRTCRealtimeHandler) runWebRTCRelay( + ctx *fasthttp.RequestCtx, + rtProvider schemas.RealtimeProvider, + providerKey schemas.ModelProvider, + model string, + sdpOffer string, + exchangeSDP func(ctx *schemas.BifrostContext, key schemas.Key, upstreamOffer string) (string, *schemas.BifrostError), +) { + bifrostCtx, cancel := lib.ConvertToBifrostContext( + ctx, + h.handlerStore.ShouldAllowDirectKeys(), + h.config.GetHeaderMatcher(), + h.config.GetMCPHeaderCombinedAllowlist(), + ) + defer cancel() + bifrostCtx.SetValue(schemas.BifrostContextKeyHTTPRequestType, schemas.RealtimeRequest) + if strings.HasPrefix(string(ctx.Path()), "/openai") { + bifrostCtx.SetValue(schemas.BifrostContextKeyIntegrationType, "openai") + } + + authKey, selectedKey, err := h.resolveRealtimeWebRTCKeys(ctx, bifrostCtx, providerKey, model) + if err != nil { + SendBifrostError(ctx, newRealtimeWebRTCError(fasthttp.StatusBadRequest, "invalid_request_error", err.Error(), nil)) + return + } + + // Resolve model alias so the provider receives the actual model identifier. + if selectedKey != nil { + model = selectedKey.Aliases.Resolve(model) + } else { + model = authKey.Aliases.Resolve(model) + } + + boundExchange := func(rCtx *schemas.BifrostContext, upstreamOffer string) (string, *schemas.BifrostError) { + return exchangeSDP(rCtx, authKey, upstreamOffer) + } + + relayCtx, relayCancel := newRealtimeRelayContext(bifrostCtx) + session := bfws.NewSession(nil) + browserAnswer, relayErr := h.establishRelay(relayCtx, relayCancel, session, rtProvider, providerKey, model, selectedKey, sdpOffer, boundExchange) + if relayErr != nil { + relayCancel() + SendBifrostError(ctx, relayErr) + return + } + + ctx.SetStatusCode(fasthttp.StatusOK) + ctx.SetContentType("application/sdp") + ctx.SetBodyString(browserAnswer) +} + +func (h *WebRTCRealtimeHandler) resolveRealtimeWebRTCKeys( + ctx *fasthttp.RequestCtx, + bifrostCtx *schemas.BifrostContext, + providerKey schemas.ModelProvider, + model string, +) (schemas.Key, *schemas.Key, error) { + inboundToken := extractRealtimeBearerToken(ctx) + mapping, mapped := lookupRealtimeEphemeralKeyMapping(h.handlerStore.GetKVStore(), inboundToken) + if mapped { + applyRealtimeEphemeralKeyMapping(bifrostCtx, mapping) + } + if isRealtimeEphemeralToken(inboundToken) && !mapped { + bifrostCtx.ClearValue(schemas.BifrostContextKeyDirectKey) + bifrostCtx.ClearValue(schemas.BifrostContextKeyAPIKeyID) + bifrostCtx.ClearValue(schemas.BifrostContextKeyAPIKeyName) + bifrostCtx.ClearValue(schemas.BifrostContextKeySelectedKeyID) + bifrostCtx.ClearValue(schemas.BifrostContextKeySelectedKeyName) + authKey := schemas.Key{Value: *schemas.NewEnvVar(inboundToken)} + return authKey, nil, nil + } + + selectedKey, err := h.client.SelectKeyForProviderRequestType(bifrostCtx, schemas.RealtimeRequest, providerKey, model) + if err != nil && mapped && mapping.KeyID != "" { + bifrostCtx.ClearValue(schemas.BifrostContextKeyAPIKeyID) + selectedKey, err = h.client.SelectKeyForProviderRequestType(bifrostCtx, schemas.RealtimeRequest, providerKey, model) + } + if err != nil { + return schemas.Key{}, nil, err + } + + authKey := selectedKey + if mapped && inboundToken != "" { + authKey.Value = *schemas.NewEnvVar(inboundToken) + } + return authKey, &selectedKey, nil +} + +func lookupRealtimeEphemeralKeyMapping(kv schemas.KVStore, token string) (realtimeEphemeralKeyMapping, bool) { + if kv == nil || strings.TrimSpace(token) == "" { + return realtimeEphemeralKeyMapping{}, false + } + + raw, err := kv.Get(buildRealtimeEphemeralKeyMappingKey(token)) + if err != nil { + return realtimeEphemeralKeyMapping{}, false + } + + switch value := raw.(type) { + case string: + return parseRealtimeEphemeralKeyMappingValue([]byte(value)) + case []byte: + return parseRealtimeEphemeralKeyMappingValue(value) + default: + return realtimeEphemeralKeyMapping{}, false + } +} + +func parseRealtimeEphemeralKeyMappingValue(raw []byte) (realtimeEphemeralKeyMapping, bool) { + raw = []byte(strings.TrimSpace(string(raw))) + if len(raw) == 0 { + return realtimeEphemeralKeyMapping{}, false + } + + var mapping realtimeEphemeralKeyMapping + if err := json.Unmarshal(raw, &mapping); err == nil { + mapping.KeyID = strings.TrimSpace(mapping.KeyID) + mapping.VirtualKey = strings.TrimSpace(mapping.VirtualKey) + if mapping.KeyID != "" || mapping.VirtualKey != "" { + return mapping, true + } + } + + var keyID string + if err := json.Unmarshal(raw, &keyID); err == nil { + keyID = strings.TrimSpace(keyID) + if keyID != "" { + return realtimeEphemeralKeyMapping{KeyID: keyID}, true + } + } + + keyID = strings.TrimSpace(string(raw)) + if keyID == "" { + return realtimeEphemeralKeyMapping{}, false + } + return realtimeEphemeralKeyMapping{KeyID: keyID}, true +} + +func applyRealtimeEphemeralKeyMapping(bifrostCtx *schemas.BifrostContext, mapping realtimeEphemeralKeyMapping) { + if bifrostCtx == nil { + return + } + if mapping.VirtualKey != "" { + bifrostCtx.SetValue(schemas.BifrostContextKeyVirtualKey, mapping.VirtualKey) + } + if mapping.KeyID != "" { + bifrostCtx.SetValue(schemas.BifrostContextKeyAPIKeyID, mapping.KeyID) + } +} + +func extractRealtimeBearerToken(ctx *fasthttp.RequestCtx) string { + if ctx == nil { + return "" + } + return extractRealtimeBearerTokenFromHeader(string(ctx.Request.Header.Peek("Authorization"))) +} + +func extractRealtimeBearerTokenFromHeader(authHeader string) string { + authHeader = strings.TrimSpace(authHeader) + if len(authHeader) < len("Bearer ")+1 || !strings.HasPrefix(strings.ToLower(authHeader), "bearer ") { + return "" + } + return strings.TrimSpace(authHeader[7:]) +} + +func isRealtimeEphemeralToken(token string) bool { + return strings.HasPrefix(strings.TrimSpace(token), "ek_") +} + +// resolveWebRTCProvider validates and returns a RealtimeProvider that supports WebRTC. +func (h *WebRTCRealtimeHandler) resolveWebRTCProvider(providerKey schemas.ModelProvider) (schemas.RealtimeProvider, *schemas.BifrostError) { + provider := h.client.GetProviderByKey(providerKey) + if provider == nil { + return nil, newRealtimeWebRTCError(fasthttp.StatusBadRequest, "invalid_request_error", "provider not found: "+string(providerKey), nil) + } + + rtProvider, ok := provider.(schemas.RealtimeProvider) + if !ok || !rtProvider.SupportsRealtimeAPI() { + return nil, newRealtimeWebRTCError(fasthttp.StatusBadRequest, "invalid_request_error", "provider does not support realtime: "+string(providerKey), nil) + } + + if !rtProvider.SupportsRealtimeWebRTC() { + return nil, newRealtimeWebRTCError(fasthttp.StatusBadRequest, "invalid_request_error", "provider does not support realtime WebRTC: "+string(providerKey), nil) + } + + return rtProvider, nil +} + +// establishRelay sets up the bidirectional WebRTC relay between the browser and the upstream provider. +// exchangeSDP is called with the upstream peer connection's SDP offer and must return the provider's +// SDP answer. This allows the handler to plug in different exchange strategies (GA calls vs legacy). +func (h *WebRTCRealtimeHandler) establishRelay( + relayCtx *schemas.BifrostContext, + relayCancel context.CancelFunc, + session *bfws.Session, + provider schemas.RealtimeProvider, + providerKey schemas.ModelProvider, + model string, + key *schemas.Key, + browserOffer string, + exchangeSDP func(ctx *schemas.BifrostContext, upstreamOffer string) (string, *schemas.BifrostError), +) (string, *schemas.BifrostError) { + downstreamPC, err := newRealtimePeerConnection() + if err != nil { + return "", newRealtimeWebRTCError(fasthttp.StatusInternalServerError, "server_error", "failed to create browser peer connection", err) + } + upstreamPC, err := newRealtimePeerConnection() + if err != nil { + _ = downstreamPC.Close() + return "", newRealtimeWebRTCError(fasthttp.StatusInternalServerError, "server_error", "failed to create upstream peer connection", err) + } + + relay := &webrtcRealtimeRelay{ + client: h.client, + downstreamPC: downstreamPC, + upstreamPC: upstreamPC, + session: session, + bifrostCtx: relayCtx, + cancel: relayCancel, + provider: provider, + providerKey: providerKey, + model: model, + key: key, + } + relay.onClose = func() { + h.unregisterRelay(session.ID()) + } + relay.installCloseHandlers() + h.registerRelay(session.ID(), relay) + + // Downstream local audio track carries provider audio back to the browser. + providerToBrowserTrack, err := webrtc.NewTrackLocalStaticRTP(defaultAudioCodec, "audio", "bifrost-provider-audio") + if err != nil { + relay.close() + return "", newRealtimeWebRTCError(fasthttp.StatusInternalServerError, "server_error", "failed to create browser audio track", err) + } + providerToBrowserSender, err := downstreamPC.AddTrack(providerToBrowserTrack) + if err != nil { + relay.close() + return "", newRealtimeWebRTCError(fasthttp.StatusInternalServerError, "server_error", "failed to attach browser audio track", err) + } + relay.providerToBrowserTrack = providerToBrowserTrack + go relay.forwardRTCP(providerToBrowserSender, upstreamPC) + + // Upstream local audio track carries browser audio to the provider. + browserToProviderTrack, err := webrtc.NewTrackLocalStaticRTP(defaultAudioCodec, "audio", "bifrost-browser-audio") + if err != nil { + relay.close() + return "", newRealtimeWebRTCError(fasthttp.StatusInternalServerError, "server_error", "failed to create provider audio track", err) + } + browserToProviderSender, err := upstreamPC.AddTrack(browserToProviderTrack) + if err != nil { + relay.close() + return "", newRealtimeWebRTCError(fasthttp.StatusInternalServerError, "server_error", "failed to attach provider audio track", err) + } + relay.browserToProviderTrack = browserToProviderTrack + go relay.forwardRTCP(browserToProviderSender, downstreamPC) + + relay.installTrackForwarders() + if err := relay.installDataChannelRelay(); err != nil { + relay.close() + return "", newRealtimeWebRTCError(fasthttp.StatusInternalServerError, "server_error", "failed to create upstream realtime data channel", err) + } + + if err := downstreamPC.SetRemoteDescription(webrtc.SessionDescription{ + Type: webrtc.SDPTypeOffer, + SDP: browserOffer, + }); err != nil { + relay.close() + return "", newRealtimeWebRTCError(fasthttp.StatusBadRequest, "invalid_request_error", "invalid browser SDP offer", err) + } + + upstreamOffer, err := relay.createOffer(upstreamPC) + if err != nil { + relay.close() + return "", newRealtimeWebRTCError(fasthttp.StatusInternalServerError, "server_error", "failed to create upstream SDP offer", err) + } + upstreamOffer = constrainRealtimeSDPMaxMessageSize(upstreamOffer, browserOffer) + + upstreamAnswer, exchangeErr := exchangeSDP(relayCtx, upstreamOffer) + if exchangeErr != nil { + relay.close() + return "", exchangeErr + } + + if err := upstreamPC.SetRemoteDescription(webrtc.SessionDescription{ + Type: webrtc.SDPTypeAnswer, + SDP: upstreamAnswer, + }); err != nil { + relay.close() + return "", newRealtimeWebRTCError(fasthttp.StatusBadGateway, "upstream_connection_error", "invalid upstream SDP answer", err) + } + + waitCtx, waitCancel := context.WithTimeout(relayCtx, webrtcRealtimeHandshakeTimeout) + defer waitCancel() + + if err := relay.waitForUpstream(waitCtx); err != nil { + relay.close() + return "", newRealtimeWebRTCError(fasthttp.StatusBadGateway, "upstream_connection_error", "upstream realtime WebRTC connection failed", err) + } + + browserAnswer, err := relay.createAnswer(downstreamPC) + if err != nil { + relay.close() + return "", newRealtimeWebRTCError(fasthttp.StatusInternalServerError, "server_error", "failed to create browser SDP answer", err) + } + + return browserAnswer, nil +} + +type webrtcRealtimeRelay struct { + client *bifrost.Bifrost + downstreamPC *webrtc.PeerConnection + upstreamPC *webrtc.PeerConnection + + downstreamChannel *webrtc.DataChannel + upstreamChannel *webrtc.DataChannel + + providerToBrowserTrack *webrtc.TrackLocalStaticRTP + browserToProviderTrack *webrtc.TrackLocalStaticRTP + + session *bfws.Session + bifrostCtx *schemas.BifrostContext + cancel context.CancelFunc + provider schemas.RealtimeProvider + providerKey schemas.ModelProvider + model string + key *schemas.Key + onClose func() + + closeOnce sync.Once + + channelMu sync.Mutex + pendingToUpstream []queuedDataChannelMessage + pendingToDownstream []queuedDataChannelMessage + upstreamConnectedOrError chan error +} + +type queuedDataChannelMessage struct { + payload []byte + isString bool +} + +func (r *webrtcRealtimeRelay) installCloseHandlers() { + r.upstreamConnectedOrError = make(chan error, 1) + + handleState := func(name string, pc *webrtc.PeerConnection) { + pc.OnConnectionStateChange(func(state webrtc.PeerConnectionState) { + switch state { + case webrtc.PeerConnectionStateConnected: + if name == "upstream" { + select { + case r.upstreamConnectedOrError <- nil: + default: + } + } + case webrtc.PeerConnectionStateFailed, webrtc.PeerConnectionStateClosed: + if name == "upstream" { + select { + case r.upstreamConnectedOrError <- fmt.Errorf("peer connection state %s", state.String()): + default: + } + } + r.close() + case webrtc.PeerConnectionStateDisconnected: + r.close() + } + }) + } + + handleState("downstream", r.downstreamPC) + handleState("upstream", r.upstreamPC) +} + +func (r *webrtcRealtimeRelay) installTrackForwarders() { + r.downstreamPC.OnTrack(func(track *webrtc.TrackRemote, _ *webrtc.RTPReceiver) { + if track.Kind() != webrtc.RTPCodecTypeAudio { + return + } + r.forwardRTPTrack(track, r.browserToProviderTrack) + }) + + r.upstreamPC.OnTrack(func(track *webrtc.TrackRemote, _ *webrtc.RTPReceiver) { + if track.Kind() != webrtc.RTPCodecTypeAudio { + return + } + r.forwardRTPTrack(track, r.providerToBrowserTrack) + }) +} + +func (r *webrtcRealtimeRelay) installDataChannelRelay() error { + label := strings.TrimSpace(r.provider.RealtimeWebRTCDataChannelLabel()) + if label == "" { + return nil + } + upstreamDC, err := r.upstreamPC.CreateDataChannel(label, nil) + if err != nil { + return err + } + r.bindUpstreamChannel(upstreamDC) + + r.downstreamPC.OnDataChannel(func(dc *webrtc.DataChannel) { + r.bindDownstreamChannel(dc) + }) + return nil +} + +func (r *webrtcRealtimeRelay) bindUpstreamChannel(dc *webrtc.DataChannel) { + r.channelMu.Lock() + r.upstreamChannel = dc + r.channelMu.Unlock() + + dc.OnOpen(func() { + r.flushPending() + }) + dc.OnMessage(func(msg webrtc.DataChannelMessage) { + r.handleUpstreamMessage(msg) + }) + dc.OnClose(func() { r.close() }) + dc.OnError(func(err error) { + logger.Warn("upstream realtime data channel error: %v", err) + r.close() + }) +} + +func (r *webrtcRealtimeRelay) bindDownstreamChannel(dc *webrtc.DataChannel) { + r.channelMu.Lock() + if r.downstreamChannel != nil { + r.channelMu.Unlock() + _ = dc.Close() + return + } + r.downstreamChannel = dc + r.channelMu.Unlock() + + dc.OnOpen(func() { + r.flushPending() + }) + dc.OnMessage(func(msg webrtc.DataChannelMessage) { + r.handleDownstreamMessage(msg) + }) + dc.OnClose(func() { r.close() }) + dc.OnError(func(err error) { + logger.Warn("browser realtime data channel error: %v", err) + r.close() + }) +} + +func (r *webrtcRealtimeRelay) createOffer(pc *webrtc.PeerConnection) (string, error) { + offer, err := pc.CreateOffer(nil) + if err != nil { + return "", err + } + gatherComplete := webrtc.GatheringCompletePromise(pc) + if err := pc.SetLocalDescription(offer); err != nil { + return "", err + } + select { + case <-gatherComplete: + case <-time.After(webrtcRealtimeICEGatherTimeout): + } + if pc.LocalDescription() == nil { + return "", errors.New("local description not set") + } + return pc.LocalDescription().SDP, nil +} + +func (r *webrtcRealtimeRelay) createAnswer(pc *webrtc.PeerConnection) (string, error) { + answer, err := pc.CreateAnswer(nil) + if err != nil { + return "", err + } + gatherComplete := webrtc.GatheringCompletePromise(pc) + if err := pc.SetLocalDescription(answer); err != nil { + return "", err + } + select { + case <-gatherComplete: + case <-time.After(webrtcRealtimeICEGatherTimeout): + } + if pc.LocalDescription() == nil { + return "", errors.New("local description not set") + } + return pc.LocalDescription().SDP, nil +} + +func (r *webrtcRealtimeRelay) waitForUpstream(ctx context.Context) error { + select { + case err := <-r.upstreamConnectedOrError: + return err + case <-ctx.Done(): + return ctx.Err() + } +} + +func (r *webrtcRealtimeRelay) forwardRTPTrack(track *webrtc.TrackRemote, target *webrtc.TrackLocalStaticRTP) { + for { + packet, _, err := track.ReadRTP() + if err != nil { + return + } + if err := target.WriteRTP(packet); err != nil { + return + } + } +} + +func (r *webrtcRealtimeRelay) forwardRTCP(sender *webrtc.RTPSender, target *webrtc.PeerConnection) { + if sender == nil || target == nil { + return + } + buf := make([]byte, 1500) + for { + n, _, readErr := sender.Read(buf) + if readErr != nil { + return + } + pkts, parseErr := rtcp.Unmarshal(buf[:n]) + if parseErr != nil { + continue + } + if writeErr := target.WriteRTCP(pkts); writeErr != nil { + return + } + } +} + +func (r *webrtcRealtimeRelay) handleDownstreamMessage(msg webrtc.DataChannelMessage) { + event, err := schemas.ParseRealtimeEvent(msg.Data) + if err != nil { + logger.Warn("failed to parse browser realtime event: %v", err) + r.sendUpstream(msg.Data, msg.IsString) + return + } + toolItemID, toolSummary := pendingRealtimeToolOutputUpdate(event) + if toolSummary != "" { + r.session.RecordRealtimeToolOutput(toolItemID, toolSummary, string(msg.Data)) + } + inputItemID, inputSummary := pendingRealtimeInputUpdate(event) + if inputSummary != "" { + r.session.RecordRealtimeInput(inputItemID, inputSummary, string(msg.Data)) + } + startsTurn := r.provider.ShouldStartRealtimeTurn(event) + if startsTurn { + if r.session.PeekRealtimeTurnHooks() != nil { + r.sendDownstream(newRealtimeTurnErrorEventPayload(newRealtimeWireBifrostError(400, "invalid_request_error", "Conversation already has an active response in progress.")), true) + return + } + if bifrostErr := startRealtimeTurnHooks(r.client, r.bifrostCtx, r.session, r.provider, r.providerKey, r.model, r.key, event.Type); bifrostErr != nil { + r.closeWithErrorEvent(newRealtimeTurnErrorEventPayload(bifrostErr)) + return + } + } + + providerEvent, err := r.provider.ToProviderRealtimeEvent(event) + if err != nil { + if startsTurn { + if finalizeErr := finalizeRealtimeTurnHooksOnTransportError( + r.client, + r.bifrostCtx, + r.session, + r.providerKey, + r.model, + r.key, + 400, + "invalid_request_error", + err.Error(), + ); finalizeErr != nil { + r.closeWithErrorEvent(newRealtimeTurnErrorEventPayload(finalizeErr)) + return + } + r.closeWithErrorEvent(newRealtimeTurnErrorEventPayload(newRealtimeWireBifrostError(400, "invalid_request_error", err.Error()))) + return + } + logger.Warn("failed to translate browser realtime event: %v", err) + r.sendUpstream(msg.Data, msg.IsString) + return + } + r.sendUpstream(providerEvent, msg.IsString) +} + +func (r *webrtcRealtimeRelay) handleUpstreamMessage(msg webrtc.DataChannelMessage) { + event, err := r.provider.ToBifrostRealtimeEvent(msg.Data) + if err != nil { + if finalizeErr := finalizeRealtimeTurnHooksOnTransportError( + r.client, + r.bifrostCtx, + r.session, + r.providerKey, + r.model, + r.key, + 502, + "server_error", + "failed to translate upstream realtime event", + ); finalizeErr != nil { + r.closeWithErrorEvent(newRealtimeTurnErrorEventPayload(finalizeErr)) + return + } + logger.Warn("failed to translate upstream realtime event: %v", err) + r.closeWithErrorEvent(newRealtimeTurnErrorEventPayload(newRealtimeWireBifrostError(502, "server_error", "failed to translate upstream realtime event"))) + return + } + if event != nil { + if event.Session != nil && event.Session.ID != "" { + r.session.SetProviderSessionID(event.Session.ID) + } + inputItemID, inputSummary := pendingRealtimeInputUpdate(event) + if inputSummary != "" { + r.session.RecordRealtimeInput(inputItemID, inputSummary, string(msg.Data)) + } + if event.Delta != nil && r.provider.ShouldAccumulateRealtimeOutput(event.Type) { + r.session.AppendRealtimeOutputText(event.Delta.Text) + r.session.AppendRealtimeOutputText(event.Delta.Transcript) + } + if r.provider.ShouldStartRealtimeTurn(event) && r.session.PeekRealtimeTurnHooks() == nil { + if bifrostErr := startRealtimeTurnHooks(r.client, r.bifrostCtx, r.session, r.provider, r.providerKey, r.model, r.key, event.Type); bifrostErr != nil { + r.closeWithErrorEvent(newRealtimeTurnErrorEventPayload(bifrostErr)) + return + } + } + } + if event != nil { + if !r.provider.ShouldForwardRealtimeEvent(event) { + return + } + if event.Type == r.provider.RealtimeTurnFinalEvent() { + contentOverride := r.session.ConsumeRealtimeOutputText() + if bifrostErr := finalizeRealtimeTurnHooks(r.client, r.bifrostCtx, r.session, r.provider, r.providerKey, r.model, r.key, msg.Data, contentOverride); bifrostErr != nil { + r.closeWithErrorEvent(newRealtimeTurnErrorEventPayload(bifrostErr)) + return + } + } else if event.Error != nil { + if finalizeErr := finalizeRealtimeTurnHooksWithError( + r.client, + r.bifrostCtx, + r.session, + r.providerKey, + r.model, + r.key, + event.Type, + msg.Data, + newBifrostErrorFromRealtimeError(r.providerKey, r.model, msg.Data, event.Error), + ); finalizeErr != nil { + r.closeWithErrorEvent(newRealtimeTurnErrorEventPayload(finalizeErr)) + return + } + } + msg.Data, err = r.provider.ToProviderRealtimeEvent(event) + if err != nil { + logger.Warn("failed to encode translated realtime event: %v", err) + // Lifecycle events (response.done / error) must reach the client so it + // can transition turn state — if encoding fails after the turn was + // finalized server-side, swallowing this would leave the client hung. + r.closeWithErrorEvent(newRealtimeTurnErrorEventPayload( + newRealtimeWireBifrostError(502, "server_error", "failed to encode translated realtime event: "+err.Error()), + )) + return + } + } + + r.sendDownstream(msg.Data, msg.IsString) +} + +func (r *webrtcRealtimeRelay) sendUpstream(payload []byte, isString bool) { + r.channelMu.Lock() + defer r.channelMu.Unlock() + if isDataChannelOpen(r.upstreamChannel) { + sendDataChannelMessage(r.upstreamChannel, payload, isString) + return + } + if len(r.pendingToUpstream) >= webrtcRealtimeMaxPendingMessages { + logger.Warn("upstream pending buffer exceeded %d messages, closing relay", webrtcRealtimeMaxPendingMessages) + go r.close() + return + } + r.pendingToUpstream = append(r.pendingToUpstream, queuedDataChannelMessage{payload: append([]byte(nil), payload...), isString: isString}) +} + +func (r *webrtcRealtimeRelay) sendDownstream(payload []byte, isString bool) { + r.channelMu.Lock() + defer r.channelMu.Unlock() + if isDataChannelOpen(r.downstreamChannel) { + sendDataChannelMessage(r.downstreamChannel, payload, isString) + return + } + if len(r.pendingToDownstream) >= webrtcRealtimeMaxPendingMessages { + logger.Warn("downstream pending buffer exceeded %d messages, closing relay", webrtcRealtimeMaxPendingMessages) + go r.close() + return + } + r.pendingToDownstream = append(r.pendingToDownstream, queuedDataChannelMessage{payload: append([]byte(nil), payload...), isString: isString}) +} + +func (r *webrtcRealtimeRelay) flushPending() { + r.channelMu.Lock() + defer r.channelMu.Unlock() + + if isDataChannelOpen(r.upstreamChannel) && len(r.pendingToUpstream) > 0 { + for _, msg := range r.pendingToUpstream { + sendDataChannelMessage(r.upstreamChannel, msg.payload, msg.isString) + } + r.pendingToUpstream = nil + } + if isDataChannelOpen(r.downstreamChannel) && len(r.pendingToDownstream) > 0 { + for _, msg := range r.pendingToDownstream { + sendDataChannelMessage(r.downstreamChannel, msg.payload, msg.isString) + } + r.pendingToDownstream = nil + } +} + +func (r *webrtcRealtimeRelay) close() { + r.closeOnce.Do(func() { + if r.session != nil { + _ = finalizeRealtimeTurnHooksOnTransportError( + r.client, + r.bifrostCtx, + r.session, + r.providerKey, + r.model, + r.key, + 502, + "connection_closed", + "realtime WebRTC session closed before turn completed", + ) + r.session.ClearRealtimeTurnHooks() + } + + if r.onClose != nil { + r.onClose() + } + if r.cancel != nil { + r.cancel() + } + + r.channelMu.Lock() + if r.downstreamChannel != nil { + _ = r.downstreamChannel.Close() + } + if r.upstreamChannel != nil { + _ = r.upstreamChannel.Close() + } + r.channelMu.Unlock() + + if r.downstreamPC != nil { + _ = r.downstreamPC.Close() + } + if r.upstreamPC != nil { + _ = r.upstreamPC.Close() + } + }) +} + +func (r *webrtcRealtimeRelay) closeWithShutdownSignal() { + r.close() +} + +func (r *webrtcRealtimeRelay) closeWithErrorEvent(payload []byte) { + r.channelMu.Lock() + dc := r.downstreamChannel + r.channelMu.Unlock() + + if isDataChannelOpen(dc) && len(payload) > 0 { + sendDataChannelMessage(dc, payload, true) + go func() { + time.Sleep(100 * time.Millisecond) + r.close() + }() + return + } + + r.close() +} + +func (h *WebRTCRealtimeHandler) registerRelay(sessionID string, relay *webrtcRealtimeRelay) { + if strings.TrimSpace(sessionID) == "" || relay == nil { + return + } + h.mu.Lock() + defer h.mu.Unlock() + h.relays[sessionID] = relay +} + +func (h *WebRTCRealtimeHandler) unregisterRelay(sessionID string) { + if strings.TrimSpace(sessionID) == "" { + return + } + h.mu.Lock() + defer h.mu.Unlock() + delete(h.relays, sessionID) +} + +func newRealtimeRelayContext(requestCtx *schemas.BifrostContext) (*schemas.BifrostContext, context.CancelFunc) { + relayCtx, cancel := schemas.NewBifrostContextWithCancel(context.Background()) + if requestCtx == nil { + return relayCtx, cancel + } + + for _, key := range []any{ + schemas.BifrostContextKeyRequestID, + schemas.BifrostContextKeyHTTPRequestType, + schemas.BifrostContextKeyIntegrationType, + schemas.BifrostContextKeyParentRequestID, + schemas.BifrostContextKeyVirtualKey, + schemas.BifrostContextKeyAPIKeyName, + schemas.BifrostContextKeyAPIKeyID, + schemas.BifrostContextKeyDirectKey, + schemas.BifrostContextKeyExtraHeaders, + schemas.BifrostContextKeyRequestHeaders, + schemas.BifrostContextKeyUserAgent, + schemas.BifrostContextKeyGovernanceVirtualKeyID, + schemas.BifrostContextKeyGovernanceVirtualKeyName, + schemas.BifrostContextKeyGovernanceRoutingRuleID, + schemas.BifrostContextKeyGovernanceRoutingRuleName, + schemas.BifrostContextKeyGovernanceCustomerID, + schemas.BifrostContextKeyGovernanceCustomerName, + schemas.BifrostContextKeyGovernanceTeamID, + schemas.BifrostContextKeyGovernanceTeamName, + schemas.BifrostContextKeyGovernanceUserID, + schemas.BifrostContextKeyGovernanceIncludeOnlyKeys, + schemas.BifrostContextKeyGovernancePluginName, + schemas.BifrostContextKeySelectedKeyID, + schemas.BifrostContextKeySelectedKeyName, + schemas.BifrostContextKeyIsEnterprise, + } { + if value := requestCtx.Value(key); value != nil { + relayCtx.SetValue(key, value) + } + } + + return relayCtx, cancel +} + +func newRealtimePeerConnection() (*webrtc.PeerConnection, error) { + return webrtc.NewPeerConnection(webrtc.Configuration{}) +} + +func isDataChannelOpen(dc *webrtc.DataChannel) bool { + return dc != nil && dc.ReadyState() == webrtc.DataChannelStateOpen +} + +func realtimeEventTypeFromPayload(payload []byte) string { + var envelope struct { + Type string `json:"type"` + } + if err := json.Unmarshal(payload, &envelope); err != nil { + return "" + } + return strings.TrimSpace(envelope.Type) +} + +func parseRealtimeSDPMaxMessageSize(sdp string) (int64, bool) { + matches := realtimeSDPMaxMessageSizePattern.FindStringSubmatch(sdp) + if len(matches) < 2 { + return 0, false + } + size, err := strconv.ParseInt(matches[1], 10, 64) + if err != nil || size <= 0 { + return 0, false + } + return size, true +} + +func setRealtimeSDPMaxMessageSize(sdp string, maxMessageSize int64) string { + line := "a=max-message-size:" + strconv.FormatInt(maxMessageSize, 10) + if realtimeSDPMaxMessageSizePattern.MatchString(sdp) { + return realtimeSDPMaxMessageSizePattern.ReplaceAllString(sdp, line) + } + if strings.Contains(sdp, "\r\nm=application ") { + return strings.Replace(sdp, "\r\nm=application ", "\r\n"+line+"\r\nm=application ", 1) + } + if strings.Contains(sdp, "\nm=application ") { + return strings.Replace(sdp, "\nm=application ", "\n"+line+"\nm=application ", 1) + } + return sdp +} + +func constrainRealtimeSDPMaxMessageSize(upstreamOffer string, browserOffer string) string { + browserMax, ok := parseRealtimeSDPMaxMessageSize(browserOffer) + if !ok { + return upstreamOffer + } + + upstreamMax, ok := parseRealtimeSDPMaxMessageSize(upstreamOffer) + if ok && upstreamMax <= browserMax { + return upstreamOffer + } + + return setRealtimeSDPMaxMessageSize(upstreamOffer, browserMax) +} + +func sendDataChannelMessage(dc *webrtc.DataChannel, payload []byte, isString bool) { + if dc == nil { + return + } + var err error + if isString { + err = dc.SendText(string(payload)) + } else { + err = dc.Send(payload) + } + if err != nil { + eventType := realtimeEventTypeFromPayload(payload) + if eventType != "" { + logger.Warn("failed to send realtime data channel message: type=%s size=%d bytes err=%v", eventType, len(payload), err) + return + } + logger.Warn("failed to send realtime data channel message: size=%d bytes err=%v", len(payload), err) + } +} + +func resolveRealtimeSDPTarget(path string, sessionJSON []byte) (schemas.ModelProvider, string, []byte, *schemas.BifrostError) { + root, err := schemas.ParseRealtimeClientSecretBody(sessionJSON) + if err != nil { + return "", "", nil, err + } + + modelJSON, ok := root["model"] + if !ok { + return "", "", nil, newRealtimeWebRTCError(fasthttp.StatusBadRequest, "invalid_request_error", "session.model is required", nil) + } + + var rawModel string + if err := json.Unmarshal(modelJSON, &rawModel); err != nil { + return "", "", nil, newRealtimeWebRTCError(fasthttp.StatusBadRequest, "invalid_request_error", "session.model must be a string", err) + } + + providerKey, model := schemas.ParseModelString(strings.TrimSpace(rawModel), realtimeDefaultProviderForPath(path)) + if providerKey == "" || strings.TrimSpace(model) == "" { + if realtimeDefaultProviderForPath(path) == "" { + return "", "", nil, newRealtimeWebRTCError(fasthttp.StatusBadRequest, "invalid_request_error", "session.model must use provider/model on /v1 realtime routes", nil) + } + return "", "", nil, newRealtimeWebRTCError(fasthttp.StatusBadRequest, "invalid_request_error", "session.model is required", nil) + } + + normalizedModel, marshalErr := json.Marshal(model) + if marshalErr != nil { + return "", "", nil, newRealtimeWebRTCError(fasthttp.StatusInternalServerError, "server_error", "failed to encode normalized session model", marshalErr) + } + root["model"] = normalizedModel + normalizedSession, marshalErr := json.Marshal(root) + if marshalErr != nil { + return "", "", nil, newRealtimeWebRTCError(fasthttp.StatusInternalServerError, "server_error", "failed to encode normalized realtime session", marshalErr) + } + + return providerKey, strings.TrimSpace(model), normalizedSession, nil +} + +func firstMultipartValue(values map[string][]string, key string) string { + if len(values[key]) == 0 { + return "" + } + return values[key][0] +} + +func newRealtimeWebRTCError(status int, errorType, message string, err error) *schemas.BifrostError { + return &schemas.BifrostError{ + IsBifrostError: false, + StatusCode: schemas.Ptr(status), + Error: &schemas.ErrorField{ + Type: schemas.Ptr(errorType), + Message: message, + Error: err, + }, + ExtraFields: schemas.BifrostErrorExtraFields{ + RequestType: schemas.RealtimeRequest, + }, + } +} diff --git a/transports/bifrost-http/handlers/webrtc_realtime_test.go b/transports/bifrost-http/handlers/webrtc_realtime_test.go new file mode 100644 index 0000000000..a0c0d72c1a --- /dev/null +++ b/transports/bifrost-http/handlers/webrtc_realtime_test.go @@ -0,0 +1,346 @@ +package handlers + +import ( + "context" + "encoding/json" + "testing" + "time" + + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/framework/kvstore" + "github.com/maximhq/bifrost/framework/logstore" + "github.com/maximhq/bifrost/transports/bifrost-http/lib" + bfws "github.com/maximhq/bifrost/transports/bifrost-http/websocket" + "github.com/valyala/fasthttp" +) + +type testHandlerStore struct { + kv *kvstore.Store +} + +func (s testHandlerStore) ShouldAllowDirectKeys() bool { return true } +func (s testHandlerStore) GetHeaderMatcher() *lib.HeaderMatcher { return nil } +func (s testHandlerStore) GetAvailableProviders() []schemas.ModelProvider { return nil } +func (s testHandlerStore) GetStreamChunkInterceptor() lib.StreamChunkInterceptor { + return nil +} +func (s testHandlerStore) GetAsyncJobExecutor() *logstore.AsyncJobExecutor { return nil } +func (s testHandlerStore) GetAsyncJobResultTTL() int { return 0 } +func (s testHandlerStore) GetKVStore() *kvstore.Store { return s.kv } +func (s testHandlerStore) GetMCPHeaderCombinedAllowlist() schemas.WhiteList { return nil } + +func TestResolveRealtimeSDPTarget_BaseRouteRequiresProviderPrefix(t *testing.T) { + _, _, _, err := resolveRealtimeSDPTarget("/v1/realtime", []byte(`{"model":"gpt-4o-realtime-preview"}`)) + if err == nil { + t.Fatal("expected provider/model validation error") + } + if err.Error == nil || err.Error.Message != "session.model must use provider/model on /v1 realtime routes" { + t.Fatalf("unexpected error: %#v", err) + } +} + +func TestResolveRealtimeSDPTarget_BaseRouteNormalizesModel(t *testing.T) { + provider, model, normalized, err := resolveRealtimeSDPTarget("/v1/realtime", []byte(`{"model":"openai/gpt-4o-realtime-preview","voice":"alloy"}`)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if provider != schemas.OpenAI { + t.Fatalf("expected provider %s, got %s", schemas.OpenAI, provider) + } + if model != "gpt-4o-realtime-preview" { + t.Fatalf("unexpected normalized model: %s", model) + } + + var root map[string]json.RawMessage + if unmarshalErr := json.Unmarshal(normalized, &root); unmarshalErr != nil { + t.Fatalf("failed to unmarshal normalized session: %v", unmarshalErr) + } + var sessionModel string + if unmarshalErr := json.Unmarshal(root["model"], &sessionModel); unmarshalErr != nil { + t.Fatalf("failed to unmarshal model: %v", unmarshalErr) + } + if sessionModel != "gpt-4o-realtime-preview" { + t.Fatalf("unexpected marshaled model: %s", sessionModel) + } +} + +func TestResolveRealtimeSDPTarget_OpenAIRouteDefaultsProvider(t *testing.T) { + provider, model, _, err := resolveRealtimeSDPTarget("/openai/v1/realtime", []byte(`{"model":"gpt-4o-realtime-preview"}`)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if provider != schemas.OpenAI { + t.Fatalf("expected provider %s, got %s", schemas.OpenAI, provider) + } + if model != "gpt-4o-realtime-preview" { + t.Fatalf("unexpected model: %s", model) + } +} + +func TestParseCallsWebRTCRequest_RawSDPKeepsGARoute(t *testing.T) { + var ctx fasthttp.RequestCtx + ctx.Request.Header.SetMethod(fasthttp.MethodPost) + ctx.Request.SetRequestURI("/openai/v1/realtime/calls?model=gpt-realtime") + ctx.Request.Header.SetContentType("application/sdp") + ctx.Request.SetBodyString("v=0\r\n") + + sdpOffer, provider, model, session, err := parseCallsWebRTCRequest(&ctx) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if sdpOffer != "v=0\r\n" { + t.Fatalf("unexpected sdp offer: %q", sdpOffer) + } + if provider != schemas.OpenAI { + t.Fatalf("expected provider %s, got %s", schemas.OpenAI, provider) + } + if model != "gpt-realtime" { + t.Fatalf("unexpected model: %s", model) + } + if session != nil { + t.Fatalf("expected nil session for raw SDP /calls request, got %s", string(session)) + } +} + +func TestNewRealtimeRelayContextCopiesValuesWithoutRequestCancellation(t *testing.T) { + requestCtx, requestCancel := schemas.NewBifrostContextWithCancel(context.Background()) + requestCtx.SetValue(schemas.BifrostContextKeyHTTPRequestType, schemas.RealtimeRequest) + requestCtx.SetValue(schemas.BifrostContextKeyIntegrationType, "openai") + requestCtx.SetValue(schemas.BifrostContextKeyGovernanceVirtualKeyID, "vk_test") + + relayCtx, relayCancel := newRealtimeRelayContext(requestCtx) + defer relayCancel() + + requestCancel() + + select { + case <-requestCtx.Done(): + case <-time.After(time.Second): + t.Fatal("expected request context to be cancelled") + } + + select { + case <-relayCtx.Done(): + t.Fatal("relay context should outlive cancelled request context") + default: + } + + if got := relayCtx.Value(schemas.BifrostContextKeyHTTPRequestType); got != schemas.RealtimeRequest { + t.Fatalf("request type = %v, want %v", got, schemas.RealtimeRequest) + } + if got := relayCtx.Value(schemas.BifrostContextKeyIntegrationType); got != "openai" { + t.Fatalf("integration type = %v, want %q", got, "openai") + } + if got := relayCtx.Value(schemas.BifrostContextKeyGovernanceVirtualKeyID); got != "vk_test" { + t.Fatalf("virtual key id = %v, want %q", got, "vk_test") + } +} + +func TestParseRealtimeEventPreservesExtraParams(t *testing.T) { + event, err := schemas.ParseRealtimeEvent([]byte(`{"type":"conversation.item.truncate","item_id":"item_123","content_index":0,"audio_end_ms":640}`)) + if err != nil { + t.Fatalf("ParseRealtimeEvent() error = %v", err) + } + + var itemID string + if err := json.Unmarshal(event.ExtraParams["item_id"], &itemID); err != nil { + t.Fatalf("json.Unmarshal(item_id) error = %v", err) + } + if itemID != "item_123" { + t.Fatalf("item_id = %q, want %q", itemID, "item_123") + } + + var contentIndex int + if err := json.Unmarshal(event.ExtraParams["content_index"], &contentIndex); err != nil { + t.Fatalf("json.Unmarshal(content_index) error = %v", err) + } + if contentIndex != 0 { + t.Fatalf("content_index = %d, want 0", contentIndex) + } +} + +func TestExtractRealtimeBearerToken(t *testing.T) { + var ctx fasthttp.RequestCtx + ctx.Request.Header.Set("Authorization", "Bearer ek_test_123") + + if got := extractRealtimeBearerToken(&ctx); got != "ek_test_123" { + t.Fatalf("extractRealtimeBearerToken() = %q, want %q", got, "ek_test_123") + } +} + +func TestLookupRealtimeEphemeralKeyMappingKeepsEntryUntilTTLExpiry(t *testing.T) { + t.Parallel() + + store, err := kvstore.New(kvstore.Config{}) + if err != nil { + t.Fatalf("kvstore.New() error = %v", err) + } + defer store.Close() + + payload, err := json.Marshal(realtimeEphemeralKeyMapping{KeyID: "key_123", VirtualKey: "sk-bf-test"}) + if err != nil { + t.Fatalf("json.Marshal() error = %v", err) + } + if err := store.SetWithTTL(buildRealtimeEphemeralKeyMappingKey("ek_test_123"), payload, time.Minute); err != nil { + t.Fatalf("store.SetWithTTL() error = %v", err) + } + + mapping, ok := lookupRealtimeEphemeralKeyMapping(store, "ek_test_123") + if !ok { + t.Fatal("expected mapping to be consumed") + } + if mapping.KeyID != "key_123" { + t.Fatalf("mapping.KeyID = %q, want %q", mapping.KeyID, "key_123") + } + if mapping.VirtualKey != "sk-bf-test" { + t.Fatalf("mapping.VirtualKey = %q, want %q", mapping.VirtualKey, "sk-bf-test") + } + + raw, err := store.Get(buildRealtimeEphemeralKeyMappingKey("ek_test_123")) + if err != nil { + t.Fatalf("expected mapping to remain until TTL expiry: %v", err) + } + if raw == nil { + t.Fatal("expected mapping to remain in KV store") + } +} + +func TestLookupRealtimeEphemeralKeyMapping_BackwardsCompatibleStringValue(t *testing.T) { + t.Parallel() + + store, err := kvstore.New(kvstore.Config{}) + if err != nil { + t.Fatalf("kvstore.New() error = %v", err) + } + defer store.Close() + + if err := store.SetWithTTL(buildRealtimeEphemeralKeyMappingKey("ek_test_legacy"), "key_legacy", time.Minute); err != nil { + t.Fatalf("store.SetWithTTL() error = %v", err) + } + + mapping, ok := lookupRealtimeEphemeralKeyMapping(store, "ek_test_legacy") + if !ok { + t.Fatal("expected legacy mapping to be consumed") + } + if mapping.KeyID != "key_legacy" { + t.Fatalf("mapping.KeyID = %q, want %q", mapping.KeyID, "key_legacy") + } + if mapping.VirtualKey != "" { + t.Fatalf("mapping.VirtualKey = %q, want empty", mapping.VirtualKey) + } +} + +func TestWebRTCRealtimeRelayCloseFinalizesActiveTurnHooks(t *testing.T) { + t.Parallel() + + session := bfws.NewSession(nil) + session.SetProviderSessionID("sess_provider_123") + session.AddRealtimeInput("hello from user", `{"type":"conversation.item.added"}`) + + var ( + capturedErr *schemas.BifrostError + cleanedUp bool + ) + session.SetRealtimeTurnHooks(&bfws.RealtimeTurnPluginState{ + RequestID: "req_realtime_123", + StartedAt: time.Now().Add(-time.Second), + PostHookRunner: func(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) { + capturedErr = err + return result, nil + }, + Cleanup: func() { + cleanedUp = true + }, + }) + + relay := &webrtcRealtimeRelay{ + session: session, + providerKey: schemas.OpenAI, + model: "gpt-realtime", + } + + relay.close() + + if capturedErr == nil { + t.Fatal("expected active turn to be finalized with an error on close") + } + if capturedErr.ExtraFields.RequestType != schemas.RealtimeRequest { + t.Fatalf("request type = %q, want %q", capturedErr.ExtraFields.RequestType, schemas.RealtimeRequest) + } + if capturedErr.Error == nil || capturedErr.Error.Message != "realtime WebRTC session closed before turn completed" { + t.Fatalf("error message = %#v, want realtime close message", capturedErr.Error) + } + if session.PeekRealtimeTurnHooks() != nil { + t.Fatal("expected active realtime turn hooks to be cleared") + } + if !cleanedUp { + t.Fatal("expected realtime hook cleanup to run") + } +} + +func TestResolveRealtimeWebRTCKeys_UnmappedEphemeralTokenStaysAnonymous(t *testing.T) { + t.Parallel() + + store, err := kvstore.New(kvstore.Config{}) + if err != nil { + t.Fatalf("kvstore.New() error = %v", err) + } + defer store.Close() + + handler := &WebRTCRealtimeHandler{ + handlerStore: testHandlerStore{kv: store}, + } + + var ctx fasthttp.RequestCtx + ctx.Request.Header.Set("Authorization", "Bearer ek_test_unmapped") + + bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + bifrostCtx.SetValue(schemas.BifrostContextKeyDirectKey, schemas.Key{ID: "header-provided"}) + bifrostCtx.SetValue(schemas.BifrostContextKeySelectedKeyID, "selected") + bifrostCtx.SetValue(schemas.BifrostContextKeySelectedKeyName, "selected-name") + bifrostCtx.SetValue(schemas.BifrostContextKeyAPIKeyID, "mapped-id") + bifrostCtx.SetValue(schemas.BifrostContextKeyAPIKeyName, "mapped-name") + + authKey, selectedKey, err := handler.resolveRealtimeWebRTCKeys(&ctx, bifrostCtx, schemas.OpenAI, "gpt-realtime") + if err != nil { + t.Fatalf("resolveRealtimeWebRTCKeys() error = %v", err) + } + if got := authKey.Value.GetValue(); got != "ek_test_unmapped" { + t.Fatalf("auth key value = %q, want %q", got, "ek_test_unmapped") + } + if selectedKey != nil { + t.Fatalf("selectedKey = %#v, want nil", selectedKey) + } + if got := bifrostCtx.Value(schemas.BifrostContextKeyDirectKey); got != nil { + t.Fatalf("direct key context = %#v, want nil", got) + } + if got := bifrostCtx.Value(schemas.BifrostContextKeySelectedKeyID); got != nil { + t.Fatalf("selected key id context = %#v, want nil", got) + } + if got := bifrostCtx.Value(schemas.BifrostContextKeySelectedKeyName); got != nil { + t.Fatalf("selected key name context = %#v, want nil", got) + } + if got := bifrostCtx.Value(schemas.BifrostContextKeyAPIKeyID); got != nil { + t.Fatalf("api key id context = %#v, want nil", got) + } + if got := bifrostCtx.Value(schemas.BifrostContextKeyAPIKeyName); got != nil { + t.Fatalf("api key name context = %#v, want nil", got) + } +} + +func TestApplyRealtimeEphemeralKeyMapping_RestoresVirtualKeyAndKeyID(t *testing.T) { + t.Parallel() + + bifrostCtx := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) + applyRealtimeEphemeralKeyMapping(bifrostCtx, realtimeEphemeralKeyMapping{ + KeyID: "key_123", + VirtualKey: "sk-bf-test", + }) + + if got := bifrostCtx.Value(schemas.BifrostContextKeyVirtualKey); got != "sk-bf-test" { + t.Fatalf("virtual key context = %#v, want %q", got, "sk-bf-test") + } + if got := bifrostCtx.Value(schemas.BifrostContextKeyAPIKeyID); got != "key_123" { + t.Fatalf("api key id context = %#v, want %q", got, "key_123") + } +} diff --git a/transports/bifrost-http/handlers/websocket.go b/transports/bifrost-http/handlers/websocket.go index 3f83cfdc64..93259291c8 100644 --- a/transports/bifrost-http/handlers/websocket.go +++ b/transports/bifrost-http/handlers/websocket.go @@ -180,26 +180,29 @@ func (h *WebSocketHandler) BroadcastLogUpdate(logEntry *logstore.Log) { operationType = "create" } - // Trim payload for table view: keep only the last input message and nil out - // large output fields that the table never renders. - if len(logEntry.InputHistoryParsed) > 1 { - logEntry.InputHistoryParsed = logEntry.InputHistoryParsed[len(logEntry.InputHistoryParsed)-1:] - } - if len(logEntry.ResponsesInputHistoryParsed) > 1 { - logEntry.ResponsesInputHistoryParsed = logEntry.ResponsesInputHistoryParsed[len(logEntry.ResponsesInputHistoryParsed)-1:] + // Trim payload for table view to keep websocket updates lightweight, but keep + // full realtime turns so the live table/detail sheet can still render the + // combined tool/user/assistant turn shape without waiting for a refresh. + if logEntry.Object != "realtime.turn" { + if len(logEntry.InputHistoryParsed) > 1 { + logEntry.InputHistoryParsed = logEntry.InputHistoryParsed[len(logEntry.InputHistoryParsed)-1:] + } + if len(logEntry.ResponsesInputHistoryParsed) > 1 { + logEntry.ResponsesInputHistoryParsed = logEntry.ResponsesInputHistoryParsed[len(logEntry.ResponsesInputHistoryParsed)-1:] + } + logEntry.OutputMessageParsed = nil + logEntry.ResponsesOutputParsed = nil + logEntry.EmbeddingOutputParsed = nil + logEntry.RerankOutputParsed = nil + logEntry.ParamsParsed = nil + logEntry.ToolsParsed = nil + logEntry.ToolCallsParsed = nil + logEntry.SpeechOutputParsed = nil + logEntry.TranscriptionOutputParsed = nil + logEntry.ImageGenerationOutputParsed = nil + logEntry.ListModelsOutputParsed = nil + logEntry.CacheDebugParsed = nil } - logEntry.OutputMessageParsed = nil - logEntry.ResponsesOutputParsed = nil - logEntry.EmbeddingOutputParsed = nil - logEntry.RerankOutputParsed = nil - logEntry.ParamsParsed = nil - logEntry.ToolsParsed = nil - logEntry.ToolCallsParsed = nil - logEntry.SpeechOutputParsed = nil - logEntry.TranscriptionOutputParsed = nil - logEntry.ImageGenerationOutputParsed = nil - logEntry.ListModelsOutputParsed = nil - logEntry.CacheDebugParsed = nil message := struct { Type string `json:"type"` diff --git a/transports/bifrost-http/handlers/wsrealtime.go b/transports/bifrost-http/handlers/wsrealtime.go new file mode 100644 index 0000000000..1d31c589e9 --- /dev/null +++ b/transports/bifrost-http/handlers/wsrealtime.go @@ -0,0 +1,666 @@ +package handlers + +import ( + "errors" + "io" + "net" + "strings" + "sync" + "time" + + "github.com/fasthttp/router" + ws "github.com/fasthttp/websocket" + bifrost "github.com/maximhq/bifrost/core" + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/transports/bifrost-http/integrations" + "github.com/maximhq/bifrost/transports/bifrost-http/lib" + bfws "github.com/maximhq/bifrost/transports/bifrost-http/websocket" + "github.com/valyala/fasthttp" +) + +const ( + realtimeWSPingInterval = 15 * time.Second + realtimeWSPongTimeout = 45 * time.Second + realtimeWSPingWriteTimeout = 10 * time.Second + realtimeWSWriteTimeout = 30 * time.Second +) + +// WSRealtimeHandler handles bidirectional WebSocket proxying for the Realtime API. +type WSRealtimeHandler struct { + client *bifrost.Bifrost + config *lib.Config + handlerStore lib.HandlerStore + pool *bfws.Pool + sessions *bfws.SessionManager +} + +// NewWSRealtimeHandler creates a new Realtime WebSocket handler. +func NewWSRealtimeHandler(client *bifrost.Bifrost, config *lib.Config, pool *bfws.Pool) *WSRealtimeHandler { + maxConns := config.WebSocketConfig.MaxConnections + + return &WSRealtimeHandler{ + client: client, + config: config, + handlerStore: config, + pool: pool, + sessions: bfws.NewSessionManager(maxConns), + } +} + +// RegisterRoutes registers the Realtime WebSocket endpoint at the base path and OpenAI integration paths. +func (h *WSRealtimeHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) { + handler := lib.ChainMiddlewares(h.handleUpgrade, middlewares...) + r.GET("/v1/realtime", handler) + for _, path := range integrations.OpenAIRealtimePaths("/openai") { + r.GET(path, handler) + } +} + +func (h *WSRealtimeHandler) Close() { + if h == nil || h.sessions == nil { + return + } + h.sessions.CloseAll() +} + +func (h *WSRealtimeHandler) handleUpgrade(ctx *fasthttp.RequestCtx) { + path := string(ctx.Path()) + modelParam := string(ctx.QueryArgs().Peek("model")) + deploymentParam := string(ctx.QueryArgs().Peek("deployment")) + auth := captureAuthHeaders(ctx) + // OpenAI's SDK sends the API key via WebSocket subprotocol: "openai-insecure-api-key.". + // Extract it into the auth headers so downstream processing recognizes it. + if auth.authorization == "" { + if token := extractRealtimeSubprotocolAPIKey(ctx); token != "" { + auth.authorization = "Bearer " + token + } + } + + providerKey, model, err := resolveRealtimeTarget(path, modelParam, deploymentParam) + if err != nil { + upgrader := h.websocketUpgrader("") + upgradeErr := upgrader.Upgrade(ctx, func(conn *ws.Conn) { + defer conn.Close() + clientConn := newRealtimeClientConn(conn) + clientConn.writeRealtimeError(newRealtimeWireBifrostError(400, "invalid_request_error", err.Error())) + }) + if upgradeErr != nil { + logger.Warn("websocket upgrade failed for %s: %v", path, upgradeErr) + } + return + } + + provider := h.client.GetProviderByKey(providerKey) + rtProvider, ok := provider.(schemas.RealtimeProvider) + if provider == nil || !ok || !rtProvider.SupportsRealtimeAPI() { + upgrader := h.websocketUpgrader("") + upgradeErr := upgrader.Upgrade(ctx, func(conn *ws.Conn) { + defer conn.Close() + clientConn := newRealtimeClientConn(conn) + clientConn.writeRealtimeError(newRealtimeWireBifrostError(400, "invalid_request_error", "provider does not support realtime: "+string(providerKey))) + }) + if upgradeErr != nil { + logger.Warn("websocket upgrade failed for %s: %v", path, upgradeErr) + } + return + } + + upgrader := h.websocketUpgrader(rtProvider.RealtimeWebSocketSubprotocol()) + err = upgrader.Upgrade(ctx, func(conn *ws.Conn) { + defer conn.Close() + clientConn := newRealtimeClientConn(conn) + + session, sessionErr := h.sessions.Create(conn) + if sessionErr != nil { + clientConn.writeRealtimeError(newRealtimeWireBifrostError(429, "rate_limit_exceeded", sessionErr.Error())) + return + } + defer h.sessions.Remove(conn) + + h.runRealtimeSession(clientConn, session, auth, path, providerKey, model) + }) + if err != nil { + logger.Warn("websocket upgrade failed for %s: %v", path, err) + } +} + +func (h *WSRealtimeHandler) websocketUpgrader(subprotocol string) ws.FastHTTPUpgrader { + upgrader := ws.FastHTTPUpgrader{ + ReadBufferSize: 4096, + WriteBufferSize: 4096, + CheckOrigin: func(ctx *fasthttp.RequestCtx) bool { + origin := string(ctx.Request.Header.Peek("Origin")) + if origin == "" { + return true + } + return IsOriginAllowed(origin, h.config.ClientConfig.AllowedOrigins) + }, + } + if strings.TrimSpace(subprotocol) != "" { + upgrader.Subprotocols = []string{subprotocol} + } + return upgrader +} + +func (h *WSRealtimeHandler) runRealtimeSession( + clientConn *realtimeClientConn, + session *bfws.Session, + auth *authHeaders, + path string, + providerKey schemas.ModelProvider, + model string, +) { + clientConn.startHeartbeat() + defer clientConn.stopHeartbeat() + + bifrostCtx, cancel := createBifrostContextFromAuth(h.handlerStore, auth) + if bifrostCtx == nil { + clientConn.writeRealtimeError(newRealtimeWireBifrostError(500, "server_error", "failed to create request context")) + return + } + defer cancel() + + // Resolve ephemeral key mapping to restore virtual key context. + token := extractRealtimeBearerTokenFromHeader(auth.authorization) + if isRealtimeEphemeralToken(token) { + mapping, ok := lookupRealtimeEphemeralKeyMapping(h.handlerStore.GetKVStore(), token) + if ok { + applyRealtimeEphemeralKeyMapping(bifrostCtx, mapping) + } + } + + bifrostCtx.SetValue(schemas.BifrostContextKeyHTTPRequestType, schemas.RealtimeRequest) + if strings.HasPrefix(path, "/openai") { + bifrostCtx.SetValue(schemas.BifrostContextKeyIntegrationType, "openai") + } + + provider := h.client.GetProviderByKey(providerKey) + if provider == nil { + clientConn.writeRealtimeError(newRealtimeWireBifrostError(400, "invalid_request_error", "provider not found: "+string(providerKey))) + return + } + + rtProvider, ok := provider.(schemas.RealtimeProvider) + if !ok || !rtProvider.SupportsRealtimeAPI() { + clientConn.writeRealtimeError(newRealtimeWireBifrostError(400, "invalid_request_error", "provider does not support realtime: "+string(providerKey))) + return + } + + key, err := h.client.SelectKeyForProviderRequestType(bifrostCtx, schemas.RealtimeRequest, providerKey, model) + if err != nil { + clientConn.writeRealtimeError(newRealtimeWireBifrostError(400, "invalid_request_error", err.Error())) + return + } + + // Resolve model alias so the provider receives the actual model identifier. + model = key.Aliases.Resolve(model) + + wsURL := rtProvider.RealtimeWebSocketURL(key, model) + upstream, err := h.pool.Get(bfws.PoolKey{ + Provider: providerKey, + KeyID: key.ID, + Endpoint: wsURL, + }, rtProvider.RealtimeHeaders(key)) + if err != nil { + clientConn.writeRealtimeError(newRealtimeWireBifrostError(502, "server_error", err.Error())) + return + } + defer h.pool.Discard(upstream) + + errCh := make(chan error, 2) + go func() { + errCh <- h.relayClientToRealtimeProvider(clientConn, session, upstream, rtProvider, bifrostCtx, providerKey, model, key) + }() + go func() { + errCh <- h.relayRealtimeProviderToClient(clientConn, session, upstream, rtProvider, bifrostCtx, providerKey, model, key) + }() + + firstErr := <-errCh + _ = upstream.Close() + _ = clientConn.Close() + secondErr := <-errCh + + if logErr := selectRealtimeRelayError(firstErr, secondErr); logErr != nil { + logger.Warn("realtime websocket relay ended for %s/%s on %s: %v", providerKey, model, path, logErr) + } +} + +func (h *WSRealtimeHandler) relayClientToRealtimeProvider( + clientConn *realtimeClientConn, + session *bfws.Session, + upstream *bfws.UpstreamConn, + provider schemas.RealtimeProvider, + bifrostCtx *schemas.BifrostContext, + providerKey schemas.ModelProvider, + model string, + key schemas.Key, +) error { + for { + messageType, message, err := clientConn.ReadMessage() + if err != nil { + finalizeRealtimeTurnHooksOnTransportError( + h.client, + bifrostCtx, + session, + providerKey, + model, + &key, + 499, + "client_closed_request", + "client realtime websocket disconnected before turn completed", + ) + if isNormalWebSocketClosure(err) { + return nil + } + return err + } + if messageType != ws.TextMessage { + clientConn.writeRealtimeError(newRealtimeWireBifrostError(400, "invalid_request_error", "realtime websocket only accepts text messages")) + return nil + } + + event, err := schemas.ParseRealtimeEvent(message) + if err != nil { + clientConn.writeRealtimeError(newRealtimeWireBifrostError(400, "invalid_request_error", "failed to parse realtime event JSON")) + continue + } + // Extract pending tool/input summaries but defer recording until the event + // passes validation — rejected events must not pollute session state. + toolItemID, toolSummary := pendingRealtimeToolOutputUpdate(event) + inputItemID, inputSummary := pendingRealtimeInputUpdate(event) + + startsTurn := provider.ShouldStartRealtimeTurn(event) + if startsTurn { + if session.PeekRealtimeTurnHooks() != nil { + clientConn.writeRealtimeError(newRealtimeWireBifrostError(400, "invalid_request_error", "Conversation already has an active response in progress.")) + continue + } + if toolSummary != "" { + session.RecordRealtimeToolOutput(toolItemID, toolSummary, string(message)) + } + if inputSummary != "" { + session.RecordRealtimeInput(inputItemID, inputSummary, string(message)) + } + if bifrostErr := startRealtimeTurnHooks(h.client, bifrostCtx, session, provider, providerKey, model, &key, event.Type); bifrostErr != nil { + clientConn.writeRealtimeError(bifrostErr) + return nil + } + } + + providerEvent, err := provider.ToProviderRealtimeEvent(event) + if err != nil { + if startsTurn { + if finalizeErr := finalizeRealtimeTurnHooksWithError( + h.client, + bifrostCtx, + session, + providerKey, + model, + &key, + schemas.RTEventError, + nil, + newRealtimeWireBifrostError(400, "invalid_request_error", err.Error()), + ); finalizeErr != nil { + clientConn.writeRealtimeError(finalizeErr) + return nil + } + } + clientConn.writeRealtimeError(newRealtimeWireBifrostError(400, "invalid_request_error", err.Error())) + continue + } + + // Record tool output / input only after the event passed validation. + if !startsTurn { + if toolSummary != "" { + session.RecordRealtimeToolOutput(toolItemID, toolSummary, string(message)) + } + if inputSummary != "" { + session.RecordRealtimeInput(inputItemID, inputSummary, string(message)) + } + } + + if err := upstream.WriteMessage(ws.TextMessage, providerEvent); err != nil { + finalizeRealtimeTurnHooksWithError( + h.client, + bifrostCtx, + session, + providerKey, + model, + &key, + schemas.RTEventError, + nil, + newRealtimeWireBifrostError(502, "server_error", "failed to write realtime event upstream"), + ) + clientConn.writeRealtimeError(newRealtimeWireBifrostError(502, "server_error", "failed to write realtime event upstream")) + return err + } + } +} + +func (h *WSRealtimeHandler) relayRealtimeProviderToClient( + clientConn *realtimeClientConn, + session *bfws.Session, + upstream *bfws.UpstreamConn, + provider schemas.RealtimeProvider, + bifrostCtx *schemas.BifrostContext, + providerKey schemas.ModelProvider, + model string, + key schemas.Key, +) error { + for { + disconnectAfterWrite := false + messageType, message, err := upstream.ReadMessage() + if err != nil { + finalizeRealtimeTurnHooksOnTransportError( + h.client, + bifrostCtx, + session, + providerKey, + model, + &key, + 502, + "upstream_connection_error", + "upstream realtime websocket closed before turn completed", + ) + if isNormalWebSocketClosure(err) { + return nil + } + finalizeRealtimeTurnHooksWithError( + h.client, + bifrostCtx, + session, + providerKey, + model, + &key, + schemas.RTEventError, + nil, + newRealtimeWireBifrostError(502, "server_error", "upstream realtime websocket stream interrupted"), + ) + clientConn.writeRealtimeError(newRealtimeWireBifrostError(502, "server_error", "upstream realtime websocket stream interrupted")) + return err + } + + if messageType == ws.TextMessage { + event, err := provider.ToBifrostRealtimeEvent(message) + if err != nil { + finalizeRealtimeTurnHooksWithError( + h.client, + bifrostCtx, + session, + providerKey, + model, + &key, + schemas.RTEventError, + message, + newRealtimeWireBifrostError(502, "server_error", "failed to translate upstream realtime event"), + ) + clientConn.writeRealtimeError(newRealtimeWireBifrostError(502, "server_error", "failed to translate upstream realtime event")) + return err + } + if event != nil { + if event.Session != nil && event.Session.ID != "" { + session.SetProviderSessionID(event.Session.ID) + } + if event.Delta != nil && provider.ShouldAccumulateRealtimeOutput(event.Type) { + session.AppendRealtimeOutputText(event.Delta.Text) + session.AppendRealtimeOutputText(event.Delta.Transcript) + } + if provider.ShouldStartRealtimeTurn(event) && session.PeekRealtimeTurnHooks() == nil { + if bifrostErr := startRealtimeTurnHooks(h.client, bifrostCtx, session, provider, providerKey, model, &key, event.Type); bifrostErr != nil { + clientConn.writeRealtimeError(bifrostErr) + return nil + } + } + } + if event != nil { + inputItemID, inputSummary := pendingRealtimeInputUpdate(event) + if !provider.ShouldForwardRealtimeEvent(event) { + continue + } + if event.Type == provider.RealtimeTurnFinalEvent() { + contentOverride := session.ConsumeRealtimeOutputText() + if bifrostErr := finalizeRealtimeTurnHooks(h.client, bifrostCtx, session, provider, providerKey, model, &key, message, contentOverride); bifrostErr != nil { + clientConn.writeRealtimeError(bifrostErr) + return nil + } + } else if event.Error != nil { + turnErr := newBifrostErrorFromRealtimeError(providerKey, model, message, event.Error) + finalizeErr := finalizeRealtimeTurnHooksWithError( + h.client, + bifrostCtx, + session, + providerKey, + model, + &key, + event.Type, + message, + turnErr, + ) + if finalizeErr != nil { + clientConn.writeRealtimeError(finalizeErr) + return nil + } + // Defer the disconnect so the normal translated-write path + // below still runs — otherwise terminal errors from translated + // providers would reach the client in provider-native format. + disconnectAfterWrite = shouldGracefullyDisconnectRealtime(turnErr) + } else if inputSummary != "" { + session.RecordRealtimeInput(inputItemID, inputSummary, string(message)) + } + if len(event.RawData) == 0 { + message, err = provider.ToProviderRealtimeEvent(event) + if err != nil { + clientConn.writeRealtimeError(newRealtimeWireBifrostError(502, "server_error", "failed to encode translated realtime event")) + return err + } + } + } + } + + if err := clientConn.WriteMessage(messageType, message); err != nil { + finalizeRealtimeTurnHooksOnTransportError( + h.client, + bifrostCtx, + session, + providerKey, + model, + &key, + 499, + "client_closed_request", + "client realtime websocket disconnected before turn completed", + ) + if isNormalWebSocketClosure(err) { + return nil + } + return err + } + if disconnectAfterWrite { + return nil + } + } +} + +func resolveRealtimeTarget(path, modelParam, deploymentParam string) (schemas.ModelProvider, string, error) { + defaultProvider := realtimeDefaultProviderForPath(path) + + switch { + case strings.TrimSpace(modelParam) != "": + provider, model := schemas.ParseModelString(strings.TrimSpace(modelParam), defaultProvider) + if provider == "" || strings.TrimSpace(model) == "" { + return "", "", errRealtimeModelFormat + } + return provider, strings.TrimSpace(model), nil + case strings.TrimSpace(deploymentParam) != "": + provider, model := schemas.ParseModelString(strings.TrimSpace(deploymentParam), defaultProvider) + if provider == "" || strings.TrimSpace(model) == "" { + return "", "", errRealtimeDeploymentFormat + } + return provider, strings.TrimSpace(model), nil + default: + return "", "", errRealtimeModelRequired + } +} + +func realtimeDefaultProviderForPath(path string) schemas.ModelProvider { + if strings.HasPrefix(path, "/openai/") { + return schemas.OpenAI + } + return "" +} + +func isNormalWebSocketClosure(err error) bool { + return ws.IsCloseError(err, ws.CloseNormalClosure, ws.CloseGoingAway, ws.CloseNoStatusReceived) +} + +func isExpectedRealtimeRelayShutdown(err error) bool { + if err == nil { + return true + } + if isNormalWebSocketClosure(err) || errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) { + return true + } + // Relay teardown closes the opposite socket after the first side exits, which can + // surface as a plain network-close read error instead of a websocket close frame. + return strings.Contains(err.Error(), "use of closed network connection") +} + +func selectRealtimeRelayError(errs ...error) error { + for _, err := range errs { + if err != nil && !isExpectedRealtimeRelayShutdown(err) { + return err + } + } + return nil +} + +var ( + errRealtimeModelRequired = errorf("model or deployment query parameter is required for realtime websocket") + errRealtimeModelFormat = errorf("model query parameter must resolve to provider/model for realtime websocket") + errRealtimeDeploymentFormat = errorf("deployment query parameter must resolve to provider/model for realtime websocket") +) + +type realtimeClientConn struct { + conn *ws.Conn + writeMu sync.Mutex + closeOnce sync.Once + done chan struct{} +} + +func newRealtimeClientConn(conn *ws.Conn) *realtimeClientConn { + return &realtimeClientConn{ + conn: conn, + done: make(chan struct{}), + } +} + +func (c *realtimeClientConn) ReadMessage() (messageType int, p []byte, err error) { + messageType, p, err = c.conn.ReadMessage() + if err == nil { + c.refreshReadDeadline() + } + return messageType, p, err +} + +func (c *realtimeClientConn) WriteMessage(messageType int, data []byte) error { + c.writeMu.Lock() + defer c.writeMu.Unlock() + if err := c.conn.SetWriteDeadline(time.Now().Add(realtimeWSWriteTimeout)); err != nil { + return err + } + if err := c.conn.WriteMessage(messageType, data); err != nil { + return err + } + return c.conn.SetWriteDeadline(time.Time{}) +} + +func (c *realtimeClientConn) startHeartbeat() { + c.installPongHandler() + c.refreshReadDeadline() + + go func() { + ticker := time.NewTicker(realtimeWSPingInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + if err := c.writePing(); err != nil { + _ = c.Close() + return + } + case <-c.done: + return + } + } + }() +} + +func (c *realtimeClientConn) stopHeartbeat() { + c.closeDone() +} + +func (c *realtimeClientConn) installPongHandler() { + c.conn.SetPongHandler(func(string) error { + return c.refreshReadDeadline() + }) +} + +func (c *realtimeClientConn) refreshReadDeadline() error { + return c.conn.SetReadDeadline(time.Now().Add(realtimeWSPongTimeout)) +} + +func (c *realtimeClientConn) writePing() error { + c.writeMu.Lock() + defer c.writeMu.Unlock() + if err := c.conn.SetWriteDeadline(time.Now().Add(realtimeWSPingWriteTimeout)); err != nil { + return err + } + if err := c.conn.WriteMessage(ws.PingMessage, nil); err != nil { + return err + } + return c.conn.SetWriteDeadline(time.Time{}) +} + +func (c *realtimeClientConn) closeDone() { + c.closeOnce.Do(func() { + close(c.done) + }) +} + +func (c *realtimeClientConn) writeRealtimeError(bifrostErr *schemas.BifrostError) { + payload := newRealtimeTurnErrorEventPayload(bifrostErr) + _ = c.WriteMessage(ws.TextMessage, payload) +} + +func (c *realtimeClientConn) Close() error { + c.closeDone() + return c.conn.Close() +} + +const realtimeSubprotocolAPIKeyPrefix = "openai-insecure-api-key." + +// extractRealtimeSubprotocolAPIKey extracts an API key from the Sec-WebSocket-Protocol +// header. The OpenAI SDK sends: "realtime, openai-insecure-api-key.". +func extractRealtimeSubprotocolAPIKey(ctx *fasthttp.RequestCtx) string { + header := string(ctx.Request.Header.Peek("Sec-WebSocket-Protocol")) + for _, proto := range strings.Split(header, ",") { + proto = strings.TrimSpace(proto) + if strings.HasPrefix(proto, realtimeSubprotocolAPIKeyPrefix) { + return strings.TrimPrefix(proto, realtimeSubprotocolAPIKeyPrefix) + } + } + return "" +} + +func newRealtimeWireBifrostError(status int, code, message string) *schemas.BifrostError { + errType := code + return &schemas.BifrostError{ + StatusCode: &status, + Type: &errType, + Error: &schemas.ErrorField{ + Type: &errType, + Code: &errType, + Message: message, + }, + } +} diff --git a/transports/bifrost-http/handlers/wsresponses.go b/transports/bifrost-http/handlers/wsresponses.go index 18a0377f41..ca293a116e 100644 --- a/transports/bifrost-http/handlers/wsresponses.go +++ b/transports/bifrost-http/handlers/wsresponses.go @@ -58,6 +58,14 @@ func NewWSResponsesHandler(client *bifrost.Bifrost, config *lib.Config, pool *bf } } +// Close gracefully shuts down all active WebSocket responses sessions. +func (h *WSResponsesHandler) Close() { + if h == nil || h.sessions == nil { + return + } + h.sessions.CloseAll() +} + // RegisterRoutes registers the WebSocket Responses endpoint at the base path // and all OpenAI integration paths. func (h *WSResponsesHandler) RegisterRoutes(r *router.Router, middlewares ...schemas.BifrostHTTPMiddleware) { @@ -98,6 +106,7 @@ type authHeaders struct { virtualKey string apiKey string googAPIKey string + baggage string extraHeaders map[string]string } @@ -108,6 +117,7 @@ func captureAuthHeaders(ctx *fasthttp.RequestCtx) *authHeaders { virtualKey: string(ctx.Request.Header.Peek("x-bf-vk")), apiKey: string(ctx.Request.Header.Peek("x-api-key")), googAPIKey: string(ctx.Request.Header.Peek("x-goog-api-key")), + baggage: string(ctx.Request.Header.Peek("baggage")), extraHeaders: make(map[string]string), } @@ -192,7 +202,7 @@ func (h *WSResponsesHandler) handleResponseCreate(session *bfws.Session, auth *a bifrostReq.Params.ExtraParams = extraParams } - bifrostCtx, cancel := h.createBifrostContext(auth) + bifrostCtx, cancel := createBifrostContextFromAuth(h.handlerStore, auth) if bifrostCtx == nil { writeWSError(session, 500, "server_error", "failed to create request context") return @@ -227,9 +237,10 @@ func (h *WSResponsesHandler) tryNativeWSUpstream( return false } - key, err := h.client.SelectKeyForProvider(ctx, req.Provider, req.Model) + key, err := h.client.SelectKeyForProviderRequestType(ctx, schemas.WebSocketResponsesRequest, req.Provider, req.Model) if err != nil { - return false + writeWSError(session, 400, "invalid_request_error", err.Error()) + return true } wsURL := wsProvider.WebSocketResponsesURL(key) @@ -495,10 +506,14 @@ func (h *WSResponsesHandler) convertEventToRequest(event *schemas.WebSocketRespo }, nil } -// createBifrostContext builds a BifrostContext from the auth headers captured during upgrade. -func (h *WSResponsesHandler) createBifrostContext(auth *authHeaders) (*schemas.BifrostContext, context.CancelFunc) { +// createBifrostContextFromAuth builds a BifrostContext from the auth headers captured during upgrade. +func createBifrostContextFromAuth(handlerStore lib.HandlerStore, auth *authHeaders) (*schemas.BifrostContext, context.CancelFunc) { ctx, cancel := schemas.NewBifrostContextWithCancel(context.Background()) + if sessionID := lib.ParseSessionIDFromBaggage(auth.baggage); sessionID != "" { + ctx.SetValue(schemas.BifrostContextKeyParentRequestID, sessionID) + } + if auth.virtualKey != "" { ctx.SetValue(schemas.BifrostContextKeyVirtualKey, auth.virtualKey) } @@ -508,8 +523,8 @@ func (h *WSResponsesHandler) createBifrostContext(auth *authHeaders) (*schemas.B if strings.HasPrefix(auth.authorization, "Bearer ") { token := strings.TrimPrefix(auth.authorization, "Bearer ") if strings.HasPrefix(token, "sk-bf-") { - ctx.SetValue(schemas.BifrostContextKeyVirtualKey, token) - } else if h.handlerStore.ShouldAllowDirectKeys() { + ctx.SetValue(schemas.BifrostContextKeyVirtualKey, strings.TrimPrefix(token, "sk-bf-")) + } else if handlerStore.ShouldAllowDirectKeys() { key := schemas.Key{ ID: "header-provided", Value: *schemas.NewEnvVar(token), @@ -523,7 +538,7 @@ func (h *WSResponsesHandler) createBifrostContext(auth *authHeaders) (*schemas.B if auth.apiKey != "" { if strings.HasPrefix(auth.apiKey, "sk-bf-") { ctx.SetValue(schemas.BifrostContextKeyVirtualKey, strings.TrimPrefix(auth.apiKey, "sk-bf-")) - } else if h.handlerStore.ShouldAllowDirectKeys() { + } else if handlerStore.ShouldAllowDirectKeys() { key := schemas.Key{ ID: "header-provided", Value: *schemas.NewEnvVar(auth.apiKey), @@ -536,7 +551,7 @@ func (h *WSResponsesHandler) createBifrostContext(auth *authHeaders) (*schemas.B if auth.googAPIKey != "" { if strings.HasPrefix(auth.googAPIKey, "sk-bf-") { ctx.SetValue(schemas.BifrostContextKeyVirtualKey, strings.TrimPrefix(auth.googAPIKey, "sk-bf-")) - } else if h.handlerStore.ShouldAllowDirectKeys() { + } else if handlerStore.ShouldAllowDirectKeys() { key := schemas.Key{ ID: "header-provided", Value: *schemas.NewEnvVar(auth.googAPIKey), diff --git a/transports/bifrost-http/handlers/wsresponses_test.go b/transports/bifrost-http/handlers/wsresponses_test.go new file mode 100644 index 0000000000..aad3b15e9c --- /dev/null +++ b/transports/bifrost-http/handlers/wsresponses_test.go @@ -0,0 +1,68 @@ +package handlers + +import ( + "testing" + + "github.com/maximhq/bifrost/core/schemas" + "github.com/maximhq/bifrost/framework/kvstore" + "github.com/maximhq/bifrost/framework/logstore" + "github.com/maximhq/bifrost/transports/bifrost-http/lib" +) + +type testWSHandlerStore struct { + allowDirectKeys bool +} + +func (s testWSHandlerStore) ShouldAllowDirectKeys() bool { + return s.allowDirectKeys +} + +func (s testWSHandlerStore) GetHeaderMatcher() *lib.HeaderMatcher { + return nil +} + +func (s testWSHandlerStore) GetAvailableProviders() []schemas.ModelProvider { + return nil +} + +func (s testWSHandlerStore) GetStreamChunkInterceptor() lib.StreamChunkInterceptor { + return nil +} + +func (s testWSHandlerStore) GetAsyncJobExecutor() *logstore.AsyncJobExecutor { + return nil +} + +func (s testWSHandlerStore) GetAsyncJobResultTTL() int { + return 0 +} + +func (s testWSHandlerStore) GetKVStore() *kvstore.Store { + return nil +} + +func (s testWSHandlerStore) GetMCPHeaderCombinedAllowlist() schemas.WhiteList { + return nil +} + +func TestCreateBifrostContextFromAuth_BaggageSessionIDSetsGrouping(t *testing.T) { + ctx, cancel := createBifrostContextFromAuth(testWSHandlerStore{}, &authHeaders{ + baggage: "foo=bar, session-id=rt-ws-123, baz=qux", + }) + defer cancel() + + if got, _ := ctx.Value(schemas.BifrostContextKeyParentRequestID).(string); got != "rt-ws-123" { + t.Fatalf("parent request id = %q, want %q", got, "rt-ws-123") + } +} + +func TestCreateBifrostContextFromAuth_EmptyBaggageSessionIDIgnored(t *testing.T) { + ctx, cancel := createBifrostContextFromAuth(testWSHandlerStore{}, &authHeaders{ + baggage: "session-id= ", + }) + defer cancel() + + if got := ctx.Value(schemas.BifrostContextKeyParentRequestID); got != nil { + t.Fatalf("parent request id should be unset, got %#v", got) + } +} diff --git a/transports/bifrost-http/integrations/openai.go b/transports/bifrost-http/integrations/openai.go index abe9adf0a2..08959c8d24 100644 --- a/transports/bifrost-http/integrations/openai.go +++ b/transports/bifrost-http/integrations/openai.go @@ -2802,6 +2802,35 @@ func OpenAIRealtimePaths(pathPrefix string) []string { return paths } +// OpenAIRealtimeWebRTCCallsPaths returns HTTP POST paths for the GA /realtime/calls +// WebRTC SDP exchange endpoint (multipart sdp + session format). +func OpenAIRealtimeWebRTCCallsPaths(pathPrefix string) []string { + basePaths := []string{ + "/v1/realtime/calls", + "/realtime/calls", + "/openai/realtime/calls", + } + paths := make([]string, 0, len(basePaths)) + for _, p := range basePaths { + paths = append(paths, pathPrefix+p) + } + return paths +} + +// OpenAIRealtimeClientSecretPaths returns HTTP POST paths for OpenAI-compatible +// realtime client secret creation aliases. +func OpenAIRealtimeClientSecretPaths(pathPrefix string) []string { + basePaths := []string{ + "/v1/realtime/client_secrets", + "/v1/realtime/sessions", + } + paths := make([]string, 0, len(basePaths)) + for _, p := range basePaths { + paths = append(paths, pathPrefix+p) + } + return paths +} + // NewOpenAIRouter creates a new OpenAIRouter with the given bifrost client. func NewOpenAIRouter(client *bifrost.Bifrost, handlerStore lib.HandlerStore, logger schemas.Logger) *OpenAIRouter { routes := CreateOpenAIRouteConfigs("/openai", handlerStore) diff --git a/transports/bifrost-http/lib/ctx.go b/transports/bifrost-http/lib/ctx.go index f715d82cef..876e487249 100644 --- a/transports/bifrost-http/lib/ctx.go +++ b/transports/bifrost-http/lib/ctx.go @@ -32,6 +32,37 @@ const ( FastHTTPUserValueLargeResponseMode = "__bifrost_large_response_mode" ) +// ParseSessionIDFromBaggage extracts the session-id baggage member value. +// It supports simple W3C baggage parsing sufficient for log grouping. +func ParseSessionIDFromBaggage(header string) string { + for _, member := range strings.Split(header, ",") { + member = strings.TrimSpace(member) + if member == "" { + continue + } + + parts := strings.SplitN(member, ";", 2) + kv := strings.SplitN(strings.TrimSpace(parts[0]), "=", 2) + if len(kv) != 2 { + continue + } + + key := strings.ToLower(strings.TrimSpace(kv[0])) + value := strings.TrimSpace(kv[1]) + if key != "session-id" || value == "" { + continue + } + if len(value) > 255 { + if logger != nil { + logger.Warn("session-id exceeds 255 chars, ignoring: length=%d, prefix=%s", len(value), value[:255]) + } + continue + } + return value + } + return "" +} + // ConvertToBifrostContext converts a FastHTTP RequestCtx to a Bifrost context, // preserving important header values for monitoring and tracing purposes. // @@ -174,6 +205,12 @@ func ConvertToBifrostContext(ctx *fasthttp.RequestCtx, allowDirectKeys bool, mat // Then process other headers ctx.Request.Header.All()(func(key, value []byte) bool { keyStr := strings.ToLower(string(key)) + if keyStr == "baggage" { + if sessionID := ParseSessionIDFromBaggage(string(value)); sessionID != "" { + bifrostCtx.SetValue(schemas.BifrostContextKeyParentRequestID, sessionID) + } + return true + } if labelName, ok := strings.CutPrefix(keyStr, "x-bf-prom-"); ok { bifrostCtx.SetValue(schemas.BifrostContextKey(labelName), string(value)) return true diff --git a/transports/bifrost-http/lib/ctx_test.go b/transports/bifrost-http/lib/ctx_test.go index 396f7a57f8..abc0620883 100644 --- a/transports/bifrost-http/lib/ctx_test.go +++ b/transports/bifrost-http/lib/ctx_test.go @@ -10,6 +10,29 @@ import ( "github.com/valyala/fasthttp" ) +func TestParseSessionIDFromBaggage(t *testing.T) { + tests := []struct { + name string + header string + want string + }{ + {name: "single member", header: "session-id=abc", want: "abc"}, + {name: "multiple members", header: "foo=bar, session-id=abc, baz=qux", want: "abc"}, + {name: "member with properties", header: "session-id=abc;ttl=60", want: "abc"}, + {name: "spaces preserved around parsing", header: " foo=bar , session-id = abc123 ;ttl=60 ", want: "abc123"}, + {name: "missing member", header: "foo=bar", want: ""}, + {name: "malformed ignored", header: "session-id, foo=bar", want: ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := ParseSessionIDFromBaggage(tt.header); got != tt.want { + t.Fatalf("ParseSessionIDFromBaggage(%q) = %q, want %q", tt.header, got, tt.want) + } + }) + } +} + func TestConvertToBifrostContext_ReusesSharedContext(t *testing.T) { ctx := &fasthttp.RequestCtx{} base := schemas.NewBifrostContext(context.Background(), schemas.NoDeadline) @@ -226,3 +249,27 @@ func TestConvertToBifrostContext_NilMatcher(t *testing.T) { t.Error("expected custom-header to be forwarded with nil matcher") } } + +func TestConvertToBifrostContext_BaggageSessionIDSetsGrouping(t *testing.T) { + ctx := &fasthttp.RequestCtx{} + ctx.Request.Header.Set("baggage", "foo=bar, session-id=rt-123, baz=qux") + + bifrostCtx, cancel := ConvertToBifrostContext(ctx, false, nil, schemas.WhiteList{}) + defer cancel() + + if got, _ := bifrostCtx.Value(schemas.BifrostContextKeyParentRequestID).(string); got != "rt-123" { + t.Fatalf("parent request id = %q, want %q", got, "rt-123") + } +} + +func TestConvertToBifrostContext_EmptyBaggageSessionIDIgnored(t *testing.T) { + ctx := &fasthttp.RequestCtx{} + ctx.Request.Header.Set("baggage", "session-id= ") + + bifrostCtx, cancel := ConvertToBifrostContext(ctx, false, nil, schemas.WhiteList{}) + defer cancel() + + if got := bifrostCtx.Value(schemas.BifrostContextKeyParentRequestID); got != nil { + t.Fatalf("parent request id should be unset, got %#v", got) + } +} diff --git a/transports/bifrost-http/server/server.go b/transports/bifrost-http/server/server.go index faff3ae70d..7e6ba75b60 100644 --- a/transports/bifrost-http/server/server.go +++ b/transports/bifrost-http/server/server.go @@ -1004,9 +1004,12 @@ func (s *BifrostHTTPServer) RegisterInferenceRoutes(ctx context.Context, middlew // Initialize WebSocket pool and handler before integrations so it can be wired through s.wsPool = bfws.NewPool(s.Config.WebSocketConfig.Pool) wsResponsesHandler := handlers.NewWSResponsesHandler(s.Client, s.Config, s.wsPool) + wsRealtimeHandler := handlers.NewWSRealtimeHandler(s.Client, s.Config, s.wsPool) + webrtcRealtimeHandler := handlers.NewWebRTCRealtimeHandler(s.Client, s.Config) + realtimeClientSecretsHandler := handlers.NewRealtimeClientSecretsHandler(s.Client, s.Config) inferenceHandler := handlers.NewInferenceHandler(s.Client, s.Config) - s.IntegrationHandler = handlers.NewIntegrationHandler(s.Client, s.Config, wsResponsesHandler) + s.IntegrationHandler = handlers.NewIntegrationHandler(s.Client, s.Config, wsResponsesHandler, wsRealtimeHandler, webrtcRealtimeHandler, realtimeClientSecretsHandler) mcpInferenceHandler := handlers.NewMCPInferenceHandler(s.Client, s.Config) mcpServerHandler, err := handlers.NewMCPServerHandler(ctx, s.Config, s) if err != nil { @@ -1405,8 +1408,9 @@ func (s *BifrostHTTPServer) Bootstrap(ctx context.Context) error { // Initializing tracer with embedded streaming accumulator traceStore := tracing.NewTraceStore(60*time.Minute, logger) tracer := tracing.NewTracer(traceStore, s.Config.ModelCatalog, logger) + tracer.SetObservabilityPlugins(observabilityPlugins) s.Client.SetTracer(tracer) - s.TracingMiddleware = handlers.NewTracingMiddleware(tracer, observabilityPlugins) + s.TracingMiddleware = handlers.NewTracingMiddleware(tracer) // TransportInterceptor must be inside TracingMiddleware so that the tracing defer // runs AFTER transport post-hooks (capturing HTTPTransportPostHook plugin logs). // Order: Tracing.pre → TransportInterceptor.pre → handler → TransportInterceptor.post → Tracing.defer @@ -1460,6 +1464,10 @@ func (s *BifrostHTTPServer) Start() error { select { case sig := <-sigChan: logger.Info("received signal %v, initiating graceful shutdown...", sig) + if s.IntegrationHandler != nil { + logger.Info("closing realtime transport sessions...") + s.IntegrationHandler.Close() + } // Create shutdown context with timeout shutdownCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() @@ -1516,6 +1524,9 @@ func (s *BifrostHTTPServer) Start() error { } case err := <-errChan: + if s.IntegrationHandler != nil { + s.IntegrationHandler.Close() + } if s.wsPool != nil { s.wsPool.Close() } diff --git a/transports/bifrost-http/websocket/session.go b/transports/bifrost-http/websocket/session.go index 314e2fd7bc..0f75b4b6a7 100644 --- a/transports/bifrost-http/websocket/session.go +++ b/transports/bifrost-http/websocket/session.go @@ -1,9 +1,13 @@ package websocket import ( + "strings" "sync" + "time" ws "github.com/fasthttp/websocket" + "github.com/google/uuid" + "github.com/maximhq/bifrost/core/schemas" ) // Session tracks the binding between a client WebSocket connection and its upstream state. @@ -12,6 +16,8 @@ type Session struct { mu sync.RWMutex writeMu sync.Mutex // serializes all WriteMessage calls to clientConn + id string + // Client connection clientConn *ws.Conn @@ -22,16 +28,64 @@ type Session struct { // LastResponseID tracks the most recent response ID for previous_response_id chaining. lastResponseID string + // providerSessionID tracks the upstream provider's session identifier when exposed. + providerSessionID string + + // realtimeOutputText accumulates assistant/provider turn text until the terminal event. + realtimeOutputText string + + // realtimeTurnInputs accumulates finalized user/tool inputs in arrival order so the + // completed assistant turn can persist the full turn history instead of only the + // latest finalized input event. + realtimeTurnInputs []RealtimeTurnInput + + // realtimeConsumedTurnItemIDs tracks finalized item IDs that have already been + // attached to a persisted turn, so late transcript updates do not pollute later turns. + realtimeConsumedTurnItemIDs map[string]struct{} + + // realtimeTurnHooks tracks the active turn-scoped plugin pipeline between + // response.create and response.done. + realtimeTurnHooks *RealtimeTurnPluginState + realtimeTurnBusy bool + closed bool } +type RealtimeToolOutput struct { + Summary string + Raw string +} + +type RealtimeTurnInput struct { + ItemID string + Role string + Summary string + Raw string +} + +type RealtimeTurnPluginState struct { + PostHookRunner schemas.PostHookRunner + Cleanup func() + RequestID string + StartedAt time.Time + PreHookValues map[any]any +} + // NewSession creates a new session for a client WebSocket connection. func NewSession(clientConn *ws.Conn) *Session { return &Session{ + id: uuid.NewString(), clientConn: clientConn, } } +// ID returns the stable Bifrost session identifier for this websocket session. +func (s *Session) ID() string { + s.mu.RLock() + defer s.mu.RUnlock() + return s.id +} + // ClientConn returns the client's WebSocket connection. func (s *Session) ClientConn() *ws.Conn { return s.clientConn @@ -83,6 +137,212 @@ func (s *Session) LastResponseID() string { return s.lastResponseID } +// SetProviderSessionID stores the upstream provider session identifier when available. +func (s *Session) SetProviderSessionID(id string) { + s.mu.Lock() + defer s.mu.Unlock() + s.providerSessionID = id +} + +// ProviderSessionID returns the upstream provider session identifier when known. +func (s *Session) ProviderSessionID() string { + s.mu.RLock() + defer s.mu.RUnlock() + return s.providerSessionID +} + +// AppendRealtimeOutputText appends provider output content for the current realtime turn. +func (s *Session) AppendRealtimeOutputText(text string) { + if text == "" { + return + } + s.mu.Lock() + defer s.mu.Unlock() + s.realtimeOutputText += text +} + +// ConsumeRealtimeOutputText returns the accumulated provider output and clears it. +func (s *Session) ConsumeRealtimeOutputText() string { + s.mu.Lock() + defer s.mu.Unlock() + text := s.realtimeOutputText + s.realtimeOutputText = "" + return text +} + +// AddRealtimeInput stores a finalized user turn event in arrival order. +func (s *Session) AddRealtimeInput(summary, raw string) { + if summary == "" && raw == "" { + return + } + s.mu.Lock() + defer s.mu.Unlock() + s.realtimeTurnInputs = append(s.realtimeTurnInputs, RealtimeTurnInput{ + Role: string(schemas.ChatMessageRoleUser), + Summary: summary, + Raw: raw, + }) +} + +// RecordRealtimeInput stores or updates a finalized user turn event keyed by item ID. +// Late updates for items already attached to a completed turn are ignored. +func (s *Session) RecordRealtimeInput(itemID, summary, raw string) { + s.recordRealtimeTurnInput(itemID, string(schemas.ChatMessageRoleUser), summary, raw) +} + +// AddRealtimeToolOutput stores a pending tool result for the next assistant turn. +func (s *Session) AddRealtimeToolOutput(summary, raw string) { + if summary == "" && raw == "" { + return + } + s.mu.Lock() + defer s.mu.Unlock() + s.realtimeTurnInputs = append(s.realtimeTurnInputs, RealtimeTurnInput{ + Role: string(schemas.ChatMessageRoleTool), + Summary: summary, + Raw: raw, + }) +} + +// RecordRealtimeToolOutput stores or updates a finalized tool result keyed by item ID. +// Late updates for items already attached to a completed turn are ignored. +func (s *Session) RecordRealtimeToolOutput(itemID, summary, raw string) { + s.recordRealtimeTurnInput(itemID, string(schemas.ChatMessageRoleTool), summary, raw) +} + +func (s *Session) recordRealtimeTurnInput(itemID, role, summary, raw string) { + if summary == "" && raw == "" { + return + } + + s.mu.Lock() + defer s.mu.Unlock() + + itemID = strings.TrimSpace(itemID) + if itemID != "" { + if _, consumed := s.realtimeConsumedTurnItemIDs[itemID]; consumed { + return + } + for idx := range s.realtimeTurnInputs { + if s.realtimeTurnInputs[idx].ItemID != itemID || s.realtimeTurnInputs[idx].Role != role { + continue + } + if strings.TrimSpace(summary) != "" { + s.realtimeTurnInputs[idx].Summary = summary + } + if strings.TrimSpace(raw) != "" { + existingRaw := strings.TrimSpace(s.realtimeTurnInputs[idx].Raw) + incomingRaw := strings.TrimSpace(raw) + switch { + case existingRaw == "": + s.realtimeTurnInputs[idx].Raw = raw + case incomingRaw == "" || existingRaw == incomingRaw: + default: + s.realtimeTurnInputs[idx].Raw = existingRaw + "\n\n" + incomingRaw + } + } + return + } + } + + s.realtimeTurnInputs = append(s.realtimeTurnInputs, RealtimeTurnInput{ + ItemID: itemID, + Role: role, + Summary: summary, + Raw: raw, + }) +} + +// ConsumeRealtimeTurnInputs returns pending realtime turn inputs and clears them. +func (s *Session) ConsumeRealtimeTurnInputs() []RealtimeTurnInput { + s.mu.Lock() + defer s.mu.Unlock() + inputs := append([]RealtimeTurnInput(nil), s.realtimeTurnInputs...) + if len(inputs) > 0 { + if s.realtimeConsumedTurnItemIDs == nil { + s.realtimeConsumedTurnItemIDs = make(map[string]struct{}, len(inputs)) + } + for _, input := range inputs { + if strings.TrimSpace(input.ItemID) != "" { + s.realtimeConsumedTurnItemIDs[input.ItemID] = struct{}{} + } + } + } + s.realtimeTurnInputs = nil + return inputs +} + +// PeekRealtimeTurnInputs returns pending realtime turn inputs without clearing them. +func (s *Session) PeekRealtimeTurnInputs() []RealtimeTurnInput { + s.mu.RLock() + defer s.mu.RUnlock() + return append([]RealtimeTurnInput(nil), s.realtimeTurnInputs...) +} + +// SetRealtimeTurnHooks stores the active turn-scoped plugin pipeline. +func (s *Session) SetRealtimeTurnHooks(state *RealtimeTurnPluginState) { + s.mu.Lock() + defer s.mu.Unlock() + if s.realtimeTurnHooks != nil && s.realtimeTurnHooks.Cleanup != nil { + s.realtimeTurnHooks.Cleanup() + } + s.realtimeTurnBusy = false + if s.closed { + if state != nil && state.Cleanup != nil { + state.Cleanup() + } + s.realtimeTurnHooks = nil + return + } + s.realtimeTurnHooks = state +} + +// TryBeginRealtimeTurnHooks reserves the single active turn slot. +func (s *Session) TryBeginRealtimeTurnHooks() bool { + s.mu.Lock() + defer s.mu.Unlock() + if s.closed || s.realtimeTurnBusy || s.realtimeTurnHooks != nil { + return false + } + s.realtimeTurnBusy = true + return true +} + +// AbortRealtimeTurnHooks releases a reserved turn slot without installing hooks. +func (s *Session) AbortRealtimeTurnHooks() { + s.mu.Lock() + defer s.mu.Unlock() + s.realtimeTurnBusy = false +} + +// PeekRealtimeTurnHooks returns the active turn-scoped plugin pipeline without clearing it. +func (s *Session) PeekRealtimeTurnHooks() *RealtimeTurnPluginState { + s.mu.RLock() + defer s.mu.RUnlock() + return s.realtimeTurnHooks +} + +// ConsumeRealtimeTurnHooks returns the active turn-scoped plugin pipeline and clears it. +func (s *Session) ConsumeRealtimeTurnHooks() *RealtimeTurnPluginState { + s.mu.Lock() + defer s.mu.Unlock() + state := s.realtimeTurnHooks + s.realtimeTurnHooks = nil + s.realtimeTurnBusy = false + return state +} + +// ClearRealtimeTurnHooks cleans up and clears any active turn-scoped plugin pipeline. +func (s *Session) ClearRealtimeTurnHooks() { + s.mu.Lock() + defer s.mu.Unlock() + if s.realtimeTurnHooks != nil && s.realtimeTurnHooks.Cleanup != nil { + s.realtimeTurnHooks.Cleanup() + } + s.realtimeTurnHooks = nil + s.realtimeTurnBusy = false +} + // Close closes the session and its upstream connection if pinned. func (s *Session) Close() { s.mu.Lock() @@ -91,6 +351,16 @@ func (s *Session) Close() { return } s.closed = true + if s.realtimeTurnHooks != nil { + if s.realtimeTurnHooks.Cleanup != nil { + s.realtimeTurnHooks.Cleanup() + } + s.realtimeTurnHooks = nil + } + s.realtimeTurnBusy = false + if s.clientConn != nil { + _ = s.clientConn.Close() + } if s.upstream != nil { s.upstream.Close() s.upstream = nil @@ -166,3 +436,15 @@ func (m *SessionManager) CloseAll() { session.Close() } } + +// Snapshot returns a copy of the currently tracked sessions. +func (m *SessionManager) Snapshot() []*Session { + m.mu.RLock() + defer m.mu.RUnlock() + + sessions := make([]*Session, 0, len(m.sessions)) + for _, session := range m.sessions { + sessions = append(sessions, session) + } + return sessions +} diff --git a/transports/bifrost-http/websocket/session_test.go b/transports/bifrost-http/websocket/session_test.go index 8c7a6ebb1f..148e6fe1d5 100644 --- a/transports/bifrost-http/websocket/session_test.go +++ b/transports/bifrost-http/websocket/session_test.go @@ -1,136 +1,156 @@ package websocket import ( - "net/http" - "net/http/httptest" - "strings" "testing" ws "github.com/fasthttp/websocket" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) -func dialTestWS(t *testing.T, server *httptest.Server) *ws.Conn { - t.Helper() - wsURL := "ws" + strings.TrimPrefix(server.URL, "http") - conn, _, err := ws.DefaultDialer.Dial(wsURL, nil) - require.NoError(t, err) - return conn -} +func TestSessionManagerCreateAndGet(t *testing.T) { + manager := NewSessionManager(2) + conn := newTestConn() -func startEchoServer(t *testing.T) *httptest.Server { - t.Helper() - upgrader := ws.Upgrader{ - CheckOrigin: func(r *http.Request) bool { return true }, - } - return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := upgrader.Upgrade(w, r, nil) - if err != nil { - return - } - defer conn.Close() - for { - mt, msg, err := conn.ReadMessage() - if err != nil { - break - } - conn.WriteMessage(mt, msg) - } - })) + session, err := manager.Create(conn) + if err != nil { + t.Fatalf("Create() unexpected error: %v", err) + } + if session == nil { + t.Fatal("Create() returned nil session") + } + if got := manager.Get(conn); got != session { + t.Fatal("Get() did not return the created session") + } + if got := manager.Count(); got != 1 { + t.Fatalf("Count() = %d, want 1", got) + } } -func TestSessionManager_CreateAndGet(t *testing.T) { - server := startEchoServer(t) - defer server.Close() +func TestSessionManagerConnectionLimit(t *testing.T) { + manager := NewSessionManager(1) - sm := NewSessionManager(10) - - conn := dialTestWS(t, server) - defer conn.Close() - - session, err := sm.Create(conn) - require.NoError(t, err) - require.NotNil(t, session) - - got := sm.Get(conn) - assert.Equal(t, session, got) - assert.Equal(t, 1, sm.Count()) + if _, err := manager.Create(newTestConn()); err != nil { + t.Fatalf("first Create() unexpected error: %v", err) + } + if _, err := manager.Create(newTestConn()); err != ErrConnectionLimitReached { + t.Fatalf("second Create() error = %v, want %v", err, ErrConnectionLimitReached) + } } -func TestSessionManager_ConnectionLimit(t *testing.T) { - server := startEchoServer(t) - defer server.Close() - - sm := NewSessionManager(2) +func TestSessionManagerRemove(t *testing.T) { + manager := NewSessionManager(2) + conn := newTestConn() - conn1 := dialTestWS(t, server) - defer conn1.Close() - conn2 := dialTestWS(t, server) - defer conn2.Close() - conn3 := dialTestWS(t, server) - defer conn3.Close() + session, err := manager.Create(conn) + if err != nil { + t.Fatalf("Create() unexpected error: %v", err) + } - _, err := sm.Create(conn1) - require.NoError(t, err) - _, err = sm.Create(conn2) - require.NoError(t, err) + manager.Remove(conn) - // Third should fail - _, err = sm.Create(conn3) - assert.ErrorIs(t, err, ErrConnectionLimitReached) - assert.Equal(t, 2, sm.Count()) + if got := manager.Get(conn); got != nil { + t.Fatal("Get() should return nil after Remove()") + } + if got := manager.Count(); got != 0 { + t.Fatalf("Count() = %d, want 0", got) + } + if !session.closed { + t.Fatal("expected removed session to be closed") + } } -func TestSessionManager_Remove(t *testing.T) { - server := startEchoServer(t) - defer server.Close() +func TestSessionLastResponseID(t *testing.T) { + session := NewSession(newTestConn()) + session.SetLastResponseID("resp-123") - sm := NewSessionManager(10) + if got := session.LastResponseID(); got != "resp-123" { + t.Fatalf("LastResponseID() = %q, want %q", got, "resp-123") + } +} - conn := dialTestWS(t, server) - defer conn.Close() +func TestSessionManagerCloseAll(t *testing.T) { + manager := NewSessionManager(4) + connA := newTestConn() + connB := newTestConn() - _, err := sm.Create(conn) - require.NoError(t, err) - assert.Equal(t, 1, sm.Count()) + sessionA, err := manager.Create(connA) + if err != nil { + t.Fatalf("Create(connA) unexpected error: %v", err) + } + sessionB, err := manager.Create(connB) + if err != nil { + t.Fatalf("Create(connB) unexpected error: %v", err) + } + + manager.CloseAll() - sm.Remove(conn) - assert.Equal(t, 0, sm.Count()) - assert.Nil(t, sm.Get(conn)) + if got := manager.Count(); got != 0 { + t.Fatalf("Count() = %d, want 0", got) + } + if !sessionA.closed || !sessionB.closed { + t.Fatal("expected all sessions to be closed") + } } -func TestSession_LastResponseID(t *testing.T) { - server := startEchoServer(t) - defer server.Close() +func TestSessionRealtimeState(t *testing.T) { + session := NewSession(newTestConn()) + if session.ID() == "" { + t.Fatal("expected session ID to be populated") + } - conn := dialTestWS(t, server) - defer conn.Close() + session.SetProviderSessionID("provider-session-1") + if got := session.ProviderSessionID(); got != "provider-session-1" { + t.Fatalf("ProviderSessionID() = %q, want %q", got, "provider-session-1") + } - session := NewSession(conn) - assert.Equal(t, "", session.LastResponseID()) + session.AppendRealtimeOutputText("hello") + session.AppendRealtimeOutputText(" world") + if got := session.ConsumeRealtimeOutputText(); got != "hello world" { + t.Fatalf("ConsumeRealtimeOutputText() = %q, want %q", got, "hello world") + } + if got := session.ConsumeRealtimeOutputText(); got != "" { + t.Fatalf("ConsumeRealtimeOutputText() after clear = %q, want empty string", got) + } - session.SetLastResponseID("resp_123") - assert.Equal(t, "resp_123", session.LastResponseID()) + session.AddRealtimeInput("hello", `{"type":"conversation.item.create","item":{"role":"user"}}`) + session.AddRealtimeToolOutput("tool result", `{"type":"conversation.item.create","item":{"type":"function_call_output"}}`) + turnInputs := session.ConsumeRealtimeTurnInputs() + if len(turnInputs) != 2 { + t.Fatalf("len(ConsumeRealtimeTurnInputs()) = %d, want 2", len(turnInputs)) + } + if turnInputs[0].Role != "user" || turnInputs[0].Summary != "hello" { + t.Fatalf("turnInputs[0] = %+v, want user hello", turnInputs[0]) + } + if turnInputs[1].Role != "tool" || turnInputs[1].Summary != "tool result" { + t.Fatalf("turnInputs[1] = %+v, want tool result", turnInputs[1]) + } + if got := session.ConsumeRealtimeTurnInputs(); len(got) != 0 { + t.Fatalf("len(ConsumeRealtimeTurnInputs()) after clear = %d, want 0", len(got)) + } } -func TestSessionManager_CloseAll(t *testing.T) { - server := startEchoServer(t) - defer server.Close() +func TestSessionRecordRealtimeInputUpdatesPendingItemAndIgnoresConsumedLateUpdate(t *testing.T) { + session := NewSession(newTestConn()) - sm := NewSessionManager(10) + session.RecordRealtimeInput("item_1", "[Audio transcription unavailable]", `{"type":"conversation.item.done","item":{"id":"item_1"}}`) + session.RecordRealtimeInput("item_1", "Hello.", `{"type":"conversation.item.input_audio_transcription.completed","item_id":"item_1","transcript":"Hello."}`) - conn1 := dialTestWS(t, server) - defer conn1.Close() - conn2 := dialTestWS(t, server) - defer conn2.Close() + turnInputs := session.ConsumeRealtimeTurnInputs() + if len(turnInputs) != 1 { + t.Fatalf("len(ConsumeRealtimeTurnInputs()) = %d, want 1", len(turnInputs)) + } + if turnInputs[0].ItemID != "item_1" { + t.Fatalf("turnInputs[0].ItemID = %q, want %q", turnInputs[0].ItemID, "item_1") + } + if turnInputs[0].Summary != "Hello." { + t.Fatalf("turnInputs[0].Summary = %q, want %q", turnInputs[0].Summary, "Hello.") + } - _, err := sm.Create(conn1) - assert.NoError(t, err) - _, err = sm.Create(conn2) - assert.NoError(t, err) - assert.Equal(t, 2, sm.Count()) + session.RecordRealtimeInput("item_1", "Hello.", `{"type":"conversation.item.input_audio_transcription.completed","item_id":"item_1","transcript":"Hello."}`) + if got := session.ConsumeRealtimeTurnInputs(); len(got) != 0 { + t.Fatalf("len(ConsumeRealtimeTurnInputs()) after late consumed update = %d, want 0", len(got)) + } +} - sm.CloseAll() - assert.Equal(t, 0, sm.Count()) +func newTestConn() *ws.Conn { + return &ws.Conn{} } diff --git a/transports/go.mod b/transports/go.mod index 1a02163c77..55c00cb50a 100644 --- a/transports/go.mod +++ b/transports/go.mod @@ -21,6 +21,8 @@ require ( github.com/maximhq/bifrost/plugins/otel v1.2.0 github.com/maximhq/bifrost/plugins/semanticcache v1.5.0 github.com/maximhq/bifrost/plugins/telemetry v1.5.0 + github.com/pion/rtcp v1.2.16 + github.com/pion/webrtc/v4 v4.2.9 github.com/prometheus/client_golang v1.23.2 github.com/santhosh-tekuri/jsonschema/v6 v6.0.2 github.com/stretchr/testify v1.11.1 @@ -117,6 +119,20 @@ require ( github.com/oapi-codegen/runtime v1.1.1 // indirect github.com/oklog/ulid v1.3.1 // indirect github.com/pinecone-io/go-pinecone/v5 v5.3.0 // indirect + github.com/pion/datachannel v1.6.0 // indirect + github.com/pion/dtls/v3 v3.1.2 // indirect + github.com/pion/ice/v4 v4.2.1 // indirect + github.com/pion/interceptor v0.1.44 // indirect + github.com/pion/logging v0.2.4 // indirect + github.com/pion/mdns/v2 v2.1.0 // indirect + github.com/pion/randutil v0.1.0 // indirect + github.com/pion/rtp v1.10.1 // indirect + github.com/pion/sctp v1.9.2 // indirect + github.com/pion/sdp/v3 v3.0.18 // indirect + github.com/pion/srtp/v3 v3.0.10 // indirect + github.com/pion/stun/v3 v3.1.1 // indirect + github.com/pion/transport/v4 v4.0.1 // indirect + github.com/pion/turn/v4 v4.1.4 // indirect github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect @@ -138,6 +154,7 @@ require ( github.com/weaviate/weaviate v1.36.5 // indirect github.com/weaviate/weaviate-go-client/v5 v5.7.1 // indirect github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect + github.com/wlynxg/anet v0.0.5 // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect go.mongodb.org/mongo-driver v1.17.6 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect @@ -159,6 +176,7 @@ require ( golang.org/x/oauth2 v0.35.0 // indirect golang.org/x/sys v0.42.0 // indirect golang.org/x/text v0.35.0 // indirect + golang.org/x/time v0.14.0 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20260203192932-546029d2fa20 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20260319201613-d00831a3d3e7 // indirect google.golang.org/grpc v1.79.3 // indirect diff --git a/transports/go.sum b/transports/go.sum index 6d69d2ac18..a70bc9ed4b 100644 --- a/transports/go.sum +++ b/transports/go.sum @@ -243,6 +243,39 @@ github.com/oklog/ulid v1.3.1 h1:EGfNDEx6MqHz8B3uNV6QAib1UR2Lm97sHi3ocA6ESJ4= github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn4U= github.com/pinecone-io/go-pinecone/v5 v5.3.0 h1:0YQlEtmXGWK/I8ztkOVM6PuBYgFJZhjSdb0ddU+bHPE= github.com/pinecone-io/go-pinecone/v5 v5.3.0/go.mod h1:6Fg85fcyvMUQFf9KW7zniN81kelSYvsjF+KPLdc1MGA= +github.com/pion/datachannel v1.6.0 h1:XecBlj+cvsxhAMZWFfFcPyUaDZtd7IJvrXqlXD/53i0= +github.com/pion/datachannel v1.6.0/go.mod h1:ur+wzYF8mWdC+Mkis5Thosk+u/VOL287apDNEbFpsIk= +github.com/pion/dtls/v3 v3.1.2 h1:gqEdOUXLtCGW+afsBLO0LtDD8GnuBBjEy6HRtyofZTc= +github.com/pion/dtls/v3 v3.1.2/go.mod h1:Hw/igcX4pdY69z1Hgv5x7wJFrUkdgHwAn/Q/uo7YHRo= +github.com/pion/ice/v4 v4.2.1 h1:XPRYXaLiFq3LFDG7a7bMrmr3mFr27G/gtXN3v/TVfxY= +github.com/pion/ice/v4 v4.2.1/go.mod h1:2quLV1S5v1tAx3VvAJaH//KGitRXvo4RKlX6D3tnN+c= +github.com/pion/interceptor v0.1.44 h1:sNlZwM8dWXU9JQAkJh8xrarC0Etn8Oolcniukmuy0/I= +github.com/pion/interceptor v0.1.44/go.mod h1:4atVlBkcgXuUP+ykQF0qOCGU2j7pQzX2ofvPRFsY5RY= +github.com/pion/logging v0.2.4 h1:tTew+7cmQ+Mc1pTBLKH2puKsOvhm32dROumOZ655zB8= +github.com/pion/logging v0.2.4/go.mod h1:DffhXTKYdNZU+KtJ5pyQDjvOAh/GsNSyv1lbkFbe3so= +github.com/pion/mdns/v2 v2.1.0 h1:3IJ9+Xio6tWYjhN6WwuY142P/1jA0D5ERaIqawg/fOY= +github.com/pion/mdns/v2 v2.1.0/go.mod h1:pcez23GdynwcfRU1977qKU0mDxSeucttSHbCSfFOd9A= +github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA= +github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8= +github.com/pion/rtcp v1.2.16 h1:fk1B1dNW4hsI78XUCljZJlC4kZOPk67mNRuQ0fcEkSo= +github.com/pion/rtcp v1.2.16/go.mod h1:/as7VKfYbs5NIb4h6muQ35kQF/J0ZVNz2Z3xKoCBYOo= +github.com/pion/rtp v1.10.1 h1:xP1prZcCTUuhO2c83XtxyOHJteISg6o8iPsE2acaMtA= +github.com/pion/rtp v1.10.1/go.mod h1:rF5nS1GqbR7H/TCpKwylzeq6yDM+MM6k+On5EgeThEM= +github.com/pion/sctp v1.9.2 h1:HxsOzEV9pWoeggv7T5kewVkstFNcGvhMPx0GvUOUQXo= +github.com/pion/sctp v1.9.2/go.mod h1:OTOlsQ5EDQ6mQ0z4MUGXt2CgQmKyafBEXhUVqLRB6G8= +github.com/pion/sdp/v3 v3.0.18 h1:l0bAXazKHpepazVdp+tPYnrsy9dfh7ZbT8DxesH5ZnI= +github.com/pion/sdp/v3 v3.0.18/go.mod h1:ZREGo6A9ZygQ9XkqAj5xYCQtQpif0i6Pa81HOiAdqQ8= +github.com/pion/srtp/v3 v3.0.10 h1:tFirkpBb3XccP5VEXLi50GqXhv5SKPxqrdlhDCJlZrQ= +github.com/pion/srtp/v3 v3.0.10/go.mod h1:3mOTIB0cq9qlbn59V4ozvv9ClW/BSEbRp4cY0VtaR7M= +github.com/pion/stun/v3 v3.1.1 h1:CkQxveJ4xGQjulGSROXbXq94TAWu8gIX2dT+ePhUkqw= +github.com/pion/stun/v3 v3.1.1/go.mod h1:qC1DfmcCTQjl9PBaMa5wSn3x9IPmKxSdcCsxBcDBndM= +github.com/pion/transport/v3 v3.1.1 h1:Tr684+fnnKlhPceU+ICdrw6KKkTms+5qHMgw6bIkYOM= +github.com/pion/transport/v4 v4.0.1 h1:sdROELU6BZ63Ab7FrOLn13M6YdJLY20wldXW2Cu2k8o= +github.com/pion/transport/v4 v4.0.1/go.mod h1:nEuEA4AD5lPdcIegQDpVLgNoDGreqM/YqmEx3ovP4jM= +github.com/pion/turn/v4 v4.1.4 h1:EU11yMXKIsK43FhcUnjLlrhE4nboHZq+TXBIi3QpcxQ= +github.com/pion/turn/v4 v4.1.4/go.mod h1:ES1DXVFKnOhuDkqn9hn5VJlSWmZPaRJLyBXoOeO/BmQ= +github.com/pion/webrtc/v4 v4.2.9 h1:DZIh1HAhPIL3RvwEDFsmL5hfPSLEpxsQk9/Jir2vkJE= +github.com/pion/webrtc/v4 v4.2.9/go.mod h1:9EmLZve0H76eTzf8v2FmchZ6tcBXtDgpfTEu+drW6SY= github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= @@ -314,6 +347,8 @@ github.com/weaviate/weaviate-go-client/v5 v5.7.1 h1:vEMxh486QqRqWaq58UEe/TiTbGbo github.com/weaviate/weaviate-go-client/v5 v5.7.1/go.mod h1:T/JDErjN074GrnYIa0AgK1TGUGP/6A/8vqXNPlv4c6E= github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= +github.com/wlynxg/anet v0.0.5 h1:J3VJGi1gvo0JwZ/P1/Yc/8p63SoW98B5dHkYDmpgvvU= +github.com/wlynxg/anet v0.0.5/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA= github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= @@ -368,6 +403,8 @@ golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8= golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA= +golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= +golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= google.golang.org/genproto/googleapis/api v0.0.0-20260203192932-546029d2fa20 h1:7ei4lp52gK1uSejlA8AZl5AJjeLUOHBQscRQZUgAcu0=