Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 89 additions & 10 deletions core/bifrost.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@ type PluginPipeline struct {
postHookTimings map[string]*pluginTimingAccumulator // keyed by plugin name
postHookPluginOrder []string // order in which post-hooks ran (for nested span creation)
chunkCount int

// Plugin logging: cached scoped contexts for streaming post-hooks (reused across chunks)
streamScopedCtxs map[string]*schemas.BifrostContext
}

// pluginTimingAccumulator accumulates timing information for a plugin across streaming chunks
Expand Down Expand Up @@ -3801,6 +3804,7 @@ func (bifrost *Bifrost) RunStreamPreHooks(ctx *schemas.BifrostContext, req *sche
if shortCircuit != nil {
if shortCircuit.Error != nil {
_, bifrostErr := pipeline.RunPostLLMHooks(ctx, nil, shortCircuit.Error, preCount)
drainAndAttachPluginLogs(ctx)
cleanup()
if bifrostErr != nil {
return nil, bifrostErr
Expand All @@ -3809,6 +3813,7 @@ func (bifrost *Bifrost) RunStreamPreHooks(ctx *schemas.BifrostContext, req *sche
}
if shortCircuit.Response != nil {
resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, shortCircuit.Response, nil, preCount)
drainAndAttachPluginLogs(ctx)
cleanup()
if bifrostErr != nil {
return nil, bifrostErr
Expand All @@ -3821,7 +3826,11 @@ func (bifrost *Bifrost) RunStreamPreHooks(ctx *schemas.BifrostContext, req *sche
}

postHookRunner := func(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) {
return pipeline.RunPostLLMHooks(ctx, result, err, preCount)
resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, result, err, preCount)
if IsFinalChunk(ctx) {
drainAndAttachPluginLogs(ctx)
}
return resp, bifrostErr
}

return &WSStreamHooks{
Expand Down Expand Up @@ -4294,6 +4303,7 @@ func (bifrost *Bifrost) tryRequest(ctx *schemas.BifrostContext, req *schemas.Bif
// Handle short-circuit with response (success case)
if shortCircuit.Response != nil {
resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, shortCircuit.Response, nil, preCount)
drainAndAttachPluginLogs(ctx)
if bifrostErr != nil {
return nil, bifrostErr
}
Expand All @@ -4302,6 +4312,7 @@ func (bifrost *Bifrost) tryRequest(ctx *schemas.BifrostContext, req *schemas.Bif
// Handle short-circuit with error
if shortCircuit.Error != nil {
resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, nil, shortCircuit.Error, preCount)
drainAndAttachPluginLogs(ctx)
if bifrostErr != nil {
return nil, bifrostErr
}
Expand Down Expand Up @@ -4409,6 +4420,7 @@ func (bifrost *Bifrost) tryRequest(ctx *schemas.BifrostContext, req *schemas.Bif
select {
case result = <-msg.Response:
resp, bifrostErr := pipeline.RunPostLLMHooks(msg.Context, result, nil, pluginCount)
drainAndAttachPluginLogs(msg.Context)
if bifrostErr != nil {
bifrost.releaseChannelMessage(msg)
return nil, bifrostErr
Expand All @@ -4425,6 +4437,7 @@ func (bifrost *Bifrost) tryRequest(ctx *schemas.BifrostContext, req *schemas.Bif
case bifrostErrVal := <-msg.Err:
bifrostErrPtr := &bifrostErrVal
resp, bifrostErrPtr = pipeline.RunPostLLMHooks(msg.Context, nil, bifrostErrPtr, pluginCount)
drainAndAttachPluginLogs(msg.Context)
bifrost.releaseChannelMessage(msg)
// Drop raw request/response on error path too
if drop, ok := ctx.Value(schemas.BifrostContextKeyRawRequestResponseForLogging).(bool); ok && drop {
Expand Down Expand Up @@ -4505,13 +4518,19 @@ func (bifrost *Bifrost) tryStreamRequest(ctx *schemas.BifrostContext, req *schem
}

pipeline := bifrost.getPluginPipeline()
defer bifrost.releasePluginPipeline(pipeline)
releasePipeline := true
defer func() {
if releasePipeline {
bifrost.releasePluginPipeline(pipeline)
}
}()

preReq, shortCircuit, preCount := pipeline.RunLLMPreHooks(ctx, req)
if shortCircuit != nil {
// Handle short-circuit with response (success case)
if shortCircuit.Response != nil {
resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, shortCircuit.Response, nil, preCount)
drainAndAttachPluginLogs(ctx)
if bifrostErr != nil {
return nil, bifrostErr
}
Expand All @@ -4520,13 +4539,23 @@ func (bifrost *Bifrost) tryStreamRequest(ctx *schemas.BifrostContext, req *schem
// Handle short-circuit with stream
if shortCircuit.Stream != nil {
outputStream := make(chan *schemas.BifrostStreamChunk)
releasePipeline = false // pipeline is released inside the goroutine after stream drains

// Create a post hook runner cause pipeline object is put back in the pool on defer
pipelinePostHookRunner := func(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) {
return pipeline.RunPostLLMHooks(ctx, result, err, preCount)
resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, result, err, preCount)
if IsFinalChunk(ctx) {
drainAndAttachPluginLogs(ctx)
}
return resp, bifrostErr
Comment thread
coderabbitai[bot] marked this conversation as resolved.
}

go func() {
defer func() {
drainAndAttachPluginLogs(ctx) // ensure logs are drained even if stream closes without a final chunk
pipeline.FinalizeStreamingPostHookSpans(ctx)
bifrost.releasePluginPipeline(pipeline)
}()
Comment thread
coderabbitai[bot] marked this conversation as resolved.
defer close(outputStream)

for streamMsg := range shortCircuit.Stream {
Expand Down Expand Up @@ -4574,6 +4603,7 @@ func (bifrost *Bifrost) tryStreamRequest(ctx *schemas.BifrostContext, req *schem
// Handle short-circuit with error
if shortCircuit.Error != nil {
resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, nil, shortCircuit.Error, preCount)
drainAndAttachPluginLogs(ctx)
if bifrostErr != nil {
return nil, bifrostErr
}
Expand Down Expand Up @@ -4689,6 +4719,7 @@ func (bifrost *Bifrost) tryStreamRequest(ctx *schemas.BifrostContext, req *schem
ctx.SetValue(schemas.BifrostContextKeyStreamEndIndicator, true)
// On error we will complete post-hooks
recoveredResp, recoveredErr := pipeline.RunPostLLMHooks(ctx, nil, &bifrostErrVal, len(*bifrost.llmPlugins.Load()))
drainAndAttachPluginLogs(ctx)
bifrost.releaseChannelMessage(msg)
if recoveredErr != nil {
return nil, recoveredErr
Expand Down Expand Up @@ -5022,6 +5053,9 @@ func (bifrost *Bifrost) requestWorker(provider schemas.Provider, config *schemas
pipeline = bifrost.getPluginPipeline()
postHookRunner = func(ctx *schemas.BifrostContext, result *schemas.BifrostResponse, err *schemas.BifrostError) (*schemas.BifrostResponse, *schemas.BifrostError) {
resp, bifrostErr := pipeline.RunPostLLMHooks(ctx, result, err, len(*bifrost.llmPlugins.Load()))
if IsFinalChunk(ctx) {
drainAndAttachPluginLogs(ctx)
}
if bifrostErr != nil {
return nil, bifrostErr
}
Expand Down Expand Up @@ -5445,6 +5479,7 @@ func (bifrost *Bifrost) handleMCPToolExecution(ctx *schemas.BifrostContext, mcpR
// Handle short-circuit with response (success case)
if shortCircuit.Response != nil {
finalMcpResp, bifrostErr := pipeline.RunMCPPostHooks(ctx, shortCircuit.Response, nil, preCount)
drainAndAttachPluginLogs(ctx)
if bifrostErr != nil {
return nil, bifrostErr
}
Expand All @@ -5454,6 +5489,7 @@ func (bifrost *Bifrost) handleMCPToolExecution(ctx *schemas.BifrostContext, mcpR
if shortCircuit.Error != nil {
// Capture post-hook results to respect transformations or recovery
finalResp, finalErr := pipeline.RunMCPPostHooks(ctx, nil, shortCircuit.Error, preCount)
drainAndAttachPluginLogs(ctx)
// Return post-hook error if present (post-hook may have transformed the error)
if finalErr != nil {
return nil, finalErr
Expand Down Expand Up @@ -5513,6 +5549,7 @@ func (bifrost *Bifrost) handleMCPToolExecution(ctx *schemas.BifrostContext, mcpR

// Run post-hooks
finalResp, finalErr := pipeline.RunMCPPostHooks(ctx, mcpResp, bifrostErr, preCount)
drainAndAttachPluginLogs(ctx)

if finalErr != nil {
return nil, finalErr
Expand Down Expand Up @@ -5577,7 +5614,9 @@ func (p *PluginPipeline) RunLLMPreHooks(ctx *schemas.BifrostContext, req *schema
}
}

req, shortCircuit, err = plugin.PreLLMHook(ctx, req)
pluginCtx := ctx.WithPluginScope(&pluginName)
req, shortCircuit, err = plugin.PreLLMHook(pluginCtx, req)
pluginCtx.ReleasePluginScope()

// End span with appropriate status
if err != nil {
Expand Down Expand Up @@ -5628,8 +5667,17 @@ func (p *PluginPipeline) RunPostLLMHooks(ctx *schemas.BifrostContext, resp *sche
p.logger.Debug("running post-hook for plugin %s", pluginName)
if isStreaming {
// For streaming: accumulate timing, don't create individual spans per chunk
// Lazily create cached scoped contexts on first chunk (reused across all chunks)
if p.streamScopedCtxs == nil {
p.streamScopedCtxs = make(map[string]*schemas.BifrostContext, len(p.llmPlugins))
for _, pl := range p.llmPlugins {
name := pl.GetName()
p.streamScopedCtxs[name] = ctx.WithPluginScope(&name)
}
}
pluginCtx := p.streamScopedCtxs[pluginName]
start := time.Now()
resp, bifrostErr, err = plugin.PostLLMHook(ctx, resp, bifrostErr)
resp, bifrostErr, err = plugin.PostLLMHook(pluginCtx, resp, bifrostErr)
duration := time.Since(start)

p.accumulatePluginTiming(pluginName, duration, err != nil)
Expand All @@ -5646,7 +5694,9 @@ func (p *PluginPipeline) RunPostLLMHooks(ctx *schemas.BifrostContext, resp *sche
ctx.SetValue(schemas.BifrostContextKeySpanID, spanID)
}
}
resp, bifrostErr, err = plugin.PostLLMHook(ctx, resp, bifrostErr)
pluginCtx := ctx.WithPluginScope(&pluginName)
resp, bifrostErr, err = plugin.PostLLMHook(pluginCtx, resp, bifrostErr)
pluginCtx.ReleasePluginScope()
// End span with appropriate status
if err != nil {
p.tracer.SetAttribute(handle, "error", err.Error())
Expand Down Expand Up @@ -5700,7 +5750,9 @@ func (p *PluginPipeline) RunMCPPreHooks(ctx *schemas.BifrostContext, req *schema
}
}

req, shortCircuit, err = plugin.PreMCPHook(ctx, req)
pluginCtx := ctx.WithPluginScope(&pluginName)
req, shortCircuit, err = plugin.PreMCPHook(pluginCtx, req)
pluginCtx.ReleasePluginScope()

// End span with appropriate status
if err != nil {
Expand Down Expand Up @@ -5755,7 +5807,9 @@ func (p *PluginPipeline) RunMCPPostHooks(ctx *schemas.BifrostContext, mcpResp *s
}
}

mcpResp, bifrostErr, err = plugin.PostMCPHook(ctx, mcpResp, bifrostErr)
pluginCtx := ctx.WithPluginScope(&pluginName)
mcpResp, bifrostErr, err = plugin.PostMCPHook(pluginCtx, mcpResp, bifrostErr)
pluginCtx.ReleasePluginScope()

// End span with appropriate status
if err != nil {
Expand All @@ -5781,7 +5835,11 @@ func (p *PluginPipeline) RunMCPPostHooks(ctx *schemas.BifrostContext, mcpResp *s
return mcpResp, nil
}

// resetPluginPipeline resets a PluginPipeline instance for reuse
// resetPluginPipeline resets a PluginPipeline instance for reuse.
// IMPORTANT: drainAndAttachPluginLogs must be called on the root BifrostContext
// BEFORE this method, because it calls ReleasePluginScope on cached scoped contexts
// which nils out their pluginLogs pointer. The drain reads from the shared store
// on the root context, so it must happen while the store is still referenced.
func (p *PluginPipeline) resetPluginPipeline() {
p.executedPreHooks = 0
p.preHookErrors = p.preHookErrors[:0]
Expand All @@ -5792,6 +5850,25 @@ func (p *PluginPipeline) resetPluginPipeline() {
clear(p.postHookTimings)
}
p.postHookPluginOrder = p.postHookPluginOrder[:0]
// Release cached scoped contexts for streaming
for _, scopedCtx := range p.streamScopedCtxs {
scopedCtx.ReleasePluginScope()
}
p.streamScopedCtxs = nil
}

// drainAndAttachPluginLogs drains accumulated plugin logs from the BifrostContext
// and attaches them to the trace for later retrieval by observability plugins.
func drainAndAttachPluginLogs(ctx *schemas.BifrostContext) {
tracer, traceID, err := GetTracerFromContext(ctx)
if err != nil || tracer == nil || traceID == "" {
return
}
logs := ctx.DrainPluginLogs()
if len(logs) == 0 {
return
}
tracer.AttachPluginLogs(traceID, logs)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
}

// accumulatePluginTiming accumulates timing for a plugin during streaming
Expand Down Expand Up @@ -5883,7 +5960,9 @@ func (bifrost *Bifrost) getPluginPipeline() *PluginPipeline {
return pipeline
}

// releasePluginPipeline returns a PluginPipeline to the pool
// releasePluginPipeline returns a PluginPipeline to the pool.
// Caller must ensure drainAndAttachPluginLogs has already been called on the
// associated BifrostContext before calling this method.
func (bifrost *Bifrost) releasePluginPipeline(pipeline *PluginPipeline) {
pipeline.resetPluginPipeline()
bifrost.pluginPipelinePool.Put(pipeline)
Expand Down
23 changes: 23 additions & 0 deletions core/schemas/bifrost.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,8 @@ const (
BifrostContextKeyPassthroughExtraParams BifrostContextKey = "bifrost-passthrough-extra-params" // bool
BifrostContextKeyRoutingEnginesUsed BifrostContextKey = "bifrost-routing-engines-used" // []string (set by bifrost - DO NOT SET THIS MANUALLY) - list of routing engines used ("routing-rule", "governance", "loadbalancing", etc.)
BifrostContextKeyRoutingEngineLogs BifrostContextKey = "bifrost-routing-engine-logs" // []RoutingEngineLogEntry (set by bifrost - DO NOT SET THIS MANUALLY) - list of routing engine log entries
BifrostContextKeyTransportPluginLogs BifrostContextKey = "bifrost-transport-plugin-logs" // []PluginLogEntry (transport-layer plugin logs accumulated during HTTP transport hooks)
BifrostContextKeyTransportPostHookCompleter BifrostContextKey = "bifrost-transport-posthook-completer" // func() (callback to run HTTPTransportPostHook after streaming - set by transport interceptor middleware)
BifrostContextKeySkipPluginPipeline BifrostContextKey = "bifrost-skip-plugin-pipeline" // bool - skip plugin pipeline for the request
BifrostIsAsyncRequest BifrostContextKey = "bifrost-is-async-request" // bool (set by bifrost - DO NOT SET THIS MANUALLY)) - whether the request is an async request (only used in gateway)
BifrostContextKeyRequestHeaders BifrostContextKey = "bifrost-request-headers" // map[string]string (all request headers with lowercased keys)
Expand Down Expand Up @@ -276,6 +278,27 @@ type RoutingEngineLogEntry struct {
Timestamp int64 // Unix milliseconds
}

// PluginLogEntry represents a structured log entry emitted by a plugin via ctx.Log().
type PluginLogEntry struct {
PluginName string `json:"plugin_name"`
Level LogLevel `json:"level"`
Message string `json:"message"`
Timestamp int64 `json:"timestamp"` // Unix milliseconds
}

// GroupPluginLogsByName groups a flat slice of plugin log entries by plugin name.
// Returns nil if the input is empty.
func GroupPluginLogsByName(logs []PluginLogEntry) map[string][]PluginLogEntry {
if len(logs) == 0 {
return nil
}
grouped := make(map[string][]PluginLogEntry, min(len(logs), 4))
for _, entry := range logs {
grouped[entry.PluginName] = append(grouped[entry.PluginName], entry)
}
return grouped
}

// NOTE: for custom plugin implementation dealing with streaming short circuit,
// make sure to mark BifrostContextKeyStreamEndIndicator as true at the end of the stream.

Expand Down
Loading