diff --git a/core/bifrost.go b/core/bifrost.go index db562495b0..bea3ae846c 100644 --- a/core/bifrost.go +++ b/core/bifrost.go @@ -3821,9 +3821,10 @@ func (bifrost *Bifrost) GetProviderByKey(providerKey schemas.ModelProvider) sche return bifrost.getProviderByKey(providerKey) } -// SelectKeyForProvider selects an API key for the given provider and model. -// Used by WebSocket handlers that need a key for upstream connections. -func (bifrost *Bifrost) SelectKeyForProvider(ctx *schemas.BifrostContext, providerKey schemas.ModelProvider, model string) (schemas.Key, error) { +// SelectKeyForProviderRequestType selects an API key for the given provider, request type, and model. +// Used by WebSocket handlers that need a key for upstream connections while honoring request-specific +// AllowedRequests gates such as realtime-only support. +func (bifrost *Bifrost) SelectKeyForProviderRequestType(ctx *schemas.BifrostContext, requestType schemas.RequestType, providerKey schemas.ModelProvider, model string) (schemas.Key, error) { if ctx == nil { ctx = bifrost.ctx } @@ -3832,7 +3833,7 @@ func (bifrost *Bifrost) SelectKeyForProvider(ctx *schemas.BifrostContext, provid config.CustomProviderConfig != nil && config.CustomProviderConfig.BaseProviderType != "" { baseProvider = config.CustomProviderConfig.BaseProviderType } - return bifrost.selectKeyFromProviderForModel(ctx, schemas.WebSocketResponsesRequest, providerKey, model, baseProvider) + return bifrost.selectKeyFromProviderForModel(ctx, requestType, providerKey, model, baseProvider) } // WSStreamHooks holds the post-hook runner and cleanup function returned by RunStreamPreHooks. @@ -3846,6 +3847,13 @@ type WSStreamHooks struct { ShortCircuitResponse *schemas.BifrostResponse } +// RealtimeTurnHooks mirrors RunStreamPreHooks but is explicitly scoped to a +// single realtime turn rather than one long-lived transport connection. +type RealtimeTurnHooks struct { + PostHookRunner schemas.PostHookRunner + Cleanup func() +} + // RunStreamPreHooks acquires a plugin pipeline, sets up tracing context, runs PreLLMHooks, // and returns a PostHookRunner for per-chunk post-processing. // Used by WebSocket handlers that bypass the normal inference path but still need plugin hooks. @@ -3884,13 +3892,22 @@ func (bifrost *Bifrost) RunStreamPreHooks(ctx *schemas.BifrostContext, req *sche preReq, shortCircuit, preCount := pipeline.RunLLMPreHooks(ctx, req) if preReq == nil && shortCircuit == nil { + bifrostErr := newBifrostErrorFromMsg("bifrost request after plugin hooks cannot be nil") + _, bifrostErr = pipeline.RunPostLLMHooks(ctx, nil, bifrostErr, preCount) + drainAndAttachPluginLogs(ctx) + if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && strings.TrimSpace(traceID) != "" { + tracer.CompleteAndFlushTrace(strings.TrimSpace(traceID)) + } cleanup() - return nil, newBifrostErrorFromMsg("bifrost request after plugin hooks cannot be nil") + return nil, bifrostErr } if shortCircuit != nil { if shortCircuit.Error != nil { _, bifrostErr := pipeline.RunPostLLMHooks(ctx, nil, shortCircuit.Error, preCount) drainAndAttachPluginLogs(ctx) + if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && strings.TrimSpace(traceID) != "" { + tracer.CompleteAndFlushTrace(strings.TrimSpace(traceID)) + } cleanup() if bifrostErr != nil { return nil, bifrostErr @@ -3900,6 +3917,9 @@ func (bifrost *Bifrost) RunStreamPreHooks(ctx *schemas.BifrostContext, req *sche if shortCircuit.Response != nil { resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, shortCircuit.Response, nil, preCount) drainAndAttachPluginLogs(ctx) + if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && strings.TrimSpace(traceID) != "" { + tracer.CompleteAndFlushTrace(strings.TrimSpace(traceID)) + } cleanup() if bifrostErr != nil { return nil, bifrostErr @@ -3934,6 +3954,94 @@ func (bifrost *Bifrost) RunStreamPreHooks(ctx *schemas.BifrostContext, req *sche }, nil } +// RunRealtimeTurnPreHooks acquires a plugin pipeline and runs LLM pre-hooks for +// a single realtime turn. Unlike generic stream hooks, realtime turns do not +// support short-circuit responses in v1 because the transports cannot yet emit a +// fully synthetic assistant turn without an upstream generation. +func (bifrost *Bifrost) RunRealtimeTurnPreHooks(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*RealtimeTurnHooks, *schemas.BifrostError) { + if req == nil { + bifrostErr := newBifrostErrorFromMsg("realtime turn request is nil") + bifrostErr.ExtraFields.RequestType = schemas.RealtimeRequest + return nil, bifrostErr + } + if ctx == nil { + ctx = bifrost.ctx + } + + if _, ok := ctx.Value(schemas.BifrostContextKeyRequestID).(string); !ok { + ctx.SetValue(schemas.BifrostContextKeyRequestID, uuid.New().String()) + } + + tracer := bifrost.getTracer() + ctx.SetValue(schemas.BifrostContextKeyTracer, tracer) + + if _, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); !ok { + traceID := tracer.CreateTrace("") + if traceID != "" { + ctx.SetValue(schemas.BifrostContextKeyTraceID, traceID) + } + } + + pipeline := bifrost.getPluginPipeline() + cleanup := func() { + if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && traceID != "" { + tracer.CleanupStreamAccumulator(traceID) + } + bifrost.releasePluginPipeline(pipeline) + } + provider, model, _ := req.GetRequestFields() + + preReq, shortCircuit, preCount := pipeline.RunLLMPreHooks(ctx, req) + if preReq == nil && shortCircuit == nil { + bifrostErr := newBifrostErrorFromMsg("bifrost request after plugin hooks cannot be nil") + bifrostErr.PopulateExtraFields(schemas.RealtimeRequest, provider, model, model) + _, bifrostErr = pipeline.RunPostLLMHooks(ctx, nil, bifrostErr, preCount) + drainAndAttachPluginLogs(ctx) + if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && strings.TrimSpace(traceID) != "" { + tracer.CompleteAndFlushTrace(strings.TrimSpace(traceID)) + } + cleanup() + return nil, bifrostErr + } + if shortCircuit != nil { + if shortCircuit.Error != nil { + shortCircuit.Error.PopulateExtraFields(schemas.RealtimeRequest, provider, model, model) + _, bifrostErr := pipeline.RunPostLLMHooks(ctx, nil, shortCircuit.Error, preCount) + drainAndAttachPluginLogs(ctx) + if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && strings.TrimSpace(traceID) != "" { + tracer.CompleteAndFlushTrace(strings.TrimSpace(traceID)) + } + cleanup() + if bifrostErr != nil { + return nil, bifrostErr + } + return nil, shortCircuit.Error + } + if shortCircuit.Response != nil { + // Short-circuit responses are not supported for realtime turns (v1). + // Treat this like an error turn so plugins can close pending state cleanly. + bifrostErr := newBifrostErrorFromMsg("realtime turn short-circuit responses are not supported") + bifrostErr.PopulateExtraFields(schemas.RealtimeRequest, provider, model, model) + _, bifrostErr = pipeline.RunPostLLMHooks(ctx, nil, bifrostErr, preCount) + drainAndAttachPluginLogs(ctx) + if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && strings.TrimSpace(traceID) != "" { + tracer.CompleteAndFlushTrace(strings.TrimSpace(traceID)) + } + cleanup() + return nil, bifrostErr + } + } + + return &RealtimeTurnHooks{ + PostHookRunner: func(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) { + resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, result, err, preCount) + drainAndAttachPluginLogs(ctx) + return resp, bifrostErr + }, + Cleanup: cleanup, + }, nil +} + // getProviderByKey retrieves a provider instance from the providers array by its provider key. // Returns the provider if found, or nil if no provider with the given key exists. func (bifrost *Bifrost) getProviderByKey(providerKey schemas.ModelProvider) schemas.Provider { @@ -5664,8 +5772,10 @@ func (p *PluginPipeline) RunPostLLMHooks(ctx *schemas.BifrostContext, resp *sche if runFrom > len(p.llmPlugins) { runFrom = len(p.llmPlugins) } - // Detect streaming mode - if StreamStartTime is set, we're in a streaming context - isStreaming := ctx.Value(schemas.BifrostContextKeyStreamStartTime) != nil + requestType, _, _, _ := GetResponseFields(resp, bifrostErr) + // Realtime turns carry StreamStartTime for plugin latency/final-chunk context, + // but they are finalized as one completed turn, not chunk-by-chunk stream output. + isStreaming := ctx.Value(schemas.BifrostContextKeyStreamStartTime) != nil && requestType != schemas.RealtimeRequest ctx.BlockRestrictedWrites() defer ctx.UnblockRestrictedWrites() var err error diff --git a/core/internal/llmtests/realtime.go b/core/internal/llmtests/realtime.go index 821aeba9eb..400f5f9cda 100644 --- a/core/internal/llmtests/realtime.go +++ b/core/internal/llmtests/realtime.go @@ -43,7 +43,7 @@ func RunRealtimeTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context, bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) defer bfCtx.Cancel() - key, err := client.SelectKeyForProvider(bfCtx, testConfig.Provider, testConfig.RealtimeModel) + key, err := client.SelectKeyForProviderRequestType(bfCtx, schemas.RealtimeRequest, testConfig.Provider, testConfig.RealtimeModel) if err != nil { t.Fatalf("failed to select key for provider %s: %v", testConfig.Provider, err) } diff --git a/core/internal/llmtests/websocket_responses.go b/core/internal/llmtests/websocket_responses.go index 420a049fb7..966463dade 100644 --- a/core/internal/llmtests/websocket_responses.go +++ b/core/internal/llmtests/websocket_responses.go @@ -38,7 +38,7 @@ func RunWebSocketResponsesTest(t *testing.T, client *bifrost.Bifrost, ctx contex bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline) defer bfCtx.Cancel() - key, err := client.SelectKeyForProvider(bfCtx, testConfig.Provider, testConfig.ChatModel) + key, err := client.SelectKeyForProviderRequestType(bfCtx, schemas.WebSocketResponsesRequest, testConfig.Provider, testConfig.ChatModel) if err != nil { t.Fatalf("failed to select key for provider %s: %v", testConfig.Provider, err) } diff --git a/core/providers/elevenlabs/realtime.go b/core/providers/elevenlabs/realtime.go index f124f58339..a18e1cd514 100644 --- a/core/providers/elevenlabs/realtime.go +++ b/core/providers/elevenlabs/realtime.go @@ -39,6 +39,44 @@ func (provider *ElevenlabsProvider) RealtimeHeaders(key schemas.Key) map[string] return headers } +// SupportsRealtimeWebRTC returns false — ElevenLabs WebRTC SDP exchange is not yet implemented. +func (provider *ElevenlabsProvider) SupportsRealtimeWebRTC() bool { + return false +} + +// ExchangeRealtimeWebRTCSDP is not yet implemented for ElevenLabs. +func (provider *ElevenlabsProvider) ExchangeRealtimeWebRTCSDP(_ *schemas.BifrostContext, _ schemas.Key, _ string, _ string, _ json.RawMessage) (string, *schemas.BifrostError) { + return "", &schemas.BifrostError{ + IsBifrostError: true, + StatusCode: schemas.Ptr(400), + Error: &schemas.ErrorField{Type: schemas.Ptr("invalid_request_error"), Message: "WebRTC SDP exchange is not yet implemented for ElevenLabs"}, + } +} + +func (provider *ElevenlabsProvider) ShouldStartRealtimeTurn(event *schemas.BifrostRealtimeEvent) bool { + return false +} + +func (provider *ElevenlabsProvider) RealtimeTurnFinalEvent() schemas.RealtimeEventType { + return schemas.RTEventResponseDone +} + +func (provider *ElevenlabsProvider) RealtimeWebRTCDataChannelLabel() string { + return "" +} + +func (provider *ElevenlabsProvider) RealtimeWebSocketSubprotocol() string { + return "" +} + +func (provider *ElevenlabsProvider) ShouldForwardRealtimeEvent(event *schemas.BifrostRealtimeEvent) bool { + return true +} + +func (provider *ElevenlabsProvider) ShouldAccumulateRealtimeOutput(eventType schemas.RealtimeEventType) bool { + return eventType == schemas.RTEventResponseDone +} + // ElevenLabs Conversational AI WebSocket event types const ( elConversationInitMetadata = "conversation_initiation_metadata" @@ -50,8 +88,8 @@ const ( elInterruption = "interruption" elClientToolCall = "client_tool_call" - elUserAudioChunk = "user_audio_chunk" - elPong = "pong" + elUserAudioChunk = "user_audio_chunk" + elPong = "pong" elClientToolResult = "client_tool_result" elContextualUpdate = "contextual_update" ) @@ -134,7 +172,7 @@ func (provider *ElevenlabsProvider) ToBifrostRealtimeEvent(providerEvent json.Ra } case elAgentResponse: - event.Type = schemas.RTEventResponseTextDone + event.Type = schemas.RTEventResponseDone if raw.AgentResponse != nil { var agentResp elevenlabsTranscriptEvent if err := json.Unmarshal(raw.AgentResponse, &agentResp); err == nil { @@ -194,10 +232,6 @@ func (provider *ElevenlabsProvider) ToBifrostRealtimeEvent(providerEvent json.Ra // ToProviderRealtimeEvent converts a unified Bifrost Realtime event to ElevenLabs' native JSON. func (provider *ElevenlabsProvider) ToProviderRealtimeEvent(bifrostEvent *schemas.BifrostRealtimeEvent) (json.RawMessage, error) { - if bifrostEvent.RawData != nil { - return bifrostEvent.RawData, nil - } - switch bifrostEvent.Type { case schemas.RTEventInputAudioAppend: if bifrostEvent.Delta == nil { diff --git a/core/providers/openai/realtime.go b/core/providers/openai/realtime.go index b73db4ea24..8c88382297 100644 --- a/core/providers/openai/realtime.go +++ b/core/providers/openai/realtime.go @@ -1,13 +1,17 @@ package openai import ( + "bytes" "encoding/json" "fmt" + "mime/multipart" + "net/http" "net/url" "strings" providerUtils "github.com/maximhq/bifrost/core/providers/utils" "github.com/maximhq/bifrost/core/schemas" + "github.com/valyala/fasthttp" ) // SupportsRealtimeAPI returns true since OpenAI natively supports the Realtime API. @@ -28,7 +32,6 @@ func (provider *OpenAIProvider) RealtimeWebSocketURL(key schemas.Key, model stri func (provider *OpenAIProvider) RealtimeHeaders(key schemas.Key) map[string]string { headers := map[string]string{ "Authorization": "Bearer " + key.Value.GetValue(), - "OpenAI-Beta": "realtime=v1", } for k, v := range provider.networkConfig.ExtraHeaders { headers[k] = v @@ -36,6 +39,380 @@ func (provider *OpenAIProvider) RealtimeHeaders(key schemas.Key) map[string]stri return headers } +// SupportsRealtimeWebRTC reports that OpenAI supports WebRTC SDP exchange. +func (provider *OpenAIProvider) SupportsRealtimeWebRTC() bool { + return true +} + +// ExchangeRealtimeWebRTCSDP performs the GA SDP exchange via multipart POST to /v1/realtime/calls. +func (provider *OpenAIProvider) ExchangeRealtimeWebRTCSDP( + ctx *schemas.BifrostContext, + key schemas.Key, + model string, + sdp string, + session json.RawMessage, +) (string, *schemas.BifrostError) { + path := "/v1/realtime/calls" + if session == nil && strings.TrimSpace(model) != "" { + path += "?model=" + url.QueryEscape(model) + } + return provider.exchangeWebRTCSDP(ctx, key, path, sdp, session) +} + +// ExchangeLegacyRealtimeWebRTCSDP performs the beta SDP exchange via multipart POST to /v1/realtime. +// Same multipart format but targets the legacy endpoint with model in the URL. +func (provider *OpenAIProvider) ExchangeLegacyRealtimeWebRTCSDP( + ctx *schemas.BifrostContext, + key schemas.Key, + sdp string, + session json.RawMessage, + model string, +) (string, *schemas.BifrostError) { + return provider.exchangeWebRTCSDP(ctx, key, "/v1/realtime?model="+url.QueryEscape(model), sdp, session) +} + +// exchangeWebRTCSDP is the shared multipart SDP exchange implementation. +// Builds a multipart body with sdp + optional session, POSTs to the given path. +func (provider *OpenAIProvider) exchangeWebRTCSDP( + ctx *schemas.BifrostContext, + key schemas.Key, + path string, + sdp string, + session json.RawMessage, +) (string, *schemas.BifrostError) { + bodyBuf := &bytes.Buffer{} + writer := multipart.NewWriter(bodyBuf) + if err := writer.WriteField("sdp", sdp); err != nil { + return "", newRealtimeWebRTCSDPError(fasthttp.StatusInternalServerError, "server_error", "failed to encode upstream SDP body", err) + } + if session != nil { + if err := writer.WriteField("session", string(session)); err != nil { + return "", newRealtimeWebRTCSDPError(fasthttp.StatusInternalServerError, "server_error", "failed to encode upstream session body", err) + } + } + if err := writer.Close(); err != nil { + return "", newRealtimeWebRTCSDPError(fasthttp.StatusInternalServerError, "server_error", "failed to finalize upstream SDP body", err) + } + + req := fasthttp.AcquireRequest() + resp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseRequest(req) + defer fasthttp.ReleaseResponse(resp) + + req.SetRequestURI(provider.buildRequestURL(ctx, path, schemas.RealtimeRequest)) + req.Header.SetMethod(http.MethodPost) + req.Header.SetContentType(writer.FormDataContentType()) + req.Header.Set("Authorization", "Bearer "+key.Value.GetValue()) + for k, v := range provider.networkConfig.ExtraHeaders { + req.Header.Set(k, v) + } + if headers, _ := ctx.Value(schemas.BifrostContextKeyRequestHeaders).(map[string]string); headers != nil { + if agentsSDK := headers["x-openai-agents-sdk"]; agentsSDK != "" { + req.Header.Set("X-OpenAI-Agents-SDK", agentsSDK) + } + } + req.SetBody(bodyBuf.Bytes()) + + _, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp) + defer wait() + if bifrostErr != nil { + return "", bifrostErr + } + + answerBody := resp.Body() + if resp.StatusCode() < fasthttp.StatusOK || resp.StatusCode() >= fasthttp.StatusMultipleChoices { + return "", provider.realtimeWebRTCUpstreamError(ctx, resp.StatusCode(), answerBody) + } + + return string(answerBody), nil +} + +func (provider *OpenAIProvider) realtimeWebRTCUpstreamError(ctx *schemas.BifrostContext, statusCode int, body []byte) *schemas.BifrostError { + bifrostErr := &schemas.BifrostError{ + IsBifrostError: false, + StatusCode: schemas.Ptr(fasthttp.StatusBadGateway), + Error: &schemas.ErrorField{ + Type: schemas.Ptr("upstream_connection_error"), + Message: fmt.Sprintf("upstream realtime WebRTC handshake failed for %s", provider.GetProviderKey()), + }, + ExtraFields: schemas.BifrostErrorExtraFields{ + RequestType: schemas.RealtimeRequest, + Provider: provider.GetProviderKey(), + }, + } + if providerUtils.ShouldSendBackRawResponse(ctx, provider.sendBackRawResponse) { + bifrostErr.ExtraFields.RawResponse = map[string]any{ + "status": statusCode, + "body": string(body), + } + } + return bifrostErr +} + +func newRealtimeWebRTCSDPError(status int, errorType, message string, err error) *schemas.BifrostError { + bifrostErr := &schemas.BifrostError{ + IsBifrostError: true, + StatusCode: schemas.Ptr(status), + Error: &schemas.ErrorField{ + Type: schemas.Ptr(errorType), + Message: message, + }, + } + if err != nil { + bifrostErr.Error.Error = err + } + return bifrostErr +} + +func (provider *OpenAIProvider) ShouldStartRealtimeTurn(event *schemas.BifrostRealtimeEvent) bool { + if event == nil { + return false + } + switch event.Type { + case schemas.RTEventResponseCreate, schemas.RTEventInputAudioBufferCommitted: + return true + default: + return false + } +} + +func (provider *OpenAIProvider) RealtimeTurnFinalEvent() schemas.RealtimeEventType { + return schemas.RTEventResponseDone +} + +func (provider *OpenAIProvider) RealtimeWebRTCDataChannelLabel() string { + return "oai-events" +} + +func (provider *OpenAIProvider) RealtimeWebSocketSubprotocol() string { + return "realtime" +} + +func (provider *OpenAIProvider) ShouldForwardRealtimeEvent(event *schemas.BifrostRealtimeEvent) bool { + return true +} + +func (provider *OpenAIProvider) ShouldAccumulateRealtimeOutput(eventType schemas.RealtimeEventType) bool { + switch eventType { + case schemas.RTEventResponseTextDelta, + schemas.RTEventResponseAudioTransDelta, + schemas.RealtimeEventType("response.output_text.delta"), + schemas.RealtimeEventType("response.output_audio_transcript.delta"): + return true + default: + return false + } +} + +// CreateRealtimeClientSecret mints an OpenAI Realtime client secret and returns +// the native OpenAI response body unchanged. +func (provider *OpenAIProvider) CreateRealtimeClientSecret( + ctx *schemas.BifrostContext, + key schemas.Key, + endpointType schemas.RealtimeSessionEndpointType, + rawRequest json.RawMessage, +) (*schemas.BifrostPassthroughResponse, *schemas.BifrostError) { + if err := providerUtils.CheckOperationAllowed(schemas.OpenAI, provider.customProviderConfig, schemas.RealtimeRequest); err != nil { + return nil, err + } + + normalizedBody, requestedModel, bifrostErr := normalizeRealtimeClientSecretRequest(rawRequest, provider.GetProviderKey(), endpointType) + if bifrostErr != nil { + return nil, bifrostErr + } + req := fasthttp.AcquireRequest() + resp := fasthttp.AcquireResponse() + defer fasthttp.ReleaseRequest(req) + defer fasthttp.ReleaseResponse(resp) + + req.SetRequestURI(provider.buildRequestURL(ctx, realtimeSessionUpstreamPath(endpointType), schemas.RealtimeRequest)) + req.Header.SetMethod(http.MethodPost) + req.Header.SetContentType("application/json") + for k, v := range provider.realtimeSessionHeaders(key, endpointType) { + req.Header.Set(k, v) + } + req.SetBody(normalizedBody) + + latency, bifrostErr, wait := providerUtils.MakeRequestWithContext(ctx, provider.client, req, resp) + defer wait() + if bifrostErr != nil { + return nil, bifrostErr + } + + headers := providerUtils.ExtractProviderResponseHeaders(resp) + ctx.SetValue(schemas.BifrostContextKeyProviderResponseHeaders, headers) + + if resp.StatusCode() < fasthttp.StatusOK || resp.StatusCode() >= fasthttp.StatusMultipleChoices { + return nil, ParseOpenAIError(resp) + } + + body, err := providerUtils.CheckAndDecodeBody(resp) + if err != nil { + return nil, providerUtils.NewBifrostOperationError("failed to decode response body", err) + } + for k := range headers { + if strings.EqualFold(k, "Content-Encoding") || strings.EqualFold(k, "Content-Length") { + delete(headers, k) + } + } + + out := &schemas.BifrostPassthroughResponse{ + StatusCode: resp.StatusCode(), + Headers: headers, + Body: body, + } + out.ExtraFields.Provider = provider.GetProviderKey() + out.ExtraFields.OriginalModelRequested = requestedModel + out.ExtraFields.RequestType = schemas.RealtimeRequest + out.ExtraFields.Latency = latency.Milliseconds() + if providerUtils.ShouldSendBackRawRequest(ctx, provider.sendBackRawRequest) { + providerUtils.ParseAndSetRawRequestIfJSON(req, &out.ExtraFields) + } + + return out, nil +} + +func normalizeRealtimeClientSecretRequest( + rawRequest json.RawMessage, + defaultProvider schemas.ModelProvider, + endpointType schemas.RealtimeSessionEndpointType, +) ([]byte, string, *schemas.BifrostError) { + root, bifrostErr := schemas.ParseRealtimeClientSecretBody(rawRequest) + if bifrostErr != nil { + return nil, "", bifrostErr + } + + modelValue, bifrostErr := schemas.ExtractRealtimeClientSecretModel(root) + if bifrostErr != nil { + return nil, "", bifrostErr + } + providerKey, normalizedModel := schemas.ParseModelString(modelValue, defaultProvider) + if normalizedModel == "" { + return nil, "", newRealtimeClientSecretError(fasthttp.StatusBadRequest, "invalid_request_error", "session.model is required", nil) + } + if providerKey == "" { + providerKey = defaultProvider + } + if providerKey == "" { + return nil, "", newRealtimeClientSecretError(fasthttp.StatusBadRequest, "invalid_request_error", "unable to determine provider from model", nil) + } + + if endpointType == schemas.RealtimeSessionEndpointSessions { + return normalizeRealtimeSessionsRequest(root, normalizedModel) + } + + return normalizeRealtimeClientSecretsRequest(root, normalizedModel) +} + +func normalizeRealtimeClientSecretsRequest( + root map[string]json.RawMessage, + normalizedModel string, +) ([]byte, string, *schemas.BifrostError) { + session := map[string]json.RawMessage{} + if existingSession, ok := root["session"]; ok && len(existingSession) > 0 && !bytes.Equal(existingSession, []byte("null")) { + if err := json.Unmarshal(existingSession, &session); err != nil { + return nil, "", newRealtimeClientSecretError(fasthttp.StatusBadRequest, "invalid_request_error", "session must be an object", err) + } + } + + modelJSON, marshalErr := json.Marshal(normalizedModel) + if marshalErr != nil { + return nil, "", newRealtimeClientSecretError(fasthttp.StatusInternalServerError, "server_error", "failed to encode normalized model", marshalErr) + } + session["model"] = modelJSON + if _, ok := session["type"]; !ok { + typeJSON, marshalErr := json.Marshal("realtime") + if marshalErr != nil { + return nil, "", newRealtimeClientSecretError(fasthttp.StatusInternalServerError, "server_error", "failed to encode realtime session type", marshalErr) + } + session["type"] = typeJSON + } + delete(root, "model") + + sessionJSON, marshalErr := json.Marshal(session) + if marshalErr != nil { + return nil, "", newRealtimeClientSecretError(fasthttp.StatusInternalServerError, "server_error", "failed to encode realtime session", marshalErr) + } + root["session"] = sessionJSON + + normalizedBody, marshalErr := json.Marshal(root) + if marshalErr != nil { + return nil, "", newRealtimeClientSecretError(fasthttp.StatusInternalServerError, "server_error", "failed to encode realtime request", marshalErr) + } + + return normalizedBody, normalizedModel, nil +} + +func normalizeRealtimeSessionsRequest( + root map[string]json.RawMessage, + normalizedModel string, +) ([]byte, string, *schemas.BifrostError) { + if existingSession, ok := root["session"]; ok && len(existingSession) > 0 && !bytes.Equal(existingSession, []byte("null")) { + session := map[string]json.RawMessage{} + if err := json.Unmarshal(existingSession, &session); err != nil { + return nil, "", newRealtimeClientSecretError(fasthttp.StatusBadRequest, "invalid_request_error", "session must be an object", err) + } + for key, value := range session { + if _, exists := root[key]; !exists { + root[key] = value + } + } + } + + modelJSON, marshalErr := json.Marshal(normalizedModel) + if marshalErr != nil { + return nil, "", newRealtimeClientSecretError(fasthttp.StatusInternalServerError, "server_error", "failed to encode normalized model", marshalErr) + } + root["model"] = modelJSON + delete(root, "session") + + normalizedBody, marshalErr := json.Marshal(root) + if marshalErr != nil { + return nil, "", newRealtimeClientSecretError(fasthttp.StatusInternalServerError, "server_error", "failed to encode realtime request", marshalErr) + } + + return normalizedBody, normalizedModel, nil +} + +func (provider *OpenAIProvider) realtimeSessionHeaders( + key schemas.Key, + endpointType schemas.RealtimeSessionEndpointType, +) map[string]string { + headers := map[string]string{ + "Authorization": "Bearer " + key.Value.GetValue(), + } + if endpointType == schemas.RealtimeSessionEndpointSessions { + headers["OpenAI-Beta"] = "realtime=v1" + } + for k, v := range provider.networkConfig.ExtraHeaders { + headers[k] = v + } + return headers +} + +func realtimeSessionUpstreamPath(endpointType schemas.RealtimeSessionEndpointType) string { + if endpointType == schemas.RealtimeSessionEndpointSessions { + return "/v1/realtime/sessions" + } + return "/v1/realtime/client_secrets" +} + +func newRealtimeClientSecretError(status int, errorType, message string, err error) *schemas.BifrostError { + return &schemas.BifrostError{ + IsBifrostError: false, + StatusCode: schemas.Ptr(status), + Error: &schemas.ErrorField{ + Type: schemas.Ptr(errorType), + Message: message, + Error: err, + }, + ExtraFields: schemas.BifrostErrorExtraFields{ + RequestType: schemas.RealtimeRequest, + Provider: schemas.OpenAI, + }, + } +} + // openAIRealtimeEvent is the raw shape of an OpenAI Realtime protocol event. type openAIRealtimeEvent struct { Type string `json:"type"` @@ -44,15 +421,17 @@ type openAIRealtimeEvent struct { Conversation json.RawMessage `json:"conversation,omitempty"` Item json.RawMessage `json:"item,omitempty"` Response json.RawMessage `json:"response,omitempty"` + Part json.RawMessage `json:"part,omitempty"` Delta string `json:"delta,omitempty"` Audio string `json:"audio,omitempty"` Transcript string `json:"transcript,omitempty"` Text string `json:"text,omitempty"` Error json.RawMessage `json:"error,omitempty"` ItemID string `json:"item_id,omitempty"` - OutputIndex int `json:"output_index,omitempty"` - ContentIndex int `json:"content_index,omitempty"` + OutputIndex *int `json:"output_index,omitempty"` + ContentIndex *int `json:"content_index,omitempty"` ResponseID string `json:"response_id,omitempty"` + AudioEndMS *int `json:"audio_end_ms,omitempty"` PreviousItemID string `json:"previous_item_id,omitempty"` } @@ -105,6 +484,17 @@ func (provider *OpenAIProvider) ToBifrostRealtimeEvent(providerEvent json.RawMes EventID: raw.EventID, RawData: providerEvent, } + setRealtimeExtraParam(event, "item_id", raw.ItemID) + setRealtimeExtraParam(event, "previous_item_id", raw.PreviousItemID) + setRealtimeExtraParam(event, "output_index", raw.OutputIndex) + setRealtimeExtraParam(event, "content_index", raw.ContentIndex) + setRealtimeExtraParam(event, "response_id", raw.ResponseID) + setRealtimeExtraParam(event, "audio_end_ms", raw.AudioEndMS) + setRealtimeExtraParam(event, "transcript", raw.Transcript) + setRealtimeExtraParam(event, "text", raw.Text) + setRealtimeExtraParam(event, "conversation", raw.Conversation) + setRealtimeExtraParam(event, "response", raw.Response) + setRealtimeExtraParam(event, "part", raw.Part) switch { case raw.Session != nil: @@ -123,8 +513,10 @@ func (provider *OpenAIProvider) ToBifrostRealtimeEvent(providerEvent json.RawMes OutputAudioType: sess.OutputAudioType, Tools: sess.Tools, } + if extra := extractRealtimeNestedParams(raw.Session, "id", "model", "modalities", "instructions", "voice", "temperature", "max_output_tokens", "turn_detection", "input_audio_format", "output_audio_type", "tools"); len(extra) > 0 { + event.Session.ExtraParams = extra + } } - case raw.Item != nil: var item openAIRealtimeItem if err := json.Unmarshal(raw.Item, &item); err == nil { @@ -139,6 +531,9 @@ func (provider *OpenAIProvider) ToBifrostRealtimeEvent(providerEvent json.RawMes Arguments: item.Arguments, Output: item.Output, } + if extra := extractRealtimeNestedParams(raw.Item, "id", "type", "role", "status", "content", "name", "call_id", "arguments", "output"); len(extra) > 0 { + event.Item.ExtraParams = extra + } } case raw.Error != nil: @@ -150,6 +545,9 @@ func (provider *OpenAIProvider) ToBifrostRealtimeEvent(providerEvent json.RawMes Message: rtErr.Message, Param: rtErr.Param, } + if extra := extractRealtimeNestedParams(raw.Error, "type", "code", "message", "param"); len(extra) > 0 { + event.Error.ExtraParams = extra + } } } @@ -159,8 +557,8 @@ func (provider *OpenAIProvider) ToBifrostRealtimeEvent(providerEvent json.RawMes Audio: raw.Audio, Transcript: raw.Transcript, ItemID: raw.ItemID, - OutputIdx: &raw.OutputIndex, - ContentIdx: &raw.ContentIndex, + OutputIdx: raw.OutputIndex, + ContentIdx: raw.ContentIndex, ResponseID: raw.ResponseID, } if raw.Delta != "" { @@ -175,19 +573,19 @@ func (provider *OpenAIProvider) ToBifrostRealtimeEvent(providerEvent json.RawMes // ToProviderRealtimeEvent converts a unified Bifrost Realtime event back to OpenAI's native JSON. func (provider *OpenAIProvider) ToProviderRealtimeEvent(bifrostEvent *schemas.BifrostRealtimeEvent) (json.RawMessage, error) { - if bifrostEvent.RawData != nil { - return bifrostEvent.RawData, nil - } - out := map[string]interface{}{ "type": string(bifrostEvent.Type), } if bifrostEvent.EventID != "" { out["event_id"] = bifrostEvent.EventID } + mergeRealtimeExtraParams(out, bifrostEvent.ExtraParams) if bifrostEvent.Session != nil { sess := map[string]interface{}{} + if bifrostEvent.Session.ID != "" && bifrostEvent.Type != schemas.RTEventSessionUpdate { + sess["id"] = bifrostEvent.Session.ID + } if bifrostEvent.Session.Model != "" { sess["model"] = bifrostEvent.Session.Model } @@ -218,6 +616,7 @@ func (provider *OpenAIProvider) ToProviderRealtimeEvent(bifrostEvent *schemas.Bi if bifrostEvent.Session.Tools != nil { sess["tools"] = bifrostEvent.Session.Tools } + mergeRealtimeSessionExtraParams(sess, bifrostEvent.Session.ExtraParams, bifrostEvent.Type) out["session"] = sess } @@ -231,6 +630,9 @@ func (provider *OpenAIProvider) ToProviderRealtimeEvent(bifrostEvent *schemas.Bi if bifrostEvent.Item.Role != "" { item["role"] = bifrostEvent.Item.Role } + if bifrostEvent.Item.Status != "" { + item["status"] = bifrostEvent.Item.Status + } if bifrostEvent.Item.Content != nil { item["content"] = bifrostEvent.Item.Content } @@ -246,9 +648,28 @@ func (provider *OpenAIProvider) ToProviderRealtimeEvent(bifrostEvent *schemas.Bi if bifrostEvent.Item.Output != "" { item["output"] = bifrostEvent.Item.Output } + mergeRealtimeExtraParams(item, bifrostEvent.Item.ExtraParams) out["item"] = item } + if bifrostEvent.Error != nil { + rtErr := map[string]interface{}{} + if bifrostEvent.Error.Type != "" { + rtErr["type"] = bifrostEvent.Error.Type + } + if bifrostEvent.Error.Code != "" { + rtErr["code"] = bifrostEvent.Error.Code + } + if bifrostEvent.Error.Message != "" { + rtErr["message"] = bifrostEvent.Error.Message + } + if bifrostEvent.Error.Param != "" { + rtErr["param"] = bifrostEvent.Error.Param + } + mergeRealtimeExtraParams(rtErr, bifrostEvent.Error.ExtraParams) + out["error"] = rtErr + } + if bifrostEvent.Delta != nil { if bifrostEvent.Delta.Text != "" { out["delta"] = bifrostEvent.Delta.Text @@ -259,16 +680,16 @@ func (provider *OpenAIProvider) ToProviderRealtimeEvent(bifrostEvent *schemas.Bi if bifrostEvent.Delta.Transcript != "" { out["transcript"] = bifrostEvent.Delta.Transcript } - if bifrostEvent.Delta.ItemID != "" { + if bifrostEvent.Delta.ItemID != "" && !hasRealtimeExtraParam(bifrostEvent.ExtraParams, "item_id") { out["item_id"] = bifrostEvent.Delta.ItemID } - if bifrostEvent.Delta.OutputIdx != nil { + if bifrostEvent.Delta.OutputIdx != nil && !hasRealtimeExtraParam(bifrostEvent.ExtraParams, "output_index") { out["output_index"] = *bifrostEvent.Delta.OutputIdx } - if bifrostEvent.Delta.ContentIdx != nil { + if bifrostEvent.Delta.ContentIdx != nil && !hasRealtimeExtraParam(bifrostEvent.ExtraParams, "content_index") { out["content_index"] = *bifrostEvent.Delta.ContentIdx } - if bifrostEvent.Delta.ResponseID != "" { + if bifrostEvent.Delta.ResponseID != "" && !hasRealtimeExtraParam(bifrostEvent.ExtraParams, "response_id") { out["response_id"] = bifrostEvent.Delta.ResponseID } } @@ -276,11 +697,269 @@ func (provider *OpenAIProvider) ToProviderRealtimeEvent(bifrostEvent *schemas.Bi return providerUtils.MarshalSorted(out) } +func mergeRealtimeSessionExtraParams(out map[string]interface{}, params map[string]json.RawMessage, eventType schemas.RealtimeEventType) { + filtered := params + if eventType == schemas.RTEventSessionUpdate && len(params) > 0 { + filtered = make(map[string]json.RawMessage, len(params)) + for key, value := range params { + switch key { + case "id", "object", "expires_at", "client_secret": + continue + default: + filtered[key] = value + } + } + } + mergeRealtimeExtraParams(out, filtered) +} + +func (provider *OpenAIProvider) ExtractRealtimeTurnUsage(terminalEventRaw []byte) *schemas.BifrostLLMUsage { + if len(terminalEventRaw) == 0 { + return nil + } + + var parsed openAIRealtimeResponseDoneEnvelope + if err := json.Unmarshal(terminalEventRaw, &parsed); err != nil || parsed.Response.Usage == nil { + return nil + } + + usage := &schemas.BifrostLLMUsage{ + PromptTokens: parsed.Response.Usage.InputTokens, + CompletionTokens: parsed.Response.Usage.OutputTokens, + TotalTokens: parsed.Response.Usage.TotalTokens, + } + + if parsed.Response.Usage.InputTokenDetails != nil { + usage.PromptTokensDetails = &schemas.ChatPromptTokensDetails{ + TextTokens: parsed.Response.Usage.InputTokenDetails.TextTokens, + AudioTokens: parsed.Response.Usage.InputTokenDetails.AudioTokens, + ImageTokens: parsed.Response.Usage.InputTokenDetails.ImageTokens, + CachedReadTokens: parsed.Response.Usage.InputTokenDetails.CachedTokens, + } + } + + if parsed.Response.Usage.OutputTokenDetails != nil { + usage.CompletionTokensDetails = &schemas.ChatCompletionTokensDetails{ + TextTokens: parsed.Response.Usage.OutputTokenDetails.TextTokens, + AudioTokens: parsed.Response.Usage.OutputTokenDetails.AudioTokens, + ReasoningTokens: parsed.Response.Usage.OutputTokenDetails.ReasoningTokens, + ImageTokens: parsed.Response.Usage.OutputTokenDetails.ImageTokens, + CitationTokens: parsed.Response.Usage.OutputTokenDetails.CitationTokens, + NumSearchQueries: parsed.Response.Usage.OutputTokenDetails.NumSearchQueries, + AcceptedPredictionTokens: parsed.Response.Usage.OutputTokenDetails.AcceptedPredictionTokens, + RejectedPredictionTokens: parsed.Response.Usage.OutputTokenDetails.RejectedPredictionTokens, + } + } + + return usage +} + +func (provider *OpenAIProvider) ExtractRealtimeTurnOutput(terminalEventRaw []byte) *schemas.ChatMessage { + if len(terminalEventRaw) == 0 { + return nil + } + + var parsed openAIRealtimeResponseDoneEnvelope + if err := json.Unmarshal(terminalEventRaw, &parsed); err != nil { + return nil + } + + content := extractOpenAIRealtimeResponseDoneAssistantText(parsed.Response.Output) + toolCalls := extractOpenAIRealtimeResponseDoneToolCalls(parsed.Response.Output) + if content == "" && len(toolCalls) == 0 { + return nil + } + + message := &schemas.ChatMessage{Role: schemas.ChatMessageRoleAssistant} + if content != "" { + message.Content = &schemas.ChatMessageContent{ContentStr: schemas.Ptr(content)} + } + if len(toolCalls) > 0 { + message.ChatAssistantMessage = &schemas.ChatAssistantMessage{ToolCalls: toolCalls} + } + + return message +} + +type openAIRealtimeResponseDoneEnvelope struct { + Response struct { + Output []openAIRealtimeResponseDoneOutput `json:"output"` + Usage *openAIRealtimeResponseDoneUsage `json:"usage"` + } `json:"response"` +} + +type openAIRealtimeResponseDoneOutput struct { + ID string `json:"id"` + Type string `json:"type"` + Name string `json:"name"` + CallID string `json:"call_id"` + Arguments string `json:"arguments"` + Content []openAIRealtimeResponseDoneBlock `json:"content"` +} + +type openAIRealtimeResponseDoneBlock struct { + Text string `json:"text"` + Transcript string `json:"transcript"` + Refusal string `json:"refusal"` +} + +type openAIRealtimeResponseDoneUsage struct { + TotalTokens int `json:"total_tokens"` + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + InputTokenDetails *openAIRealtimeResponseDoneInputTokenUsage `json:"input_token_details"` + OutputTokenDetails *openAIRealtimeResponseDoneOutputTokenUsage `json:"output_token_details"` +} + +type openAIRealtimeResponseDoneInputTokenUsage struct { + TextTokens int `json:"text_tokens"` + AudioTokens int `json:"audio_tokens"` + ImageTokens int `json:"image_tokens"` + CachedTokens int `json:"cached_tokens"` +} + +type openAIRealtimeResponseDoneOutputTokenUsage struct { + TextTokens int `json:"text_tokens"` + AudioTokens int `json:"audio_tokens"` + ReasoningTokens int `json:"reasoning_tokens"` + ImageTokens *int `json:"image_tokens"` + CitationTokens *int `json:"citation_tokens"` + NumSearchQueries *int `json:"num_search_queries"` + AcceptedPredictionTokens int `json:"accepted_prediction_tokens"` + RejectedPredictionTokens int `json:"rejected_prediction_tokens"` +} + +func extractOpenAIRealtimeResponseDoneAssistantText(outputs []openAIRealtimeResponseDoneOutput) string { + var sb strings.Builder + for _, output := range outputs { + if output.Type != "message" { + continue + } + for _, block := range output.Content { + switch { + case strings.TrimSpace(block.Text) != "": + sb.WriteString(block.Text) + case strings.TrimSpace(block.Transcript) != "": + sb.WriteString(block.Transcript) + case strings.TrimSpace(block.Refusal) != "": + sb.WriteString(block.Refusal) + } + } + } + return strings.TrimSpace(sb.String()) +} + +func extractOpenAIRealtimeResponseDoneToolCalls(outputs []openAIRealtimeResponseDoneOutput) []schemas.ChatAssistantMessageToolCall { + toolCalls := make([]schemas.ChatAssistantMessageToolCall, 0) + for _, output := range outputs { + if output.Type != "function_call" { + continue + } + + name := strings.TrimSpace(output.Name) + if name == "" { + continue + } + + toolType := "function" + id := strings.TrimSpace(output.CallID) + if id == "" { + id = strings.TrimSpace(output.ID) + } + + toolCall := schemas.ChatAssistantMessageToolCall{ + Index: uint16(len(toolCalls)), + Type: &toolType, + Function: schemas.ChatAssistantMessageToolCallFunction{ + Name: schemas.Ptr(name), + Arguments: output.Arguments, + }, + } + if id != "" { + toolCall.ID = schemas.Ptr(id) + } + + toolCalls = append(toolCalls, toolCall) + } + return toolCalls +} + +func setRealtimeExtraParam(event *schemas.BifrostRealtimeEvent, key string, value any) { + if event == nil || key == "" || value == nil { + return + } + + switch v := value.(type) { + case string: + if v == "" { + return + } + case *int: + if v == nil { + return + } + case json.RawMessage: + if len(v) == 0 || string(v) == "null" { + return + } + } + + raw, err := json.Marshal(value) + if err != nil { + return + } + if event.ExtraParams == nil { + event.ExtraParams = make(map[string]json.RawMessage) + } + event.ExtraParams[key] = raw +} + +func mergeRealtimeExtraParams(out map[string]interface{}, params map[string]json.RawMessage) { + for key, raw := range params { + if len(raw) == 0 { + continue + } + var value any + if err := json.Unmarshal(raw, &value); err != nil { + continue + } + out[key] = value + } +} + +func hasRealtimeExtraParam(params map[string]json.RawMessage, key string) bool { + if params == nil { + return false + } + raw, ok := params[key] + return ok && len(raw) > 0 +} + +func extractRealtimeNestedParams(raw json.RawMessage, knownKeys ...string) map[string]json.RawMessage { + if len(raw) == 0 { + return nil + } + root := map[string]json.RawMessage{} + if err := json.Unmarshal(raw, &root); err != nil { + return nil + } + for _, key := range knownKeys { + delete(root, key) + } + if len(root) == 0 { + return nil + } + return root +} + func isRealtimeDeltaEvent(eventType string) bool { switch eventType { case "response.text.delta", + "response.output_text.delta", "response.audio.delta", + "response.output_audio.delta", "response.audio_transcript.delta", + "response.output_audio_transcript.delta", "conversation.item.input_audio_transcription.delta": return true } diff --git a/core/providers/openai/realtime_test.go b/core/providers/openai/realtime_test.go new file mode 100644 index 0000000000..6b7f76f98f --- /dev/null +++ b/core/providers/openai/realtime_test.go @@ -0,0 +1,561 @@ +package openai + +import ( + "encoding/json" + "testing" + + "github.com/maximhq/bifrost/core/schemas" +) + +func TestNormalizeRealtimeClientSecretRequest(t *testing.T) { + t.Parallel() + + body, model, bifrostErr := normalizeRealtimeClientSecretRequest( + json.RawMessage(`{"model":"openai/gpt-4o-realtime-preview","voice":"alloy"}`), + schemas.OpenAI, + schemas.RealtimeSessionEndpointClientSecrets, + ) + if bifrostErr != nil { + t.Fatalf("normalizeRealtimeClientSecretRequest() error = %v", bifrostErr) + } + if model != "gpt-4o-realtime-preview" { + t.Fatalf("model = %q, want %q", model, "gpt-4o-realtime-preview") + } + + var payload map[string]json.RawMessage + if err := json.Unmarshal(body, &payload); err != nil { + t.Fatalf("failed to unmarshal normalized body: %v", err) + } + if _, ok := payload["model"]; ok { + t.Fatal("top-level model should be removed after normalization") + } + + var session map[string]any + if err := json.Unmarshal(payload["session"], &session); err != nil { + t.Fatalf("failed to unmarshal session: %v", err) + } + if session["model"] != "gpt-4o-realtime-preview" { + t.Fatalf("session.model = %v, want %q", session["model"], "gpt-4o-realtime-preview") + } + if session["type"] != "realtime" { + t.Fatalf("session.type = %v, want %q", session["type"], "realtime") + } +} + +func TestNormalizeRealtimeClientSecretRequestUsesDefaultProvider(t *testing.T) { + t.Parallel() + + body, model, bifrostErr := normalizeRealtimeClientSecretRequest( + json.RawMessage(`{"session":{"model":"gpt-4o-realtime-preview"}}`), + schemas.OpenAI, + schemas.RealtimeSessionEndpointClientSecrets, + ) + if bifrostErr != nil { + t.Fatalf("normalizeRealtimeClientSecretRequest() error = %v", bifrostErr) + } + if model != "gpt-4o-realtime-preview" { + t.Fatalf("model = %q, want %q", model, "gpt-4o-realtime-preview") + } + + var payload map[string]json.RawMessage + if err := json.Unmarshal(body, &payload); err != nil { + t.Fatalf("failed to unmarshal normalized body: %v", err) + } + + var session map[string]any + if err := json.Unmarshal(payload["session"], &session); err != nil { + t.Fatalf("failed to unmarshal session: %v", err) + } + if session["model"] != "gpt-4o-realtime-preview" { + t.Fatalf("session.model = %v, want %q", session["model"], "gpt-4o-realtime-preview") + } + if session["type"] != "realtime" { + t.Fatalf("session.type = %v, want %q", session["type"], "realtime") + } +} + +func TestNormalizeRealtimeSessionsRequest(t *testing.T) { + t.Parallel() + + body, model, bifrostErr := normalizeRealtimeClientSecretRequest( + json.RawMessage(`{"session":{"model":"openai/gpt-4o-realtime-preview","voice":"alloy"}}`), + schemas.OpenAI, + schemas.RealtimeSessionEndpointSessions, + ) + if bifrostErr != nil { + t.Fatalf("normalizeRealtimeClientSecretRequest() error = %v", bifrostErr) + } + if model != "gpt-4o-realtime-preview" { + t.Fatalf("model = %q, want %q", model, "gpt-4o-realtime-preview") + } + + var payload map[string]any + if err := json.Unmarshal(body, &payload); err != nil { + t.Fatalf("failed to unmarshal normalized body: %v", err) + } + if _, ok := payload["session"]; ok { + t.Fatal("legacy sessions endpoint should not forward nested session object") + } + if payload["model"] != "gpt-4o-realtime-preview" { + t.Fatalf("model = %v, want %q", payload["model"], "gpt-4o-realtime-preview") + } + if payload["voice"] != "alloy" { + t.Fatalf("voice = %v, want %q", payload["voice"], "alloy") + } +} + +func TestToProviderRealtimeEventSerializesTopLevelClientFields(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + contentIndex, err := json.Marshal(0) + if err != nil { + t.Fatalf("json.Marshal() error = %v", err) + } + audioEndMS, err := json.Marshal(640) + if err != nil { + t.Fatalf("json.Marshal() error = %v", err) + } + out, err := provider.ToProviderRealtimeEvent(&schemas.BifrostRealtimeEvent{ + Type: schemas.RealtimeEventType("conversation.item.truncate"), + ExtraParams: map[string]json.RawMessage{ + "item_id": json.RawMessage(`"item_123"`), + "content_index": contentIndex, + "audio_end_ms": audioEndMS, + }, + }) + if err != nil { + t.Fatalf("ToProviderRealtimeEvent() error = %v", err) + } + + var payload map[string]any + if err := json.Unmarshal(out, &payload); err != nil { + t.Fatalf("json.Unmarshal() error = %v", err) + } + if payload["type"] != "conversation.item.truncate" { + t.Fatalf("type = %v, want %q", payload["type"], "conversation.item.truncate") + } + if payload["item_id"] != "item_123" { + t.Fatalf("item_id = %v, want %q", payload["item_id"], "item_123") + } + if payload["content_index"] != float64(0) { + t.Fatalf("content_index = %v, want 0", payload["content_index"]) + } + if payload["audio_end_ms"] != float64(640) { + t.Fatalf("audio_end_ms = %v, want 640", payload["audio_end_ms"]) + } +} + +func TestToBifrostRealtimeEventParsesTopLevelClientFields(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + event, err := provider.ToBifrostRealtimeEvent(json.RawMessage(`{"type":"conversation.item.truncate","item_id":"item_123","content_index":0,"audio_end_ms":640}`)) + if err != nil { + t.Fatalf("ToBifrostRealtimeEvent() error = %v", err) + } + var itemID string + if err := json.Unmarshal(event.ExtraParams["item_id"], &itemID); err != nil { + t.Fatalf("json.Unmarshal(item_id) error = %v", err) + } + if itemID != "item_123" { + t.Fatalf("item_id = %q, want %q", itemID, "item_123") + } + var contentIndex int + if err := json.Unmarshal(event.ExtraParams["content_index"], &contentIndex); err != nil { + t.Fatalf("json.Unmarshal(content_index) error = %v", err) + } + if contentIndex != 0 { + t.Fatalf("content_index = %d, want 0", contentIndex) + } + var audioEndMS int + if err := json.Unmarshal(event.ExtraParams["audio_end_ms"], &audioEndMS); err != nil { + t.Fatalf("json.Unmarshal(audio_end_ms) error = %v", err) + } + if audioEndMS != 640 { + t.Fatalf("audio_end_ms = %d, want 640", audioEndMS) + } +} + +func TestToBifrostRealtimeEventParsesCompletedInputAudioTranscript(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + event, err := provider.ToBifrostRealtimeEvent(json.RawMessage(`{"type":"conversation.item.input_audio_transcription.completed","event_id":"evt_123","item_id":"item_123","content_index":0,"transcript":"Who are you?"}`)) + if err != nil { + t.Fatalf("ToBifrostRealtimeEvent() error = %v", err) + } + + var transcript string + if err := json.Unmarshal(event.ExtraParams["transcript"], &transcript); err != nil { + t.Fatalf("json.Unmarshal(transcript) error = %v", err) + } + if transcript != "Who are you?" { + t.Fatalf("transcript = %q, want %q", transcript, "Who are you?") + } +} + +func TestToBifrostRealtimeEventParsesModernOutputTextDelta(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + event, err := provider.ToBifrostRealtimeEvent(json.RawMessage(`{ + "type":"response.output_text.delta", + "event_id":"evt_123", + "item_id":"item_123", + "output_index":0, + "content_index":0, + "response_id":"resp_123", + "delta":"hello" + }`)) + if err != nil { + t.Fatalf("ToBifrostRealtimeEvent() error = %v", err) + } + if event.Delta == nil || event.Delta.Text != "hello" { + t.Fatalf("Delta = %+v, want text=hello", event.Delta) + } +} + +func TestShouldStartRealtimeTurn(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + tests := []struct { + name string + event *schemas.BifrostRealtimeEvent + want bool + }{ + { + name: "response create starts turn", + event: &schemas.BifrostRealtimeEvent{Type: schemas.RTEventResponseCreate}, + want: true, + }, + { + name: "audio buffer committed starts turn", + event: &schemas.BifrostRealtimeEvent{Type: schemas.RTEventInputAudioBufferCommitted}, + want: true, + }, + { + name: "response done does not start turn", + event: &schemas.BifrostRealtimeEvent{Type: schemas.RTEventResponseDone}, + want: false, + }, + { + name: "nil event does not start turn", + event: nil, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if got := provider.ShouldStartRealtimeTurn(tt.event); got != tt.want { + t.Fatalf("ShouldStartRealtimeTurn() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestToProviderRealtimeEventSerializesModernOutputTextDelta(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + outputIndex := 0 + contentIndex := 0 + out, err := provider.ToProviderRealtimeEvent(&schemas.BifrostRealtimeEvent{ + Type: schemas.RealtimeEventType("response.output_text.delta"), + Delta: &schemas.RealtimeDelta{ + Text: "hello", + ItemID: "item_123", + OutputIdx: &outputIndex, + ContentIdx: &contentIndex, + ResponseID: "resp_123", + }, + }) + if err != nil { + t.Fatalf("ToProviderRealtimeEvent() error = %v", err) + } + + var payload map[string]any + if err := json.Unmarshal(out, &payload); err != nil { + t.Fatalf("json.Unmarshal() error = %v", err) + } + if payload["type"] != "response.output_text.delta" { + t.Fatalf("type = %v, want response.output_text.delta", payload["type"]) + } + if payload["delta"] != "hello" { + t.Fatalf("delta = %v, want hello", payload["delta"]) + } +} + +func TestToProviderRealtimeEventSerializesSessionID(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + out, err := provider.ToProviderRealtimeEvent(&schemas.BifrostRealtimeEvent{ + Type: schemas.RTEventSessionCreated, + Session: &schemas.RealtimeSession{ + ID: "sess_123", + Model: "gpt-realtime", + }, + }) + if err != nil { + t.Fatalf("ToProviderRealtimeEvent() error = %v", err) + } + + var payload map[string]any + if err := json.Unmarshal(out, &payload); err != nil { + t.Fatalf("json.Unmarshal() error = %v", err) + } + session, ok := payload["session"].(map[string]any) + if !ok { + t.Fatalf("session = %T, want object", payload["session"]) + } + if session["id"] != "sess_123" { + t.Fatalf("session.id = %v, want sess_123", session["id"]) + } +} + +func TestToProviderRealtimeEventSerializesMessageItemStatus(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + content := json.RawMessage(`[{"type":"input_audio","transcript":"hello"}]`) + out, err := provider.ToProviderRealtimeEvent(&schemas.BifrostRealtimeEvent{ + Type: schemas.RealtimeEventType("conversation.item.retrieved"), + Item: &schemas.RealtimeItem{ + ID: "item_123", + Type: "message", + Role: "user", + Status: "completed", + Content: content, + }, + }) + if err != nil { + t.Fatalf("ToProviderRealtimeEvent() error = %v", err) + } + + var payload map[string]any + if err := json.Unmarshal(out, &payload); err != nil { + t.Fatalf("json.Unmarshal() error = %v", err) + } + item, ok := payload["item"].(map[string]any) + if !ok { + t.Fatalf("item = %T, want object", payload["item"]) + } + if item["status"] != "completed" { + t.Fatalf("item.status = %v, want completed", item["status"]) + } +} + +func TestToBifrostRealtimeEventPreservesTopLevelResponsePayload(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + event, err := provider.ToBifrostRealtimeEvent(json.RawMessage(`{ + "type":"response.done", + "event_id":"evt_123", + "response":{ + "id":"resp_123", + "output":[{"type":"message","content":[{"type":"output_text","text":"hello"}]}] + } + }`)) + if err != nil { + t.Fatalf("ToBifrostRealtimeEvent() error = %v", err) + } + + var response map[string]any + if err := json.Unmarshal(event.ExtraParams["response"], &response); err != nil { + t.Fatalf("json.Unmarshal(response) error = %v", err) + } + if response["id"] != "resp_123" { + t.Fatalf("response.id = %v, want resp_123", response["id"]) + } +} + +func TestToProviderRealtimeEventSerializesTopLevelResponsePayload(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + out, err := provider.ToProviderRealtimeEvent(&schemas.BifrostRealtimeEvent{ + Type: schemas.RTEventResponseDone, + ExtraParams: map[string]json.RawMessage{ + "response": json.RawMessage(`{"id":"resp_123","output":[{"type":"message","content":[{"type":"output_text","text":"hello"}]}]}`), + }, + }) + if err != nil { + t.Fatalf("ToProviderRealtimeEvent() error = %v", err) + } + + var payload map[string]any + if err := json.Unmarshal(out, &payload); err != nil { + t.Fatalf("json.Unmarshal() error = %v", err) + } + response, ok := payload["response"].(map[string]any) + if !ok { + t.Fatalf("response = %T, want object", payload["response"]) + } + if response["id"] != "resp_123" { + t.Fatalf("response.id = %v, want resp_123", response["id"]) + } +} + +func TestToBifrostRealtimeEventPreservesTopLevelPartPayload(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + event, err := provider.ToBifrostRealtimeEvent(json.RawMessage(`{ + "type":"response.content_part.added", + "event_id":"evt_123", + "item_id":"item_123", + "output_index":0, + "content_index":0, + "part":{ + "type":"text", + "text":"hello" + } + }`)) + if err != nil { + t.Fatalf("ToBifrostRealtimeEvent() error = %v", err) + } + + var part map[string]any + if err := json.Unmarshal(event.ExtraParams["part"], &part); err != nil { + t.Fatalf("json.Unmarshal(part) error = %v", err) + } + if part["type"] != "text" { + t.Fatalf("part.type = %v, want text", part["type"]) + } +} + +func TestToProviderRealtimeEventSerializesTopLevelPartPayload(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + out, err := provider.ToProviderRealtimeEvent(&schemas.BifrostRealtimeEvent{ + Type: schemas.RTEventResponseContentPartAdded, + ExtraParams: map[string]json.RawMessage{ + "part": json.RawMessage(`{"type":"text","text":"hello"}`), + }, + }) + if err != nil { + t.Fatalf("ToProviderRealtimeEvent() error = %v", err) + } + + var payload map[string]any + if err := json.Unmarshal(out, &payload); err != nil { + t.Fatalf("json.Unmarshal() error = %v", err) + } + part, ok := payload["part"].(map[string]any) + if !ok { + t.Fatalf("part = %T, want object", payload["part"]) + } + if part["type"] != "text" { + t.Fatalf("part.type = %v, want text", part["type"]) + } +} + +func TestParseRealtimeEventPreservesNestedSessionExtraParams(t *testing.T) { + t.Parallel() + + event, err := schemas.ParseRealtimeEvent([]byte(`{ + "type":"session.update", + "session":{ + "type":"realtime", + "model":"gpt-4o-realtime-preview", + "output_modalities":["text"] + } + }`)) + if err != nil { + t.Fatalf("ParseRealtimeEvent() error = %v", err) + } + if event.Session == nil { + t.Fatal("expected session to be parsed") + } + var outputModalities []string + if err := json.Unmarshal(event.Session.ExtraParams["output_modalities"], &outputModalities); err != nil { + t.Fatalf("json.Unmarshal(output_modalities) error = %v", err) + } + if len(outputModalities) != 1 || outputModalities[0] != "text" { + t.Fatalf("output_modalities = %v, want [text]", outputModalities) + } +} + +func TestToProviderRealtimeEventSerializesNestedSessionExtraParams(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + out, err := provider.ToProviderRealtimeEvent(&schemas.BifrostRealtimeEvent{ + Type: schemas.RTEventSessionUpdate, + Session: &schemas.RealtimeSession{ + Model: "gpt-4o-realtime-preview", + ExtraParams: map[string]json.RawMessage{ + "type": json.RawMessage(`"realtime"`), + "output_modalities": json.RawMessage(`["text"]`), + }, + }, + }) + if err != nil { + t.Fatalf("ToProviderRealtimeEvent() error = %v", err) + } + + var payload struct { + Type string `json:"type"` + Session map[string]any `json:"session"` + } + if err := json.Unmarshal(out, &payload); err != nil { + t.Fatalf("json.Unmarshal() error = %v", err) + } + if payload.Type != "session.update" { + t.Fatalf("type = %q, want %q", payload.Type, "session.update") + } + if payload.Session["type"] != "realtime" { + t.Fatalf("session.type = %v, want realtime", payload.Session["type"]) + } + outputModalities, ok := payload.Session["output_modalities"].([]any) + if !ok || len(outputModalities) != 1 || outputModalities[0] != "text" { + t.Fatalf("session.output_modalities = %v, want [text]", payload.Session["output_modalities"]) + } +} + +func TestToProviderRealtimeEventOmitsReadOnlySessionFieldsOnSessionUpdate(t *testing.T) { + t.Parallel() + + provider := &OpenAIProvider{} + out, err := provider.ToProviderRealtimeEvent(&schemas.BifrostRealtimeEvent{ + Type: schemas.RTEventSessionUpdate, + Session: &schemas.RealtimeSession{ + ID: "sess_123", + Model: "gpt-realtime", + ExtraParams: map[string]json.RawMessage{ + "type": json.RawMessage(`"realtime"`), + "object": json.RawMessage(`"realtime.session"`), + "expires_at": json.RawMessage(`1774614381`), + "client_secret": json.RawMessage(`{"value":"secret"}`), + "modalities": json.RawMessage(`["text","audio"]`), + }, + }, + }) + if err != nil { + t.Fatalf("ToProviderRealtimeEvent() error = %v", err) + } + + var payload struct { + Session map[string]any `json:"session"` + } + if err := json.Unmarshal(out, &payload); err != nil { + t.Fatalf("json.Unmarshal() error = %v", err) + } + for _, key := range []string{"id", "object", "expires_at", "client_secret"} { + if _, ok := payload.Session[key]; ok { + t.Fatalf("session.%s unexpectedly present in session.update payload", key) + } + } + if payload.Session["type"] != "realtime" { + t.Fatalf("session.type = %v, want realtime", payload.Session["type"]) + } + if payload.Session["model"] != "gpt-realtime" { + t.Fatalf("session.model = %v, want gpt-realtime", payload.Session["model"]) + } +} diff --git a/core/schemas/bifrost.go b/core/schemas/bifrost.go index 34f120432c..f50eed4210 100644 --- a/core/schemas/bifrost.go +++ b/core/schemas/bifrost.go @@ -228,6 +228,11 @@ const ( BifrostContextKeyTransportPluginLogs BifrostContextKey = "bifrost-transport-plugin-logs" // []PluginLogEntry (transport-layer plugin logs accumulated during HTTP transport hooks) BifrostContextKeyTransportPostHookCompleter BifrostContextKey = "bifrost-transport-posthook-completer" // func() (callback to run HTTPTransportPostHook after streaming - set by transport interceptor middleware) BifrostContextKeySkipPluginPipeline BifrostContextKey = "bifrost-skip-plugin-pipeline" // bool - skip plugin pipeline for the request + BifrostContextKeyParentRequestID BifrostContextKey = "bifrost-parent-request-id" // string (parent linkage for grouped request logs like realtime turns) + BifrostContextKeyRealtimeSessionID BifrostContextKey = "bifrost-realtime-session-id" // string + BifrostContextKeyRealtimeProviderSessionID BifrostContextKey = "bifrost-realtime-provider-session-id" // string + BifrostContextKeyRealtimeSource BifrostContextKey = "bifrost-realtime-source" // string ("ei" or "lm") + BifrostContextKeyRealtimeEventType BifrostContextKey = "bifrost-realtime-event-type" // string BifrostIsAsyncRequest BifrostContextKey = "bifrost-is-async-request" // bool (set by bifrost - DO NOT SET THIS MANUALLY)) - whether the request is an async request (only used in gateway) BifrostContextKeyRequestHeaders BifrostContextKey = "bifrost-request-headers" // map[string]string (all request headers with lowercased keys) BifrostContextKeySkipListModelsGovernanceFiltering BifrostContextKey = "bifrost-skip-list-models-governance-filtering" // bool (set by bifrost - DO NOT SET THIS MANUALLY)) diff --git a/core/schemas/realtime.go b/core/schemas/realtime.go index e1e20d7bf4..ec4fd6789d 100644 --- a/core/schemas/realtime.go +++ b/core/schemas/realtime.go @@ -19,33 +19,75 @@ const ( // Server-to-client event types (received from the provider, forwarded to client) const ( - RTEventSessionCreated RealtimeEventType = "session.created" - RTEventSessionUpdated RealtimeEventType = "session.updated" - RTEventConversationCreated RealtimeEventType = "conversation.created" - RTEventConversationItemCreated RealtimeEventType = "conversation.item.created" - RTEventConversationItemDone RealtimeEventType = "conversation.item.done" - RTEventResponseCreated RealtimeEventType = "response.created" - RTEventResponseDone RealtimeEventType = "response.done" - RTEventResponseTextDelta RealtimeEventType = "response.text.delta" - RTEventResponseTextDone RealtimeEventType = "response.text.done" - RTEventResponseAudioDelta RealtimeEventType = "response.audio.delta" - RTEventResponseAudioDone RealtimeEventType = "response.audio.done" - RTEventResponseAudioTransDelta RealtimeEventType = "response.audio_transcript.delta" - RTEventResponseAudioTransDone RealtimeEventType = "response.audio_transcript.done" - RTEventResponseOutputItemAdded RealtimeEventType = "response.output_item.added" - RTEventResponseOutputItemDone RealtimeEventType = "response.output_item.done" - RTEventResponseContentPartAdded RealtimeEventType = "response.content_part.added" - RTEventResponseContentPartDone RealtimeEventType = "response.content_part.done" - RTEventInputAudioTransCompleted RealtimeEventType = "conversation.item.input_audio_transcription.completed" - RTEventInputAudioTransDelta RealtimeEventType = "conversation.item.input_audio_transcription.delta" - RTEventInputAudioTransFailed RealtimeEventType = "conversation.item.input_audio_transcription.failed" - RTEventInputAudioBufferCommitted RealtimeEventType = "input_audio_buffer.committed" - RTEventInputAudioBufferCleared RealtimeEventType = "input_audio_buffer.cleared" - RTEventInputAudioSpeechStarted RealtimeEventType = "input_audio_buffer.speech_started" - RTEventInputAudioSpeechStopped RealtimeEventType = "input_audio_buffer.speech_stopped" - RTEventError RealtimeEventType = "error" + RTEventSessionCreated RealtimeEventType = "session.created" + RTEventSessionUpdated RealtimeEventType = "session.updated" + RTEventConversationCreated RealtimeEventType = "conversation.created" + RTEventConversationItemAdded RealtimeEventType = "conversation.item.added" + RTEventConversationItemCreated RealtimeEventType = "conversation.item.created" + RTEventConversationItemRetrieved RealtimeEventType = "conversation.item.retrieved" + RTEventConversationItemDone RealtimeEventType = "conversation.item.done" + RTEventResponseCreated RealtimeEventType = "response.created" + RTEventResponseDone RealtimeEventType = "response.done" + RTEventResponseTextDelta RealtimeEventType = "response.text.delta" + RTEventResponseTextDone RealtimeEventType = "response.text.done" + RTEventResponseAudioDelta RealtimeEventType = "response.audio.delta" + RTEventResponseAudioDone RealtimeEventType = "response.audio.done" + RTEventResponseAudioTransDelta RealtimeEventType = "response.audio_transcript.delta" + RTEventResponseAudioTransDone RealtimeEventType = "response.audio_transcript.done" + RTEventResponseOutputItemAdded RealtimeEventType = "response.output_item.added" + RTEventResponseOutputItemDone RealtimeEventType = "response.output_item.done" + RTEventResponseContentPartAdded RealtimeEventType = "response.content_part.added" + RTEventResponseContentPartDone RealtimeEventType = "response.content_part.done" + RTEventRateLimitsUpdated RealtimeEventType = "rate_limits.updated" + RTEventInputAudioTransCompleted RealtimeEventType = "conversation.item.input_audio_transcription.completed" + RTEventInputAudioTransDelta RealtimeEventType = "conversation.item.input_audio_transcription.delta" + RTEventInputAudioTransFailed RealtimeEventType = "conversation.item.input_audio_transcription.failed" + RTEventInputAudioBufferCommitted RealtimeEventType = "input_audio_buffer.committed" + RTEventInputAudioBufferCleared RealtimeEventType = "input_audio_buffer.cleared" + RTEventInputAudioSpeechStarted RealtimeEventType = "input_audio_buffer.speech_started" + RTEventInputAudioSpeechStopped RealtimeEventType = "input_audio_buffer.speech_stopped" + RTEventError RealtimeEventType = "error" ) +// IsRealtimeConversationItemEventType reports whether the event carries a +// canonical conversation item payload after provider translation. +func IsRealtimeConversationItemEventType(eventType RealtimeEventType) bool { + switch eventType { + case RTEventConversationItemCreate, + RTEventConversationItemAdded, + RTEventConversationItemCreated, + RTEventConversationItemRetrieved, + RTEventConversationItemDone: + return true + default: + return false + } +} + +// IsRealtimeUserInputEvent reports whether the event represents a finalized +// user input item in the canonical Bifrost realtime schema. +func IsRealtimeUserInputEvent(event *BifrostRealtimeEvent) bool { + return event != nil && + event.Item != nil && + event.Item.Role == "user" && + IsRealtimeConversationItemEventType(event.Type) +} + +// IsRealtimeToolOutputEvent reports whether the event represents a finalized +// tool output item in the canonical Bifrost realtime schema. +func IsRealtimeToolOutputEvent(event *BifrostRealtimeEvent) bool { + return event != nil && + event.Item != nil && + event.Item.Type == "function_call_output" && + IsRealtimeConversationItemEventType(event.Type) +} + +// IsRealtimeInputTranscriptEvent reports whether the event carries a finalized +// input-audio transcript in the canonical Bifrost realtime schema. +func IsRealtimeInputTranscriptEvent(event *BifrostRealtimeEvent) bool { + return event != nil && event.Type == RTEventInputAudioTransCompleted +} + // BifrostRealtimeEvent is the unified Bifrost envelope for all Realtime events. // Provider converters translate between this format and the provider-native protocol. type BifrostRealtimeEvent struct { @@ -58,36 +100,42 @@ type BifrostRealtimeEvent struct { Audio []byte `json:"audio,omitempty"` Error *RealtimeError `json:"error,omitempty"` + // ExtraParams preserves provider-specific top-level event fields that are not + // promoted into the common Bifrost schema. + ExtraParams map[string]json.RawMessage `json:"extra_params,omitempty"` + // RawData preserves the original provider event for pass-through or debugging. RawData json.RawMessage `json:"raw_data,omitempty"` } // RealtimeSession describes session configuration for the Realtime connection. type RealtimeSession struct { - ID string `json:"id,omitempty"` - Model string `json:"model,omitempty"` - Modalities []string `json:"modalities,omitempty"` - Instructions string `json:"instructions,omitempty"` - Voice string `json:"voice,omitempty"` - Temperature *float64 `json:"temperature,omitempty"` - MaxOutputTokens json.RawMessage `json:"max_output_tokens,omitempty"` - TurnDetection json.RawMessage `json:"turn_detection,omitempty"` - InputAudioFormat string `json:"input_audio_format,omitempty"` - OutputAudioType string `json:"output_audio_type,omitempty"` - Tools json.RawMessage `json:"tools,omitempty"` + ID string `json:"id,omitempty"` + Model string `json:"model,omitempty"` + Modalities []string `json:"modalities,omitempty"` + Instructions string `json:"instructions,omitempty"` + Voice string `json:"voice,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + MaxOutputTokens json.RawMessage `json:"max_output_tokens,omitempty"` + TurnDetection json.RawMessage `json:"turn_detection,omitempty"` + InputAudioFormat string `json:"input_audio_format,omitempty"` + OutputAudioType string `json:"output_audio_type,omitempty"` + Tools json.RawMessage `json:"tools,omitempty"` + ExtraParams map[string]json.RawMessage `json:"extra_params,omitempty"` } // RealtimeItem represents a conversation item in the Realtime protocol. type RealtimeItem struct { - ID string `json:"id,omitempty"` - Type string `json:"type,omitempty"` - Role string `json:"role,omitempty"` - Status string `json:"status,omitempty"` - Content json.RawMessage `json:"content,omitempty"` - Name string `json:"name,omitempty"` - CallID string `json:"call_id,omitempty"` - Arguments string `json:"arguments,omitempty"` - Output string `json:"output,omitempty"` + ID string `json:"id,omitempty"` + Type string `json:"type,omitempty"` + Role string `json:"role,omitempty"` + Status string `json:"status,omitempty"` + Content json.RawMessage `json:"content,omitempty"` + Name string `json:"name,omitempty"` + CallID string `json:"call_id,omitempty"` + Arguments string `json:"arguments,omitempty"` + Output string `json:"output,omitempty"` + ExtraParams map[string]json.RawMessage `json:"extra_params,omitempty"` } // RealtimeDelta carries incremental content for streaming events. @@ -103,10 +151,28 @@ type RealtimeDelta struct { // RealtimeError describes an error from the Realtime API. type RealtimeError struct { - Type string `json:"type,omitempty"` - Code string `json:"code,omitempty"` - Message string `json:"message,omitempty"` - Param string `json:"param,omitempty"` + Type string `json:"type,omitempty"` + Code string `json:"code,omitempty"` + Message string `json:"message,omitempty"` + Param string `json:"param,omitempty"` + ExtraParams map[string]json.RawMessage `json:"extra_params,omitempty"` +} + +// RealtimeSessionEndpointType identifies the public ephemeral-token endpoint +// shape the client called so providers can preserve versioned behavior. +type RealtimeSessionEndpointType string + +const ( + RealtimeSessionEndpointClientSecrets RealtimeSessionEndpointType = "client_secrets" + RealtimeSessionEndpointSessions RealtimeSessionEndpointType = "sessions" +) + +// RealtimeSessionRoute describes a provider-registered public route for +// ephemeral-token creation. +type RealtimeSessionRoute struct { + Path string + EndpointType RealtimeSessionEndpointType + DefaultProvider ModelProvider } // RealtimeProvider is an optional interface that providers can implement to @@ -116,6 +182,129 @@ type RealtimeProvider interface { SupportsRealtimeAPI() bool RealtimeWebSocketURL(key Key, model string) string RealtimeHeaders(key Key) map[string]string + // SupportsRealtimeWebRTC reports whether the provider supports WebRTC SDP exchange. + SupportsRealtimeWebRTC() bool + // ExchangeRealtimeWebRTCSDP performs the provider-specific SDP signaling exchange. + // The provider owns the HTTP specifics (URL, headers, body format). + // session may be nil if the signaling format doesn't include session config. + ExchangeRealtimeWebRTCSDP(ctx *BifrostContext, key Key, model string, sdp string, session json.RawMessage) (string, *BifrostError) ToBifrostRealtimeEvent(providerEvent json.RawMessage) (*BifrostRealtimeEvent, error) ToProviderRealtimeEvent(bifrostEvent *BifrostRealtimeEvent) (json.RawMessage, error) + // ShouldStartRealtimeTurn reports whether the canonical client-side event + // should start pre-hooks. Providers without an explicit turn-start signal + // return false and rely on finalize-time fallback hooks. + ShouldStartRealtimeTurn(event *BifrostRealtimeEvent) bool + // RealtimeTurnFinalEvent returns the canonical provider event that completes + // a turn and should trigger post-hooks. + RealtimeTurnFinalEvent() RealtimeEventType + RealtimeWebRTCDataChannelLabel() string + RealtimeWebSocketSubprotocol() string + ShouldForwardRealtimeEvent(event *BifrostRealtimeEvent) bool + ShouldAccumulateRealtimeOutput(eventType RealtimeEventType) bool +} + +// RealtimeLegacyWebRTCProvider is an optional interface for providers that +// support the beta WebRTC handshake (e.g., OpenAI's /v1/realtime). +// Only checked for legacy integration routes via type assertion. +// Takes SDP offer + optional session JSON, same as ExchangeRealtimeWebRTCSDP +// but targets the provider's legacy/beta endpoint. +type RealtimeLegacyWebRTCProvider interface { + ExchangeLegacyRealtimeWebRTCSDP(ctx *BifrostContext, key Key, sdp string, session json.RawMessage, model string) (string, *BifrostError) +} + +// RealtimeUsageExtractor lets providers parse terminal-turn usage/output from +// their native wire payloads without coupling handlers to a specific protocol. +type RealtimeUsageExtractor interface { + ExtractRealtimeTurnUsage(terminalEventRaw []byte) *BifrostLLMUsage + ExtractRealtimeTurnOutput(terminalEventRaw []byte) *ChatMessage +} + +// RealtimeSessionProvider is an optional interface for providers that can mint +// short-lived client secrets for browser/client-side Realtime connections. +// Checked via type assertion: provider.(RealtimeSessionProvider). +type RealtimeSessionProvider interface { + CreateRealtimeClientSecret(ctx *BifrostContext, key Key, endpointType RealtimeSessionEndpointType, rawRequest json.RawMessage) (*BifrostPassthroughResponse, *BifrostError) +} + +// ParseRealtimeEvent decodes a client/provider realtime event while preserving +// unknown top-level fields in ExtraParams for provider-specific round-tripping. +func ParseRealtimeEvent(raw []byte) (*BifrostRealtimeEvent, error) { + type realtimeEventAlias struct { + Type RealtimeEventType `json:"type"` + EventID string `json:"event_id,omitempty"` + Session *RealtimeSession `json:"session,omitempty"` + Item *RealtimeItem `json:"item,omitempty"` + Delta *RealtimeDelta `json:"delta,omitempty"` + Audio []byte `json:"audio,omitempty"` + Error *RealtimeError `json:"error,omitempty"` + } + + var alias realtimeEventAlias + if err := Unmarshal(raw, &alias); err != nil { + return nil, err + } + + event := &BifrostRealtimeEvent{ + Type: alias.Type, + EventID: alias.EventID, + Session: alias.Session, + Item: alias.Item, + Delta: alias.Delta, + Audio: alias.Audio, + Error: alias.Error, + } + + var root map[string]json.RawMessage + if err := Unmarshal(raw, &root); err != nil { + return nil, err + } + savedSession := root["session"] + savedItem := root["item"] + savedError := root["error"] + for _, key := range []string{"type", "event_id", "session", "item", "delta", "audio", "error", "raw_data"} { + delete(root, key) + } + if len(root) > 0 { + event.ExtraParams = root + } + if event.Session != nil { + var sessionRoot map[string]json.RawMessage + if len(savedSession) > 0 && Unmarshal(savedSession, &sessionRoot) == nil { + for _, key := range []string{ + "id", "model", "modalities", "instructions", "voice", "temperature", + "max_output_tokens", "turn_detection", "input_audio_format", "output_audio_type", "tools", + } { + delete(sessionRoot, key) + } + if len(sessionRoot) > 0 { + event.Session.ExtraParams = sessionRoot + } + } + } + if event.Item != nil { + var itemRoot map[string]json.RawMessage + if len(savedItem) > 0 && Unmarshal(savedItem, &itemRoot) == nil { + for _, key := range []string{ + "id", "type", "role", "status", "content", "name", "call_id", "arguments", "output", + } { + delete(itemRoot, key) + } + if len(itemRoot) > 0 { + event.Item.ExtraParams = itemRoot + } + } + } + if event.Error != nil { + var errorRoot map[string]json.RawMessage + if len(savedError) > 0 && Unmarshal(savedError, &errorRoot) == nil { + for _, key := range []string{"type", "code", "message", "param"} { + delete(errorRoot, key) + } + if len(errorRoot) > 0 { + event.Error.ExtraParams = errorRoot + } + } + } + + return event, nil } diff --git a/core/schemas/realtime_client_secrets.go b/core/schemas/realtime_client_secrets.go new file mode 100644 index 0000000000..ae97b573a1 --- /dev/null +++ b/core/schemas/realtime_client_secrets.go @@ -0,0 +1,66 @@ +package schemas + +import ( + "bytes" + "encoding/json" + "strings" +) + +// ParseRealtimeClientSecretBody parses a realtime client-secret request body +// into a mutable raw JSON map while preserving unknown fields. +func ParseRealtimeClientSecretBody(raw json.RawMessage) (map[string]json.RawMessage, *BifrostError) { + var root map[string]json.RawMessage + if err := Unmarshal(raw, &root); err != nil { + return nil, NewRealtimeClientSecretBodyError(400, "invalid_request_error", "invalid JSON body", err) + } + return root, nil +} + +// ExtractRealtimeClientSecretModel extracts the model from either session.model +// or the legacy top-level model field. +func ExtractRealtimeClientSecretModel(root map[string]json.RawMessage) (string, *BifrostError) { + if sessionJSON, ok := root["session"]; ok && len(sessionJSON) > 0 && !bytes.Equal(sessionJSON, []byte("null")) { + var session map[string]json.RawMessage + if err := Unmarshal(sessionJSON, &session); err != nil { + return "", NewRealtimeClientSecretBodyError(400, "invalid_request_error", "session must be an object", err) + } + if modelJSON, ok := session["model"]; ok { + var sessionModel string + if err := Unmarshal(modelJSON, &sessionModel); err != nil { + return "", NewRealtimeClientSecretBodyError(400, "invalid_request_error", "session.model must be a string", err) + } + if strings.TrimSpace(sessionModel) != "" { + return strings.TrimSpace(sessionModel), nil + } + } + } + + if modelJSON, ok := root["model"]; ok { + var model string + if err := Unmarshal(modelJSON, &model); err != nil { + return "", NewRealtimeClientSecretBodyError(400, "invalid_request_error", "model must be a string", err) + } + if strings.TrimSpace(model) != "" { + return strings.TrimSpace(model), nil + } + } + + return "", NewRealtimeClientSecretBodyError(400, "invalid_request_error", "session.model or model is required", nil) +} + +// NewRealtimeClientSecretBodyError builds a standard invalid-request style error +// for HTTP realtime client-secret request parsing/validation. +func NewRealtimeClientSecretBodyError(status int, errorType, message string, err error) *BifrostError { + return &BifrostError{ + IsBifrostError: false, + StatusCode: Ptr(status), + Error: &ErrorField{ + Type: Ptr(errorType), + Message: message, + Error: err, + }, + ExtraFields: BifrostErrorExtraFields{ + RequestType: RealtimeRequest, + }, + } +} diff --git a/core/schemas/realtime_client_secrets_test.go b/core/schemas/realtime_client_secrets_test.go new file mode 100644 index 0000000000..dfd8f8b1d3 --- /dev/null +++ b/core/schemas/realtime_client_secrets_test.go @@ -0,0 +1,40 @@ +package schemas + +import ( + "encoding/json" + "testing" +) + +func TestExtractRealtimeClientSecretModel(t *testing.T) { + t.Parallel() + + root, err := ParseRealtimeClientSecretBody(json.RawMessage(`{"session":{"model":"openai/gpt-4o-realtime-preview"}}`)) + if err != nil { + t.Fatalf("ParseRealtimeClientSecretBody() error = %v", err) + } + + model, err := ExtractRealtimeClientSecretModel(root) + if err != nil { + t.Fatalf("ExtractRealtimeClientSecretModel() error = %v", err) + } + if model != "openai/gpt-4o-realtime-preview" { + t.Fatalf("model = %q, want %q", model, "openai/gpt-4o-realtime-preview") + } +} + +func TestExtractRealtimeClientSecretModelFallbackTopLevel(t *testing.T) { + t.Parallel() + + root, err := ParseRealtimeClientSecretBody(json.RawMessage(`{"model":"gpt-4o-realtime-preview"}`)) + if err != nil { + t.Fatalf("ParseRealtimeClientSecretBody() error = %v", err) + } + + model, err := ExtractRealtimeClientSecretModel(root) + if err != nil { + t.Fatalf("ExtractRealtimeClientSecretModel() error = %v", err) + } + if model != "gpt-4o-realtime-preview" { + t.Fatalf("model = %q, want %q", model, "gpt-4o-realtime-preview") + } +} diff --git a/core/schemas/realtime_test.go b/core/schemas/realtime_test.go new file mode 100644 index 0000000000..69e9e403c8 --- /dev/null +++ b/core/schemas/realtime_test.go @@ -0,0 +1,68 @@ +package schemas + +import "testing" + +func TestIsRealtimeConversationItemEventType(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + eventType RealtimeEventType + want bool + }{ + {name: "create", eventType: RTEventConversationItemCreate, want: true}, + {name: "added", eventType: RTEventConversationItemAdded, want: true}, + {name: "created", eventType: RTEventConversationItemCreated, want: true}, + {name: "retrieved", eventType: RTEventConversationItemRetrieved, want: true}, + {name: "done", eventType: RTEventConversationItemDone, want: true}, + {name: "response done", eventType: RTEventResponseDone, want: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if got := IsRealtimeConversationItemEventType(tt.eventType); got != tt.want { + t.Fatalf("IsRealtimeConversationItemEventType(%q) = %v, want %v", tt.eventType, got, tt.want) + } + }) + } +} + +func TestRealtimeCanonicalEventClassifiers(t *testing.T) { + t.Parallel() + + userEvent := &BifrostRealtimeEvent{ + Type: RTEventConversationItemAdded, + Item: &RealtimeItem{ + Role: "user", + Type: "message", + }, + } + if !IsRealtimeUserInputEvent(userEvent) { + t.Fatal("expected conversation.item.added user event to be classified as realtime user input") + } + if IsRealtimeToolOutputEvent(userEvent) { + t.Fatal("did not expect conversation.item.added user event to be classified as realtime tool output") + } + + toolEvent := &BifrostRealtimeEvent{ + Type: RTEventConversationItemRetrieved, + Item: &RealtimeItem{ + Type: "function_call_output", + }, + } + if !IsRealtimeToolOutputEvent(toolEvent) { + t.Fatal("expected function_call_output item to be classified as realtime tool output") + } + if IsRealtimeUserInputEvent(toolEvent) { + t.Fatal("did not expect function_call_output item to be classified as realtime user input") + } + + transcriptEvent := &BifrostRealtimeEvent{Type: RTEventInputAudioTransCompleted} + if !IsRealtimeInputTranscriptEvent(transcriptEvent) { + t.Fatal("expected input audio transcription completion to be classified as transcript event") + } + if IsRealtimeInputTranscriptEvent(&BifrostRealtimeEvent{Type: RTEventInputAudioTransDelta}) { + t.Fatal("did not expect input audio transcription delta to be classified as transcript event") + } +} diff --git a/core/schemas/tracer.go b/core/schemas/tracer.go index 3e5bcbf46f..23c5d4cc4c 100644 --- a/core/schemas/tracer.go +++ b/core/schemas/tracer.go @@ -117,6 +117,10 @@ type Tracer interface { // Thread-safe. Should be called after plugin hooks complete, before trace completion. AttachPluginLogs(traceID string, logs []PluginLogEntry) + // CompleteAndFlushTrace ends a trace, exports it to observability plugins, and + // releases the trace resources. Used by transports that bypass normal HTTP trace completion. + CompleteAndFlushTrace(traceID string) + // Stop releases resources associated with the tracer. // Should be called during shutdown to stop background goroutines. Stop() @@ -185,6 +189,9 @@ func (n *NoOpTracer) ProcessStreamingChunk(_ string, _ bool, _ *BifrostResponse, // AttachPluginLogs does nothing. func (n *NoOpTracer) AttachPluginLogs(_ string, _ []PluginLogEntry) {} +// CompleteAndFlushTrace does nothing. +func (n *NoOpTracer) CompleteAndFlushTrace(_ string) {} + // Stop does nothing. func (n *NoOpTracer) Stop() {}