Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 117 additions & 7 deletions core/bifrost.go
Original file line number Diff line number Diff line change
Expand Up @@ -3821,9 +3821,10 @@ func (bifrost *Bifrost) GetProviderByKey(providerKey schemas.ModelProvider) sche
return bifrost.getProviderByKey(providerKey)
}

// SelectKeyForProvider selects an API key for the given provider and model.
// Used by WebSocket handlers that need a key for upstream connections.
func (bifrost *Bifrost) SelectKeyForProvider(ctx *schemas.BifrostContext, providerKey schemas.ModelProvider, model string) (schemas.Key, error) {
// SelectKeyForProviderRequestType selects an API key for the given provider, request type, and model.
// Used by WebSocket handlers that need a key for upstream connections while honoring request-specific
// AllowedRequests gates such as realtime-only support.
func (bifrost *Bifrost) SelectKeyForProviderRequestType(ctx *schemas.BifrostContext, requestType schemas.RequestType, providerKey schemas.ModelProvider, model string) (schemas.Key, error) {
if ctx == nil {
ctx = bifrost.ctx
}
Expand All @@ -3832,7 +3833,7 @@ func (bifrost *Bifrost) SelectKeyForProvider(ctx *schemas.BifrostContext, provid
config.CustomProviderConfig != nil && config.CustomProviderConfig.BaseProviderType != "" {
baseProvider = config.CustomProviderConfig.BaseProviderType
}
return bifrost.selectKeyFromProviderForModel(ctx, schemas.WebSocketResponsesRequest, providerKey, model, baseProvider)
return bifrost.selectKeyFromProviderForModel(ctx, requestType, providerKey, model, baseProvider)
}

// WSStreamHooks holds the post-hook runner and cleanup function returned by RunStreamPreHooks.
Expand All @@ -3846,6 +3847,13 @@ type WSStreamHooks struct {
ShortCircuitResponse *schemas.BifrostResponse
}

// RealtimeTurnHooks mirrors RunStreamPreHooks but is explicitly scoped to a
// single realtime turn rather than one long-lived transport connection.
type RealtimeTurnHooks struct {
PostHookRunner schemas.PostHookRunner
Cleanup func()
}

// RunStreamPreHooks acquires a plugin pipeline, sets up tracing context, runs PreLLMHooks,
// and returns a PostHookRunner for per-chunk post-processing.
// Used by WebSocket handlers that bypass the normal inference path but still need plugin hooks.
Expand Down Expand Up @@ -3884,13 +3892,22 @@ func (bifrost *Bifrost) RunStreamPreHooks(ctx *schemas.BifrostContext, req *sche

preReq, shortCircuit, preCount := pipeline.RunLLMPreHooks(ctx, req)
if preReq == nil && shortCircuit == nil {
bifrostErr := newBifrostErrorFromMsg("bifrost request after plugin hooks cannot be nil")
_, bifrostErr = pipeline.RunPostLLMHooks(ctx, nil, bifrostErr, preCount)
drainAndAttachPluginLogs(ctx)
if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && strings.TrimSpace(traceID) != "" {
tracer.CompleteAndFlushTrace(strings.TrimSpace(traceID))
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.
cleanup()
return nil, newBifrostErrorFromMsg("bifrost request after plugin hooks cannot be nil")
return nil, bifrostErr
Comment thread
danpiths marked this conversation as resolved.
}
if shortCircuit != nil {
if shortCircuit.Error != nil {
_, bifrostErr := pipeline.RunPostLLMHooks(ctx, nil, shortCircuit.Error, preCount)
drainAndAttachPluginLogs(ctx)
if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && strings.TrimSpace(traceID) != "" {
tracer.CompleteAndFlushTrace(strings.TrimSpace(traceID))
}
cleanup()
if bifrostErr != nil {
return nil, bifrostErr
Expand All @@ -3900,6 +3917,9 @@ func (bifrost *Bifrost) RunStreamPreHooks(ctx *schemas.BifrostContext, req *sche
if shortCircuit.Response != nil {
resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, shortCircuit.Response, nil, preCount)
drainAndAttachPluginLogs(ctx)
if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && strings.TrimSpace(traceID) != "" {
tracer.CompleteAndFlushTrace(strings.TrimSpace(traceID))
}
cleanup()
if bifrostErr != nil {
return nil, bifrostErr
Expand Down Expand Up @@ -3934,6 +3954,94 @@ func (bifrost *Bifrost) RunStreamPreHooks(ctx *schemas.BifrostContext, req *sche
}, nil
}

// RunRealtimeTurnPreHooks acquires a plugin pipeline and runs LLM pre-hooks for
// a single realtime turn. Unlike generic stream hooks, realtime turns do not
// support short-circuit responses in v1 because the transports cannot yet emit a
// fully synthetic assistant turn without an upstream generation.
func (bifrost *Bifrost) RunRealtimeTurnPreHooks(ctx *schemas.BifrostContext, req *schemas.BifrostRequest) (*RealtimeTurnHooks, *schemas.BifrostError) {
if req == nil {
bifrostErr := newBifrostErrorFromMsg("realtime turn request is nil")
bifrostErr.ExtraFields.RequestType = schemas.RealtimeRequest
return nil, bifrostErr
}
if ctx == nil {
ctx = bifrost.ctx
}

if _, ok := ctx.Value(schemas.BifrostContextKeyRequestID).(string); !ok {
ctx.SetValue(schemas.BifrostContextKeyRequestID, uuid.New().String())
}

tracer := bifrost.getTracer()
ctx.SetValue(schemas.BifrostContextKeyTracer, tracer)

if _, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); !ok {
traceID := tracer.CreateTrace("")
if traceID != "" {
ctx.SetValue(schemas.BifrostContextKeyTraceID, traceID)
}
}

pipeline := bifrost.getPluginPipeline()
cleanup := func() {
if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && traceID != "" {
tracer.CleanupStreamAccumulator(traceID)
}
bifrost.releasePluginPipeline(pipeline)
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.
provider, model, _ := req.GetRequestFields()

Comment thread
coderabbitai[bot] marked this conversation as resolved.
preReq, shortCircuit, preCount := pipeline.RunLLMPreHooks(ctx, req)
if preReq == nil && shortCircuit == nil {
bifrostErr := newBifrostErrorFromMsg("bifrost request after plugin hooks cannot be nil")
bifrostErr.PopulateExtraFields(schemas.RealtimeRequest, provider, model, model)
_, bifrostErr = pipeline.RunPostLLMHooks(ctx, nil, bifrostErr, preCount)
drainAndAttachPluginLogs(ctx)
if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && strings.TrimSpace(traceID) != "" {
tracer.CompleteAndFlushTrace(strings.TrimSpace(traceID))
}
cleanup()
return nil, bifrostErr
}
if shortCircuit != nil {
if shortCircuit.Error != nil {
shortCircuit.Error.PopulateExtraFields(schemas.RealtimeRequest, provider, model, model)
_, bifrostErr := pipeline.RunPostLLMHooks(ctx, nil, shortCircuit.Error, preCount)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
drainAndAttachPluginLogs(ctx)
if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && strings.TrimSpace(traceID) != "" {
tracer.CompleteAndFlushTrace(strings.TrimSpace(traceID))
}
cleanup()
if bifrostErr != nil {
return nil, bifrostErr
}
return nil, shortCircuit.Error
}
if shortCircuit.Response != nil {
// Short-circuit responses are not supported for realtime turns (v1).
// Treat this like an error turn so plugins can close pending state cleanly.
bifrostErr := newBifrostErrorFromMsg("realtime turn short-circuit responses are not supported")
bifrostErr.PopulateExtraFields(schemas.RealtimeRequest, provider, model, model)
_, bifrostErr = pipeline.RunPostLLMHooks(ctx, nil, bifrostErr, preCount)
drainAndAttachPluginLogs(ctx)
if traceID, ok := ctx.Value(schemas.BifrostContextKeyTraceID).(string); ok && strings.TrimSpace(traceID) != "" {
tracer.CompleteAndFlushTrace(strings.TrimSpace(traceID))
}
cleanup()
return nil, bifrostErr
}
}

return &RealtimeTurnHooks{
PostHookRunner: func(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) {
resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, result, err, preCount)
drainAndAttachPluginLogs(ctx)
return resp, bifrostErr
Comment thread
danpiths marked this conversation as resolved.
},
Cleanup: cleanup,
Comment thread
danpiths marked this conversation as resolved.
}, nil
}

// getProviderByKey retrieves a provider instance from the providers array by its provider key.
// Returns the provider if found, or nil if no provider with the given key exists.
func (bifrost *Bifrost) getProviderByKey(providerKey schemas.ModelProvider) schemas.Provider {
Expand Down Expand Up @@ -5664,8 +5772,10 @@ func (p *PluginPipeline) RunPostLLMHooks(ctx *schemas.BifrostContext, resp *sche
if runFrom > len(p.llmPlugins) {
runFrom = len(p.llmPlugins)
}
// Detect streaming mode - if StreamStartTime is set, we're in a streaming context
isStreaming := ctx.Value(schemas.BifrostContextKeyStreamStartTime) != nil
requestType, _, _, _ := GetResponseFields(resp, bifrostErr)
// Realtime turns carry StreamStartTime for plugin latency/final-chunk context,
// but they are finalized as one completed turn, not chunk-by-chunk stream output.
isStreaming := ctx.Value(schemas.BifrostContextKeyStreamStartTime) != nil && requestType != schemas.RealtimeRequest
ctx.BlockRestrictedWrites()
defer ctx.UnblockRestrictedWrites()
var err error
Expand Down
2 changes: 1 addition & 1 deletion core/internal/llmtests/realtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func RunRealtimeTest(t *testing.T, client *bifrost.Bifrost, ctx context.Context,

bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
defer bfCtx.Cancel()
key, err := client.SelectKeyForProvider(bfCtx, testConfig.Provider, testConfig.RealtimeModel)
key, err := client.SelectKeyForProviderRequestType(bfCtx, schemas.RealtimeRequest, testConfig.Provider, testConfig.RealtimeModel)
if err != nil {
t.Fatalf("failed to select key for provider %s: %v", testConfig.Provider, err)
}
Expand Down
2 changes: 1 addition & 1 deletion core/internal/llmtests/websocket_responses.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func RunWebSocketResponsesTest(t *testing.T, client *bifrost.Bifrost, ctx contex

bfCtx := schemas.NewBifrostContext(ctx, schemas.NoDeadline)
defer bfCtx.Cancel()
key, err := client.SelectKeyForProvider(bfCtx, testConfig.Provider, testConfig.ChatModel)
key, err := client.SelectKeyForProviderRequestType(bfCtx, schemas.WebSocketResponsesRequest, testConfig.Provider, testConfig.ChatModel)
if err != nil {
t.Fatalf("failed to select key for provider %s: %v", testConfig.Provider, err)
}
Expand Down
48 changes: 41 additions & 7 deletions core/providers/elevenlabs/realtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,44 @@ func (provider *ElevenlabsProvider) RealtimeHeaders(key schemas.Key) map[string]
return headers
}

// SupportsRealtimeWebRTC returns false — ElevenLabs WebRTC SDP exchange is not yet implemented.
func (provider *ElevenlabsProvider) SupportsRealtimeWebRTC() bool {
return false
}

// ExchangeRealtimeWebRTCSDP is not yet implemented for ElevenLabs.
func (provider *ElevenlabsProvider) ExchangeRealtimeWebRTCSDP(_ *schemas.BifrostContext, _ schemas.Key, _ string, _ string, _ json.RawMessage) (string, *schemas.BifrostError) {
return "", &schemas.BifrostError{
IsBifrostError: true,
StatusCode: schemas.Ptr(400),
Error: &schemas.ErrorField{Type: schemas.Ptr("invalid_request_error"), Message: "WebRTC SDP exchange is not yet implemented for ElevenLabs"},
}
Comment thread
danpiths marked this conversation as resolved.
}

func (provider *ElevenlabsProvider) ShouldStartRealtimeTurn(event *schemas.BifrostRealtimeEvent) bool {
return false
}

func (provider *ElevenlabsProvider) RealtimeTurnFinalEvent() schemas.RealtimeEventType {
return schemas.RTEventResponseDone
}

func (provider *ElevenlabsProvider) RealtimeWebRTCDataChannelLabel() string {
return ""
}

func (provider *ElevenlabsProvider) RealtimeWebSocketSubprotocol() string {
return ""
}

func (provider *ElevenlabsProvider) ShouldForwardRealtimeEvent(event *schemas.BifrostRealtimeEvent) bool {
return true
}

func (provider *ElevenlabsProvider) ShouldAccumulateRealtimeOutput(eventType schemas.RealtimeEventType) bool {
return eventType == schemas.RTEventResponseDone
}

// ElevenLabs Conversational AI WebSocket event types
const (
elConversationInitMetadata = "conversation_initiation_metadata"
Expand All @@ -50,8 +88,8 @@ const (
elInterruption = "interruption"
elClientToolCall = "client_tool_call"

elUserAudioChunk = "user_audio_chunk"
elPong = "pong"
elUserAudioChunk = "user_audio_chunk"
elPong = "pong"
elClientToolResult = "client_tool_result"
elContextualUpdate = "contextual_update"
)
Expand Down Expand Up @@ -134,7 +172,7 @@ func (provider *ElevenlabsProvider) ToBifrostRealtimeEvent(providerEvent json.Ra
}

case elAgentResponse:
event.Type = schemas.RTEventResponseTextDone
event.Type = schemas.RTEventResponseDone
if raw.AgentResponse != nil {
var agentResp elevenlabsTranscriptEvent
if err := json.Unmarshal(raw.AgentResponse, &agentResp); err == nil {
Expand Down Expand Up @@ -194,10 +232,6 @@ func (provider *ElevenlabsProvider) ToBifrostRealtimeEvent(providerEvent json.Ra

// ToProviderRealtimeEvent converts a unified Bifrost Realtime event to ElevenLabs' native JSON.
func (provider *ElevenlabsProvider) ToProviderRealtimeEvent(bifrostEvent *schemas.BifrostRealtimeEvent) (json.RawMessage, error) {
if bifrostEvent.RawData != nil {
return bifrostEvent.RawData, nil
}

switch bifrostEvent.Type {
case schemas.RTEventInputAudioAppend:
if bifrostEvent.Delta == nil {
Expand Down
Loading
Loading